Fix: server script

This commit is contained in:
martin legrand 2025-03-23 21:07:44 +01:00
parent 8b5bb28c94
commit 9448ac1012

View File

@ -1,47 +1,30 @@
#!/usr/bin python3
from flask import Flask, jsonify, request from flask import Flask, jsonify, request
import threading import threading
import ollama import ollama
import logging import logging
import json
log = logging.getLogger('werkzeug') log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR) log.setLevel(logging.ERROR)
app = Flask(__name__) app = Flask(__name__)
model = 'deepseek-r1:14b'
# Shared state with thread-safe locks # Shared state with thread-safe locks
class Config:
def __init__(self):
self.model = None
self.known_models = []
self.allowed_models = []
self.model_name = None
def load(self):
with open('config.json', 'r') as f:
data = json.load(f)
self.known_models = data['known_models']
self.model_name = data['model_name']
def validate_model(self, model):
if model not in self.known_models:
raise ValueError(f"Model {model} is not known")
class GenerationState: class GenerationState:
def __init__(self): def __init__(self):
self.lock = threading.Lock() self.lock = threading.Lock()
self.last_complete_sentence = "" self.last_complete_sentence = ""
self.current_buffer = "" self.current_buffer = ""
self.is_generating = False self.is_generating = False
self.model = None
state = GenerationState() state = GenerationState()
def generate_response_vllm(history): def generate_response(history, model):
pass
def generate_response_ollama(history): # Only takes history as an argument
global state global state
print("using model:::::::", model)
try: try:
with state.lock: with state.lock:
state.is_generating = True state.is_generating = True
@ -49,18 +32,21 @@ def generate_response_ollama(history): # Only takes history as an argument
state.current_buffer = "" state.current_buffer = ""
stream = ollama.chat( stream = ollama.chat(
model=state.model, # Access state.model directly model=model,
messages=history, messages=history,
stream=True, stream=True,
) )
for chunk in stream: for chunk in stream:
content = chunk['message']['content'] content = chunk['message']['content']
print(content, end='', flush=True) print(content, end='', flush=True)
with state.lock: with state.lock:
state.current_buffer += content state.current_buffer += content
except ollama.ResponseError as e: except ollama.ResponseError as e:
if e.status_code == 404: if e.status_code == 404:
ollama.pull(state.model) ollama.pull(model)
with state.lock: with state.lock:
state.is_generating = False state.is_generating = False
print(f"Error: {e}") print(f"Error: {e}")
@ -78,8 +64,8 @@ def start_generation():
return jsonify({"error": "Generation already in progress"}), 400 return jsonify({"error": "Generation already in progress"}), 400
history = data.get('messages', []) history = data.get('messages', [])
# Pass only history to the thread # Start generation in background thread
threading.Thread(target=generate_response, args=(history,)).start() # Note the comma to make it a single-element tuple threading.Thread(target=generate_response, args=(history, model)).start()
return jsonify({"message": "Generation started"}), 202 return jsonify({"message": "Generation started"}), 202
@app.route('/get_updated_sentence') @app.route('/get_updated_sentence')
@ -92,8 +78,4 @@ def get_updated_sentence():
}) })
if __name__ == '__main__': if __name__ == '__main__':
config = Config() app.run(host='0.0.0.0', threaded=True, debug=True, port=5000)
config.load()
config.validate_model(config.model_name)
state.model = config.model_name
app.run(host='0.0.0.0', port=5000, debug=False, threaded=True)