mirror of
https://github.com/tcsenpai/agenticSeek.git
synced 2025-06-03 01:30:11 +00:00
feat : improve memory system
This commit is contained in:
parent
5949540007
commit
a7deffedec
@ -39,9 +39,7 @@ class Agent():
|
||||
self.type = None
|
||||
self.current_directory = os.getcwd()
|
||||
self.llm = provider
|
||||
self.memory = Memory(self.load_prompt(prompt_path),
|
||||
recover_last_session=False, # session recovery in handled by the interaction class
|
||||
memory_compression=False)
|
||||
self.memory = None
|
||||
self.tools = {}
|
||||
self.blocks_result = []
|
||||
self.success = True
|
||||
|
@ -10,6 +10,7 @@ from sources.agents.agent import Agent
|
||||
from sources.tools.searxSearch import searxSearch
|
||||
from sources.browser import Browser
|
||||
from sources.logger import Logger
|
||||
from sources.memory import Memory
|
||||
|
||||
class Action(Enum):
|
||||
REQUEST_EXIT = "REQUEST_EXIT"
|
||||
@ -37,6 +38,10 @@ class BrowserAgent(Agent):
|
||||
self.notes = []
|
||||
self.date = self.get_today_date()
|
||||
self.logger = Logger("browser_agent.log")
|
||||
self.memory = Memory(self.load_prompt(prompt_path),
|
||||
recover_last_session=False, # session recovery in handled by the interaction class
|
||||
memory_compression=False,
|
||||
model_provider=provider.get_model_name())
|
||||
|
||||
def get_today_date(self) -> str:
|
||||
"""Get the date"""
|
||||
|
@ -6,6 +6,7 @@ from sources.tools.searxSearch import searxSearch
|
||||
from sources.tools.flightSearch import FlightSearch
|
||||
from sources.tools.fileFinder import FileFinder
|
||||
from sources.tools.BashInterpreter import BashInterpreter
|
||||
from sources.memory import Memory
|
||||
|
||||
class CasualAgent(Agent):
|
||||
def __init__(self, name, prompt_path, provider, verbose=False):
|
||||
@ -17,6 +18,10 @@ class CasualAgent(Agent):
|
||||
} # No tools for the casual agent
|
||||
self.role = "talk"
|
||||
self.type = "casual_agent"
|
||||
self.memory = Memory(self.load_prompt(prompt_path),
|
||||
recover_last_session=False, # session recovery in handled by the interaction class
|
||||
memory_compression=False,
|
||||
model_provider=provider.get_model_name())
|
||||
|
||||
async def process(self, prompt, speech_module) -> str:
|
||||
self.memory.push('user', prompt)
|
||||
|
@ -10,6 +10,7 @@ from sources.tools.BashInterpreter import BashInterpreter
|
||||
from sources.tools.JavaInterpreter import JavaInterpreter
|
||||
from sources.tools.fileFinder import FileFinder
|
||||
from sources.logger import Logger
|
||||
from sources.memory import Memory
|
||||
|
||||
class CoderAgent(Agent):
|
||||
"""
|
||||
@ -29,6 +30,10 @@ class CoderAgent(Agent):
|
||||
self.role = "code"
|
||||
self.type = "code_agent"
|
||||
self.logger = Logger("code_agent.log")
|
||||
self.memory = Memory(self.load_prompt(prompt_path),
|
||||
recover_last_session=False, # session recovery in handled by the interaction class
|
||||
memory_compression=False,
|
||||
model_provider=provider.get_model_name())
|
||||
|
||||
def add_sys_info_prompt(self, prompt):
|
||||
"""Add system information to the prompt."""
|
||||
|
@ -4,6 +4,7 @@ from sources.utility import pretty_print, animate_thinking
|
||||
from sources.agents.agent import Agent
|
||||
from sources.tools.fileFinder import FileFinder
|
||||
from sources.tools.BashInterpreter import BashInterpreter
|
||||
from sources.memory import Memory
|
||||
|
||||
class FileAgent(Agent):
|
||||
def __init__(self, name, prompt_path, provider, verbose=False):
|
||||
@ -18,6 +19,10 @@ class FileAgent(Agent):
|
||||
self.work_dir = self.tools["file_finder"].get_work_dir()
|
||||
self.role = "files"
|
||||
self.type = "file_agent"
|
||||
self.memory = Memory(self.load_prompt(prompt_path),
|
||||
recover_last_session=False, # session recovery in handled by the interaction class
|
||||
memory_compression=False,
|
||||
model_provider=provider.get_model_name())
|
||||
|
||||
async def process(self, prompt, speech_module) -> str:
|
||||
exec_success = False
|
||||
|
@ -4,6 +4,7 @@ import asyncio
|
||||
from sources.utility import pretty_print, animate_thinking
|
||||
from sources.agents.agent import Agent
|
||||
from sources.tools.mcpFinder import MCP_finder
|
||||
from sources.memory import Memory
|
||||
|
||||
# NOTE MCP agent is an active work in progress, not functional yet.
|
||||
|
||||
@ -22,6 +23,10 @@ class McpAgent(Agent):
|
||||
}
|
||||
self.role = "mcp"
|
||||
self.type = "mcp_agent"
|
||||
self.memory = Memory(self.load_prompt(prompt_path),
|
||||
recover_last_session=False, # session recovery in handled by the interaction class
|
||||
memory_compression=False,
|
||||
model_provider=provider.get_model_name())
|
||||
self.enabled = True
|
||||
|
||||
def get_api_keys(self) -> dict:
|
||||
|
@ -9,6 +9,7 @@ from sources.agents.casual_agent import CasualAgent
|
||||
from sources.text_to_speech import Speech
|
||||
from sources.tools.tools import Tools
|
||||
from sources.logger import Logger
|
||||
from sources.memory import Memory
|
||||
|
||||
class PlannerAgent(Agent):
|
||||
def __init__(self, name, prompt_path, provider, verbose=False, browser=None):
|
||||
@ -29,6 +30,10 @@ class PlannerAgent(Agent):
|
||||
}
|
||||
self.role = "planification"
|
||||
self.type = "planner_agent"
|
||||
self.memory = Memory(self.load_prompt(prompt_path),
|
||||
recover_last_session=False, # session recovery in handled by the interaction class
|
||||
memory_compression=False,
|
||||
model_provider=provider.get_model_name())
|
||||
self.logger = Logger("planner_agent.log")
|
||||
|
||||
def get_task_names(self, text: str) -> List[str]:
|
||||
|
@ -44,6 +44,9 @@ class Provider:
|
||||
self.api_key = self.get_api_key(self.provider_name)
|
||||
elif self.provider_name != "ollama":
|
||||
pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success")
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
return self.model
|
||||
|
||||
def get_api_key(self, provider):
|
||||
load_dotenv()
|
||||
|
@ -8,7 +8,7 @@ from typing import List, Tuple, Type, Dict
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
|
||||
from sources.utility import timer_decorator, pretty_print
|
||||
from sources.utility import timer_decorator, pretty_print, animate_thinking
|
||||
from sources.logger import Logger
|
||||
|
||||
class Memory():
|
||||
@ -18,7 +18,8 @@ class Memory():
|
||||
"""
|
||||
def __init__(self, system_prompt: str,
|
||||
recover_last_session: bool = False,
|
||||
memory_compression: bool = True):
|
||||
memory_compression: bool = True,
|
||||
model_provider: str = "deepseek-r1:14b"):
|
||||
self.memory = []
|
||||
self.memory = [{'role': 'system', 'content': system_prompt}]
|
||||
|
||||
@ -31,21 +32,42 @@ class Memory():
|
||||
self.load_memory()
|
||||
self.session_recovered = True
|
||||
# memory compression system
|
||||
self.model = "pszemraj/led-base-book-summary"
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.device = self.get_cuda_device()
|
||||
self.memory_compression = memory_compression
|
||||
self.tokenizer = None
|
||||
self.model = None
|
||||
self.model_provider = model_provider
|
||||
if self.memory_compression:
|
||||
self.download_model()
|
||||
|
||||
def get_ideal_ctx(self, model_name: str) -> int:
|
||||
"""
|
||||
Estimate context size based on the model name.
|
||||
"""
|
||||
import re
|
||||
import math
|
||||
|
||||
def extract_number_before_b(sentence: str) -> int:
|
||||
match = re.search(r'(\d+)b', sentence, re.IGNORECASE)
|
||||
return int(match.group(1)) if match else None
|
||||
|
||||
model_size = extract_number_before_b(model_name)
|
||||
if not model_size:
|
||||
return None
|
||||
base_size = 7 # Base model size in billions
|
||||
base_context = 4096 # Base context size in tokens
|
||||
scaling_factor = 1.5 # Approximate scaling factor for context size growth
|
||||
context_size = int(base_context * (model_size / base_size) ** scaling_factor)
|
||||
context_size = 2 ** round(math.log2(context_size))
|
||||
self.logger.info(f"Estimated context size for {model_name}: {context_size} tokens.")
|
||||
return context_size
|
||||
|
||||
def download_model(self):
|
||||
"""Download the model if not already downloaded."""
|
||||
pretty_print("Downloading memory compression model...", color="status")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model)
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model)
|
||||
animate_thinking("Loading memory compression model...", color="status")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("pszemraj/led-base-book-summary")
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained("pszemraj/led-base-book-summary")
|
||||
self.logger.info("Memory compression system initialized.")
|
||||
|
||||
|
||||
def get_filename(self) -> str:
|
||||
"""Get the filename for the save file."""
|
||||
@ -106,13 +128,15 @@ class Memory():
|
||||
|
||||
def push(self, role: str, content: str) -> int:
|
||||
"""Push a message to the memory."""
|
||||
if self.memory_compression and role == 'assistant':
|
||||
self.logger.info("Compressing memories on message push.")
|
||||
ideal_ctx = self.get_ideal_ctx(self.model_provider)
|
||||
if self.memory_compression and len(content) > ideal_ctx:
|
||||
self.logger.info(f"Compressing memory: Content {len(content)} > {ideal_ctx} model context.")
|
||||
self.compress()
|
||||
curr_idx = len(self.memory)
|
||||
if self.memory[curr_idx-1]['content'] == content:
|
||||
pretty_print("Warning: same message have been pushed twice to memory", color="error")
|
||||
self.memory.append({'role': role, 'content': content})
|
||||
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.memory.append({'role': role, 'content': content, 'time': time_str, 'model_used': self.model_provider})
|
||||
return curr_idx-1
|
||||
|
||||
def clear(self) -> None:
|
||||
@ -182,11 +206,9 @@ class Memory():
|
||||
self.logger.warning("No tokenizer or model to perform memory compression.")
|
||||
return
|
||||
for i in range(len(self.memory)):
|
||||
if i < 2:
|
||||
continue
|
||||
if self.memory[i]['role'] == 'system':
|
||||
continue
|
||||
if len(self.memory[i]['content']) > 128:
|
||||
if len(self.memory[i]['content']) > 2048:
|
||||
self.memory[i]['content'] = self.summarize(self.memory[i]['content'])
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user