feat: add support for --model_list option to gemini

This commit is contained in:
Risin 2024-10-15 18:44:33 +09:00
parent 5ad87bca4f
commit 0e13949b19
3 changed files with 58 additions and 2 deletions

View File

@ -372,7 +372,7 @@ So you are close to reaching the limit. You have to choose your own value, there
API_KEY = options.custom_api or env.get("BBM_CUSTOM_API") API_KEY = options.custom_api or env.get("BBM_CUSTOM_API")
if not API_KEY: if not API_KEY:
raise Exception("Please provide custom translate api") raise Exception("Please provide custom translate api")
elif options.model == "gemini": elif options.model in ["gemini", "geminipro"]:
API_KEY = options.gemini_key or env.get("BBM_GOOGLE_GEMINI_KEY") API_KEY = options.gemini_key or env.get("BBM_GOOGLE_GEMINI_KEY")
elif options.model == "groq": elif options.model == "groq":
API_KEY = options.groq_key or env.get("BBM_GROQ_API_KEY") API_KEY = options.groq_key or env.get("BBM_GROQ_API_KEY")
@ -488,6 +488,15 @@ So you are close to reaching the limit. You have to choose your own value, there
if options.batch_use_flag: if options.batch_use_flag:
e.batch_use_flag = options.batch_use_flag e.batch_use_flag = options.batch_use_flag
if options.model == "gemini":
if options.model_list:
e.translate_model.set_model_list(options.model_list.split(","))
else:
e.translate_model.set_geminiflash_models()
if options.model == "geminipro":
e.translate_model.set_geminipro_models()
e.make_bilingual_book() e.make_bilingual_book()

View File

@ -21,6 +21,7 @@ MODEL_DICT = {
"deeplfree": DeepLFree, "deeplfree": DeepLFree,
"claude": Claude, "claude": Claude,
"gemini": Gemini, "gemini": Gemini,
"geminipro": Gemini,
"groq": GroqClient, "groq": GroqClient,
"tencentransmart": TencentTranSmart, "tencentransmart": TencentTranSmart,
"customapi": CustomAPI, "customapi": CustomAPI,

View File

@ -1,5 +1,7 @@
import re import re
import time import time
from os import environ
from itertools import cycle
import google.generativeai as genai import google.generativeai as genai
from google.generativeai.types.generation_types import ( from google.generativeai.types.generation_types import (
@ -28,6 +30,19 @@ safety_settings = [
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE", "threshold": "BLOCK_MEDIUM_AND_ABOVE",
}, },
GEMINIPRO_MODEL_LIST = [
"gemini-1.5-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-pro-001",
"gemini-1.5-pro-002",
]
GEMINIFLASH_MODEL_LIST = [
"gemini-1.5-flash",
"gemini-1.5-flash-latest",
"gemini-1.5-flash-001",
"gemini-1.5-flash-002",
] ]
@ -50,11 +65,17 @@ class Gemini(Base):
self.interval = interval self.interval = interval
generation_config["temperature"] = temperature generation_config["temperature"] = temperature
model = genai.GenerativeModel( model = genai.GenerativeModel(
model_name="gemini-pro", model_name=self.model,
generation_config=generation_config, generation_config=generation_config,
safety_settings=safety_settings, safety_settings=safety_settings,
) )
self.convo = model.start_chat() self.convo = model.start_chat()
# print(model) # Uncomment to debug and inspect the model details.
def rotate_model(self):
self.model = next(self.model_list)
self.create_convo()
print(f"Using model {self.model}")
def rotate_key(self): def rotate_key(self):
pass pass
@ -97,3 +118,28 @@ class Gemini(Base):
if num: if num:
t_text = str(num) + "\n" + t_text t_text = str(num) + "\n" + t_text
return t_text return t_text
def set_geminipro_models(self):
self.set_models(GEMINIPRO_MODEL_LIST)
def set_geminiflash_models(self):
self.set_models(GEMINIFLASH_MODEL_LIST)
def set_models(self, allowed_models):
available_models = [
re.sub(r"^models/", "", i.name) for i in genai.list_models()
]
model_list = sorted(
list(set(available_models) & set(allowed_models)),
key=allowed_models.index,
)
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
self.rotate_model()
def set_model_list(self, model_list):
# keep the order of input
model_list = sorted(list(set(model_list)), key=model_list.index)
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
self.rotate_model()