add ollama support (#396)

* add ollama support

* fix format

---------

Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
b1tg 2024-04-22 16:16:33 +08:00 committed by GitHub
parent 205c482fd8
commit 1ca0e78558
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 2 deletions

View File

@ -137,6 +137,14 @@ def main():
metavar="MODEL", metavar="MODEL",
help="model to use, available: {%(choices)s}", help="model to use, available: {%(choices)s}",
) )
parser.add_argument(
"--ollama_model",
dest="ollama_model",
type=str,
default="ollama_model",
metavar="MODEL",
help="use ollama",
)
parser.add_argument( parser.add_argument(
"--language", "--language",
type=str, type=str,
@ -308,6 +316,9 @@ So you are close to reaching the limit. You have to choose your own value, there
): ):
API_KEY = OPENAI_API_KEY API_KEY = OPENAI_API_KEY
# patch # patch
elif options.ollama_model:
# any string is ok, can't be empty
API_KEY = "ollama"
else: else:
raise Exception( raise Exception(
"OpenAI API key not provided, please google how to obtain it", "OpenAI API key not provided, please google how to obtain it",
@ -365,6 +376,10 @@ So you are close to reaching the limit. You have to choose your own value, there
# change api_base for issue #42 # change api_base for issue #42
model_api_base = options.api_base model_api_base = options.api_base
if options.ollama_model and not model_api_base:
# ollama default api_base
model_api_base = "http://localhost:11434/v1"
e = book_loader( e = book_loader(
options.book_name, options.book_name,
translate_model, translate_model,
@ -418,6 +433,9 @@ So you are close to reaching the limit. You have to choose your own value, there
) )
# TODO refactor, quick fix for gpt4 model # TODO refactor, quick fix for gpt4 model
if options.model == "chatgptapi": if options.model == "chatgptapi":
if options.ollama_model:
e.translate_model.set_gpt35_models(ollama_model=options.ollama_model)
else:
e.translate_model.set_gpt35_models() e.translate_model.set_gpt35_models()
if options.model == "gpt4": if options.model == "gpt4":
e.translate_model.set_gpt4_models() e.translate_model.set_gpt4_models()

View File

@ -307,7 +307,10 @@ class ChatGPTAPI(Base):
azure_deployment=self.deployment_id, azure_deployment=self.deployment_id,
) )
def set_gpt35_models(self): def set_gpt35_models(self, ollama_model=""):
if ollama_model:
self.model_list = cycle([ollama_model])
return
# gpt3 all models for save the limit # gpt3 all models for save the limit
if self.deployment_id: if self.deployment_id:
self.model_list = cycle(["gpt-35-turbo"]) self.model_list = cycle(["gpt-35-turbo"])