fix: memory compression on session loading

This commit is contained in:
martin legrand 2025-03-28 15:41:54 +01:00
parent fcda0abc21
commit dae2c224e5
2 changed files with 10 additions and 3 deletions

View File

@ -66,13 +66,16 @@ class Memory():
def load_memory(self, agent_type: str = "casual_agent") -> None: def load_memory(self, agent_type: str = "casual_agent") -> None:
"""Load the memory from the last session.""" """Load the memory from the last session."""
pretty_print(f"Loading {agent_type} past memories... ", color="status")
if self.session_recovered == True: if self.session_recovered == True:
return return
save_path = os.path.join(self.conversation_folder, agent_type) save_path = os.path.join(self.conversation_folder, agent_type)
if not os.path.exists(save_path): if not os.path.exists(save_path):
pretty_print("No memory to load.", color="success")
return return
filename = self.find_last_session_path(save_path) filename = self.find_last_session_path(save_path)
if filename is None: if filename is None:
pretty_print("Last session memory not found.", color="warning")
return return
path = os.path.join(save_path, filename) path = os.path.join(save_path, filename)
with open(path, 'r') as f: with open(path, 'r') as f:
@ -80,6 +83,7 @@ class Memory():
if self.memory[-1]['role'] == 'user': if self.memory[-1]['role'] == 'user':
self.memory.pop() self.memory.pop()
self.compress() self.compress()
pretty_print("Session recovered successfully", color="success")
def reset(self, memory: list) -> None: def reset(self, memory: list) -> None:
self.memory = memory self.memory = memory
@ -118,6 +122,8 @@ class Memory():
""" """
if self.tokenizer is None or self.model is None: if self.tokenizer is None or self.model is None:
return text return text
if len(text) < min_length*1.5:
return text
max_length = len(text) // 2 if len(text) > min_length*2 else min_length*2 max_length = len(text) // 2 if len(text) > min_length*2 else min_length*2
input_text = "summarize: " + text input_text = "summarize: " + text
inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
@ -141,7 +147,9 @@ class Memory():
for i in range(len(self.memory)): for i in range(len(self.memory)):
if i < 3: if i < 3:
continue continue
if len(self.memory[i]['content']) > 1024: if self.memory[i]['role'] == 'system':
continue
if len(self.memory[i]['content']) > 128:
self.memory[i]['content'] = self.summarize(self.memory[i]['content']) self.memory[i]['content'] = self.summarize(self.memory[i]['content'])
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -76,7 +76,6 @@ def animate_thinking(text, color="status", duration=2):
print(colored(f"{symbol} {text}", term_color), flush=True) print(colored(f"{symbol} {text}", term_color), flush=True)
time.sleep(0.1) time.sleep(0.1)
print("\033[1A\033[K", end="", flush=True) print("\033[1A\033[K", end="", flush=True)
print()
animation_thread = threading.Thread(target=_animate) animation_thread = threading.Thread(target=_animate)
animation_thread.start() animation_thread.start()
animation_thread.join() animation_thread.join()
@ -94,6 +93,6 @@ def timer_decorator(func):
start_time = time() start_time = time()
result = func(*args, **kwargs) result = func(*args, **kwargs)
end_time = time() end_time = time()
print(f"{func.__name__} took {end_time - start_time:.2f} seconds to execute") pretty_print(f"{func.__name__} took {end_time - start_time:.2f} seconds to execute", "status")
return result return result
return wrapper return wrapper