supprt openai batch api (#423)

* feat: support batch api
This commit is contained in:
mkXultra 2024-08-20 17:17:48 +09:00 committed by GitHub
parent f155e2075f
commit 9e4e7b59c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 298 additions and 7 deletions

View File

@ -304,6 +304,18 @@ So you are close to reaching the limit. You have to choose your own value, there
dest="model_list",
help="Rather than using our preset lists of models, specify exactly the models you want as a comma separated list `gpt-4-32k,gpt-3.5-turbo-0125` (Currently only supports: `openai`)",
)
parser.add_argument(
"--batch",
dest="batch_flag",
action="store_true",
help="Enable batch translation using ChatGPT's batch API for improved efficiency",
)
parser.add_argument(
"--batch-use",
dest="batch_use_flag",
action="store_true",
help="Use pre-generated batch translations to create files. Run with --batch first before using this option",
)
options = parser.parse_args()
@ -461,6 +473,10 @@ So you are close to reaching the limit. You have to choose your own value, there
e.translate_model.set_gpt4omini_models()
if options.block_size > 0:
e.block_size = options.block_size
if options.batch_flag:
e.batch_flag = options.batch_flag
if options.batch_use_flag:
e.batch_use_flag = options.batch_use_flag
e.make_bilingual_book()

8
book_maker/config.py Normal file
View File

@ -0,0 +1,8 @@
config = {
"translator": {
"chatgptapi": {
"context_paragraph_limit": 3,
"batch_context_update_interval": 50,
}
},
}

View File

