mirror of
https://github.com/tcsenpai/agenticSeek.git
synced 2025-06-05 02:25:27 +00:00
feat : improve memory system
This commit is contained in:
parent
5949540007
commit
a7deffedec
@ -39,9 +39,7 @@ class Agent():
|
|||||||
self.type = None
|
self.type = None
|
||||||
self.current_directory = os.getcwd()
|
self.current_directory = os.getcwd()
|
||||||
self.llm = provider
|
self.llm = provider
|
||||||
self.memory = Memory(self.load_prompt(prompt_path),
|
self.memory = None
|
||||||
recover_last_session=False, # session recovery in handled by the interaction class
|
|
||||||
memory_compression=False)
|
|
||||||
self.tools = {}
|
self.tools = {}
|
||||||
self.blocks_result = []
|
self.blocks_result = []
|
||||||
self.success = True
|
self.success = True
|
||||||
|
@ -10,6 +10,7 @@ from sources.agents.agent import Agent
|
|||||||
from sources.tools.searxSearch import searxSearch
|
from sources.tools.searxSearch import searxSearch
|
||||||
from sources.browser import Browser
|
from sources.browser import Browser
|
||||||
from sources.logger import Logger
|
from sources.logger import Logger
|
||||||
|
from sources.memory import Memory
|
||||||
|
|
||||||
class Action(Enum):
|
class Action(Enum):
|
||||||
REQUEST_EXIT = "REQUEST_EXIT"
|
REQUEST_EXIT = "REQUEST_EXIT"
|
||||||
@ -37,6 +38,10 @@ class BrowserAgent(Agent):
|
|||||||
self.notes = []
|
self.notes = []
|
||||||
self.date = self.get_today_date()
|
self.date = self.get_today_date()
|
||||||
self.logger = Logger("browser_agent.log")
|
self.logger = Logger("browser_agent.log")
|
||||||
|
self.memory = Memory(self.load_prompt(prompt_path),
|
||||||
|
recover_last_session=False, # session recovery in handled by the interaction class
|
||||||
|
memory_compression=False,
|
||||||
|
model_provider=provider.get_model_name())
|
||||||
|
|
||||||
def get_today_date(self) -> str:
|
def get_today_date(self) -> str:
|
||||||
"""Get the date"""
|
"""Get the date"""
|
||||||
|
@ -6,6 +6,7 @@ from sources.tools.searxSearch import searxSearch
|
|||||||
from sources.tools.flightSearch import FlightSearch
|
from sources.tools.flightSearch import FlightSearch
|
||||||
from sources.tools.fileFinder import FileFinder
|
from sources.tools.fileFinder import FileFinder
|
||||||
from sources.tools.BashInterpreter import BashInterpreter
|
from sources.tools.BashInterpreter import BashInterpreter
|
||||||
|
from sources.memory import Memory
|
||||||
|
|
||||||
class CasualAgent(Agent):
|
class CasualAgent(Agent):
|
||||||
def __init__(self, name, prompt_path, provider, verbose=False):
|
def __init__(self, name, prompt_path, provider, verbose=False):
|
||||||
@ -17,6 +18,10 @@ class CasualAgent(Agent):
|
|||||||
} # No tools for the casual agent
|
} # No tools for the casual agent
|
||||||
self.role = "talk"
|
self.role = "talk"
|
||||||
self.type = "casual_agent"
|
self.type = "casual_agent"
|
||||||
|
self.memory = Memory(self.load_prompt(prompt_path),
|
||||||
|
recover_last_session=False, # session recovery in handled by the interaction class
|
||||||
|
memory_compression=False,
|
||||||
|
model_provider=provider.get_model_name())
|
||||||
|
|
||||||
async def process(self, prompt, speech_module) -> str:
|
async def process(self, prompt, speech_module) -> str:
|
||||||
self.memory.push('user', prompt)
|
self.memory.push('user', prompt)
|
||||||
|
@ -10,6 +10,7 @@ from sources.tools.BashInterpreter import BashInterpreter
|
|||||||
from sources.tools.JavaInterpreter import JavaInterpreter
|
from sources.tools.JavaInterpreter import JavaInterpreter
|
||||||
from sources.tools.fileFinder import FileFinder
|
from sources.tools.fileFinder import FileFinder
|
||||||
from sources.logger import Logger
|
from sources.logger import Logger
|
||||||
|
from sources.memory import Memory
|
||||||
|
|
||||||
class CoderAgent(Agent):
|
class CoderAgent(Agent):
|
||||||
"""
|
"""
|
||||||
@ -29,6 +30,10 @@ class CoderAgent(Agent):
|
|||||||
self.role = "code"
|
self.role = "code"
|
||||||
self.type = "code_agent"
|
self.type = "code_agent"
|
||||||
self.logger = Logger("code_agent.log")
|
self.logger = Logger("code_agent.log")
|
||||||
|
self.memory = Memory(self.load_prompt(prompt_path),
|
||||||
|
recover_last_session=False, # session recovery in handled by the interaction class
|
||||||
|
memory_compression=False,
|
||||||
|
model_provider=provider.get_model_name())
|
||||||
|
|
||||||
def add_sys_info_prompt(self, prompt):
|
def add_sys_info_prompt(self, prompt):
|
||||||
"""Add system information to the prompt."""
|
"""Add system information to the prompt."""
|
||||||
|
@ -4,6 +4,7 @@ from sources.utility import pretty_print, animate_thinking
|
|||||||
from sources.agents.agent import Agent
|
from sources.agents.agent import Agent
|
||||||
from sources.tools.fileFinder import FileFinder
|
from sources.tools.fileFinder import FileFinder
|
||||||
from sources.tools.BashInterpreter import BashInterpreter
|
from sources.tools.BashInterpreter import BashInterpreter
|
||||||
|
from sources.memory import Memory
|
||||||
|
|
||||||
class FileAgent(Agent):
|
class FileAgent(Agent):
|
||||||
def __init__(self, name, prompt_path, provider, verbose=False):
|
def __init__(self, name, prompt_path, provider, verbose=False):
|
||||||
@ -18,6 +19,10 @@ class FileAgent(Agent):
|
|||||||
self.work_dir = self.tools["file_finder"].get_work_dir()
|
self.work_dir = self.tools["file_finder"].get_work_dir()
|
||||||
self.role = "files"
|
self.role = "files"
|
||||||
self.type = "file_agent"
|
self.type = "file_agent"
|
||||||
|
self.memory = Memory(self.load_prompt(prompt_path),
|
||||||
|
recover_last_session=False, # session recovery in handled by the interaction class
|
||||||
|
memory_compression=False,
|
||||||
|
model_provider=provider.get_model_name())
|
||||||
|
|
||||||
async def process(self, prompt, speech_module) -> str:
|
async def process(self, prompt, speech_module) -> str:
|
||||||
exec_success = False
|
exec_success = False
|
||||||
|
@ -4,6 +4,7 @@ import asyncio
|
|||||||
from sources.utility import pretty_print, animate_thinking
|
from sources.utility import pretty_print, animate_thinking
|
||||||
from sources.agents.agent import Agent
|
from sources.agents.agent import Agent
|
||||||
from sources.tools.mcpFinder import MCP_finder
|
from sources.tools.mcpFinder import MCP_finder
|
||||||
|
from sources.memory import Memory
|
||||||
|
|
||||||
# NOTE MCP agent is an active work in progress, not functional yet.
|
# NOTE MCP agent is an active work in progress, not functional yet.
|
||||||
|
|
||||||
@ -22,6 +23,10 @@ class McpAgent(Agent):
|
|||||||
}
|
}
|
||||||
self.role = "mcp"
|
self.role = "mcp"
|
||||||
self.type = "mcp_agent"
|
self.type = "mcp_agent"
|
||||||
|
self.memory = Memory(self.load_prompt(prompt_path),
|
||||||
|
recover_last_session=False, # session recovery in handled by the interaction class
|
||||||
|
memory_compression=False,
|
||||||
|
model_provider=provider.get_model_name())
|
||||||
self.enabled = True
|
self.enabled = True
|
||||||
|
|
||||||
def get_api_keys(self) -> dict:
|
def get_api_keys(self) -> dict:
|
||||||
|
@ -9,6 +9,7 @@ from sources.agents.casual_agent import CasualAgent
|
|||||||
from sources.text_to_speech import Speech
|
from sources.text_to_speech import Speech
|
||||||
from sources.tools.tools import Tools
|
from sources.tools.tools import Tools
|
||||||
from sources.logger import Logger
|
from sources.logger import Logger
|
||||||
|
from sources.memory import Memory
|
||||||
|
|
||||||
class PlannerAgent(Agent):
|
class PlannerAgent(Agent):
|
||||||
def __init__(self, name, prompt_path, provider, verbose=False, browser=None):
|
def __init__(self, name, prompt_path, provider, verbose=False, browser=None):
|
||||||
@ -29,6 +30,10 @@ class PlannerAgent(Agent):
|
|||||||
}
|
}
|
||||||
self.role = "planification"
|
self.role = "planification"
|
||||||
self.type = "planner_agent"
|
self.type = "planner_agent"
|
||||||
|
self.memory = Memory(self.load_prompt(prompt_path),
|
||||||
|
recover_last_session=False, # session recovery in handled by the interaction class
|
||||||
|
memory_compression=False,
|
||||||
|
model_provider=provider.get_model_name())
|
||||||
self.logger = Logger("planner_agent.log")
|
self.logger = Logger("planner_agent.log")
|
||||||
|
|
||||||
def get_task_names(self, text: str) -> List[str]:
|
def get_task_names(self, text: str) -> List[str]:
|
||||||
|
@ -44,6 +44,9 @@ class Provider:
|
|||||||
self.api_key = self.get_api_key(self.provider_name)
|
self.api_key = self.get_api_key(self.provider_name)
|
||||||
elif self.provider_name != "ollama":
|
elif self.provider_name != "ollama":
|
||||||
pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success")
|
pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success")
|
||||||
|
|
||||||
|
def get_model_name(self) -> str:
|
||||||
|
return self.model
|
||||||
|
|
||||||
def get_api_key(self, provider):
|
def get_api_key(self, provider):
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -8,7 +8,7 @@ from typing import List, Tuple, Type, Dict
|
|||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||||
|
|
||||||
from sources.utility import timer_decorator, pretty_print
|
from sources.utility import timer_decorator, pretty_print, animate_thinking
|
||||||
from sources.logger import Logger
|
from sources.logger import Logger
|
||||||
|
|
||||||
class Memory():
|
class Memory():
|
||||||
@ -18,7 +18,8 @@ class Memory():
|
|||||||
"""
|
"""
|
||||||
def __init__(self, system_prompt: str,
|
def __init__(self, system_prompt: str,
|
||||||
recover_last_session: bool = False,
|
recover_last_session: bool = False,
|
||||||
memory_compression: bool = True):
|
memory_compression: bool = True,
|
||||||
|
model_provider: str = "deepseek-r1:14b"):
|
||||||
self.memory = []
|
self.memory = []
|
||||||
self.memory = [{'role': 'system', 'content': system_prompt}]
|
self.memory = [{'role': 'system', 'content': system_prompt}]
|
||||||
|
|
||||||
@ -31,21 +32,42 @@ class Memory():
|
|||||||
self.load_memory()
|
self.load_memory()
|
||||||
self.session_recovered = True
|
self.session_recovered = True
|
||||||
# memory compression system
|
# memory compression system
|
||||||
self.model = "pszemraj/led-base-book-summary"
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
self.device = self.get_cuda_device()
|
self.device = self.get_cuda_device()
|
||||||
self.memory_compression = memory_compression
|
self.memory_compression = memory_compression
|
||||||
self.tokenizer = None
|
self.model_provider = model_provider
|
||||||
self.model = None
|
|
||||||
if self.memory_compression:
|
if self.memory_compression:
|
||||||
self.download_model()
|
self.download_model()
|
||||||
|
|
||||||
|
def get_ideal_ctx(self, model_name: str) -> int:
|
||||||
|
"""
|
||||||
|
Estimate context size based on the model name.
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
import math
|
||||||
|
|
||||||
|
def extract_number_before_b(sentence: str) -> int:
|
||||||
|
match = re.search(r'(\d+)b', sentence, re.IGNORECASE)
|
||||||
|
return int(match.group(1)) if match else None
|
||||||
|
|
||||||
|
model_size = extract_number_before_b(model_name)
|
||||||
|
if not model_size:
|
||||||
|
return None
|
||||||
|
base_size = 7 # Base model size in billions
|
||||||
|
base_context = 4096 # Base context size in tokens
|
||||||
|
scaling_factor = 1.5 # Approximate scaling factor for context size growth
|
||||||
|
context_size = int(base_context * (model_size / base_size) ** scaling_factor)
|
||||||
|
context_size = 2 ** round(math.log2(context_size))
|
||||||
|
self.logger.info(f"Estimated context size for {model_name}: {context_size} tokens.")
|
||||||
|
return context_size
|
||||||
|
|
||||||
def download_model(self):
|
def download_model(self):
|
||||||
"""Download the model if not already downloaded."""
|
"""Download the model if not already downloaded."""
|
||||||
pretty_print("Downloading memory compression model...", color="status")
|
animate_thinking("Loading memory compression model...", color="status")
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model)
|
self.tokenizer = AutoTokenizer.from_pretrained("pszemraj/led-base-book-summary")
|
||||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model)
|
self.model = AutoModelForSeq2SeqLM.from_pretrained("pszemraj/led-base-book-summary")
|
||||||
self.logger.info("Memory compression system initialized.")
|
self.logger.info("Memory compression system initialized.")
|
||||||
|
|
||||||
|
|
||||||
def get_filename(self) -> str:
|
def get_filename(self) -> str:
|
||||||
"""Get the filename for the save file."""
|
"""Get the filename for the save file."""
|
||||||
@ -106,13 +128,15 @@ class Memory():
|
|||||||
|
|
||||||
def push(self, role: str, content: str) -> int:
|
def push(self, role: str, content: str) -> int:
|
||||||
"""Push a message to the memory."""
|
"""Push a message to the memory."""
|
||||||
if self.memory_compression and role == 'assistant':
|
ideal_ctx = self.get_ideal_ctx(self.model_provider)
|
||||||
self.logger.info("Compressing memories on message push.")
|
if self.memory_compression and len(content) > ideal_ctx:
|
||||||
|
self.logger.info(f"Compressing memory: Content {len(content)} > {ideal_ctx} model context.")
|
||||||
self.compress()
|
self.compress()
|
||||||
curr_idx = len(self.memory)
|
curr_idx = len(self.memory)
|
||||||
if self.memory[curr_idx-1]['content'] == content:
|
if self.memory[curr_idx-1]['content'] == content:
|
||||||
pretty_print("Warning: same message have been pushed twice to memory", color="error")
|
pretty_print("Warning: same message have been pushed twice to memory", color="error")
|
||||||
self.memory.append({'role': role, 'content': content})
|
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
self.memory.append({'role': role, 'content': content, 'time': time_str, 'model_used': self.model_provider})
|
||||||
return curr_idx-1
|
return curr_idx-1
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
@ -182,11 +206,9 @@ class Memory():
|
|||||||
self.logger.warning("No tokenizer or model to perform memory compression.")
|
self.logger.warning("No tokenizer or model to perform memory compression.")
|
||||||
return
|
return
|
||||||
for i in range(len(self.memory)):
|
for i in range(len(self.memory)):
|
||||||
if i < 2:
|
|
||||||
continue
|
|
||||||
if self.memory[i]['role'] == 'system':
|
if self.memory[i]['role'] == 'system':
|
||||||
continue
|
continue
|
||||||
if len(self.memory[i]['content']) > 128:
|
if len(self.memory[i]['content']) > 2048:
|
||||||
self.memory[i]['content'] = self.summarize(self.memory[i]['content'])
|
self.memory[i]['content'] = self.summarize(self.memory[i]['content'])
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user