diff --git a/book_maker/cli.py b/book_maker/cli.py index 8898fb3..b02930f 100644 --- a/book_maker/cli.py +++ b/book_maker/cli.py @@ -279,6 +279,13 @@ So you are close to reaching the limit. You have to choose your own value, there action="store_true", help="adds an additional paragraph for global, updating historical context of the story to the model's input, improving the narrative consistency for the AI model (this uses ~200 more tokens each time)", ) + parser.add_argument( + "--context_paragraph_limit", + dest="context_paragraph_limit", + type=int, + default=0, + help="if use --use_context, set context paragraph limit", + ) parser.add_argument( "--temperature", type=float, diff --git a/book_maker/loader/epub_loader.py b/book_maker/loader/epub_loader.py index b7f0750..f26d783 100644 --- a/book_maker/loader/epub_loader.py +++ b/book_maker/loader/epub_loader.py @@ -33,6 +33,7 @@ class EPUBBookLoader(BaseBookLoader): single_translate=False, context_flag=False, temperature=1.0, + context_paragraph_limit=0, ): self.epub_name = epub_name self.new_epub = epub.EpubBook() @@ -41,6 +42,7 @@ class EPUBBookLoader(BaseBookLoader): language, api_base=model_api_base, context_flag=context_flag, + context_paragraph_limit=context_paragraph_limit, temperature=temperature, **prompt_config_to_kwargs(prompt_config), ) diff --git a/book_maker/translator/chatgptapi_translator.py b/book_maker/translator/chatgptapi_translator.py index def1c22..5326719 100644 --- a/book_maker/translator/chatgptapi_translator.py +++ b/book_maker/translator/chatgptapi_translator.py @@ -36,6 +36,7 @@ GPT4oMINI_MODEL_LIST = [ "gpt-4o-mini", "gpt-4o-mini-2024-07-18", ] +CONTEXT_PARAGRAPH_LIMIT = 3 class ChatGPTAPI(Base): @@ -49,6 +50,8 @@ class ChatGPTAPI(Base): prompt_template=None, prompt_sys_msg=None, temperature=1.0, + context_flag=False, + context_paragraph_limit=0, **kwargs, ) -> None: super().__init__(key, language) @@ -73,6 +76,15 @@ class ChatGPTAPI(Base): self.deployment_id = None self.temperature = temperature self.model_list = None + self.context_flag = context_flag + self.context_list = [] + self.context_translated_list = [] + if context_paragraph_limit > 0: + # not set by user, use default + self.context_paragraph_limit = context_paragraph_limit + else: + # set by user, use user's value + self.context_paragraph_limit = CONTEXT_PARAGRAPH_LIMIT def rotate_key(self): self.openai_client.api_key = next(self.keys) @@ -87,8 +99,11 @@ class ChatGPTAPI(Base): sys_content = self.system_content or self.prompt_sys_msg.format(crlf="\n") messages = [ {"role": "system", "content": sys_content}, - {"role": "user", "content": content}, ] + if self.context_flag: + messages.append({"role": "user", "content": "\n".join(self.context_list)}) + messages.append({"role": "assistant", "content": "\n".join(self.context_translated_list)}) + messages.append({"role": "user", "content": content}) completion = self.openai_client.chat.completions.create( model=self.model, @@ -110,7 +125,18 @@ class ChatGPTAPI(Base): else: t_text = "" + if self.context_flag: + self.save_context(text, t_text) + return t_text + def save_context(self, text, t_text): + if self.context_paragraph_limit > 0: + self.context_list.append(text) + self.context_translated_list.append(t_text) + # Remove the oldest context + if len(self.context_list) > self.context_paragraph_limit: + self.context_list.pop(0) + self.context_translated_list.pop(0) def translate(self, text, needprint=True): start_time = time.time()