add ollama support (#396)

* add ollama support

* fix format

---------

Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
b1tg 2024-04-22 16:16:33 +08:00 committed by GitHub
parent 205c482fd8
commit 1ca0e78558
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 2 deletions

View File

@ -137,6 +137,14 @@ def main():
metavar="MODEL",
help="model to use, available: {%(choices)s}",
)
parser.add_argument(
"--ollama_model",
dest="ollama_model",
type=str,
default="ollama_model",
metavar="MODEL",
help="use ollama",
)
parser.add_argument(
"--language",
type=str,
@ -308,6 +316,9 @@ So you are close to reaching the limit. You have to choose your own value, there
):
API_KEY = OPENAI_API_KEY
# patch
elif options.ollama_model:
# any string is ok, can't be empty
API_KEY = "ollama"
else:
raise Exception(
"OpenAI API key not provided, please google how to obtain it",
@ -365,6 +376,10 @@ So you are close to reaching the limit. You have to choose your own value, there
# change api_base for issue #42
model_api_base = options.api_base
if options.ollama_model and not model_api_base:
# ollama default api_base
model_api_base = "http://localhost:11434/v1"
e = book_loader(
options.book_name,
translate_model,
@ -418,7 +433,10 @@ So you are close to reaching the limit. You have to choose your own value, there
)
# TODO refactor, quick fix for gpt4 model
if options.model == "chatgptapi":
e.translate_model.set_gpt35_models()
if options.ollama_model:
e.translate_model.set_gpt35_models(ollama_model=options.ollama_model)
else:
e.translate_model.set_gpt35_models()
if options.model == "gpt4":
e.translate_model.set_gpt4_models()
if options.block_size > 0:

View File

@ -307,7 +307,10 @@ class ChatGPTAPI(Base):
azure_deployment=self.deployment_id,
)
def set_gpt35_models(self):
def set_gpt35_models(self, ollama_model=""):
if ollama_model:
self.model_list = cycle([ollama_model])
return
# gpt3 all models for save the limit
if self.deployment_id:
self.model_list = cycle(["gpt-35-turbo"])