Add context and system message support for Claude (#438)

* option to select claude model

* add context_flag and context_paragraph_limit option for claude

* reformat with black

* remove nonexistent model
This commit is contained in:
cce 2024-12-07 18:06:38 -05:00 committed by GitHub
parent daea974d68
commit b80f1ba785
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 77 additions and 11 deletions

View File

@ -375,7 +375,7 @@ So you are close to reaching the limit. You have to choose your own value, there
API_KEY = options.deepl_key or env.get("BBM_DEEPL_API_KEY") API_KEY = options.deepl_key or env.get("BBM_DEEPL_API_KEY")
if not API_KEY: if not API_KEY:
raise Exception("Please provide deepl key") raise Exception("Please provide deepl key")
elif options.model == "claude": elif options.model.startswith("claude"):
API_KEY = options.claude_key or env.get("BBM_CLAUDE_API_KEY") API_KEY = options.claude_key or env.get("BBM_CLAUDE_API_KEY")
if not API_KEY: if not API_KEY:
raise Exception("Please provide claude key") raise Exception("Please provide claude key")
@ -494,6 +494,8 @@ So you are close to reaching the limit. You have to choose your own value, there
e.translate_model.set_gpt4omini_models() e.translate_model.set_gpt4omini_models()
if options.model == "gpt4o": if options.model == "gpt4o":
e.translate_model.set_gpt4o_models() e.translate_model.set_gpt4o_models()
if options.model.startswith("claude-"):
e.translate_model.set_claude_model(options.model)
if options.block_size > 0: if options.block_size > 0:
e.block_size = options.block_size e.block_size = options.block_size
if options.batch_flag: if options.batch_flag:

View File

@ -21,6 +21,11 @@ MODEL_DICT = {
"deepl": DeepL, "deepl": DeepL,
"deeplfree": DeepLFree, "deeplfree": DeepLFree,
"claude": Claude, "claude": Claude,
"claude-3-5-sonnet-latest": Claude,
"claude-3-5-sonnet-20241022": Claude,
"claude-3-5-sonnet-20240620": Claude,
"claude-3-5-haiku-latest": Claude,
"claude-3-5-haiku-20241022": Claude,
"gemini": Gemini, "gemini": Gemini,
"geminipro": Gemini, "geminipro": Gemini,
"groq": GroqClient, "groq": GroqClient,

View File

@ -13,40 +13,99 @@ class Claude(Base):
language, language,
api_base=None, api_base=None,
prompt_template=None, prompt_template=None,
prompt_sys_msg=None,
temperature=1.0, temperature=1.0,
context_flag=False,
context_paragraph_limit=5,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(key, language) super().__init__(key, language)
self.api_url = f"{api_base}" if api_base else "https://api.anthropic.com" self.api_url = api_base or "https://api.anthropic.com"
self.client = Anthropic(base_url=api_base, api_key=key, timeout=20) self.client = Anthropic(base_url=api_base, api_key=key, timeout=20)
self.model = "claude-3-5-sonnet-20241022" # default it for now self.model = "claude-3-5-sonnet-20241022" # default it for now
self.language = language self.language = language
self.prompt_template = ( self.prompt_template = (
prompt_template prompt_template
or "\n\nHuman: Help me translate the text within triple backticks into {language} and provide only the translated result.\n```{text}```\n\nAssistant: " or "Help me translate the text within triple backticks into {language} and provide only the translated result.\n```{text}```"
) )
self.prompt_sys_msg = prompt_sys_msg or ""
self.temperature = temperature self.temperature = temperature
self.context_flag = context_flag
self.context_list = []
self.context_translated_list = []
self.context_paragraph_limit = context_paragraph_limit
def rotate_key(self): def rotate_key(self):
pass pass
def set_claude_model(self, model_name):
self.model = model_name
def create_messages(self, text, intermediate_messages=None):
"""Create messages for the current translation request"""
current_msg = {
"role": "user",
"content": self.prompt_template.format(
text=text,
language=self.language,
),
}
messages = []
if intermediate_messages:
messages.extend(intermediate_messages)
messages.append(current_msg)
return messages
def create_context_messages(self):
"""Create a message pair containing all context paragraphs"""
if not self.context_flag or not self.context_list:
return []
# Create a single message pair for all previous context
return [
{
"role": "user",
"content": self.prompt_template.format(
text="\n\n".join(self.context_list),
language=self.language,
),
},
{"role": "assistant", "content": "\n\n".join(self.context_translated_list)},
]
def save_context(self, text, t_text):
"""Save the current translation pair to context"""
if not self.context_flag:
return
self.context_list.append(text)
self.context_translated_list.append(t_text)
# Keep only the most recent paragraphs within the limit
if len(self.context_list) > self.context_paragraph_limit:
self.context_list.pop(0)
self.context_translated_list.pop(0)
def translate(self, text): def translate(self, text):
print(text) print(text)
self.rotate_key() self.rotate_key()
prompt = self.prompt_template.format(
text=text, # Create messages with context
language=self.language, messages = self.create_messages(text, self.create_context_messages())
)
message = [{"role": "user", "content": prompt}]
r = self.client.messages.create( r = self.client.messages.create(
max_tokens=4096, max_tokens=4096,
messages=message, messages=messages,
system=self.prompt_sys_msg,
temperature=self.temperature, temperature=self.temperature,
model=self.model, model=self.model,
) )
t_text = r.content[0].text t_text = r.content[0].text
# api limit rate and spider rule
time.sleep(1) if self.context_flag:
self.save_context(text, t_text)
print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]") print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]")
return t_text return t_text