mirror of
https://github.com/tcsenpai/agenticSeek.git
synced 2025-06-06 11:05:26 +00:00
Feat : agent router system + casual agent
This commit is contained in:
parent
dcb5724e28
commit
ca49a7f0a5
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,6 +1,8 @@
|
|||||||
*.wav
|
*.wav
|
||||||
config.ini
|
config.ini
|
||||||
experimental/
|
experimental/
|
||||||
|
.env
|
||||||
|
*/.env
|
||||||
|
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
|
16
main.py
16
main.py
@ -8,11 +8,11 @@ import configparser
|
|||||||
from sources.llm_provider import Provider
|
from sources.llm_provider import Provider
|
||||||
from sources.interaction import Interaction
|
from sources.interaction import Interaction
|
||||||
from sources.code_agent import CoderAgent
|
from sources.code_agent import CoderAgent
|
||||||
|
from sources.casual_agent import CasualAgent
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Deepseek AI assistant')
|
parser = argparse.ArgumentParser(description='Deepseek AI assistant')
|
||||||
parser.add_argument('--speak', action='store_true',
|
parser.add_argument('--no-speak', action='store_true',
|
||||||
help='Make AI use text-to-speech')
|
help='Make AI not use text-to-speech')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@ -31,12 +31,18 @@ def main():
|
|||||||
model=config["MAIN"]["provider_model"],
|
model=config["MAIN"]["provider_model"],
|
||||||
server_address=config["MAIN"]["provider_server_address"])
|
server_address=config["MAIN"]["provider_server_address"])
|
||||||
|
|
||||||
agent = CoderAgent(model=config["MAIN"]["provider_model"],
|
agents = [
|
||||||
|
CoderAgent(model=config["MAIN"]["provider_model"],
|
||||||
name=config["MAIN"]["agent_name"],
|
name=config["MAIN"]["agent_name"],
|
||||||
prompt_path="prompts/coder_agent.txt",
|
prompt_path="prompts/coder_agent.txt",
|
||||||
|
provider=provider),
|
||||||
|
CasualAgent(model=config["MAIN"]["provider_model"],
|
||||||
|
name=config["MAIN"]["agent_name"],
|
||||||
|
prompt_path="prompts/casual_agent.txt",
|
||||||
provider=provider)
|
provider=provider)
|
||||||
|
]
|
||||||
|
|
||||||
interaction = Interaction([agent], tts_enabled=config.getboolean('MAIN', 'speak'),
|
interaction = Interaction(agents, tts_enabled=config.getboolean('MAIN', 'speak'),
|
||||||
recover_last_session=config.getboolean('MAIN', 'recover_last_session'))
|
recover_last_session=config.getboolean('MAIN', 'recover_last_session'))
|
||||||
while interaction.is_active:
|
while interaction.is_active:
|
||||||
interaction.get_user()
|
interaction.get_user()
|
||||||
|
@ -12,8 +12,8 @@ kokoro==0.7.12
|
|||||||
flask==3.1.0
|
flask==3.1.0
|
||||||
soundfile==0.13.1
|
soundfile==0.13.1
|
||||||
protobuf==3.20.3
|
protobuf==3.20.3
|
||||||
termcolor==2.3.0
|
termcolor==2.5.0
|
||||||
openai==1.13.3
|
gliclass==0.1.8
|
||||||
# if use chinese
|
# if use chinese
|
||||||
ordered_set
|
ordered_set
|
||||||
pypinyin
|
pypinyin
|
||||||
|
@ -2,6 +2,10 @@ from typing import Tuple, Callable
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
from sources.memory import Memory
|
from sources.memory import Memory
|
||||||
from sources.utility import pretty_print
|
from sources.utility import pretty_print
|
||||||
|
|
||||||
@ -25,6 +29,7 @@ class Agent():
|
|||||||
provider,
|
provider,
|
||||||
recover_last_session=False) -> None:
|
recover_last_session=False) -> None:
|
||||||
self.agent_name = name
|
self.agent_name = name
|
||||||
|
self.role = None
|
||||||
self.current_directory = os.getcwd()
|
self.current_directory = os.getcwd()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.llm = provider
|
self.llm = provider
|
||||||
@ -60,7 +65,14 @@ class Agent():
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def answer(self, prompt, speech_module) -> str:
|
def show_answer(self):
|
||||||
|
"""
|
||||||
|
abstract method, implementation in child class.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process(self, prompt, speech_module) -> str:
|
||||||
"""
|
"""
|
||||||
abstract method, implementation in child class.
|
abstract method, implementation in child class.
|
||||||
"""
|
"""
|
||||||
@ -78,7 +90,7 @@ class Agent():
|
|||||||
end_idx = text.rfind(end_tag)+8
|
end_idx = text.rfind(end_tag)+8
|
||||||
return text[start_idx:end_idx]
|
return text[start_idx:end_idx]
|
||||||
|
|
||||||
def llm_request(self, verbose = True) -> Tuple[str, str]:
|
def llm_request(self, verbose = False) -> Tuple[str, str]:
|
||||||
memory = self.memory.get()
|
memory = self.memory.get()
|
||||||
thought = self.llm.respond(memory, verbose)
|
thought = self.llm.respond(memory, verbose)
|
||||||
|
|
||||||
@ -95,16 +107,30 @@ class Agent():
|
|||||||
"Working on it sir, please let me think."]
|
"Working on it sir, please let me think."]
|
||||||
speech_module.speak(messages[random.randint(0, len(messages)-1)])
|
speech_module.speak(messages[random.randint(0, len(messages)-1)])
|
||||||
|
|
||||||
def print_code_blocks(self, blocks: list, name: str):
|
|
||||||
for block in blocks:
|
|
||||||
pretty_print(f"Executing {name} code...\n", color="output")
|
|
||||||
pretty_print("-"*100, color="output")
|
|
||||||
pretty_print(block, color="code")
|
|
||||||
pretty_print("-"*100, color="output")
|
|
||||||
|
|
||||||
def get_blocks_result(self) -> list:
|
def get_blocks_result(self) -> list:
|
||||||
return self.blocks_result
|
return self.blocks_result
|
||||||
|
|
||||||
|
def remove_blocks(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove all code/query blocks within a tag from the answer text.
|
||||||
|
"""
|
||||||
|
tag = f'```'
|
||||||
|
lines = text.split('\n')
|
||||||
|
post_lines = []
|
||||||
|
in_block = False
|
||||||
|
block_idx = 0
|
||||||
|
for line in lines:
|
||||||
|
if tag in line and not in_block:
|
||||||
|
in_block = True
|
||||||
|
continue
|
||||||
|
if not in_block:
|
||||||
|
post_lines.append(line)
|
||||||
|
if tag in line:
|
||||||
|
in_block = False
|
||||||
|
post_lines.append(f"block:{block_idx}")
|
||||||
|
block_idx += 1
|
||||||
|
return "\n".join(post_lines)
|
||||||
|
|
||||||
def execute_modules(self, answer: str) -> Tuple[bool, str]:
|
def execute_modules(self, answer: str) -> Tuple[bool, str]:
|
||||||
feedback = ""
|
feedback = ""
|
||||||
success = False
|
success = False
|
||||||
|
36
sources/casual_agent.py
Normal file
36
sources/casual_agent.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
|
||||||
|
from sources.utility import pretty_print
|
||||||
|
from sources.agent import Agent
|
||||||
|
|
||||||
|
class CasualAgent(Agent):
|
||||||
|
def __init__(self, model, name, prompt_path, provider):
|
||||||
|
"""
|
||||||
|
The casual agent is a special for casual talk to the user without specific tasks.
|
||||||
|
"""
|
||||||
|
super().__init__(model, name, prompt_path, provider)
|
||||||
|
self.tools = {
|
||||||
|
} # TODO implement casual tools like basic web search, basic file search, basic image search, basic knowledge search
|
||||||
|
self.role = "talking"
|
||||||
|
|
||||||
|
def show_answer(self):
|
||||||
|
lines = self.last_answer.split("\n")
|
||||||
|
for line in lines:
|
||||||
|
pretty_print(line, color="output")
|
||||||
|
|
||||||
|
def process(self, prompt, speech_module) -> str:
|
||||||
|
self.memory.push('user', prompt)
|
||||||
|
|
||||||
|
pretty_print("Thinking...", color="status")
|
||||||
|
self.wait_message(speech_module)
|
||||||
|
answer, reasoning = self.llm_request()
|
||||||
|
self.last_answer = answer
|
||||||
|
return answer, reasoning
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from llm_provider import Provider
|
||||||
|
|
||||||
|
#local_provider = Provider("ollama", "deepseek-r1:14b", None)
|
||||||
|
server_provider = Provider("server", "deepseek-r1:14b", "192.168.1.100:5000")
|
||||||
|
agent = CasualAgent("deepseek-r1:14b", "jarvis", "prompts/casual_agent.txt", server_provider)
|
||||||
|
ans = agent.process("Hello, how are you?")
|
||||||
|
print(ans)
|
@ -4,36 +4,17 @@ from sources.agent import Agent, executorResult
|
|||||||
from sources.tools import PyInterpreter, BashInterpreter, CInterpreter, GoInterpreter
|
from sources.tools import PyInterpreter, BashInterpreter, CInterpreter, GoInterpreter
|
||||||
|
|
||||||
class CoderAgent(Agent):
|
class CoderAgent(Agent):
|
||||||
|
"""
|
||||||
|
The code agent is a special for writing code and shell commands.
|
||||||
|
"""
|
||||||
def __init__(self, model, name, prompt_path, provider):
|
def __init__(self, model, name, prompt_path, provider):
|
||||||
super().__init__(model, name, prompt_path, provider)
|
super().__init__(model, name, prompt_path, provider)
|
||||||
self.tools = {
|
self.tools = {
|
||||||
"bash": BashInterpreter(),
|
"bash": BashInterpreter(),
|
||||||
"python": PyInterpreter()
|
"python": PyInterpreter()
|
||||||
"C": CInterpreter(),
|
|
||||||
"go": GoInterpreter()
|
|
||||||
}
|
}
|
||||||
|
self.role = "coding"
|
||||||
|
|
||||||
def remove_blocks(self, text: str) -> str:
|
|
||||||
"""
|
|
||||||
Remove all code/query blocks within a tag from the answer text.
|
|
||||||
"""
|
|
||||||
tag = f'```'
|
|
||||||
lines = text.split('\n')
|
|
||||||
post_lines = []
|
|
||||||
in_block = False
|
|
||||||
block_idx = 0
|
|
||||||
for line in lines:
|
|
||||||
if tag in line and not in_block:
|
|
||||||
in_block = True
|
|
||||||
continue
|
|
||||||
if not in_block:
|
|
||||||
post_lines.append(line)
|
|
||||||
if tag in line:
|
|
||||||
in_block = False
|
|
||||||
post_lines.append(f"block:{block_idx}")
|
|
||||||
block_idx += 1
|
|
||||||
return "\n".join(post_lines)
|
|
||||||
|
|
||||||
def show_answer(self):
|
def show_answer(self):
|
||||||
lines = self.last_answer.split("\n")
|
lines = self.last_answer.split("\n")
|
||||||
for line in lines:
|
for line in lines:
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
|
|
||||||
from sources.text_to_speech import Speech
|
from sources.text_to_speech import Speech
|
||||||
from sources.utility import pretty_print
|
from sources.utility import pretty_print
|
||||||
|
from sources.router import AgentRouter
|
||||||
|
|
||||||
class Interaction:
|
class Interaction:
|
||||||
def __init__(self, agents, tts_enabled: bool = False, recover_last_session: bool = False):
|
def __init__(self, agents, tts_enabled: bool = False, recover_last_session: bool = False):
|
||||||
self.tts_enabled = tts_enabled
|
self.tts_enabled = tts_enabled
|
||||||
self.agents = agents
|
self.agents = agents
|
||||||
|
self.current_agent = None
|
||||||
|
self.router = AgentRouter(self.agents)
|
||||||
self.speech = Speech()
|
self.speech = Speech()
|
||||||
self.is_active = True
|
self.is_active = True
|
||||||
self.last_query = None
|
self.last_query = None
|
||||||
@ -44,10 +47,15 @@ class Interaction:
|
|||||||
return query
|
return query
|
||||||
|
|
||||||
def think(self):
|
def think(self):
|
||||||
self.last_answer, _ = self.agents[0].process(self.last_query, self.speech)
|
agent = self.router.select_agent(self.last_query)
|
||||||
|
if self.current_agent != agent:
|
||||||
|
self.current_agent = agent
|
||||||
|
# get history from previous agent
|
||||||
|
self.current_agent.memory.push('user', self.last_query)
|
||||||
|
self.last_answer, _ = agent.process(self.last_query, self.speech)
|
||||||
|
|
||||||
def show_answer(self):
|
def show_answer(self):
|
||||||
self.agents[0].show_answer()
|
self.current_agent.show_answer()
|
||||||
if self.tts_enabled:
|
if self.tts_enabled:
|
||||||
self.speech.speak(self.last_answer)
|
self.speech.speak(self.last_answer)
|
||||||
|
|
||||||
|
@ -4,7 +4,11 @@ import time
|
|||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
from sources.utility import timer_decorator
|
from sources.utility import timer_decorator
|
||||||
|
|
||||||
class Memory():
|
class Memory():
|
||||||
|
60
sources/router.py
Normal file
60
sources/router.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import pipeline
|
||||||
|
from sources.agent import Agent
|
||||||
|
from sources.code_agent import CoderAgent
|
||||||
|
from sources.casual_agent import CasualAgent
|
||||||
|
from sources.utility import pretty_print
|
||||||
|
|
||||||
|
class AgentRouter:
|
||||||
|
def __init__(self, agents: list, model_name="facebook/bart-large-mnli"):
|
||||||
|
self.model = model_name
|
||||||
|
self.pipeline = pipeline("zero-shot-classification",
|
||||||
|
model=self.model)
|
||||||
|
self.agents = agents
|
||||||
|
self.labels = [agent.role for agent in agents]
|
||||||
|
|
||||||
|
def get_device(self):
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return "mps"
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
return "cuda:0"
|
||||||
|
else:
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
def classify_text(self, text, threshold=0.5):
|
||||||
|
result = self.pipeline(text, self.labels, threshold=threshold)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def select_agent(self, text: str) -> Agent:
|
||||||
|
result = self.classify_text(text)
|
||||||
|
for agent in self.agents:
|
||||||
|
if result["labels"][0] == agent.role:
|
||||||
|
pretty_print(f"Selected agent role: {agent.role}", color="warning")
|
||||||
|
return agent
|
||||||
|
return None
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
agents = [
|
||||||
|
CoderAgent("deepseek-r1:14b", "agent1", "../prompts/coder_agent.txt", "server"),
|
||||||
|
CasualAgent("deepseek-r1:14b", "agent2", "../prompts/casual_agent.txt", "server")
|
||||||
|
]
|
||||||
|
router = AgentRouter(agents)
|
||||||
|
|
||||||
|
texts = ["""
|
||||||
|
Write a python script to check if the device on my network is connected to the internet
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
Hey could you search the web for the latest news on the stock market ?
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
hey can you give dating advice ?
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
print(text)
|
||||||
|
results = router.classify_text(text)
|
||||||
|
for result in results:
|
||||||
|
print(result["label"], "=>", result["score"])
|
||||||
|
agent = router.select_agent(text)
|
||||||
|
print("Selected agent role:", agent.role)
|
Loading…
x
Reference in New Issue
Block a user