fix : memory summarization issue

This commit is contained in:
martin legrand 2025-04-04 11:41:47 +02:00
parent 80a3391b84
commit 92e2e8c0d6

View File

@ -15,7 +15,7 @@ from sources.utility import timer_decorator, pretty_print
class Memory():
"""
Memory is a class for managing the conversation memory
It provides a method to compress the memory (experimental, use with caution).
It provides a method to compress the memory using summarization model.
"""
def __init__(self, system_prompt: str,
recover_last_session: bool = False,
@ -38,6 +38,7 @@ class Memory():
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model)
def get_filename(self) -> str:
"""Get the filename for the save file."""
return f"memory_{self.session_time.strftime('%Y-%m-%d_%H-%M-%S')}.txt"
def save_memory(self, agent_type: str = "casual_agent") -> None:
@ -103,6 +104,7 @@ class Memory():
self.memory = []
def clear_section(self, start: int, end: int) -> None:
"""Clear a section of the memory."""
self.memory = self.memory[:start] + self.memory[end:]
def get(self) -> list:
@ -134,23 +136,23 @@ class Memory():
inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
summary_ids = self.model.generate(
inputs['input_ids'],
max_length=max_length, # Maximum length of the summary
min_length=min_length, # Minimum length of the summary
length_penalty=1.0, # Adjusts length preference
num_beams=4, # Beam search for better quality
early_stopping=True # Stop when all beams finish
max_length=max_length,
min_length=min_length,
length_penalty=1.0,
num_beams=4,
early_stopping=True
)
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summary.replace('summary:', '')
return summary
@timer_decorator
#@timer_decorator
def compress(self) -> str:
"""
Compress the memory using the AI model.
"""
for i in range(len(self.memory)):
if i < 3:
if i < 2:
continue
if self.memory[i]['role'] == 'system':
continue
@ -179,10 +181,12 @@ Use the -I flag to specify the directory containing helper_functions.h.
Ensure the file exists in the specified location.
"""
memory.push('user', "why do i get this error?")
memory.push('user', "hello")
memory.push('assistant', "how can i help you?")
memory.push('user', "why do i get this cuda error?")
memory.push('assistant', sample_text)
print("\n---\nmemory before:", memory.get())
memory.compress()
print("\n---\nmemory after:", memory.get())
memory.save_memory()
#memory.save_memory()