From 757a9b1e3eaebe9ae6702b467b00188aa783e666 Mon Sep 17 00:00:00 2001 From: martin legrand Date: Wed, 26 Mar 2025 10:54:25 +0100 Subject: [PATCH] feat : task complexity routing --- sources/agents/browser_agent.py | 4 + sources/interaction.py | 5 +- sources/router.py | 147 ++++++++++++++++++++++++++++---- 3 files changed, 136 insertions(+), 20 deletions(-) diff --git a/sources/agents/browser_agent.py b/sources/agents/browser_agent.py index d4dd8fe..1447a37 100644 --- a/sources/agents/browser_agent.py +++ b/sources/agents/browser_agent.py @@ -210,6 +210,7 @@ class BrowserAgent(Agent): You: "search: Recent space missions news, {self.date}" Do not explain, do not write anything beside the search query. + If the query does not make any sense for a web search explain why and say REQUEST_EXIT """ def process(self, user_prompt, speech_module) -> str: @@ -218,6 +219,9 @@ class BrowserAgent(Agent): animate_thinking(f"Thinking...", color="status") self.memory.push('user', self.search_prompt(user_prompt)) ai_prompt, _ = self.llm_request() + if "REQUEST_EXIT" in ai_prompt: + # request make no sense, maybe wrong agent was allocated? + return ai_prompt, "" animate_thinking(f"Searching...", color="status") search_result_raw = self.tools["web_search"].execute([ai_prompt], False) search_result = self.jsonify_search_results(search_result_raw)[:7] # until futher improvement diff --git a/sources/interaction.py b/sources/interaction.py index 6613119..6c5374e 100644 --- a/sources/interaction.py +++ b/sources/interaction.py @@ -103,13 +103,14 @@ class Interaction: self.current_agent.memory.push('user', self.last_query) self.current_agent.memory.push('assistant', self.last_answer) self.current_agent = agent - #self.last_answer, _ = agent.process(self.last_query, self.speech) + self.last_answer, _ = agent.process(self.last_query, self.speech) def show_answer(self) -> None: """Show the answer to the user.""" if self.last_query is None: return - self.current_agent.show_answer() + if self.current_agent is None: + self.current_agent.show_answer() if self.tts_enabled and self.last_answer: self.speech.speak(self.last_answer) diff --git a/sources/router.py b/sources/router.py index 06d6eaa..3152915 100644 --- a/sources/router.py +++ b/sources/router.py @@ -20,15 +20,16 @@ class AgentRouter: """ AgentRouter is a class that selects the appropriate agent based on the user query. """ - # TODO add adaptive-classifier==0.0.10 to requirements.txt def __init__(self, agents: list): self.agents = agents self.lang_analysis = LanguageUtility() self.pipelines = { "bart": pipeline("zero-shot-classification", model="facebook/bart-large-mnli") } - self.classifier = self.load_llm_router() - self.learn_few_shots_en() + self.talk_classifier = self.load_llm_router() + self.complexity_classifier = self.load_llm_router() + self.learn_few_shots_tasks() + self.learn_few_shots_complexity() def load_llm_router(self) -> AdaptiveClassifier: """ @@ -40,10 +41,10 @@ class AgentRouter: """ path = "../llm_router" if __name__ == "__main__" else "./llm_router" try: - classifier = AdaptiveClassifier.from_pretrained(path) + talk_classifier = AdaptiveClassifier.from_pretrained(path) except Exception as e: raise Exception("Failed to load the routing model. Please run the dl_safetensors.sh script inside llm_router/ directory to download the model.") - return classifier + return talk_classifier def get_device(self) -> str: if torch.backends.mps.is_available(): @@ -52,11 +53,80 @@ class AgentRouter: return "cuda:0" else: return "cpu" - - def learn_few_shots_en(self) -> None: + + def learn_few_shots_complexity(self) -> None: """ - Few shot learning for the LLM router. - Use the build in add_examples method of the AdaptiveClassifier. + Few shot learning for complexity estimation. + Use the build in add_examples method of the Adaptive_classifier. + """ + few_shots = [ + ("can you find api and build a python web app with it ?", "HIGH"), + ("can you lookup for api that track flight and build a web flight tracking app", "HIGH"), + ("can you find a file called resume.docx on my drive?", "LOW"), + ("can you write a python script to check if the device on my network is connected to the internet", "LOW"), + ("can you debug this Java code? It’s not working.", "LOW"), + ("can you browse the web and find me a 4090 for cheap?", "LOW"), + ("can you find the old_project.zip file somewhere on my drive?", "LOW"), + ("can you locate the backup folder I created last month on my system?", "LOW"), + ("could you check if the presentation.pdf file exists in my downloads?", "LOW"), + ("search my drive for a file called vacation_photos_2023.jpg.", "LOW"), + ("help me organize my desktop files into folders by type.", "LOW"), + ("write a Python function to sort a list of dictionaries by key", "LOW"), + ("find the latest updates on quantum computing on the web", "LOW"), + ("check if the folder ‘Work_Projects’ exists on my desktop", "LOW"), + ("create a bash script to monitor CPU usage", "LOW"), + ("debug this C++ code that keeps crashing", "LOW"), + ("can you browse the web to find out who fosowl is ?", "LOW"), + ("find the file ‘important_notes.txt’", "LOW"), + ("search the web for the best ways to learn a new language", "LOW"), + ("locate the file ‘presentation.pptx’ in my Documents folder", "LOW"), + ("Make a 3d game in javascript using three.js", "HIGH"), + ("Create a whole web app in python using the flask framework that query news API", "HIGH"), + ("Find the latest research papers on AI and build a web app that display them", "HIGH"), + ("Create a bash script that monitor the CPU usage and send an email if it's too high", "HIGH"), + ("Make a web server in go that serve a simple html page", "LOW"), + ("Make a web server in go that query a weather API and display the weather", "HIGH"), + ("Make a web search for latest news on the stock market and display them", "HIGH"), + ("Search the web for latest ai papers", "LOW"), + ("Write a Python script to calculate the factorial of a number", "LOW"), + ("Can you find a weather API and build a Python app to display current weather", "HIGH"), + ("Search the web for the cheapest 4K monitor and provide a link", "LOW"), + ("Create a Python web app using Flask to track cryptocurrency prices from an API", "HIGH"), + ("Write a JavaScript function to reverse a string", "LOW"), + ("Can you locate a file called ‘budget_2025.xlsx’ on my system?", "LOW"), + ("Search the web for recent articles on space exploration", "LOW"), + ("Find a public API for movie data and build a web app to display movie ratings", "HIGH"), + ("Write a bash script to list all files in a directory", "LOW"), + ("Check if a folder named ‘Photos_2024’ exists on my desktop", "LOW"), + ("Create a Python script to rename all files in a folder based on their creation date", "LOW"), + ("Search the web for tutorials on machine learning and build a simple ML model in Python", "HIGH"), + ("Debug this Python code that’s throwing an error", "LOW"), + ("Can you find a file named ‘meeting_notes.txt’ in my Downloads folder?", "LOW"), + ("Create a JavaScript game using Phaser.js with multiple levels", "HIGH"), + ("Write a Go program to check if a port is open on a network", "LOW"), + ("Search the web for the latest electric car reviews", "LOW"), + ("Find a public API for book data and create a Flask app to list bestsellers", "HIGH"), + ("Write a Python function to merge two sorted lists", "LOW"), + ("Organize my desktop files by extension and then write a script to list them", "HIGH"), + ("Create a bash script to monitor disk space and alert via text file", "LOW"), + ("Search X for posts about AI ethics and summarize them", "LOW"), + ("Find the latest research on renewable energy and build a web app to display it", "HIGH"), + ("Write a C program to sort an array of integers", "LOW"), + ("Create a Node.js server that queries a public API for traffic data and displays it", "HIGH"), + ("Check if a file named ‘project_proposal.pdf’ exists in my Documents", "LOW"), + ("Search the web for tips on improving coding skills", "LOW"), + ("Write a Python script to count words in a text file", "LOW"), + ("Find a public API for sports scores and build a web app to show live updates", "HIGH"), + ("Create a simple HTML page with CSS styling", "LOW"), + ] + texts = [text for text, _ in few_shots] + labels = [label for _, label in few_shots] + self.complexity_classifier.add_examples(texts, labels) + + def learn_few_shots_tasks(self) -> None: + """ + Few shot learning for tasks classification. + Use the build in add_examples method of the Adaptive_classifier. """ few_shots = [ ("Write a python script to check if the device on my network is connected to the internet", "coding"), @@ -193,8 +263,7 @@ class AgentRouter: ] texts = [text for text, _ in few_shots] labels = [label for _, label in few_shots] - self.classifier.clear_memory() - self.classifier.add_examples(texts, labels) + self.talk_classifier.add_examples(texts, labels) def llm_router(self, text: str) -> tuple: """ @@ -202,7 +271,7 @@ class AgentRouter: Args: text: The input text """ - predictions = self.classifier.predict(text) + predictions = self.talk_classifier.predict(text) predictions = [pred for pred in predictions if pred[0] not in ["HIGH", "LOW"]] predictions = sorted(predictions, key=lambda x: x[1], reverse=True) return predictions[0] @@ -238,14 +307,50 @@ class AgentRouter: if first_sentence is None: first_sentence = text try: - lang = self.lang_analysis.detect_language(first_sentence) - # only use english role labels for now, we don't have a multilingual router yet - labels = [agent.role["en"] for agent in self.agents] + #lang = self.lang_analysis.detect_language(first_sentence) + lang = "en" # NOTE only use english role labels for now, we don't have a multilingual router yet + labels = [agent.role[lang] for agent in self.agents] result = self.router_vote(first_sentence, labels, log_confidence=True) except Exception as e: raise e return result, lang + def estimate_complexity(self, text: str) -> str: + """ + Estimate the complexity of the text. + Args: + text: The input text + Returns: + str: The estimated complexity + """ + predictions = self.complexity_classifier.predict(text) + predictions = sorted(predictions, key=lambda x: x[1], reverse=True) + if len(predictions) == 0: + return "LOW" + complexity, confidence = predictions[0][0], predictions[0][1] + if confidence < 0.4: + return "LOW" + if complexity == "HIGH" and len(text) < 64: + return None # ask for more info + if complexity == "HIGH": + return "HIGH" + elif complexity == "LOW": + return "LOW" + pretty_print(f"Failed to estimate the complexity of the text. Confidence: {confidence}", color="failure") + return None + + def find_planner_agent(self) -> Agent: + """ + Find the planner agent. + Returns: + Agent: The planner agent + """ + for agent in self.agents: + if agent.type == "planner_agent": + return agent + pretty_print(f"Error finding planner agent. Please add a planner agent to the list of agents.", color="failure") + return None + def select_agent(self, text: str) -> Agent: """ Select the appropriate agent based on the text. @@ -256,9 +361,16 @@ class AgentRouter: """ if len(self.agents) == 0: return self.agents[0] - result, lang = self.classify_text(text) + complexity = self.estimate_complexity(text) + if complexity == None: + pretty_print(f"Humm, the task seem complex but you gave very little information. can you clarify?", color="info") + return None + if complexity == "HIGH": + pretty_print(f"Complex task detected, routing to planner agent.", color="info") + return self.find_planner_agent() + best_agent, lang = self.classify_text(text) for agent in self.agents: - if result == agent.role[lang]: + if best_agent == agent.role[lang]: pretty_print(f"Selected agent: {agent.agent_name} (roles: {agent.role[lang]})", color="warning") return agent pretty_print(f"Error choosing agent. Routing system is not multilingual yet.", color="failure") @@ -266,7 +378,6 @@ class AgentRouter: pretty_print(f"エージェントの選択エラー。ルーティングシステムはまだ多言語に対応していません", color="failure") pretty_print(f"Erreur lors du choix de l'agent. Le système de routage n'est pas encore multilingue.", color="failure") pretty_print(f"Error al elegir agente. El sistema de enrutamiento aún no es multilingüe.", color="failure") - return None if __name__ == "__main__":