fix: memory compression on session loading

This commit is contained in:
martin legrand 2025-03-28 15:41:54 +01:00
parent fcda0abc21
commit dae2c224e5
2 changed files with 10 additions and 3 deletions

View File

@ -66,13 +66,16 @@ class Memory():
def load_memory(self, agent_type: str = "casual_agent") -> None:
"""Load the memory from the last session."""
pretty_print(f"Loading {agent_type} past memories... ", color="status")
if self.session_recovered == True:
return
save_path = os.path.join(self.conversation_folder, agent_type)
if not os.path.exists(save_path):
pretty_print("No memory to load.", color="success")
return
filename = self.find_last_session_path(save_path)
if filename is None:
pretty_print("Last session memory not found.", color="warning")
return
path = os.path.join(save_path, filename)
with open(path, 'r') as f:
@ -80,6 +83,7 @@ class Memory():
if self.memory[-1]['role'] == 'user':
self.memory.pop()
self.compress()
pretty_print("Session recovered successfully", color="success")
def reset(self, memory: list) -> None:
self.memory = memory
@ -118,6 +122,8 @@ class Memory():
"""
if self.tokenizer is None or self.model is None:
return text
if len(text) < min_length*1.5:
return text
max_length = len(text) // 2 if len(text) > min_length*2 else min_length*2
input_text = "summarize: " + text
inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
@ -141,7 +147,9 @@ class Memory():
for i in range(len(self.memory)):
if i < 3:
continue
if len(self.memory[i]['content']) > 1024:
if self.memory[i]['role'] == 'system':
continue
if len(self.memory[i]['content']) > 128:
self.memory[i]['content'] = self.summarize(self.memory[i]['content'])
if __name__ == "__main__":

View File

@ -76,7 +76,6 @@ def animate_thinking(text, color="status", duration=2):
print(colored(f"{symbol} {text}", term_color), flush=True)
time.sleep(0.1)
print("\033[1A\033[K", end="", flush=True)
print()
animation_thread = threading.Thread(target=_animate)
animation_thread.start()
animation_thread.join()
@ -94,6 +93,6 @@ def timer_decorator(func):
start_time = time()
result = func(*args, **kwargs)
end_time = time()
print(f"{func.__name__} took {end_time - start_time:.2f} seconds to execute")
pretty_print(f"{func.__name__} took {end_time - start_time:.2f} seconds to execute", "status")
return result
return wrapper