From eca688baba5c490efd52340c7005ef1a8f15b6a3 Mon Sep 17 00:00:00 2001 From: martin legrand Date: Thu, 6 Mar 2025 11:58:54 +0100 Subject: [PATCH] Docs: better documentation --- sources/interaction.py | 21 ++++++++++++++++----- sources/memory.py | 15 +++++++++++++++ sources/router.py | 25 +++++++++++++++++++++---- sources/speech_to_text.py | 39 ++++++++++++++++++++++++++++----------- sources/text_to_speech.py | 36 ++++++++++++++++++++++++++++-------- sources/utility.py | 21 ++++++++++++++++++++- 6 files changed, 128 insertions(+), 29 deletions(-) diff --git a/sources/interaction.py b/sources/interaction.py index d35deac..836b0fe 100644 --- a/sources/interaction.py +++ b/sources/interaction.py @@ -5,6 +5,9 @@ from sources.router import AgentRouter from sources.speech_to_text import AudioTranscriber, AudioRecorder class Interaction: + """ + Interaction is a class that handles the interaction between the user and the agents. + """ def __init__(self, agents, tts_enabled: bool = True, stt_enabled: bool = True, @@ -29,6 +32,7 @@ class Interaction: self.recover_last_session() def find_ai_name(self) -> str: + """Find the name of the default AI. It is required for STT as a trigger word.""" ai_name = "jarvis" for agent in self.agents: if agent.role == "talking": @@ -37,17 +41,20 @@ class Interaction: return ai_name def recover_last_session(self): + """Recover the last session.""" for agent in self.agents: agent.memory.load_memory() def save_session(self): + """Save the current session.""" for agent in self.agents: agent.memory.save_memory() - def is_active(self): + def is_active(self) -> bool: return self.is_active def read_stdin(self) -> str: + """Read the input from the user.""" buffer = "" while buffer == "" or buffer.isascii() == False: @@ -59,7 +66,8 @@ class Interaction: return None return buffer - def transcription_job(self): + def transcription_job(self) -> str: + """Transcribe the audio from the microphone.""" self.recorder = AudioRecorder(verbose=True) self.transcriber = AudioTranscriber(self.ai_name, verbose=True) self.transcriber.start() @@ -69,7 +77,8 @@ class Interaction: query = self.transcriber.get_transcript() return query - def get_user(self): + def get_user_input(self) -> str: + """Get the user input from the microphone or the keyboard.""" if self.stt_enabled: query = "TTS transcription of user: " + self.transcription_job() else: @@ -81,7 +90,8 @@ class Interaction: self.last_query = query return query - def think(self): + def think(self) -> None: + """Request AI agents to process the user input.""" if self.last_query is None or len(self.last_query) == 0: return agent = self.router.select_agent(self.last_query) @@ -93,7 +103,8 @@ class Interaction: self.current_agent.memory.push('user', self.last_query) self.last_answer, _ = agent.process(self.last_query, self.speech) - def show_answer(self): + def show_answer(self) -> None: + """Show the answer to the user.""" if self.last_query is None: return self.current_agent.show_answer() diff --git a/sources/memory.py b/sources/memory.py index 3a8507a..af2c4f0 100644 --- a/sources/memory.py +++ b/sources/memory.py @@ -39,6 +39,7 @@ class Memory(): return f"memory_{self.session_time.strftime('%Y-%m-%d_%H-%M-%S')}.txt" def save_memory(self) -> None: + """Save the session memory to a file.""" if not os.path.exists(self.conversation_folder): os.makedirs(self.conversation_folder) filename = self.get_filename() @@ -48,6 +49,7 @@ class Memory(): f.write(json_memory) def find_last_session_path(self) -> str: + """Find the last session path.""" saved_sessions = [] for filename in os.listdir(self.conversation_folder): if filename.startswith('memory_'): @@ -59,6 +61,7 @@ class Memory(): return None def load_memory(self) -> None: + """Load the memory from the last session.""" if not os.path.exists(self.conversation_folder): return filename = self.find_last_session_path() @@ -72,6 +75,7 @@ class Memory(): self.memory = memory def push(self, role: str, content: str) -> None: + """Push a message to the memory.""" self.memory.append({'role': role, 'content': content}) # EXPERIMENTAL if self.memory_compression and role == 'assistant': @@ -92,6 +96,14 @@ class Memory(): return "cpu" def summarize(self, text: str, min_length: int = 64) -> str: + """ + Summarize the text using the AI model. + Args: + text (str): The text to summarize + min_length (int, optional): The minimum length of the summary. Defaults to 64. + Returns: + str: The summarized text + """ if self.tokenizer is None or self.model is None: return text max_length = len(text) // 2 if len(text) > min_length*2 else min_length*2 @@ -110,6 +122,9 @@ class Memory(): @timer_decorator def compress(self) -> str: + """ + Compress the memory using the AI model. + """ if not self.memory_compression: return for i in range(len(self.memory)): diff --git a/sources/router.py b/sources/router.py index 63727c4..094489f 100644 --- a/sources/router.py +++ b/sources/router.py @@ -6,14 +6,17 @@ from sources.casual_agent import CasualAgent from sources.utility import pretty_print class AgentRouter: - def __init__(self, agents: list, model_name="facebook/bart-large-mnli"): + """ + AgentRouter is a class that selects the appropriate agent based on the user query. + """ + def __init__(self, agents: list, model_name: str = "facebook/bart-large-mnli"): self.model = model_name self.pipeline = pipeline("zero-shot-classification", model=self.model) self.agents = agents self.labels = [agent.role for agent in agents] - def get_device(self): + def get_device(self) -> str: if torch.backends.mps.is_available(): return "mps" elif torch.cuda.is_available(): @@ -21,10 +24,17 @@ class AgentRouter: else: return "cpu" - def classify_text(self, text, threshold=0.5): + def classify_text(self, text: str, threshold: float = 0.5) -> list: + """ + Classify the text into labels (agent roles). + Args: + text (str): The text to classify + threshold (float, optional): The threshold for the classification. + Returns: + list: The list of agents and their scores + """ first_sentence = None for line in text.split("\n"): - if line.strip() != "": first_sentence = line.strip() break if first_sentence is None: @@ -33,6 +43,13 @@ class AgentRouter: return result def select_agent(self, text: str) -> Agent: + """ + Select the appropriate agent based on the text. + Args: + text (str): The text to select the agent from + Returns: + Agent: The selected agent + """ if len(self.agents) == 0 or len(self.labels) == 0: return self.agents[0] result = self.classify_text(text) diff --git a/sources/speech_to_text.py b/sources/speech_to_text.py index c24771f..b9b9983 100644 --- a/sources/speech_to_text.py +++ b/sources/speech_to_text.py @@ -12,7 +12,10 @@ audio_queue = queue.Queue() done = False class AudioRecorder: - def __init__(self, format=pyaudio.paInt16, channels=1, rate=4096, chunk=8192, record_seconds=5, verbose=False): + """ + AudioRecorder is a class that records audio from the microphone and adds it to the audio queue. + """ + def __init__(self, format: int = pyaudio.paInt16, channels: int = 1, rate: int = 4096, chunk: int = 8192, record_seconds: int = 5, verbose: bool = False): self.format = format self.channels = channels self.rate = rate @@ -22,7 +25,10 @@ class AudioRecorder: self.audio = pyaudio.PyAudio() self.thread = threading.Thread(target=self._record, daemon=True) - def _record(self): + def _record(self) -> None: + """ + Record audio from the microphone and add it to the audio queue. + """ stream = self.audio.open(format=self.format, channels=self.channels, rate=self.rate, input=True, frames_per_buffer=self.chunk) if self.verbose: @@ -49,16 +55,19 @@ class AudioRecorder: if self.verbose: print(Fore.GREEN + "AudioRecorder: Stopped" + Fore.RESET) - def start(self): + def start(self) -> None: """Start the recording thread.""" self.thread.start() - def join(self): + def join(self) -> None: """Wait for the recording thread to finish.""" self.thread.join() class Transcript: - def __init__(self) -> None: + """ + Transcript is a class that transcribes audio from the audio queue and adds it to the transcript. + """ + def __init__(self): self.last_read = None device = self.get_device() torch_dtype = torch.float16 if device == "cuda" else torch.float32 @@ -80,7 +89,7 @@ class Transcript: device=device, ) - def get_device(self): + def get_device(self) -> str: if torch.backends.mps.is_available(): return "mps" if torch.cuda.is_available(): @@ -88,14 +97,16 @@ class Transcript: else: return "cpu" - def remove_hallucinations(self, text: str): + def remove_hallucinations(self, text: str) -> str: + """Remove model hallucinations from the text.""" # TODO find a better way to do this common_hallucinations = ['Okay.', 'Thank you.', 'Thank you for watching.', 'You\'re', 'Oh', 'you', 'Oh.', 'Uh', 'Oh,', 'Mh-hmm', 'Hmm.', 'going to.', 'not.'] for hallucination in common_hallucinations: text = text.replace(hallucination, "") return text - def transcript_job(self, audio_data: np.ndarray, sample_rate: int = 16000): + def transcript_job(self, audio_data: np.ndarray, sample_rate: int = 16000) -> str: + """Transcribe the audio data.""" if audio_data.dtype != np.float32: audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max if len(audio_data.shape) > 1: @@ -106,7 +117,10 @@ class Transcript: return self.remove_hallucinations(result["text"]) class AudioTranscriber: - def __init__(self, ai_name: str, verbose=False): + """ + AudioTranscriber is a class that transcribes audio from the audio queue and adds it to the transcript. + """ + def __init__(self, ai_name: str, verbose: bool = False): self.verbose = verbose self.ai_name = ai_name self.transcriptor = Transcript() @@ -126,14 +140,17 @@ class AudioTranscriber: } self.recorded = "" - def get_transcript(self): + def get_transcript(self) -> str: global done buffer = self.recorded self.recorded = "" done = False return buffer - def _transcribe(self): + def _transcribe(self) -> None: + """ + Transcribe the audio data using AI stt model. + """ global done if self.verbose: print(Fore.BLUE + "AudioTranscriber: Started processing..." + Fore.RESET) diff --git a/sources/text_to_speech.py b/sources/text_to_speech.py index 0a90742..13ba1a4 100644 --- a/sources/text_to_speech.py +++ b/sources/text_to_speech.py @@ -9,7 +9,7 @@ class Speech(): """ Speech is a class for generating speech from text. """ - def __init__(self, language = "english") -> None: + def __init__(self, language: str = "english") -> None: self.lang_map = { "english": 'a', "chinese": 'z', @@ -24,9 +24,13 @@ class Speech(): self.voice = self.voice_map[language][2] self.speed = 1.2 - def speak(self, sentence, voice_number = 1): + def speak(self, sentence: str, voice_number: int = 1): """ - Use AI model to generate speech from text after pre-processing the text. + Convert text to speech using an AI model and play the audio. + + Args: + sentence (str): The text to convert to speech. Will be pre-processed. + voice_number (int, optional): Index of the voice to use from the voice map. """ sentence = self.clean_sentence(sentence) self.voice = self.voice_map["english"][voice_number] @@ -45,18 +49,26 @@ class Speech(): import winsound winsound.PlaySound(audio_file, winsound.SND_FILENAME) - def replace_url(self, m): + def replace_url(self, url: re.Match) -> str: """ - Replace URL with empty string. + Replace URL with domain name or empty string if IP address. + Args: + url (re.Match): Match object containing the URL pattern match + Returns: + str: The domain name from the URL, or empty string if IP address """ - domain = m.group(1) + domain = url.group(1) if re.match(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$', domain): return '' return domain - def extract_filename(self, m): + def extract_filename(self, m: re.Match) -> str: """ Extract filename from path. + Args: + m (re.Match): Match object containing the path pattern match + Returns: + str: The filename from the path """ path = m.group() parts = re.split(r'/|\\', path) @@ -65,6 +77,10 @@ class Speech(): def shorten_paragraph(self, sentence): """ Shorten paragraph like **explaination**: by keeping only the first sentence. + Args: + sentence (str): The sentence to shorten + Returns: + str: The shortened sentence """ lines = sentence.split('\n') lines_edited = [] @@ -77,7 +93,11 @@ class Speech(): def clean_sentence(self, sentence): """ - Clean sentence by removing URLs, filenames, and other non-alphanumeric characters. + Clean and normalize text for speech synthesis by removing technical elements. + Args: + sentence (str): The input text to clean + Returns: + str: The cleaned text with URLs replaced by domain names, code blocks removed, etc.. """ lines = sentence.split('\n') filtered_lines = [line for line in lines if re.match(r'^\s*[a-zA-Z]', line)] diff --git a/sources/utility.py b/sources/utility.py index 6445c5d..6f053c9 100644 --- a/sources/utility.py +++ b/sources/utility.py @@ -6,7 +6,19 @@ import platform def pretty_print(text, color = "info"): """ - print text with color + Print text with color formatting. + + Args: + text (str): The text to print + color (str, optional): The color to use. Defaults to "info". + Valid colors are: + - "success": Green + - "failure": Red + - "status": Light green + - "code": Light blue + - "warning": Yellow + - "output": Cyan + - "default": Black (Windows only) """ if platform.system().lower() != "windows": color_map = { @@ -37,6 +49,13 @@ def pretty_print(text, color = "info"): print(colored(text, color_map[color])) def timer_decorator(func): + """ + Decorator to measure the execution time of a function. + Usage: + @timer_decorator + def my_function(): + # code to execute + """ from time import time def wrapper(*args, **kwargs): start_time = time()