bilingual_book_maker/book_maker/translator/chatgptapi_translator.py

712 lines
24 KiB
Python

import re
import time
import os
import shutil
from copy import copy
from os import environ
from itertools import cycle
import json
from openai import AzureOpenAI, OpenAI, RateLimitError
from rich import print
from .base_translator import Base
from ..config import config
CHATGPT_CONFIG = config["translator"]["chatgptapi"]
PROMPT_ENV_MAP = {
"user": "BBM_CHATGPTAPI_USER_MSG_TEMPLATE",
"system": "BBM_CHATGPTAPI_SYS_MSG",
}
GPT35_MODEL_LIST = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0125",
]
GPT4_MODEL_LIST = [
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-32k",
"gpt-4o-2024-05-13",
"gpt-4-0613",
"gpt-4-32k-0613",
]
GPT4oMINI_MODEL_LIST = [
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
]
GPT4o_MODEL_LIST = [
"gpt-4o",
"gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"chatgpt-4o-latest",
]
O1PREVIEW_MODEL_LIST = [
"o1-preview",
"o1-preview-2024-09-12",
]
O1_MODEL_LIST = [
"o1",
"o1-2024-12-17",
]
O1MINI_MODEL_LIST = [
"o1-mini",
"o1-mini-2024-09-12",
]
O3MINI_MODEL_LIST = [
"o3-mini",
]
class ChatGPTAPI(Base):
DEFAULT_PROMPT = "Please help me to translate,`{text}` to {language}, please return only translated content not include the origin text"
def __init__(
self,
key,
language,
api_base=None,
prompt_template=None,
prompt_sys_msg=None,
temperature=1.0,
context_flag=False,
context_paragraph_limit=0,
**kwargs,
) -> None:
super().__init__(key, language)
self.key_len = len(key.split(","))
self.openai_client = OpenAI(api_key=next(self.keys), base_url=api_base)
self.api_base = api_base
self.prompt_template = (
prompt_template
or environ.get(PROMPT_ENV_MAP["user"])
or self.DEFAULT_PROMPT
)
self.prompt_sys_msg = (
prompt_sys_msg
or environ.get(
"OPENAI_API_SYS_MSG",
) # XXX: for backward compatibility, deprecate soon
or environ.get(PROMPT_ENV_MAP["system"])
or ""
)
self.system_content = environ.get("OPENAI_API_SYS_MSG") or ""
self.deployment_id = None
self.temperature = temperature
self.model_list = None
self.context_flag = context_flag
self.context_list = []
self.context_translated_list = []
if context_paragraph_limit > 0:
# not set by user, use default
self.context_paragraph_limit = context_paragraph_limit
else:
# set by user, use user's value
self.context_paragraph_limit = CHATGPT_CONFIG["context_paragraph_limit"]
self.batch_text_list = []
self.batch_info_cache = None
self.result_content_cache = {}
def rotate_key(self):
self.openai_client.api_key = next(self.keys)
def rotate_model(self):
self.model = next(self.model_list)
def create_messages(self, text, intermediate_messages=None):
content = self.prompt_template.format(
text=text, language=self.language, crlf="\n"
)
sys_content = self.system_content or self.prompt_sys_msg.format(crlf="\n")
messages = [
{"role": "system", "content": sys_content},
]
if intermediate_messages:
messages.extend(intermediate_messages)
messages.append({"role": "user", "content": content})
return messages
def create_context_messages(self):
messages = []
if self.context_flag:
messages.append({"role": "user", "content": "\n".join(self.context_list)})
messages.append(
{
"role": "assistant",
"content": "\n".join(self.context_translated_list),
}
)
return messages
def create_chat_completion(self, text):
messages = self.create_messages(text, self.create_context_messages())
completion = self.openai_client.chat.completions.create(
model=self.model,
messages=messages,
temperature=self.temperature,
)
return completion
def get_translation(self, text):
self.rotate_key()
self.rotate_model() # rotate all the model to avoid the limit
completion = self.create_chat_completion(text)
# TODO work well or exception finish by length limit
# Check if content is not None before encoding
if completion.choices[0].message.content is not None:
t_text = completion.choices[0].message.content.encode("utf8").decode() or ""
else:
t_text = ""
if self.context_flag:
self.save_context(text, t_text)
return t_text
def save_context(self, text, t_text):
if self.context_paragraph_limit > 0:
self.context_list.append(text)
self.context_translated_list.append(t_text)
# Remove the oldest context
if len(self.context_list) > self.context_paragraph_limit:
self.context_list.pop(0)
self.context_translated_list.pop(0)
def translate(self, text, needprint=True):
start_time = time.time()
# todo: Determine whether to print according to the cli option
if needprint:
print(re.sub("\n{3,}", "\n\n", text))
attempt_count = 0
max_attempts = 3
t_text = ""
while attempt_count < max_attempts:
try:
t_text = self.get_translation(text)
break
except RateLimitError as e:
# todo: better sleep time? why sleep alawys about key_len
# 1. openai server error or own network interruption, sleep for a fixed time
# 2. an apikey has no money or reach limit, don`t sleep, just replace it with another apikey
# 3. all apikey reach limit, then use current sleep
sleep_time = int(60 / self.key_len)
print(e, f"will sleep {sleep_time} seconds")
time.sleep(sleep_time)
attempt_count += 1
if attempt_count == max_attempts:
print(f"Get {attempt_count} consecutive exceptions")
raise
except Exception as e:
print(str(e))
return
# todo: Determine whether to print according to the cli option
if needprint:
print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]")
time.time() - start_time
# print(f"translation time: {elapsed_time:.1f}s")
return t_text
def translate_and_split_lines(self, text):
result_str = self.translate(text, False)
lines = result_str.splitlines()
lines = [line.strip() for line in lines if line.strip() != ""]
return lines
def get_best_result_list(
self,
plist_len,
new_str,
sleep_dur,
result_list,
max_retries=15,
):
if len(result_list) == plist_len:
return result_list, 0
best_result_list = result_list
retry_count = 0
while retry_count < max_retries and len(result_list) != plist_len:
print(
f"bug: {plist_len} -> {len(result_list)} : Number of paragraphs before and after translation",
)
print(f"sleep for {sleep_dur}s and retry {retry_count+1} ...")
time.sleep(sleep_dur)
retry_count += 1
result_list = self.translate_and_split_lines(new_str)
if (
len(result_list) == plist_len
or len(best_result_list) < len(result_list) <= plist_len
or (
len(result_list) < len(best_result_list)
and len(best_result_list) > plist_len
)
):
best_result_list = result_list
return best_result_list, retry_count
def log_retry(self, state, retry_count, elapsed_time, log_path="log/buglog.txt"):
if retry_count == 0:
return
print(f"retry {state}")
with open(log_path, "a", encoding="utf-8") as f:
print(
f"retry {state}, count = {retry_count}, time = {elapsed_time:.1f}s",
file=f,
)
def log_translation_mismatch(
self,
plist_len,
result_list,
new_str,
sep,
log_path="log/buglog.txt",
):
if len(result_list) == plist_len:
return
newlist = new_str.split(sep)
with open(log_path, "a", encoding="utf-8") as f:
print(f"problem size: {plist_len - len(result_list)}", file=f)
for i in range(len(newlist)):
print(newlist[i], file=f)
print(file=f)
if i < len(result_list):
print("............................................", file=f)
print(result_list[i], file=f)
print(file=f)
print("=============================", file=f)
print(
f"bug: {plist_len} paragraphs of text translated into {len(result_list)} paragraphs",
)
print("continue")
def join_lines(self, text):
lines = text.splitlines()
new_lines = []
temp_line = []
# join
for line in lines:
if line.strip():
temp_line.append(line.strip())
else:
if temp_line:
new_lines.append(" ".join(temp_line))
temp_line = []
new_lines.append(line)
if temp_line:
new_lines.append(" ".join(temp_line))
text = "\n".join(new_lines)
# try to fix #372
if not text:
return ""
# del ^M
text = text.replace("^M", "\r")
lines = text.splitlines()
filtered_lines = [line for line in lines if line.strip() != "\r"]
new_text = "\n".join(filtered_lines)
return new_text
def translate_list(self, plist):
sep = "\n\n\n\n\n"
# new_str = sep.join([item.text for item in plist])
new_str = ""
i = 1
for p in plist:
temp_p = copy(p)
for sup in temp_p.find_all("sup"):
sup.extract()
new_str += f"({i}) {temp_p.get_text().strip()}{sep}"
i = i + 1
if new_str.endswith(sep):
new_str = new_str[: -len(sep)]
new_str = self.join_lines(new_str)
plist_len = len(plist)
print(f"plist len = {len(plist)}")
result_list = self.translate_and_split_lines(new_str)
start_time = time.time()
result_list, retry_count = self.get_best_result_list(
plist_len,
new_str,
6, # WTF this magic number here?
result_list,
)
end_time = time.time()
state = "fail" if len(result_list) != plist_len else "success"
log_path = "log/buglog.txt"
self.log_retry(state, retry_count, end_time - start_time, log_path)
self.log_translation_mismatch(plist_len, result_list, new_str, sep, log_path)
# del (num), num. sometime (num) will translated to num.
result_list = [re.sub(r"^(\(\d+\)|\d+\.|(\d+))\s*", "", s) for s in result_list]
return result_list
def set_deployment_id(self, deployment_id):
self.deployment_id = deployment_id
self.openai_client = AzureOpenAI(
api_key=next(self.keys),
azure_endpoint=self.api_base,
api_version="2023-07-01-preview",
azure_deployment=self.deployment_id,
)
def set_gpt35_models(self, ollama_model=""):
if ollama_model:
self.model_list = cycle([ollama_model])
return
# gpt3 all models for save the limit
if self.deployment_id:
self.model_list = cycle(["gpt-35-turbo"])
else:
my_model_list = [
i["id"] for i in self.openai_client.models.list().model_dump()["data"]
]
model_list = list(set(my_model_list) & set(GPT35_MODEL_LIST))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
def set_gpt4_models(self):
# for issue #375 azure can not use model list
if self.deployment_id:
self.model_list = cycle(["gpt-4"])
else:
my_model_list = [
i["id"] for i in self.openai_client.models.list().model_dump()["data"]
]
model_list = list(set(my_model_list) & set(GPT4_MODEL_LIST))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
def set_gpt4omini_models(self):
# for issue #375 azure can not use model list
if self.deployment_id:
self.model_list = cycle(["gpt-4o-mini"])
else:
my_model_list = [
i["id"] for i in self.openai_client.models.list().model_dump()["data"]
]
model_list = list(set(my_model_list) & set(GPT4oMINI_MODEL_LIST))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
def set_gpt4o_models(self):
# for issue #375 azure can not use model list
if self.deployment_id:
self.model_list = cycle(["gpt-4o"])
else:
my_model_list = [
i["id"] for i in self.openai_client.models.list().model_dump()["data"]
]
model_list = list(set(my_model_list) & set(GPT4o_MODEL_LIST))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
def set_o1preview_models(self):
# for issue #375 azure can not use model list
if self.deployment_id:
self.model_list = cycle(["o1-preview"])
else:
my_model_list = [
i["id"] for i in self.openai_client.models.list().model_dump()["data"]
]
model_list = list(set(my_model_list) & set(O1PREVIEW_MODEL_LIST))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
def set_o1_models(self):
# for issue #375 azure can not use model list
if self.deployment_id:
self.model_list = cycle(["o1"])
else:
my_model_list = [
i["id"] for i in self.openai_client.models.list().model_dump()["data"]
]
model_list = list(set(my_model_list) & set(O1_MODEL_LIST))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
def set_o1mini_models(self):
# for issue #375 azure can not use model list
if self.deployment_id:
self.model_list = cycle(["o1-mini"])
else:
my_model_list = [
i["id"] for i in self.openai_client.models.list().model_dump()["data"]
]
model_list = list(set(my_model_list) & set(O1MINI_MODEL_LIST))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
def set_o3mini_models(self):
# for issue #375 azure can not use model list
if self.deployment_id:
self.model_list = cycle(["o3-mini"])
else:
my_model_list = [
i["id"] for i in self.openai_client.models.list().model_dump()["data"]
]
model_list = list(set(my_model_list) & set(O3MINI_MODEL_LIST))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
def set_model_list(self, model_list):
model_list = list(set(model_list))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
def batch_init(self, book_name):
self.book_name = self.sanitize_book_name(book_name)
def add_to_batch_translate_queue(self, book_index, text):
self.batch_text_list.append({"book_index": book_index, "text": text})
def sanitize_book_name(self, book_name):
# Replace any characters that are not alphanumeric, underscore, hyphen, or dot with an underscore
sanitized_book_name = re.sub(r"[^\w\-_\.]", "_", book_name)
# Remove leading and trailing underscores and dots
sanitized_book_name = sanitized_book_name.strip("._")
return sanitized_book_name
def batch_metadata_file_path(self):
return os.path.join(os.getcwd(), "batch_files", f"{self.book_name}_info.json")
def batch_dir(self):
return os.path.join(os.getcwd(), "batch_files", self.book_name)
def custom_id(self, book_index):
return f"{self.book_name}-{book_index}"
def is_completed_batch(self):
batch_metadata_file_path = self.batch_metadata_file_path()
if not os.path.exists(batch_metadata_file_path):
print("Batch result file does not exist")
raise Exception("Batch result file does not exist")
with open(batch_metadata_file_path, "r", encoding="utf-8") as f:
batch_info = json.load(f)
for batch_file in batch_info["batch_files"]:
batch_status = self.check_batch_status(batch_file["batch_id"])
if batch_status.status != "completed":
return False
return True
def batch_translate(self, book_index):
if self.batch_info_cache is None:
batch_metadata_file_path = self.batch_metadata_file_path()
with open(batch_metadata_file_path, "r", encoding="utf-8") as f:
self.batch_info_cache = json.load(f)
batch_info = self.batch_info_cache
target_batch = None
for batch in batch_info["batch_files"]:
if batch["start_index"] <= book_index < batch["end_index"]:
target_batch = batch
break
if not target_batch:
raise ValueError(f"No batch found for book_index {book_index}")
if target_batch["batch_id"] in self.result_content_cache:
result_content = self.result_content_cache[target_batch["batch_id"]]
else:
batch_status = self.check_batch_status(target_batch["batch_id"])
if batch_status.output_file_id is None:
raise ValueError(f"Batch {target_batch['batch_id']} is not completed")
result_content = self.get_batch_result(batch_status.output_file_id)
self.result_content_cache[target_batch["batch_id"]] = result_content
result_lines = result_content.text.split("\n")
custom_id = self.custom_id(book_index)
for line in result_lines:
if line.strip():
result = json.loads(line)
if result["custom_id"] == custom_id:
return result["response"]["body"]["choices"][0]["message"][
"content"
]
raise ValueError(f"No result found for custom_id {custom_id}")
def create_batch_context_messages(self, index):
messages = []
if self.context_flag:
if index % CHATGPT_CONFIG[
"batch_context_update_interval"
] == 0 or not hasattr(self, "cached_context_messages"):
context_messages = []
for i in range(index - 1, -1, -1):
item = self.batch_text_list[i]
if len(item["text"].split()) >= 100:
context_messages.append(item["text"])
if len(context_messages) == self.context_paragraph_limit:
break
if len(context_messages) == self.context_paragraph_limit:
print("Creating cached context messages")
self.cached_context_messages = [
{"role": "user", "content": "\n".join(context_messages)},
{
"role": "assistant",
"content": self.get_translation(
"\n".join(context_messages)
),
},
]
if hasattr(self, "cached_context_messages"):
messages.extend(self.cached_context_messages)
return messages
def make_batch_request(self, book_index, text):
messages = self.create_messages(
text, self.create_batch_context_messages(book_index)
)
return {
"custom_id": self.custom_id(book_index),
"method": "POST",
"url": "/v1/chat/completions",
"body": {
# model shuould not be rotate
"model": self.batch_model,
"messages": messages,
"temperature": self.temperature,
},
}
def create_batch_files(self, dest_file_path):
file_paths = []
# max request 50,000 and max size 100MB
lines_per_file = 40000
current_file = 0
for i in range(0, len(self.batch_text_list), lines_per_file):
current_file += 1
file_path = os.path.join(dest_file_path, f"{current_file}.jsonl")
start_index = i
end_index = i + lines_per_file
# TODO: Split the file if it exceeds 100MB
with open(file_path, "w", encoding="utf-8") as f:
for text in self.batch_text_list[i : i + lines_per_file]:
batch_req = self.make_batch_request(
text["book_index"], text["text"]
)
json.dump(batch_req, f, ensure_ascii=False)
f.write("\n")
file_paths.append(
{
"file_path": file_path,
"start_index": start_index,
"end_index": end_index,
}
)
return file_paths
def batch(self):
self.rotate_model()
self.batch_model = self.model
# current working directory
batch_dir = self.batch_dir()
batch_metadata_file_path = self.batch_metadata_file_path()
# cleanup batch dir and result file
if os.path.exists(batch_dir):
shutil.rmtree(batch_dir)
if os.path.exists(batch_metadata_file_path):
os.remove(batch_metadata_file_path)
os.makedirs(batch_dir, exist_ok=True)
# batch execute
batch_files = self.create_batch_files(batch_dir)
batch_info = []
for batch_file in batch_files:
file_id = self.upload_batch_file(batch_file["file_path"])
batch = self.batch_execute(file_id)
batch_info.append(
self.create_batch_info(
file_id, batch, batch_file["start_index"], batch_file["end_index"]
)
)
# save batch info
batch_info_json = {
"book_id": self.book_name,
"batch_date": time.strftime("%Y-%m-%d %H:%M:%S"),
"batch_files": batch_info,
}
with open(batch_metadata_file_path, "w", encoding="utf-8") as f:
json.dump(batch_info_json, f, ensure_ascii=False, indent=2)
def create_batch_info(self, file_id, batch, start_index, end_index):
return {
"input_file_id": file_id,
"batch_id": batch.id,
"start_index": start_index,
"end_index": end_index,
"prefix": self.book_name,
}
def upload_batch_file(self, file_path):
batch_input_file = self.openai_client.files.create(
file=open(file_path, "rb"), purpose="batch"
)
return batch_input_file.id
def batch_execute(self, file_id):
current_time = time.strftime("%Y-%m-%d %H:%M:%S")
res = self.openai_client.batches.create(
input_file_id=file_id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={
"description": f"Batch job for {self.book_name} at {current_time}"
},
)
if res.errors:
print(res.errors)
raise Exception(f"Batch execution failed: {res.errors}")
return res
def check_batch_status(self, batch_id):
return self.openai_client.batches.retrieve(batch_id)
def get_batch_result(self, output_file_id):
return self.openai_client.files.content(output_file_id)