From 8991aaae8d7d70acd9b2434982d5c8ccc44f245d Mon Sep 17 00:00:00 2001 From: martin legrand Date: Sat, 5 Apr 2025 16:39:16 +0200 Subject: [PATCH] feat : planner agent improvement --- sources/agents/planner_agent.py | 56 +++++++++++++++++++++++---------- sources/browser.py | 2 +- sources/interaction.py | 2 +- sources/language.py | 2 +- sources/llm_provider.py | 17 +--------- sources/logger.py | 2 +- sources/memory.py | 2 +- sources/router.py | 2 +- sources/text_to_speech.py | 2 +- 9 files changed, 47 insertions(+), 40 deletions(-) diff --git a/sources/agents/planner_agent.py b/sources/agents/planner_agent.py index 9d28b88..0b71793 100644 --- a/sources/agents/planner_agent.py +++ b/sources/agents/planner_agent.py @@ -1,9 +1,11 @@ import json +from typing import List, Tuple, Type, Dict from sources.utility import pretty_print, animate_thinking from sources.agents.agent import Agent from sources.agents.code_agent import CoderAgent from sources.agents.file_agent import FileAgent from sources.agents.browser_agent import BrowserAgent +from sources.text_to_speech import Speech from sources.tools.tools import Tools class PlannerAgent(Agent): @@ -61,18 +63,22 @@ class PlannerAgent(Agent): return zip(names, tasks) return zip(tasks_names, tasks) - def make_prompt(self, task, needed_infos): - if needed_infos is None: - needed_infos = "No needed informations." + def make_prompt(self, task: dict, agent_infos_dict: dict): + infos = "" + if agent_infos_dict is None or len(agent_infos_dict) == 0: + infos = "No needed informations." + else: + for agent_id, info in agent_infos_dict: + infos += f"\t- According to agent {agent_id}:\n{info}\n\n" prompt = f""" - You are given the following informations: - {needed_infos} + You are given informations from your AI friends work: + {infos} Your task is: {task} """ return prompt - def show_plan(self, json_plan): + def show_plan(self, json_plan: dict) -> None: agents_tasks = self.parse_agent_tasks(json_plan) if agents_tasks == (None, None): return @@ -81,11 +87,10 @@ class PlannerAgent(Agent): pretty_print(f"{task['agent']} -> {task['task']}", color="info") pretty_print("▔▗ E N D ▖▔", color="status") - def process(self, prompt, speech_module) -> str: + def make_plan(self, prompt: str) -> str: ok = False - agents_tasks = (None, None) + answer = None while not ok: - self.wait_message(speech_module) animate_thinking("Thinking...", color="status") self.memory.push('user', prompt) answer, _ = self.llm_request() @@ -96,25 +101,42 @@ class PlannerAgent(Agent): ok = True else: prompt = input("Please reformulate: ") + return answer + + def start_agent_process(self, task: str, required_infos: dict | None) -> str: + agent_prompt = self.make_prompt(task['task'], required_infos) + pretty_print(f"Agent {task['agent']} started working...", color="status") + agent_answer, _ = self.agents[task['agent'].lower()].process(agent_prompt, None) + self.agents[task['agent'].lower()].show_answer() + pretty_print(f"Agent {task['agent']} completed task.", color="status") + return agent_answer + + def get_work_result_agent(self, task_needs, agents_work_result): + return {k: agents_work_result[k] for k in task_needs if k in agents_work_result} + def process(self, prompt: str, speech_module: Speech) -> str: + agents_tasks = (None, None) + required_infos = None + agents_work_result = dict() + + answer = self.make_plan(prompt) agents_tasks = self.parse_agent_tasks(answer) + if agents_tasks == (None, None): return "Failed to parse the tasks", reasoning - prev_agent_answer = None for task_name, task in agents_tasks: pretty_print(f"I will {task_name}.", color="info") - agent_prompt = self.make_prompt(task['task'], prev_agent_answer) pretty_print(f"Assigned agent {task['agent']} to {task_name}", color="info") if speech_module: speech_module.speak(f"I will {task_name}. I assigned the {task['agent']} agent to the task.") + + if agents_work_result is not None: + required_infos = self.get_work_result_agent(task['need'], agents_work_result) try: - prev_agent_answer, _ = self.agents[task['agent'].lower()].process(agent_prompt, speech_module) - pretty_print(f"-- Agent answer ---\n\n", color="output") - self.agents[task['agent'].lower()].show_answer() - pretty_print(f"\n\n", color="output") + self.last_answer = self.start_agent_process(task, required_infos) except Exception as e: raise e - self.last_answer = prev_agent_answer - return prev_agent_answer, "" + agents_work_result[task['id']] = self.last_answer + return self.last_answer, "" if __name__ == "__main__": pass \ No newline at end of file diff --git a/sources/browser.py b/sources/browser.py index 6885eab..76c402d 100644 --- a/sources/browser.py +++ b/sources/browser.py @@ -6,7 +6,7 @@ from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC from selenium.common.exceptions import TimeoutException, WebDriverException from selenium.webdriver.common.action_chains import ActionChains -from typing import List, Tuple, Type, Dict, Tuple +from typing import List, Tuple, Type, Dict from bs4 import BeautifulSoup from urllib.parse import urlparse from fake_useragent import UserAgent diff --git a/sources/interaction.py b/sources/interaction.py index 4fef15f..0a75a47 100644 --- a/sources/interaction.py +++ b/sources/interaction.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Type, Dict, Tuple +from typing import List, Tuple, Type, Dict from sources.text_to_speech import Speech from sources.utility import pretty_print, animate_thinking diff --git a/sources/language.py b/sources/language.py index a7ed91f..457debb 100644 --- a/sources/language.py +++ b/sources/language.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Type, Dict, Tuple +from typing import List, Tuple, Type, Dict import re import langid import nltk diff --git a/sources/llm_provider.py b/sources/llm_provider.py index ab52af3..85e6f60 100644 --- a/sources/llm_provider.py +++ b/sources/llm_provider.py @@ -294,22 +294,7 @@ class Provider: This function is used to conduct tests. """ thought = """ -hello! -```python -print("Hello world from python") -``` - -This is ls -la from bash. -```bash -ls -la -``` - -This is pwd from bash. -```bash -pwd -``` - -goodbye! +\n\n```json\n{\n \"plan\": [\n {\n \"agent\": \"Web\",\n \"id\": \"1\",\n \"need\": null,\n \"task\": \"Conduct a comprehensive web search to identify at least five AI startups located in Osaka. Use reliable sources and websites such as Crunchbase, TechCrunch, or local Japanese business directories. Capture the company names, their websites, areas of expertise, and any other relevant details.\"\n },\n {\n \"agent\": \"Web\",\n \"id\": \"2\",\n \"need\": null,\n \"task\": \"Perform a similar search to find at least five AI startups in Tokyo. Again, use trusted sources like Crunchbase, TechCrunch, or Japanese business news websites. Gather the same details as for Osaka: company names, websites, areas of focus, and additional information.\"\n },\n {\n \"agent\": \"File\",\n \"id\": \"3\",\n \"need\": [\"1\", \"2\"],\n \"task\": \"Create a new text file named research_japan.txt in the user's home directory. Organize the data collected from both searches into this file, ensuring it is well-structured and formatted for readability. Include headers for Osaka and Tokyo sections, followed by the details of each startup found.\"\n }\n ]\n}\n``` """ return thought diff --git a/sources/logger.py b/sources/logger.py index af0d0f0..e22ca16 100644 --- a/sources/logger.py +++ b/sources/logger.py @@ -1,5 +1,5 @@ import os, sys -from typing import List, Tuple, Type, Dict, Tuple +from typing import List, Tuple, Type, Dict import datetime import logging diff --git a/sources/memory.py b/sources/memory.py index e48ff57..532bcc2 100644 --- a/sources/memory.py +++ b/sources/memory.py @@ -4,7 +4,7 @@ import uuid import os import sys import json -from typing import List, Tuple, Type, Dict, Tuple +from typing import List, Tuple, Type, Dict import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM diff --git a/sources/router.py b/sources/router.py index 56c0afe..ed7f849 100644 --- a/sources/router.py +++ b/sources/router.py @@ -1,7 +1,7 @@ import os import sys import torch -from typing import List, Tuple, Type, Dict, Tuple +from typing import List, Tuple, Type, Dict from transformers import pipeline from adaptive_classifier import AdaptiveClassifier diff --git a/sources/text_to_speech.py b/sources/text_to_speech.py index 32a96b9..e881f50 100644 --- a/sources/text_to_speech.py +++ b/sources/text_to_speech.py @@ -3,7 +3,7 @@ import re import platform import subprocess from sys import modules -from typing import List, Tuple, Type, Dict, Tuple +from typing import List, Tuple, Type, Dict from kokoro import KPipeline from IPython.display import display, Audio