bilingual_book_maker/book_maker/translator/chatgptapi_translator.py

302 lines
9.4 KiB
Python

import re
import time
from copy import copy
from os import environ
from rich import print
import openai
from .base_translator import Base
PROMPT_ENV_MAP = {
"user": "BBM_CHATGPTAPI_USER_MSG_TEMPLATE",
"system": "BBM_CHATGPTAPI_SYS_MSG",
}
class ChatGPTAPI(Base):
DEFAULT_PROMPT = "Please help me to translate,`{text}` to {language}, please return only translated content not include the origin text"
def __init__(
self,
key,
language,
api_base=None,
prompt_template=None,
prompt_sys_msg=None,
temperature=1.0,
**kwargs,
) -> None:
super().__init__(key, language)
self.key_len = len(key.split(","))
if api_base:
openai.api_base = api_base
self.prompt_template = (
prompt_template
or environ.get(PROMPT_ENV_MAP["user"])
or self.DEFAULT_PROMPT
)
self.prompt_sys_msg = (
prompt_sys_msg
or environ.get(
"OPENAI_API_SYS_MSG",
) # XXX: for backward compatibility, deprecate soon
or environ.get(PROMPT_ENV_MAP["system"])
or ""
)
self.system_content = environ.get("OPENAI_API_SYS_MSG") or ""
self.deployment_id = None
self.temperature = temperature
def rotate_key(self):
openai.api_key = next(self.keys)
def create_chat_completion(self, text):
content = self.prompt_template.format(
text=text, language=self.language, crlf="\n"
)
sys_content = self.system_content or self.prompt_sys_msg.format(crlf="\n")
messages = [
{"role": "system", "content": sys_content},
{"role": "user", "content": content},
]
if self.deployment_id:
return openai.ChatCompletion.create(
engine=self.deployment_id,
messages=messages,
temperature=self.temperature,
)
return openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=self.temperature,
)
def get_translation(self, text):
self.rotate_key()
completion = {}
try:
completion = self.create_chat_completion(text)
except Exception:
if (
"choices" not in completion
or not isinstance(completion["choices"], list)
or len(completion["choices"]) == 0
):
raise
if completion["choices"][0]["finish_reason"] != "length":
raise
# work well or exception finish by length limit
choice = completion["choices"][0]
t_text = choice.get("message").get("content", "").encode("utf8").decode()
if choice["finish_reason"] == "length":
with open("log/long_text.txt", "a") as f:
print(
f"""==================================================
The total token is too long and cannot be completely translated\n
{text}
""",
file=f,
)
return t_text
def translate(self, text, needprint=True):
# print("=================================================")
start_time = time.time()
# todo: Determine whether to print according to the cli option
if needprint:
print(re.sub("\n{3,}", "\n\n", text))
attempt_count = 0
max_attempts = 3
t_text = ""
while attempt_count < max_attempts:
try:
t_text = self.get_translation(text)
break
except Exception as e:
# todo: better sleep time? why sleep alawys about key_len
# 1. openai server error or own network interruption, sleep for a fixed time
# 2. an apikey has no money or reach limit, don`t sleep, just replace it with another apikey
# 3. all apikey reach limit, then use current sleep
sleep_time = int(60 / self.key_len)
print(e, f"will sleep {sleep_time} seconds")
time.sleep(sleep_time)
attempt_count += 1
if attempt_count == max_attempts:
print(f"Get {attempt_count} consecutive exceptions")
raise
# todo: Determine whether to print according to the cli option
if needprint:
print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]")
time.time() - start_time
# print(f"translation time: {elapsed_time:.1f}s")
return t_text
def translate_and_split_lines(self, text):
result_str = self.translate(text, False)
lines = result_str.splitlines()
lines = [line.strip() for line in lines if line.strip() != ""]
return lines
def get_best_result_list(
self,
plist_len,
new_str,
sleep_dur,
result_list,
max_retries=15,
):
if len(result_list) == plist_len:
return result_list, 0
best_result_list = result_list
retry_count = 0
while retry_count < max_retries and len(result_list) != plist_len:
print(
f"bug: {plist_len} -> {len(result_list)} : Number of paragraphs before and after translation",
)
print(f"sleep for {sleep_dur}s and retry {retry_count+1} ...")
time.sleep(sleep_dur)
retry_count += 1
result_list = self.translate_and_split_lines(new_str)
if (
len(result_list) == plist_len
or len(best_result_list) < len(result_list) <= plist_len
or (
len(result_list) < len(best_result_list)
and len(best_result_list) > plist_len
)
):
best_result_list = result_list
return best_result_list, retry_count
def log_retry(self, state, retry_count, elapsed_time, log_path="log/buglog.txt"):
if retry_count == 0:
return
print(f"retry {state}")
with open(log_path, "a", encoding="utf-8") as f:
print(
f"retry {state}, count = {retry_count}, time = {elapsed_time:.1f}s",
file=f,
)
def log_translation_mismatch(
self,
plist_len,
result_list,
new_str,
sep,
log_path="log/buglog.txt",
):
if len(result_list) == plist_len:
return
newlist = new_str.split(sep)
with open(log_path, "a", encoding="utf-8") as f:
print(f"problem size: {plist_len - len(result_list)}", file=f)
for i in range(len(newlist)):
print(newlist[i], file=f)
print(file=f)
if i < len(result_list):
print("............................................", file=f)
print(result_list[i], file=f)
print(file=f)
print("=============================", file=f)
print(
f"bug: {plist_len} paragraphs of text translated into {len(result_list)} paragraphs",
)
print("continue")
def join_lines(self, text):
lines = text.splitlines()
new_lines = []
temp_line = []
# join
for line in lines:
if line.strip():
temp_line.append(line.strip())
else:
if temp_line:
new_lines.append(" ".join(temp_line))
temp_line = []
new_lines.append(line)
if temp_line:
new_lines.append(" ".join(temp_line))
text = "\n".join(new_lines)
# del ^M
text = text.replace("^M", "\r")
lines = text.splitlines()
filtered_lines = [line for line in lines if line.strip() != "\r"]
new_text = "\n".join(filtered_lines)
return new_text
def translate_list(self, plist):
sep = "\n\n\n\n\n"
# new_str = sep.join([item.text for item in plist])
new_str = ""
i = 1
for p in plist:
temp_p = copy(p)
for sup in temp_p.find_all("sup"):
sup.extract()
new_str += f"({i}) {temp_p.get_text().strip()}{sep}"
i = i + 1
if new_str.endswith(sep):
new_str = new_str[: -len(sep)]
new_str = self.join_lines(new_str)
plist_len = len(plist)
print(f"plist len = {len(plist)}")
result_list = self.translate_and_split_lines(new_str)
start_time = time.time()
result_list, retry_count = self.get_best_result_list(
plist_len,
new_str,
6,
result_list,
)
end_time = time.time()
state = "fail" if len(result_list) != plist_len else "success"
log_path = "log/buglog.txt"
self.log_retry(state, retry_count, end_time - start_time, log_path)
self.log_translation_mismatch(plist_len, result_list, new_str, sep, log_path)
# del (num), num. sometime (num) will translated to num.
result_list = [re.sub(r"^(\(\d+\)|\d+\.|(\d+))\s*", "", s) for s in result_list]
return result_list
def set_deployment_id(self, deployment_id):
openai.api_type = "azure"
openai.api_version = "2023-03-15-preview"
self.deployment_id = deployment_id