feat : improve memory system

This commit is contained in:
martin legrand 2025-05-04 21:54:01 +02:00
parent 5949540007
commit a7deffedec
9 changed files with 71 additions and 18 deletions

View File

@ -39,9 +39,7 @@ class Agent():
self.type = None
self.current_directory = os.getcwd()
self.llm = provider
self.memory = Memory(self.load_prompt(prompt_path),
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False)
self.memory = None
self.tools = {}
self.blocks_result = []
self.success = True

View File

@ -10,6 +10,7 @@ from sources.agents.agent import Agent
from sources.tools.searxSearch import searxSearch
from sources.browser import Browser
from sources.logger import Logger
from sources.memory import Memory
class Action(Enum):
REQUEST_EXIT = "REQUEST_EXIT"
@ -37,6 +38,10 @@ class BrowserAgent(Agent):
self.notes = []
self.date = self.get_today_date()
self.logger = Logger("browser_agent.log")
self.memory = Memory(self.load_prompt(prompt_path),
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False,
model_provider=provider.get_model_name())
def get_today_date(self) -> str:
"""Get the date"""

View File

@ -6,6 +6,7 @@ from sources.tools.searxSearch import searxSearch
from sources.tools.flightSearch import FlightSearch
from sources.tools.fileFinder import FileFinder
from sources.tools.BashInterpreter import BashInterpreter
from sources.memory import Memory
class CasualAgent(Agent):
def __init__(self, name, prompt_path, provider, verbose=False):
@ -17,6 +18,10 @@ class CasualAgent(Agent):
} # No tools for the casual agent
self.role = "talk"
self.type = "casual_agent"
self.memory = Memory(self.load_prompt(prompt_path),
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False,
model_provider=provider.get_model_name())
async def process(self, prompt, speech_module) -> str:
self.memory.push('user', prompt)

View File

