fix : memory summarization issue

This commit is contained in:
martin legrand 2025-04-04 11:41:47 +02:00
parent 80a3391b84
commit 92e2e8c0d6

View File

@ -15,7 +15,7 @@ from sources.utility import timer_decorator, pretty_print
class Memory(): class Memory():
""" """
Memory is a class for managing the conversation memory Memory is a class for managing the conversation memory
It provides a method to compress the memory (experimental, use with caution). It provides a method to compress the memory using summarization model.
""" """
def __init__(self, system_prompt: str, def __init__(self, system_prompt: str,
recover_last_session: bool = False, recover_last_session: bool = False,
@ -38,6 +38,7 @@ class Memory():
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model) self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model)
def get_filename(self) -> str: def get_filename(self) -> str:
"""Get the filename for the save file."""
return f"memory_{self.session_time.strftime('%Y-%m-%d_%H-%M-%S')}.txt" return f"memory_{self.session_time.strftime('%Y-%m-%d_%H-%M-%S')}.txt"
def save_memory(self, agent_type: str = "casual_agent") -> None: def save_memory(self, agent_type: str = "casual_agent") -> None:
@ -103,6 +104,7 @@ class Memory():
self.memory = [] self.memory = []
def clear_section(self, start: int, end: int) -> None: def clear_section(self, start: int, end: int) -> None:
"""Clear a section of the memory."""
self.memory = self.memory[:start] + self.memory[end:] self.memory = self.memory[:start] + self.memory[end:]
def get(self) -> list: def get(self) -> list:
@ -134,23 +136,23 @@ class Memory():
inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
summary_ids = self.model.generate( summary_ids = self.model.generate(
inputs['input_ids'], inputs['input_ids'],
max_length=max_length, # Maximum length of the summary max_length=max_length,
min_length=min_length, # Minimum length of the summary min_length=min_length,
length_penalty=1.0, # Adjusts length preference length_penalty=1.0,
num_beams=4, # Beam search for better quality num_beams=4,
early_stopping=True # Stop when all beams finish early_stopping=True
) )
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summary.replace('summary:', '') summary.replace('summary:', '')
return summary return summary
@timer_decorator #@timer_decorator
def compress(self) -> str: def compress(self) -> str:
""" """
Compress the memory using the AI model. Compress the memory using the AI model.
""" """
for i in range(len(self.memory)): for i in range(len(self.memory)):
if i < 3: if i < 2:
continue continue
if self.memory[i]['role'] == 'system': if self.memory[i]['role'] == 'system':
continue continue
@ -179,10 +181,12 @@ Use the -I flag to specify the directory containing helper_functions.h.
Ensure the file exists in the specified location. Ensure the file exists in the specified location.
""" """
memory.push('user', "why do i get this error?") memory.push('user', "hello")
memory.push('assistant', "how can i help you?")
memory.push('user', "why do i get this cuda error?")
memory.push('assistant', sample_text) memory.push('assistant', sample_text)
print("\n---\nmemory before:", memory.get()) print("\n---\nmemory before:", memory.get())
memory.compress() memory.compress()
print("\n---\nmemory after:", memory.get()) print("\n---\nmemory after:", memory.get())
memory.save_memory() #memory.save_memory()