mirror of
https://github.com/yihong0618/bilingual_book_maker.git
synced 2025-06-02 09:30:24 +00:00
331 lines
12 KiB
Python
331 lines
12 KiB
Python
import re
|
|
import time
|
|
from copy import copy
|
|
from os import environ, linesep
|
|
from rich import print
|
|
|
|
import openai
|
|
|
|
from .base_translator import Base
|
|
|
|
PROMPT_ENV_MAP = {
|
|
"user": "BBM_CHATGPTAPI_USER_MSG_TEMPLATE",
|
|
"system": "BBM_CHATGPTAPI_SYS_MSG",
|
|
}
|
|
|
|
|
|
class GPT4(Base):
|
|
DEFAULT_PROMPT = "Please help me to translate,`{text}` to {language}, please return only translated content not include the origin text"
|
|
|
|
def __init__(
|
|
self,
|
|
key,
|
|
language,
|
|
api_base=None,
|
|
prompt_template=None,
|
|
prompt_sys_msg=None,
|
|
context_flag=False,
|
|
temperature=1.0,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(key, language)
|
|
self.context_flag = context_flag
|
|
self.context = "<summary>The start of the story.</summary>"
|
|
self.key_len = len(key.split(","))
|
|
|
|
if api_base:
|
|
openai.api_base = api_base
|
|
self.prompt_template = (
|
|
prompt_template
|
|
or environ.get(PROMPT_ENV_MAP["user"])
|
|
or self.DEFAULT_PROMPT
|
|
)
|
|
self.prompt_sys_msg = (
|
|
prompt_sys_msg
|
|
or environ.get(
|
|
"OPENAI_API_SYS_MSG",
|
|
) # XXX: for backward compatibility, deprecate soon
|
|
or environ.get(PROMPT_ENV_MAP["system"])
|
|
or ""
|
|
)
|
|
self.system_content = environ.get("OPENAI_API_SYS_MSG") or ""
|
|
self.deployment_id = None
|
|
self.temperature = temperature
|
|
|
|
def rotate_key(self):
|
|
openai.api_key = next(self.keys)
|
|
|
|
def create_chat_completion(self, text):
|
|
# content = self.prompt_template.format(
|
|
# text=text, language=self.language, crlf="\n"
|
|
# )
|
|
|
|
content = f"{self.context if self.context_flag else ''} {self.prompt_template.format(text=text, language=self.language, crlf=linesep)}"
|
|
|
|
sys_content = self.system_content or self.prompt_sys_msg.format(crlf="\n")
|
|
|
|
context_sys_str = "For each passage given, you may be provided a summary of the story up until this point (wrapped in tags '<summary>' and '</summary>') for context within the query, to provide background context of the story up until this point. If it's provided, use the context summary to aid you in translation with deeper comprehension, and write a new summary above the returned translation, wrapped in '<summary>' HTML-like tags, including important details (if relevant) from the new passage, retaining the most important key details from the existing summary, and dropping out less important details. If the summary is blank, assume it is the start of the story and write a summary from scratch. Do not make the summary longer than a paragraph, and smaller details can be replaced based on the relative importance of new details. The summary should be formatted in straightforward, inornate text, briefly summarising the entire story (from the start, including information before the given passage, leading up to the given passage) to act as an instructional payload for a Large-Language AI Model to fully understand the context of the passage."
|
|
|
|
sys_content = f"{self.system_content or self.prompt_sys_msg.format(crlf=linesep)} {context_sys_str if self.context_flag else ''} "
|
|
|
|
messages = [
|
|
{"role": "system", "content": sys_content},
|
|
{"role": "user", "content": content},
|
|
]
|
|
|
|
if self.deployment_id:
|
|
return openai.ChatCompletion.create(
|
|
engine=self.deployment_id,
|
|
messages=messages,
|
|
temperature=self.temperature,
|
|
)
|
|
|
|
return openai.ChatCompletion.create(
|
|
model="gpt-4",
|
|
messages=messages,
|
|
temperature=self.temperature,
|
|
)
|
|
|
|
def get_translation(self, text):
|
|
self.rotate_key()
|
|
|
|
completion = {}
|
|
try:
|
|
completion = self.create_chat_completion(text)
|
|
except Exception:
|
|
if (
|
|
"choices" not in completion
|
|
or not isinstance(completion["choices"], list)
|
|
or len(completion["choices"]) == 0
|
|
):
|
|
raise
|
|
if completion["choices"][0]["finish_reason"] != "length":
|
|
raise
|
|
|
|
# work well or exception finish by length limit
|
|
choice = completion["choices"][0]
|
|
|
|
t_text = choice.get("message").get("content", "").encode("utf8").decode()
|
|
|
|
if choice["finish_reason"] == "length":
|
|
with open("log/long_text.txt", "a") as f:
|
|
print(
|
|
f"""==================================================
|
|
The total token is too long and cannot be completely translated\n
|
|
{text}
|
|
""",
|
|
file=f,
|
|
)
|
|
|
|
return t_text
|
|
|
|
def translate(self, text, needprint=True):
|
|
# print("=================================================")
|
|
start_time = time.time()
|
|
# todo: Determine whether to print according to the cli option
|
|
if needprint:
|
|
print(re.sub("\n{3,}", "\n\n", text))
|
|
|
|
attempt_count = 0
|
|
max_attempts = 3
|
|
t_text = ""
|
|
|
|
while attempt_count < max_attempts:
|
|
try:
|
|
t_text = self.get_translation(text)
|
|
|
|
# Extract the text between <summary> and </summary> tags (including the tags), save the next context text, then delete it from the text.
|
|
context_match = re.search(
|
|
r"(<summary>.*?</summary>)", t_text, re.DOTALL
|
|
)
|
|
if context_match:
|
|
self.context = context_match.group(0)
|
|
t_text = t_text.replace(self.context, "", 1)
|
|
else:
|
|
pass
|
|
# self.context = ""
|
|
|
|
break
|
|
except Exception as e:
|
|
# todo: better sleep time? why sleep alawys about key_len
|
|
# 1. openai server error or own network interruption, sleep for a fixed time
|
|
# 2. an apikey has no money or reach limit, don`t sleep, just replace it with another apikey
|
|
# 3. all apikey reach limit, then use current sleep
|
|
sleep_time = int(60 / self.key_len)
|
|
print(e, f"will sleep {sleep_time} seconds")
|
|
time.sleep(sleep_time)
|
|
attempt_count += 1
|
|
if attempt_count == max_attempts:
|
|
print(f"Get {attempt_count} consecutive exceptions")
|
|
raise
|
|
|
|
# todo: Determine whether to print according to the cli option
|
|
if needprint:
|
|
print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]")
|
|
|
|
time.time() - start_time
|
|
# print(f"translation time: {elapsed_time:.1f}s")
|
|
|
|
return t_text
|
|
|
|
def translate_and_split_lines(self, text):
|
|
result_str = self.translate(text, False)
|
|
lines = result_str.splitlines()
|
|
lines = [line.strip() for line in lines if line.strip() != ""]
|
|
return lines
|
|
|
|
def get_best_result_list(
|
|
self,
|
|
plist_len,
|
|
new_str,
|
|
sleep_dur,
|
|
result_list,
|
|
max_retries=15,
|
|
):
|
|
if len(result_list) == plist_len:
|
|
return result_list, 0
|
|
|
|
best_result_list = result_list
|
|
retry_count = 0
|
|
|
|
while retry_count < max_retries and len(result_list) != plist_len:
|
|
print(
|
|
f"bug: {plist_len} -> {len(result_list)} : Number of paragraphs before and after translation",
|
|
)
|
|
print(f"sleep for {sleep_dur}s and retry {retry_count+1} ...")
|
|
time.sleep(sleep_dur)
|
|
retry_count += 1
|
|
result_list = self.translate_and_split_lines(new_str)
|
|
if (
|
|
len(result_list) == plist_len
|
|
or len(best_result_list) < len(result_list) <= plist_len
|
|
or (
|
|
len(result_list) < len(best_result_list)
|
|
and len(best_result_list) > plist_len
|
|
)
|
|
):
|
|
best_result_list = result_list
|
|
|
|
return best_result_list, retry_count
|
|
|
|
def log_retry(self, state, retry_count, elapsed_time, log_path="log/buglog.txt"):
|
|
if retry_count == 0:
|
|
return
|
|
print(f"retry {state}")
|
|
with open(log_path, "a", encoding="utf-8") as f:
|
|
print(
|
|
f"retry {state}, count = {retry_count}, time = {elapsed_time:.1f}s",
|
|
file=f,
|
|
)
|
|
|
|
def log_translation_mismatch(
|
|
self,
|
|
plist_len,
|
|
result_list,
|
|
new_str,
|
|
sep,
|
|
log_path="log/buglog.txt",
|
|
):
|
|
if len(result_list) == plist_len:
|
|
return
|
|
newlist = new_str.split(sep)
|
|
with open(log_path, "a", encoding="utf-8") as f:
|
|
print(f"problem size: {plist_len - len(result_list)}", file=f)
|
|
for i in range(len(newlist)):
|
|
print(newlist[i], file=f)
|
|
print(file=f)
|
|
if i < len(result_list):
|
|
print("............................................", file=f)
|
|
print(result_list[i], file=f)
|
|
print(file=f)
|
|
print("=============================", file=f)
|
|
|
|
print(
|
|
f"bug: {plist_len} paragraphs of text translated into {len(result_list)} paragraphs",
|
|
)
|
|
print("continue")
|
|
|
|
def join_lines(self, text):
|
|
lines = text.splitlines()
|
|
new_lines = []
|
|
temp_line = []
|
|
|
|
# join
|
|
for line in lines:
|
|
if line.strip():
|
|
temp_line.append(line.strip())
|
|
else:
|
|
if temp_line:
|
|
new_lines.append(" ".join(temp_line))
|
|
temp_line = []
|
|
new_lines.append(line)
|
|
|
|
if temp_line:
|
|
new_lines.append(" ".join(temp_line))
|
|
|
|
text = "\n".join(new_lines)
|
|
|
|
# del ^M
|
|
text = text.replace("^M", "\r")
|
|
lines = text.splitlines()
|
|
filtered_lines = [line for line in lines if line.strip() != "\r"]
|
|
new_text = "\n".join(filtered_lines)
|
|
|
|
return new_text
|
|
|
|
def translate_list(self, plist, context_flag):
|
|
sep = "\n\n\n\n\n"
|
|
# new_str = sep.join([item.text for item in plist])
|
|
|
|
new_str = ""
|
|
i = 1
|
|
for p in plist:
|
|
temp_p = copy(p)
|
|
for sup in temp_p.find_all("sup"):
|
|
sup.extract()
|
|
new_str += f"({i}) {temp_p.get_text().strip()}{sep}"
|
|
i = i + 1
|
|
|
|
if new_str.endswith(sep):
|
|
new_str = new_str[: -len(sep)]
|
|
|
|
new_str = self.join_lines(new_str)
|
|
|
|
plist_len = len(plist)
|
|
|
|
print(f"plist len = {len(plist)}")
|
|
|
|
result_list = self.translate_and_split_lines(new_str)
|
|
|
|
start_time = time.time()
|
|
|
|
result_list, retry_count = self.get_best_result_list(
|
|
plist_len,
|
|
new_str,
|
|
6,
|
|
result_list,
|
|
)
|
|
|
|
end_time = time.time()
|
|
|
|
state = "fail" if len(result_list) != plist_len else "success"
|
|
log_path = "log/buglog.txt"
|
|
|
|
self.log_retry(state, retry_count, end_time - start_time, log_path)
|
|
self.log_translation_mismatch(plist_len, result_list, new_str, sep, log_path)
|
|
|
|
# del (num), num. sometime (num) will translated to num.
|
|
result_list = [re.sub(r"^(\(\d+\)|\d+\.|(\d+))\s*", "", s) for s in result_list]
|
|
|
|
# # Remove the context paragraph from the final output
|
|
# if self.context:
|
|
# context_len = len(self.context)
|
|
# result_list = [s[context_len:] if s.startswith(self.context) else s for s in result_list]
|
|
|
|
return result_list
|
|
|
|
def set_deployment_id(self, deployment_id):
|
|
openai.api_type = "azure"
|
|
openai.api_version = "2023-03-15-preview"
|
|
self.deployment_id = deployment_id
|