diff --git a/sources/agents/agent.py b/sources/agents/agent.py index d662e55..a154867 100644 --- a/sources/agents/agent.py +++ b/sources/agents/agent.py @@ -100,7 +100,7 @@ class Agent(): """ description = "" for name in self.get_tools_name(): - description += f"{tool}: {self.tools[name].description}\n" + description += f"{name}: {self.tools[name].description}\n" return description def load_prompt(self, file_path: str) -> str: diff --git a/sources/agents/browser_agent.py b/sources/agents/browser_agent.py index 648d700..5a86f4e 100644 --- a/sources/agents/browser_agent.py +++ b/sources/agents/browser_agent.py @@ -243,6 +243,14 @@ class BrowserAgent(Agent): self.logger.warning("No link selected.") return None + def get_page_text(self, compression = False) -> str: + """Get the text content of the current page.""" + page_text = self.browser.get_text() + if compression: + #page_text = self.memory.compress_text_to_max_ctx(page_text) + page_text = self.memory.trim_text_to_max_ctx(page_text) + return page_text + def conclude_prompt(self, user_query: str) -> str: annotated_notes = [f"{i+1}: {note.lower()}" for i, note in enumerate(self.notes)] search_note = '\n'.join(annotated_notes) @@ -357,13 +365,13 @@ class BrowserAgent(Agent): self.status_message = "Filling web form..." pretty_print(f"Filling inputs form...", color="status") fill_success = self.browser.fill_form(extracted_form) - page_text = self.browser.get_text() + page_text = self.get_page_text() answer = self.handle_update_prompt(user_prompt, page_text, fill_success) answer, reasoning = await self.llm_decide(prompt) if Action.FORM_FILLED.value in answer: pretty_print(f"Filled form. Handling page update.", color="status") - page_text = self.browser.get_text() + page_text = self.get_page_text() self.navigable_links = self.browser.get_navigable() prompt = self.make_navigation_prompt(user_prompt, page_text) continue @@ -399,7 +407,7 @@ class BrowserAgent(Agent): prompt = self.make_newsearch_prompt(user_prompt, unvisited) continue self.current_page = link - page_text = self.browser.get_text() + page_text = self.get_page_text() self.navigable_links = self.browser.get_navigable() prompt = self.make_navigation_prompt(user_prompt, page_text) self.status_message = "Navigating..." diff --git a/sources/browser.py b/sources/browser.py index be511ad..c5f8309 100644 --- a/sources/browser.py +++ b/sources/browser.py @@ -266,7 +266,7 @@ class Browser: result = re.sub(r'!\[(.*?)\]\(.*?\)', r'[IMAGE: \1]', result) self.logger.info(f"Extracted text: {result[:100]}...") self.logger.info(f"Extracted text length: {len(result)}") - return result[:8192] + return result[:32768] except Exception as e: self.logger.error(f"Error getting text: {str(e)}") return None diff --git a/sources/memory.py b/sources/memory.py index 66ed83d..86535aa 100644 --- a/sources/memory.py +++ b/sources/memory.py @@ -20,7 +20,6 @@ class Memory(): recover_last_session: bool = False, memory_compression: bool = True, model_provider: str = "deepseek-r1:14b"): - self.memory = [] self.memory = [{'role': 'system', 'content': system_prompt}] self.logger = Logger("memory.log") @@ -40,9 +39,10 @@ class Memory(): if self.memory_compression: self.download_model() - def get_ideal_ctx(self, model_name: str) -> int: + def get_ideal_ctx(self, model_name: str) -> int | None: """ Estimate context size based on the model name. + EXPERIMENTAL for memory compression """ import re import math @@ -100,6 +100,32 @@ class Memory(): self.logger.info(f"Last session found at {saved_sessions[0][0]}") return saved_sessions[0][0] return None + + def save_json_file(self, path: str, json_memory: dict) -> None: + """Save a JSON file.""" + try: + with open(path, 'w') as f: + json.dump(json_memory, f) + self.logger.info(f"Saved memory json at {path}") + except Exception as e: + self.logger.warning(f"Error saving file {path}: {e}") + + def load_json_file(self, path: str) -> dict: + """Load a JSON file.""" + json_memory = {} + try: + with open(path, 'r') as f: + json_memory = json.load(f) + except FileNotFoundError: + self.logger.warning(f"File not found: {path}") + return None + except json.JSONDecodeError: + self.logger.warning(f"Error decoding JSON from file: {path}") + return None + except Exception as e: + self.logger.warning(f"Error loading file {path}: {e}") + return None + return json_memory def load_memory(self, agent_type: str = "casual_agent") -> None: """Load the memory from the last session.""" @@ -115,8 +141,7 @@ class Memory(): pretty_print("Last session memory not found.", color="warning") return path = os.path.join(save_path, filename) - with open(path, 'r') as f: - self.memory = json.load(f) + self.memory = self.load_json_file(path) if self.memory[-1]['role'] == 'user': self.memory.pop() self.compress() @@ -129,9 +154,10 @@ class Memory(): def push(self, role: str, content: str) -> int: """Push a message to the memory.""" ideal_ctx = self.get_ideal_ctx(self.model_provider) - if self.memory_compression and len(content) > ideal_ctx: - self.logger.info(f"Compressing memory: Content {len(content)} > {ideal_ctx} model context.") - self.compress() + if ideal_ctx is not None: + if self.memory_compression and len(content) > ideal_ctx * 1.5: + self.logger.info(f"Compressing memory: Content {len(content)} > {ideal_ctx} model context.") + self.compress() curr_idx = len(self.memory) if self.memory[curr_idx-1]['content'] == content: pretty_print("Warning: same message have been pushed twice to memory", color="error") @@ -194,13 +220,14 @@ class Memory(): ) summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) summary.replace('summary:', '') - self.logger.info(f"Memory summarization success from len {len(text)} to {len(summary)}.") + self.logger.info(f"Memory summarized from len {len(text)} to {len(summary)}.") + self.logger.info(f"Summarized text:\n{summary}") return summary #@timer_decorator def compress(self) -> str: """ - Compress the memory using the AI model. + Compress (summarize) the memory using the model. """ if self.tokenizer is None or self.model is None: self.logger.warning("No tokenizer or model to perform memory compression.") @@ -208,8 +235,32 @@ class Memory(): for i in range(len(self.memory)): if self.memory[i]['role'] == 'system': continue - if len(self.memory[i]['content']) > 2048: + if len(self.memory[i]['content']) > 1024: self.memory[i]['content'] = self.summarize(self.memory[i]['content']) + + def trip_text_to_max_ctx(self, text: str) -> str: + """ + Truncate a text to fit within the maximum context size of the model. + """ + ideal_ctx = self.get_ideal_ctx(self.model_provider) + return text[:ideal_ctx] if ideal_ctx is not None else text + + #@timer_decorator + def compress_text_to_max_ctx(self, text) -> str: + """ + Compress a text to fit within the maximum context size of the model. + """ + if self.tokenizer is None or self.model is None: + self.logger.warning("No tokenizer or model to perform memory compression.") + return text + ideal_ctx = self.get_ideal_ctx(self.model_provider) + if ideal_ctx is None: + self.logger.warning("No ideal context size found.") + return text + while len(text) > ideal_ctx: + self.logger.info(f"Compressing text: {len(text)} > {ideal_ctx} model context.") + text = self.summarize(text) + return text if __name__ == "__main__": sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))