feat : task complexity routing

This commit is contained in:
martin legrand 2025-03-26 10:54:25 +01:00
parent 32bc096d9a
commit 757a9b1e3e
3 changed files with 136 additions and 20 deletions

View File

@ -210,6 +210,7 @@ class BrowserAgent(Agent):
You: "search: Recent space missions news, {self.date}"
Do not explain, do not write anything beside the search query.
If the query does not make any sense for a web search explain why and say REQUEST_EXIT
"""
def process(self, user_prompt, speech_module) -> str:
@ -218,6 +219,9 @@ class BrowserAgent(Agent):
animate_thinking(f"Thinking...", color="status")
self.memory.push('user', self.search_prompt(user_prompt))
ai_prompt, _ = self.llm_request()
if "REQUEST_EXIT" in ai_prompt:
# request make no sense, maybe wrong agent was allocated?
return ai_prompt, ""
animate_thinking(f"Searching...", color="status")
search_result_raw = self.tools["web_search"].execute([ai_prompt], False)
search_result = self.jsonify_search_results(search_result_raw)[:7] # until futher improvement

View File

@ -103,13 +103,14 @@ class Interaction:
self.current_agent.memory.push('user', self.last_query)
self.current_agent.memory.push('assistant', self.last_answer)
self.current_agent = agent
#self.last_answer, _ = agent.process(self.last_query, self.speech)
self.last_answer, _ = agent.process(self.last_query, self.speech)
def show_answer(self) -> None:
"""Show the answer to the user."""
if self.last_query is None:
return
self.current_agent.show_answer()
if self.current_agent is None:
self.current_agent.show_answer()
if self.tts_enabled and self.last_answer:
self.speech.speak(self.last_answer)

View File

@ -20,15 +20,16 @@ class AgentRouter:
"""
AgentRouter is a class that selects the appropriate agent based on the user query.
"""
# TODO add adaptive-classifier==0.0.10 to requirements.txt
def __init__(self, agents: list):
self.agents = agents
self.lang_analysis = LanguageUtility()
self.pipelines = {
"bart": pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
}
self.classifier = self.load_llm_router()
self.learn_few_shots_en()
self.talk_classifier = self.load_llm_router()
self.complexity_classifier = self.load_llm_router()
self.learn_few_shots_tasks()
self.learn_few_shots_complexity()
def load_llm_router(self) -> AdaptiveClassifier:
"""
@ -40,10 +41,10 @@ class AgentRouter:
"""
path = "../llm_router" if __name__ == "__main__" else "./llm_router"
try:
classifier = AdaptiveClassifier.from_pretrained(path)
talk_classifier = AdaptiveClassifier.from_pretrained(path)
except Exception as e:
raise Exception("Failed to load the routing model. Please run the dl_safetensors.sh script inside llm_router/ directory to download the model.")
return classifier
return talk_classifier
def get_device(self) -> str:
if torch.backends.mps.is_available():
@ -52,11 +53,80 @@ class AgentRouter:
return "cuda:0"
else:
return "cpu"
def learn_few_shots_en(self) -> None:
def learn_few_shots_complexity(self) -> None:
"""
Few shot learning for the LLM router.
Use the build in add_examples method of the AdaptiveClassifier.
Few shot learning for complexity estimation.
Use the build in add_examples method of the Adaptive_classifier.
"""
few_shots = [
("can you find api and build a python web app with it ?", "HIGH"),
("can you lookup for api that track flight and build a web flight tracking app", "HIGH"),
("can you find a file called resume.docx on my drive?", "LOW"),
("can you write a python script to check if the device on my network is connected to the internet", "LOW"),
("can you debug this Java code? Its not working.", "LOW"),
("can you browse the web and find me a 4090 for cheap?", "LOW"),
("can you find the old_project.zip file somewhere on my drive?", "LOW"),
("can you locate the backup folder I created last month on my system?", "LOW"),
("could you check if the presentation.pdf file exists in my downloads?", "LOW"),
("search my drive for a file called vacation_photos_2023.jpg.", "LOW"),
("help me organize my desktop files into folders by type.", "LOW"),
("write a Python function to sort a list of dictionaries by key", "LOW"),
("find the latest updates on quantum computing on the web", "LOW"),
("check if the folder Work_Projects exists on my desktop", "LOW"),
("create a bash script to monitor CPU usage", "LOW"),
("debug this C++ code that keeps crashing", "LOW"),
("can you browse the web to find out who fosowl is ?", "LOW"),
("find the file important_notes.txt", "LOW"),
("search the web for the best ways to learn a new language", "LOW"),
("locate the file presentation.pptx in my Documents folder", "LOW"),
("Make a 3d game in javascript using three.js", "HIGH"),
("Create a whole web app in python using the flask framework that query news API", "HIGH"),
("Find the latest research papers on AI and build a web app that display them", "HIGH"),
("Create a bash script that monitor the CPU usage and send an email if it's too high", "HIGH"),
("Make a web server in go that serve a simple html page", "LOW"),
("Make a web server in go that query a weather API and display the weather", "HIGH"),
("Make a web search for latest news on the stock market and display them", "HIGH"),
("Search the web for latest ai papers", "LOW"),
("Write a Python script to calculate the factorial of a number", "LOW"),
("Can you find a weather API and build a Python app to display current weather", "HIGH"),
("Search the web for the cheapest 4K monitor and provide a link", "LOW"),
("Create a Python web app using Flask to track cryptocurrency prices from an API", "HIGH"),
("Write a JavaScript function to reverse a string", "LOW"),
("Can you locate a file called budget_2025.xlsx on my system?", "LOW"),
("Search the web for recent articles on space exploration", "LOW"),
("Find a public API for movie data and build a web app to display movie ratings", "HIGH"),
("Write a bash script to list all files in a directory", "LOW"),
("Check if a folder named Photos_2024 exists on my desktop", "LOW"),
("Create a Python script to rename all files in a folder based on their creation date", "LOW"),
("Search the web for tutorials on machine learning and build a simple ML model in Python", "HIGH"),
("Debug this Python code thats throwing an error", "LOW"),
("Can you find a file named meeting_notes.txt in my Downloads folder?", "LOW"),
("Create a JavaScript game using Phaser.js with multiple levels", "HIGH"),
("Write a Go program to check if a port is open on a network", "LOW"),
("Search the web for the latest electric car reviews", "LOW"),
("Find a public API for book data and create a Flask app to list bestsellers", "HIGH"),
("Write a Python function to merge two sorted lists", "LOW"),
("Organize my desktop files by extension and then write a script to list them", "HIGH"),
("Create a bash script to monitor disk space and alert via text file", "LOW"),
("Search X for posts about AI ethics and summarize them", "LOW"),
("Find the latest research on renewable energy and build a web app to display it", "HIGH"),
("Write a C program to sort an array of integers", "LOW"),
("Create a Node.js server that queries a public API for traffic data and displays it", "HIGH"),
("Check if a file named project_proposal.pdf exists in my Documents", "LOW"),
("Search the web for tips on improving coding skills", "LOW"),
("Write a Python script to count words in a text file", "LOW"),
("Find a public API for sports scores and build a web app to show live updates", "HIGH"),
("Create a simple HTML page with CSS styling", "LOW"),
]
texts = [text for text, _ in few_shots]
labels = [label for _, label in few_shots]
self.complexity_classifier.add_examples(texts, labels)
def learn_few_shots_tasks(self) -> None:
"""
Few shot learning for tasks classification.
Use the build in add_examples method of the Adaptive_classifier.
"""
few_shots = [
("Write a python script to check if the device on my network is connected to the internet", "coding"),
@ -193,8 +263,7 @@ class AgentRouter:
]
texts = [text for text, _ in few_shots]
labels = [label for _, label in few_shots]
self.classifier.clear_memory()
self.classifier.add_examples(texts, labels)
self.talk_classifier.add_examples(texts, labels)
def llm_router(self, text: str) -> tuple:
"""
@ -202,7 +271,7 @@ class AgentRouter:
Args:
text: The input text
"""
predictions = self.classifier.predict(text)
predictions = self.talk_classifier.predict(text)
predictions = [pred for pred in predictions if pred[0] not in ["HIGH", "LOW"]]
predictions = sorted(predictions, key=lambda x: x[1], reverse=True)
return predictions[0]
@ -238,14 +307,50 @@ class AgentRouter:
if first_sentence is None:
first_sentence = text
try:
lang = self.lang_analysis.detect_language(first_sentence)
# only use english role labels for now, we don't have a multilingual router yet
labels = [agent.role["en"] for agent in self.agents]
#lang = self.lang_analysis.detect_language(first_sentence)
lang = "en" # NOTE only use english role labels for now, we don't have a multilingual router yet
labels = [agent.role[lang] for agent in self.agents]
result = self.router_vote(first_sentence, labels, log_confidence=True)
except Exception as e:
raise e
return result, lang
def estimate_complexity(self, text: str) -> str:
"""
Estimate the complexity of the text.
Args:
text: The input text
Returns:
str: The estimated complexity
"""
predictions = self.complexity_classifier.predict(text)
predictions = sorted(predictions, key=lambda x: x[1], reverse=True)
if len(predictions) == 0:
return "LOW"
complexity, confidence = predictions[0][0], predictions[0][1]
if confidence < 0.4:
return "LOW"
if complexity == "HIGH" and len(text) < 64:
return None # ask for more info
if complexity == "HIGH":
return "HIGH"
elif complexity == "LOW":
return "LOW"
pretty_print(f"Failed to estimate the complexity of the text. Confidence: {confidence}", color="failure")
return None
def find_planner_agent(self) -> Agent:
"""
Find the planner agent.
Returns:
Agent: The planner agent
"""
for agent in self.agents:
if agent.type == "planner_agent":
return agent
pretty_print(f"Error finding planner agent. Please add a planner agent to the list of agents.", color="failure")
return None
def select_agent(self, text: str) -> Agent:
"""
Select the appropriate agent based on the text.
@ -256,9 +361,16 @@ class AgentRouter:
"""
if len(self.agents) == 0:
return self.agents[0]
result, lang = self.classify_text(text)
complexity = self.estimate_complexity(text)
if complexity == None:
pretty_print(f"Humm, the task seem complex but you gave very little information. can you clarify?", color="info")
return None
if complexity == "HIGH":
pretty_print(f"Complex task detected, routing to planner agent.", color="info")
return self.find_planner_agent()
best_agent, lang = self.classify_text(text)
for agent in self.agents:
if result == agent.role[lang]:
if best_agent == agent.role[lang]:
pretty_print(f"Selected agent: {agent.agent_name} (roles: {agent.role[lang]})", color="warning")
return agent
pretty_print(f"Error choosing agent. Routing system is not multilingual yet.", color="failure")
@ -266,7 +378,6 @@ class AgentRouter:
pretty_print(f"エージェントの選択エラー。ルーティングシステムはまだ多言語に対応していません", color="failure")
pretty_print(f"Erreur lors du choix de l'agent. Le système de routage n'est pas encore multilingue.", color="failure")
pretty_print(f"Error al elegir agente. El sistema de enrutamiento aún no es multilingüe.", color="failure")
return None
if __name__ == "__main__":