From 15d80dd1770d8042cd9cdfd8a7aa8c20425be145 Mon Sep 17 00:00:00 2001 From: mkXultra Date: Fri, 16 Aug 2024 13:47:45 +0900 Subject: [PATCH] feat: support batch api --- book_maker/cli.py | 16 ++ book_maker/loader/epub_loader.py | 31 ++- .../translator/chatgptapi_translator.py | 188 +++++++++++++++++- 3 files changed, 232 insertions(+), 3 deletions(-) diff --git a/book_maker/cli.py b/book_maker/cli.py index b02930f..3e87360 100644 --- a/book_maker/cli.py +++ b/book_maker/cli.py @@ -304,6 +304,18 @@ So you are close to reaching the limit. You have to choose your own value, there dest="model_list", help="Rather than using our preset lists of models, specify exactly the models you want as a comma separated list `gpt-4-32k,gpt-3.5-turbo-0125` (Currently only supports: `openai`)", ) + parser.add_argument( + "--batch", + dest="batch_flag", + action="store_true", + help="Enable batch translation using ChatGPT's batch API for improved efficiency", + ) + parser.add_argument( + "--batch-use", + dest="batch_use_flag", + action="store_true", + help="Use pre-generated batch translations to create files. Run with --batch first before using this option", + ) options = parser.parse_args() @@ -461,6 +473,10 @@ So you are close to reaching the limit. You have to choose your own value, there e.translate_model.set_gpt4omini_models() if options.block_size > 0: e.block_size = options.block_size + if options.batch_flag: + e.batch_flag = options.batch_flag + if options.batch_use_flag: + e.batch_use_flag = options.batch_use_flag e.make_bilingual_book() diff --git a/book_maker/loader/epub_loader.py b/book_maker/loader/epub_loader.py index f26d783..c065379 100644 --- a/book_maker/loader/epub_loader.py +++ b/book_maker/loader/epub_loader.py @@ -2,6 +2,7 @@ import os import pickle import string import sys +import time from copy import copy from pathlib import Path @@ -65,6 +66,8 @@ class EPUBBookLoader(BaseBookLoader): self.only_filelist = "" self.single_translate = single_translate self.block_size = -1 + self.batch_use_flag = False + self.batch_flag = False # monkey patch for # 173 def _write_items_patch(obj): @@ -142,11 +145,18 @@ class EPUBBookLoader(BaseBookLoader): if self.resume and index < p_to_save_len: p.string = self.p_to_save[index] else: + t_text = "" + if self.batch_flag: + self.translate_model.add_to_batch_trasnlate_queue(index, new_p.text) + elif self.batch_use_flag: + t_text = self.translate_model.batch_translate(index) + else: + t_text = self.translate_model.translate(new_p.text) if type(p) == NavigableString: - new_p = self.translate_model.translate(new_p.text) + new_p = t_text self.p_to_save.append(new_p) else: - new_p.string = self.translate_model.translate(new_p.text) + new_p.string = t_text self.p_to_save.append(new_p.text) self.helper.insert_trans( @@ -456,6 +466,18 @@ class EPUBBookLoader(BaseBookLoader): return index + def batch_init_then_wait(self): + name, _ = os.path.splitext(self.epub_name) + if self.batch_flag or self.batch_use_flag: + self.translate_model.batch_init(name) + if self.batch_use_flag: + start_time = time.time() + while not self.translate_model.is_completed_batch(): + print("Batch translation is not completed yet") + time.sleep(2) + if time.time() - start_time > 300: # 5 minutes + raise Exception("Batch translation timed out after 5 minutes") + def make_bilingual_book(self): self.helper = EPUBBookLoaderHelper( self.translate_model, @@ -463,6 +485,7 @@ class EPUBBookLoader(BaseBookLoader): self.translation_style, self.context_flag, ) + self.batch_init_then_wait() new_book = self._make_new_book(self.origin_book) all_items = list(self.origin_book.get_items()) trans_taglist = self.translate_tags.split(",") @@ -520,6 +543,10 @@ class EPUBBookLoader(BaseBookLoader): name, _ = os.path.splitext(self.epub_name) epub.write_epub(f"{name}_bilingual.epub", new_book, {}) name, _ = os.path.splitext(self.epub_name) + if self.batch_flag: + self.translate_model.batch() + else: + epub.write_epub(f"{name}_bilingual.epub", new_book, {}) epub.write_epub(f"{name}_bilingual.epub", new_book, {}) if self.accumulated_num == 1: pbar.close() diff --git a/book_maker/translator/chatgptapi_translator.py b/book_maker/translator/chatgptapi_translator.py index 94ecfae..a19f4e4 100644 --- a/book_maker/translator/chatgptapi_translator.py +++ b/book_maker/translator/chatgptapi_translator.py @@ -1,8 +1,11 @@ import re import time +import os +import shutil from copy import copy from os import environ from itertools import cycle +import json from openai import AzureOpenAI, OpenAI, RateLimitError from rich import print @@ -85,6 +88,8 @@ class ChatGPTAPI(Base): else: # set by user, use user's value self.context_paragraph_limit = CONTEXT_PARAGRAPH_LIMIT + self.batch_text_list = [] + self.batch_info_cache = None def rotate_key(self): self.openai_client.api_key = next(self.keys) @@ -92,10 +97,11 @@ class ChatGPTAPI(Base): def rotate_model(self): self.model = next(self.model_list) - def create_chat_completion(self, text): + def create_messages(self, text): content = self.prompt_template.format( text=text, language=self.language, crlf="\n" ) + sys_content = self.system_content or self.prompt_sys_msg.format(crlf="\n") messages = [ {"role": "system", "content": sys_content}, @@ -109,7 +115,10 @@ class ChatGPTAPI(Base): } ) messages.append({"role": "user", "content": content}) + return messages + def create_chat_completion(self, text): + messages = self.create_messages(text) completion = self.openai_client.chat.completions.create( model=self.model, messages=messages, @@ -388,3 +397,180 @@ class ChatGPTAPI(Base): model_list = list(set(model_list)) print(f"Using model list {model_list}") self.model_list = cycle(model_list) + + def add_to_batch_trasnlate_queue(self, book_index, text): + self.batch_text_list.append({"book_index": book_index, "text": text}) + + def batch_init(self, book_name): + self.book_name = self.sanitize_book_name(book_name) + + def sanitize_book_name(self, book_name): + # Replace any characters that are not alphanumeric, underscore, hyphen, or dot with an underscore + sanitized_book_name = re.sub(r"[^\w\-_\.]", "_", book_name) + # Remove leading and trailing underscores and dots + sanitized_book_name = sanitized_book_name.strip("._") + return sanitized_book_name + + def batch_result_file_path(self): + return f"{os.getcwd()}/batch_files/{self.book_name}_info.json" + + def batch_dir(self): + return f"{os.getcwd()}/batch_files/{self.book_name}" + + def custom_id(self, book_index): + return f"{self.book_name}-{book_index}" + + def is_completed_batch(self): + batch_result_file_path = self.batch_result_file_path() + + if not os.path.exists(batch_result_file_path): + print("Batch result file does not exist") + raise Exception("Batch result file does not exist") + + with open(batch_result_file_path, "r", encoding="utf-8") as f: + batch_info = json.load(f) + + for batch_file in batch_info["batch_files"]: + batch_status = self.check_batch_status(batch_file["batch_id"]) + if batch_status.status != "completed": + return False + + return True + + def batch_translate(self, book_index): + if self.batch_info_cache is None: + batch_result_file_path = self.batch_result_file_path() + with open(batch_result_file_path, "r", encoding="utf-8") as f: + self.batch_info_cache = json.load(f) + + batch_info = self.batch_info_cache + target_batch = None + for batch in batch_info["batch_files"]: + if batch["start_index"] <= book_index < batch["end_index"]: + target_batch = batch + break + + if not target_batch: + raise ValueError(f"No batch found for book_index {book_index}") + + batch_status = self.check_batch_status(target_batch["batch_id"]) + if batch_status.output_file_id == None: + raise ValueError(f"Batch {target_batch['batch_id']} is not completed") + result_content = self.get_batch_result(batch_status.output_file_id) + result_lines = result_content.text.split("\n") + custom_id = self.custom_id(book_index) + for line in result_lines: + if line.strip(): + result = json.loads(line) + if result["custom_id"] == custom_id: + return result["response"]["body"]["choices"][0]["message"][ + "content" + ] + + raise ValueError(f"No result found for custom_id {custom_id}") + + def make_batch_request(self, book_index, text): + messages = self.create_messages(text) + return { + "custom_id": self.custom_id(book_index), + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": self.model, "messages": messages, "max_tokens": 1000}, + } + + def create_batch_files(self, dest_file_path): + file_paths = [] + # max request 50,000 and max size 100MB + lines_per_file = 40000 + current_file = 0 + + for i in range(0, len(self.batch_text_list), lines_per_file): + current_file += 1 + file_path = f"{dest_file_path}/{current_file}.jsonl" + start_index = i + end_index = i + lines_per_file + + # TODO: Split the file if it exceeds 100MB + with open(file_path, "w", encoding="utf-8") as f: + for text in self.batch_text_list[i : i + lines_per_file]: + batch_req = self.make_batch_request( + text["book_index"], text["text"] + ) + json.dump(batch_req, f, ensure_ascii=False) + f.write("\n") + file_paths.append( + { + "file_path": file_path, + "start_index": start_index, + "end_index": end_index, + } + ) + + return file_paths + + def upload_batch_file(self, file_path): + batch_input_file = self.openai_client.files.create( + file=open(file_path, "rb"), purpose="batch" + ) + return batch_input_file.id + + def batch(self): + self.rotate_model() + # current working directory + batch_dir = self.batch_dir() + batch_result_file_path = self.batch_result_file_path() + # cleanup batch dir and result file + if os.path.exists(batch_dir): + shutil.rmtree(batch_dir) + if os.path.exists(batch_result_file_path): + os.remove(batch_result_file_path) + os.makedirs(batch_dir, exist_ok=True) + # batch execute + batch_files = self.create_batch_files(batch_dir) + batch_info = [] + for batch_file in batch_files: + file_id = self.upload_batch_file(batch_file["file_path"]) + batch = self.batch_execute(file_id) + batch_info.append( + self.create_batch_info( + file_id, batch, batch_file["start_index"], batch_file["end_index"] + ) + ) + # save batch info + batch_info_json = { + "book_id": self.book_name, + "batch_date": time.strftime("%Y-%m-%d %H:%M:%S"), + "batch_files": batch_info, + } + with open(batch_result_file_path, "w", encoding="utf-8") as f: + json.dump(batch_info_json, f, ensure_ascii=False, indent=2) + + def create_batch_info(self, file_id, batch, start_index, end_index): + return { + "input_file_id": file_id, + "batch_id": batch.id, + "start_index": start_index, + "end_index": end_index, + "prefix": self.book_name, + } + + def batch_execute(self, file_id): + current_time = time.strftime("%Y-%m-%d %H:%M:%S") + res = self.openai_client.batches.create( + input_file_id=file_id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "description": f"Batch job for {self.book_name} at {current_time}" + }, + ) + if res.errors: + print(res.errors) + raise Exception(f"Batch execution failed: {res.errors}") + return res + + def check_batch_status(self, batch_id): + return self.openai_client.batches.retrieve(batch_id) + + def get_batch_result(self, output_file_id): + return self.openai_client.files.content(output_file_id)