feat : server response cache

This commit is contained in:
martin legrand 2025-04-05 16:36:58 +02:00
parent 688e94d97c
commit 97708c7947
7 changed files with 50 additions and 28 deletions

36
server/sources/cache.py Normal file
View File

@ -0,0 +1,36 @@
import os
import json
from pathlib import Path
class Cache:
def __init__(self, cache_dir='.cache', cache_file='messages.json'):
self.cache_dir = Path(cache_dir)
self.cache_file = self.cache_dir / cache_file
self.cache_dir.mkdir(parents=True, exist_ok=True)
if not self.cache_file.exists():
with open(self.cache_file, 'w') as f:
json.dump([], f)
with open(self.cache_file, 'r') as f:
self.cache = set(json.load(f))
def add_message_pair(self, user_message: str, assistant_message: str):
"""Add a user/assistant pair to the cache if not present."""
if not any(entry["user"] == user_message for entry in self.cache):
self.cache.append({"user": user_message, "assistant": assistant_message})
self._save()
def is_cached(self, user_message: str) -> bool:
"""Check if a user msg is cached."""
return any(entry["user"] == user_message for entry in self.cache)
def get_cached_response(self, user_message: str) -> str | None:
"""Return the assistant response to a user message if cached."""
for entry in self.cache:
if entry["user"] == user_message:
return entry["assistant"]
return None
def _save(self):
with open(self.cache_file, 'w') as f:
json.dump(self.cache, f, indent=2)

View File

@ -2,6 +2,7 @@
import threading
import logging
from abc import abstractmethod
from .cache import Cache
class GenerationState:
def __init__(self):
@ -29,6 +30,7 @@ class GeneratorLLM():
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
cache = Cache()
def set_model(self, model: str) -> None:
self.logger.info(f"Model set to {model}")

View File

@ -13,6 +13,10 @@ class OllamaLLM(GeneratorLLM):
def generate(self, history):
self.logger.info(f"Using {self.model} for generation with Ollama")
if cache.is_cached(history[-1]['content']):
self.state.current_buffer = cache.get_cached_response(history[-1]['content'])
self.state.is_generating = False
return
try:
with self.state.lock:
self.state.is_generating = True
@ -43,6 +47,7 @@ class OllamaLLM(GeneratorLLM):
self.logger.info("Generation complete")
with self.state.lock:
self.state.is_generating = False
self.cache.add_message_pair(history[-1]['content'], self.state.current_buffer)
if __name__ == "__main__":
generator = OllamaLLM()

View File

@ -29,10 +29,4 @@ class CasualAgent(Agent):
return answer, reasoning
if __name__ == "__main__":
from llm_provider import Provider
#local_provider = Provider("ollama", "deepseek-r1:14b", None)
server_provider = Provider("server", "deepseek-r1:14b", "192.168.1.100:5000")
agent = CasualAgent("deepseek-r1:14b", "jarvis", "prompts/casual_agent.txt", server_provider)
ans = agent.process("Hello, how are you?")
print(ans)
pass

View File

@ -68,10 +68,4 @@ class CoderAgent(Agent):
return answer, reasoning
if __name__ == "__main__":
from llm_provider import Provider
#local_provider = Provider("ollama", "deepseek-r1:14b", None)
server_provider = Provider("server", "deepseek-r1:14b", "192.168.1.100:5000")
agent = CoderAgent("deepseek-r1:14b", "jarvis", "prompts/coder_agent.txt", server_provider)
ans = agent.process("What is the output of 5+5 in python ?")
print(ans)
pass

View File

@ -36,10 +36,4 @@ class FileAgent(Agent):
return answer, reasoning
if __name__ == "__main__":
from llm_provider import Provider
#local_provider = Provider("ollama", "deepseek-r1:14b", None)
server_provider = Provider("server", "deepseek-r1:14b", "192.168.1.100:5000")
agent = FileAgent("deepseek-r1:14b", "jarvis", "prompts/file_agent.txt", server_provider)
ans = agent.process("What is the content of the file toto.py ?")
print(ans)
pass

View File

@ -76,10 +76,10 @@ class PlannerAgent(Agent):
agents_tasks = self.parse_agent_tasks(json_plan)
if agents_tasks == (None, None):
return
pretty_print("▂▘ P L A N ▝▂", color="output")
pretty_print("▂▘ P L A N ▝▂", color="status")
for task_name, task in agents_tasks:
pretty_print(f"{task['agent']} -> {task['task']}", color="info")
pretty_print("▔▗ E N D ▖▔", color="output")
pretty_print(f"{task['agent']} -> {task['task']}", color="info")
pretty_print("▔▗ E N D ▖▔", color="status")
def process(self, prompt, speech_module) -> str:
ok = False
@ -117,7 +117,4 @@ class PlannerAgent(Agent):
return prev_agent_answer, ""
if __name__ == "__main__":
from llm_provider import Provider
server_provider = Provider("server", "deepseek-r1:14b", "192.168.1.100:5000")
agent = PlannerAgent("deepseek-r1:14b", "jarvis", "prompts/planner_agent.txt", server_provider)
ans = agent.process("Make a cool game to illustrate the current relation between USA and europe")
pass