Feat : agent router system + casual agent

This commit is contained in:
martin legrand 2025-03-02 15:21:36 +01:00
parent dcb5724e28
commit ca49a7f0a5
9 changed files with 164 additions and 41 deletions

2
.gitignore vendored
View File

@ -1,6 +1,8 @@
*.wav
config.ini
experimental/
.env
*/.env
# Byte-compiled / optimized / DLL files

16
main.py
View File

@ -8,11 +8,11 @@ import configparser
from sources.llm_provider import Provider
from sources.interaction import Interaction
from sources.code_agent import CoderAgent
from sources.casual_agent import CasualAgent
parser = argparse.ArgumentParser(description='Deepseek AI assistant')
parser.add_argument('--speak', action='store_true',
help='Make AI use text-to-speech')
parser.add_argument('--no-speak', action='store_true',
help='Make AI not use text-to-speech')
args = parser.parse_args()
config = configparser.ConfigParser()
@ -31,12 +31,18 @@ def main():
model=config["MAIN"]["provider_model"],
server_address=config["MAIN"]["provider_server_address"])
agent = CoderAgent(model=config["MAIN"]["provider_model"],
agents = [
CoderAgent(model=config["MAIN"]["provider_model"],
name=config["MAIN"]["agent_name"],
prompt_path="prompts/coder_agent.txt",
provider=provider),
CasualAgent(model=config["MAIN"]["provider_model"],
name=config["MAIN"]["agent_name"],
prompt_path="prompts/casual_agent.txt",
provider=provider)
]
interaction = Interaction([agent], tts_enabled=config.getboolean('MAIN', 'speak'),
interaction = Interaction(agents, tts_enabled=config.getboolean('MAIN', 'speak'),
recover_last_session=config.getboolean('MAIN', 'recover_last_session'))
while interaction.is_active:
interaction.get_user()

View File

@ -12,8 +12,8 @@ kokoro==0.7.12
flask==3.1.0
soundfile==0.13.1
protobuf==3.20.3
termcolor==2.3.0
openai==1.13.3
termcolor==2.5.0
gliclass==0.1.8
# if use chinese
ordered_set
pypinyin

View File

@ -2,6 +2,10 @@ from typing import Tuple, Callable
from abc import abstractmethod
import os
import random
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sources.memory import Memory
from sources.utility import pretty_print
@ -25,6 +29,7 @@ class Agent():
provider,
recover_last_session=False) -> None:
self.agent_name = name
self.role = None
self.current_directory = os.getcwd()
self.model = model
self.llm = provider
@ -60,7 +65,14 @@ class Agent():
raise e
@abstractmethod
def answer(self, prompt, speech_module) -> str:
def show_answer(self):
"""
abstract method, implementation in child class.
"""
pass
@abstractmethod
def process(self, prompt, speech_module) -> str:
"""
abstract method, implementation in child class.
"""
@ -78,7 +90,7 @@ class Agent():
end_idx = text.rfind(end_tag)+8
return text[start_idx:end_idx]
def llm_request(self, verbose = True) -> Tuple[str, str]:
def llm_request(self, verbose = False) -> Tuple[str, str]:
memory = self.memory.get()
thought = self.llm.respond(memory, verbose)
@ -95,16 +107,30 @@ class Agent():
"Working on it sir, please let me think."]
speech_module.speak(messages[random.randint(0, len(messages)-1)])
def print_code_blocks(self, blocks: list, name: str):
for block in blocks:
pretty_print(f"Executing {name} code...\n", color="output")
pretty_print("-"*100, color="output")
pretty_print(block, color="code")
pretty_print("-"*100, color="output")
def get_blocks_result(self) -> list:
return self.blocks_result
def remove_blocks(self, text: str) -> str:
"""
Remove all code/query blocks within a tag from the answer text.
"""
tag = f'```'
lines = text.split('\n')
post_lines = []
in_block = False
block_idx = 0
for line in lines:
if tag in line and not in_block:
in_block = True
continue
if not in_block:
post_lines.append(line)
if tag in line:
in_block = False
post_lines.append(f"block:{block_idx}")
block_idx += 1
return "\n".join(post_lines)
def execute_modules(self, answer: str) -> Tuple[bool, str]:
feedback = ""
success = False

36
sources/casual_agent.py Normal file
View File

