diff --git a/book_maker/cli.py b/book_maker/cli.py index e95de94..bc3a147 100644 --- a/book_maker/cli.py +++ b/book_maker/cli.py @@ -137,6 +137,14 @@ def main(): metavar="MODEL", help="model to use, available: {%(choices)s}", ) + parser.add_argument( + "--ollama_model", + dest="ollama_model", + type=str, + default="ollama_model", + metavar="MODEL", + help="use ollama", + ) parser.add_argument( "--language", type=str, @@ -308,6 +316,9 @@ So you are close to reaching the limit. You have to choose your own value, there ): API_KEY = OPENAI_API_KEY # patch + elif options.ollama_model: + # any string is ok, can't be empty + API_KEY = "ollama" else: raise Exception( "OpenAI API key not provided, please google how to obtain it", @@ -365,6 +376,10 @@ So you are close to reaching the limit. You have to choose your own value, there # change api_base for issue #42 model_api_base = options.api_base + if options.ollama_model and not model_api_base: + # ollama default api_base + model_api_base = "http://localhost:11434/v1" + e = book_loader( options.book_name, translate_model, @@ -418,7 +433,10 @@ So you are close to reaching the limit. You have to choose your own value, there ) # TODO refactor, quick fix for gpt4 model if options.model == "chatgptapi": - e.translate_model.set_gpt35_models() + if options.ollama_model: + e.translate_model.set_gpt35_models(ollama_model=options.ollama_model) + else: + e.translate_model.set_gpt35_models() if options.model == "gpt4": e.translate_model.set_gpt4_models() if options.block_size > 0: diff --git a/book_maker/translator/chatgptapi_translator.py b/book_maker/translator/chatgptapi_translator.py index 8133ed1..bc83d8c 100644 --- a/book_maker/translator/chatgptapi_translator.py +++ b/book_maker/translator/chatgptapi_translator.py @@ -307,7 +307,10 @@ class ChatGPTAPI(Base): azure_deployment=self.deployment_id, ) - def set_gpt35_models(self): + def set_gpt35_models(self, ollama_model=""): + if ollama_model: + self.model_list = cycle([ollama_model]) + return # gpt3 all models for save the limit if self.deployment_id: self.model_list = cycle(["gpt-35-turbo"])