diff --git a/app/utils.py b/app/utils.py index aeddd9e..62a3516 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,9 +1,10 @@ import json import time import os +from api_handlers import BaseHandler -def generate_response(prompt, api_handler): +def generate_response(prompt:str, api_handler:BaseHandler, max_steps:int=5, max_tokens:int=512, temperature:float=0.2, timeout:float = 30.0, sleeptime:float=2.0): messages = [ { "role": "system", @@ -20,9 +21,11 @@ def generate_response(prompt, api_handler): step_count = 1 total_thinking_time = 0 - while True: + for _ in range(max_steps): + time.sleep(sleeptime) # to avoid too many requests error start_time = time.time() - step_data = api_handler.make_api_call(messages, 300) + step_data = api_handler.make_api_call(messages, max_tokens=max_tokens, temperature=temperature, timeout=timeout) + print(step_data) end_time = time.time() thinking_time = end_time - start_time total_thinking_time += thinking_time @@ -64,10 +67,7 @@ def generate_response(prompt, api_handler): def load_env_vars(): return { - "OLLAMA_URL": os.getenv("OLLAMA_URL", "http://localhost:11434"), - "OLLAMA_MODEL": os.getenv("OLLAMA_MODEL", "llama3.1:70b"), - "PERPLEXITY_API_KEY": os.getenv("PERPLEXITY_API_KEY"), - "PERPLEXITY_MODEL": os.getenv( - "PERPLEXITY_MODEL", "llama-3.1-sonar-small-128k-online" - ), + "MODEL": os.getenv("MODEL", "gemini/gemini-1.5-pro"), + "MODEL_API_KEY": os.getenv("MODEL_API_KEY"), } +