From 1cbbe4c6d413745803360485cbf08af661f628e7 Mon Sep 17 00:00:00 2001 From: tcsenpai Date: Wed, 18 Sep 2024 12:08:55 +0200 Subject: [PATCH] Added LiteLLM initial support (ollama tested) --- app/api_handlers.py | 3 ++- app/config_menu.py | 2 +- app/handlers/__init__.py | 3 ++- app/handlers/litellm_handler.py | 23 +++++++++++++++++++++++ app/main.py | 25 +++++++++++++++++++------ app/utils.py | 33 +++++++++++++++++++++++++++++++++ requirements.txt | 1 + 7 files changed, 81 insertions(+), 9 deletions(-) create mode 100644 app/handlers/litellm_handler.py diff --git a/app/api_handlers.py b/app/api_handlers.py index 0f0cf34..9dd8fbb 100644 --- a/app/api_handlers.py +++ b/app/api_handlers.py @@ -39,4 +39,5 @@ class BaseHandler(ABC): # Import derived handlers from handlers.ollama_handler import OllamaHandler from handlers.perplexity_handler import PerplexityHandler -from handlers.groq_handler import GroqHandler \ No newline at end of file +from handlers.groq_handler import GroqHandler +from handlers.litellm_handler import LiteLLMHandler diff --git a/app/config_menu.py b/app/config_menu.py index 16d859b..8946755 100644 --- a/app/config_menu.py +++ b/app/config_menu.py @@ -45,5 +45,5 @@ def display_config(backend, config): st.sidebar.markdown(f"- 🤖 Ollama Model: `{config['OLLAMA_MODEL']}`") elif backend == "Perplexity AI": st.sidebar.markdown(f"- 🧠 Perplexity AI Model: `{config['PERPLEXITY_MODEL']}`") - else: # Groq + elif backend == "Groq": st.sidebar.markdown(f"- ⚡ Groq Model: `{config['GROQ_MODEL']}`") diff --git a/app/handlers/__init__.py b/app/handlers/__init__.py index a335078..bb51492 100644 --- a/app/handlers/__init__.py +++ b/app/handlers/__init__.py @@ -1,5 +1,6 @@ from .ollama_handler import OllamaHandler from .perplexity_handler import PerplexityHandler from .groq_handler import GroqHandler +from .litellm_handler import LiteLLMHandler -__all__ = ['OllamaHandler', 'PerplexityHandler', 'GroqHandler'] \ No newline at end of file +__all__ = ['OllamaHandler', 'PerplexityHandler', 'GroqHandler', 'LiteLLMHandler'] \ No newline at end of file diff --git a/app/handlers/litellm_handler.py b/app/handlers/litellm_handler.py new file mode 100644 index 0000000..9462d45 --- /dev/null +++ b/app/handlers/litellm_handler.py @@ -0,0 +1,23 @@ +from api_handlers import BaseHandler +from litellm import completion + +class LiteLLMHandler(BaseHandler): + def __init__(self, model, api_base=None, api_key=None): + super().__init__() + self.model = model + self.api_base = api_base + self.api_key = api_key + + def _make_request(self, messages, max_tokens): + response = completion( + model=self.model, + messages=messages, + max_tokens=max_tokens, + temperature=0.2, + api_base=self.api_base, + api_key=self.api_key + ) + return response.choices[0].message.content + + def _process_response(self, response, is_final_answer): + return super()._process_response(response, is_final_answer) \ No newline at end of file diff --git a/app/main.py b/app/main.py index dd9c1f2..110f3d8 100644 --- a/app/main.py +++ b/app/main.py @@ -1,10 +1,11 @@ import streamlit as st from dotenv import load_dotenv from api_handlers import OllamaHandler, PerplexityHandler, GroqHandler -from utils import generate_response +from utils import generate_response, litellm_config, litellm_instructions from config_menu import config_menu, display_config from logger import logger import os +from handlers.litellm_handler import LiteLLMHandler # Load environment variables from .env file load_dotenv() @@ -35,13 +36,19 @@ def setup_page(): """, unsafe_allow_html=True) def get_api_handler(backend, config): - # Create and return the appropriate API handler based on the selected backend if backend == "Ollama": return OllamaHandler(config['OLLAMA_URL'], config['OLLAMA_MODEL']) elif backend == "Perplexity AI": return PerplexityHandler(config['PERPLEXITY_API_KEY'], config['PERPLEXITY_MODEL']) - else: # Groq + elif backend == "Groq": return GroqHandler(config['GROQ_API_KEY'], config['GROQ_MODEL']) + else: # LiteLLM + litellm_config = st.session_state.get('litellm_config', {}) + return LiteLLMHandler( + litellm_config.get('model', ''), + litellm_config.get('api_base', ''), + litellm_config.get('api_key', '') + ) def main(): logger.info("Starting the application") @@ -51,9 +58,15 @@ def main(): st.sidebar.markdown('', unsafe_allow_html=True) config = config_menu() - # Allow user to select the AI backend - backend = st.sidebar.selectbox("Choose AI Backend", ["Ollama", "Perplexity AI", "Groq"]) - display_config(backend, config) + # Allow user to select the AI backend + backend = st.sidebar.selectbox("Choose AI Backend", ["LiteLLM", "Ollama", "Perplexity AI", "Groq"]) + + if backend == "LiteLLM": + litellm_instructions() + litellm_config() + else: + display_config(backend, config) + api_handler = get_api_handler(backend, config) logger.info(f"Selected backend: {backend}") diff --git a/app/utils.py b/app/utils.py index 22a2fd2..e00211c 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,6 +1,7 @@ import json import time import os +import streamlit as st def generate_response(prompt, api_handler):# Get the absolute path to the system_prompt.txt file @@ -80,3 +81,35 @@ def load_env_vars(): "PERPLEXITY_API_KEY": os.getenv("PERPLEXITY_API_KEY"), "PERPLEXITY_MODEL": os.getenv("PERPLEXITY_MODEL", "llama-3.1-sonar-small-128k-online"), } + +def litellm_instructions(): + st.sidebar.markdown(""" + ### LiteLLM Configuration Instructions: + 1. **Model**: Enter the model name (e.g., 'gpt-3.5-turbo', 'claude-2'). + For Ollama, use 'ollama/{model_name}' + 2. **API Base**: + - For Ollama: Leave blank or use 'http://localhost:11434' + - For OpenAI: Leave blank or use 'https://api.openai.com/v1' + - For Anthropic: Use 'https://api.anthropic.com' + - For other providers: Enter their specific API base URL + 3. **API Key**: Enter your API key for the chosen provider (only if required by the provider). + + Note: Ensure you have the necessary permissions and credits for the selected model and provider. + """) + +def litellm_config(): + if 'litellm_config' not in st.session_state: + st.session_state.litellm_config = {} + + col1, col2, col3 = st.columns(3) + + with col1: + st.session_state.litellm_config['model'] = st.text_input("Model", value=st.session_state.litellm_config.get('model', 'ollama/qwen2:1.5b')) + + with col2: + st.session_state.litellm_config['api_base'] = st.text_input("API Base", value=st.session_state.litellm_config.get('api_base', '')) + + with col3: + st.session_state.litellm_config['api_key'] = st.text_input("API Key", value=st.session_state.litellm_config.get('api_key', ''), type="password") + + st.info("Configuration is automatically saved in the session. No need to click a save button.") diff --git a/requirements.txt b/requirements.txt index a90c375..c1dba97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ groq python-dotenv requests blessed +litellm \ No newline at end of file