Merge pull request #28 from Fosowl/dev

Fix #23 #25 #26
This commit is contained in:
Martin 2025-03-15 18:50:59 +01:00 committed by GitHub
commit fda7b47faf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 11 deletions

View File

@ -133,7 +133,7 @@ Note: For Windows or macOS, use ipconfig or ifconfig respectively to find the IP
Clone the repository and then, run the script `stream_llm.py` in `server/`
```sh
python3 stream_llm.py
python3 server_ollama.py
```
### 2**Run it**

View File

@ -10,10 +10,9 @@ log.setLevel(logging.ERROR)
app = Flask(__name__)
# Shared state with thread-safe locks
class Config:
def __init__(self):
self.model = None
self.model = None
self.known_models = []
self.allowed_models = []
self.model_name = None
@ -23,7 +22,7 @@ class Config:
data = json.load(f)
self.known_models = data['known_models']
self.model_name = data['model_name']
def validate_model(self, model):
if model not in self.known_models:
raise ValueError(f"Model {model} is not known")
@ -37,8 +36,8 @@ class GenerationState:
self.model = None
state = GenerationState()
def generate_response(history):
def generate_response(history): # Only takes history as an argument
global state
try:
with state.lock:
@ -47,7 +46,7 @@ def generate_response(history):
state.current_buffer = ""
stream = ollama.chat(
model=state.model,
model=state.model, # Access state.model directly
messages=history,
stream=True,
)
@ -70,13 +69,14 @@ def generate_response(history):
def start_generation():
global state
data = request.get_json()
with state.lock:
if state.is_generating:
return jsonify({"error": "Generation already in progress"}), 400
history = data.get('messages', [])
threading.Thread(target=generate_response, args=(history, state.model)).start()
# Pass only history to the thread
threading.Thread(target=generate_response, args=(history,)).start() # Note the comma to make it a single-element tuple
return jsonify({"message": "Generation started"}), 202
@app.route('/get_updated_sentence')

View File

@ -44,7 +44,7 @@ def pretty_print(text, color = "info"):
"failure": "red",
"status": "light_green",
"code": "light_blue",
"warning": "yello",
"warning": "yellow",
"output": "cyan",
"default": "black"
}