feat: improve compression/web agent memory management

This commit is contained in:
martin legrand 2025-05-05 14:15:32 +02:00
parent a7deffedec
commit e4591ea1b4
4 changed files with 74 additions and 15 deletions

View File

@ -100,7 +100,7 @@ class Agent():
"""
description = ""
for name in self.get_tools_name():
description += f"{tool}: {self.tools[name].description}\n"
description += f"{name}: {self.tools[name].description}\n"
return description
def load_prompt(self, file_path: str) -> str:

View File

@ -243,6 +243,14 @@ class BrowserAgent(Agent):
self.logger.warning("No link selected.")
return None
def get_page_text(self, compression = False) -> str:
"""Get the text content of the current page."""
page_text = self.browser.get_text()
if compression:
#page_text = self.memory.compress_text_to_max_ctx(page_text)
page_text = self.memory.trim_text_to_max_ctx(page_text)
return page_text
def conclude_prompt(self, user_query: str) -> str:
annotated_notes = [f"{i+1}: {note.lower()}" for i, note in enumerate(self.notes)]
search_note = '\n'.join(annotated_notes)
@ -357,13 +365,13 @@ class BrowserAgent(Agent):
self.status_message = "Filling web form..."
pretty_print(f"Filling inputs form...", color="status")
fill_success = self.browser.fill_form(extracted_form)
page_text = self.browser.get_text()
page_text = self.get_page_text()
answer = self.handle_update_prompt(user_prompt, page_text, fill_success)
answer, reasoning = await self.llm_decide(prompt)
if Action.FORM_FILLED.value in answer:
pretty_print(f"Filled form. Handling page update.", color="status")
page_text = self.browser.get_text()
page_text = self.get_page_text()
self.navigable_links = self.browser.get_navigable()
prompt = self.make_navigation_prompt(user_prompt, page_text)
continue
@ -399,7 +407,7 @@ class BrowserAgent(Agent):
prompt = self.make_newsearch_prompt(user_prompt, unvisited)
continue
self.current_page = link
page_text = self.browser.get_text()
page_text = self.get_page_text()
self.navigable_links = self.browser.get_navigable()
prompt = self.make_navigation_prompt(user_prompt, page_text)
self.status_message = "Navigating..."

View File

@ -266,7 +266,7 @@ class Browser:
result = re.sub(r'!\[(.*?)\]\(.*?\)', r'[IMAGE: \1]', result)
self.logger.info(f"Extracted text: {result[:100]}...")
self.logger.info(f"Extracted text length: {len(result)}")
return result[:8192]
return result[:32768]
except Exception as e:
self.logger.error(f"Error getting text: {str(e)}")
return None

View File

@ -20,7 +20,6 @@ class Memory():
recover_last_session: bool = False,
memory_compression: bool = True,
model_provider: str = "deepseek-r1:14b"):
self.memory = []
self.memory = [{'role': 'system', 'content': system_prompt}]
self.logger = Logger("memory.log")
@ -40,9 +39,10 @@ class Memory():
if self.memory_compression:
self.download_model()
def get_ideal_ctx(self, model_name: str) -> int:
def get_ideal_ctx(self, model_name: str) -> int | None:
"""
Estimate context size based on the model name.
EXPERIMENTAL for memory compression
"""
import re
import math
@ -100,6 +100,32 @@ class Memory():
self.logger.info(f"Last session found at {saved_sessions[0][0]}")
return saved_sessions[0][0]
return None
def save_json_file(self, path: str, json_memory: dict) -> None:
"""Save a JSON file."""
try:
with open(path, 'w') as f:
json.dump(json_memory, f)
self.logger.info(f"Saved memory json at {path}")
except Exception as e:
self.logger.warning(f"Error saving file {path}: {e}")
def load_json_file(self, path: str) -> dict:
"""Load a JSON file."""
json_memory = {}
try:
with open(path, 'r') as f:
json_memory = json.load(f)
except FileNotFoundError:
self.logger.warning(f"File not found: {path}")
return None
except json.JSONDecodeError:
self.logger.warning(f"Error decoding JSON from file: {path}")
return None
except Exception as e:
self.logger.warning(f"Error loading file {path}: {e}")
return None
return json_memory
def load_memory(self, agent_type: str = "casual_agent") -> None:
"""Load the memory from the last session."""
@ -115,8 +141,7 @@ class Memory():
pretty_print("Last session memory not found.", color="warning")
return
path = os.path.join(save_path, filename)
with open(path, 'r') as f:
self.memory = json.load(f)
self.memory = self.load_json_file(path)
if self.memory[-1]['role'] == 'user':
self.memory.pop()
self.compress()
@ -129,9 +154,10 @@ class Memory():
def push(self, role: str, content: str) -> int:
"""Push a message to the memory."""
ideal_ctx = self.get_ideal_ctx(self.model_provider)
if self.memory_compression and len(content) > ideal_ctx:
self.logger.info(f"Compressing memory: Content {len(content)} > {ideal_ctx} model context.")
self.compress()
if ideal_ctx is not None:
if self.memory_compression and len(content) > ideal_ctx * 1.5:
self.logger.info(f"Compressing memory: Content {len(content)} > {ideal_ctx} model context.")
self.compress()
curr_idx = len(self.memory)
if self.memory[curr_idx-1]['content'] == content:
pretty_print("Warning: same message have been pushed twice to memory", color="error")
@ -194,13 +220,14 @@ class Memory():
)
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summary.replace('summary:', '')
self.logger.info(f"Memory summarization success from len {len(text)} to {len(summary)}.")
self.logger.info(f"Memory summarized from len {len(text)} to {len(summary)}.")
self.logger.info(f"Summarized text:\n{summary}")
return summary
#@timer_decorator
def compress(self) -> str:
"""
Compress the memory using the AI model.
Compress (summarize) the memory using the model.
"""
if self.tokenizer is None or self.model is None:
self.logger.warning("No tokenizer or model to perform memory compression.")
@ -208,8 +235,32 @@ class Memory():
for i in range(len(self.memory)):
if self.memory[i]['role'] == 'system':
continue
if len(self.memory[i]['content']) > 2048:
if len(self.memory[i]['content']) > 1024:
self.memory[i]['content'] = self.summarize(self.memory[i]['content'])
def trip_text_to_max_ctx(self, text: str) -> str:
"""
Truncate a text to fit within the maximum context size of the model.
"""
ideal_ctx = self.get_ideal_ctx(self.model_provider)
return text[:ideal_ctx] if ideal_ctx is not None else text
#@timer_decorator
def compress_text_to_max_ctx(self, text) -> str:
"""
Compress a text to fit within the maximum context size of the model.
"""
if self.tokenizer is None or self.model is None:
self.logger.warning("No tokenizer or model to perform memory compression.")
return text
ideal_ctx = self.get_ideal_ctx(self.model_provider)
if ideal_ctx is None:
self.logger.warning("No ideal context size found.")
return text
while len(text) > ideal_ctx:
self.logger.info(f"Compressing text: {len(text)} > {ideal_ctx} model context.")
text = self.summarize(text)
return text
if __name__ == "__main__":
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))