mirror of
https://github.com/tcsenpai/multi1.git
synced 2025-06-06 02:55:21 +00:00
added comments for reference
This commit is contained in:
parent
4dc438f536
commit
311a292c57
@ -4,16 +4,19 @@ import groq
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# Abstract base class for API handlers
|
||||
class BaseHandler(ABC):
|
||||
def __init__(self):
|
||||
self.max_attempts = 3
|
||||
self.retry_delay = 1
|
||||
self.max_attempts = 3 # Maximum number of retry attempts
|
||||
self.retry_delay = 1 # Delay between retry attempts in seconds
|
||||
|
||||
@abstractmethod
|
||||
def _make_request(self, messages, max_tokens):
|
||||
# Abstract method to be implemented by subclasses
|
||||
pass
|
||||
|
||||
def make_api_call(self, messages, max_tokens, is_final_answer=False):
|
||||
# Attempt to make an API call with retry logic
|
||||
for attempt in range(self.max_attempts):
|
||||
try:
|
||||
response = self._make_request(messages, max_tokens)
|
||||
@ -24,15 +27,18 @@ class BaseHandler(ABC):
|
||||
time.sleep(self.retry_delay)
|
||||
|
||||
def _process_response(self, response, is_final_answer):
|
||||
# Default response processing (can be overridden by subclasses)
|
||||
return json.loads(response)
|
||||
|
||||
def _error_response(self, error_msg, is_final_answer):
|
||||
# Generate an error response
|
||||
return {
|
||||
"title": "Error",
|
||||
"content": f"Failed to generate {'final answer' if is_final_answer else 'step'} after {self.max_attempts} attempts. Error: {error_msg}",
|
||||
"next_action": "final_answer" if is_final_answer else "continue"
|
||||
}
|
||||
|
||||
# Handler for Ollama API
|
||||
class OllamaHandler(BaseHandler):
|
||||
def __init__(self, url, model):
|
||||
super().__init__()
|
||||
@ -40,6 +46,7 @@ class OllamaHandler(BaseHandler):
|
||||
self.model = model
|
||||
|
||||
def _make_request(self, messages, max_tokens):
|
||||
# Make a request to the Ollama API
|
||||
response = requests.post(
|
||||
f"{self.url}/api/chat",
|
||||
json={
|
||||
@ -57,6 +64,30 @@ class OllamaHandler(BaseHandler):
|
||||
print(response.json())
|
||||
return response.json()["message"]["content"]
|
||||
|
||||
def _process_response(self, response, is_final_answer):
|
||||
# Process the Ollama API response
|
||||
if isinstance(response, dict) and 'message' in response:
|
||||
content = response['message']['content']
|
||||
else:
|
||||
content = response
|
||||
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
if 'final_answer' in parsed_content:
|
||||
return {
|
||||
"title": "Final Answer",
|
||||
"content": parsed_content['final_answer'],
|
||||
"next_action": "final_answer"
|
||||
}
|
||||
return parsed_content
|
||||
except json.JSONDecodeError:
|
||||
return {
|
||||
"title": "Raw Response",
|
||||
"content": content,
|
||||
"next_action": "final_answer" if is_final_answer else "continue"
|
||||
}
|
||||
|
||||
# Handler for Perplexity API
|
||||
class PerplexityHandler(BaseHandler):
|
||||
def __init__(self, api_key, model):
|
||||
super().__init__()
|
||||
@ -64,6 +95,7 @@ class PerplexityHandler(BaseHandler):
|
||||
self.model = model
|
||||
|
||||
def _clean_messages(self, messages):
|
||||
# Clean and consolidate messages for the Perplexity API
|
||||
cleaned_messages = []
|
||||
last_role = None
|
||||
for message in messages:
|
||||
@ -74,12 +106,13 @@ class PerplexityHandler(BaseHandler):
|
||||
last_role = message["role"]
|
||||
elif message["role"] == "user":
|
||||
cleaned_messages[-1]["content"] += "\n" + message["content"]
|
||||
# If the last message is an assistant message, delete it
|
||||
# Remove the last assistant message if present
|
||||
if cleaned_messages and cleaned_messages[-1]["role"] == "assistant":
|
||||
cleaned_messages.pop()
|
||||
return cleaned_messages
|
||||
|
||||
def _make_request(self, messages, max_tokens):
|
||||
# Make a request to the Perplexity API
|
||||
cleaned_messages = self._clean_messages(messages)
|
||||
|
||||
url = "https://api.perplexity.ai/chat/completions"
|
||||
@ -99,6 +132,7 @@ class PerplexityHandler(BaseHandler):
|
||||
raise # Re-raise the exception if it's not a 400 error
|
||||
|
||||
def _process_response(self, response, is_final_answer):
|
||||
# Process the Perplexity API response
|
||||
try:
|
||||
return super()._process_response(response, is_final_answer)
|
||||
except json.JSONDecodeError:
|
||||
@ -110,12 +144,14 @@ class PerplexityHandler(BaseHandler):
|
||||
"next_action": "final_answer" if (is_final_answer or forced_final_answer) else "continue"
|
||||
}
|
||||
|
||||
# Handler for Groq API
|
||||
class GroqHandler(BaseHandler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.client = groq.Groq()
|
||||
|
||||
def _make_request(self, messages, max_tokens):
|
||||
# Make a request to the Groq API
|
||||
response = self.client.chat.completions.create(
|
||||
model="llama-3.1-70b-versatile",
|
||||
messages=messages,
|
||||
|
17
app/main.py
17
app/main.py
@ -6,21 +6,27 @@ from config_menu import config_menu, display_config
|
||||
from logger import logger
|
||||
import os
|
||||
|
||||
# Load environment variables
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
def load_css():
|
||||
# Load custom CSS styles
|
||||
with open(os.path.join(os.path.dirname(__file__), "..", "static", "styles.css")) as f:
|
||||
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
|
||||
|
||||
def setup_page():
|
||||
# Configure the Streamlit page
|
||||
st.set_page_config(page_title="multi1 - Unified AI Reasoning Chains", page_icon="🧠", layout="wide")
|
||||
load_css()
|
||||
|
||||
# Display the main title
|
||||
st.markdown("""
|
||||
<h1 class="main-title">
|
||||
🧠 multi1 - Unified AI Reasoning Chains
|
||||
</h1>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
# Display the app description
|
||||
st.markdown("""
|
||||
<p class="main-description">
|
||||
This app demonstrates AI reasoning chains using different backends: Ollama, Perplexity AI, and Groq.
|
||||
@ -29,6 +35,7 @@ def setup_page():
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
def get_api_handler(backend, config):
|
||||
# Create and return the appropriate API handler based on the selected backend
|
||||
if backend == "Ollama":
|
||||
return OllamaHandler(config['OLLAMA_URL'], config['OLLAMA_MODEL'])
|
||||
elif backend == "Perplexity AI":
|
||||
@ -40,14 +47,17 @@ def main():
|
||||
logger.info("Starting the application")
|
||||
setup_page()
|
||||
|
||||
# Set up the sidebar for configuration
|
||||
st.sidebar.markdown('<h3 class="sidebar-title">⚙️ Settings</h3>', unsafe_allow_html=True)
|
||||
config = config_menu()
|
||||
|
||||
# Allow user to select the AI backend
|
||||
backend = st.sidebar.selectbox("Choose AI Backend", ["Ollama", "Perplexity AI", "Groq"])
|
||||
display_config(backend, config)
|
||||
api_handler = get_api_handler(backend, config)
|
||||
logger.info(f"Selected backend: {backend}")
|
||||
|
||||
# User input field
|
||||
user_query = st.text_input("💬 Enter your query:", placeholder="e.g., How many 'R's are in the word strawberry?")
|
||||
|
||||
if user_query:
|
||||
@ -57,22 +67,27 @@ def main():
|
||||
time_container = st.empty()
|
||||
|
||||
try:
|
||||
# Generate and display the response
|
||||
for steps, total_thinking_time in generate_response(user_query, api_handler):
|
||||
with response_container.container():
|
||||
for title, content, _ in steps:
|
||||
if title.startswith("Final Answer"):
|
||||
# Display the final answer
|
||||
st.markdown(f'<h3 class="expander-title">🎯 {title}</h3>', unsafe_allow_html=True)
|
||||
st.markdown(f'<div>{content}</div>', unsafe_allow_html=True)
|
||||
logger.info(f"Final answer generated: {content}")
|
||||
else:
|
||||
# Display intermediate steps
|
||||
with st.expander(f"📝 {title}", expanded=True):
|
||||
st.markdown(f'<div>{content}</div>', unsafe_allow_html=True)
|
||||
logger.debug(f"Step completed: {title}")
|
||||
|
||||
# Display total thinking time
|
||||
if total_thinking_time is not None:
|
||||
time_container.markdown(f'<p class="thinking-time">⏱️ Total thinking time: {total_thinking_time:.2f} seconds</p>', unsafe_allow_html=True)
|
||||
logger.info(f"Total thinking time: {total_thinking_time:.2f} seconds")
|
||||
except Exception as e:
|
||||
# Handle and display any errors
|
||||
logger.error(f"Error generating response: {str(e)}", exc_info=True)
|
||||
st.error("An error occurred while generating the response. Please try again.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user