diff --git a/src/ollama_client.py b/src/ollama_client.py index 3d0a150..d234885 100644 --- a/src/ollama_client.py +++ b/src/ollama_client.py @@ -11,6 +11,13 @@ class OllamaClient: def __init__(self, base_url, model): self.base_url = base_url self.model = model + self.context_size_table = {"llama3.1": 128000, "mistral-nemo": 128000} + self.context_size = 2048 + if self.model not in self.context_size_table: + print(f"Model {self.model} not found in context size table: using default {self.context_size}") + else: + self.context_size = self.context_size_table[self.model] + print(f"Using context size {self.context_size} for model {self.model}") def get_models(self): url = f"{self.base_url}/api/tags" @@ -19,7 +26,7 @@ class OllamaClient: response_json = response.json() all_models = response_json["models"] for model in all_models: - models.append(model["name"]) + models.append(model["name"]) return models def generate(self, prompt): @@ -28,6 +35,7 @@ class OllamaClient: "model": self.model, "prompt": prompt, "stream": False, + "num_ctx": self.context_size, } response = requests.post(url, json=data) if response.status_code == 200: