feat: support batch api

This commit is contained in:
mkXultra 2024-08-16 13:47:45 +09:00
parent f155e2075f
commit 15d80dd177
3 changed files with 232 additions and 3 deletions

View File

@ -304,6 +304,18 @@ So you are close to reaching the limit. You have to choose your own value, there
dest="model_list",
help="Rather than using our preset lists of models, specify exactly the models you want as a comma separated list `gpt-4-32k,gpt-3.5-turbo-0125` (Currently only supports: `openai`)",
)
parser.add_argument(
"--batch",
dest="batch_flag",
action="store_true",
help="Enable batch translation using ChatGPT's batch API for improved efficiency",
)
parser.add_argument(
"--batch-use",
dest="batch_use_flag",
action="store_true",
help="Use pre-generated batch translations to create files. Run with --batch first before using this option",
)
options = parser.parse_args()
@ -461,6 +473,10 @@ So you are close to reaching the limit. You have to choose your own value, there
e.translate_model.set_gpt4omini_models()
if options.block_size > 0:
e.block_size = options.block_size
if options.batch_flag:
e.batch_flag = options.batch_flag
if options.batch_use_flag:
e.batch_use_flag = options.batch_use_flag
e.make_bilingual_book()

View File

@ -2,6 +2,7 @@ import os
import pickle
import string
import sys
import time
from copy import copy
from pathlib import Path
@ -65,6 +66,8 @@ class EPUBBookLoader(BaseBookLoader):
self.only_filelist = ""
self.single_translate = single_translate
self.block_size = -1
self.batch_use_flag = False
self.batch_flag = False
# monkey patch for # 173
def _write_items_patch(obj):
@ -142,11 +145,18 @@ class EPUBBookLoader(BaseBookLoader):
if self.resume and index < p_to_save_len:
p.string = self.p_to_save[index]
else:
t_text = ""
if self.batch_flag:
self.translate_model.add_to_batch_trasnlate_queue(index, new_p.text)
elif self.batch_use_flag:
t_text = self.translate_model.batch_translate(index)
else:
t_text = self.translate_model.translate(new_p.text)
if type(p) == NavigableString:
new_p = self.translate_model.translate(new_p.text)
new_p = t_text
self.p_to_save.append(new_p)
else:
new_p.string = self.translate_model.translate(new_p.text)
new_p.string = t_text
self.p_to_save.append(new_p.text)
self.helper.insert_trans(
@ -456,6 +466,18 @@ class EPUBBookLoader(BaseBookLoader):
return index
def batch_init_then_wait(self):
name, _ = os.path.splitext(self.epub_name)
if self.batch_flag or self.batch_use_flag:
self.translate_model.batch_init(name)
if self.batch_use_flag:
start_time = time.time()
while not self.translate_model.is_completed_batch():
print("Batch translation is not completed yet")
time.sleep(2)
if time.time() - start_time > 300: # 5 minutes
raise Exception("Batch translation timed out after 5 minutes")
def make_bilingual_book(self):
self.helper = EPUBBookLoaderHelper(
self.translate_model,
@ -463,6 +485,7 @@ class EPUBBookLoader(BaseBookLoader):
self.translation_style,
self.context_flag,
)
self.batch_init_then_wait()
new_book = self._make_new_book(self.origin_book)
all_items = list(self.origin_book.get_items())
trans_taglist = self.translate_tags.split(",")
@ -520,6 +543,10 @@ class EPUBBookLoader(BaseBookLoader):
name, _ = os.path.splitext(self.epub_name)
epub.write_epub(f"{name}_bilingual.epub", new_book, {})
name, _ = os.path.splitext(self.epub_name)
if self.batch_flag:
self.translate_model.batch()
else:
epub.write_epub(f"{name}_bilingual.epub", new_book, {})
epub.write_epub(f"{name}_bilingual.epub", new_book, {})
if self.accumulated_num == 1:
pbar.close()

View File

@ -1,8 +1,11 @@
import re
import time
import os
import shutil
from copy import copy
from os import environ
from itertools import cycle
import json
from openai import AzureOpenAI, OpenAI, RateLimitError
from rich import print
@ -85,6 +88,8 @@ class ChatGPTAPI(Base):
else:
# set by user, use user's value
self.context_paragraph_limit = CONTEXT_PARAGRAPH_LIMIT
self.batch_text_list = []
self.batch_info_cache = None
def rotate_key(self):
self.openai_client.api_key = next(self.keys)
@ -92,10 +97,11 @@ class ChatGPTAPI(Base):
def rotate_model(self):
self.model = next(self.model_list)
def create_chat_completion(self, text):
def create_messages(self, text):
content = self.prompt_template.format(
text=text, language=self.language, crlf="\n"
)
sys_content = self.system_content or self.prompt_sys_msg.format(crlf="\n")
messages = [
{"role": "system", "content": sys_content},
@ -109,7 +115,10 @@ class ChatGPTAPI(Base):
}
)
messages.append({"role": "user", "content": content})
return messages
def create_chat_completion(self, text):
messages = self.create_messages(text)
completion = self.openai_client.chat.completions.create(
model=self.model,
messages=messages,
@ -388,3 +397,180 @@ class ChatGPTAPI(Base):
model_list = list(set(model_list))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
def add_to_batch_trasnlate_queue(self, book_index, text):
self.batch_text_list.append({"book_index": book_index, "text": text})
def batch_init(self, book_name):
self.book_name = self.sanitize_book_name(book_name)
def sanitize_book_name(self, book_name):
# Replace any characters that are not alphanumeric, underscore, hyphen, or dot with an underscore
sanitized_book_name = re.sub(r"[^\w\-_\.]", "_", book_name)
# Remove leading and trailing underscores and dots
sanitized_book_name = sanitized_book_name.strip("._")
return sanitized_book_name
def batch_result_file_path(self):
return f"{os.getcwd()}/batch_files/{self.book_name}_info.json"
def batch_dir(self):
return f"{os.getcwd()}/batch_files/{self.book_name}"
def custom_id(self, book_index):
return f"{self.book_name}-{book_index}"
def is_completed_batch(self):
batch_result_file_path = self.batch_result_file_path()
if not os.path.exists(batch_result_file_path):
print("Batch result file does not exist")
raise Exception("Batch result file does not exist")
with open(batch_result_file_path, "r", encoding="utf-8") as f:
batch_info = json.load(f)
for batch_file in batch_info["batch_files"]:
batch_status = self.check_batch_status(batch_file["batch_id"])
if batch_status.status != "completed":
return False
return True
def batch_translate(self, book_index):
if self.batch_info_cache is None:
batch_result_file_path = self.batch_result_file_path()
with open(batch_result_file_path, "r", encoding="utf-8") as f:
self.batch_info_cache = json.load(f)
batch_info = self.batch_info_cache
target_batch = None
for batch in batch_info["batch_files"]:
if batch["start_index"] <= book_index < batch["end_index"]:
target_batch = batch
break
if not target_batch:
raise ValueError(f"No batch found for book_index {book_index}")
batch_status = self.check_batch_status(target_batch["batch_id"])
if batch_status.output_file_id == None:
raise ValueError(f"Batch {target_batch['batch_id']} is not completed")
result_content = self.get_batch_result(batch_status.output_file_id)
result_lines = result_content.text.split("\n")
custom_id = self.custom_id(book_index)
for line in result_lines:
if line.strip():
result = json.loads(line)
if result["custom_id"] == custom_id:
return result["response"]["body"]["choices"][0]["message"][
"content"
]
raise ValueError(f"No result found for custom_id {custom_id}")
def make_batch_request(self, book_index, text):
messages = self.create_messages(text)
return {
"custom_id": self.custom_id(book_index),
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": self.model, "messages": messages, "max_tokens": 1000},
}
def create_batch_files(self, dest_file_path):
file_paths = []
# max request 50,000 and max size 100MB
lines_per_file = 40000
current_file = 0
for i in range(0, len(self.batch_text_list), lines_per_file):
current_file += 1
file_path = f"{dest_file_path}/{current_file}.jsonl"
start_index = i
end_index = i + lines_per_file
# TODO: Split the file if it exceeds 100MB
with open(file_path, "w", encoding="utf-8") as f:
for text in self.batch_text_list[i : i + lines_per_file]:
batch_req = self.make_batch_request(
text["book_index"], text["text"]
)
json.dump(batch_req, f, ensure_ascii=False)
f.write("\n")
file_paths.append(
{
"file_path": file_path,
"start_index": start_index,
"end_index": end_index,
}
)
return file_paths
def upload_batch_file(self, file_path):
batch_input_file = self.openai_client.files.create(
file=open(file_path, "rb"), purpose="batch"
)
return batch_input_file.id
def batch(self):
self.rotate_model()
# current working directory
batch_dir = self.batch_dir()
batch_result_file_path = self.batch_result_file_path()
# cleanup batch dir and result file
if os.path.exists(batch_dir):
shutil.rmtree(batch_dir)
if os.path.exists(batch_result_file_path):
os.remove(batch_result_file_path)
os.makedirs(batch_dir, exist_ok=True)
# batch execute
batch_files = self.create_batch_files(batch_dir)
batch_info = []
for batch_file in batch_files:
file_id = self.upload_batch_file(batch_file["file_path"])
batch = self.batch_execute(file_id)
batch_info.append(
self.create_batch_info(
file_id, batch, batch_file["start_index"], batch_file["end_index"]
)
)
# save batch info
batch_info_json = {
"book_id": self.book_name,
"batch_date": time.strftime("%Y-%m-%d %H:%M:%S"),
"batch_files": batch_info,
}
with open(batch_result_file_path, "w", encoding="utf-8") as f:
json.dump(batch_info_json, f, ensure_ascii=False, indent=2)
def create_batch_info(self, file_id, batch, start_index, end_index):
return {
"input_file_id": file_id,
"batch_id": batch.id,
"start_index": start_index,
"end_index": end_index,
"prefix": self.book_name,
}
def batch_execute(self, file_id):
current_time = time.strftime("%Y-%m-%d %H:%M:%S")
res = self.openai_client.batches.create(
input_file_id=file_id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={
"description": f"Batch job for {self.book_name} at {current_time}"
},
)
if res.errors:
print(res.errors)
raise Exception(f"Batch execution failed: {res.errors}")
return res
def check_batch_status(self, batch_id):
return self.openai_client.batches.retrieve(batch_id)
def get_batch_result(self, output_file_id):
return self.openai_client.files.content(output_file_id)