diff --git a/app/api_handlers.py b/app/api_handlers.py index 0f0cf34..9dd8fbb 100644 --- a/app/api_handlers.py +++ b/app/api_handlers.py @@ -39,4 +39,5 @@ class BaseHandler(ABC): # Import derived handlers from handlers.ollama_handler import OllamaHandler from handlers.perplexity_handler import PerplexityHandler -from handlers.groq_handler import GroqHandler \ No newline at end of file +from handlers.groq_handler import GroqHandler +from handlers.litellm_handler import LiteLLMHandler diff --git a/app/config_menu.py b/app/config_menu.py index 16d859b..8946755 100644 --- a/app/config_menu.py +++ b/app/config_menu.py @@ -45,5 +45,5 @@ def display_config(backend, config): st.sidebar.markdown(f"- 🤖 Ollama Model: `{config['OLLAMA_MODEL']}`") elif backend == "Perplexity AI": st.sidebar.markdown(f"- 🧠 Perplexity AI Model: `{config['PERPLEXITY_MODEL']}`") - else: # Groq + elif backend == "Groq": st.sidebar.markdown(f"- ⚡ Groq Model: `{config['GROQ_MODEL']}`") diff --git a/app/handlers/__init__.py b/app/handlers/__init__.py index a335078..bb51492 100644 --- a/app/handlers/__init__.py +++ b/app/handlers/__init__.py @@ -1,5 +1,6 @@ from .ollama_handler import OllamaHandler from .perplexity_handler import PerplexityHandler from .groq_handler import GroqHandler +from .litellm_handler import LiteLLMHandler -__all__ = ['OllamaHandler', 'PerplexityHandler', 'GroqHandler'] \ No newline at end of file +__all__ = ['OllamaHandler', 'PerplexityHandler', 'GroqHandler', 'LiteLLMHandler'] \ No newline at end of file diff --git a/app/handlers/litellm_handler.py b/app/handlers/litellm_handler.py new file mode 100644 index 0000000..9462d45 --- /dev/null +++ b/app/handlers/litellm_handler.py @@ -0,0 +1,23 @@ +from api_handlers import BaseHandler +from litellm import completion + +class LiteLLMHandler(BaseHandler): + def __init__(self, model, api_base=None, api_key=None): + super().__init__() + self.model = model + self.api_base = api_base + self.api_key = api_key + + def _make_request(self, messages, max_tokens): + response = completion( + model=self.model, + messages=messages, + max_tokens=max_tokens, + temperature=0.2, + api_base=self.api_base, + api_key=self.api_key + ) + return response.choices[0].message.content + + def _process_response(self, response, is_final_answer): + return super()._process_response(response, is_final_answer) \ No newline at end of file diff --git a/app/main.py b/app/main.py index dd9c1f2..110f3d8 100644 --- a/app/main.py +++ b/app/main.py @@ -1,10 +1,11 @@ import streamlit as st from dotenv import load_dotenv from api_handlers import OllamaHandler, PerplexityHandler, GroqHandler -from utils import generate_response +from utils import generate_response, litellm_config, litellm_instructions from config_menu import config_menu, display_config from logger import logger import os +from handlers.litellm_handler import LiteLLMHandler # Load environment variables from .env file load_dotenv() @@ -35,13 +36,19 @@ def setup_page(): """, unsafe_allow_html=True) def get_api_handler(backend, config): - # Create and return the appropriate API handler based on the selected backend if backend == "Ollama": return OllamaHandler(config['OLLAMA_URL'], config['OLLAMA_MODEL']) elif backend == "Perplexity AI": return PerplexityHandler(config['PERPLEXITY_API_KEY'], config['PERPLEXITY_MODEL']) - else: # Groq + elif backend == "Groq": return GroqHandler(config['GROQ_API_KEY'], config['GROQ_MODEL']) + else: # LiteLLM + litellm_config = st.session_state.get('litellm_config', {}) + return LiteLLMHandler( + litellm_config.get('model', ''), + litellm_config.get('api_base', ''), + litellm_config.get('api_key', '') + ) def main(): logger.info("Starting the application") @@ -51,9 +58,15 @@ def main(): st.sidebar.markdown('