feat : improve memory system

This commit is contained in:
martin legrand 2025-05-04 21:54:01 +02:00
parent 5949540007
commit a7deffedec
9 changed files with 71 additions and 18 deletions

View File

@ -39,9 +39,7 @@ class Agent():
self.type = None self.type = None
self.current_directory = os.getcwd() self.current_directory = os.getcwd()
self.llm = provider self.llm = provider
self.memory = Memory(self.load_prompt(prompt_path), self.memory = None
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False)
self.tools = {} self.tools = {}
self.blocks_result = [] self.blocks_result = []
self.success = True self.success = True

View File

@ -10,6 +10,7 @@ from sources.agents.agent import Agent
from sources.tools.searxSearch import searxSearch from sources.tools.searxSearch import searxSearch
from sources.browser import Browser from sources.browser import Browser
from sources.logger import Logger from sources.logger import Logger
from sources.memory import Memory
class Action(Enum): class Action(Enum):
REQUEST_EXIT = "REQUEST_EXIT" REQUEST_EXIT = "REQUEST_EXIT"
@ -37,6 +38,10 @@ class BrowserAgent(Agent):
self.notes = [] self.notes = []
self.date = self.get_today_date() self.date = self.get_today_date()
self.logger = Logger("browser_agent.log") self.logger = Logger("browser_agent.log")
self.memory = Memory(self.load_prompt(prompt_path),
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False,
model_provider=provider.get_model_name())
def get_today_date(self) -> str: def get_today_date(self) -> str:
"""Get the date""" """Get the date"""

View File

@ -6,6 +6,7 @@ from sources.tools.searxSearch import searxSearch
from sources.tools.flightSearch import FlightSearch from sources.tools.flightSearch import FlightSearch
from sources.tools.fileFinder import FileFinder from sources.tools.fileFinder import FileFinder
from sources.tools.BashInterpreter import BashInterpreter from sources.tools.BashInterpreter import BashInterpreter
from sources.memory import Memory
class CasualAgent(Agent): class CasualAgent(Agent):
def __init__(self, name, prompt_path, provider, verbose=False): def __init__(self, name, prompt_path, provider, verbose=False):
@ -17,6 +18,10 @@ class CasualAgent(Agent):
} # No tools for the casual agent } # No tools for the casual agent
self.role = "talk" self.role = "talk"
self.type = "casual_agent" self.type = "casual_agent"
self.memory = Memory(self.load_prompt(prompt_path),
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False,
model_provider=provider.get_model_name())
async def process(self, prompt, speech_module) -> str: async def process(self, prompt, speech_module) -> str:
self.memory.push('user', prompt) self.memory.push('user', prompt)

View File

@ -10,6 +10,7 @@ from sources.tools.BashInterpreter import BashInterpreter
from sources.tools.JavaInterpreter import JavaInterpreter from sources.tools.JavaInterpreter import JavaInterpreter
from sources.tools.fileFinder import FileFinder from sources.tools.fileFinder import FileFinder
from sources.logger import Logger from sources.logger import Logger
from sources.memory import Memory
class CoderAgent(Agent): class CoderAgent(Agent):
""" """
@ -29,6 +30,10 @@ class CoderAgent(Agent):
self.role = "code" self.role = "code"
self.type = "code_agent" self.type = "code_agent"
self.logger = Logger("code_agent.log") self.logger = Logger("code_agent.log")
self.memory = Memory(self.load_prompt(prompt_path),
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False,
model_provider=provider.get_model_name())
def add_sys_info_prompt(self, prompt): def add_sys_info_prompt(self, prompt):
"""Add system information to the prompt.""" """Add system information to the prompt."""

View File

