mirror of
https://github.com/yihong0618/bilingual_book_maker.git
synced 2025-06-06 11:35:49 +00:00
feat: add temperature parameter (#278)
* add `temperature` parameter to OpenAI based translators and `Claude` * add `temperature` parameter to book loaders * add `--temperature` option to cli
This commit is contained in:
parent
9ba0dd4b91
commit
20ad3ba7c1
@ -247,6 +247,12 @@ So you are close to reaching the limit. You have to choose your own value, there
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="adds an additional paragraph for global, updating historical context of the story to the model's input, improving the narrative consistency for the AI model (this uses ~200 more tokens each time)",
|
help="adds an additional paragraph for global, updating historical context of the story to the model's input, improving the narrative consistency for the AI model (this uses ~200 more tokens each time)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="temperature parameter for `gpt3`/`chatgptapi`/`gpt4`/`claude`",
|
||||||
|
)
|
||||||
|
|
||||||
options = parser.parse_args()
|
options = parser.parse_args()
|
||||||
|
|
||||||
@ -331,6 +337,7 @@ So you are close to reaching the limit. You have to choose your own value, there
|
|||||||
prompt_config=parse_prompt_arg(options.prompt_arg),
|
prompt_config=parse_prompt_arg(options.prompt_arg),
|
||||||
single_translate=options.single_translate,
|
single_translate=options.single_translate,
|
||||||
context_flag=options.context_flag,
|
context_flag=options.context_flag,
|
||||||
|
temperature=options.temperature,
|
||||||
)
|
)
|
||||||
# other options
|
# other options
|
||||||
if options.allow_navigable_strings:
|
if options.allow_navigable_strings:
|
||||||
|
@ -31,6 +31,7 @@ class EPUBBookLoader(BaseBookLoader):
|
|||||||
prompt_config=None,
|
prompt_config=None,
|
||||||
single_translate=False,
|
single_translate=False,
|
||||||
context_flag=False,
|
context_flag=False,
|
||||||
|
temperature=1.0,
|
||||||
):
|
):
|
||||||
self.epub_name = epub_name
|
self.epub_name = epub_name
|
||||||
self.new_epub = epub.EpubBook()
|
self.new_epub = epub.EpubBook()
|
||||||
@ -39,6 +40,7 @@ class EPUBBookLoader(BaseBookLoader):
|
|||||||
language,
|
language,
|
||||||
api_base=model_api_base,
|
api_base=model_api_base,
|
||||||
context_flag=context_flag,
|
context_flag=context_flag,
|
||||||
|
temperature=temperature,
|
||||||
**prompt_config_to_kwargs(prompt_config),
|
**prompt_config_to_kwargs(prompt_config),
|
||||||
)
|
)
|
||||||
self.is_test = is_test
|
self.is_test = is_test
|
||||||
|
@ -24,12 +24,14 @@ class SRTBookLoader(BaseBookLoader):
|
|||||||
prompt_config=None,
|
prompt_config=None,
|
||||||
single_translate=False,
|
single_translate=False,
|
||||||
context_flag=False,
|
context_flag=False,
|
||||||
|
temperature=1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.srt_name = srt_name
|
self.srt_name = srt_name
|
||||||
self.translate_model = model(
|
self.translate_model = model(
|
||||||
key,
|
key,
|
||||||
language,
|
language,
|
||||||
api_base=model_api_base,
|
api_base=model_api_base,
|
||||||
|
temperature=temperature,
|
||||||
**prompt_config_to_kwargs(
|
**prompt_config_to_kwargs(
|
||||||
{
|
{
|
||||||
"system": "You are a srt subtitle file translator.",
|
"system": "You are a srt subtitle file translator.",
|
||||||
|
@ -20,12 +20,14 @@ class TXTBookLoader(BaseBookLoader):
|
|||||||
prompt_config=None,
|
prompt_config=None,
|
||||||
single_translate=False,
|
single_translate=False,
|
||||||
context_flag=False,
|
context_flag=False,
|
||||||
|
temperature=1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.txt_name = txt_name
|
self.txt_name = txt_name
|
||||||
self.translate_model = model(
|
self.translate_model = model(
|
||||||
key,
|
key,
|
||||||
language,
|
language,
|
||||||
api_base=model_api_base,
|
api_base=model_api_base,
|
||||||
|
temperature=temperature,
|
||||||
**prompt_config_to_kwargs(prompt_config),
|
**prompt_config_to_kwargs(prompt_config),
|
||||||
)
|
)
|
||||||
self.is_test = is_test
|
self.is_test = is_test
|
||||||
|
@ -24,6 +24,7 @@ class ChatGPTAPI(Base):
|
|||||||
api_base=None,
|
api_base=None,
|
||||||
prompt_template=None,
|
prompt_template=None,
|
||||||
prompt_sys_msg=None,
|
prompt_sys_msg=None,
|
||||||
|
temperature=1.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(key, language)
|
super().__init__(key, language)
|
||||||
@ -46,6 +47,7 @@ class ChatGPTAPI(Base):
|
|||||||
)
|
)
|
||||||
self.system_content = environ.get("OPENAI_API_SYS_MSG") or ""
|
self.system_content = environ.get("OPENAI_API_SYS_MSG") or ""
|
||||||
self.deployment_id = None
|
self.deployment_id = None
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
def rotate_key(self):
|
def rotate_key(self):
|
||||||
openai.api_key = next(self.keys)
|
openai.api_key = next(self.keys)
|
||||||
@ -64,11 +66,13 @@ class ChatGPTAPI(Base):
|
|||||||
return openai.ChatCompletion.create(
|
return openai.ChatCompletion.create(
|
||||||
engine=self.deployment_id,
|
engine=self.deployment_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
temperature=self.temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
return openai.ChatCompletion.create(
|
return openai.ChatCompletion.create(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
temperature=self.temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_translation(self, text):
|
def get_translation(self, text):
|
||||||
|
@ -7,7 +7,13 @@ from .base_translator import Base
|
|||||||
|
|
||||||
class Claude(Base):
|
class Claude(Base):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, key, language, api_base=None, prompt_template=None, **kwargs
|
self,
|
||||||
|
key,
|
||||||
|
language,
|
||||||
|
api_base=None,
|
||||||
|
prompt_template=None,
|
||||||
|
temperature=1.0,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(key, language)
|
super().__init__(key, language)
|
||||||
self.api_url = (
|
self.api_url = (
|
||||||
@ -23,7 +29,7 @@ class Claude(Base):
|
|||||||
"prompt": "",
|
"prompt": "",
|
||||||
"model": "claude-v1.3",
|
"model": "claude-v1.3",
|
||||||
"max_tokens_to_sample": 1024,
|
"max_tokens_to_sample": 1024,
|
||||||
"temperature": 1,
|
"temperature": temperature,
|
||||||
"stop_sequences": ["\n\nHuman:"],
|
"stop_sequences": ["\n\nHuman:"],
|
||||||
}
|
}
|
||||||
self.session = requests.session()
|
self.session = requests.session()
|
||||||
|
@ -7,7 +7,13 @@ from .base_translator import Base
|
|||||||
|
|
||||||
class GPT3(Base):
|
class GPT3(Base):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, key, language, api_base=None, prompt_template=None, **kwargs
|
self,
|
||||||
|
key,
|
||||||
|
language,
|
||||||
|
api_base=None,
|
||||||
|
prompt_template=None,
|
||||||
|
temperature=1.0,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(key, language)
|
super().__init__(key, language)
|
||||||
self.api_url = (
|
self.api_url = (
|
||||||
@ -23,7 +29,7 @@ class GPT3(Base):
|
|||||||
"prompt": "",
|
"prompt": "",
|
||||||
"model": "text-davinci-003",
|
"model": "text-davinci-003",
|
||||||
"max_tokens": 1024,
|
"max_tokens": 1024,
|
||||||
"temperature": 1,
|
"temperature": temperature,
|
||||||
"top_p": 1,
|
"top_p": 1,
|
||||||
}
|
}
|
||||||
self.session = requests.session()
|
self.session = requests.session()
|
||||||
|
@ -25,6 +25,7 @@ class GPT4(Base):
|
|||||||
prompt_template=None,
|
prompt_template=None,
|
||||||
prompt_sys_msg=None,
|
prompt_sys_msg=None,
|
||||||
context_flag=False,
|
context_flag=False,
|
||||||
|
temperature=1.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(key, language)
|
super().__init__(key, language)
|
||||||
@ -49,6 +50,7 @@ class GPT4(Base):
|
|||||||
)
|
)
|
||||||
self.system_content = environ.get("OPENAI_API_SYS_MSG") or ""
|
self.system_content = environ.get("OPENAI_API_SYS_MSG") or ""
|
||||||
self.deployment_id = None
|
self.deployment_id = None
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
def rotate_key(self):
|
def rotate_key(self):
|
||||||
openai.api_key = next(self.keys)
|
openai.api_key = next(self.keys)
|
||||||
@ -75,11 +77,13 @@ class GPT4(Base):
|
|||||||
return openai.ChatCompletion.create(
|
return openai.ChatCompletion.create(
|
||||||
engine=self.deployment_id,
|
engine=self.deployment_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
temperature=self.temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
return openai.ChatCompletion.create(
|
return openai.ChatCompletion.create(
|
||||||
model="gpt-4",
|
model="gpt-4",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
temperature=self.temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_translation(self, text):
|
def get_translation(self, text):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user