feat: add support for --prompt option to gemini

This commit is contained in:
Risin 2024-10-15 18:47:29 +09:00
parent 0e13949b19
commit 4649ad1aaf

View File

@ -31,6 +31,11 @@ safety_settings = [
"threshold": "BLOCK_MEDIUM_AND_ABOVE", "threshold": "BLOCK_MEDIUM_AND_ABOVE",
}, },
PROMPT_ENV_MAP = {
"user": "BBM_GEMINIAPI_USER_MSG_TEMPLATE",
"system": "BBM_GEMINIAPI_SYS_MSG",
}
GEMINIPRO_MODEL_LIST = [ GEMINIPRO_MODEL_LIST = [
"gemini-1.5-pro", "gemini-1.5-pro",
"gemini-1.5-pro-latest", "gemini-1.5-pro-latest",
@ -57,17 +62,33 @@ class Gemini(Base):
self, self,
key, key,
language, language,
prompt_template=None,
prompt_sys_msg=None,
temperature=1.0, temperature=1.0,
interval=0.01, interval=0.01,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(key, language) super().__init__(key, language)
self.interval = interval self.interval = interval
self.prompt = (
prompt_template
or environ.get(PROMPT_ENV_MAP["user"])
or self.DEFAULT_PROMPT
)
self.prompt_sys_msg = (
prompt_sys_msg
or environ.get(PROMPT_ENV_MAP["system"])
or None # Allow None, but not empty string
)
generation_config["temperature"] = temperature generation_config["temperature"] = temperature
def create_convo(self):
model = genai.GenerativeModel( model = genai.GenerativeModel(
model_name=self.model, model_name=self.model,
generation_config=generation_config, generation_config=generation_config,
safety_settings=safety_settings, safety_settings=safety_settings,
system_instruction=self.prompt_sys_msg,
) )
self.convo = model.start_chat() self.convo = model.start_chat()
# print(model) # Uncomment to debug and inspect the model details. # print(model) # Uncomment to debug and inspect the model details.