mirror of
https://github.com/tcsenpai/agenticSeek.git
synced 2025-06-10 12:57:14 +00:00
feat : task complexity routing
This commit is contained in:
parent
32bc096d9a
commit
757a9b1e3e
@ -210,6 +210,7 @@ class BrowserAgent(Agent):
|
|||||||
You: "search: Recent space missions news, {self.date}"
|
You: "search: Recent space missions news, {self.date}"
|
||||||
|
|
||||||
Do not explain, do not write anything beside the search query.
|
Do not explain, do not write anything beside the search query.
|
||||||
|
If the query does not make any sense for a web search explain why and say REQUEST_EXIT
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def process(self, user_prompt, speech_module) -> str:
|
def process(self, user_prompt, speech_module) -> str:
|
||||||
@ -218,6 +219,9 @@ class BrowserAgent(Agent):
|
|||||||
animate_thinking(f"Thinking...", color="status")
|
animate_thinking(f"Thinking...", color="status")
|
||||||
self.memory.push('user', self.search_prompt(user_prompt))
|
self.memory.push('user', self.search_prompt(user_prompt))
|
||||||
ai_prompt, _ = self.llm_request()
|
ai_prompt, _ = self.llm_request()
|
||||||
|
if "REQUEST_EXIT" in ai_prompt:
|
||||||
|
# request make no sense, maybe wrong agent was allocated?
|
||||||
|
return ai_prompt, ""
|
||||||
animate_thinking(f"Searching...", color="status")
|
animate_thinking(f"Searching...", color="status")
|
||||||
search_result_raw = self.tools["web_search"].execute([ai_prompt], False)
|
search_result_raw = self.tools["web_search"].execute([ai_prompt], False)
|
||||||
search_result = self.jsonify_search_results(search_result_raw)[:7] # until futher improvement
|
search_result = self.jsonify_search_results(search_result_raw)[:7] # until futher improvement
|
||||||
|
@ -103,12 +103,13 @@ class Interaction:
|
|||||||
self.current_agent.memory.push('user', self.last_query)
|
self.current_agent.memory.push('user', self.last_query)
|
||||||
self.current_agent.memory.push('assistant', self.last_answer)
|
self.current_agent.memory.push('assistant', self.last_answer)
|
||||||
self.current_agent = agent
|
self.current_agent = agent
|
||||||
#self.last_answer, _ = agent.process(self.last_query, self.speech)
|
self.last_answer, _ = agent.process(self.last_query, self.speech)
|
||||||
|
|
||||||
def show_answer(self) -> None:
|
def show_answer(self) -> None:
|
||||||
"""Show the answer to the user."""
|
"""Show the answer to the user."""
|
||||||
if self.last_query is None:
|
if self.last_query is None:
|
||||||
return
|
return
|
||||||
|
if self.current_agent is None:
|
||||||
self.current_agent.show_answer()
|
self.current_agent.show_answer()
|
||||||
if self.tts_enabled and self.last_answer:
|
if self.tts_enabled and self.last_answer:
|
||||||
self.speech.speak(self.last_answer)
|
self.speech.speak(self.last_answer)
|
||||||
|
@ -20,15 +20,16 @@ class AgentRouter:
|
|||||||
"""
|
"""
|
||||||
AgentRouter is a class that selects the appropriate agent based on the user query.
|
AgentRouter is a class that selects the appropriate agent based on the user query.
|
||||||
"""
|
"""
|
||||||
# TODO add adaptive-classifier==0.0.10 to requirements.txt
|
|
||||||
def __init__(self, agents: list):
|
def __init__(self, agents: list):
|
||||||
self.agents = agents
|
self.agents = agents
|
||||||
self.lang_analysis = LanguageUtility()
|
self.lang_analysis = LanguageUtility()
|
||||||
self.pipelines = {
|
self.pipelines = {
|
||||||
"bart": pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
"bart": pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
||||||
}
|
}
|
||||||
self.classifier = self.load_llm_router()
|
self.talk_classifier = self.load_llm_router()
|
||||||
self.learn_few_shots_en()
|
self.complexity_classifier = self.load_llm_router()
|
||||||
|
self.learn_few_shots_tasks()
|
||||||
|
self.learn_few_shots_complexity()
|
||||||
|
|
||||||
def load_llm_router(self) -> AdaptiveClassifier:
|
def load_llm_router(self) -> AdaptiveClassifier:
|
||||||
"""
|
"""
|
||||||
@ -40,10 +41,10 @@ class AgentRouter:
|
|||||||
"""
|
"""
|
||||||
path = "../llm_router" if __name__ == "__main__" else "./llm_router"
|
path = "../llm_router" if __name__ == "__main__" else "./llm_router"
|
||||||
try:
|
try:
|
||||||
classifier = AdaptiveClassifier.from_pretrained(path)
|
talk_classifier = AdaptiveClassifier.from_pretrained(path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception("Failed to load the routing model. Please run the dl_safetensors.sh script inside llm_router/ directory to download the model.")
|
raise Exception("Failed to load the routing model. Please run the dl_safetensors.sh script inside llm_router/ directory to download the model.")
|
||||||
return classifier
|
return talk_classifier
|
||||||
|
|
||||||
def get_device(self) -> str:
|
def get_device(self) -> str:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
@ -53,10 +54,79 @@ class AgentRouter:
|
|||||||
else:
|
else:
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
def learn_few_shots_en(self) -> None:
|
def learn_few_shots_complexity(self) -> None:
|
||||||
"""
|
"""
|
||||||
Few shot learning for the LLM router.
|
Few shot learning for complexity estimation.
|
||||||
Use the build in add_examples method of the AdaptiveClassifier.
|
Use the build in add_examples method of the Adaptive_classifier.
|
||||||
|
"""
|
||||||
|
few_shots = [
|
||||||
|
("can you find api and build a python web app with it ?", "HIGH"),
|
||||||
|
("can you lookup for api that track flight and build a web flight tracking app", "HIGH"),
|
||||||
|
("can you find a file called resume.docx on my drive?", "LOW"),
|
||||||
|
("can you write a python script to check if the device on my network is connected to the internet", "LOW"),
|
||||||
|
("can you debug this Java code? It’s not working.", "LOW"),
|
||||||
|
("can you browse the web and find me a 4090 for cheap?", "LOW"),
|
||||||
|
("can you find the old_project.zip file somewhere on my drive?", "LOW"),
|
||||||
|
("can you locate the backup folder I created last month on my system?", "LOW"),
|
||||||
|
("could you check if the presentation.pdf file exists in my downloads?", "LOW"),
|
||||||
|
("search my drive for a file called vacation_photos_2023.jpg.", "LOW"),
|
||||||
|
("help me organize my desktop files into folders by type.", "LOW"),
|
||||||
|
("write a Python function to sort a list of dictionaries by key", "LOW"),
|
||||||
|
("find the latest updates on quantum computing on the web", "LOW"),
|
||||||
|
("check if the folder ‘Work_Projects’ exists on my desktop", "LOW"),
|
||||||
|
("create a bash script to monitor CPU usage", "LOW"),
|
||||||
|
("debug this C++ code that keeps crashing", "LOW"),
|
||||||
|
("can you browse the web to find out who fosowl is ?", "LOW"),
|
||||||
|
("find the file ‘important_notes.txt’", "LOW"),
|
||||||
|
("search the web for the best ways to learn a new language", "LOW"),
|
||||||
|
("locate the file ‘presentation.pptx’ in my Documents folder", "LOW"),
|
||||||
|
("Make a 3d game in javascript using three.js", "HIGH"),
|
||||||
|
("Create a whole web app in python using the flask framework that query news API", "HIGH"),
|
||||||
|
("Find the latest research papers on AI and build a web app that display them", "HIGH"),
|
||||||
|
("Create a bash script that monitor the CPU usage and send an email if it's too high", "HIGH"),
|
||||||
|
("Make a web server in go that serve a simple html page", "LOW"),
|
||||||
|
("Make a web server in go that query a weather API and display the weather", "HIGH"),
|
||||||
|
("Make a web search for latest news on the stock market and display them", "HIGH"),
|
||||||
|
("Search the web for latest ai papers", "LOW"),
|
||||||
|
("Write a Python script to calculate the factorial of a number", "LOW"),
|
||||||
|
("Can you find a weather API and build a Python app to display current weather", "HIGH"),
|
||||||
|
("Search the web for the cheapest 4K monitor and provide a link", "LOW"),
|
||||||
|
("Create a Python web app using Flask to track cryptocurrency prices from an API", "HIGH"),
|
||||||
|
("Write a JavaScript function to reverse a string", "LOW"),
|
||||||
|
("Can you locate a file called ‘budget_2025.xlsx’ on my system?", "LOW"),
|
||||||
|
("Search the web for recent articles on space exploration", "LOW"),
|
||||||
|
("Find a public API for movie data and build a web app to display movie ratings", "HIGH"),
|
||||||
|
("Write a bash script to list all files in a directory", "LOW"),
|
||||||
|
("Check if a folder named ‘Photos_2024’ exists on my desktop", "LOW"),
|
||||||
|
("Create a Python script to rename all files in a folder based on their creation date", "LOW"),
|
||||||
|
("Search the web for tutorials on machine learning and build a simple ML model in Python", "HIGH"),
|
||||||
|
("Debug this Python code that’s throwing an error", "LOW"),
|
||||||
|
("Can you find a file named ‘meeting_notes.txt’ in my Downloads folder?", "LOW"),
|
||||||
|
("Create a JavaScript game using Phaser.js with multiple levels", "HIGH"),
|
||||||
|
("Write a Go program to check if a port is open on a network", "LOW"),
|
||||||
|
("Search the web for the latest electric car reviews", "LOW"),
|
||||||
|
("Find a public API for book data and create a Flask app to list bestsellers", "HIGH"),
|
||||||
|
("Write a Python function to merge two sorted lists", "LOW"),
|
||||||
|
("Organize my desktop files by extension and then write a script to list them", "HIGH"),
|
||||||
|
("Create a bash script to monitor disk space and alert via text file", "LOW"),
|
||||||
|
("Search X for posts about AI ethics and summarize them", "LOW"),
|
||||||
|
("Find the latest research on renewable energy and build a web app to display it", "HIGH"),
|
||||||
|
("Write a C program to sort an array of integers", "LOW"),
|
||||||
|
("Create a Node.js server that queries a public API for traffic data and displays it", "HIGH"),
|
||||||
|
("Check if a file named ‘project_proposal.pdf’ exists in my Documents", "LOW"),
|
||||||
|
("Search the web for tips on improving coding skills", "LOW"),
|
||||||
|
("Write a Python script to count words in a text file", "LOW"),
|
||||||
|
("Find a public API for sports scores and build a web app to show live updates", "HIGH"),
|
||||||
|
("Create a simple HTML page with CSS styling", "LOW"),
|
||||||
|
]
|
||||||
|
texts = [text for text, _ in few_shots]
|
||||||
|
labels = [label for _, label in few_shots]
|
||||||
|
self.complexity_classifier.add_examples(texts, labels)
|
||||||
|
|
||||||
|
def learn_few_shots_tasks(self) -> None:
|
||||||
|
"""
|
||||||
|
Few shot learning for tasks classification.
|
||||||
|
Use the build in add_examples method of the Adaptive_classifier.
|
||||||
"""
|
"""
|
||||||
few_shots = [
|
few_shots = [
|
||||||
("Write a python script to check if the device on my network is connected to the internet", "coding"),
|
("Write a python script to check if the device on my network is connected to the internet", "coding"),
|
||||||
@ -193,8 +263,7 @@ class AgentRouter:
|
|||||||
]
|
]
|
||||||
texts = [text for text, _ in few_shots]
|
texts = [text for text, _ in few_shots]
|
||||||
labels = [label for _, label in few_shots]
|
labels = [label for _, label in few_shots]
|
||||||
self.classifier.clear_memory()
|
self.talk_classifier.add_examples(texts, labels)
|
||||||
self.classifier.add_examples(texts, labels)
|
|
||||||
|
|
||||||
def llm_router(self, text: str) -> tuple:
|
def llm_router(self, text: str) -> tuple:
|
||||||
"""
|
"""
|
||||||
@ -202,7 +271,7 @@ class AgentRouter:
|
|||||||
Args:
|
Args:
|
||||||
text: The input text
|
text: The input text
|
||||||
"""
|
"""
|
||||||
predictions = self.classifier.predict(text)
|
predictions = self.talk_classifier.predict(text)
|
||||||
predictions = [pred for pred in predictions if pred[0] not in ["HIGH", "LOW"]]
|
predictions = [pred for pred in predictions if pred[0] not in ["HIGH", "LOW"]]
|
||||||
predictions = sorted(predictions, key=lambda x: x[1], reverse=True)
|
predictions = sorted(predictions, key=lambda x: x[1], reverse=True)
|
||||||
return predictions[0]
|
return predictions[0]
|
||||||
@ -238,14 +307,50 @@ class AgentRouter:
|
|||||||
if first_sentence is None:
|
if first_sentence is None:
|
||||||
first_sentence = text
|
first_sentence = text
|
||||||
try:
|
try:
|
||||||
lang = self.lang_analysis.detect_language(first_sentence)
|
#lang = self.lang_analysis.detect_language(first_sentence)
|
||||||
# only use english role labels for now, we don't have a multilingual router yet
|
lang = "en" # NOTE only use english role labels for now, we don't have a multilingual router yet
|
||||||
labels = [agent.role["en"] for agent in self.agents]
|
labels = [agent.role[lang] for agent in self.agents]
|
||||||
result = self.router_vote(first_sentence, labels, log_confidence=True)
|
result = self.router_vote(first_sentence, labels, log_confidence=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
return result, lang
|
return result, lang
|
||||||
|
|
||||||
|
def estimate_complexity(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Estimate the complexity of the text.
|
||||||
|
Args:
|
||||||
|
text: The input text
|
||||||
|
Returns:
|
||||||
|
str: The estimated complexity
|
||||||
|
"""
|
||||||
|
predictions = self.complexity_classifier.predict(text)
|
||||||
|
predictions = sorted(predictions, key=lambda x: x[1], reverse=True)
|
||||||
|
if len(predictions) == 0:
|
||||||
|
return "LOW"
|
||||||
|
complexity, confidence = predictions[0][0], predictions[0][1]
|
||||||
|
if confidence < 0.4:
|
||||||
|
return "LOW"
|
||||||
|
if complexity == "HIGH" and len(text) < 64:
|
||||||
|
return None # ask for more info
|
||||||
|
if complexity == "HIGH":
|
||||||
|
return "HIGH"
|
||||||
|
elif complexity == "LOW":
|
||||||
|
return "LOW"
|
||||||
|
pretty_print(f"Failed to estimate the complexity of the text. Confidence: {confidence}", color="failure")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_planner_agent(self) -> Agent:
|
||||||
|
"""
|
||||||
|
Find the planner agent.
|
||||||
|
Returns:
|
||||||
|
Agent: The planner agent
|
||||||
|
"""
|
||||||
|
for agent in self.agents:
|
||||||
|
if agent.type == "planner_agent":
|
||||||
|
return agent
|
||||||
|
pretty_print(f"Error finding planner agent. Please add a planner agent to the list of agents.", color="failure")
|
||||||
|
return None
|
||||||
|
|
||||||
def select_agent(self, text: str) -> Agent:
|
def select_agent(self, text: str) -> Agent:
|
||||||
"""
|
"""
|
||||||
Select the appropriate agent based on the text.
|
Select the appropriate agent based on the text.
|
||||||
@ -256,9 +361,16 @@ class AgentRouter:
|
|||||||
"""
|
"""
|
||||||
if len(self.agents) == 0:
|
if len(self.agents) == 0:
|
||||||
return self.agents[0]
|
return self.agents[0]
|
||||||
result, lang = self.classify_text(text)
|
complexity = self.estimate_complexity(text)
|
||||||
|
if complexity == None:
|
||||||
|
pretty_print(f"Humm, the task seem complex but you gave very little information. can you clarify?", color="info")
|
||||||
|
return None
|
||||||
|
if complexity == "HIGH":
|
||||||
|
pretty_print(f"Complex task detected, routing to planner agent.", color="info")
|
||||||
|
return self.find_planner_agent()
|
||||||
|
best_agent, lang = self.classify_text(text)
|
||||||
for agent in self.agents:
|
for agent in self.agents:
|
||||||
if result == agent.role[lang]:
|
if best_agent == agent.role[lang]:
|
||||||
pretty_print(f"Selected agent: {agent.agent_name} (roles: {agent.role[lang]})", color="warning")
|
pretty_print(f"Selected agent: {agent.agent_name} (roles: {agent.role[lang]})", color="warning")
|
||||||
return agent
|
return agent
|
||||||
pretty_print(f"Error choosing agent. Routing system is not multilingual yet.", color="failure")
|
pretty_print(f"Error choosing agent. Routing system is not multilingual yet.", color="failure")
|
||||||
@ -266,7 +378,6 @@ class AgentRouter:
|
|||||||
pretty_print(f"エージェントの選択エラー。ルーティングシステムはまだ多言語に対応していません", color="failure")
|
pretty_print(f"エージェントの選択エラー。ルーティングシステムはまだ多言語に対応していません", color="failure")
|
||||||
pretty_print(f"Erreur lors du choix de l'agent. Le système de routage n'est pas encore multilingue.", color="failure")
|
pretty_print(f"Erreur lors du choix de l'agent. Le système de routage n'est pas encore multilingue.", color="failure")
|
||||||
pretty_print(f"Error al elegir agente. El sistema de enrutamiento aún no es multilingüe.", color="failure")
|
pretty_print(f"Error al elegir agente. El sistema de enrutamiento aún no es multilingüe.", color="failure")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user