added comments for reference

This commit is contained in:
tcsenpai 2024-09-17 21:04:10 +02:00
parent 4dc438f536
commit 311a292c57
2 changed files with 55 additions and 4 deletions

View File

@ -4,16 +4,19 @@ import groq
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
# Abstract base class for API handlers
class BaseHandler(ABC): class BaseHandler(ABC):
def __init__(self): def __init__(self):
self.max_attempts = 3 self.max_attempts = 3 # Maximum number of retry attempts
self.retry_delay = 1 self.retry_delay = 1 # Delay between retry attempts in seconds
@abstractmethod @abstractmethod
def _make_request(self, messages, max_tokens): def _make_request(self, messages, max_tokens):
# Abstract method to be implemented by subclasses
pass pass
def make_api_call(self, messages, max_tokens, is_final_answer=False): def make_api_call(self, messages, max_tokens, is_final_answer=False):
# Attempt to make an API call with retry logic
for attempt in range(self.max_attempts): for attempt in range(self.max_attempts):
try: try:
response = self._make_request(messages, max_tokens) response = self._make_request(messages, max_tokens)
@ -24,15 +27,18 @@ class BaseHandler(ABC):
time.sleep(self.retry_delay) time.sleep(self.retry_delay)
def _process_response(self, response, is_final_answer): def _process_response(self, response, is_final_answer):
# Default response processing (can be overridden by subclasses)
return json.loads(response) return json.loads(response)
def _error_response(self, error_msg, is_final_answer): def _error_response(self, error_msg, is_final_answer):
# Generate an error response
return { return {
"title": "Error", "title": "Error",
"content": f"Failed to generate {'final answer' if is_final_answer else 'step'} after {self.max_attempts} attempts. Error: {error_msg}", "content": f"Failed to generate {'final answer' if is_final_answer else 'step'} after {self.max_attempts} attempts. Error: {error_msg}",
"next_action": "final_answer" if is_final_answer else "continue" "next_action": "final_answer" if is_final_answer else "continue"
} }
# Handler for Ollama API
class OllamaHandler(BaseHandler): class OllamaHandler(BaseHandler):
def __init__(self, url, model): def __init__(self, url, model):
super().__init__() super().__init__()
@ -40,6 +46,7 @@ class OllamaHandler(BaseHandler):
self.model = model self.model = model
def _make_request(self, messages, max_tokens): def _make_request(self, messages, max_tokens):
# Make a request to the Ollama API
response = requests.post( response = requests.post(
f"{self.url}/api/chat", f"{self.url}/api/chat",
json={ json={
@ -57,6 +64,30 @@ class OllamaHandler(BaseHandler):
print(response.json()) print(response.json())
return response.json()["message"]["content"] return response.json()["message"]["content"]
def _process_response(self, response, is_final_answer):
# Process the Ollama API response
if isinstance(response, dict) and 'message' in response:
content = response['message']['content']
else:
content = response
try:
parsed_content = json.loads(content)
if 'final_answer' in parsed_content:
return {
"title": "Final Answer",
"content": parsed_content['final_answer'],
"next_action": "final_answer"
}
return parsed_content
except json.JSONDecodeError:
return {
"title": "Raw Response",
"content": content,
"next_action": "final_answer" if is_final_answer else "continue"
}
# Handler for Perplexity API
class PerplexityHandler(BaseHandler): class PerplexityHandler(BaseHandler):
def __init__(self, api_key, model): def __init__(self, api_key, model):
super().__init__() super().__init__()
@ -64,6 +95,7 @@ class PerplexityHandler(BaseHandler):
self.model = model self.model = model
def _clean_messages(self, messages): def _clean_messages(self, messages):
# Clean and consolidate messages for the Perplexity API
cleaned_messages = [] cleaned_messages = []
last_role = None last_role = None
for message in messages: for message in messages:
@ -74,12 +106,13 @@ class PerplexityHandler(BaseHandler):
last_role = message["role"] last_role = message["role"]
elif message["role"] == "user": elif message["role"] == "user":
cleaned_messages[-1]["content"] += "\n" + message["content"] cleaned_messages[-1]["content"] += "\n" + message["content"]
# If the last message is an assistant message, delete it # Remove the last assistant message if present
if cleaned_messages and cleaned_messages[-1]["role"] == "assistant": if cleaned_messages and cleaned_messages[-1]["role"] == "assistant":
cleaned_messages.pop() cleaned_messages.pop()
return cleaned_messages return cleaned_messages
def _make_request(self, messages, max_tokens): def _make_request(self, messages, max_tokens):
# Make a request to the Perplexity API
cleaned_messages = self._clean_messages(messages) cleaned_messages = self._clean_messages(messages)
url = "https://api.perplexity.ai/chat/completions" url = "https://api.perplexity.ai/chat/completions"
@ -99,6 +132,7 @@ class PerplexityHandler(BaseHandler):
raise # Re-raise the exception if it's not a 400 error raise # Re-raise the exception if it's not a 400 error
def _process_response(self, response, is_final_answer): def _process_response(self, response, is_final_answer):
# Process the Perplexity API response
try: try:
return super()._process_response(response, is_final_answer) return super()._process_response(response, is_final_answer)
except json.JSONDecodeError: except json.JSONDecodeError:
@ -110,12 +144,14 @@ class PerplexityHandler(BaseHandler):
"next_action": "final_answer" if (is_final_answer or forced_final_answer) else "continue" "next_action": "final_answer" if (is_final_answer or forced_final_answer) else "continue"
} }
# Handler for Groq API
class GroqHandler(BaseHandler): class GroqHandler(BaseHandler):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.client = groq.Groq() self.client = groq.Groq()
def _make_request(self, messages, max_tokens): def _make_request(self, messages, max_tokens):
# Make a request to the Groq API
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model="llama-3.1-70b-versatile", model="llama-3.1-70b-versatile",
messages=messages, messages=messages,

View File

@ -6,21 +6,27 @@ from config_menu import config_menu, display_config
from logger import logger from logger import logger
import os import os
# Load environment variables # Load environment variables from .env file
load_dotenv() load_dotenv()
def load_css(): def load_css():
# Load custom CSS styles
with open(os.path.join(os.path.dirname(__file__), "..", "static", "styles.css")) as f: with open(os.path.join(os.path.dirname(__file__), "..", "static", "styles.css")) as f:
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
def setup_page(): def setup_page():
# Configure the Streamlit page
st.set_page_config(page_title="multi1 - Unified AI Reasoning Chains", page_icon="🧠", layout="wide") st.set_page_config(page_title="multi1 - Unified AI Reasoning Chains", page_icon="🧠", layout="wide")
load_css() load_css()
# Display the main title
st.markdown(""" st.markdown("""
<h1 class="main-title"> <h1 class="main-title">
🧠 multi1 - Unified AI Reasoning Chains 🧠 multi1 - Unified AI Reasoning Chains
</h1> </h1>
""", unsafe_allow_html=True) """, unsafe_allow_html=True)
# Display the app description
st.markdown(""" st.markdown("""
<p class="main-description"> <p class="main-description">
This app demonstrates AI reasoning chains using different backends: Ollama, Perplexity AI, and Groq. This app demonstrates AI reasoning chains using different backends: Ollama, Perplexity AI, and Groq.
@ -29,6 +35,7 @@ def setup_page():
""", unsafe_allow_html=True) """, unsafe_allow_html=True)
def get_api_handler(backend, config): def get_api_handler(backend, config):
# Create and return the appropriate API handler based on the selected backend
if backend == "Ollama": if backend == "Ollama":
return OllamaHandler(config['OLLAMA_URL'], config['OLLAMA_MODEL']) return OllamaHandler(config['OLLAMA_URL'], config['OLLAMA_MODEL'])
elif backend == "Perplexity AI": elif backend == "Perplexity AI":
@ -40,14 +47,17 @@ def main():
logger.info("Starting the application") logger.info("Starting the application")
setup_page() setup_page()
# Set up the sidebar for configuration
st.sidebar.markdown('<h3 class="sidebar-title">⚙️ Settings</h3>', unsafe_allow_html=True) st.sidebar.markdown('<h3 class="sidebar-title">⚙️ Settings</h3>', unsafe_allow_html=True)
config = config_menu() config = config_menu()
# Allow user to select the AI backend
backend = st.sidebar.selectbox("Choose AI Backend", ["Ollama", "Perplexity AI", "Groq"]) backend = st.sidebar.selectbox("Choose AI Backend", ["Ollama", "Perplexity AI", "Groq"])
display_config(backend, config) display_config(backend, config)
api_handler = get_api_handler(backend, config) api_handler = get_api_handler(backend, config)
logger.info(f"Selected backend: {backend}") logger.info(f"Selected backend: {backend}")
# User input field
user_query = st.text_input("💬 Enter your query:", placeholder="e.g., How many 'R's are in the word strawberry?") user_query = st.text_input("💬 Enter your query:", placeholder="e.g., How many 'R's are in the word strawberry?")
if user_query: if user_query:
@ -57,22 +67,27 @@ def main():
time_container = st.empty() time_container = st.empty()
try: try:
# Generate and display the response
for steps, total_thinking_time in generate_response(user_query, api_handler): for steps, total_thinking_time in generate_response(user_query, api_handler):
with response_container.container(): with response_container.container():
for title, content, _ in steps: for title, content, _ in steps:
if title.startswith("Final Answer"): if title.startswith("Final Answer"):
# Display the final answer
st.markdown(f'<h3 class="expander-title">🎯 {title}</h3>', unsafe_allow_html=True) st.markdown(f'<h3 class="expander-title">🎯 {title}</h3>', unsafe_allow_html=True)
st.markdown(f'<div>{content}</div>', unsafe_allow_html=True) st.markdown(f'<div>{content}</div>', unsafe_allow_html=True)
logger.info(f"Final answer generated: {content}") logger.info(f"Final answer generated: {content}")
else: else:
# Display intermediate steps
with st.expander(f"📝 {title}", expanded=True): with st.expander(f"📝 {title}", expanded=True):
st.markdown(f'<div>{content}</div>', unsafe_allow_html=True) st.markdown(f'<div>{content}</div>', unsafe_allow_html=True)
logger.debug(f"Step completed: {title}") logger.debug(f"Step completed: {title}")
# Display total thinking time
if total_thinking_time is not None: if total_thinking_time is not None:
time_container.markdown(f'<p class="thinking-time">⏱️ Total thinking time: {total_thinking_time:.2f} seconds</p>', unsafe_allow_html=True) time_container.markdown(f'<p class="thinking-time">⏱️ Total thinking time: {total_thinking_time:.2f} seconds</p>', unsafe_allow_html=True)
logger.info(f"Total thinking time: {total_thinking_time:.2f} seconds") logger.info(f"Total thinking time: {total_thinking_time:.2f} seconds")
except Exception as e: except Exception as e:
# Handle and display any errors
logger.error(f"Error generating response: {str(e)}", exc_info=True) logger.error(f"Error generating response: {str(e)}", exc_info=True)
st.error("An error occurred while generating the response. Please try again.") st.error("An error occurred while generating the response. Please try again.")