@ -0,0 +1,36 @@
from sources.utility import pretty_print
from sources.agent import Agent
class CasualAgent(Agent):
def __init__(self, model, name, prompt_path, provider):
"""
The casual agent is a special for casual talk to the user without specific tasks.
"""
super().__init__(model, name, prompt_path, provider)
self.tools = {
} # TODO implement casual tools like basic web search, basic file search, basic image search, basic knowledge search
self.role = "talking"
def show_answer(self):
lines = self.last_answer.split("\n")
for line in lines:
pretty_print(line, color="output")
def process(self, prompt, speech_module) -> str:
self.memory.push('user', prompt)
pretty_print("Thinking...", color="status")
self.wait_message(speech_module)
answer, reasoning = self.llm_request()
self.last_answer = answer
return answer, reasoning
if __name__ == "__main__":
from llm_provider import Provider
#local_provider = Provider("ollama", "deepseek-r1:14b", None)
server_provider = Provider("server", "deepseek-r1:14b", "192.168.1.100:5000")
agent = CasualAgent("deepseek-r1:14b", "jarvis", "prompts/casual_agent.txt", server_provider)
ans = agent.process("Hello, how are you?")
print(ans)

View File

@ -4,36 +4,17 @@ from sources.agent import Agent, executorResult
from sources.tools import PyInterpreter, BashInterpreter, CInterpreter, GoInterpreter
class CoderAgent(Agent):
"""
The code agent is a special for writing code and shell commands.
"""
def __init__(self, model, name, prompt_path, provider):
super().__init__(model, name, prompt_path, provider)
self.tools = {
"bash": BashInterpreter(),
"python": PyInterpreter()
"C": CInterpreter(),
"go": GoInterpreter()
}
self.role = "coding"
def remove_blocks(self, text: str) -> str:
"""
Remove all code/query blocks within a tag from the answer text.
"""
tag = f'```'
lines = text.split('\n')
post_lines = []
in_block = False
block_idx = 0
for line in lines:
if tag in line and not in_block:
in_block = True
continue
if not in_block:
post_lines.append(line)
if tag in line:
in_block = False
post_lines.append(f"block:{block_idx}")
block_idx += 1
return "\n".join(post_lines)
def show_answer(self):
lines = self.last_answer.split("\n")
for line in lines:

View File

@ -1,11 +1,14 @@
from sources.text_to_speech import Speech
from sources.utility import pretty_print
from sources.router import AgentRouter
class Interaction:
def __init__(self, agents, tts_enabled: bool = False, recover_last_session: bool = False):
self.tts_enabled = tts_enabled
self.agents = agents
self.current_agent = None
self.router = AgentRouter(self.agents)
self.speech = Speech()
self.is_active = True
self.last_query = None
@ -44,10 +47,15 @@ class Interaction:
return query
def think(self):
self.last_answer, _ = self.agents[0].process(self.last_query, self.speech)
agent = self.router.select_agent(self.last_query)
if self.current_agent != agent:
self.current_agent = agent
# get history from previous agent
self.current_agent.memory.push('user', self.last_query)
self.last_answer, _ = agent.process(self.last_query, self.speech)
def show_answer(self):
self.agents[0].show_answer()
self.current_agent.show_answer()
if self.tts_enabled:
self.speech.speak(self.last_answer)

View File

@ -4,7 +4,11 @@ import time
import datetime
import uuid
import os
import sys
import json
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sources.utility import timer_decorator
class Memory():

60
sources/router.py Normal file
View File

@ -0,0 +1,60 @@
import torch
from transformers import pipeline
from sources.agent import Agent
from sources.code_agent import CoderAgent
from sources.casual_agent import CasualAgent
from sources.utility import pretty_print
class AgentRouter:
def __init__(self, agents: list, model_name="facebook/bart-large-mnli"):
self.model = model_name
self.pipeline = pipeline("zero-shot-classification",
model=self.model)
self.agents = agents
self.labels = [agent.role for agent in agents]
def get_device(self):
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda:0"
else:
return "cpu"
def classify_text(self, text, threshold=0.5):
result = self.pipeline(text, self.labels, threshold=threshold)
return result
def select_agent(self, text: str) -> Agent:
result = self.classify_text(text)
for agent in self.agents:
if result["labels"][0] == agent.role:
pretty_print(f"Selected agent role: {agent.role}", color="warning")
return agent
return None
if __name__ == "__main__":
agents = [
CoderAgent("deepseek-r1:14b", "agent1", "../prompts/coder_agent.txt", "server"),
CasualAgent("deepseek-r1:14b", "agent2", "../prompts/casual_agent.txt", "server")
]
router = AgentRouter(agents)
texts = ["""
Write a python script to check if the device on my network is connected to the internet
""",
"""
Hey could you search the web for the latest news on the stock market ?
""",
"""
hey can you give dating advice ?
"""
]
for text in texts:
print(text)
results = router.classify_text(text)
for result in results:
print(result["label"], "=>", result["score"])
agent = router.select_agent(text)
print("Selected agent role:", agent.role)