diff --git a/book_maker/translator/__init__.py b/book_maker/translator/__init__.py index 2e55bcc..4410c86 100644 --- a/book_maker/translator/__init__.py +++ b/book_maker/translator/__init__.py @@ -26,6 +26,7 @@ MODEL_DICT = { "claude-3-5-sonnet-20240620": Claude, "claude-3-5-haiku-latest": Claude, "claude-3-5-haiku-20241022": Claude, + "claude-3-5-haiku-20240620": Claude, "gemini": Gemini, "geminipro": Gemini, "groq": GroqClient, diff --git a/book_maker/translator/claude_translator.py b/book_maker/translator/claude_translator.py index 3b6417d..0d97dd7 100644 --- a/book_maker/translator/claude_translator.py +++ b/book_maker/translator/claude_translator.py @@ -13,19 +13,27 @@ class Claude(Base): language, api_base=None, prompt_template=None, + prompt_sys_msg=None, temperature=1.0, + context_flag=False, + context_paragraph_limit=5, **kwargs, ) -> None: super().__init__(key, language) - self.api_url = f"{api_base}" if api_base else "https://api.anthropic.com" + self.api_url = api_base or "https://api.anthropic.com" self.client = Anthropic(base_url=api_base, api_key=key, timeout=20) self.model = "claude-3-5-sonnet-20241022" # default it for now self.language = language self.prompt_template = ( prompt_template - or "\n\nHuman: Help me translate the text within triple backticks into {language} and provide only the translated result.\n```{text}```\n\nAssistant: " + or "Help me translate the text within triple backticks into {language} and provide only the translated result.\n```{text}```" ) + self.prompt_sys_msg = prompt_sys_msg or "" self.temperature = temperature + self.context_flag = context_flag + self.context_list = [] + self.context_translated_list = [] + self.context_paragraph_limit = context_paragraph_limit def rotate_key(self): pass @@ -33,23 +41,71 @@ class Claude(Base): def set_claude_model(self, model_name): self.model = model_name + def create_messages(self, text, intermediate_messages=None): + """Create messages for the current translation request""" + current_msg = {"role": "user", "content": self.prompt_template.format( + text=text, + language=self.language, + )} + + messages = [] + if intermediate_messages: + messages.extend(intermediate_messages) + messages.append(current_msg) + + return messages + + def create_context_messages(self): + """Create a message pair containing all context paragraphs""" + if not self.context_flag or not self.context_list: + return [] + + # Create a single message pair for all previous context + return [ + { + "role": "user", + "content": self.prompt_template.format( + text="\n\n".join(self.context_list), + language=self.language, + ) + }, + { + "role": "assistant", + "content": "\n\n".join(self.context_translated_list) + } + ] + + def save_context(self, text, t_text): + """Save the current translation pair to context""" + if not self.context_flag: + return + + self.context_list.append(text) + self.context_translated_list.append(t_text) + + # Keep only the most recent paragraphs within the limit + if len(self.context_list) > self.context_paragraph_limit: + self.context_list.pop(0) + self.context_translated_list.pop(0) + def translate(self, text): print(text) self.rotate_key() - prompt = self.prompt_template.format( - text=text, - language=self.language, - ) - message = [{"role": "user", "content": prompt}] + + # Create messages with context + messages = self.create_messages(text, self.create_context_messages()) + r = self.client.messages.create( max_tokens=4096, - messages=message, + messages=messages, + system=self.prompt_sys_msg, temperature=self.temperature, model=self.model, ) t_text = r.content[0].text - # api limit rate and spider rule - time.sleep(1) + + if self.context_flag: + self.save_context(text, t_text) print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]") return t_text