feat : better ensemble logic for router

This commit is contained in:
martin legrand 2025-03-26 09:58:27 +01:00
parent d871c378fe
commit 32bc096d9a
2 changed files with 26 additions and 13 deletions

View File

@ -99,11 +99,11 @@ class Interaction:
if agent is None:
return
if self.current_agent != agent and self.last_answer is not None:
## get history from previous agent, good ?
## get last history from previous agent
self.current_agent.memory.push('user', self.last_query)
self.current_agent.memory.push('assistant', self.last_answer)
self.current_agent = agent
self.last_answer, _ = agent.process(self.last_query, self.speech)
#self.last_answer, _ = agent.process(self.last_query, self.speech)
def show_answer(self) -> None:
"""Show the answer to the user."""

View File

@ -176,6 +176,20 @@ class AgentRouter:
("Whats your favorite thing about space?", "talk"),
("Browse the web for the latest fitness trends", "web"),
("Move all .docx files to a Work folder", "files"),
("I would like to make a new project called 'new_project'", "files"),
("I would like to setup a new project index as mark2", "files"),
("can you create a 3d js game that run in the browser", "code"),
("can you make a web app in python that use the flask framework", "code"),
("can you build a web server in go that serve a simple html page", "code"),
("can you find out who Jacky yougouri is ?", "web"),
("Setup a new flutter project called 'new_flutter_project'", "files"),
("can you create a new project called 'new_project'", "files"),
("can you make a simple web app that display a list of files in my dir", "code"),
("can you build a simple web server in python that serve a html page", "code"),
("find and buy me the latest rtx 4090", "web"),
("What are some good netflix show like Altered Carbon ?", "web"),
("can you find the latest research paper on AI", "web"),
("can you find research.pdf in my drive", "files"),
]
texts = [text for text, _ in few_shots]
labels = [label for _, label in few_shots]
@ -193,7 +207,7 @@ class AgentRouter:
predictions = sorted(predictions, key=lambda x: x[1], reverse=True)
return predictions[0]
def router_vote(self, text: str, labels: list) -> str:
def router_vote(self, text: str, labels: list, log_confidence:bool = False) -> str:
"""
Vote between the LLM router and BART model.
Args:
@ -202,17 +216,15 @@ class AgentRouter:
Returns:
str: The selected label
"""
result_bart = self.pipelines['bart'](text, labels, threshold=0.3)
result_bart = self.pipelines['bart'](text, labels)
result_llm_router = self.llm_router(text)
bart, confidence_bart = result_bart['labels'][0], result_bart['scores'][0]
llm_router, confidence_llm_router = result_llm_router[0], result_llm_router[1]
confidence_bart *= 0.8 # was always a bit too confident
print("BART:", bart, "LLM Router:", llm_router)
print("Confidence BART:", confidence_bart, "Confidence LLM Router:", confidence_llm_router)
if confidence_bart > confidence_llm_router:
return bart
else:
return llm_router
final_score_bart = confidence_bart / (confidence_bart + confidence_llm_router)
final_score_llm = confidence_llm_router / (confidence_bart + confidence_llm_router)
if log_confidence:
pretty_print(f"Agent choice -> BART: {bart} ({final_score_bart}) LLM-router: {llm_router} ({final_score_llm})")
return bart if final_score_bart > final_score_llm else llm_router
def classify_text(self, text: str, threshold: float = 0.4) -> list:
"""
@ -227,8 +239,9 @@ class AgentRouter:
first_sentence = text
try:
lang = self.lang_analysis.detect_language(first_sentence)
labels = [agent.role[lang] for agent in self.agents]
result = self.router_vote(first_sentence, labels)
# only use english role labels for now, we don't have a multilingual router yet
labels = [agent.role["en"] for agent in self.agents]
result = self.router_vote(first_sentence, labels, log_confidence=True)
except Exception as e:
raise e
return result, lang