mirror of
https://github.com/tcsenpai/agenticSeek.git
synced 2025-06-05 02:25:27 +00:00
Feat : agent router system + casual agent
This commit is contained in:
parent
dcb5724e28
commit
ca49a7f0a5
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,6 +1,8 @@
|
||||
*.wav
|
||||
config.ini
|
||||
experimental/
|
||||
.env
|
||||
*/.env
|
||||
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
|
16
main.py
16
main.py
@ -8,11 +8,11 @@ import configparser
|
||||
from sources.llm_provider import Provider
|
||||
from sources.interaction import Interaction
|
||||
from sources.code_agent import CoderAgent
|
||||
|
||||
from sources.casual_agent import CasualAgent
|
||||
|
||||
parser = argparse.ArgumentParser(description='Deepseek AI assistant')
|
||||
parser.add_argument('--speak', action='store_true',
|
||||
help='Make AI use text-to-speech')
|
||||
parser.add_argument('--no-speak', action='store_true',
|
||||
help='Make AI not use text-to-speech')
|
||||
args = parser.parse_args()
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
@ -31,12 +31,18 @@ def main():
|
||||
model=config["MAIN"]["provider_model"],
|
||||
server_address=config["MAIN"]["provider_server_address"])
|
||||
|
||||
agent = CoderAgent(model=config["MAIN"]["provider_model"],
|
||||
agents = [
|
||||
CoderAgent(model=config["MAIN"]["provider_model"],
|
||||
name=config["MAIN"]["agent_name"],
|
||||
prompt_path="prompts/coder_agent.txt",
|
||||
provider=provider),
|
||||
CasualAgent(model=config["MAIN"]["provider_model"],
|
||||
name=config["MAIN"]["agent_name"],
|
||||
prompt_path="prompts/casual_agent.txt",
|
||||
provider=provider)
|
||||
]
|
||||
|
||||
interaction = Interaction([agent], tts_enabled=config.getboolean('MAIN', 'speak'),
|
||||
interaction = Interaction(agents, tts_enabled=config.getboolean('MAIN', 'speak'),
|
||||
recover_last_session=config.getboolean('MAIN', 'recover_last_session'))
|
||||
while interaction.is_active:
|
||||
interaction.get_user()
|
||||
|
@ -12,8 +12,8 @@ kokoro==0.7.12
|
||||
flask==3.1.0
|
||||
soundfile==0.13.1
|
||||
protobuf==3.20.3
|
||||
termcolor==2.3.0
|
||||
openai==1.13.3
|
||||
termcolor==2.5.0
|
||||
gliclass==0.1.8
|
||||
# if use chinese
|
||||
ordered_set
|
||||
pypinyin
|
||||
|
@ -2,6 +2,10 @@ from typing import Tuple, Callable
|
||||
from abc import abstractmethod
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sources.memory import Memory
|
||||
from sources.utility import pretty_print
|
||||
|
||||
@ -25,6 +29,7 @@ class Agent():
|
||||
provider,
|
||||
recover_last_session=False) -> None:
|
||||
self.agent_name = name
|
||||
self.role = None
|
||||
self.current_directory = os.getcwd()
|
||||
self.model = model
|
||||
self.llm = provider
|
||||
@ -60,7 +65,14 @@ class Agent():
|
||||
raise e
|
||||
|
||||
@abstractmethod
|
||||
def answer(self, prompt, speech_module) -> str:
|
||||
def show_answer(self):
|
||||
"""
|
||||
abstract method, implementation in child class.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process(self, prompt, speech_module) -> str:
|
||||
"""
|
||||
abstract method, implementation in child class.
|
||||
"""
|
||||
@ -78,7 +90,7 @@ class Agent():
|
||||
end_idx = text.rfind(end_tag)+8
|
||||
return text[start_idx:end_idx]
|
||||
|
||||
def llm_request(self, verbose = True) -> Tuple[str, str]:
|
||||
def llm_request(self, verbose = False) -> Tuple[str, str]:
|
||||
memory = self.memory.get()
|
||||
thought = self.llm.respond(memory, verbose)
|
||||
|
||||
@ -95,16 +107,30 @@ class Agent():
|
||||
"Working on it sir, please let me think."]
|
||||
speech_module.speak(messages[random.randint(0, len(messages)-1)])
|
||||
|
||||
def print_code_blocks(self, blocks: list, name: str):
|
||||
for block in blocks:
|
||||
pretty_print(f"Executing {name} code...\n", color="output")
|
||||
pretty_print("-"*100, color="output")
|
||||
pretty_print(block, color="code")
|
||||
pretty_print("-"*100, color="output")
|
||||
|
||||
def get_blocks_result(self) -> list:
|
||||
return self.blocks_result
|
||||
|
||||
def remove_blocks(self, text: str) -> str:
|
||||
"""
|
||||
Remove all code/query blocks within a tag from the answer text.
|
||||
"""
|
||||
tag = f'```'
|
||||
lines = text.split('\n')
|
||||
post_lines = []
|
||||
in_block = False
|
||||
block_idx = 0
|
||||
for line in lines:
|
||||
if tag in line and not in_block:
|
||||
in_block = True
|
||||
continue
|
||||
if not in_block:
|
||||
post_lines.append(line)
|
||||
if tag in line:
|
||||
in_block = False
|
||||
post_lines.append(f"block:{block_idx}")
|
||||
block_idx += 1
|
||||
return "\n".join(post_lines)
|
||||
|
||||
def execute_modules(self, answer: str) -> Tuple[bool, str]:
|
||||
feedback = ""
|
||||
success = False
|
||||
|
36
sources/casual_agent.py
Normal file
36
sources/casual_agent.py
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
from sources.utility import pretty_print
|
||||
from sources.agent import Agent
|
||||
|
||||
class CasualAgent(Agent):
|
||||
def __init__(self, model, name, prompt_path, provider):
|
||||
"""
|
||||
The casual agent is a special for casual talk to the user without specific tasks.
|
||||
"""
|
||||
super().__init__(model, name, prompt_path, provider)
|
||||
self.tools = {
|
||||
} # TODO implement casual tools like basic web search, basic file search, basic image search, basic knowledge search
|
||||
self.role = "talking"
|
||||
|
||||
def show_answer(self):
|
||||
lines = self.last_answer.split("\n")
|
||||
for line in lines:
|
||||
pretty_print(line, color="output")
|
||||
|
||||
def process(self, prompt, speech_module) -> str:
|
||||
self.memory.push('user', prompt)
|
||||
|
||||
pretty_print("Thinking...", color="status")
|
||||
self.wait_message(speech_module)
|
||||
answer, reasoning = self.llm_request()
|
||||
self.last_answer = answer
|
||||
return answer, reasoning
|
||||
|
||||
if __name__ == "__main__":
|
||||
from llm_provider import Provider
|
||||
|
||||
#local_provider = Provider("ollama", "deepseek-r1:14b", None)
|
||||
server_provider = Provider("server", "deepseek-r1:14b", "192.168.1.100:5000")
|
||||
agent = CasualAgent("deepseek-r1:14b", "jarvis", "prompts/casual_agent.txt", server_provider)
|
||||
ans = agent.process("Hello, how are you?")
|
||||
print(ans)
|
@ -4,36 +4,17 @@ from sources.agent import Agent, executorResult
|
||||
from sources.tools import PyInterpreter, BashInterpreter, CInterpreter, GoInterpreter
|
||||
|
||||
class CoderAgent(Agent):
|
||||
"""
|
||||
The code agent is a special for writing code and shell commands.
|
||||
"""
|
||||
def __init__(self, model, name, prompt_path, provider):
|
||||
super().__init__(model, name, prompt_path, provider)
|
||||
self.tools = {
|
||||
"bash": BashInterpreter(),
|
||||
"python": PyInterpreter()
|
||||
"C": CInterpreter(),
|
||||
"go": GoInterpreter()
|
||||
}
|
||||
self.role = "coding"
|
||||
|
||||
def remove_blocks(self, text: str) -> str:
|
||||
"""
|
||||
Remove all code/query blocks within a tag from the answer text.
|
||||
"""
|
||||
tag = f'```'
|
||||
lines = text.split('\n')
|
||||
post_lines = []
|
||||
in_block = False
|
||||
block_idx = 0
|
||||
for line in lines:
|
||||
if tag in line and not in_block:
|
||||
in_block = True
|
||||
continue
|
||||
if not in_block:
|
||||
post_lines.append(line)
|
||||
if tag in line:
|
||||
in_block = False
|
||||
post_lines.append(f"block:{block_idx}")
|
||||
block_idx += 1
|
||||
return "\n".join(post_lines)
|
||||
|
||||
def show_answer(self):
|
||||
lines = self.last_answer.split("\n")
|
||||
for line in lines:
|
||||
|
@ -1,11 +1,14 @@
|
||||
|
||||
from sources.text_to_speech import Speech
|
||||
from sources.utility import pretty_print
|
||||
from sources.router import AgentRouter
|
||||
|
||||
class Interaction:
|
||||
def __init__(self, agents, tts_enabled: bool = False, recover_last_session: bool = False):
|
||||
self.tts_enabled = tts_enabled
|
||||
self.agents = agents
|
||||
self.current_agent = None
|
||||
self.router = AgentRouter(self.agents)
|
||||
self.speech = Speech()
|
||||
self.is_active = True
|
||||
self.last_query = None
|
||||
@ -44,10 +47,15 @@ class Interaction:
|
||||
return query
|
||||
|
||||
def think(self):
|
||||
self.last_answer, _ = self.agents[0].process(self.last_query, self.speech)
|
||||
agent = self.router.select_agent(self.last_query)
|
||||
if self.current_agent != agent:
|
||||
self.current_agent = agent
|
||||
# get history from previous agent
|
||||
self.current_agent.memory.push('user', self.last_query)
|
||||
self.last_answer, _ = agent.process(self.last_query, self.speech)
|
||||
|
||||
def show_answer(self):
|
||||
self.agents[0].show_answer()
|
||||
self.current_agent.show_answer()
|
||||
if self.tts_enabled:
|
||||
self.speech.speak(self.last_answer)
|
||||
|
||||
|
@ -4,7 +4,11 @@ import time
|
||||
import datetime
|
||||
import uuid
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sources.utility import timer_decorator
|
||||
|
||||
class Memory():
|
||||
|
60
sources/router.py
Normal file
60
sources/router.py
Normal file
@ -0,0 +1,60 @@
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
from sources.agent import Agent
|
||||
from sources.code_agent import CoderAgent
|
||||
from sources.casual_agent import CasualAgent
|
||||
from sources.utility import pretty_print
|
||||
|
||||
class AgentRouter:
|
||||
def __init__(self, agents: list, model_name="facebook/bart-large-mnli"):
|
||||
self.model = model_name
|
||||
self.pipeline = pipeline("zero-shot-classification",
|
||||
model=self.model)
|
||||
self.agents = agents
|
||||
self.labels = [agent.role for agent in agents]
|
||||
|
||||
def get_device(self):
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
elif torch.cuda.is_available():
|
||||
return "cuda:0"
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
def classify_text(self, text, threshold=0.5):
|
||||
result = self.pipeline(text, self.labels, threshold=threshold)
|
||||
return result
|
||||
|
||||
def select_agent(self, text: str) -> Agent:
|
||||
result = self.classify_text(text)
|
||||
for agent in self.agents:
|
||||
if result["labels"][0] == agent.role:
|
||||
pretty_print(f"Selected agent role: {agent.role}", color="warning")
|
||||
return agent
|
||||
return None
|
||||
|
||||
if __name__ == "__main__":
|
||||
agents = [
|
||||
CoderAgent("deepseek-r1:14b", "agent1", "../prompts/coder_agent.txt", "server"),
|
||||
CasualAgent("deepseek-r1:14b", "agent2", "../prompts/casual_agent.txt", "server")
|
||||
]
|
||||
router = AgentRouter(agents)
|
||||
|
||||
texts = ["""
|
||||
Write a python script to check if the device on my network is connected to the internet
|
||||
""",
|
||||
"""
|
||||
Hey could you search the web for the latest news on the stock market ?
|
||||
""",
|
||||
"""
|
||||
hey can you give dating advice ?
|
||||
"""
|
||||
]
|
||||
|
||||
for text in texts:
|
||||
print(text)
|
||||
results = router.classify_text(text)
|
||||
for result in results:
|
||||
print(result["label"], "=>", result["score"])
|
||||
agent = router.select_agent(text)
|
||||
print("Selected agent role:", agent.role)
|
Loading…
x
Reference in New Issue
Block a user