bilingual_book_maker/book_maker/translator/chatgptapi_translator.py
yihong0618 eae40a6a7b fix: typo
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2023-12-26 17:53:39 +08:00

305 lines
9.3 KiB
Python

import re
import time
from copy import copy
from os import environ
from itertools import cycle
from openai import AzureOpenAI, OpenAI
from rich import print
from .base_translator import Base
PROMPT_ENV_MAP = {
"user": "BBM_CHATGPTAPI_USER_MSG_TEMPLATE",
"system": "BBM_CHATGPTAPI_SYS_MSG",
}
GPT35_MODEL_LIST = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
]
GPT4_MODEL_LIST = [
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-0613",
"gpt-4-0314",
]
class ChatGPTAPI(Base):
DEFAULT_PROMPT = "Please help me to translate,`{text}` to {language}, please return only translated content not include the origin text"
def __init__(
self,
key,
language,
api_base=None,
prompt_template=None,
prompt_sys_msg=None,
temperature=1.0,
**kwargs,
) -> None:
super().__init__(key, language)
self.key_len = len(key.split(","))
self.openai_client = OpenAI(api_key=key, base_url=api_base)
self.prompt_template = (
prompt_template
or environ.get(PROMPT_ENV_MAP["user"])
or self.DEFAULT_PROMPT
)
self.prompt_sys_msg = (
prompt_sys_msg
or environ.get(
"OPENAI_API_SYS_MSG",
) # XXX: for backward compatibility, deprecate soon
or environ.get(PROMPT_ENV_MAP["system"])
or ""
)
self.system_content = environ.get("OPENAI_API_SYS_MSG") or ""
self.deployment_id = None
self.temperature = temperature
# gpt3 all models for save the limit
self.model_list = cycle(GPT35_MODEL_LIST)
def rotate_key(self):
self.openai_client.api_key = next(self.keys)
def rotate_model(self):
# TODO
self.model = next(self.model_list)
def create_chat_completion(self, text):
content = self.prompt_template.format(
text=text, language=self.language, crlf="\n"
)
sys_content = self.system_content or self.prompt_sys_msg.format(crlf="\n")
messages = [
{"role": "system", "content": sys_content},
{"role": "user", "content": content},
]
completion = self.openai_client.chat.completions.create(
model=self.model,
messages=messages,
temperature=self.temperature,
)
return completion
def get_translation(self, text):
self.rotate_key()
self.rotate_model() # rotate all the model to avoid the limit
try:
completion = self.create_chat_completion(text)
except Exception as e:
print(e)
pass
# TODO work well or exception finish by length limit
t_text = completion.choices[0].message.content.encode("utf8").decode() or ""
return t_text
def translate(self, text, needprint=True):
start_time = time.time()
# todo: Determine whether to print according to the cli option
if needprint:
print(re.sub("\n{3,}", "\n\n", text))
attempt_count = 0
max_attempts = 3
t_text = ""
while attempt_count < max_attempts:
try:
t_text = self.get_translation(text)
break
except Exception as e:
# todo: better sleep time? why sleep alawys about key_len
# 1. openai server error or own network interruption, sleep for a fixed time
# 2. an apikey has no money or reach limit, don`t sleep, just replace it with another apikey
# 3. all apikey reach limit, then use current sleep
sleep_time = int(60 / self.key_len)
print(e, f"will sleep {sleep_time} seconds")
time.sleep(sleep_time)
attempt_count += 1
if attempt_count == max_attempts:
print(f"Get {attempt_count} consecutive exceptions")
raise
# todo: Determine whether to print according to the cli option
if needprint:
print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]")
time.time() - start_time
# print(f"translation time: {elapsed_time:.1f}s")
return t_text
def translate_and_split_lines(self, text):
result_str = self.translate(text, False)
lines = result_str.splitlines()
lines = [line.strip() for line in lines if line.strip() != ""]
return lines
def get_best_result_list(
self,
plist_len,
new_str,
sleep_dur,
result_list,
max_retries=15,
):
if len(result_list) == plist_len:
return result_list, 0
best_result_list = result_list
retry_count = 0
while retry_count < max_retries and len(result_list) != plist_len:
print(
f"bug: {plist_len} -> {len(result_list)} : Number of paragraphs before and after translation",
)
print(f"sleep for {sleep_dur}s and retry {retry_count+1} ...")
time.sleep(sleep_dur)
retry_count += 1
result_list = self.translate_and_split_lines(new_str)
if (
len(result_list) == plist_len
or len(best_result_list) < len(result_list) <= plist_len
or (
len(result_list) < len(best_result_list)
and len(best_result_list) > plist_len
)
):
best_result_list = result_list
return best_result_list, retry_count
def log_retry(self, state, retry_count, elapsed_time, log_path="log/buglog.txt"):
if retry_count == 0:
return
print(f"retry {state}")
with open(log_path, "a", encoding="utf-8") as f:
print(
f"retry {state}, count = {retry_count}, time = {elapsed_time:.1f}s",
file=f,
)
def log_translation_mismatch(
self,
plist_len,
result_list,
new_str,
sep,
log_path="log/buglog.txt",
):
if len(result_list) == plist_len:
return
newlist = new_str.split(sep)
with open(log_path, "a", encoding="utf-8") as f:
print(f"problem size: {plist_len - len(result_list)}", file=f)
for i in range(len(newlist)):
print(newlist[i], file=f)
print(file=f)
if i < len(result_list):
print("............................................", file=f)
print(result_list[i], file=f)
print(file=f)
print("=============================", file=f)
print(
f"bug: {plist_len} paragraphs of text translated into {len(result_list)} paragraphs",
)
print("continue")
def join_lines(self, text):
lines = text.splitlines()
new_lines = []
temp_line = []
# join
for line in lines:
if line.strip():
temp_line.append(line.strip())
else:
if temp_line:
new_lines.append(" ".join(temp_line))
temp_line = []
new_lines.append(line)
if temp_line:
new_lines.append(" ".join(temp_line))
text = "\n".join(new_lines)
# del ^M
text = text.replace("^M", "\r")
lines = text.splitlines()
filtered_lines = [line for line in lines if line.strip() != "\r"]
new_text = "\n".join(filtered_lines)
return new_text
def translate_list(self, plist):
sep = "\n\n\n\n\n"
# new_str = sep.join([item.text for item in plist])
new_str = ""
i = 1
for p in plist:
temp_p = copy(p)
for sup in temp_p.find_all("sup"):
sup.extract()
new_str += f"({i}) {temp_p.get_text().strip()}{sep}"
i = i + 1
if new_str.endswith(sep):
new_str = new_str[: -len(sep)]
new_str = self.join_lines(new_str)
plist_len = len(plist)
print(f"plist len = {len(plist)}")
result_list = self.translate_and_split_lines(new_str)
start_time = time.time()
result_list, retry_count = self.get_best_result_list(
plist_len,
new_str,
6, # WTF this magic number here?
result_list,
)
end_time = time.time()
state = "fail" if len(result_list) != plist_len else "success"
log_path = "log/buglog.txt"
self.log_retry(state, retry_count, end_time - start_time, log_path)
self.log_translation_mismatch(plist_len, result_list, new_str, sep, log_path)
# del (num), num. sometime (num) will translated to num.
result_list = [re.sub(r"^(\(\d+\)|\d+\.|(\d+))\s*", "", s) for s in result_list]
return result_list
def set_deployment_id(self, deployment_id):
self.deployment_id = deployment_id
self.openai_client = AzureOpenAI(
api_key=next(self.keys),
azure_endpoint=self.api_base,
api_version="2023-07-01-preview",
azure_deployment=self.deployment_id,
)
def set_gpt4_models(self, model="gpt4"):
self.model_list = cycle(GPT4_MODEL_LIST)