diff --git a/book_maker/cli.py b/book_maker/cli.py index b02930f..3e87360 100644 --- a/book_maker/cli.py +++ b/book_maker/cli.py @@ -304,6 +304,18 @@ So you are close to reaching the limit. You have to choose your own value, there dest="model_list", help="Rather than using our preset lists of models, specify exactly the models you want as a comma separated list `gpt-4-32k,gpt-3.5-turbo-0125` (Currently only supports: `openai`)", ) + parser.add_argument( + "--batch", + dest="batch_flag", + action="store_true", + help="Enable batch translation using ChatGPT's batch API for improved efficiency", + ) + parser.add_argument( + "--batch-use", + dest="batch_use_flag", + action="store_true", + help="Use pre-generated batch translations to create files. Run with --batch first before using this option", + ) options = parser.parse_args() @@ -461,6 +473,10 @@ So you are close to reaching the limit. You have to choose your own value, there e.translate_model.set_gpt4omini_models() if options.block_size > 0: e.block_size = options.block_size + if options.batch_flag: + e.batch_flag = options.batch_flag + if options.batch_use_flag: + e.batch_use_flag = options.batch_use_flag e.make_bilingual_book() diff --git a/book_maker/config.py b/book_maker/config.py new file mode 100644 index 0000000..a186a4a --- /dev/null +++ b/book_maker/config.py @@ -0,0 +1,8 @@ +config = { + "translator": { + "chatgptapi": { + "context_paragraph_limit": 3, + "batch_context_update_interval": 50, + } + }, +} diff --git a/book_maker/loader/epub_loader.py b/book_maker/loader/epub_loader.py index f26d783..c2b1437 100644 --- a/book_maker/loader/epub_loader.py +++ b/book_maker/loader/epub_loader.py @@ -2,6 +2,7 @@ import os import pickle import string import sys +import time from copy import copy from pathlib import Path @@ -65,6 +66,8 @@ class EPUBBookLoader(BaseBookLoader): self.only_filelist = "" self.single_translate = single_translate self.block_size = -1 + self.batch_use_flag = False + self.batch_flag = False # monkey patch for # 173 def _write_items_patch(obj): @@ -142,11 +145,18 @@ class EPUBBookLoader(BaseBookLoader): if self.resume and index < p_to_save_len: p.string = self.p_to_save[index] else: + t_text = "" + if self.batch_flag: + self.translate_model.add_to_batch_translate_queue(index, new_p.text) + elif self.batch_use_flag: + t_text = self.translate_model.batch_translate(index) + else: + t_text = self.translate_model.translate(new_p.text) if type(p) == NavigableString: - new_p = self.translate_model.translate(new_p.text) + new_p = t_text self.p_to_save.append(new_p) else: - new_p.string = self.translate_model.translate(new_p.text) + new_p.string = t_text self.p_to_save.append(new_p.text) self.helper.insert_trans( @@ -456,6 +466,18 @@ class EPUBBookLoader(BaseBookLoader): return index + def batch_init_then_wait(self): + name, _ = os.path.splitext(self.epub_name) + if self.batch_flag or self.batch_use_flag: + self.translate_model.batch_init(name) + if self.batch_use_flag: + start_time = time.time() + while not self.translate_model.is_completed_batch(): + print("Batch translation is not completed yet") + time.sleep(2) + if time.time() - start_time > 300: # 5 minutes + raise Exception("Batch translation timed out after 5 minutes") + def make_bilingual_book(self): self.helper = EPUBBookLoaderHelper( self.translate_model, @@ -463,6 +485,7 @@ class EPUBBookLoader(BaseBookLoader): self.translation_style, self.context_flag, ) + self.batch_init_then_wait() new_book = self._make_new_book(self.origin_book) all_items = list(self.origin_book.get_items()) trans_taglist = self.translate_tags.split(",") @@ -520,7 +543,10 @@ class EPUBBookLoader(BaseBookLoader): name, _ = os.path.splitext(self.epub_name) epub.write_epub(f"{name}_bilingual.epub", new_book, {}) name, _ = os.path.splitext(self.epub_name) - epub.write_epub(f"{name}_bilingual.epub", new_book, {}) + if self.batch_flag: + self.translate_model.batch() + else: + epub.write_epub(f"{name}_bilingual.epub", new_book, {}) if self.accumulated_num == 1: pbar.close() except (KeyboardInterrupt, Exception) as e: diff --git a/book_maker/translator/chatgptapi_translator.py b/book_maker/translator/chatgptapi_translator.py index 94ecfae..c97121d 100644 --- a/book_maker/translator/chatgptapi_translator.py +++ b/book_maker/translator/chatgptapi_translator.py @@ -1,13 +1,19 @@ import re import time +import os +import shutil from copy import copy from os import environ from itertools import cycle +import json from openai import AzureOpenAI, OpenAI, RateLimitError from rich import print from .base_translator import Base +from ..config import config + +CHATGPT_CONFIG = config["translator"]["chatgptapi"] PROMPT_ENV_MAP = { "user": "BBM_CHATGPTAPI_USER_MSG_TEMPLATE", @@ -36,7 +42,6 @@ GPT4oMINI_MODEL_LIST = [ "gpt-4o-mini", "gpt-4o-mini-2024-07-18", ] -CONTEXT_PARAGRAPH_LIMIT = 3 class ChatGPTAPI(Base): @@ -84,7 +89,10 @@ class ChatGPTAPI(Base): self.context_paragraph_limit = context_paragraph_limit else: # set by user, use user's value - self.context_paragraph_limit = CONTEXT_PARAGRAPH_LIMIT + self.context_paragraph_limit = CHATGPT_CONFIG["context_paragraph_limit"] + self.batch_text_list = [] + self.batch_info_cache = None + self.result_content_cache = {} def rotate_key(self): self.openai_client.api_key = next(self.keys) @@ -92,14 +100,24 @@ class ChatGPTAPI(Base): def rotate_model(self): self.model = next(self.model_list) - def create_chat_completion(self, text): + def create_messages(self, text, intermediate_messages=None): content = self.prompt_template.format( text=text, language=self.language, crlf="\n" ) + sys_content = self.system_content or self.prompt_sys_msg.format(crlf="\n") messages = [ {"role": "system", "content": sys_content}, ] + + if intermediate_messages: + messages.extend(intermediate_messages) + + messages.append({"role": "user", "content": content}) + return messages + + def create_context_messages(self): + messages = [] if self.context_flag: messages.append({"role": "user", "content": "\n".join(self.context_list)}) messages.append( @@ -108,8 +126,10 @@ class ChatGPTAPI(Base): "content": "\n".join(self.context_translated_list), } ) - messages.append({"role": "user", "content": content}) + return messages + def create_chat_completion(self, text): + messages = self.create_messages(text, self.create_context_messages()) completion = self.openai_client.chat.completions.create( model=self.model, messages=messages, @@ -388,3 +408,224 @@ class ChatGPTAPI(Base): model_list = list(set(model_list)) print(f"Using model list {model_list}") self.model_list = cycle(model_list) + + def batch_init(self, book_name): + self.book_name = self.sanitize_book_name(book_name) + + def add_to_batch_translate_queue(self, book_index, text): + self.batch_text_list.append({"book_index": book_index, "text": text}) + + def sanitize_book_name(self, book_name): + # Replace any characters that are not alphanumeric, underscore, hyphen, or dot with an underscore + sanitized_book_name = re.sub(r"[^\w\-_\.]", "_", book_name) + # Remove leading and trailing underscores and dots + sanitized_book_name = sanitized_book_name.strip("._") + return sanitized_book_name + + def batch_metadata_file_path(self): + return os.path.join(os.getcwd(), "batch_files", f"{self.book_name}_info.json") + + def batch_dir(self): + return os.path.join(os.getcwd(), "batch_files", self.book_name) + + def custom_id(self, book_index): + return f"{self.book_name}-{book_index}" + + def is_completed_batch(self): + batch_metadata_file_path = self.batch_metadata_file_path() + + if not os.path.exists(batch_metadata_file_path): + print("Batch result file does not exist") + raise Exception("Batch result file does not exist") + + with open(batch_metadata_file_path, "r", encoding="utf-8") as f: + batch_info = json.load(f) + + for batch_file in batch_info["batch_files"]: + batch_status = self.check_batch_status(batch_file["batch_id"]) + if batch_status.status != "completed": + return False + + return True + + def batch_translate(self, book_index): + if self.batch_info_cache is None: + batch_metadata_file_path = self.batch_metadata_file_path() + with open(batch_metadata_file_path, "r", encoding="utf-8") as f: + self.batch_info_cache = json.load(f) + + batch_info = self.batch_info_cache + target_batch = None + for batch in batch_info["batch_files"]: + if batch["start_index"] <= book_index < batch["end_index"]: + target_batch = batch + break + + if not target_batch: + raise ValueError(f"No batch found for book_index {book_index}") + + if target_batch["batch_id"] in self.result_content_cache: + result_content = self.result_content_cache[target_batch["batch_id"]] + else: + batch_status = self.check_batch_status(target_batch["batch_id"]) + if batch_status.output_file_id is None: + raise ValueError(f"Batch {target_batch['batch_id']} is not completed") + result_content = self.get_batch_result(batch_status.output_file_id) + self.result_content_cache[target_batch["batch_id"]] = result_content + + result_lines = result_content.text.split("\n") + custom_id = self.custom_id(book_index) + for line in result_lines: + if line.strip(): + result = json.loads(line) + if result["custom_id"] == custom_id: + return result["response"]["body"]["choices"][0]["message"][ + "content" + ] + + raise ValueError(f"No result found for custom_id {custom_id}") + + def create_batch_context_messages(self, index): + messages = [] + if self.context_flag: + if index % CHATGPT_CONFIG[ + "batch_context_update_interval" + ] == 0 or not hasattr(self, "cached_context_messages"): + context_messages = [] + for i in range(index - 1, -1, -1): + item = self.batch_text_list[i] + if len(item["text"].split()) >= 100: + context_messages.append(item["text"]) + if len(context_messages) == self.context_paragraph_limit: + break + + if len(context_messages) == self.context_paragraph_limit: + print("Creating cached context messages") + self.cached_context_messages = [ + {"role": "user", "content": "\n".join(context_messages)}, + { + "role": "assistant", + "content": self.get_translation( + "\n".join(context_messages) + ), + }, + ] + + if hasattr(self, "cached_context_messages"): + messages.extend(self.cached_context_messages) + + return messages + + def make_batch_request(self, book_index, text): + messages = self.create_messages( + text, self.create_batch_context_messages(book_index) + ) + return { + "custom_id": self.custom_id(book_index), + "method": "POST", + "url": "/v1/chat/completions", + "body": { + # model shuould not be rotate + "model": self.batch_model, + "messages": messages, + "temperature": self.temperature, + }, + } + + def create_batch_files(self, dest_file_path): + file_paths = [] + # max request 50,000 and max size 100MB + lines_per_file = 40000 + current_file = 0 + + for i in range(0, len(self.batch_text_list), lines_per_file): + current_file += 1 + file_path = os.path.join(dest_file_path, f"{current_file}.jsonl") + start_index = i + end_index = i + lines_per_file + + # TODO: Split the file if it exceeds 100MB + with open(file_path, "w", encoding="utf-8") as f: + for text in self.batch_text_list[i : i + lines_per_file]: + batch_req = self.make_batch_request( + text["book_index"], text["text"] + ) + json.dump(batch_req, f, ensure_ascii=False) + f.write("\n") + file_paths.append( + { + "file_path": file_path, + "start_index": start_index, + "end_index": end_index, + } + ) + + return file_paths + + def batch(self): + self.rotate_model() + self.batch_model = self.model + # current working directory + batch_dir = self.batch_dir() + batch_metadata_file_path = self.batch_metadata_file_path() + # cleanup batch dir and result file + if os.path.exists(batch_dir): + shutil.rmtree(batch_dir) + if os.path.exists(batch_metadata_file_path): + os.remove(batch_metadata_file_path) + os.makedirs(batch_dir, exist_ok=True) + # batch execute + batch_files = self.create_batch_files(batch_dir) + batch_info = [] + for batch_file in batch_files: + file_id = self.upload_batch_file(batch_file["file_path"]) + batch = self.batch_execute(file_id) + batch_info.append( + self.create_batch_info( + file_id, batch, batch_file["start_index"], batch_file["end_index"] + ) + ) + # save batch info + batch_info_json = { + "book_id": self.book_name, + "batch_date": time.strftime("%Y-%m-%d %H:%M:%S"), + "batch_files": batch_info, + } + with open(batch_metadata_file_path, "w", encoding="utf-8") as f: + json.dump(batch_info_json, f, ensure_ascii=False, indent=2) + + def create_batch_info(self, file_id, batch, start_index, end_index): + return { + "input_file_id": file_id, + "batch_id": batch.id, + "start_index": start_index, + "end_index": end_index, + "prefix": self.book_name, + } + + def upload_batch_file(self, file_path): + batch_input_file = self.openai_client.files.create( + file=open(file_path, "rb"), purpose="batch" + ) + return batch_input_file.id + + def batch_execute(self, file_id): + current_time = time.strftime("%Y-%m-%d %H:%M:%S") + res = self.openai_client.batches.create( + input_file_id=file_id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "description": f"Batch job for {self.book_name} at {current_time}" + }, + ) + if res.errors: + print(res.errors) + raise Exception(f"Batch execution failed: {res.errors}") + return res + + def check_batch_status(self, batch_id): + return self.openai_client.batches.retrieve(batch_id) + + def get_batch_result(self, output_file_id): + return self.openai_client.files.content(output_file_id)