diff --git a/.gitignore b/.gitignore index bd4ab5e..cf95114 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ *.wav config.ini experimental/ +.env +*/.env # Byte-compiled / optimized / DLL files diff --git a/main.py b/main.py index e40fab6..88478d0 100755 --- a/main.py +++ b/main.py @@ -8,11 +8,11 @@ import configparser from sources.llm_provider import Provider from sources.interaction import Interaction from sources.code_agent import CoderAgent - +from sources.casual_agent import CasualAgent parser = argparse.ArgumentParser(description='Deepseek AI assistant') -parser.add_argument('--speak', action='store_true', - help='Make AI use text-to-speech') +parser.add_argument('--no-speak', action='store_true', + help='Make AI not use text-to-speech') args = parser.parse_args() config = configparser.ConfigParser() @@ -31,12 +31,18 @@ def main(): model=config["MAIN"]["provider_model"], server_address=config["MAIN"]["provider_server_address"]) - agent = CoderAgent(model=config["MAIN"]["provider_model"], + agents = [ + CoderAgent(model=config["MAIN"]["provider_model"], name=config["MAIN"]["agent_name"], prompt_path="prompts/coder_agent.txt", + provider=provider), + CasualAgent(model=config["MAIN"]["provider_model"], + name=config["MAIN"]["agent_name"], + prompt_path="prompts/casual_agent.txt", provider=provider) + ] - interaction = Interaction([agent], tts_enabled=config.getboolean('MAIN', 'speak'), + interaction = Interaction(agents, tts_enabled=config.getboolean('MAIN', 'speak'), recover_last_session=config.getboolean('MAIN', 'recover_last_session')) while interaction.is_active: interaction.get_user() diff --git a/requirements.txt b/requirements.txt index ddbc9f5..e122ccd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,8 +12,8 @@ kokoro==0.7.12 flask==3.1.0 soundfile==0.13.1 protobuf==3.20.3 -termcolor==2.3.0 -openai==1.13.3 +termcolor==2.5.0 +gliclass==0.1.8 # if use chinese ordered_set pypinyin diff --git a/sources/agent.py b/sources/agent.py index 4b3e8d0..7963fa4 100644 --- a/sources/agent.py +++ b/sources/agent.py @@ -2,6 +2,10 @@ from typing import Tuple, Callable from abc import abstractmethod import os import random +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from sources.memory import Memory from sources.utility import pretty_print @@ -25,6 +29,7 @@ class Agent(): provider, recover_last_session=False) -> None: self.agent_name = name + self.role = None self.current_directory = os.getcwd() self.model = model self.llm = provider @@ -60,7 +65,14 @@ class Agent(): raise e @abstractmethod - def answer(self, prompt, speech_module) -> str: + def show_answer(self): + """ + abstract method, implementation in child class. + """ + pass + + @abstractmethod + def process(self, prompt, speech_module) -> str: """ abstract method, implementation in child class. """ @@ -78,7 +90,7 @@ class Agent(): end_idx = text.rfind(end_tag)+8 return text[start_idx:end_idx] - def llm_request(self, verbose = True) -> Tuple[str, str]: + def llm_request(self, verbose = False) -> Tuple[str, str]: memory = self.memory.get() thought = self.llm.respond(memory, verbose) @@ -95,16 +107,30 @@ class Agent(): "Working on it sir, please let me think."] speech_module.speak(messages[random.randint(0, len(messages)-1)]) - def print_code_blocks(self, blocks: list, name: str): - for block in blocks: - pretty_print(f"Executing {name} code...\n", color="output") - pretty_print("-"*100, color="output") - pretty_print(block, color="code") - pretty_print("-"*100, color="output") - def get_blocks_result(self) -> list: return self.blocks_result + def remove_blocks(self, text: str) -> str: + """ + Remove all code/query blocks within a tag from the answer text. + """ + tag = f'```' + lines = text.split('\n') + post_lines = [] + in_block = False + block_idx = 0 + for line in lines: + if tag in line and not in_block: + in_block = True + continue + if not in_block: + post_lines.append(line) + if tag in line: + in_block = False + post_lines.append(f"block:{block_idx}") + block_idx += 1 + return "\n".join(post_lines) + def execute_modules(self, answer: str) -> Tuple[bool, str]: feedback = "" success = False diff --git a/sources/casual_agent.py b/sources/casual_agent.py new file mode 100644 index 0000000..8e12c9d --- /dev/null +++ b/sources/casual_agent.py @@ -0,0 +1,36 @@ + +from sources.utility import pretty_print +from sources.agent import Agent + +class CasualAgent(Agent): + def __init__(self, model, name, prompt_path, provider): + """ + The casual agent is a special for casual talk to the user without specific tasks. + """ + super().__init__(model, name, prompt_path, provider) + self.tools = { + } # TODO implement casual tools like basic web search, basic file search, basic image search, basic knowledge search + self.role = "talking" + + def show_answer(self): + lines = self.last_answer.split("\n") + for line in lines: + pretty_print(line, color="output") + + def process(self, prompt, speech_module) -> str: + self.memory.push('user', prompt) + + pretty_print("Thinking...", color="status") + self.wait_message(speech_module) + answer, reasoning = self.llm_request() + self.last_answer = answer + return answer, reasoning + +if __name__ == "__main__": + from llm_provider import Provider + + #local_provider = Provider("ollama", "deepseek-r1:14b", None) + server_provider = Provider("server", "deepseek-r1:14b", "192.168.1.100:5000") + agent = CasualAgent("deepseek-r1:14b", "jarvis", "prompts/casual_agent.txt", server_provider) + ans = agent.process("Hello, how are you?") + print(ans) \ No newline at end of file diff --git a/sources/code_agent.py b/sources/code_agent.py index dc5967b..bc4096c 100644 --- a/sources/code_agent.py +++ b/sources/code_agent.py @@ -4,36 +4,17 @@ from sources.agent import Agent, executorResult from sources.tools import PyInterpreter, BashInterpreter, CInterpreter, GoInterpreter class CoderAgent(Agent): + """ + The code agent is a special for writing code and shell commands. + """ def __init__(self, model, name, prompt_path, provider): super().__init__(model, name, prompt_path, provider) self.tools = { "bash": BashInterpreter(), "python": PyInterpreter() - "C": CInterpreter(), - "go": GoInterpreter() } + self.role = "coding" - def remove_blocks(self, text: str) -> str: - """ - Remove all code/query blocks within a tag from the answer text. - """ - tag = f'```' - lines = text.split('\n') - post_lines = [] - in_block = False - block_idx = 0 - for line in lines: - if tag in line and not in_block: - in_block = True - continue - if not in_block: - post_lines.append(line) - if tag in line: - in_block = False - post_lines.append(f"block:{block_idx}") - block_idx += 1 - return "\n".join(post_lines) - def show_answer(self): lines = self.last_answer.split("\n") for line in lines: diff --git a/sources/interaction.py b/sources/interaction.py index fe4ebd5..2d61ac1 100644 --- a/sources/interaction.py +++ b/sources/interaction.py @@ -1,11 +1,14 @@ from sources.text_to_speech import Speech from sources.utility import pretty_print +from sources.router import AgentRouter class Interaction: def __init__(self, agents, tts_enabled: bool = False, recover_last_session: bool = False): self.tts_enabled = tts_enabled self.agents = agents + self.current_agent = None + self.router = AgentRouter(self.agents) self.speech = Speech() self.is_active = True self.last_query = None @@ -44,10 +47,15 @@ class Interaction: return query def think(self): - self.last_answer, _ = self.agents[0].process(self.last_query, self.speech) + agent = self.router.select_agent(self.last_query) + if self.current_agent != agent: + self.current_agent = agent + # get history from previous agent + self.current_agent.memory.push('user', self.last_query) + self.last_answer, _ = agent.process(self.last_query, self.speech) def show_answer(self): - self.agents[0].show_answer() + self.current_agent.show_answer() if self.tts_enabled: self.speech.speak(self.last_answer) diff --git a/sources/memory.py b/sources/memory.py index 53df96c..8790de1 100644 --- a/sources/memory.py +++ b/sources/memory.py @@ -4,7 +4,11 @@ import time import datetime import uuid import os +import sys import json + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from sources.utility import timer_decorator class Memory(): diff --git a/sources/router.py b/sources/router.py new file mode 100644 index 0000000..876c507 --- /dev/null +++ b/sources/router.py @@ -0,0 +1,60 @@ +import torch +from transformers import pipeline +from sources.agent import Agent +from sources.code_agent import CoderAgent +from sources.casual_agent import CasualAgent +from sources.utility import pretty_print + +class AgentRouter: + def __init__(self, agents: list, model_name="facebook/bart-large-mnli"): + self.model = model_name + self.pipeline = pipeline("zero-shot-classification", + model=self.model) + self.agents = agents + self.labels = [agent.role for agent in agents] + + def get_device(self): + if torch.backends.mps.is_available(): + return "mps" + elif torch.cuda.is_available(): + return "cuda:0" + else: + return "cpu" + + def classify_text(self, text, threshold=0.5): + result = self.pipeline(text, self.labels, threshold=threshold) + return result + + def select_agent(self, text: str) -> Agent: + result = self.classify_text(text) + for agent in self.agents: + if result["labels"][0] == agent.role: + pretty_print(f"Selected agent role: {agent.role}", color="warning") + return agent + return None + +if __name__ == "__main__": + agents = [ + CoderAgent("deepseek-r1:14b", "agent1", "../prompts/coder_agent.txt", "server"), + CasualAgent("deepseek-r1:14b", "agent2", "../prompts/casual_agent.txt", "server") + ] + router = AgentRouter(agents) + + texts = [""" + Write a python script to check if the device on my network is connected to the internet + """, + """ + Hey could you search the web for the latest news on the stock market ? + """, + """ + hey can you give dating advice ? + """ + ] + + for text in texts: + print(text) + results = router.classify_text(text) + for result in results: + print(result["label"], "=>", result["score"]) + agent = router.select_agent(text) + print("Selected agent role:", agent.role)