@ -10,6 +10,7 @@ from sources.tools.BashInterpreter import BashInterpreter
from sources.tools.JavaInterpreter import JavaInterpreter
from sources.tools.fileFinder import FileFinder
from sources.logger import Logger
from sources.memory import Memory
class CoderAgent(Agent):
"""
@ -29,6 +30,10 @@ class CoderAgent(Agent):
self.role = "code"
self.type = "code_agent"
self.logger = Logger("code_agent.log")
self.memory = Memory(self.load_prompt(prompt_path),
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False,
model_provider=provider.get_model_name())
def add_sys_info_prompt(self, prompt):
"""Add system information to the prompt."""

View File

@ -4,6 +4,7 @@ from sources.utility import pretty_print, animate_thinking
from sources.agents.agent import Agent
from sources.tools.fileFinder import FileFinder
from sources.tools.BashInterpreter import BashInterpreter
from sources.memory import Memory
class FileAgent(Agent):
def __init__(self, name, prompt_path, provider, verbose=False):
@ -18,6 +19,10 @@ class FileAgent(Agent):
self.work_dir = self.tools["file_finder"].get_work_dir()
self.role = "files"
self.type = "file_agent"
self.memory = Memory(self.load_prompt(prompt_path),
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False,
model_provider=provider.get_model_name())
async def process(self, prompt, speech_module) -> str:
exec_success = False

View File

@ -4,6 +4,7 @@ import asyncio
from sources.utility import pretty_print, animate_thinking
from sources.agents.agent import Agent
from sources.tools.mcpFinder import MCP_finder
from sources.memory import Memory
# NOTE MCP agent is an active work in progress, not functional yet.
@ -22,6 +23,10 @@ class McpAgent(Agent):
}
self.role = "mcp"
self.type = "mcp_agent"
self.memory = Memory(self.load_prompt(prompt_path),
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False,
model_provider=provider.get_model_name())
self.enabled = True
def get_api_keys(self) -> dict:

View File

@ -9,6 +9,7 @@ from sources.agents.casual_agent import CasualAgent
from sources.text_to_speech import Speech
from sources.tools.tools import Tools
from sources.logger import Logger
from sources.memory import Memory
class PlannerAgent(Agent):
def __init__(self, name, prompt_path, provider, verbose=False, browser=None):
@ -29,6 +30,10 @@ class PlannerAgent(Agent):
}
self.role = "planification"
self.type = "planner_agent"
self.memory = Memory(self.load_prompt(prompt_path),
recover_last_session=False, # session recovery in handled by the interaction class
memory_compression=False,
model_provider=provider.get_model_name())
self.logger = Logger("planner_agent.log")
def get_task_names(self, text: str) -> List[str]:

View File

@ -44,6 +44,9 @@ class Provider:
self.api_key = self.get_api_key(self.provider_name)
elif self.provider_name != "ollama":
pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success")
def get_model_name(self) -> str:
return self.model
def get_api_key(self, provider):
load_dotenv()

View File

@ -8,7 +8,7 @@ from typing import List, Tuple, Type, Dict
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sources.utility import timer_decorator, pretty_print
from sources.utility import timer_decorator, pretty_print, animate_thinking
from sources.logger import Logger
class Memory():
@ -18,7 +18,8 @@ class Memory():
"""
def __init__(self, system_prompt: str,
recover_last_session: bool = False,
memory_compression: bool = True):
memory_compression: bool = True,
model_provider: str = "deepseek-r1:14b"):
self.memory = []
self.memory = [{'role': 'system', 'content': system_prompt}]
@ -31,21 +32,42 @@ class Memory():
self.load_memory()
self.session_recovered = True
# memory compression system
self.model = "pszemraj/led-base-book-summary"
self.model = None
self.tokenizer = None
self.device = self.get_cuda_device()
self.memory_compression = memory_compression
self.tokenizer = None
self.model = None
self.model_provider = model_provider
if self.memory_compression:
self.download_model()
def get_ideal_ctx(self, model_name: str) -> int:
"""
Estimate context size based on the model name.
"""
import re
import math
def extract_number_before_b(sentence: str) -> int:
match = re.search(r'(\d+)b', sentence, re.IGNORECASE)
return int(match.group(1)) if match else None
model_size = extract_number_before_b(model_name)
if not model_size:
return None
base_size = 7 # Base model size in billions
base_context = 4096 # Base context size in tokens
scaling_factor = 1.5 # Approximate scaling factor for context size growth
context_size = int(base_context * (model_size / base_size) ** scaling_factor)
context_size = 2 ** round(math.log2(context_size))
self.logger.info(f"Estimated context size for {model_name}: {context_size} tokens.")
return context_size
def download_model(self):
"""Download the model if not already downloaded."""
pretty_print("Downloading memory compression model...", color="status")
self.tokenizer = AutoTokenizer.from_pretrained(self.model)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model)
animate_thinking("Loading memory compression model...", color="status")
self.tokenizer = AutoTokenizer.from_pretrained("pszemraj/led-base-book-summary")
self.model = AutoModelForSeq2SeqLM.from_pretrained("pszemraj/led-base-book-summary")
self.logger.info("Memory compression system initialized.")
def get_filename(self) -> str:
"""Get the filename for the save file."""
@ -106,13 +128,15 @@ class Memory():
def push(self, role: str, content: str) -> int:
"""Push a message to the memory."""
if self.memory_compression and role == 'assistant':
self.logger.info("Compressing memories on message push.")
ideal_ctx = self.get_ideal_ctx(self.model_provider)
if self.memory_compression and len(content) > ideal_ctx:
self.logger.info(f"Compressing memory: Content {len(content)} > {ideal_ctx} model context.")
self.compress()
curr_idx = len(self.memory)
if self.memory[curr_idx-1]['content'] == content:
pretty_print("Warning: same message have been pushed twice to memory", color="error")
self.memory.append({'role': role, 'content': content})
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.memory.append({'role': role, 'content': content, 'time': time_str, 'model_used': self.model_provider})
return curr_idx-1
def clear(self) -> None:
@ -182,11 +206,9 @@ class Memory():
self.logger.warning("No tokenizer or model to perform memory compression.")
return
for i in range(len(self.memory)):
if i < 2:
continue
if self.memory[i]['role'] == 'system':
continue
if len(self.memory[i]['content']) > 128:
if len(self.memory[i]['content']) > 2048:
self.memory[i]['content'] = self.summarize(self.memory[i]['content'])
if __name__ == "__main__":