From 311a292c5752f72b470315f6ee751e70bdeea3dd Mon Sep 17 00:00:00 2001 From: tcsenpai Date: Tue, 17 Sep 2024 21:04:10 +0200 Subject: [PATCH] added comments for reference --- app/api_handlers.py | 42 +++++++++++++++++++++++++++++++++++++++--- app/main.py | 17 ++++++++++++++++- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/app/api_handlers.py b/app/api_handlers.py index 6d501fa..1570b57 100644 --- a/app/api_handlers.py +++ b/app/api_handlers.py @@ -4,16 +4,19 @@ import groq import time from abc import ABC, abstractmethod +# Abstract base class for API handlers class BaseHandler(ABC): def __init__(self): - self.max_attempts = 3 - self.retry_delay = 1 + self.max_attempts = 3 # Maximum number of retry attempts + self.retry_delay = 1 # Delay between retry attempts in seconds @abstractmethod def _make_request(self, messages, max_tokens): + # Abstract method to be implemented by subclasses pass def make_api_call(self, messages, max_tokens, is_final_answer=False): + # Attempt to make an API call with retry logic for attempt in range(self.max_attempts): try: response = self._make_request(messages, max_tokens) @@ -24,15 +27,18 @@ class BaseHandler(ABC): time.sleep(self.retry_delay) def _process_response(self, response, is_final_answer): + # Default response processing (can be overridden by subclasses) return json.loads(response) def _error_response(self, error_msg, is_final_answer): + # Generate an error response return { "title": "Error", "content": f"Failed to generate {'final answer' if is_final_answer else 'step'} after {self.max_attempts} attempts. Error: {error_msg}", "next_action": "final_answer" if is_final_answer else "continue" } +# Handler for Ollama API class OllamaHandler(BaseHandler): def __init__(self, url, model): super().__init__() @@ -40,6 +46,7 @@ class OllamaHandler(BaseHandler): self.model = model def _make_request(self, messages, max_tokens): + # Make a request to the Ollama API response = requests.post( f"{self.url}/api/chat", json={ @@ -57,6 +64,30 @@ class OllamaHandler(BaseHandler): print(response.json()) return response.json()["message"]["content"] + def _process_response(self, response, is_final_answer): + # Process the Ollama API response + if isinstance(response, dict) and 'message' in response: + content = response['message']['content'] + else: + content = response + + try: + parsed_content = json.loads(content) + if 'final_answer' in parsed_content: + return { + "title": "Final Answer", + "content": parsed_content['final_answer'], + "next_action": "final_answer" + } + return parsed_content + except json.JSONDecodeError: + return { + "title": "Raw Response", + "content": content, + "next_action": "final_answer" if is_final_answer else "continue" + } + +# Handler for Perplexity API class PerplexityHandler(BaseHandler): def __init__(self, api_key, model): super().__init__() @@ -64,6 +95,7 @@ class PerplexityHandler(BaseHandler): self.model = model def _clean_messages(self, messages): + # Clean and consolidate messages for the Perplexity API cleaned_messages = [] last_role = None for message in messages: @@ -74,12 +106,13 @@ class PerplexityHandler(BaseHandler): last_role = message["role"] elif message["role"] == "user": cleaned_messages[-1]["content"] += "\n" + message["content"] - # If the last message is an assistant message, delete it + # Remove the last assistant message if present if cleaned_messages and cleaned_messages[-1]["role"] == "assistant": cleaned_messages.pop() return cleaned_messages def _make_request(self, messages, max_tokens): + # Make a request to the Perplexity API cleaned_messages = self._clean_messages(messages) url = "https://api.perplexity.ai/chat/completions" @@ -99,6 +132,7 @@ class PerplexityHandler(BaseHandler): raise # Re-raise the exception if it's not a 400 error def _process_response(self, response, is_final_answer): + # Process the Perplexity API response try: return super()._process_response(response, is_final_answer) except json.JSONDecodeError: @@ -110,12 +144,14 @@ class PerplexityHandler(BaseHandler): "next_action": "final_answer" if (is_final_answer or forced_final_answer) else "continue" } +# Handler for Groq API class GroqHandler(BaseHandler): def __init__(self): super().__init__() self.client = groq.Groq() def _make_request(self, messages, max_tokens): + # Make a request to the Groq API response = self.client.chat.completions.create( model="llama-3.1-70b-versatile", messages=messages, diff --git a/app/main.py b/app/main.py index b460c6e..dd9c1f2 100644 --- a/app/main.py +++ b/app/main.py @@ -6,21 +6,27 @@ from config_menu import config_menu, display_config from logger import logger import os -# Load environment variables +# Load environment variables from .env file load_dotenv() def load_css(): + # Load custom CSS styles with open(os.path.join(os.path.dirname(__file__), "..", "static", "styles.css")) as f: st.markdown(f'', unsafe_allow_html=True) def setup_page(): + # Configure the Streamlit page st.set_page_config(page_title="multi1 - Unified AI Reasoning Chains", page_icon="🧠", layout="wide") load_css() + + # Display the main title st.markdown("""

🧠 multi1 - Unified AI Reasoning Chains

""", unsafe_allow_html=True) + + # Display the app description st.markdown("""

This app demonstrates AI reasoning chains using different backends: Ollama, Perplexity AI, and Groq. @@ -29,6 +35,7 @@ def setup_page(): """, unsafe_allow_html=True) def get_api_handler(backend, config): + # Create and return the appropriate API handler based on the selected backend if backend == "Ollama": return OllamaHandler(config['OLLAMA_URL'], config['OLLAMA_MODEL']) elif backend == "Perplexity AI": @@ -40,14 +47,17 @@ def main(): logger.info("Starting the application") setup_page() + # Set up the sidebar for configuration st.sidebar.markdown('

', unsafe_allow_html=True) config = config_menu() + # Allow user to select the AI backend backend = st.sidebar.selectbox("Choose AI Backend", ["Ollama", "Perplexity AI", "Groq"]) display_config(backend, config) api_handler = get_api_handler(backend, config) logger.info(f"Selected backend: {backend}") + # User input field user_query = st.text_input("💬 Enter your query:", placeholder="e.g., How many 'R's are in the word strawberry?") if user_query: @@ -57,22 +67,27 @@ def main(): time_container = st.empty() try: + # Generate and display the response for steps, total_thinking_time in generate_response(user_query, api_handler): with response_container.container(): for title, content, _ in steps: if title.startswith("Final Answer"): + # Display the final answer st.markdown(f'

🎯 {title}

', unsafe_allow_html=True) st.markdown(f'
{content}
', unsafe_allow_html=True) logger.info(f"Final answer generated: {content}") else: + # Display intermediate steps with st.expander(f"📝 {title}", expanded=True): st.markdown(f'
{content}
', unsafe_allow_html=True) logger.debug(f"Step completed: {title}") + # Display total thinking time if total_thinking_time is not None: time_container.markdown(f'

⏱️ Total thinking time: {total_thinking_time:.2f} seconds

', unsafe_allow_html=True) logger.info(f"Total thinking time: {total_thinking_time:.2f} seconds") except Exception as e: + # Handle and display any errors logger.error(f"Error generating response: {str(e)}", exc_info=True) st.error("An error occurred while generating the response. Please try again.")