Gemini Enhancements (#428)

* chore: Bump google-generativeai and related dependencies

* feat: add support for --temperature option to gemini

* feat: add support for --interval option to gemini

* feat: add support for --model_list option to gemini

* feat: add support for --prompt option to gemini

* modify: model settings

* feat: add support for --use_context option to gemini

* feat: add support for rotate_key to gemini

* feat: add exponential backoff to gemini

* Update README.md

* fix: typos and apply black formatting

* Update make_test_ebook.yaml

* fix: cli

* fix: interval option implementation

* fix: interval for geminipro

* fix: recreate convo after rotating key
This commit is contained in:
risin42 2024-10-21 14:42:33 +09:00 committed by GitHub
parent 6912206cb1
commit 9261d92e20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 171 additions and 52 deletions

View File

@ -71,7 +71,7 @@ jobs:
- name: Rename and Upload ePub - name: Rename and Upload ePub
if: env.OPENAI_API_KEY != null if: env.OPENAI_API_KEY != null
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v3
with: with:
name: epub_output name: epub_output
path: "test_books/lemo_bilingual.epub" path: "test_books/lemo_bilingual.epub"

View File

@ -30,7 +30,8 @@ Find more info here for using liteLLM: https://github.com/BerriAI/litellm/blob/m
- If using chatgptapi, you can add `--use_context` to add a context paragraph to each passage sent to the model for translation (see below). - If using chatgptapi, you can add `--use_context` to add a context paragraph to each passage sent to the model for translation (see below).
- Support DeepL model [DeepL Translator](https://rapidapi.com/splintPRO/api/dpl-translator) need pay to get the token use `--model deepl --deepl_key ${deepl_key}` - Support DeepL model [DeepL Translator](https://rapidapi.com/splintPRO/api/dpl-translator) need pay to get the token use `--model deepl --deepl_key ${deepl_key}`
- Support DeepL free model `--model deeplfree` - Support DeepL free model `--model deeplfree`
- Support Google [Gemini](https://makersuite.google.com/app/apikey) model `--model gemini --gemini_key ${gemini_key}` - Support Google [Gemini](https://aistudio.google.com/app/apikey) model, use `--model gemini` for Gemini Flash or `--model geminipro` for Gemini Pro. `--gemini_key ${gemini_key}`
- If you want to use a specific model alias with Gemini (eg `gemini-1.5-flash-002` or `gemini-1.5-flash-8b-exp-0924`), you can use `--model gemini --model_list gemini-1.5-flash-002,gemini-1.5-flash-8b-exp-0924`. `--model_list` takes a comma-separated list of model aliases.
- Support [Claude](https://console.anthropic.com/docs) model, use `--model claude --claude_key ${claude_key}` - Support [Claude](https://console.anthropic.com/docs) model, use `--model claude --claude_key ${claude_key}`
- Support [Tencent TranSmart](https://transmart.qq.com) model (Free), use `--model tencentransmart` - Support [Tencent TranSmart](https://transmart.qq.com) model (Free), use `--model tencentransmart`
- Support [Ollama](https://github.com/ollama/ollama) self-host models, use `--ollama_model ${ollama_model_name}` - Support [Ollama](https://github.com/ollama/ollama) self-host models, use `--ollama_model ${ollama_model_name}`
@ -57,7 +58,7 @@ Find more info here for using liteLLM: https://github.com/BerriAI/litellm/blob/m
- `--accumulated_num` Wait for how many tokens have been accumulated before starting the translation. gpt3.5 limits the total_token to 4090. For example, if you use --accumulated_num 1600, maybe openai will - `--accumulated_num` Wait for how many tokens have been accumulated before starting the translation. gpt3.5 limits the total_token to 4090. For example, if you use --accumulated_num 1600, maybe openai will
output 2200 tokens and maybe 200 tokens for other messages in the system messages user messages, 1600+2200+200=4000, So you are close to reaching the limit. You have to choose your own output 2200 tokens and maybe 200 tokens for other messages in the system messages user messages, 1600+2200+200=4000, So you are close to reaching the limit. You have to choose your own
value, there is no way to know if the limit is reached before sending value, there is no way to know if the limit is reached before sending
- `--use_context` prompts the model to create a three-paragraph summary. If it's the beginning of the translation, it will summarize the entire passage sent (the size depending on `--accumulated_num`). For subsequent passages, it will amend the summary to include details from the most recent passage, creating a running one-paragraph context payload of the important details of the entire translated work. This improves consistency of flow and tone throughout the translation. This option is available for all ChatGPT-compatible models. - `--use_context` prompts the model to create a three-paragraph summary. If it's the beginning of the translation, it will summarize the entire passage sent (the size depending on `--accumulated_num`). For subsequent passages, it will amend the summary to include details from the most recent passage, creating a running one-paragraph context payload of the important details of the entire translated work. This improves consistency of flow and tone throughout the translation. This option is available for all ChatGPT-compatible models and Gemini models.
- Use `--context_paragraph_limit` to set a limit on the number of context paragraphs when using the `--use_context` option. - Use `--context_paragraph_limit` to set a limit on the number of context paragraphs when using the `--use_context` option.
- Use `--temperature` to set the temperature parameter for `chatgptapi`/`gpt4`/`claude` models. For example: `--temperature 0.7`. - Use `--temperature` to set the temperature parameter for `chatgptapi`/`gpt4`/`claude` models. For example: `--temperature 0.7`.
- Use `--block_size` to merge multiple paragraphs into one block. This may increase accuracy and speed up the process but can disturb the original format. Must be used with `--single_translate`. For example: `--block_size 5`. - Use `--block_size` to merge multiple paragraphs into one block. This may increase accuracy and speed up the process but can disturb the original format. Must be used with `--single_translate`. For example: `--block_size 5`.
@ -82,9 +83,12 @@ python3 make_book.py --book_name test_books/Lex_Fridman_episode_322.srt --openai
# Or translate the whole book # Or translate the whole book
python3 make_book.py --book_name test_books/animal_farm.epub --openai_key ${openai_key} --language zh-hans python3 make_book.py --book_name test_books/animal_farm.epub --openai_key ${openai_key} --language zh-hans
# Or translate the whole book using Gemini # Or translate the whole book using Gemini flash
python3 make_book.py --book_name test_books/animal_farm.epub --gemini_key ${gemini_key} --model gemini python3 make_book.py --book_name test_books/animal_farm.epub --gemini_key ${gemini_key} --model gemini
# Use a specific list of Gemini model aliases
python3 make_book.py --book_name test_books/animal_farm.epub --gemini_key ${gemini_key} --model gemini --model_list gemini-1.5-flash-002,gemini-1.5-flash-8b-exp-0924
# Set env OPENAI_API_KEY to ignore option --openai_key # Set env OPENAI_API_KEY to ignore option --openai_key
export OPENAI_API_KEY=${your_api_key} export OPENAI_API_KEY=${your_api_key}

View File

@ -290,7 +290,7 @@ So you are close to reaching the limit. You have to choose your own value, there
"--temperature", "--temperature",
type=float, type=float,
default=1.0, default=1.0,
help="temperature parameter for `chatgptapi`/`gpt4`/`claude`", help="temperature parameter for `chatgptapi`/`gpt4`/`claude`/`gemini`",
) )
parser.add_argument( parser.add_argument(
"--block_size", "--block_size",
@ -316,6 +316,12 @@ So you are close to reaching the limit. You have to choose your own value, there
action="store_true", action="store_true",
help="Use pre-generated batch translations to create files. Run with --batch first before using this option", help="Use pre-generated batch translations to create files. Run with --batch first before using this option",
) )
parser.add_argument(
"--interval",
type=float,
default=0.01,
help="Request interval in seconds (e.g., 0.1 for 100ms). Currently only supported for Gemini models. Default: 0.01",
)
options = parser.parse_args() options = parser.parse_args()
@ -366,7 +372,7 @@ So you are close to reaching the limit. You have to choose your own value, there
API_KEY = options.custom_api or env.get("BBM_CUSTOM_API") API_KEY = options.custom_api or env.get("BBM_CUSTOM_API")
if not API_KEY: if not API_KEY:
raise Exception("Please provide custom translate api") raise Exception("Please provide custom translate api")
elif options.model == "gemini": elif options.model in ["gemini", "geminipro"]:
API_KEY = options.gemini_key or env.get("BBM_GOOGLE_GEMINI_KEY") API_KEY = options.gemini_key or env.get("BBM_GOOGLE_GEMINI_KEY")
elif options.model == "groq": elif options.model == "groq":
API_KEY = options.groq_key or env.get("BBM_GROQ_API_KEY") API_KEY = options.groq_key or env.get("BBM_GROQ_API_KEY")
@ -481,6 +487,16 @@ So you are close to reaching the limit. You have to choose your own value, there
if options.batch_use_flag: if options.batch_use_flag:
e.batch_use_flag = options.batch_use_flag e.batch_use_flag = options.batch_use_flag
if options.model in ("gemini", "geminipro"):
e.translate_model.set_interval(options.interval)
if options.model == "gemini":
if options.model_list:
e.translate_model.set_model_list(options.model_list.split(","))
else:
e.translate_model.set_geminiflash_models()
if options.model == "geminipro":
e.translate_model.set_geminipro_models()
e.make_bilingual_book() e.make_bilingual_book()

View File

@ -21,6 +21,7 @@ MODEL_DICT = {
"deeplfree": DeepLFree, "deeplfree": DeepLFree,
"claude": Claude, "claude": Claude,
"gemini": Gemini, "gemini": Gemini,
"geminipro": Gemini,
"groq": GroqClient, "groq": GroqClient,
"tencentransmart": TencentTranSmart, "tencentransmart": TencentTranSmart,
"customapi": CustomAPI, "customapi": CustomAPI,

View File

@ -1,5 +1,7 @@
import re import re
import time import time
from os import environ
from itertools import cycle
import google.generativeai as genai import google.generativeai as genai
from google.generativeai.types.generation_types import ( from google.generativeai.types.generation_types import (
@ -11,23 +13,36 @@ from rich import print
from .base_translator import Base from .base_translator import Base
generation_config = { generation_config = {
"temperature": 0.7, "temperature": 1.0,
"top_p": 1, "top_p": 1,
"top_k": 1, "top_k": 1,
"max_output_tokens": 2048, "max_output_tokens": 8192,
} }
safety_settings = [ safety_settings = {
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, "HATE": "BLOCK_NONE",
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, "HARASSMENT": "BLOCK_NONE",
{ "SEXUAL": "BLOCK_NONE",
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "DANGEROUS": "BLOCK_NONE",
"threshold": "BLOCK_MEDIUM_AND_ABOVE", }
},
{ PROMPT_ENV_MAP = {
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "user": "BBM_GEMINIAPI_USER_MSG_TEMPLATE",
"threshold": "BLOCK_MEDIUM_AND_ABOVE", "system": "BBM_GEMINIAPI_SYS_MSG",
}, }
GEMINIPRO_MODEL_LIST = [
"gemini-1.5-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-pro-001",
"gemini-1.5-pro-002",
]
GEMINIFLASH_MODEL_LIST = [
"gemini-1.5-flash",
"gemini-1.5-flash-latest",
"gemini-1.5-flash-001",
"gemini-1.5-flash-002",
] ]
@ -38,20 +53,57 @@ class Gemini(Base):
DEFAULT_PROMPT = "Please help me to translate,`{text}` to {language}, please return only translated content not include the origin text" DEFAULT_PROMPT = "Please help me to translate,`{text}` to {language}, please return only translated content not include the origin text"
def __init__(self, key, language, **kwargs) -> None: def __init__(
genai.configure(api_key=key) self,
key,
language,
prompt_template=None,
prompt_sys_msg=None,
context_flag=False,
temperature=1.0,
**kwargs,
) -> None:
super().__init__(key, language) super().__init__(key, language)
self.context_flag = context_flag
self.prompt = (
prompt_template
or environ.get(PROMPT_ENV_MAP["user"])
or self.DEFAULT_PROMPT
)
self.prompt_sys_msg = (
prompt_sys_msg
or environ.get(PROMPT_ENV_MAP["system"])
or None # Allow None, but not empty string
)
genai.configure(api_key=next(self.keys))
generation_config["temperature"] = temperature
def create_convo(self):
model = genai.GenerativeModel( model = genai.GenerativeModel(
model_name="gemini-pro", model_name=self.model,
generation_config=generation_config, generation_config=generation_config,
safety_settings=safety_settings, safety_settings=safety_settings,
system_instruction=self.prompt_sys_msg,
) )
self.convo = model.start_chat() self.convo = model.start_chat()
# print(model) # Uncomment to debug and inspect the model details.
def rotate_model(self):
self.model = next(self.model_list)
self.create_convo()
print(f"Using model {self.model}")
def rotate_key(self): def rotate_key(self):
pass genai.configure(api_key=next(self.keys))
self.create_convo()
def translate(self, text): def translate(self, text):
delay = 1
exponential_base = 2
attempt_count = 0
max_attempts = 7
t_text = "" t_text = ""
print(text) print(text)
# same for caiyun translate src issue #279 gemini for #374 # same for caiyun translate src issue #279 gemini for #374
@ -60,32 +112,78 @@ class Gemini(Base):
if len(text_list) > 1: if len(text_list) > 1:
if text_list[0].isdigit(): if text_list[0].isdigit():
num = text_list[0] num = text_list[0]
while attempt_count < max_attempts:
try: try:
self.convo.send_message( self.convo.send_message(
self.DEFAULT_PROMPT.format(text=text, language=self.language) self.prompt.format(text=text, language=self.language)
) )
print(text)
t_text = self.convo.last.text.strip() t_text = self.convo.last.text.strip()
break
except StopCandidateException as e: except StopCandidateException as e:
match = re.search(r'content\s*{\s*parts\s*{\s*text:\s*"([^"]+)"', str(e)) print(
if match: f"Translation failed due to StopCandidateException: {e} Attempting to switch model..."
t_text = match.group(1) )
t_text = re.sub(r"\\n", "\n", t_text) self.rotate_model()
else:
t_text = "Can not translate"
except BlockedPromptException as e: except BlockedPromptException as e:
print(str(e)) print(
t_text = "Can not translate by SAFETY reason.(因安全问题不能翻译)" f"Translation failed due to BlockedPromptException: {e} Attempting to switch model..."
)
self.rotate_model()
except Exception as e: except Exception as e:
print(str(e)) print(
t_text = "Can not translate by other reason.(因安全问题不能翻译)" f"Translation failed due to {type(e).__name__}: {e} Will sleep {delay} seconds"
)
time.sleep(delay)
delay *= exponential_base
self.rotate_key()
if attempt_count >= 1:
self.rotate_model()
attempt_count += 1
if attempt_count == max_attempts:
print(f"Translation failed after {max_attempts} attempts.")
return
if self.context_flag:
if len(self.convo.history) > 10: if len(self.convo.history) > 10:
self.convo.history = self.convo.history[2:] self.convo.history = self.convo.history[2:]
else:
self.convo.history = []
print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]") print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]")
# for limit # for rate limit(RPM)
time.sleep(0.5) time.sleep(self.interval)
if num: if num:
t_text = str(num) + "\n" + t_text t_text = str(num) + "\n" + t_text
return t_text return t_text
def set_interval(self, interval):
self.interval = interval
def set_geminipro_models(self):
self.set_models(GEMINIPRO_MODEL_LIST)
def set_geminiflash_models(self):
self.set_models(GEMINIFLASH_MODEL_LIST)
def set_models(self, allowed_models):
available_models = [
re.sub(r"^models/", "", i.name) for i in genai.list_models()
]
model_list = sorted(
list(set(available_models) & set(allowed_models)),
key=allowed_models.index,
)
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
self.rotate_model()
def set_model_list(self, model_list):
# keep the order of input
model_list = sorted(list(set(model_list)), key=model_list.index)
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
self.rotate_model()

View File

@ -25,13 +25,13 @@ exceptiongroup==1.2.1; python_version < "3.11"
filelock==3.14.0 filelock==3.14.0
frozenlist==1.4.1 frozenlist==1.4.1
fsspec==2024.3.1 fsspec==2024.3.1
google-ai-generativelanguage==0.6.4 google-ai-generativelanguage==0.6.10
google-api-core==2.19.0 google-api-core==2.21.0
google-api-python-client==2.127.0 google-api-python-client==2.149.0
google-auth==2.29.0 google-auth==2.35.0
google-auth-httplib2==0.2.0 google-auth-httplib2==0.2.0
google-generativeai==0.5.4 google-generativeai==0.8.3
googleapis-common-protos==1.63.0 googleapis-common-protos==1.65.0
groq==0.8.0 groq==0.8.0
grpcio==1.63.0 grpcio==1.63.0
grpcio-status==1.62.2 grpcio-status==1.62.2