From c0172ba9863ec1d21dc8291b9dd2909af42e569e Mon Sep 17 00:00:00 2001 From: martin legrand Date: Fri, 7 Mar 2025 11:01:14 +0100 Subject: [PATCH] Feat : llm server config --- server/stream_llm.py | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/server/stream_llm.py b/server/stream_llm.py index 2160cdf..cccbf58 100644 --- a/server/stream_llm.py +++ b/server/stream_llm.py @@ -2,25 +2,43 @@ from flask import Flask, jsonify, request import threading import ollama import logging +import json log = logging.getLogger('werkzeug') log.setLevel(logging.ERROR) app = Flask(__name__) -model = 'deepseek-r1:14b' - # Shared state with thread-safe locks + +class Config: + def __init__(self): + self.model = None + self.known_models = [] + self.allowed_models = [] + self.model_name = None + + def load(self): + with open('config.json', 'r') as f: + data = json.load(f) + self.known_models = data['known_models'] + self.model_name = data['model_name'] + + def validate_model(self, model): + if model not in self.known_models: + raise ValueError(f"Model {model} is not known") + class GenerationState: def __init__(self): self.lock = threading.Lock() self.last_complete_sentence = "" self.current_buffer = "" self.is_generating = False + self.model = None state = GenerationState() -def generate_response(history, model): +def generate_response(history): global state try: with state.lock: @@ -29,21 +47,18 @@ def generate_response(history, model): state.current_buffer = "" stream = ollama.chat( - model=model, + model=state.model, messages=history, stream=True, ) - for chunk in stream: content = chunk['message']['content'] print(content, end='', flush=True) - with state.lock: state.current_buffer += content - except ollama.ResponseError as e: if e.status_code == 404: - ollama.pull(model) + ollama.pull(state.model) with state.lock: state.is_generating = False print(f"Error: {e}") @@ -62,7 +77,7 @@ def start_generation(): history = data.get('messages', []) # Start generation in background thread - threading.Thread(target=generate_response, args=(history, model)).start() + threading.Thread(target=generate_response, args=(history, state.model)).start() return jsonify({"message": "Generation started"}), 202 @app.route('/get_updated_sentence') @@ -75,4 +90,8 @@ def get_updated_sentence(): }) if __name__ == '__main__': + config = Config() + config.load() + config.validate_model(config.model_name) + state.model = config.model_name app.run(host='0.0.0.0', port=5000, debug=False, threaded=True) \ No newline at end of file