@ -2,6 +2,7 @@ import os
import pickle
import string
import sys
import time
from copy import copy
from pathlib import Path
@ -65,6 +66,8 @@ class EPUBBookLoader(BaseBookLoader):
self.only_filelist = ""
self.single_translate = single_translate
self.block_size = -1
self.batch_use_flag = False
self.batch_flag = False
# monkey patch for # 173
def _write_items_patch(obj):
@ -142,11 +145,18 @@ class EPUBBookLoader(BaseBookLoader):
if self.resume and index < p_to_save_len:
p.string = self.p_to_save[index]
else:
t_text = ""
if self.batch_flag:
self.translate_model.add_to_batch_translate_queue(index, new_p.text)
elif self.batch_use_flag:
t_text = self.translate_model.batch_translate(index)
else:
t_text = self.translate_model.translate(new_p.text)
if type(p) == NavigableString:
new_p = self.translate_model.translate(new_p.text)
new_p = t_text
self.p_to_save.append(new_p)
else:
new_p.string = self.translate_model.translate(new_p.text)
new_p.string = t_text
self.p_to_save.append(new_p.text)
self.helper.insert_trans(
@ -456,6 +466,18 @@ class EPUBBookLoader(BaseBookLoader):
return index
def batch_init_then_wait(self):
name, _ = os.path.splitext(self.epub_name)
if self.batch_flag or self.batch_use_flag:
self.translate_model.batch_init(name)
if self.batch_use_flag:
start_time = time.time()
while not self.translate_model.is_completed_batch():
print("Batch translation is not completed yet")
time.sleep(2)
if time.time() - start_time > 300: # 5 minutes
raise Exception("Batch translation timed out after 5 minutes")
def make_bilingual_book(self):
self.helper = EPUBBookLoaderHelper(
self.translate_model,
@ -463,6 +485,7 @@ class EPUBBookLoader(BaseBookLoader):
self.translation_style,
self.context_flag,
)
self.batch_init_then_wait()
new_book = self._make_new_book(self.origin_book)
all_items = list(self.origin_book.get_items())
trans_taglist = self.translate_tags.split(",")
@ -520,7 +543,10 @@ class EPUBBookLoader(BaseBookLoader):
name, _ = os.path.splitext(self.epub_name)
epub.write_epub(f"{name}_bilingual.epub", new_book, {})
name, _ = os.path.splitext(self.epub_name)
epub.write_epub(f"{name}_bilingual.epub", new_book, {})
if self.batch_flag:
self.translate_model.batch()
else:
epub.write_epub(f"{name}_bilingual.epub", new_book, {})
if self.accumulated_num == 1:
pbar.close()
except (KeyboardInterrupt, Exception) as e:

View File

@ -1,13 +1,19 @@
import re
import time
import os
import shutil
from copy import copy
from os import environ
from itertools import cycle
import json
from openai import AzureOpenAI, OpenAI, RateLimitError
from rich import print
from .base_translator import Base
from ..config import config
CHATGPT_CONFIG = config["translator"]["chatgptapi"]
PROMPT_ENV_MAP = {
"user": "BBM_CHATGPTAPI_USER_MSG_TEMPLATE",
@ -36,7 +42,6 @@ GPT4oMINI_MODEL_LIST = [
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
]
CONTEXT_PARAGRAPH_LIMIT = 3
class ChatGPTAPI(Base):
@ -84,7 +89,10 @@ class ChatGPTAPI(Base):
self.context_paragraph_limit = context_paragraph_limit
else:
# set by user, use user's value
self.context_paragraph_limit = CONTEXT_PARAGRAPH_LIMIT
self.context_paragraph_limit = CHATGPT_CONFIG["context_paragraph_limit"]
self.batch_text_list = []
self.batch_info_cache = None
self.result_content_cache = {}
def rotate_key(self):
self.openai_client.api_key = next(self.keys)
@ -92,14 +100,24 @@ class ChatGPTAPI(Base):
def rotate_model(self):
self.model = next(self.model_list)
def create_chat_completion(self, text):
def create_messages(self, text, intermediate_messages=None):
content = self.prompt_template.format(
text=text, language=self.language, crlf="\n"
)
sys_content = self.system_content or self.prompt_sys_msg.format(crlf="\n")
messages = [
{"role": "system", "content": sys_content},
]
if intermediate_messages:
messages.extend(intermediate_messages)
messages.append({"role": "user", "content": content})
return messages
def create_context_messages(self):
messages = []
if self.context_flag:
messages.append({"role": "user", "content": "\n".join(self.context_list)})
messages.append(
@ -108,8 +126,10 @@ class ChatGPTAPI(Base):
"content": "\n".join(self.context_translated_list),
}
)
messages.append({"role": "user", "content": content})
return messages
def create_chat_completion(self, text):
messages = self.create_messages(text, self.create_context_messages())
completion = self.openai_client.chat.completions.create(
model=self.model,
messages=messages,
@ -388,3 +408,224 @@ class ChatGPTAPI(Base):
model_list = list(set(model_list))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
def batch_init(self, book_name):
self.book_name = self.sanitize_book_name(book_name)
def add_to_batch_translate_queue(self, book_index, text):
self.batch_text_list.append({"book_index": book_index, "text": text})
def sanitize_book_name(self, book_name):
# Replace any characters that are not alphanumeric, underscore, hyphen, or dot with an underscore
sanitized_book_name = re.sub(r"[^\w\-_\.]", "_", book_name)
# Remove leading and trailing underscores and dots
sanitized_book_name = sanitized_book_name.strip("._")
return sanitized_book_name
def batch_metadata_file_path(self):
return os.path.join(os.getcwd(), "batch_files", f"{self.book_name}_info.json")
def batch_dir(self):
return os.path.join(os.getcwd(), "batch_files", self.book_name)
def custom_id(self, book_index):
return f"{self.book_name}-{book_index}"
def is_completed_batch(self):
batch_metadata_file_path = self.batch_metadata_file_path()
if not os.path.exists(batch_metadata_file_path):
print("Batch result file does not exist")
raise Exception("Batch result file does not exist")
with open(batch_metadata_file_path, "r", encoding="utf-8") as f:
batch_info = json.load(f)
for batch_file in batch_info["batch_files"]:
batch_status = self.check_batch_status(batch_file["batch_id"])
if batch_status.status != "completed":
return False
return True
def batch_translate(self, book_index):
if self.batch_info_cache is None:
batch_metadata_file_path = self.batch_metadata_file_path()
with open(batch_metadata_file_path, "r", encoding="utf-8") as f:
self.batch_info_cache = json.load(f)
batch_info = self.batch_info_cache
target_batch = None
for batch in batch_info["batch_files"]:
if batch["start_index"] <= book_index < batch["end_index"]:
target_batch = batch
break
if not target_batch:
raise ValueError(f"No batch found for book_index {book_index}")
if target_batch["batch_id"] in self.result_content_cache:
result_content = self.result_content_cache[target_batch["batch_id"]]
else:
batch_status = self.check_batch_status(target_batch["batch_id"])
if batch_status.output_file_id is None:
raise ValueError(f"Batch {target_batch['batch_id']} is not completed")
result_content = self.get_batch_result(batch_status.output_file_id)
self.result_content_cache[target_batch["batch_id"]] = result_content
result_lines = result_content.text.split("\n")
custom_id = self.custom_id(book_index)
for line in result_lines:
if line.strip():
result = json.loads(line)
if result["custom_id"] == custom_id:
return result["response"]["body"]["choices"][0]["message"][
"content"
]
raise ValueError(f"No result found for custom_id {custom_id}")
def create_batch_context_messages(self, index):
messages = []
if self.context_flag:
if index % CHATGPT_CONFIG[
"batch_context_update_interval"
] == 0 or not hasattr(self, "cached_context_messages"):
context_messages = []
for i in range(index - 1, -1, -1):
item = self.batch_text_list[i]
if len(item["text"].split()) >= 100:
context_messages.append(item["text"])
if len(context_messages) == self.context_paragraph_limit:
break
if len(context_messages) == self.context_paragraph_limit:
print("Creating cached context messages")
self.cached_context_messages = [
{"role": "user", "content": "\n".join(context_messages)},
{
"role": "assistant",
"content": self.get_translation(
"\n".join(context_messages)
),
},
]
if hasattr(self, "cached_context_messages"):
messages.extend(self.cached_context_messages)
return messages
def make_batch_request(self, book_index, text):
messages = self.create_messages(
text, self.create_batch_context_messages(book_index)
)
return {
"custom_id": self.custom_id(book_index),
"method": "POST",
"url": "/v1/chat/completions",
"body": {
# model shuould not be rotate
"model": self.batch_model,
"messages": messages,
"temperature": self.temperature,
},
}
def create_batch_files(self, dest_file_path):
file_paths = []
# max request 50,000 and max size 100MB
lines_per_file = 40000
current_file = 0
for i in range(0, len(self.batch_text_list), lines_per_file):
current_file += 1
file_path = os.path.join(dest_file_path, f"{current_file}.jsonl")
start_index = i
end_index = i + lines_per_file
# TODO: Split the file if it exceeds 100MB
with open(file_path, "w", encoding="utf-8") as f:
for text in self.batch_text_list[i : i + lines_per_file]:
batch_req = self.make_batch_request(
text["book_index"], text["text"]
)
json.dump(batch_req, f, ensure_ascii=False)
f.write("\n")
file_paths.append(
{
"file_path": file_path,
"start_index": start_index,
"end_index": end_index,
}
)
return file_paths
def batch(self):
self.rotate_model()
self.batch_model = self.model
# current working directory
batch_dir = self.batch_dir()
batch_metadata_file_path = self.batch_metadata_file_path()
# cleanup batch dir and result file
if os.path.exists(batch_dir):
shutil.rmtree(batch_dir)
if os.path.exists(batch_metadata_file_path):
os.remove(batch_metadata_file_path)
os.makedirs(batch_dir, exist_ok=True)
# batch execute
batch_files = self.create_batch_files(batch_dir)
batch_info = []
for batch_file in batch_files:
file_id = self.upload_batch_file(batch_file["file_path"])
batch = self.batch_execute(file_id)
batch_info.append(
self.create_batch_info(
file_id, batch, batch_file["start_index"], batch_file["end_index"]
)
)
# save batch info
batch_info_json = {
"book_id": self.book_name,
"batch_date": time.strftime("%Y-%m-%d %H:%M:%S"),
"batch_files": batch_info,
}
with open(batch_metadata_file_path, "w", encoding="utf-8") as f:
json.dump(batch_info_json, f, ensure_ascii=False, indent=2)
def create_batch_info(self, file_id, batch, start_index, end_index):
return {
"input_file_id": file_id,
"batch_id": batch.id,
"start_index": start_index,
"end_index": end_index,
"prefix": self.book_name,
}
def upload_batch_file(self, file_path):
batch_input_file = self.openai_client.files.create(
file=open(file_path, "rb"), purpose="batch"
)
return batch_input_file.id
def batch_execute(self, file_id):
current_time = time.strftime("%Y-%m-%d %H:%M:%S")
res = self.openai_client.batches.create(
input_file_id=file_id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={
"description": f"Batch job for {self.book_name} at {current_time}"
},
)
if res.errors:
print(res.errors)
raise Exception(f"Batch execution failed: {res.errors}")
return res
def check_batch_status(self, batch_id):
return self.openai_client.batches.retrieve(batch_id)
def get_batch_result(self, output_file_id):
return self.openai_client.files.content(output_file_id)