@ -4,6 +4,7 @@ from sources.utility import pretty_print, animate_thinking
from sources.agents.agent import Agent from sources.agents.agent import Agent
from sources.tools.fileFinder import FileFinder from sources.tools.fileFinder import FileFinder
from sources.tools.BashInterpreter import BashInterpreter from sources.tools.BashInterpreter import BashInterpreter
from sources.memory import Memory
class FileAgent(Agent): class FileAgent(Agent):
def __init__(self, name, prompt_path, provider, verbose=False): def __init__(self, name, prompt_path, provider, verbose=False):
@ -18,6 +19,10 @@ class FileAgent(Agent):
self.work_dir = self.tools["file_finder"].get_work_dir() self.work_dir = self.tools["file_finder"].get_work_dir()
self.role = "files" self.role = "files"
self.type = "file_agent" self.type = "file_agent"
self.memory = Memory(self.load_prompt(prompt_path),
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False,
model_provider=provider.get_model_name())
async def process(self, prompt, speech_module) -> str: async def process(self, prompt, speech_module) -> str:
exec_success = False exec_success = False

View File

@ -4,6 +4,7 @@ import asyncio
from sources.utility import pretty_print, animate_thinking from sources.utility import pretty_print, animate_thinking
from sources.agents.agent import Agent from sources.agents.agent import Agent
from sources.tools.mcpFinder import MCP_finder from sources.tools.mcpFinder import MCP_finder
from sources.memory import Memory
# NOTE MCP agent is an active work in progress, not functional yet. # NOTE MCP agent is an active work in progress, not functional yet.
@ -22,6 +23,10 @@ class McpAgent(Agent):
} }
self.role = "mcp" self.role = "mcp"
self.type = "mcp_agent" self.type = "mcp_agent"
self.memory = Memory(self.load_prompt(prompt_path),
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False,
model_provider=provider.get_model_name())
self.enabled = True self.enabled = True
def get_api_keys(self) -> dict: def get_api_keys(self) -> dict:

View File

@ -9,6 +9,7 @@ from sources.agents.casual_agent import CasualAgent
from sources.text_to_speech import Speech from sources.text_to_speech import Speech
from sources.tools.tools import Tools from sources.tools.tools import Tools
from sources.logger import Logger from sources.logger import Logger
from sources.memory import Memory
class PlannerAgent(Agent): class PlannerAgent(Agent):
def __init__(self, name, prompt_path, provider, verbose=False, browser=None): def __init__(self, name, prompt_path, provider, verbose=False, browser=None):
@ -29,6 +30,10 @@ class PlannerAgent(Agent):
} }
self.role = "planification" self.role = "planification"
self.type = "planner_agent" self.type = "planner_agent"
self.memory = Memory(self.load_prompt(prompt_path),
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False,
model_provider=provider.get_model_name())
self.logger = Logger("planner_agent.log") self.logger = Logger("planner_agent.log")
def get_task_names(self, text: str) -> List[str]: def get_task_names(self, text: str) -> List[str]:

View File

@ -44,6 +44,9 @@ class Provider:
self.api_key = self.get_api_key(self.provider_name) self.api_key = self.get_api_key(self.provider_name)
elif self.provider_name != "ollama": elif self.provider_name != "ollama":
pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success") pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success")
def get_model_name(self) -> str:
return self.model
def get_api_key(self, provider): def get_api_key(self, provider):
load_dotenv() load_dotenv()

View File

