Docs: better documentation

This commit is contained in:
martin legrand 2025-03-06 11:58:54 +01:00
parent ff1af3b6a9
commit eca688baba
6 changed files with 128 additions and 29 deletions

View File

@ -5,6 +5,9 @@ from sources.router import AgentRouter
from sources.speech_to_text import AudioTranscriber, AudioRecorder from sources.speech_to_text import AudioTranscriber, AudioRecorder
class Interaction: class Interaction:
"""
Interaction is a class that handles the interaction between the user and the agents.
"""
def __init__(self, agents, def __init__(self, agents,
tts_enabled: bool = True, tts_enabled: bool = True,
stt_enabled: bool = True, stt_enabled: bool = True,
@ -29,6 +32,7 @@ class Interaction:
self.recover_last_session() self.recover_last_session()
def find_ai_name(self) -> str: def find_ai_name(self) -> str:
"""Find the name of the default AI. It is required for STT as a trigger word."""
ai_name = "jarvis" ai_name = "jarvis"
for agent in self.agents: for agent in self.agents:
if agent.role == "talking": if agent.role == "talking":
@ -37,17 +41,20 @@ class Interaction:
return ai_name return ai_name
def recover_last_session(self): def recover_last_session(self):
"""Recover the last session."""
for agent in self.agents: for agent in self.agents:
agent.memory.load_memory() agent.memory.load_memory()
def save_session(self): def save_session(self):
"""Save the current session."""
for agent in self.agents: for agent in self.agents:
agent.memory.save_memory() agent.memory.save_memory()
def is_active(self): def is_active(self) -> bool:
return self.is_active return self.is_active
def read_stdin(self) -> str: def read_stdin(self) -> str:
"""Read the input from the user."""
buffer = "" buffer = ""
while buffer == "" or buffer.isascii() == False: while buffer == "" or buffer.isascii() == False:
@ -59,7 +66,8 @@ class Interaction:
return None return None
return buffer return buffer
def transcription_job(self): def transcription_job(self) -> str:
"""Transcribe the audio from the microphone."""
self.recorder = AudioRecorder(verbose=True) self.recorder = AudioRecorder(verbose=True)
self.transcriber = AudioTranscriber(self.ai_name, verbose=True) self.transcriber = AudioTranscriber(self.ai_name, verbose=True)
self.transcriber.start() self.transcriber.start()
@ -69,7 +77,8 @@ class Interaction:
query = self.transcriber.get_transcript() query = self.transcriber.get_transcript()
return query return query
def get_user(self): def get_user_input(self) -> str:
"""Get the user input from the microphone or the keyboard."""
if self.stt_enabled: if self.stt_enabled:
query = "TTS transcription of user: " + self.transcription_job() query = "TTS transcription of user: " + self.transcription_job()
else: else:
@ -81,7 +90,8 @@ class Interaction:
self.last_query = query self.last_query = query
return query return query
def think(self): def think(self) -> None:
"""Request AI agents to process the user input."""
if self.last_query is None or len(self.last_query) == 0: if self.last_query is None or len(self.last_query) == 0:
return return
agent = self.router.select_agent(self.last_query) agent = self.router.select_agent(self.last_query)
@ -93,7 +103,8 @@ class Interaction:
self.current_agent.memory.push('user', self.last_query) self.current_agent.memory.push('user', self.last_query)
self.last_answer, _ = agent.process(self.last_query, self.speech) self.last_answer, _ = agent.process(self.last_query, self.speech)
def show_answer(self): def show_answer(self) -> None:
"""Show the answer to the user."""
if self.last_query is None: if self.last_query is None:
return return
self.current_agent.show_answer() self.current_agent.show_answer()

View File

@ -39,6 +39,7 @@ class Memory():
return f"memory_{self.session_time.strftime('%Y-%m-%d_%H-%M-%S')}.txt" return f"memory_{self.session_time.strftime('%Y-%m-%d_%H-%M-%S')}.txt"
def save_memory(self) -> None: def save_memory(self) -> None:
"""Save the session memory to a file."""
if not os.path.exists(self.conversation_folder): if not os.path.exists(self.conversation_folder):
os.makedirs(self.conversation_folder) os.makedirs(self.conversation_folder)
filename = self.get_filename() filename = self.get_filename()
@ -48,6 +49,7 @@ class Memory():
f.write(json_memory) f.write(json_memory)
def find_last_session_path(self) -> str: def find_last_session_path(self) -> str:
"""Find the last session path."""
saved_sessions = [] saved_sessions = []
for filename in os.listdir(self.conversation_folder): for filename in os.listdir(self.conversation_folder):
if filename.startswith('memory_'): if filename.startswith('memory_'):
@ -59,6 +61,7 @@ class Memory():
return None return None
def load_memory(self) -> None: def load_memory(self) -> None:
"""Load the memory from the last session."""
if not os.path.exists(self.conversation_folder): if not os.path.exists(self.conversation_folder):
return return
filename = self.find_last_session_path() filename = self.find_last_session_path()
@ -72,6 +75,7 @@ class Memory():
self.memory = memory self.memory = memory
def push(self, role: str, content: str) -> None: def push(self, role: str, content: str) -> None:
"""Push a message to the memory."""
self.memory.append({'role': role, 'content': content}) self.memory.append({'role': role, 'content': content})
# EXPERIMENTAL # EXPERIMENTAL
if self.memory_compression and role == 'assistant': if self.memory_compression and role == 'assistant':
@ -92,6 +96,14 @@ class Memory():
return "cpu" return "cpu"
def summarize(self, text: str, min_length: int = 64) -> str: def summarize(self, text: str, min_length: int = 64) -> str:
"""
Summarize the text using the AI model.
Args:
text (str): The text to summarize
min_length (int, optional): The minimum length of the summary. Defaults to 64.
Returns:
str: The summarized text
"""
if self.tokenizer is None or self.model is None: if self.tokenizer is None or self.model is None:
return text return text
max_length = len(text) // 2 if len(text) > min_length*2 else min_length*2 max_length = len(text) // 2 if len(text) > min_length*2 else min_length*2
@ -110,6 +122,9 @@ class Memory():
@timer_decorator @timer_decorator
def compress(self) -> str: def compress(self) -> str:
"""
Compress the memory using the AI model.
"""
if not self.memory_compression: if not self.memory_compression:
return return
for i in range(len(self.memory)): for i in range(len(self.memory)):

View File

@ -6,14 +6,17 @@ from sources.casual_agent import CasualAgent
from sources.utility import pretty_print from sources.utility import pretty_print
class AgentRouter: class AgentRouter:
def __init__(self, agents: list, model_name="facebook/bart-large-mnli"): """
AgentRouter is a class that selects the appropriate agent based on the user query.
"""
def __init__(self, agents: list, model_name: str = "facebook/bart-large-mnli"):
self.model = model_name self.model = model_name
self.pipeline = pipeline("zero-shot-classification", self.pipeline = pipeline("zero-shot-classification",
model=self.model) model=self.model)
self.agents = agents self.agents = agents
self.labels = [agent.role for agent in agents] self.labels = [agent.role for agent in agents]
def get_device(self): def get_device(self) -> str:
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
return "mps" return "mps"
elif torch.cuda.is_available(): elif torch.cuda.is_available():
@ -21,10 +24,17 @@ class AgentRouter:
else: else:
return "cpu" return "cpu"
def classify_text(self, text, threshold=0.5): def classify_text(self, text: str, threshold: float = 0.5) -> list:
"""
Classify the text into labels (agent roles).
Args:
text (str): The text to classify
threshold (float, optional): The threshold for the classification.
Returns:
list: The list of agents and their scores
"""
first_sentence = None first_sentence = None
for line in text.split("\n"): for line in text.split("\n"):
if line.strip() != "":
first_sentence = line.strip() first_sentence = line.strip()
break break
if first_sentence is None: if first_sentence is None:
@ -33,6 +43,13 @@ class AgentRouter:
return result return result
def select_agent(self, text: str) -> Agent: def select_agent(self, text: str) -> Agent:
"""
Select the appropriate agent based on the text.
Args:
text (str): The text to select the agent from
Returns:
Agent: The selected agent
"""
if len(self.agents) == 0 or len(self.labels) == 0: if len(self.agents) == 0 or len(self.labels) == 0:
return self.agents[0] return self.agents[0]
result = self.classify_text(text) result = self.classify_text(text)

View File

@ -12,7 +12,10 @@ audio_queue = queue.Queue()
done = False done = False
class AudioRecorder: class AudioRecorder:
def __init__(self, format=pyaudio.paInt16, channels=1, rate=4096, chunk=8192, record_seconds=5, verbose=False): """
AudioRecorder is a class that records audio from the microphone and adds it to the audio queue.
"""
def __init__(self, format: int = pyaudio.paInt16, channels: int = 1, rate: int = 4096, chunk: int = 8192, record_seconds: int = 5, verbose: bool = False):
self.format = format self.format = format
self.channels = channels self.channels = channels
self.rate = rate self.rate = rate
@ -22,7 +25,10 @@ class AudioRecorder:
self.audio = pyaudio.PyAudio() self.audio = pyaudio.PyAudio()
self.thread = threading.Thread(target=self._record, daemon=True) self.thread = threading.Thread(target=self._record, daemon=True)
def _record(self): def _record(self) -> None:
"""
Record audio from the microphone and add it to the audio queue.
"""
stream = self.audio.open(format=self.format, channels=self.channels, rate=self.rate, stream = self.audio.open(format=self.format, channels=self.channels, rate=self.rate,
input=True, frames_per_buffer=self.chunk) input=True, frames_per_buffer=self.chunk)
if self.verbose: if self.verbose:
@ -49,16 +55,19 @@ class AudioRecorder:
if self.verbose: if self.verbose:
print(Fore.GREEN + "AudioRecorder: Stopped" + Fore.RESET) print(Fore.GREEN + "AudioRecorder: Stopped" + Fore.RESET)
def start(self): def start(self) -> None:
"""Start the recording thread.""" """Start the recording thread."""
self.thread.start() self.thread.start()
def join(self): def join(self) -> None:
"""Wait for the recording thread to finish.""" """Wait for the recording thread to finish."""
self.thread.join() self.thread.join()
class Transcript: class Transcript:
def __init__(self) -> None: """
Transcript is a class that transcribes audio from the audio queue and adds it to the transcript.
"""
def __init__(self):
self.last_read = None self.last_read = None
device = self.get_device() device = self.get_device()
torch_dtype = torch.float16 if device == "cuda" else torch.float32 torch_dtype = torch.float16 if device == "cuda" else torch.float32
@ -80,7 +89,7 @@ class Transcript:
device=device, device=device,
) )
def get_device(self): def get_device(self) -> str:
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
return "mps" return "mps"
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -88,14 +97,16 @@ class Transcript:
else: else:
return "cpu" return "cpu"
def remove_hallucinations(self, text: str): def remove_hallucinations(self, text: str) -> str:
"""Remove model hallucinations from the text."""
# TODO find a better way to do this # TODO find a better way to do this
common_hallucinations = ['Okay.', 'Thank you.', 'Thank you for watching.', 'You\'re', 'Oh', 'you', 'Oh.', 'Uh', 'Oh,', 'Mh-hmm', 'Hmm.', 'going to.', 'not.'] common_hallucinations = ['Okay.', 'Thank you.', 'Thank you for watching.', 'You\'re', 'Oh', 'you', 'Oh.', 'Uh', 'Oh,', 'Mh-hmm', 'Hmm.', 'going to.', 'not.']
for hallucination in common_hallucinations: for hallucination in common_hallucinations:
text = text.replace(hallucination, "") text = text.replace(hallucination, "")
return text return text
def transcript_job(self, audio_data: np.ndarray, sample_rate: int = 16000): def transcript_job(self, audio_data: np.ndarray, sample_rate: int = 16000) -> str:
"""Transcribe the audio data."""
if audio_data.dtype != np.float32: if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
if len(audio_data.shape) > 1: if len(audio_data.shape) > 1:
@ -106,7 +117,10 @@ class Transcript:
return self.remove_hallucinations(result["text"]) return self.remove_hallucinations(result["text"])
class AudioTranscriber: class AudioTranscriber:
def __init__(self, ai_name: str, verbose=False): """
AudioTranscriber is a class that transcribes audio from the audio queue and adds it to the transcript.
"""
def __init__(self, ai_name: str, verbose: bool = False):
self.verbose = verbose self.verbose = verbose
self.ai_name = ai_name self.ai_name = ai_name
self.transcriptor = Transcript() self.transcriptor = Transcript()
@ -126,14 +140,17 @@ class AudioTranscriber:
} }
self.recorded = "" self.recorded = ""
def get_transcript(self): def get_transcript(self) -> str:
global done global done
buffer = self.recorded buffer = self.recorded
self.recorded = "" self.recorded = ""
done = False done = False
return buffer return buffer
def _transcribe(self): def _transcribe(self) -> None:
"""
Transcribe the audio data using AI stt model.
"""
global done global done
if self.verbose: if self.verbose:
print(Fore.BLUE + "AudioTranscriber: Started processing..." + Fore.RESET) print(Fore.BLUE + "AudioTranscriber: Started processing..." + Fore.RESET)

