mirror of
https://github.com/tcsenpai/agenticSeek.git
synced 2025-06-03 01:30:11 +00:00
feat : zero shot router for other lang
This commit is contained in:
parent
a289ddf1fd
commit
af168d2c57
@ -17,7 +17,12 @@ class BrowserAgent(Agent):
|
||||
self.tools = {
|
||||
"web_search": searxSearch(),
|
||||
}
|
||||
self.role = "Web search and navigation"
|
||||
self.role = {
|
||||
"en": "Web search and navigation",
|
||||
"fr": "Recherche et navigation web",
|
||||
"zh": "网络搜索和导航",
|
||||
"es": "Búsqueda y navegación web"
|
||||
}
|
||||
self.type = "browser_agent"
|
||||
self.browser = Browser()
|
||||
self.current_page = ""
|
||||
|
@ -18,7 +18,12 @@ class CasualAgent(Agent):
|
||||
"file_finder": FileFinder(),
|
||||
"bash": BashInterpreter()
|
||||
}
|
||||
self.role = "talk"
|
||||
self.role = {
|
||||
"en": "talking",
|
||||
"fr": "discuter",
|
||||
"zh": "聊天",
|
||||
"es": "discutir"
|
||||
}
|
||||
self.type = "casual_agent"
|
||||
|
||||
def process(self, prompt, speech_module) -> str:
|
||||
|
@ -20,7 +20,12 @@ class CoderAgent(Agent):
|
||||
"go": GoInterpreter(),
|
||||
"file_finder": FileFinder()
|
||||
}
|
||||
self.role = "Coding task"
|
||||
self.role = {
|
||||
"en": "Coding task",
|
||||
"fr": "Tâche de codage",
|
||||
"zh": "编码任务",
|
||||
"es": "Tarea de codificación"
|
||||
}
|
||||
self.type = "code_agent"
|
||||
|
||||
def process(self, prompt, speech_module) -> str:
|
||||
|
@ -14,7 +14,12 @@ class FileAgent(Agent):
|
||||
"file_finder": FileFinder(),
|
||||
"bash": BashInterpreter()
|
||||
}
|
||||
self.role = "find and read files"
|
||||
self.role = {
|
||||
"en": "find and read files",
|
||||
"fr": "trouver et lire des fichiers",
|
||||
"zh": "查找和读取文件",
|
||||
"es": "encontrar y leer archivos"
|
||||
}
|
||||
self.type = "file_agent"
|
||||
|
||||
def process(self, prompt, speech_module) -> str:
|
||||
|
@ -21,7 +21,12 @@ class PlannerAgent(Agent):
|
||||
"file": FileAgent(model, name, prompt_path, provider),
|
||||
"web": BrowserAgent(model, name, prompt_path, provider)
|
||||
}
|
||||
self.role = "Research, setup and code"
|
||||
self.role = {
|
||||
"en": "Research, setup and code",
|
||||
"fr": "Recherche, configuration et codage",
|
||||
"zh": "研究,设置和编码",
|
||||
"es": "Investigación, configuración y code"
|
||||
}
|
||||
self.type = "planner_agent"
|
||||
|
||||
def parse_agent_tasks(self, text):
|
||||
|
@ -19,7 +19,7 @@ class LanguageUtility:
|
||||
text: string to analyze
|
||||
Returns: ISO639-1 language code
|
||||
"""
|
||||
langid.set_languages(['fr', 'en', 'zh', 'es']) # ISO 639-1 codes
|
||||
langid.set_languages(['fr', 'en', 'zh', 'es'])
|
||||
lang, score = langid.classify(text)
|
||||
return lang
|
||||
|
||||
|
@ -10,18 +10,23 @@ from sources.agents.code_agent import CoderAgent
|
||||
from sources.agents.casual_agent import CasualAgent
|
||||
from sources.agents.planner_agent import PlannerAgent
|
||||
from sources.agents.browser_agent import BrowserAgent
|
||||
from sources.language import LanguageUtility
|
||||
from sources.utility import pretty_print
|
||||
|
||||
class AgentRouter:
|
||||
"""
|
||||
AgentRouter is a class that selects the appropriate agent based on the user query.
|
||||
"""
|
||||
def __init__(self, agents: list, model_name: str = "facebook/bart-large-mnli"):
|
||||
def __init__(self, agents: list):
|
||||
self.model = model_name
|
||||
self.pipeline = pipeline("zero-shot-classification",
|
||||
model=self.model)
|
||||
self.lang_analysis = LanguageUtility()
|
||||
self.pipelines = {
|
||||
"fr": pipeline("zero-shot-classification", model="facebook/bart-large-mnli"),
|
||||
"zh": pipeline("zero-shot-classification", model="morit/chinese_xlm_xnli"),
|
||||
"es": pipeline("zero-shot-classification", model="facebook/bart-large-mnli"),
|
||||
"en": pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
||||
}
|
||||
self.agents = agents
|
||||
self.labels = [agent.role for agent in agents]
|
||||
|
||||
def get_device(self) -> str:
|
||||
if torch.backends.mps.is_available():
|
||||
@ -46,8 +51,23 @@ class AgentRouter:
|
||||
break
|
||||
if first_sentence is None:
|
||||
first_sentence = text
|
||||
result = self.pipeline(first_sentence, self.labels, threshold=threshold)
|
||||
return result
|
||||
try:
|
||||
lang = lang_analysis.detect_language(first_sentence)
|
||||
assert lang in ["en", "fr", "zh", "es"]
|
||||
labels = [agent.role[lang] for agent in agents]
|
||||
if lang == "en":
|
||||
result = self.pipelines['en'](first_sentence, labels, threshold=threshold)
|
||||
elif lang == "fr":
|
||||
result = self.pipelines['fr'](first_sentence, labels, threshold=threshold)
|
||||
elif lang == "zh":
|
||||
result = self.pipelines['zh'](first_sentence, labels, threshold=threshold)
|
||||
elif lang == "es":
|
||||
result = self.pipelines['es'](first_sentence, labels, threshold=threshold)
|
||||
else:
|
||||
result = None
|
||||
except Exception as e:
|
||||
return None, lang
|
||||
return result, lang
|
||||
|
||||
def select_agent(self, text: str) -> Agent:
|
||||
"""
|
||||
@ -57,13 +77,14 @@ class AgentRouter:
|
||||
Returns:
|
||||
Agent: The selected agent
|
||||
"""
|
||||
if len(self.agents) == 0 or len(self.labels) == 0:
|
||||
if len(self.agents) == 0:
|
||||
return self.agents[0]
|
||||
result = self.classify_text(text)
|
||||
result, lang = self.classify_text(text)
|
||||
for agent in self.agents:
|
||||
if result["labels"][0] == agent.role:
|
||||
pretty_print(f"Selected agent: {agent.agent_name} (roles: {agent.role})", color="warning")
|
||||
if result["labels"][0] == agent.role[lang]:
|
||||
pretty_print(f"Selected agent: {agent.agent_name} (roles: {agent.role[lang]})", color="warning")
|
||||
return agent
|
||||
pretty_print(f"Error choosing agent.", color="error")
|
||||
return None
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user