mirror of
https://github.com/tcsenpai/agenticSeek.git
synced 2025-07-23 09:50:30 +00:00
Docs: better documentation
This commit is contained in:
parent
ff1af3b6a9
commit
eca688baba
@ -5,6 +5,9 @@ from sources.router import AgentRouter
|
||||
from sources.speech_to_text import AudioTranscriber, AudioRecorder
|
||||
|
||||
class Interaction:
|
||||
"""
|
||||
Interaction is a class that handles the interaction between the user and the agents.
|
||||
"""
|
||||
def __init__(self, agents,
|
||||
tts_enabled: bool = True,
|
||||
stt_enabled: bool = True,
|
||||
@ -29,6 +32,7 @@ class Interaction:
|
||||
self.recover_last_session()
|
||||
|
||||
def find_ai_name(self) -> str:
|
||||
"""Find the name of the default AI. It is required for STT as a trigger word."""
|
||||
ai_name = "jarvis"
|
||||
for agent in self.agents:
|
||||
if agent.role == "talking":
|
||||
@ -37,17 +41,20 @@ class Interaction:
|
||||
return ai_name
|
||||
|
||||
def recover_last_session(self):
|
||||
"""Recover the last session."""
|
||||
for agent in self.agents:
|
||||
agent.memory.load_memory()
|
||||
|
||||
def save_session(self):
|
||||
"""Save the current session."""
|
||||
for agent in self.agents:
|
||||
agent.memory.save_memory()
|
||||
|
||||
def is_active(self):
|
||||
def is_active(self) -> bool:
|
||||
return self.is_active
|
||||
|
||||
def read_stdin(self) -> str:
|
||||
"""Read the input from the user."""
|
||||
buffer = ""
|
||||
|
||||
while buffer == "" or buffer.isascii() == False:
|
||||
@ -59,7 +66,8 @@ class Interaction:
|
||||
return None
|
||||
return buffer
|
||||
|
||||
def transcription_job(self):
|
||||
def transcription_job(self) -> str:
|
||||
"""Transcribe the audio from the microphone."""
|
||||
self.recorder = AudioRecorder(verbose=True)
|
||||
self.transcriber = AudioTranscriber(self.ai_name, verbose=True)
|
||||
self.transcriber.start()
|
||||
@ -69,7 +77,8 @@ class Interaction:
|
||||
query = self.transcriber.get_transcript()
|
||||
return query
|
||||
|
||||
def get_user(self):
|
||||
def get_user_input(self) -> str:
|
||||
"""Get the user input from the microphone or the keyboard."""
|
||||
if self.stt_enabled:
|
||||
query = "TTS transcription of user: " + self.transcription_job()
|
||||
else:
|
||||
@ -81,7 +90,8 @@ class Interaction:
|
||||
self.last_query = query
|
||||
return query
|
||||
|
||||
def think(self):
|
||||
def think(self) -> None:
|
||||
"""Request AI agents to process the user input."""
|
||||
if self.last_query is None or len(self.last_query) == 0:
|
||||
return
|
||||
agent = self.router.select_agent(self.last_query)
|
||||
@ -93,7 +103,8 @@ class Interaction:
|
||||
self.current_agent.memory.push('user', self.last_query)
|
||||
self.last_answer, _ = agent.process(self.last_query, self.speech)
|
||||
|
||||
def show_answer(self):
|
||||
def show_answer(self) -> None:
|
||||
"""Show the answer to the user."""
|
||||
if self.last_query is None:
|
||||
return
|
||||
self.current_agent.show_answer()
|
||||
|
@ -39,6 +39,7 @@ class Memory():
|
||||
return f"memory_{self.session_time.strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
||||
|
||||
def save_memory(self) -> None:
|
||||
"""Save the session memory to a file."""
|
||||
if not os.path.exists(self.conversation_folder):
|
||||
os.makedirs(self.conversation_folder)
|
||||
filename = self.get_filename()
|
||||
@ -48,6 +49,7 @@ class Memory():
|
||||
f.write(json_memory)
|
||||
|
||||
def find_last_session_path(self) -> str:
|
||||
"""Find the last session path."""
|
||||
saved_sessions = []
|
||||
for filename in os.listdir(self.conversation_folder):
|
||||
if filename.startswith('memory_'):
|
||||
@ -59,6 +61,7 @@ class Memory():
|
||||
return None
|
||||
|
||||
def load_memory(self) -> None:
|
||||
"""Load the memory from the last session."""
|
||||
if not os.path.exists(self.conversation_folder):
|
||||
return
|
||||
filename = self.find_last_session_path()
|
||||
@ -72,6 +75,7 @@ class Memory():
|
||||
self.memory = memory
|
||||
|
||||
def push(self, role: str, content: str) -> None:
|
||||
"""Push a message to the memory."""
|
||||
self.memory.append({'role': role, 'content': content})
|
||||
# EXPERIMENTAL
|
||||
if self.memory_compression and role == 'assistant':
|
||||
@ -92,6 +96,14 @@ class Memory():
|
||||
return "cpu"
|
||||
|
||||
def summarize(self, text: str, min_length: int = 64) -> str:
|
||||
"""
|
||||
Summarize the text using the AI model.
|
||||
Args:
|
||||
text (str): The text to summarize
|
||||
min_length (int, optional): The minimum length of the summary. Defaults to 64.
|
||||
Returns:
|
||||
str: The summarized text
|
||||
"""
|
||||
if self.tokenizer is None or self.model is None:
|
||||
return text
|
||||
max_length = len(text) // 2 if len(text) > min_length*2 else min_length*2
|
||||
@ -110,6 +122,9 @@ class Memory():
|
||||
|
||||
@timer_decorator
|
||||
def compress(self) -> str:
|
||||
"""
|
||||
Compress the memory using the AI model.
|
||||
"""
|
||||
if not self.memory_compression:
|
||||
return
|
||||
for i in range(len(self.memory)):
|
||||
|
@ -6,14 +6,17 @@ from sources.casual_agent import CasualAgent
|
||||
from sources.utility import pretty_print
|
||||
|
||||
class AgentRouter:
|
||||
def __init__(self, agents: list, model_name="facebook/bart-large-mnli"):
|
||||
"""
|
||||
AgentRouter is a class that selects the appropriate agent based on the user query.
|
||||
"""
|
||||
def __init__(self, agents: list, model_name: str = "facebook/bart-large-mnli"):
|
||||
self.model = model_name
|
||||
self.pipeline = pipeline("zero-shot-classification",
|
||||
model=self.model)
|
||||
self.agents = agents
|
||||
self.labels = [agent.role for agent in agents]
|
||||
|
||||
def get_device(self):
|
||||
def get_device(self) -> str:
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
elif torch.cuda.is_available():
|
||||
@ -21,10 +24,17 @@ class AgentRouter:
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
def classify_text(self, text, threshold=0.5):
|
||||
def classify_text(self, text: str, threshold: float = 0.5) -> list:
|
||||
"""
|
||||
Classify the text into labels (agent roles).
|
||||
Args:
|
||||
text (str): The text to classify
|
||||
threshold (float, optional): The threshold for the classification.
|
||||
Returns:
|
||||
list: The list of agents and their scores
|
||||
"""
|
||||
first_sentence = None
|
||||
for line in text.split("\n"):
|
||||
if line.strip() != "":
|
||||
first_sentence = line.strip()
|
||||
break
|
||||
if first_sentence is None:
|
||||
@ -33,6 +43,13 @@ class AgentRouter:
|
||||
return result
|
||||
|
||||
def select_agent(self, text: str) -> Agent:
|
||||
"""
|
||||
Select the appropriate agent based on the text.
|
||||
Args:
|
||||
text (str): The text to select the agent from
|
||||
Returns:
|
||||
Agent: The selected agent
|
||||
"""
|
||||
if len(self.agents) == 0 or len(self.labels) == 0:
|
||||
return self.agents[0]
|
||||
result = self.classify_text(text)
|
||||
|
@ -12,7 +12,10 @@ audio_queue = queue.Queue()
|
||||
done = False
|
||||
|
||||
class AudioRecorder:
|
||||
def __init__(self, format=pyaudio.paInt16, channels=1, rate=4096, chunk=8192, record_seconds=5, verbose=False):
|
||||
"""
|
||||
AudioRecorder is a class that records audio from the microphone and adds it to the audio queue.
|
||||
"""
|
||||
def __init__(self, format: int = pyaudio.paInt16, channels: int = 1, rate: int = 4096, chunk: int = 8192, record_seconds: int = 5, verbose: bool = False):
|
||||
self.format = format
|
||||
self.channels = channels
|
||||
self.rate = rate
|
||||
@ -22,7 +25,10 @@ class AudioRecorder:
|
||||
self.audio = pyaudio.PyAudio()
|
||||
self.thread = threading.Thread(target=self._record, daemon=True)
|
||||
|
||||
def _record(self):
|
||||
def _record(self) -> None:
|
||||
"""
|
||||
Record audio from the microphone and add it to the audio queue.
|
||||
"""
|
||||
stream = self.audio.open(format=self.format, channels=self.channels, rate=self.rate,
|
||||
input=True, frames_per_buffer=self.chunk)
|
||||
if self.verbose:
|
||||
@ -49,16 +55,19 @@ class AudioRecorder:
|
||||
if self.verbose:
|
||||
print(Fore.GREEN + "AudioRecorder: Stopped" + Fore.RESET)
|
||||
|
||||
def start(self):
|
||||
def start(self) -> None:
|
||||
"""Start the recording thread."""
|
||||
self.thread.start()
|
||||
|
||||
def join(self):
|
||||
def join(self) -> None:
|
||||
"""Wait for the recording thread to finish."""
|
||||
self.thread.join()
|
||||
|
||||
class Transcript:
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Transcript is a class that transcribes audio from the audio queue and adds it to the transcript.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.last_read = None
|
||||
device = self.get_device()
|
||||
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||||
@ -80,7 +89,7 @@ class Transcript:
|
||||
device=device,
|
||||
)
|
||||
|
||||
def get_device(self):
|
||||
def get_device(self) -> str:
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
if torch.cuda.is_available():
|
||||
@ -88,14 +97,16 @@ class Transcript:
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
def remove_hallucinations(self, text: str):
|
||||
def remove_hallucinations(self, text: str) -> str:
|
||||
"""Remove model hallucinations from the text."""
|
||||
# TODO find a better way to do this
|
||||
common_hallucinations = ['Okay.', 'Thank you.', 'Thank you for watching.', 'You\'re', 'Oh', 'you', 'Oh.', 'Uh', 'Oh,', 'Mh-hmm', 'Hmm.', 'going to.', 'not.']
|
||||
for hallucination in common_hallucinations:
|
||||
text = text.replace(hallucination, "")
|
||||
return text
|
||||
|
||||
def transcript_job(self, audio_data: np.ndarray, sample_rate: int = 16000):
|
||||
def transcript_job(self, audio_data: np.ndarray, sample_rate: int = 16000) -> str:
|
||||
"""Transcribe the audio data."""
|
||||
if audio_data.dtype != np.float32:
|
||||
audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
|
||||
if len(audio_data.shape) > 1:
|
||||
@ -106,7 +117,10 @@ class Transcript:
|
||||
return self.remove_hallucinations(result["text"])
|
||||
|
||||
class AudioTranscriber:
|
||||
def __init__(self, ai_name: str, verbose=False):
|
||||
"""
|
||||
AudioTranscriber is a class that transcribes audio from the audio queue and adds it to the transcript.
|
||||
"""
|
||||
def __init__(self, ai_name: str, verbose: bool = False):
|
||||
self.verbose = verbose
|
||||
self.ai_name = ai_name
|
||||
self.transcriptor = Transcript()
|
||||
@ -126,14 +140,17 @@ class AudioTranscriber:
|
||||
}
|
||||
self.recorded = ""
|
||||
|
||||
def get_transcript(self):
|
||||
def get_transcript(self) -> str:
|
||||
global done
|
||||
buffer = self.recorded
|
||||
self.recorded = ""
|
||||
done = False
|
||||
return buffer
|
||||
|
||||
def _transcribe(self):
|
||||
def _transcribe(self) -> None:
|
||||
"""
|
||||
Transcribe the audio data using AI stt model.
|
||||
"""
|
||||
global done
|
||||
if self.verbose:
|
||||
print(Fore.BLUE + "AudioTranscriber: Started processing..." + Fore.RESET)
|
||||
|
@ -9,7 +9,7 @@ class Speech():
|
||||
"""
|
||||
Speech is a class for generating speech from text.
|
||||
"""
|
||||
def __init__(self, language = "english") -> None:
|
||||
def __init__(self, language: str = "english") -> None:
|
||||
self.lang_map = {
|
||||
"english": 'a',
|
||||
"chinese": 'z',
|
||||
@ -24,9 +24,13 @@ class Speech():
|
||||
self.voice = self.voice_map[language][2]
|
||||
self.speed = 1.2
|
||||
|
||||
def speak(self, sentence, voice_number = 1):
|
||||
def speak(self, sentence: str, voice_number: int = 1):
|
||||
"""
|
||||
Use AI model to generate speech from text after pre-processing the text.
|
||||
Convert text to speech using an AI model and play the audio.
|
||||
|
||||
Args:
|
||||
sentence (str): The text to convert to speech. Will be pre-processed.
|
||||
voice_number (int, optional): Index of the voice to use from the voice map.
|
||||
"""
|
||||
sentence = self.clean_sentence(sentence)
|
||||
self.voice = self.voice_map["english"][voice_number]
|
||||
@ -45,18 +49,26 @@ class Speech():
|
||||
import winsound
|
||||
winsound.PlaySound(audio_file, winsound.SND_FILENAME)
|
||||
|
||||
def replace_url(self, m):
|
||||
def replace_url(self, url: re.Match) -> str:
|
||||
"""
|
||||
Replace URL with empty string.
|
||||
Replace URL with domain name or empty string if IP address.
|
||||
Args:
|
||||
url (re.Match): Match object containing the URL pattern match
|
||||
Returns:
|
||||
str: The domain name from the URL, or empty string if IP address
|
||||
"""
|
||||
domain = m.group(1)
|
||||
domain = url.group(1)
|
||||
if re.match(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$', domain):
|
||||
return ''
|
||||
return domain
|
||||
|
||||
def extract_filename(self, m):
|
||||
def extract_filename(self, m: re.Match) -> str:
|
||||
"""
|
||||
Extract filename from path.
|
||||
Args:
|
||||
m (re.Match): Match object containing the path pattern match
|
||||
Returns:
|
||||
str: The filename from the path
|
||||
"""
|
||||
path = m.group()
|
||||
parts = re.split(r'/|\\', path)
|
||||
@ -65,6 +77,10 @@ class Speech():
|
||||
def shorten_paragraph(self, sentence):
|
||||
"""
|
||||
Shorten paragraph like **explaination**: <long text> by keeping only the first sentence.
|
||||
Args:
|
||||
sentence (str): The sentence to shorten
|
||||
Returns:
|
||||
str: The shortened sentence
|
||||
"""
|
||||
lines = sentence.split('\n')
|
||||
lines_edited = []
|
||||
@ -77,7 +93,11 @@ class Speech():
|
||||
|
||||
def clean_sentence(self, sentence):
|
||||
"""
|
||||
Clean sentence by removing URLs, filenames, and other non-alphanumeric characters.
|
||||
Clean and normalize text for speech synthesis by removing technical elements.
|
||||
Args:
|
||||
sentence (str): The input text to clean
|
||||
Returns:
|
||||
str: The cleaned text with URLs replaced by domain names, code blocks removed, etc..
|
||||
"""
|
||||
lines = sentence.split('\n')
|
||||
filtered_lines = [line for line in lines if re.match(r'^\s*[a-zA-Z]', line)]
|
||||
|
@ -6,7 +6,19 @@ import platform
|
||||
|
||||
def pretty_print(text, color = "info"):
|
||||
"""
|
||||
print text with color
|
||||
Print text with color formatting.
|
||||
|
||||
Args:
|
||||
text (str): The text to print
|
||||
color (str, optional): The color to use. Defaults to "info".
|
||||
Valid colors are:
|
||||
- "success": Green
|
||||
- "failure": Red
|
||||
- "status": Light green
|
||||
- "code": Light blue
|
||||
- "warning": Yellow
|
||||
- "output": Cyan
|
||||
- "default": Black (Windows only)
|
||||
"""
|
||||
if platform.system().lower() != "windows":
|
||||
color_map = {
|
||||
@ -37,6 +49,13 @@ def pretty_print(text, color = "info"):
|
||||
print(colored(text, color_map[color]))
|
||||
|
||||
def timer_decorator(func):
|
||||
"""
|
||||
Decorator to measure the execution time of a function.
|
||||
Usage:
|
||||
@timer_decorator
|
||||
def my_function():
|
||||
# code to execute
|
||||
"""
|
||||
from time import time
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time()
|
||||
|
Loading…
x
Reference in New Issue
Block a user