View File

@ -9,7 +9,7 @@ class Speech():
""" """
Speech is a class for generating speech from text. Speech is a class for generating speech from text.
""" """
def __init__(self, language = "english") -> None: def __init__(self, language: str = "english") -> None:
self.lang_map = { self.lang_map = {
"english": 'a', "english": 'a',
"chinese": 'z', "chinese": 'z',
@ -24,9 +24,13 @@ class Speech():
self.voice = self.voice_map[language][2] self.voice = self.voice_map[language][2]
self.speed = 1.2 self.speed = 1.2
def speak(self, sentence, voice_number = 1): def speak(self, sentence: str, voice_number: int = 1):
""" """
Use AI model to generate speech from text after pre-processing the text. Convert text to speech using an AI model and play the audio.
Args:
sentence (str): The text to convert to speech. Will be pre-processed.
voice_number (int, optional): Index of the voice to use from the voice map.
""" """
sentence = self.clean_sentence(sentence) sentence = self.clean_sentence(sentence)
self.voice = self.voice_map["english"][voice_number] self.voice = self.voice_map["english"][voice_number]
@ -45,18 +49,26 @@ class Speech():
import winsound import winsound
winsound.PlaySound(audio_file, winsound.SND_FILENAME) winsound.PlaySound(audio_file, winsound.SND_FILENAME)
def replace_url(self, m): def replace_url(self, url: re.Match) -> str:
""" """
Replace URL with empty string. Replace URL with domain name or empty string if IP address.
Args:
url (re.Match): Match object containing the URL pattern match
Returns:
str: The domain name from the URL, or empty string if IP address
""" """
domain = m.group(1) domain = url.group(1)
if re.match(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$', domain): if re.match(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$', domain):
return '' return ''
return domain return domain
def extract_filename(self, m): def extract_filename(self, m: re.Match) -> str:
""" """
Extract filename from path. Extract filename from path.
Args:
m (re.Match): Match object containing the path pattern match
Returns:
str: The filename from the path
""" """
path = m.group() path = m.group()
parts = re.split(r'/|\\', path) parts = re.split(r'/|\\', path)
@ -65,6 +77,10 @@ class Speech():
def shorten_paragraph(self, sentence): def shorten_paragraph(self, sentence):
""" """
Shorten paragraph like **explaination**: <long text> by keeping only the first sentence. Shorten paragraph like **explaination**: <long text> by keeping only the first sentence.
Args:
sentence (str): The sentence to shorten
Returns:
str: The shortened sentence
""" """
lines = sentence.split('\n') lines = sentence.split('\n')
lines_edited = [] lines_edited = []
@ -77,7 +93,11 @@ class Speech():
def clean_sentence(self, sentence): def clean_sentence(self, sentence):
""" """
Clean sentence by removing URLs, filenames, and other non-alphanumeric characters. Clean and normalize text for speech synthesis by removing technical elements.
Args:
sentence (str): The input text to clean
Returns:
str: The cleaned text with URLs replaced by domain names, code blocks removed, etc..
""" """
lines = sentence.split('\n') lines = sentence.split('\n')
filtered_lines = [line for line in lines if re.match(r'^\s*[a-zA-Z]', line)] filtered_lines = [line for line in lines if re.match(r'^\s*[a-zA-Z]', line)]

View File

@ -6,7 +6,19 @@ import platform
def pretty_print(text, color = "info"): def pretty_print(text, color = "info"):
""" """
print text with color Print text with color formatting.
Args:
text (str): The text to print
color (str, optional): The color to use. Defaults to "info".
Valid colors are:
- "success": Green
- "failure": Red
- "status": Light green
- "code": Light blue
- "warning": Yellow
- "output": Cyan
- "default": Black (Windows only)
""" """
if platform.system().lower() != "windows": if platform.system().lower() != "windows":
color_map = { color_map = {
@ -37,6 +49,13 @@ def pretty_print(text, color = "info"):
print(colored(text, color_map[color])) print(colored(text, color_map[color]))
def timer_decorator(func): def timer_decorator(func):
"""
Decorator to measure the execution time of a function.
Usage:
@timer_decorator
def my_function():
# code to execute
"""
from time import time from time import time
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
start_time = time() start_time = time()