mirror of
https://github.com/yihong0618/bilingual_book_maker.git
synced 2025-06-05 19:15:34 +00:00
Add context and system message support for Claude (#438)
* option to select claude model * add context_flag and context_paragraph_limit option for claude * reformat with black * remove nonexistent model
This commit is contained in:
parent
daea974d68
commit
b80f1ba785
@ -375,7 +375,7 @@ So you are close to reaching the limit. You have to choose your own value, there
|
|||||||
API_KEY = options.deepl_key or env.get("BBM_DEEPL_API_KEY")
|
API_KEY = options.deepl_key or env.get("BBM_DEEPL_API_KEY")
|
||||||
if not API_KEY:
|
if not API_KEY:
|
||||||
raise Exception("Please provide deepl key")
|
raise Exception("Please provide deepl key")
|
||||||
elif options.model == "claude":
|
elif options.model.startswith("claude"):
|
||||||
API_KEY = options.claude_key or env.get("BBM_CLAUDE_API_KEY")
|
API_KEY = options.claude_key or env.get("BBM_CLAUDE_API_KEY")
|
||||||
if not API_KEY:
|
if not API_KEY:
|
||||||
raise Exception("Please provide claude key")
|
raise Exception("Please provide claude key")
|
||||||
@ -494,6 +494,8 @@ So you are close to reaching the limit. You have to choose your own value, there
|
|||||||
e.translate_model.set_gpt4omini_models()
|
e.translate_model.set_gpt4omini_models()
|
||||||
if options.model == "gpt4o":
|
if options.model == "gpt4o":
|
||||||
e.translate_model.set_gpt4o_models()
|
e.translate_model.set_gpt4o_models()
|
||||||
|
if options.model.startswith("claude-"):
|
||||||
|
e.translate_model.set_claude_model(options.model)
|
||||||
if options.block_size > 0:
|
if options.block_size > 0:
|
||||||
e.block_size = options.block_size
|
e.block_size = options.block_size
|
||||||
if options.batch_flag:
|
if options.batch_flag:
|
||||||
|
@ -21,6 +21,11 @@ MODEL_DICT = {
|
|||||||
"deepl": DeepL,
|
"deepl": DeepL,
|
||||||
"deeplfree": DeepLFree,
|
"deeplfree": DeepLFree,
|
||||||
"claude": Claude,
|
"claude": Claude,
|
||||||
|
"claude-3-5-sonnet-latest": Claude,
|
||||||
|
"claude-3-5-sonnet-20241022": Claude,
|
||||||
|
"claude-3-5-sonnet-20240620": Claude,
|
||||||
|
"claude-3-5-haiku-latest": Claude,
|
||||||
|
"claude-3-5-haiku-20241022": Claude,
|
||||||
"gemini": Gemini,
|
"gemini": Gemini,
|
||||||
"geminipro": Gemini,
|
"geminipro": Gemini,
|
||||||
"groq": GroqClient,
|
"groq": GroqClient,
|
||||||
|
@ -13,40 +13,99 @@ class Claude(Base):
|
|||||||
language,
|
language,
|
||||||
api_base=None,
|
api_base=None,
|
||||||
prompt_template=None,
|
prompt_template=None,
|
||||||
|
prompt_sys_msg=None,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
|
context_flag=False,
|
||||||
|
context_paragraph_limit=5,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(key, language)
|
super().__init__(key, language)
|
||||||
self.api_url = f"{api_base}" if api_base else "https://api.anthropic.com"
|
self.api_url = api_base or "https://api.anthropic.com"
|
||||||
self.client = Anthropic(base_url=api_base, api_key=key, timeout=20)
|
self.client = Anthropic(base_url=api_base, api_key=key, timeout=20)
|
||||||
self.model = "claude-3-5-sonnet-20241022" # default it for now
|
self.model = "claude-3-5-sonnet-20241022" # default it for now
|
||||||
self.language = language
|
self.language = language
|
||||||
self.prompt_template = (
|
self.prompt_template = (
|
||||||
prompt_template
|
prompt_template
|
||||||
or "\n\nHuman: Help me translate the text within triple backticks into {language} and provide only the translated result.\n```{text}```\n\nAssistant: "
|
or "Help me translate the text within triple backticks into {language} and provide only the translated result.\n```{text}```"
|
||||||
)
|
)
|
||||||
|
self.prompt_sys_msg = prompt_sys_msg or ""
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
self.context_flag = context_flag
|
||||||
|
self.context_list = []
|
||||||
|
self.context_translated_list = []
|
||||||
|
self.context_paragraph_limit = context_paragraph_limit
|
||||||
|
|
||||||
def rotate_key(self):
|
def rotate_key(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def set_claude_model(self, model_name):
|
||||||
|
self.model = model_name
|
||||||
|
|
||||||
|
def create_messages(self, text, intermediate_messages=None):
|
||||||
|
"""Create messages for the current translation request"""
|
||||||
|
current_msg = {
|
||||||
|
"role": "user",
|
||||||
|
"content": self.prompt_template.format(
|
||||||
|
text=text,
|
||||||
|
language=self.language,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if intermediate_messages:
|
||||||
|
messages.extend(intermediate_messages)
|
||||||
|
messages.append(current_msg)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def create_context_messages(self):
|
||||||
|
"""Create a message pair containing all context paragraphs"""
|
||||||
|
if not self.context_flag or not self.context_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create a single message pair for all previous context
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.prompt_template.format(
|
||||||
|
text="\n\n".join(self.context_list),
|
||||||
|
language=self.language,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "\n\n".join(self.context_translated_list)},
|
||||||
|
]
|
||||||
|
|
||||||
|
def save_context(self, text, t_text):
|
||||||
|
"""Save the current translation pair to context"""
|
||||||
|
if not self.context_flag:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.context_list.append(text)
|
||||||
|
self.context_translated_list.append(t_text)
|
||||||
|
|
||||||
|
# Keep only the most recent paragraphs within the limit
|
||||||
|
if len(self.context_list) > self.context_paragraph_limit:
|
||||||
|
self.context_list.pop(0)
|
||||||
|
self.context_translated_list.pop(0)
|
||||||
|
|
||||||
def translate(self, text):
|
def translate(self, text):
|
||||||
print(text)
|
print(text)
|
||||||
self.rotate_key()
|
self.rotate_key()
|
||||||
prompt = self.prompt_template.format(
|
|
||||||
text=text,
|
# Create messages with context
|
||||||
language=self.language,
|
messages = self.create_messages(text, self.create_context_messages())
|
||||||
)
|
|
||||||
message = [{"role": "user", "content": prompt}]
|
|
||||||
r = self.client.messages.create(
|
r = self.client.messages.create(
|
||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
messages=message,
|
messages=messages,
|
||||||
|
system=self.prompt_sys_msg,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
)
|
)
|
||||||
t_text = r.content[0].text
|
t_text = r.content[0].text
|
||||||
# api limit rate and spider rule
|
|
||||||
time.sleep(1)
|
if self.context_flag:
|
||||||
|
self.save_context(text, t_text)
|
||||||
|
|
||||||
print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]")
|
print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]")
|
||||||
return t_text
|
return t_text
|
||||||
|
Loading…
x
Reference in New Issue
Block a user