diff --git a/book_maker/translator/gemini_translator.py b/book_maker/translator/gemini_translator.py index 0241d03..119114e 100644 --- a/book_maker/translator/gemini_translator.py +++ b/book_maker/translator/gemini_translator.py @@ -31,6 +31,11 @@ safety_settings = [ "threshold": "BLOCK_MEDIUM_AND_ABOVE", }, +PROMPT_ENV_MAP = { + "user": "BBM_GEMINIAPI_USER_MSG_TEMPLATE", + "system": "BBM_GEMINIAPI_SYS_MSG", +} + GEMINIPRO_MODEL_LIST = [ "gemini-1.5-pro", "gemini-1.5-pro-latest", @@ -57,17 +62,33 @@ class Gemini(Base): self, key, language, + prompt_template=None, + prompt_sys_msg=None, temperature=1.0, interval=0.01, **kwargs, ) -> None: super().__init__(key, language) self.interval = interval + self.prompt = ( + prompt_template + or environ.get(PROMPT_ENV_MAP["user"]) + or self.DEFAULT_PROMPT + ) + self.prompt_sys_msg = ( + prompt_sys_msg + or environ.get(PROMPT_ENV_MAP["system"]) + or None # Allow None, but not empty string + ) + generation_config["temperature"] = temperature + + def create_convo(self): model = genai.GenerativeModel( model_name=self.model, generation_config=generation_config, safety_settings=safety_settings, + system_instruction=self.prompt_sys_msg, ) self.convo = model.start_chat() # print(model) # Uncomment to debug and inspect the model details.