adjusted context size management

This commit is contained in:
tcsenpai 2024-10-14 11:56:55 +02:00
parent 96ad736734
commit 629cc5b881

View File

@ -11,6 +11,13 @@ class OllamaClient:
def __init__(self, base_url, model): def __init__(self, base_url, model):
self.base_url = base_url self.base_url = base_url
self.model = model self.model = model
self.context_size_table = {"llama3.1": 128000, "mistral-nemo": 128000}
self.context_size = 2048
if self.model not in self.context_size_table:
print(f"Model {self.model} not found in context size table: using default {self.context_size}")
else:
self.context_size = self.context_size_table[self.model]
print(f"Using context size {self.context_size} for model {self.model}")
def get_models(self): def get_models(self):
url = f"{self.base_url}/api/tags" url = f"{self.base_url}/api/tags"
@ -19,7 +26,7 @@ class OllamaClient:
response_json = response.json() response_json = response.json()
all_models = response_json["models"] all_models = response_json["models"]
for model in all_models: for model in all_models:
models.append(model["name"]) models.append(model["name"])
return models return models
def generate(self, prompt): def generate(self, prompt):
@ -28,6 +35,7 @@ class OllamaClient:
"model": self.model, "model": self.model,
"prompt": prompt, "prompt": prompt,
"stream": False, "stream": False,
"num_ctx": self.context_size,
} }
response = requests.post(url, json=data) response = requests.post(url, json=data)
if response.status_code == 200: if response.status_code == 200: