feature: GPT-4 support (#262)

* feat: add gpt4 support

* Update prompt_template_sample.json

* fix: cleaned up formatting for quotes

* feature: added context functionality (--use_context) for GPT4 model, which accumulates a running paragraph giving historical context to the current passage

* fix: propagated context_flag argument to txt_loader and srt_loader

* Updated Readme to include GPT4 parameters

* Removed debug output

* fix: lint

---------

Co-authored-by: yihong <zouzou0208@gmail.com>
This commit is contained in:
astromaddie 2023-05-11 22:57:32 +09:00 committed by GitHub
parent c965ace98c
commit cd56b1e0ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 366 additions and 8 deletions

View File

@ -19,6 +19,9 @@ The bilingual_book_maker is an AI translation tool that uses ChatGPT to assist u
- Use `--openai_key` option to specify OpenAI API key. If you have multiple keys, separate them by commas (xxx,xxx,xxx) to reduce errors caused by API call limits. - Use `--openai_key` option to specify OpenAI API key. If you have multiple keys, separate them by commas (xxx,xxx,xxx) to reduce errors caused by API call limits.
Or, just set environment variable `BBM_OPENAI_API_KEY` instead. Or, just set environment variable `BBM_OPENAI_API_KEY` instead.
- A sample book, `test_books/animal_farm.epub`, is provided for testing purposes. - A sample book, `test_books/animal_farm.epub`, is provided for testing purposes.
- The default underlying model is [GPT-3.5-turbo](https://openai.com/blog/introducing-chatgpt-and-whisper-apis), which is used by ChatGPT currently. Use `--model gpt4` to change the underlying model to `GPT4` and use `--model gpt3` to change the model to `GPT3`.
If using `GPT4`, you can add `--use_context` to add a context paragraph to each passage sent to the model for translation (see below)
- support DeepL model [DeepL Translator](https://rapidapi.com/splintPRO/api/deepl-translator) need pay to get the token use `--model deepl --deepl_key ${deepl_key}`
- The default underlying model is [GPT-3.5-turbo](https://openai.com/blog/introducing-chatgpt-and-whisper-apis), which is used by ChatGPT currently. Use `--model gpt3` to change the underlying model to `GPT3` - The default underlying model is [GPT-3.5-turbo](https://openai.com/blog/introducing-chatgpt-and-whisper-apis), which is used by ChatGPT currently. Use `--model gpt3` to change the underlying model to `GPT3`
- Support DeepL model [DeepL Translator](https://rapidapi.com/splintPRO/api/deepl-translator) need pay to get the token use `--model deepl --deepl_key ${deepl_key}` - Support DeepL model [DeepL Translator](https://rapidapi.com/splintPRO/api/deepl-translator) need pay to get the token use `--model deepl --deepl_key ${deepl_key}`
- Support [Claude](https://console.anthropic.com/docs) model, use `--model claude --claude_key ${claude_key}` - Support [Claude](https://console.anthropic.com/docs) model, use `--model claude --claude_key ${claude_key}`
@ -44,6 +47,7 @@ The bilingual_book_maker is an AI translation tool that uses ChatGPT to assist u
- `--accumulated_num` Wait for how many tokens have been accumulated before starting the translation. gpt3.5 limits the total_token to 4090. For example, if you use --accumulated_num 1600, maybe openai will - `--accumulated_num` Wait for how many tokens have been accumulated before starting the translation. gpt3.5 limits the total_token to 4090. For example, if you use --accumulated_num 1600, maybe openai will
output 2200 tokens and maybe 200 tokens for other messages in the system messages user messages, 1600+2200+200=4000, So you are close to reaching the limit. You have to choose your own output 2200 tokens and maybe 200 tokens for other messages in the system messages user messages, 1600+2200+200=4000, So you are close to reaching the limit. You have to choose your own
value, there is no way to know if the limit is reached before sending value, there is no way to know if the limit is reached before sending
- `--use_context` prompts the GPT4 model to create a one-paragraph summary. If it's the beginning of the translation, it will summarise the entire passage sent (the size depending on `--accumulated_num`), but if it's any proceeding passage, it will amend the summary to include details from the most recent passage, creating a running one-paragraph context payload of the important details of the entire translated work, which improves consistency of flow and tone of each translation.
- `--translation_style` example: `--translation_style "color: #808080; font-style: italic;"` - `--translation_style` example: `--translation_style "color: #808080; font-style: italic;"`
- `--retranslate` `--retranslate "$translated_filepath" "file_name_in_epub" "start_str" "end_str"(optional)`<br> - `--retranslate` `--retranslate "$translated_filepath" "file_name_in_epub" "start_str" "end_str"(optional)`<br>
Retranslate from start_str to end_str's tag: Retranslate from start_str to end_str's tag:
@ -67,6 +71,9 @@ python3 make_book.py --book_name test_books/animal_farm.epub --openai_key ${open
# Set env OPENAI_API_KEY to ignore option --openai_key # Set env OPENAI_API_KEY to ignore option --openai_key
export OPENAI_API_KEY=${your_api_key} export OPENAI_API_KEY=${your_api_key}
# Use the GPT-4 model with context to Japanese
python3 make_book.py --book_name test_books/animal_farm.epub --model gpt4 --use_context --language ja
# Use the GPT-3 model with Japanese # Use the GPT-3 model with Japanese
python3 make_book.py --book_name test_books/animal_farm.epub --model gpt3 --language ja python3 make_book.py --book_name test_books/animal_farm.epub --model gpt3 --language ja

View File

@ -241,6 +241,12 @@ So you are close to reaching the limit. You have to choose your own value, there
action="store_true", action="store_true",
help="output translated book, no bilingual", help="output translated book, no bilingual",
) )
parser.add_argument(
"--use_context",
dest="context_flag",
action="store_true",
help="adds an additional paragraph for global, updating historical context of the story to the model's input, improving the narrative consistency for the AI model (this uses ~200 more tokens each time)",
)
options = parser.parse_args() options = parser.parse_args()
@ -256,7 +262,7 @@ So you are close to reaching the limit. You have to choose your own value, there
translate_model = MODEL_DICT.get(options.model) translate_model = MODEL_DICT.get(options.model)
assert translate_model is not None, "unsupported model" assert translate_model is not None, "unsupported model"
API_KEY = "" API_KEY = ""
if options.model in ["gpt3", "chatgptapi"]: if options.model in ["gpt3", "chatgptapi", "gpt4"]:
if OPENAI_API_KEY := ( if OPENAI_API_KEY := (
options.openai_key options.openai_key
or env.get( or env.get(
@ -324,6 +330,7 @@ So you are close to reaching the limit. You have to choose your own value, there
test_num=options.test_num, test_num=options.test_num,
prompt_config=parse_prompt_arg(options.prompt_arg), prompt_config=parse_prompt_arg(options.prompt_arg),
single_translate=options.single_translate, single_translate=options.single_translate,
context_flag=options.context_flag,
) )
# other options # other options
if options.allow_navigable_strings: if options.allow_navigable_strings:

View File

@ -30,6 +30,7 @@ class EPUBBookLoader(BaseBookLoader):
test_num=5, test_num=5,
prompt_config=None, prompt_config=None,
single_translate=False, single_translate=False,
context_flag=False,
): ):
self.epub_name = epub_name self.epub_name = epub_name
self.new_epub = epub.EpubBook() self.new_epub = epub.EpubBook()
@ -37,6 +38,7 @@ class EPUBBookLoader(BaseBookLoader):
key, key,
language, language,
api_base=model_api_base, api_base=model_api_base,
context_flag=context_flag,
**prompt_config_to_kwargs(prompt_config), **prompt_config_to_kwargs(prompt_config),
) )
self.is_test = is_test self.is_test = is_test
@ -46,8 +48,12 @@ class EPUBBookLoader(BaseBookLoader):
self.allow_navigable_strings = False self.allow_navigable_strings = False
self.accumulated_num = 1 self.accumulated_num = 1
self.translation_style = "" self.translation_style = ""
self.context_flag = context_flag
self.helper = EPUBBookLoaderHelper( self.helper = EPUBBookLoaderHelper(
self.translate_model, self.accumulated_num, self.translation_style self.translate_model,
self.accumulated_num,
self.translation_style,
self.context_flag,
) )
self.retranslate = None self.retranslate = None
self.exclude_filelist = "" self.exclude_filelist = ""
@ -378,7 +384,10 @@ class EPUBBookLoader(BaseBookLoader):
def make_bilingual_book(self): def make_bilingual_book(self):
self.helper = EPUBBookLoaderHelper( self.helper = EPUBBookLoaderHelper(
self.translate_model, self.accumulated_num, self.translation_style self.translate_model,
self.accumulated_num,
self.translation_style,
self.context_flag,
) )
new_book = self._make_new_book(self.origin_book) new_book = self._make_new_book(self.origin_book)
all_items = list(self.origin_book.get_items()) all_items = list(self.origin_book.get_items())

View File

@ -3,10 +3,13 @@ from copy import copy
class EPUBBookLoaderHelper: class EPUBBookLoaderHelper:
def __init__(self, translate_model, accumulated_num, translation_style): def __init__(
self, translate_model, accumulated_num, translation_style, context_flag
):
self.translate_model = translate_model self.translate_model = translate_model
self.accumulated_num = accumulated_num self.accumulated_num = accumulated_num
self.translation_style = translation_style self.translation_style = translation_style
self.context_flag = context_flag
def insert_trans(self, p, text, translation_style="", single_translate=False): def insert_trans(self, p, text, translation_style="", single_translate=False):
if ( if (
@ -23,19 +26,21 @@ class EPUBBookLoaderHelper:
p.extract() p.extract()
def deal_new(self, p, wait_p_list, single_translate=False): def deal_new(self, p, wait_p_list, single_translate=False):
self.deal_old(wait_p_list, single_translate) self.deal_old(wait_p_list, single_translate, self.context_flag)
self.insert_trans( self.insert_trans(
p, p,
shorter_result_link(self.translate_model.translate(p.text)), shorter_result_link(
self.translate_model.translate(p.text, self.context_flag)
),
self.translation_style, self.translation_style,
single_translate, single_translate,
) )
def deal_old(self, wait_p_list, single_translate=False): def deal_old(self, wait_p_list, single_translate=False, context_flag=False):
if not wait_p_list: if not wait_p_list:
return return
result_txt_list = self.translate_model.translate_list(wait_p_list) result_txt_list = self.translate_model.translate_list(wait_p_list, context_flag)
for i in range(len(wait_p_list)): for i in range(len(wait_p_list)):
if i < len(result_txt_list): if i < len(result_txt_list):

View File

@ -23,6 +23,7 @@ class SRTBookLoader(BaseBookLoader):
test_num=5, test_num=5,
prompt_config=None, prompt_config=None,
single_translate=False, single_translate=False,
context_flag=False,
) -> None: ) -> None:
self.srt_name = srt_name self.srt_name = srt_name
self.translate_model = model( self.translate_model = model(

View File

@ -19,6 +19,7 @@ class TXTBookLoader(BaseBookLoader):
test_num=5, test_num=5,
prompt_config=None, prompt_config=None,
single_translate=False, single_translate=False,
context_flag=False,
) -> None: ) -> None:
self.txt_name = txt_name self.txt_name = txt_name
self.translate_model = model( self.translate_model = model(

View File

@ -3,6 +3,7 @@ from book_maker.translator.chatgptapi_translator import ChatGPTAPI
from book_maker.translator.deepl_translator import DeepL from book_maker.translator.deepl_translator import DeepL
from book_maker.translator.google_translator import Google from book_maker.translator.google_translator import Google
from book_maker.translator.gpt3_translator import GPT3 from book_maker.translator.gpt3_translator import GPT3
from book_maker.translator.gpt4_translator import GPT4
from book_maker.translator.claude_translator import Claude from book_maker.translator.claude_translator import Claude
MODEL_DICT = { MODEL_DICT = {
@ -11,6 +12,7 @@ MODEL_DICT = {
"google": Google, "google": Google,
"caiyun": Caiyun, "caiyun": Caiyun,
"deepl": DeepL, "deepl": DeepL,
"gpt4": GPT4,
"claude": Claude, "claude": Claude,
# add more here # add more here
} }

View File

@ -0,0 +1,326 @@
import re
import time
from copy import copy
from os import environ, linesep
from rich import print
import openai
from .base_translator import Base
PROMPT_ENV_MAP = {
"user": "BBM_CHATGPTAPI_USER_MSG_TEMPLATE",
"system": "BBM_CHATGPTAPI_SYS_MSG",
}
class GPT4(Base):
DEFAULT_PROMPT = "Please help me to translate,`{text}` to {language}, please return only translated content not include the origin text"
def __init__(
self,
key,
language,
api_base=None,
prompt_template=None,
prompt_sys_msg=None,
context_flag=False,
**kwargs,
) -> None:
super().__init__(key, language)
self.context_flag = context_flag
self.context = "<summary>The start of the story.</summary>"
self.key_len = len(key.split(","))
if api_base:
openai.api_base = api_base
self.prompt_template = (
prompt_template
or environ.get(PROMPT_ENV_MAP["user"])
or self.DEFAULT_PROMPT
)
self.prompt_sys_msg = (
prompt_sys_msg
or environ.get(
"OPENAI_API_SYS_MSG",
) # XXX: for backward compatability, deprecate soon
or environ.get(PROMPT_ENV_MAP["system"])
or ""
)
self.system_content = environ.get("OPENAI_API_SYS_MSG") or ""
self.deployment_id = None
def rotate_key(self):
openai.api_key = next(self.keys)
def create_chat_completion(self, text):
# content = self.prompt_template.format(
# text=text, language=self.language, crlf="\n"
# )
content = f"{self.context if self.context_flag else ''} {self.prompt_template.format(text=text, language=self.language, crlf=linesep)}"
sys_content = self.system_content or self.prompt_sys_msg.format(crlf="\n")
context_sys_str = "For each passage given, you may be provided a summary of the story up until this point (wrapped in tags '<summary>' and '</summary>') for context within the query, to provide background context of the story up until this point. If it's provided, use the context summary to aid you in translation with deeper comprehension, and write a new summary above the returned translation, wrapped in '<summary>' HTML-like tags, including important details (if relevant) from the new passage, retaining the most important key details from the existing summary, and dropping out less important details. If the summary is blank, assume it is the start of the story and write a summary from scratch. Do not make the summary longer than a paragraph, and smaller details can be replaced based on the relative importance of new details. The summary should be formatted in straightforward, inornate text, briefly summarising the entire story (from the start, including information before the given passage, leading up to the given passage) to act as an instructional payload for a Large-Language AI Model to fully understand the context of the passage."
sys_content = f"{self.system_content or self.prompt_sys_msg.format(crlf=linesep)} {context_sys_str if self.context_flag else ''} "
messages = [
{"role": "system", "content": sys_content},
{"role": "user", "content": content},
]
if self.deployment_id:
return openai.ChatCompletion.create(
engine=self.deployment_id,
messages=messages,
)
return openai.ChatCompletion.create(
model="gpt-4",
messages=messages,
)
def get_translation(self, text):
self.rotate_key()
completion = {}
try:
completion = self.create_chat_completion(text)
except Exception:
if (
"choices" not in completion
or not isinstance(completion["choices"], list)
or len(completion["choices"]) == 0
):
raise
if completion["choices"][0]["finish_reason"] != "length":
raise
# work well or exception finish by length limit
choice = completion["choices"][0]
t_text = choice.get("message").get("content", "").encode("utf8").decode()
if choice["finish_reason"] == "length":
with open("log/long_text.txt", "a") as f:
print(
f"""==================================================
The total token is too long and cannot be completely translated\n
{text}
""",
file=f,
)
return t_text
def translate(self, text, needprint=True):
# print("=================================================")
start_time = time.time()
# todo: Determine whether to print according to the cli option
if needprint:
print(re.sub("\n{3,}", "\n\n", text))
attempt_count = 0
max_attempts = 3
t_text = ""
while attempt_count < max_attempts:
try:
t_text = self.get_translation(text)
# Extract the text between <summary> and </summary> tags (including the tags), save the next context text, then delete it from the text.
context_match = re.search(
r"(<summary>.*?</summary>)", t_text, re.DOTALL
)
if context_match:
self.context = context_match.group(0)
t_text = t_text.replace(self.context, "", 1)
else:
pass
# self.context = ""
break
except Exception as e:
# todo: better sleep time? why sleep alawys about key_len
# 1. openai server error or own network interruption, sleep for a fixed time
# 2. an apikey has no money or reach limit, don`t sleep, just replace it with another apikey
# 3. all apikey reach limit, then use current sleep
sleep_time = int(60 / self.key_len)
print(e, f"will sleep {sleep_time} seconds")
time.sleep(sleep_time)
attempt_count += 1
if attempt_count == max_attempts:
print(f"Get {attempt_count} consecutive exceptions")
raise
# todo: Determine whether to print according to the cli option
if needprint:
print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]")
time.time() - start_time
# print(f"translation time: {elapsed_time:.1f}s")
return t_text
def translate_and_split_lines(self, text):
result_str = self.translate(text, False)
lines = result_str.split("\n")
lines = [line.strip() for line in lines if line.strip() != ""]
return lines
def get_best_result_list(
self,
plist_len,
new_str,
sleep_dur,
result_list,
max_retries=15,
):
if len(result_list) == plist_len:
return result_list, 0
best_result_list = result_list
retry_count = 0
while retry_count < max_retries and len(result_list) != plist_len:
print(
f"bug: {plist_len} -> {len(result_list)} : Number of paragraphs before and after translation",
)
print(f"sleep for {sleep_dur}s and retry {retry_count+1} ...")
time.sleep(sleep_dur)
retry_count += 1
result_list = self.translate_and_split_lines(new_str)
if (
len(result_list) == plist_len
or len(best_result_list) < len(result_list) <= plist_len
or (
len(result_list) < len(best_result_list)
and len(best_result_list) > plist_len
)
):
best_result_list = result_list
return best_result_list, retry_count
def log_retry(self, state, retry_count, elapsed_time, log_path="log/buglog.txt"):
if retry_count == 0:
return
print(f"retry {state}")
with open(log_path, "a", encoding="utf-8") as f:
print(
f"retry {state}, count = {retry_count}, time = {elapsed_time:.1f}s",
file=f,
)
def log_translation_mismatch(
self,
plist_len,
result_list,
new_str,
sep,
log_path="log/buglog.txt",
):
if len(result_list) == plist_len:
return
newlist = new_str.split(sep)
with open(log_path, "a", encoding="utf-8") as f:
print(f"problem size: {plist_len - len(result_list)}", file=f)
for i in range(len(newlist)):
print(newlist[i], file=f)
print(file=f)
if i < len(result_list):
print("............................................", file=f)
print(result_list[i], file=f)
print(file=f)
print("=============================", file=f)
print(
f"bug: {plist_len} paragraphs of text translated into {len(result_list)} paragraphs",
)
print("continue")
def join_lines(self, text):
lines = text.split("\n")
new_lines = []
temp_line = []
# join
for line in lines:
if line.strip():
temp_line.append(line.strip())
else:
if temp_line:
new_lines.append(" ".join(temp_line))
temp_line = []
new_lines.append(line)
if temp_line:
new_lines.append(" ".join(temp_line))
text = "\n".join(new_lines)
# del ^M
text = text.replace("^M", "\r")
lines = text.split("\n")
filtered_lines = [line for line in lines if line.strip() != "\r"]
new_text = "\n".join(filtered_lines)
return new_text
def translate_list(self, plist, context_flag):
sep = "\n\n\n\n\n"
# new_str = sep.join([item.text for item in plist])
new_str = ""
i = 1
for p in plist:
temp_p = copy(p)
for sup in temp_p.find_all("sup"):
sup.extract()
new_str += f"({i}) {temp_p.get_text().strip()}{sep}"
i = i + 1
if new_str.endswith(sep):
new_str = new_str[: -len(sep)]
new_str = self.join_lines(new_str)
plist_len = len(plist)
print(f"plist len = {len(plist)}")
result_list = self.translate_and_split_lines(new_str)
start_time = time.time()
result_list, retry_count = self.get_best_result_list(
plist_len,
new_str,
6,
result_list,
)
end_time = time.time()
state = "fail" if len(result_list) != plist_len else "success"
log_path = "log/buglog.txt"
self.log_retry(state, retry_count, end_time - start_time, log_path)
self.log_translation_mismatch(plist_len, result_list, new_str, sep, log_path)
# del (num), num. sometime (num) will translated to num.
result_list = [re.sub(r"^(\(\d+\)|\d+\.|(\d+))\s*", "", s) for s in result_list]
# # Remove the context paragraph from the final output
# if self.context:
# context_len = len(self.context)
# result_list = [s[context_len:] if s.startswith(self.context) else s for s in result_list]
return result_list
def set_deployment_id(self, deployment_id):
openai.api_type = "azure"
openai.api_version = "2023-03-15-preview"
self.deployment_id = deployment_id