diff --git a/book_maker/cli.py b/book_maker/cli.py index 4f1437f..fea8470 100644 --- a/book_maker/cli.py +++ b/book_maker/cli.py @@ -372,7 +372,7 @@ So you are close to reaching the limit. You have to choose your own value, there API_KEY = options.custom_api or env.get("BBM_CUSTOM_API") if not API_KEY: raise Exception("Please provide custom translate api") - elif options.model == "gemini": + elif options.model in ["gemini", "geminipro"]: API_KEY = options.gemini_key or env.get("BBM_GOOGLE_GEMINI_KEY") elif options.model == "groq": API_KEY = options.groq_key or env.get("BBM_GROQ_API_KEY") @@ -488,6 +488,15 @@ So you are close to reaching the limit. You have to choose your own value, there if options.batch_use_flag: e.batch_use_flag = options.batch_use_flag + if options.model == "gemini": + if options.model_list: + e.translate_model.set_model_list(options.model_list.split(",")) + else: + e.translate_model.set_geminiflash_models() + + if options.model == "geminipro": + e.translate_model.set_geminipro_models() + e.make_bilingual_book() diff --git a/book_maker/translator/__init__.py b/book_maker/translator/__init__.py index 8f848d5..3044027 100644 --- a/book_maker/translator/__init__.py +++ b/book_maker/translator/__init__.py @@ -21,6 +21,7 @@ MODEL_DICT = { "deeplfree": DeepLFree, "claude": Claude, "gemini": Gemini, + "geminipro": Gemini, "groq": GroqClient, "tencentransmart": TencentTranSmart, "customapi": CustomAPI, diff --git a/book_maker/translator/gemini_translator.py b/book_maker/translator/gemini_translator.py index fad32ae..0241d03 100644 --- a/book_maker/translator/gemini_translator.py +++ b/book_maker/translator/gemini_translator.py @@ -1,5 +1,7 @@ import re import time +from os import environ +from itertools import cycle import google.generativeai as genai from google.generativeai.types.generation_types import ( @@ -28,6 +30,19 @@ safety_settings = [ "category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE", }, + +GEMINIPRO_MODEL_LIST = [ + "gemini-1.5-pro", + "gemini-1.5-pro-latest", + "gemini-1.5-pro-001", + "gemini-1.5-pro-002", +] + +GEMINIFLASH_MODEL_LIST = [ + "gemini-1.5-flash", + "gemini-1.5-flash-latest", + "gemini-1.5-flash-001", + "gemini-1.5-flash-002", ] @@ -50,11 +65,17 @@ class Gemini(Base): self.interval = interval generation_config["temperature"] = temperature model = genai.GenerativeModel( - model_name="gemini-pro", + model_name=self.model, generation_config=generation_config, safety_settings=safety_settings, ) self.convo = model.start_chat() + # print(model) # Uncomment to debug and inspect the model details. + + def rotate_model(self): + self.model = next(self.model_list) + self.create_convo() + print(f"Using model {self.model}") def rotate_key(self): pass @@ -97,3 +118,28 @@ class Gemini(Base): if num: t_text = str(num) + "\n" + t_text return t_text + + def set_geminipro_models(self): + self.set_models(GEMINIPRO_MODEL_LIST) + + def set_geminiflash_models(self): + self.set_models(GEMINIFLASH_MODEL_LIST) + + def set_models(self, allowed_models): + available_models = [ + re.sub(r"^models/", "", i.name) for i in genai.list_models() + ] + model_list = sorted( + list(set(available_models) & set(allowed_models)), + key=allowed_models.index, + ) + print(f"Using model list {model_list}") + self.model_list = cycle(model_list) + self.rotate_model() + + def set_model_list(self, model_list): + # keep the order of input + model_list = sorted(list(set(model_list)), key=model_list.index) + print(f"Using model list {model_list}") + self.model_list = cycle(model_list) + self.rotate_model()