mirror of
https://github.com/tcsenpai/multi1.git
synced 2025-06-06 02:55:21 +00:00
added comments for reference
This commit is contained in:
parent
4dc438f536
commit
311a292c57
@ -4,16 +4,19 @@ import groq
|
|||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
# Abstract base class for API handlers
|
||||||
class BaseHandler(ABC):
|
class BaseHandler(ABC):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.max_attempts = 3
|
self.max_attempts = 3 # Maximum number of retry attempts
|
||||||
self.retry_delay = 1
|
self.retry_delay = 1 # Delay between retry attempts in seconds
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _make_request(self, messages, max_tokens):
|
def _make_request(self, messages, max_tokens):
|
||||||
|
# Abstract method to be implemented by subclasses
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def make_api_call(self, messages, max_tokens, is_final_answer=False):
|
def make_api_call(self, messages, max_tokens, is_final_answer=False):
|
||||||
|
# Attempt to make an API call with retry logic
|
||||||
for attempt in range(self.max_attempts):
|
for attempt in range(self.max_attempts):
|
||||||
try:
|
try:
|
||||||
response = self._make_request(messages, max_tokens)
|
response = self._make_request(messages, max_tokens)
|
||||||
@ -24,15 +27,18 @@ class BaseHandler(ABC):
|
|||||||
time.sleep(self.retry_delay)
|
time.sleep(self.retry_delay)
|
||||||
|
|
||||||
def _process_response(self, response, is_final_answer):
|
def _process_response(self, response, is_final_answer):
|
||||||
|
# Default response processing (can be overridden by subclasses)
|
||||||
return json.loads(response)
|
return json.loads(response)
|
||||||
|
|
||||||
def _error_response(self, error_msg, is_final_answer):
|
def _error_response(self, error_msg, is_final_answer):
|
||||||
|
# Generate an error response
|
||||||
return {
|
return {
|
||||||
"title": "Error",
|
"title": "Error",
|
||||||
"content": f"Failed to generate {'final answer' if is_final_answer else 'step'} after {self.max_attempts} attempts. Error: {error_msg}",
|
"content": f"Failed to generate {'final answer' if is_final_answer else 'step'} after {self.max_attempts} attempts. Error: {error_msg}",
|
||||||
"next_action": "final_answer" if is_final_answer else "continue"
|
"next_action": "final_answer" if is_final_answer else "continue"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Handler for Ollama API
|
||||||
class OllamaHandler(BaseHandler):
|
class OllamaHandler(BaseHandler):
|
||||||
def __init__(self, url, model):
|
def __init__(self, url, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -40,6 +46,7 @@ class OllamaHandler(BaseHandler):
|
|||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def _make_request(self, messages, max_tokens):
|
def _make_request(self, messages, max_tokens):
|
||||||
|
# Make a request to the Ollama API
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{self.url}/api/chat",
|
f"{self.url}/api/chat",
|
||||||
json={
|
json={
|
||||||
@ -57,6 +64,30 @@ class OllamaHandler(BaseHandler):
|
|||||||
print(response.json())
|
print(response.json())
|
||||||
return response.json()["message"]["content"]
|
return response.json()["message"]["content"]
|
||||||
|
|
||||||
|
def _process_response(self, response, is_final_answer):
|
||||||
|
# Process the Ollama API response
|
||||||
|
if isinstance(response, dict) and 'message' in response:
|
||||||
|
content = response['message']['content']
|
||||||
|
else:
|
||||||
|
content = response
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_content = json.loads(content)
|
||||||
|
if 'final_answer' in parsed_content:
|
||||||
|
return {
|
||||||
|
"title": "Final Answer",
|
||||||
|
"content": parsed_content['final_answer'],
|
||||||
|
"next_action": "final_answer"
|
||||||
|
}
|
||||||
|
return parsed_content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {
|
||||||
|
"title": "Raw Response",
|
||||||
|
"content": content,
|
||||||
|
"next_action": "final_answer" if is_final_answer else "continue"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handler for Perplexity API
|
||||||
class PerplexityHandler(BaseHandler):
|
class PerplexityHandler(BaseHandler):
|
||||||
def __init__(self, api_key, model):
|
def __init__(self, api_key, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -64,6 +95,7 @@ class PerplexityHandler(BaseHandler):
|
|||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def _clean_messages(self, messages):
|
def _clean_messages(self, messages):
|
||||||
|
# Clean and consolidate messages for the Perplexity API
|
||||||
cleaned_messages = []
|
cleaned_messages = []
|
||||||
last_role = None
|
last_role = None
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@ -74,12 +106,13 @@ class PerplexityHandler(BaseHandler):
|
|||||||
last_role = message["role"]
|
last_role = message["role"]
|
||||||
elif message["role"] == "user":
|
elif message["role"] == "user":
|
||||||
cleaned_messages[-1]["content"] += "\n" + message["content"]
|
cleaned_messages[-1]["content"] += "\n" + message["content"]
|
||||||
# If the last message is an assistant message, delete it
|
# Remove the last assistant message if present
|
||||||
if cleaned_messages and cleaned_messages[-1]["role"] == "assistant":
|
if cleaned_messages and cleaned_messages[-1]["role"] == "assistant":
|
||||||
cleaned_messages.pop()
|
cleaned_messages.pop()
|
||||||
return cleaned_messages
|
return cleaned_messages
|
||||||
|
|
||||||
def _make_request(self, messages, max_tokens):
|
def _make_request(self, messages, max_tokens):
|
||||||
|
# Make a request to the Perplexity API
|
||||||
cleaned_messages = self._clean_messages(messages)
|
cleaned_messages = self._clean_messages(messages)
|
||||||
|
|
||||||
url = "https://api.perplexity.ai/chat/completions"
|
url = "https://api.perplexity.ai/chat/completions"
|
||||||
@ -99,6 +132,7 @@ class PerplexityHandler(BaseHandler):
|
|||||||
raise # Re-raise the exception if it's not a 400 error
|
raise # Re-raise the exception if it's not a 400 error
|
||||||
|
|
||||||
def _process_response(self, response, is_final_answer):
|
def _process_response(self, response, is_final_answer):
|
||||||
|
# Process the Perplexity API response
|
||||||
try:
|
try:
|
||||||
return super()._process_response(response, is_final_answer)
|
return super()._process_response(response, is_final_answer)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
@ -110,12 +144,14 @@ class PerplexityHandler(BaseHandler):
|
|||||||
"next_action": "final_answer" if (is_final_answer or forced_final_answer) else "continue"
|
"next_action": "final_answer" if (is_final_answer or forced_final_answer) else "continue"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Handler for Groq API
|
||||||
class GroqHandler(BaseHandler):
|
class GroqHandler(BaseHandler):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.client = groq.Groq()
|
self.client = groq.Groq()
|
||||||
|
|
||||||
def _make_request(self, messages, max_tokens):
|
def _make_request(self, messages, max_tokens):
|
||||||
|
# Make a request to the Groq API
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model="llama-3.1-70b-versatile",
|
model="llama-3.1-70b-versatile",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
17
app/main.py
17
app/main.py
@ -6,21 +6,27 @@ from config_menu import config_menu, display_config
|
|||||||
from logger import logger
|
from logger import logger
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables from .env file
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
def load_css():
|
def load_css():
|
||||||
|
# Load custom CSS styles
|
||||||
with open(os.path.join(os.path.dirname(__file__), "..", "static", "styles.css")) as f:
|
with open(os.path.join(os.path.dirname(__file__), "..", "static", "styles.css")) as f:
|
||||||
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
|
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
|
||||||
|
|
||||||
def setup_page():
|
def setup_page():
|
||||||
|
# Configure the Streamlit page
|
||||||
st.set_page_config(page_title="multi1 - Unified AI Reasoning Chains", page_icon="🧠", layout="wide")
|
st.set_page_config(page_title="multi1 - Unified AI Reasoning Chains", page_icon="🧠", layout="wide")
|
||||||
load_css()
|
load_css()
|
||||||
|
|
||||||
|
# Display the main title
|
||||||
st.markdown("""
|
st.markdown("""
|
||||||
<h1 class="main-title">
|
<h1 class="main-title">
|
||||||
🧠 multi1 - Unified AI Reasoning Chains
|
🧠 multi1 - Unified AI Reasoning Chains
|
||||||
</h1>
|
</h1>
|
||||||
""", unsafe_allow_html=True)
|
""", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# Display the app description
|
||||||
st.markdown("""
|
st.markdown("""
|
||||||
<p class="main-description">
|
<p class="main-description">
|
||||||
This app demonstrates AI reasoning chains using different backends: Ollama, Perplexity AI, and Groq.
|
This app demonstrates AI reasoning chains using different backends: Ollama, Perplexity AI, and Groq.
|
||||||
@ -29,6 +35,7 @@ def setup_page():
|
|||||||
""", unsafe_allow_html=True)
|
""", unsafe_allow_html=True)
|
||||||
|
|
||||||
def get_api_handler(backend, config):
|
def get_api_handler(backend, config):
|
||||||
|
# Create and return the appropriate API handler based on the selected backend
|
||||||
if backend == "Ollama":
|
if backend == "Ollama":
|
||||||
return OllamaHandler(config['OLLAMA_URL'], config['OLLAMA_MODEL'])
|
return OllamaHandler(config['OLLAMA_URL'], config['OLLAMA_MODEL'])
|
||||||
elif backend == "Perplexity AI":
|
elif backend == "Perplexity AI":
|
||||||
@ -40,14 +47,17 @@ def main():
|
|||||||
logger.info("Starting the application")
|
logger.info("Starting the application")
|
||||||
setup_page()
|
setup_page()
|
||||||
|
|
||||||
|
# Set up the sidebar for configuration
|
||||||
st.sidebar.markdown('<h3 class="sidebar-title">⚙️ Settings</h3>', unsafe_allow_html=True)
|
st.sidebar.markdown('<h3 class="sidebar-title">⚙️ Settings</h3>', unsafe_allow_html=True)
|
||||||
config = config_menu()
|
config = config_menu()
|
||||||
|
|
||||||
|
# Allow user to select the AI backend
|
||||||
backend = st.sidebar.selectbox("Choose AI Backend", ["Ollama", "Perplexity AI", "Groq"])
|
backend = st.sidebar.selectbox("Choose AI Backend", ["Ollama", "Perplexity AI", "Groq"])
|
||||||
display_config(backend, config)
|
display_config(backend, config)
|
||||||
api_handler = get_api_handler(backend, config)
|
api_handler = get_api_handler(backend, config)
|
||||||
logger.info(f"Selected backend: {backend}")
|
logger.info(f"Selected backend: {backend}")
|
||||||
|
|
||||||
|
# User input field
|
||||||
user_query = st.text_input("💬 Enter your query:", placeholder="e.g., How many 'R's are in the word strawberry?")
|
user_query = st.text_input("💬 Enter your query:", placeholder="e.g., How many 'R's are in the word strawberry?")
|
||||||
|
|
||||||
if user_query:
|
if user_query:
|
||||||
@ -57,22 +67,27 @@ def main():
|
|||||||
time_container = st.empty()
|
time_container = st.empty()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Generate and display the response
|
||||||
for steps, total_thinking_time in generate_response(user_query, api_handler):
|
for steps, total_thinking_time in generate_response(user_query, api_handler):
|
||||||
with response_container.container():
|
with response_container.container():
|
||||||
for title, content, _ in steps:
|
for title, content, _ in steps:
|
||||||
if title.startswith("Final Answer"):
|
if title.startswith("Final Answer"):
|
||||||
|
# Display the final answer
|
||||||
st.markdown(f'<h3 class="expander-title">🎯 {title}</h3>', unsafe_allow_html=True)
|
st.markdown(f'<h3 class="expander-title">🎯 {title}</h3>', unsafe_allow_html=True)
|
||||||
st.markdown(f'<div>{content}</div>', unsafe_allow_html=True)
|
st.markdown(f'<div>{content}</div>', unsafe_allow_html=True)
|
||||||
logger.info(f"Final answer generated: {content}")
|
logger.info(f"Final answer generated: {content}")
|
||||||
else:
|
else:
|
||||||
|
# Display intermediate steps
|
||||||
with st.expander(f"📝 {title}", expanded=True):
|
with st.expander(f"📝 {title}", expanded=True):
|
||||||
st.markdown(f'<div>{content}</div>', unsafe_allow_html=True)
|
st.markdown(f'<div>{content}</div>', unsafe_allow_html=True)
|
||||||
logger.debug(f"Step completed: {title}")
|
logger.debug(f"Step completed: {title}")
|
||||||
|
|
||||||
|
# Display total thinking time
|
||||||
if total_thinking_time is not None:
|
if total_thinking_time is not None:
|
||||||
time_container.markdown(f'<p class="thinking-time">⏱️ Total thinking time: {total_thinking_time:.2f} seconds</p>', unsafe_allow_html=True)
|
time_container.markdown(f'<p class="thinking-time">⏱️ Total thinking time: {total_thinking_time:.2f} seconds</p>', unsafe_allow_html=True)
|
||||||
logger.info(f"Total thinking time: {total_thinking_time:.2f} seconds")
|
logger.info(f"Total thinking time: {total_thinking_time:.2f} seconds")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Handle and display any errors
|
||||||
logger.error(f"Error generating response: {str(e)}", exc_info=True)
|
logger.error(f"Error generating response: {str(e)}", exc_info=True)
|
||||||
st.error("An error occurred while generating the response. Please try again.")
|
st.error("An error occurred while generating the response. Please try again.")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user