refactor: self.role for router

This commit is contained in:
martin legrand 2025-04-12 11:50:48 +02:00
parent 06e6b2798b
commit f4feb42dda
8 changed files with 27 additions and 35 deletions

View File

@ -2,4 +2,5 @@
pip3 install --upgrade packaging
pip3 install --upgrade pip setuptools
curl -fsSL https://ollama.com/install.sh | sh
pip3 install -r requirements.txt

View File

@ -26,11 +26,7 @@ class BrowserAgent(Agent):
self.tools = {
"web_search": searxSearch(),
}
self.role = {
"en": "web",
"fr": "web",
"zh": "网络",
}
self.role = "web"
self.type = "browser_agent"
self.browser = browser
self.current_page = ""

View File

@ -14,11 +14,7 @@ class CasualAgent(Agent):
super().__init__(name, prompt_path, provider, verbose, None)
self.tools = {
} # No tools for the casual agent
self.role = {
"en": "talk",
"fr": "discuter",
"zh": "聊天",
}
self.role = "talk"
self.type = "casual_agent"
def process(self, prompt, speech_module) -> str:

View File

@ -24,11 +24,7 @@ class CoderAgent(Agent):
"file_finder": FileFinder()
}
self.work_dir = self.tools["file_finder"].get_work_dir()
self.role = {
"en": "code",
"fr": "codage",
"zh": "编码",
}
self.role = "code"
self.type = "code_agent"

View File

@ -15,11 +15,7 @@ class FileAgent(Agent):
"bash": BashInterpreter()
}
self.work_dir = self.tools["file_finder"].get_work_dir()
self.role = {
"en": "files",
"fr": "fichiers",
"zh": "文件",
}
self.role = "files"
self.type = "file_agent"
def process(self, prompt, speech_module) -> str:

View File

@ -24,11 +24,7 @@ class PlannerAgent(Agent):
"file": FileAgent(name, "prompts/base/file_agent.txt", provider, verbose=False),
"web": BrowserAgent(name, "prompts/base/browser_agent.txt", provider, verbose=False, browser=browser)
}
self.role = {
"en": "Complex Task",
"fr": "Tache complexe",
"zh": "复杂任务",
}
self.role = "planification"
self.type = "planner_agent"
def parse_agent_tasks(self, text):

View File

@ -129,13 +129,24 @@ class Provider:
requests.post(route_gen, json={"messages": history})
is_complete = False
while not is_complete:
response = requests.get(f"http://{self.server_ip}/get_updated_sentence")
if "error" in response.json():
pretty_print(response.json()["error"], color="failure")
try:
response = requests.get(f"http://{self.server_ip}/get_updated_sentence")
print(response)
if "error" in response.json():
pretty_print(response.json()["error"], color="failure")
break
thought = response.json()["sentence"]
is_complete = bool(response.json()["is_complete"])
time.sleep(2)
except requests.exceptions.RequestException as e:
pretty_print(f"HTTP request failed: {str(e)}", color="failure")
break
except ValueError as e:
pretty_print(f"Failed to parse JSON response: {str(e)}", color="failure")
break
except Exception as e:
pretty_print(f"An error occurred: {str(e)}", color="failure")
break
thought = response.json()["sentence"]
is_complete = bool(response.json()["is_complete"])
time.sleep(2)
except KeyError as e:
raise Exception(f"{str(e)}\nError occured with server route. Are you using the correct address for the config.ini provider?") from e
except Exception as e:
@ -300,6 +311,6 @@ class Provider:
return thought
if __name__ == "__main__":
provider = Provider("server", "deepseek-r1:14b", "192.168.1.20:3333")
provider = Provider("server", "deepseek-r1:32b", " 172.81.127.6:8080")
res = provider.respond(["user", "Hello, how are you?"])
print("Response:", res)

View File

@ -436,7 +436,7 @@ class AgentRouter:
lang = self.lang_analysis.detect_language(text)
text = self.find_first_sentence(text)
text = self.lang_analysis.translate(text, lang)
labels = [agent.role["en"] for agent in self.agents]
labels = [agent.role for agent in self.agents]
complexity = self.estimate_complexity(text)
if complexity == None:
pretty_print(f"Humm, the task seem complex but you gave very little information. can you clarify?", color="info")
@ -449,8 +449,8 @@ class AgentRouter:
except Exception as e:
raise e
for agent in self.agents:
if best_agent == agent.role["en"]:
role_name = agent.role[lang] if lang in agent.role else agent.role["en"]
if best_agent == agent.role:
role_name = agent.role
pretty_print(f"Selected agent: {agent.agent_name} (roles: {role_name})", color="warning")
return agent
pretty_print(f"Error choosing agent.", color="failure")