diff --git a/book_maker/cli.py b/book_maker/cli.py index ee91108..c5b823d 100644 --- a/book_maker/cli.py +++ b/book_maker/cli.py @@ -375,7 +375,7 @@ So you are close to reaching the limit. You have to choose your own value, there API_KEY = options.deepl_key or env.get("BBM_DEEPL_API_KEY") if not API_KEY: raise Exception("Please provide deepl key") - elif options.model == "claude": + elif options.model.startswith("claude"): API_KEY = options.claude_key or env.get("BBM_CLAUDE_API_KEY") if not API_KEY: raise Exception("Please provide claude key") @@ -494,6 +494,8 @@ So you are close to reaching the limit. You have to choose your own value, there e.translate_model.set_gpt4omini_models() if options.model == "gpt4o": e.translate_model.set_gpt4o_models() + if options.model.startswith("claude-"): + e.translate_model.set_claude_model(options.model) if options.block_size > 0: e.block_size = options.block_size if options.batch_flag: diff --git a/book_maker/translator/__init__.py b/book_maker/translator/__init__.py index 003dbb8..2e55bcc 100644 --- a/book_maker/translator/__init__.py +++ b/book_maker/translator/__init__.py @@ -21,6 +21,11 @@ MODEL_DICT = { "deepl": DeepL, "deeplfree": DeepLFree, "claude": Claude, + "claude-3-5-sonnet-latest": Claude, + "claude-3-5-sonnet-20241022": Claude, + "claude-3-5-sonnet-20240620": Claude, + "claude-3-5-haiku-latest": Claude, + "claude-3-5-haiku-20241022": Claude, "gemini": Gemini, "geminipro": Gemini, "groq": GroqClient, diff --git a/book_maker/translator/claude_translator.py b/book_maker/translator/claude_translator.py index f21da6f..5335be5 100644 --- a/book_maker/translator/claude_translator.py +++ b/book_maker/translator/claude_translator.py @@ -13,40 +13,99 @@ class Claude(Base): language, api_base=None, prompt_template=None, + prompt_sys_msg=None, temperature=1.0, + context_flag=False, + context_paragraph_limit=5, **kwargs, ) -> None: super().__init__(key, language) - self.api_url = f"{api_base}" if api_base else "https://api.anthropic.com" + self.api_url = api_base or "https://api.anthropic.com" self.client = Anthropic(base_url=api_base, api_key=key, timeout=20) self.model = "claude-3-5-sonnet-20241022" # default it for now self.language = language self.prompt_template = ( prompt_template - or "\n\nHuman: Help me translate the text within triple backticks into {language} and provide only the translated result.\n```{text}```\n\nAssistant: " + or "Help me translate the text within triple backticks into {language} and provide only the translated result.\n```{text}```" ) + self.prompt_sys_msg = prompt_sys_msg or "" self.temperature = temperature + self.context_flag = context_flag + self.context_list = [] + self.context_translated_list = [] + self.context_paragraph_limit = context_paragraph_limit def rotate_key(self): pass + def set_claude_model(self, model_name): + self.model = model_name + + def create_messages(self, text, intermediate_messages=None): + """Create messages for the current translation request""" + current_msg = { + "role": "user", + "content": self.prompt_template.format( + text=text, + language=self.language, + ), + } + + messages = [] + if intermediate_messages: + messages.extend(intermediate_messages) + messages.append(current_msg) + + return messages + + def create_context_messages(self): + """Create a message pair containing all context paragraphs""" + if not self.context_flag or not self.context_list: + return [] + + # Create a single message pair for all previous context + return [ + { + "role": "user", + "content": self.prompt_template.format( + text="\n\n".join(self.context_list), + language=self.language, + ), + }, + {"role": "assistant", "content": "\n\n".join(self.context_translated_list)}, + ] + + def save_context(self, text, t_text): + """Save the current translation pair to context""" + if not self.context_flag: + return + + self.context_list.append(text) + self.context_translated_list.append(t_text) + + # Keep only the most recent paragraphs within the limit + if len(self.context_list) > self.context_paragraph_limit: + self.context_list.pop(0) + self.context_translated_list.pop(0) + def translate(self, text): print(text) self.rotate_key() - prompt = self.prompt_template.format( - text=text, - language=self.language, - ) - message = [{"role": "user", "content": prompt}] + + # Create messages with context + messages = self.create_messages(text, self.create_context_messages()) + r = self.client.messages.create( max_tokens=4096, - messages=message, + messages=messages, + system=self.prompt_sys_msg, temperature=self.temperature, model=self.model, ) t_text = r.content[0].text - # api limit rate and spider rule - time.sleep(1) + + if self.context_flag: + self.save_context(text, t_text) print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]") return t_text