@ -8,7 +8,7 @@ from typing import List, Tuple, Type, Dict
import torch import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sources.utility import timer_decorator, pretty_print from sources.utility import timer_decorator, pretty_print, animate_thinking
from sources.logger import Logger from sources.logger import Logger
class Memory(): class Memory():
@ -18,7 +18,8 @@ class Memory():
""" """
def __init__(self, system_prompt: str, def __init__(self, system_prompt: str,
recover_last_session: bool = False, recover_last_session: bool = False,
memory_compression: bool = True): memory_compression: bool = True,
model_provider: str = "deepseek-r1:14b"):
self.memory = [] self.memory = []
self.memory = [{'role': 'system', 'content': system_prompt}] self.memory = [{'role': 'system', 'content': system_prompt}]
@ -31,21 +32,42 @@ class Memory():
self.load_memory() self.load_memory()
self.session_recovered = True self.session_recovered = True
# memory compression system # memory compression system
self.model = "pszemraj/led-base-book-summary" self.model = None
self.tokenizer = None
self.device = self.get_cuda_device() self.device = self.get_cuda_device()
self.memory_compression = memory_compression self.memory_compression = memory_compression
self.tokenizer = None self.model_provider = model_provider
self.model = None
if self.memory_compression: if self.memory_compression:
self.download_model() self.download_model()
def get_ideal_ctx(self, model_name: str) -> int:
"""
Estimate context size based on the model name.
"""
import re
import math
def extract_number_before_b(sentence: str) -> int:
match = re.search(r'(\d+)b', sentence, re.IGNORECASE)
return int(match.group(1)) if match else None
model_size = extract_number_before_b(model_name)
if not model_size:
return None
base_size = 7 # Base model size in billions
base_context = 4096 # Base context size in tokens
scaling_factor = 1.5 # Approximate scaling factor for context size growth
context_size = int(base_context * (model_size / base_size) ** scaling_factor)
context_size = 2 ** round(math.log2(context_size))
self.logger.info(f"Estimated context size for {model_name}: {context_size} tokens.")
return context_size
def download_model(self): def download_model(self):
"""Download the model if not already downloaded.""" """Download the model if not already downloaded."""
pretty_print("Downloading memory compression model...", color="status") animate_thinking("Loading memory compression model...", color="status")
self.tokenizer = AutoTokenizer.from_pretrained(self.model) self.tokenizer = AutoTokenizer.from_pretrained("pszemraj/led-base-book-summary")
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model) self.model = AutoModelForSeq2SeqLM.from_pretrained("pszemraj/led-base-book-summary")
self.logger.info("Memory compression system initialized.") self.logger.info("Memory compression system initialized.")
def get_filename(self) -> str: def get_filename(self) -> str:
"""Get the filename for the save file.""" """Get the filename for the save file."""
@ -106,13 +128,15 @@ class Memory():
def push(self, role: str, content: str) -> int: def push(self, role: str, content: str) -> int:
"""Push a message to the memory.""" """Push a message to the memory."""
if self.memory_compression and role == 'assistant': ideal_ctx = self.get_ideal_ctx(self.model_provider)
self.logger.info("Compressing memories on message push.") if self.memory_compression and len(content) > ideal_ctx:
self.logger.info(f"Compressing memory: Content {len(content)} > {ideal_ctx} model context.")
self.compress() self.compress()
curr_idx = len(self.memory) curr_idx = len(self.memory)
if self.memory[curr_idx-1]['content'] == content: if self.memory[curr_idx-1]['content'] == content:
pretty_print("Warning: same message have been pushed twice to memory", color="error") pretty_print("Warning: same message have been pushed twice to memory", color="error")
self.memory.append({'role': role, 'content': content}) time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.memory.append({'role': role, 'content': content, 'time': time_str, 'model_used': self.model_provider})
return curr_idx-1 return curr_idx-1
def clear(self) -> None: def clear(self) -> None:
@ -182,11 +206,9 @@ class Memory():
self.logger.warning("No tokenizer or model to perform memory compression.") self.logger.warning("No tokenizer or model to perform memory compression.")
return return
for i in range(len(self.memory)): for i in range(len(self.memory)):
if i < 2:
continue
if self.memory[i]['role'] == 'system': if self.memory[i]['role'] == 'system':
continue continue
if len(self.memory[i]['content']) > 128: if len(self.memory[i]['content']) > 2048:
self.memory[i]['content'] = self.summarize(self.memory[i]['content']) self.memory[i]['content'] = self.summarize(self.memory[i]['content'])
if __name__ == "__main__": if __name__ == "__main__":