mirror of
https://github.com/tcsenpai/youlama.git
synced 2025-06-07 03:35:41 +00:00
adjusted context size management
This commit is contained in:
parent
96ad736734
commit
629cc5b881
@ -11,6 +11,13 @@ class OllamaClient:
|
|||||||
def __init__(self, base_url, model):
|
def __init__(self, base_url, model):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.context_size_table = {"llama3.1": 128000, "mistral-nemo": 128000}
|
||||||
|
self.context_size = 2048
|
||||||
|
if self.model not in self.context_size_table:
|
||||||
|
print(f"Model {self.model} not found in context size table: using default {self.context_size}")
|
||||||
|
else:
|
||||||
|
self.context_size = self.context_size_table[self.model]
|
||||||
|
print(f"Using context size {self.context_size} for model {self.model}")
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
url = f"{self.base_url}/api/tags"
|
url = f"{self.base_url}/api/tags"
|
||||||
@ -19,7 +26,7 @@ class OllamaClient:
|
|||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
all_models = response_json["models"]
|
all_models = response_json["models"]
|
||||||
for model in all_models:
|
for model in all_models:
|
||||||
models.append(model["name"])
|
models.append(model["name"])
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def generate(self, prompt):
|
def generate(self, prompt):
|
||||||
@ -28,6 +35,7 @@ class OllamaClient:
|
|||||||
"model": self.model,
|
"model": self.model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
|
"num_ctx": self.context_size,
|
||||||
}
|
}
|
||||||
response = requests.post(url, json=data)
|
response = requests.post(url, json=data)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user