agenticSeek/sources/router.py

63 lines
2.0 KiB
Python

import torch
from transformers import pipeline
from sources.agent import Agent
from sources.code_agent import CoderAgent
from sources.casual_agent import CasualAgent
from sources.utility import pretty_print
class AgentRouter:
def __init__(self, agents: list, model_name="facebook/bart-large-mnli"):
self.model = model_name
self.pipeline = pipeline("zero-shot-classification",
model=self.model)
self.agents = agents
self.labels = [agent.role for agent in agents]
def get_device(self):
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda:0"
else:
return "cpu"
def classify_text(self, text, threshold=0.5):
result = self.pipeline(text, self.labels, threshold=threshold)
return result
def select_agent(self, text: str) -> Agent:
if len(self.agents) == 0 or len(self.labels) == 0:
return self.agents[0]
result = self.classify_text(text)
for agent in self.agents:
if result["labels"][0] == agent.role:
pretty_print(f"Selected agent: {agent.agent_name}", color="warning")
return agent
return None
if __name__ == "__main__":
agents = [
CoderAgent("deepseek-r1:14b", "agent1", "../prompts/coder_agent.txt", "server"),
CasualAgent("deepseek-r1:14b", "agent2", "../prompts/casual_agent.txt", "server")
]
router = AgentRouter(agents)
texts = ["""
Write a python script to check if the device on my network is connected to the internet
""",
"""
Hey could you search the web for the latest news on the stock market ?
""",
"""
hey can you give dating advice ?
"""
]
for text in texts:
print(text)
results = router.classify_text(text)
for result in results:
print(result["label"], "=>", result["score"])
agent = router.select_agent(text)
print("Selected agent role:", agent.role)