diff --git a/README.md b/README.md index ee4d194..9a790de 100644 --- a/README.md +++ b/README.md @@ -133,7 +133,7 @@ Note: For Windows or macOS, use ipconfig or ifconfig respectively to find the IP Clone the repository and then, run the script `stream_llm.py` in `server/` ```sh -python3 stream_llm.py +python3 server_ollama.py ``` ### 2️⃣ **Run it** diff --git a/server/server_ollama.py b/server/server_ollama.py index 39327b2..b89f025 100644 --- a/server/server_ollama.py +++ b/server/server_ollama.py @@ -10,10 +10,9 @@ log.setLevel(logging.ERROR) app = Flask(__name__) # Shared state with thread-safe locks - class Config: def __init__(self): - self.model = None + self.model = None self.known_models = [] self.allowed_models = [] self.model_name = None @@ -23,7 +22,7 @@ class Config: data = json.load(f) self.known_models = data['known_models'] self.model_name = data['model_name'] - + def validate_model(self, model): if model not in self.known_models: raise ValueError(f"Model {model} is not known") @@ -37,8 +36,8 @@ class GenerationState: self.model = None state = GenerationState() - -def generate_response(history): + +def generate_response(history): # Only takes history as an argument global state try: with state.lock: @@ -47,7 +46,7 @@ def generate_response(history): state.current_buffer = "" stream = ollama.chat( - model=state.model, + model=state.model, # Access state.model directly messages=history, stream=True, ) @@ -70,13 +69,14 @@ def generate_response(history): def start_generation(): global state data = request.get_json() - + with state.lock: if state.is_generating: return jsonify({"error": "Generation already in progress"}), 400 - + history = data.get('messages', []) - threading.Thread(target=generate_response, args=(history, state.model)).start() + # Pass only history to the thread + threading.Thread(target=generate_response, args=(history,)).start() # Note the comma to make it a single-element tuple return jsonify({"message": "Generation started"}), 202 @app.route('/get_updated_sentence') diff --git a/sources/utility.py b/sources/utility.py index a9a29c0..ad1c161 100644 --- a/sources/utility.py +++ b/sources/utility.py @@ -44,7 +44,7 @@ def pretty_print(text, color = "info"): "failure": "red", "status": "light_green", "code": "light_blue", - "warning": "yello", + "warning": "yellow", "output": "cyan", "default": "black" }