Docs: better documentation

This commit is contained in:
martin legrand 2025-03-06 11:58:54 +01:00
parent ff1af3b6a9
commit eca688baba
6 changed files with 128 additions and 29 deletions

View File

@ -5,6 +5,9 @@ from sources.router import AgentRouter
from sources.speech_to_text import AudioTranscriber, AudioRecorder
class Interaction:
"""
Interaction is a class that handles the interaction between the user and the agents.
"""
def __init__(self, agents,
tts_enabled: bool = True,
stt_enabled: bool = True,
@ -29,6 +32,7 @@ class Interaction:
self.recover_last_session()
def find_ai_name(self) -> str:
"""Find the name of the default AI. It is required for STT as a trigger word."""
ai_name = "jarvis"
for agent in self.agents:
if agent.role == "talking":
@ -37,17 +41,20 @@ class Interaction:
return ai_name
def recover_last_session(self):
"""Recover the last session."""
for agent in self.agents:
agent.memory.load_memory()
def save_session(self):
"""Save the current session."""
for agent in self.agents:
agent.memory.save_memory()
def is_active(self):
def is_active(self) -> bool:
return self.is_active
def read_stdin(self) -> str:
"""Read the input from the user."""
buffer = ""
while buffer == "" or buffer.isascii() == False:
@ -59,7 +66,8 @@ class Interaction:
return None
return buffer
def transcription_job(self):
def transcription_job(self) -> str:
"""Transcribe the audio from the microphone."""
self.recorder = AudioRecorder(verbose=True)
self.transcriber = AudioTranscriber(self.ai_name, verbose=True)
self.transcriber.start()
@ -69,7 +77,8 @@ class Interaction:
query = self.transcriber.get_transcript()
return query
def get_user(self):
def get_user_input(self) -> str:
"""Get the user input from the microphone or the keyboard."""
if self.stt_enabled:
query = "TTS transcription of user: " + self.transcription_job()
else:
@ -81,7 +90,8 @@ class Interaction:
self.last_query = query
return query
def think(self):
def think(self) -> None:
"""Request AI agents to process the user input."""
if self.last_query is None or len(self.last_query) == 0:
return
agent = self.router.select_agent(self.last_query)
@ -93,7 +103,8 @@ class Interaction:
self.current_agent.memory.push('user', self.last_query)
self.last_answer, _ = agent.process(self.last_query, self.speech)
def show_answer(self):
def show_answer(self) -> None:
"""Show the answer to the user."""
if self.last_query is None:
return
self.current_agent.show_answer()

View File

@ -39,6 +39,7 @@ class Memory():
return f"memory_{self.session_time.strftime('%Y-%m-%d_%H-%M-%S')}.txt"
def save_memory(self) -> None:
"""Save the session memory to a file."""
if not os.path.exists(self.conversation_folder):
os.makedirs(self.conversation_folder)
filename = self.get_filename()
@ -48,6 +49,7 @@ class Memory():
f.write(json_memory)
def find_last_session_path(self) -> str:
"""Find the last session path."""
saved_sessions = []
for filename in os.listdir(self.conversation_folder):
if filename.startswith('memory_'):
@ -59,6 +61,7 @@ class Memory():
return None
def load_memory(self) -> None:
"""Load the memory from the last session."""
if not os.path.exists(self.conversation_folder):
return
filename = self.find_last_session_path()
@ -72,6 +75,7 @@ class Memory():
self.memory = memory
def push(self, role: str, content: str) -> None:
"""Push a message to the memory."""
self.memory.append({'role': role, 'content': content})
# EXPERIMENTAL
if self.memory_compression and role == 'assistant':
@ -92,6 +96,14 @@ class Memory():
return "cpu"
def summarize(self, text: str, min_length: int = 64) -> str:
"""
Summarize the text using the AI model.
Args:
text (str): The text to summarize
min_length (int, optional): The minimum length of the summary. Defaults to 64.
Returns:
str: The summarized text
"""
if self.tokenizer is None or self.model is None:
return text
max_length = len(text) // 2 if len(text) > min_length*2 else min_length*2
@ -110,6 +122,9 @@ class Memory():
@timer_decorator
def compress(self) -> str:
"""
Compress the memory using the AI model.
"""
if not self.memory_compression:
return
for i in range(len(self.memory)):

View File

@ -6,14 +6,17 @@ from sources.casual_agent import CasualAgent
from sources.utility import pretty_print
class AgentRouter:
def __init__(self, agents: list, model_name="facebook/bart-large-mnli"):
"""
AgentRouter is a class that selects the appropriate agent based on the user query.
"""
def __init__(self, agents: list, model_name: str = "facebook/bart-large-mnli"):
self.model = model_name
self.pipeline = pipeline("zero-shot-classification",
model=self.model)
self.agents = agents
self.labels = [agent.role for agent in agents]
def get_device(self):
def get_device(self) -> str:
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
@ -21,10 +24,17 @@ class AgentRouter:
else:
return "cpu"
def classify_text(self, text, threshold=0.5):
def classify_text(self, text: str, threshold: float = 0.5) -> list:
"""
Classify the text into labels (agent roles).
Args:
text (str): The text to classify
threshold (float, optional): The threshold for the classification.
Returns:
list: The list of agents and their scores
"""
first_sentence = None
for line in text.split("\n"):
if line.strip() != "":
first_sentence = line.strip()
break
if first_sentence is None:
@ -33,6 +43,13 @@ class AgentRouter:
return result
def select_agent(self, text: str) -> Agent:
"""
Select the appropriate agent based on the text.
Args:
text (str): The text to select the agent from
Returns:
Agent: The selected agent
"""
if len(self.agents) == 0 or len(self.labels) == 0:
return self.agents[0]
result = self.classify_text(text)

View File

@ -12,7 +12,10 @@ audio_queue = queue.Queue()
done = False
class AudioRecorder:
def __init__(self, format=pyaudio.paInt16, channels=1, rate=4096, chunk=8192, record_seconds=5, verbose=False):
"""
AudioRecorder is a class that records audio from the microphone and adds it to the audio queue.
"""
def __init__(self, format: int = pyaudio.paInt16, channels: int = 1, rate: int = 4096, chunk: int = 8192, record_seconds: int = 5, verbose: bool = False):
self.format = format
self.channels = channels
self.rate = rate
@ -22,7 +25,10 @@ class AudioRecorder:
self.audio = pyaudio.PyAudio()
self.thread = threading.Thread(target=self._record, daemon=True)
def _record(self):
def _record(self) -> None:
"""
Record audio from the microphone and add it to the audio queue.
"""
stream = self.audio.open(format=self.format, channels=self.channels, rate=self.rate,
input=True, frames_per_buffer=self.chunk)
if self.verbose:
@ -49,16 +55,19 @@ class AudioRecorder:
if self.verbose:
print(Fore.GREEN + "AudioRecorder: Stopped" + Fore.RESET)
def start(self):
def start(self) -> None:
"""Start the recording thread."""
self.thread.start()
def join(self):
def join(self) -> None:
"""Wait for the recording thread to finish."""
self.thread.join()
class Transcript:
def __init__(self) -> None:
"""
Transcript is a class that transcribes audio from the audio queue and adds it to the transcript.
"""
def __init__(self):
self.last_read = None
device = self.get_device()
torch_dtype = torch.float16 if device == "cuda" else torch.float32
@ -80,7 +89,7 @@ class Transcript:
device=device,
)
def get_device(self):
def get_device(self) -> str:
if torch.backends.mps.is_available():
return "mps"
if torch.cuda.is_available():
@ -88,14 +97,16 @@ class Transcript:
else:
return "cpu"
def remove_hallucinations(self, text: str):
def remove_hallucinations(self, text: str) -> str:
"""Remove model hallucinations from the text."""
# TODO find a better way to do this
common_hallucinations = ['Okay.', 'Thank you.', 'Thank you for watching.', 'You\'re', 'Oh', 'you', 'Oh.', 'Uh', 'Oh,', 'Mh-hmm', 'Hmm.', 'going to.', 'not.']
for hallucination in common_hallucinations:
text = text.replace(hallucination, "")
return text
def transcript_job(self, audio_data: np.ndarray, sample_rate: int = 16000):
def transcript_job(self, audio_data: np.ndarray, sample_rate: int = 16000) -> str:
"""Transcribe the audio data."""
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
if len(audio_data.shape) > 1:
@ -106,7 +117,10 @@ class Transcript:
return self.remove_hallucinations(result["text"])
class AudioTranscriber:
def __init__(self, ai_name: str, verbose=False):
"""
AudioTranscriber is a class that transcribes audio from the audio queue and adds it to the transcript.
"""
def __init__(self, ai_name: str, verbose: bool = False):
self.verbose = verbose
self.ai_name = ai_name
self.transcriptor = Transcript()
@ -126,14 +140,17 @@ class AudioTranscriber:
}
self.recorded = ""
def get_transcript(self):
def get_transcript(self) -> str:
global done
buffer = self.recorded
self.recorded = ""
done = False
return buffer
def _transcribe(self):
def _transcribe(self) -> None:
"""
Transcribe the audio data using AI stt model.
"""
global done
if self.verbose:
print(Fore.BLUE + "AudioTranscriber: Started processing..." + Fore.RESET)

