mirror of
https://github.com/yihong0618/bilingual_book_maker.git
synced 2025-06-05 19:15:34 +00:00
feat: add support for --model_list option to gemini
This commit is contained in:
parent
5ad87bca4f
commit
0e13949b19
@ -372,7 +372,7 @@ So you are close to reaching the limit. You have to choose your own value, there
|
|||||||
API_KEY = options.custom_api or env.get("BBM_CUSTOM_API")
|
API_KEY = options.custom_api or env.get("BBM_CUSTOM_API")
|
||||||
if not API_KEY:
|
if not API_KEY:
|
||||||
raise Exception("Please provide custom translate api")
|
raise Exception("Please provide custom translate api")
|
||||||
elif options.model == "gemini":
|
elif options.model in ["gemini", "geminipro"]:
|
||||||
API_KEY = options.gemini_key or env.get("BBM_GOOGLE_GEMINI_KEY")
|
API_KEY = options.gemini_key or env.get("BBM_GOOGLE_GEMINI_KEY")
|
||||||
elif options.model == "groq":
|
elif options.model == "groq":
|
||||||
API_KEY = options.groq_key or env.get("BBM_GROQ_API_KEY")
|
API_KEY = options.groq_key or env.get("BBM_GROQ_API_KEY")
|
||||||
@ -488,6 +488,15 @@ So you are close to reaching the limit. You have to choose your own value, there
|
|||||||
if options.batch_use_flag:
|
if options.batch_use_flag:
|
||||||
e.batch_use_flag = options.batch_use_flag
|
e.batch_use_flag = options.batch_use_flag
|
||||||
|
|
||||||
|
if options.model == "gemini":
|
||||||
|
if options.model_list:
|
||||||
|
e.translate_model.set_model_list(options.model_list.split(","))
|
||||||
|
else:
|
||||||
|
e.translate_model.set_geminiflash_models()
|
||||||
|
|
||||||
|
if options.model == "geminipro":
|
||||||
|
e.translate_model.set_geminipro_models()
|
||||||
|
|
||||||
e.make_bilingual_book()
|
e.make_bilingual_book()
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ MODEL_DICT = {
|
|||||||
"deeplfree": DeepLFree,
|
"deeplfree": DeepLFree,
|
||||||
"claude": Claude,
|
"claude": Claude,
|
||||||
"gemini": Gemini,
|
"gemini": Gemini,
|
||||||
|
"geminipro": Gemini,
|
||||||
"groq": GroqClient,
|
"groq": GroqClient,
|
||||||
"tencentransmart": TencentTranSmart,
|
"tencentransmart": TencentTranSmart,
|
||||||
"customapi": CustomAPI,
|
"customapi": CustomAPI,
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from os import environ
|
||||||
|
from itertools import cycle
|
||||||
|
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
from google.generativeai.types.generation_types import (
|
from google.generativeai.types.generation_types import (
|
||||||
@ -28,6 +30,19 @@ safety_settings = [
|
|||||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
GEMINIPRO_MODEL_LIST = [
|
||||||
|
"gemini-1.5-pro",
|
||||||
|
"gemini-1.5-pro-latest",
|
||||||
|
"gemini-1.5-pro-001",
|
||||||
|
"gemini-1.5-pro-002",
|
||||||
|
]
|
||||||
|
|
||||||
|
GEMINIFLASH_MODEL_LIST = [
|
||||||
|
"gemini-1.5-flash",
|
||||||
|
"gemini-1.5-flash-latest",
|
||||||
|
"gemini-1.5-flash-001",
|
||||||
|
"gemini-1.5-flash-002",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -50,11 +65,17 @@ class Gemini(Base):
|
|||||||
self.interval = interval
|
self.interval = interval
|
||||||
generation_config["temperature"] = temperature
|
generation_config["temperature"] = temperature
|
||||||
model = genai.GenerativeModel(
|
model = genai.GenerativeModel(
|
||||||
model_name="gemini-pro",
|
model_name=self.model,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
safety_settings=safety_settings,
|
safety_settings=safety_settings,
|
||||||
)
|
)
|
||||||
self.convo = model.start_chat()
|
self.convo = model.start_chat()
|
||||||
|
# print(model) # Uncomment to debug and inspect the model details.
|
||||||
|
|
||||||
|
def rotate_model(self):
|
||||||
|
self.model = next(self.model_list)
|
||||||
|
self.create_convo()
|
||||||
|
print(f"Using model {self.model}")
|
||||||
|
|
||||||
def rotate_key(self):
|
def rotate_key(self):
|
||||||
pass
|
pass
|
||||||
@ -97,3 +118,28 @@ class Gemini(Base):
|
|||||||
if num:
|
if num:
|
||||||
t_text = str(num) + "\n" + t_text
|
t_text = str(num) + "\n" + t_text
|
||||||
return t_text
|
return t_text
|
||||||
|
|
||||||
|
def set_geminipro_models(self):
|
||||||
|
self.set_models(GEMINIPRO_MODEL_LIST)
|
||||||
|
|
||||||
|
def set_geminiflash_models(self):
|
||||||
|
self.set_models(GEMINIFLASH_MODEL_LIST)
|
||||||
|
|
||||||
|
def set_models(self, allowed_models):
|
||||||
|
available_models = [
|
||||||
|
re.sub(r"^models/", "", i.name) for i in genai.list_models()
|
||||||
|
]
|
||||||
|
model_list = sorted(
|
||||||
|
list(set(available_models) & set(allowed_models)),
|
||||||
|
key=allowed_models.index,
|
||||||
|
)
|
||||||
|
print(f"Using model list {model_list}")
|
||||||
|
self.model_list = cycle(model_list)
|
||||||
|
self.rotate_model()
|
||||||
|
|
||||||
|
def set_model_list(self, model_list):
|
||||||
|
# keep the order of input
|
||||||
|
model_list = sorted(list(set(model_list)), key=model_list.index)
|
||||||
|
print(f"Using model list {model_list}")
|
||||||
|
self.model_list = cycle(model_list)
|
||||||
|
self.rotate_model()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user