From af168d2c572f33e1c60527de4f092d8f3d0f0a57 Mon Sep 17 00:00:00 2001 From: martin legrand Date: Mon, 24 Mar 2025 20:22:30 +0100 Subject: [PATCH] feat : zero shot router for other lang --- sources/agents/browser_agent.py | 7 +++++- sources/agents/casual_agent.py | 7 +++++- sources/agents/code_agent.py | 7 +++++- sources/agents/file_agent.py | 7 +++++- sources/agents/planner_agent.py | 7 +++++- sources/language.py | 2 +- sources/router.py | 41 +++++++++++++++++++++++++-------- 7 files changed, 62 insertions(+), 16 deletions(-) diff --git a/sources/agents/browser_agent.py b/sources/agents/browser_agent.py index fa0d298..c6edff7 100644 --- a/sources/agents/browser_agent.py +++ b/sources/agents/browser_agent.py @@ -17,7 +17,12 @@ class BrowserAgent(Agent): self.tools = { "web_search": searxSearch(), } - self.role = "Web search and navigation" + self.role = { + "en": "Web search and navigation", + "fr": "Recherche et navigation web", + "zh": "网络搜索和导航", + "es": "Búsqueda y navegación web" + } self.type = "browser_agent" self.browser = Browser() self.current_page = "" diff --git a/sources/agents/casual_agent.py b/sources/agents/casual_agent.py index 114d617..db67f78 100644 --- a/sources/agents/casual_agent.py +++ b/sources/agents/casual_agent.py @@ -18,7 +18,12 @@ class CasualAgent(Agent): "file_finder": FileFinder(), "bash": BashInterpreter() } - self.role = "talk" + self.role = { + "en": "talking", + "fr": "discuter", + "zh": "聊天", + "es": "discutir" + } self.type = "casual_agent" def process(self, prompt, speech_module) -> str: diff --git a/sources/agents/code_agent.py b/sources/agents/code_agent.py index c454e63..7cc400b 100644 --- a/sources/agents/code_agent.py +++ b/sources/agents/code_agent.py @@ -20,7 +20,12 @@ class CoderAgent(Agent): "go": GoInterpreter(), "file_finder": FileFinder() } - self.role = "Coding task" + self.role = { + "en": "Coding task", + "fr": "Tâche de codage", + "zh": "编码任务", + "es": "Tarea de codificación" + } self.type = "code_agent" def process(self, prompt, speech_module) -> str: diff --git a/sources/agents/file_agent.py b/sources/agents/file_agent.py index e493605..123eea2 100644 --- a/sources/agents/file_agent.py +++ b/sources/agents/file_agent.py @@ -14,7 +14,12 @@ class FileAgent(Agent): "file_finder": FileFinder(), "bash": BashInterpreter() } - self.role = "find and read files" + self.role = { + "en": "find and read files", + "fr": "trouver et lire des fichiers", + "zh": "查找和读取文件", + "es": "encontrar y leer archivos" + } self.type = "file_agent" def process(self, prompt, speech_module) -> str: diff --git a/sources/agents/planner_agent.py b/sources/agents/planner_agent.py index 6bcca80..56b68c0 100644 --- a/sources/agents/planner_agent.py +++ b/sources/agents/planner_agent.py @@ -21,7 +21,12 @@ class PlannerAgent(Agent): "file": FileAgent(model, name, prompt_path, provider), "web": BrowserAgent(model, name, prompt_path, provider) } - self.role = "Research, setup and code" + self.role = { + "en": "Research, setup and code", + "fr": "Recherche, configuration et codage", + "zh": "研究,设置和编码", + "es": "Investigación, configuración y code" + } self.type = "planner_agent" def parse_agent_tasks(self, text): diff --git a/sources/language.py b/sources/language.py index a4cc98d..16c80d2 100644 --- a/sources/language.py +++ b/sources/language.py @@ -19,7 +19,7 @@ class LanguageUtility: text: string to analyze Returns: ISO639-1 language code """ - langid.set_languages(['fr', 'en', 'zh', 'es']) # ISO 639-1 codes + langid.set_languages(['fr', 'en', 'zh', 'es']) lang, score = langid.classify(text) return lang diff --git a/sources/router.py b/sources/router.py index 1fc6781..23444ae 100644 --- a/sources/router.py +++ b/sources/router.py @@ -10,18 +10,23 @@ from sources.agents.code_agent import CoderAgent from sources.agents.casual_agent import CasualAgent from sources.agents.planner_agent import PlannerAgent from sources.agents.browser_agent import BrowserAgent +from sources.language import LanguageUtility from sources.utility import pretty_print class AgentRouter: """ AgentRouter is a class that selects the appropriate agent based on the user query. """ - def __init__(self, agents: list, model_name: str = "facebook/bart-large-mnli"): + def __init__(self, agents: list): self.model = model_name - self.pipeline = pipeline("zero-shot-classification", - model=self.model) + self.lang_analysis = LanguageUtility() + self.pipelines = { + "fr": pipeline("zero-shot-classification", model="facebook/bart-large-mnli"), + "zh": pipeline("zero-shot-classification", model="morit/chinese_xlm_xnli"), + "es": pipeline("zero-shot-classification", model="facebook/bart-large-mnli"), + "en": pipeline("zero-shot-classification", model="facebook/bart-large-mnli") + } self.agents = agents - self.labels = [agent.role for agent in agents] def get_device(self) -> str: if torch.backends.mps.is_available(): @@ -46,8 +51,23 @@ class AgentRouter: break if first_sentence is None: first_sentence = text - result = self.pipeline(first_sentence, self.labels, threshold=threshold) - return result + try: + lang = lang_analysis.detect_language(first_sentence) + assert lang in ["en", "fr", "zh", "es"] + labels = [agent.role[lang] for agent in agents] + if lang == "en": + result = self.pipelines['en'](first_sentence, labels, threshold=threshold) + elif lang == "fr": + result = self.pipelines['fr'](first_sentence, labels, threshold=threshold) + elif lang == "zh": + result = self.pipelines['zh'](first_sentence, labels, threshold=threshold) + elif lang == "es": + result = self.pipelines['es'](first_sentence, labels, threshold=threshold) + else: + result = None + except Exception as e: + return None, lang + return result, lang def select_agent(self, text: str) -> Agent: """ @@ -57,13 +77,14 @@ class AgentRouter: Returns: Agent: The selected agent """ - if len(self.agents) == 0 or len(self.labels) == 0: + if len(self.agents) == 0: return self.agents[0] - result = self.classify_text(text) + result, lang = self.classify_text(text) for agent in self.agents: - if result["labels"][0] == agent.role: - pretty_print(f"Selected agent: {agent.agent_name} (roles: {agent.role})", color="warning") + if result["labels"][0] == agent.role[lang]: + pretty_print(f"Selected agent: {agent.agent_name} (roles: {agent.role[lang]})", color="warning") return agent + pretty_print(f"Error choosing agent.", color="error") return None if __name__ == "__main__":