View File

@ -9,7 +9,7 @@ class Speech():
"""
Speech is a class for generating speech from text.
"""
def __init__(self, language = "english") -> None:
def __init__(self, language: str = "english") -> None:
self.lang_map = {
"english": 'a',
"chinese": 'z',
@ -24,9 +24,13 @@ class Speech():
self.voice = self.voice_map[language][2]
self.speed = 1.2
def speak(self, sentence, voice_number = 1):
def speak(self, sentence: str, voice_number: int = 1):
"""
Use AI model to generate speech from text after pre-processing the text.
Convert text to speech using an AI model and play the audio.
Args:
sentence (str): The text to convert to speech. Will be pre-processed.
voice_number (int, optional): Index of the voice to use from the voice map.
"""
sentence = self.clean_sentence(sentence)
self.voice = self.voice_map["english"][voice_number]
@ -45,18 +49,26 @@ class Speech():
import winsound
winsound.PlaySound(audio_file, winsound.SND_FILENAME)
def replace_url(self, m):
def replace_url(self, url: re.Match) -> str:
"""
Replace URL with empty string.
Replace URL with domain name or empty string if IP address.
Args:
url (re.Match): Match object containing the URL pattern match
Returns:
str: The domain name from the URL, or empty string if IP address
"""
domain = m.group(1)
domain = url.group(1)
if re.match(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$', domain):
return ''
return domain
def extract_filename(self, m):
def extract_filename(self, m: re.Match) -> str:
"""
Extract filename from path.
Args:
m (re.Match): Match object containing the path pattern match
Returns:
str: The filename from the path
"""
path = m.group()
parts = re.split(r'/|\\', path)
@ -65,6 +77,10 @@ class Speech():
def shorten_paragraph(self, sentence):
"""
Shorten paragraph like **explaination**: <long text> by keeping only the first sentence.
Args:
sentence (str): The sentence to shorten
Returns:
str: The shortened sentence
"""
lines = sentence.split('\n')
lines_edited = []
@ -77,7 +93,11 @@ class Speech():
def clean_sentence(self, sentence):
"""
Clean sentence by removing URLs, filenames, and other non-alphanumeric characters.
Clean and normalize text for speech synthesis by removing technical elements.
Args:
sentence (str): The input text to clean
Returns:
str: The cleaned text with URLs replaced by domain names, code blocks removed, etc..
"""
lines = sentence.split('\n')
filtered_lines = [line for line in lines if re.match(r'^\s*[a-zA-Z]', line)]

View File

@ -6,7 +6,19 @@ import platform
def pretty_print(text, color = "info"):
"""
print text with color
Print text with color formatting.
Args:
text (str): The text to print
color (str, optional): The color to use. Defaults to "info".
Valid colors are:
- "success": Green
- "failure": Red
- "status": Light green
- "code": Light blue
- "warning": Yellow
- "output": Cyan
- "default": Black (Windows only)
"""
if platform.system().lower() != "windows":
color_map = {
@ -37,6 +49,13 @@ def pretty_print(text, color = "info"):
print(colored(text, color_map[color]))
def timer_decorator(func):
"""
Decorator to measure the execution time of a function.
Usage:
@timer_decorator
def my_function():
# code to execute
"""
from time import time
def wrapper(*args, **kwargs):
start_time = time()