From 06cba4f3ca35384ecd034f8a77319d52d92dd82b Mon Sep 17 00:00:00 2001 From: tcsenpai Date: Tue, 17 Sep 2024 12:19:25 +0200 Subject: [PATCH] unified more elegant approach --- .gitignore | 4 +- api_handlers.py | 209 ++++++++++++++++++++++++------------------------ 2 files changed, 108 insertions(+), 105 deletions(-) diff --git a/.gitignore b/.gitignore index 686ade3..4919cbe 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ .env venv -.venv \ No newline at end of file +.venv +__pycache__/ +.venv/ \ No newline at end of file diff --git a/api_handlers.py b/api_handlers.py index 069e7bf..10b0e69 100644 --- a/api_handlers.py +++ b/api_handlers.py @@ -2,123 +2,124 @@ import json import requests import groq import time +from abc import ABC, abstractmethod -class OllamaHandler: - def __init__(self, url, model): - self.url = url - self.model = model +class BaseHandler(ABC): + def __init__(self): + self.max_attempts = 3 + self.retry_delay = 1 + + @abstractmethod + def _make_request(self, messages, max_tokens): + pass def make_api_call(self, messages, max_tokens, is_final_answer=False): - for attempt in range(3): + for attempt in range(self.max_attempts): try: - response = requests.post( - f"{self.url}/api/chat", - json={ - "model": self.model, - "messages": messages, - "stream": False, - "format": "json", - "options": { - "num_predict": max_tokens, - "temperature": 0.2 - } - } - ) - response.raise_for_status() - return json.loads(response.json()["message"]["content"]) + response = self._make_request(messages, max_tokens) + return self._process_response(response, is_final_answer) except Exception as e: - if attempt == 2: + if attempt == self.max_attempts - 1: return self._error_response(str(e), is_final_answer) - time.sleep(1) + time.sleep(self.retry_delay) - def _error_response(self, error_msg, is_final_answer): - if is_final_answer: - return {"title": "Error", "content": f"Failed to generate final answer after 3 attempts. Error: {error_msg}"} - else: - return {"title": "Error", "content": f"Failed to generate step after 3 attempts. Error: {error_msg}", "next_action": "final_answer"} - -class PerplexityHandler: - def __init__(self, api_key, model): - self.api_key = api_key - self.model = model - - def make_api_call(self, messages, max_tokens, is_final_answer=False): - - # Quick dirty fix for API calls in perplexity that removes the assistant message - #messages[0]["content"] = messages[0]["content"] + " You will always respond ONLY with JSON with the following format: {'title': 'Title of the step', 'content': 'Content of the step', 'next_action': 'continue' or 'final_answer'}. You are not allowed to respond with anything else or any additional text. " - if not is_final_answer: - for i in range(len(messages)): - if messages[i]["role"] == "assistant": - messages.pop(i) - - for attempt in range(3): - try: - url = "https://api.perplexity.ai/chat/completions" - payload = {"model": self.model, "messages": messages} - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } - response = requests.post(url, json=payload, headers=headers) - - # Add specific handling for 400 error - if response.status_code == 400: - error_content = response.json() - print(f"HTTP 400 Error: {error_content}") - return self._error_response(f"HTTP 400 Error: {error_content}", is_final_answer) - - response.raise_for_status() - content = response.json()["choices"][0]["message"]["content"] - print("Content: ", content) - return json.loads(content) - except json.JSONDecodeError: - print("Warning: content is not a valid JSON, returning raw response") - # Better detection of final answer in the raw response for Perplexity - forced_final_answer = False - if '"next_action": "final_answer"' in content.lower().strip(): - forced_final_answer = True - print("Forced final answer: ", forced_final_answer) - - return { - "title": "Raw Response", - "content": content, - "next_action": "final_answer" if (is_final_answer|forced_final_answer) else "continue" - } - except requests.exceptions.RequestException as e: - print(f"Request failed: {e}") - if attempt == 2: - return self._error_response(str(e), is_final_answer) - time.sleep(1) + def _process_response(self, response, is_final_answer): + return json.loads(response) def _error_response(self, error_msg, is_final_answer): return { "title": "Error", - "content": f"API request failed after 3 attempts. Error: {error_msg}", - "next_action": "final_answer", + "content": f"Failed to generate {'final answer' if is_final_answer else 'step'} after {self.max_attempts} attempts. Error: {error_msg}", + "next_action": "final_answer" if is_final_answer else "continue" } -class GroqHandler: +class OllamaHandler(BaseHandler): + def __init__(self, url, model): + super().__init__() + self.url = url + self.model = model + + def _make_request(self, messages, max_tokens): + response = requests.post( + f"{self.url}/api/chat", + json={ + "model": self.model, + "messages": messages, + "stream": False, + "format": "json", + "options": { + "num_predict": max_tokens, + "temperature": 0.2 + } + } + ) + response.raise_for_status() + return response.json()["message"]["content"] + +class PerplexityHandler(BaseHandler): + def __init__(self, api_key, model): + super().__init__() + self.api_key = api_key + self.model = model + + def _clean_messages(self, messages): + cleaned_messages = [] + last_role = None + for message in messages: + if message["role"] == "system": + cleaned_messages.append(message) + elif message["role"] != last_role: + cleaned_messages.append(message) + last_role = message["role"] + elif message["role"] == "user": + cleaned_messages[-1]["content"] += "\n" + message["content"] + # If the last message is an assistant message, delete it + if cleaned_messages and cleaned_messages[-1]["role"] == "assistant": + cleaned_messages.pop() + return cleaned_messages + + def _make_request(self, messages, max_tokens): + cleaned_messages = self._clean_messages(messages) + + url = "https://api.perplexity.ai/chat/completions" + payload = {"model": self.model, "messages": cleaned_messages} + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + try: + response = requests.post(url, json=payload, headers=headers) + response.raise_for_status() + return response.json()["choices"][0]["message"]["content"] + except requests.exceptions.HTTPError as http_err: + if response.status_code == 400: + error_message = response.json().get("error", {}).get("message", "Unknown error") + raise ValueError(f"Bad request (400): {error_message}") + raise # Re-raise the exception if it's not a 400 error + + def _process_response(self, response, is_final_answer): + try: + return super()._process_response(response, is_final_answer) + except json.JSONDecodeError: + print("Warning: content is not a valid JSON, returning raw response") + forced_final_answer = '"next_action": "final_answer"' in response.lower().strip() + return { + "title": "Raw Response", + "content": response, + "next_action": "final_answer" if (is_final_answer or forced_final_answer) else "continue" + } + +class GroqHandler(BaseHandler): def __init__(self): + super().__init__() self.client = groq.Groq() - def make_api_call(self, messages, max_tokens, is_final_answer=False): - for attempt in range(3): - try: - response = self.client.chat.completions.create( - model="llama-3.1-70b-versatile", - messages=messages, - max_tokens=max_tokens, - temperature=0.2, - response_format={"type": "json_object"} - ) - return json.loads(response.choices[0].message.content) - except Exception as e: - if attempt == 2: - return self._error_response(str(e), is_final_answer) - time.sleep(1) - - def _error_response(self, error_msg, is_final_answer): - if is_final_answer: - return {"title": "Error", "content": f"Failed to generate final answer after 3 attempts. Error: {error_msg}"} - else: - return {"title": "Error", "content": f"Failed to generate step after 3 attempts. Error: {error_msg}", "next_action": "final_answer"} \ No newline at end of file + def _make_request(self, messages, max_tokens): + response = self.client.chat.completions.create( + model="llama-3.1-70b-versatile", + messages=messages, + max_tokens=max_tokens, + temperature=0.2, + response_format={"type": "json_object"} + ) + return response.choices[0].message.content \ No newline at end of file