diff --git a/sources/agents/agent.py b/sources/agents/agent.py index 6f1e93d..d662e55 100644 --- a/sources/agents/agent.py +++ b/sources/agents/agent.py @@ -39,9 +39,7 @@ class Agent(): self.type = None self.current_directory = os.getcwd() self.llm = provider - self.memory = Memory(self.load_prompt(prompt_path), - recover_last_session=False, # session recovery in handled by the interaction class - memory_compression=False) + self.memory = None self.tools = {} self.blocks_result = [] self.success = True diff --git a/sources/agents/browser_agent.py b/sources/agents/browser_agent.py index 268262e..648d700 100644 --- a/sources/agents/browser_agent.py +++ b/sources/agents/browser_agent.py @@ -10,6 +10,7 @@ from sources.agents.agent import Agent from sources.tools.searxSearch import searxSearch from sources.browser import Browser from sources.logger import Logger +from sources.memory import Memory class Action(Enum): REQUEST_EXIT = "REQUEST_EXIT" @@ -37,6 +38,10 @@ class BrowserAgent(Agent): self.notes = [] self.date = self.get_today_date() self.logger = Logger("browser_agent.log") + self.memory = Memory(self.load_prompt(prompt_path), + recover_last_session=False, # session recovery in handled by the interaction class + memory_compression=False, + model_provider=provider.get_model_name()) def get_today_date(self) -> str: """Get the date""" diff --git a/sources/agents/casual_agent.py b/sources/agents/casual_agent.py index f756f27..a3219a6 100644 --- a/sources/agents/casual_agent.py +++ b/sources/agents/casual_agent.py @@ -6,6 +6,7 @@ from sources.tools.searxSearch import searxSearch from sources.tools.flightSearch import FlightSearch from sources.tools.fileFinder import FileFinder from sources.tools.BashInterpreter import BashInterpreter +from sources.memory import Memory class CasualAgent(Agent): def __init__(self, name, prompt_path, provider, verbose=False): @@ -17,6 +18,10 @@ class CasualAgent(Agent): } # No tools for the casual agent self.role = "talk" self.type = "casual_agent" + self.memory = Memory(self.load_prompt(prompt_path), + recover_last_session=False, # session recovery in handled by the interaction class + memory_compression=False, + model_provider=provider.get_model_name()) async def process(self, prompt, speech_module) -> str: self.memory.push('user', prompt) diff --git a/sources/agents/code_agent.py b/sources/agents/code_agent.py index e291057..fd030de 100644 --- a/sources/agents/code_agent.py +++ b/sources/agents/code_agent.py @@ -10,6 +10,7 @@ from sources.tools.BashInterpreter import BashInterpreter from sources.tools.JavaInterpreter import JavaInterpreter from sources.tools.fileFinder import FileFinder from sources.logger import Logger +from sources.memory import Memory class CoderAgent(Agent): """ @@ -29,6 +30,10 @@ class CoderAgent(Agent): self.role = "code" self.type = "code_agent" self.logger = Logger("code_agent.log") + self.memory = Memory(self.load_prompt(prompt_path), + recover_last_session=False, # session recovery in handled by the interaction class + memory_compression=False, + model_provider=provider.get_model_name()) def add_sys_info_prompt(self, prompt): """Add system information to the prompt.""" diff --git a/sources/agents/file_agent.py b/sources/agents/file_agent.py index f3169b1..8dbd8e8 100644 --- a/sources/agents/file_agent.py +++ b/sources/agents/file_agent.py @@ -4,6 +4,7 @@ from sources.utility import pretty_print, animate_thinking from sources.agents.agent import Agent from sources.tools.fileFinder import FileFinder from sources.tools.BashInterpreter import BashInterpreter +from sources.memory import Memory class FileAgent(Agent): def __init__(self, name, prompt_path, provider, verbose=False): @@ -18,6 +19,10 @@ class FileAgent(Agent): self.work_dir = self.tools["file_finder"].get_work_dir() self.role = "files" self.type = "file_agent" + self.memory = Memory(self.load_prompt(prompt_path), + recover_last_session=False, # session recovery in handled by the interaction class + memory_compression=False, + model_provider=provider.get_model_name()) async def process(self, prompt, speech_module) -> str: exec_success = False diff --git a/sources/agents/mcp_agent.py b/sources/agents/mcp_agent.py index b51b76d..528cd40 100644 --- a/sources/agents/mcp_agent.py +++ b/sources/agents/mcp_agent.py @@ -4,6 +4,7 @@ import asyncio from sources.utility import pretty_print, animate_thinking from sources.agents.agent import Agent from sources.tools.mcpFinder import MCP_finder +from sources.memory import Memory # NOTE MCP agent is an active work in progress, not functional yet. @@ -22,6 +23,10 @@ class McpAgent(Agent): } self.role = "mcp" self.type = "mcp_agent" + self.memory = Memory(self.load_prompt(prompt_path), + recover_last_session=False, # session recovery in handled by the interaction class + memory_compression=False, + model_provider=provider.get_model_name()) self.enabled = True def get_api_keys(self) -> dict: diff --git a/sources/agents/planner_agent.py b/sources/agents/planner_agent.py index 2b6a34c..7955e32 100644 --- a/sources/agents/planner_agent.py +++ b/sources/agents/planner_agent.py @@ -9,6 +9,7 @@ from sources.agents.casual_agent import CasualAgent from sources.text_to_speech import Speech from sources.tools.tools import Tools from sources.logger import Logger +from sources.memory import Memory class PlannerAgent(Agent): def __init__(self, name, prompt_path, provider, verbose=False, browser=None): @@ -29,6 +30,10 @@ class PlannerAgent(Agent): } self.role = "planification" self.type = "planner_agent" + self.memory = Memory(self.load_prompt(prompt_path), + recover_last_session=False, # session recovery in handled by the interaction class + memory_compression=False, + model_provider=provider.get_model_name()) self.logger = Logger("planner_agent.log") def get_task_names(self, text: str) -> List[str]: diff --git a/sources/llm_provider.py b/sources/llm_provider.py index 32b8279..cde9d6d 100644 --- a/sources/llm_provider.py +++ b/sources/llm_provider.py @@ -44,6 +44,9 @@ class Provider: self.api_key = self.get_api_key(self.provider_name) elif self.provider_name != "ollama": pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success") + + def get_model_name(self) -> str: + return self.model def get_api_key(self, provider): load_dotenv() diff --git a/sources/memory.py b/sources/memory.py index 2814120..66ed83d 100644 --- a/sources/memory.py +++ b/sources/memory.py @@ -8,7 +8,7 @@ from typing import List, Tuple, Type, Dict import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM -from sources.utility import timer_decorator, pretty_print +from sources.utility import timer_decorator, pretty_print, animate_thinking from sources.logger import Logger class Memory(): @@ -18,7 +18,8 @@ class Memory(): """ def __init__(self, system_prompt: str, recover_last_session: bool = False, - memory_compression: bool = True): + memory_compression: bool = True, + model_provider: str = "deepseek-r1:14b"): self.memory = [] self.memory = [{'role': 'system', 'content': system_prompt}] @@ -31,21 +32,42 @@ class Memory(): self.load_memory() self.session_recovered = True # memory compression system - self.model = "pszemraj/led-base-book-summary" + self.model = None + self.tokenizer = None self.device = self.get_cuda_device() self.memory_compression = memory_compression - self.tokenizer = None - self.model = None + self.model_provider = model_provider if self.memory_compression: self.download_model() + + def get_ideal_ctx(self, model_name: str) -> int: + """ + Estimate context size based on the model name. + """ + import re + import math + + def extract_number_before_b(sentence: str) -> int: + match = re.search(r'(\d+)b', sentence, re.IGNORECASE) + return int(match.group(1)) if match else None + + model_size = extract_number_before_b(model_name) + if not model_size: + return None + base_size = 7 # Base model size in billions + base_context = 4096 # Base context size in tokens + scaling_factor = 1.5 # Approximate scaling factor for context size growth + context_size = int(base_context * (model_size / base_size) ** scaling_factor) + context_size = 2 ** round(math.log2(context_size)) + self.logger.info(f"Estimated context size for {model_name}: {context_size} tokens.") + return context_size def download_model(self): """Download the model if not already downloaded.""" - pretty_print("Downloading memory compression model...", color="status") - self.tokenizer = AutoTokenizer.from_pretrained(self.model) - self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model) + animate_thinking("Loading memory compression model...", color="status") + self.tokenizer = AutoTokenizer.from_pretrained("pszemraj/led-base-book-summary") + self.model = AutoModelForSeq2SeqLM.from_pretrained("pszemraj/led-base-book-summary") self.logger.info("Memory compression system initialized.") - def get_filename(self) -> str: """Get the filename for the save file.""" @@ -106,13 +128,15 @@ class Memory(): def push(self, role: str, content: str) -> int: """Push a message to the memory.""" - if self.memory_compression and role == 'assistant': - self.logger.info("Compressing memories on message push.") + ideal_ctx = self.get_ideal_ctx(self.model_provider) + if self.memory_compression and len(content) > ideal_ctx: + self.logger.info(f"Compressing memory: Content {len(content)} > {ideal_ctx} model context.") self.compress() curr_idx = len(self.memory) if self.memory[curr_idx-1]['content'] == content: pretty_print("Warning: same message have been pushed twice to memory", color="error") - self.memory.append({'role': role, 'content': content}) + time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.memory.append({'role': role, 'content': content, 'time': time_str, 'model_used': self.model_provider}) return curr_idx-1 def clear(self) -> None: @@ -182,11 +206,9 @@ class Memory(): self.logger.warning("No tokenizer or model to perform memory compression.") return for i in range(len(self.memory)): - if i < 2: - continue if self.memory[i]['role'] == 'system': continue - if len(self.memory[i]['content']) > 128: + if len(self.memory[i]['content']) > 2048: self.memory[i]['content'] = self.summarize(self.memory[i]['content']) if __name__ == "__main__":