From 94fb15359bae2ecc2d17b7b1b1c3df6a73b49ac2 Mon Sep 17 00:00:00 2001 From: rense Date: Tue, 6 May 2025 18:03:14 +0200 Subject: [PATCH] allow connecting to remote Ollama server --- sources/llm_provider.py | 92 +++++++++++++++++++++++------------------ 1 file changed, 51 insertions(+), 41 deletions(-) diff --git a/sources/llm_provider.py b/sources/llm_provider.py index cde9d6d..263ed07 100644 --- a/sources/llm_provider.py +++ b/sources/llm_provider.py @@ -1,27 +1,27 @@ - import os -import time -import ollama -from ollama import chat -import requests -import subprocess -import ipaddress -import httpx -import socket import platform +import socket +import subprocess +import time from urllib.parse import urlparse -from dotenv import load_dotenv, set_key + +import httpx +import requests +from dotenv import load_dotenv +from ollama import Client as OllamaClient from openai import OpenAI -from typing import List, Tuple, Type, Dict -from sources.utility import pretty_print, animate_thinking + from sources.logger import Logger +from sources.utility import pretty_print, animate_thinking + class Provider: - def __init__(self, provider_name, model, server_address = "127.0.0.1:5000", is_local=False): + def __init__(self, provider_name, model, server_address="127.0.0.1:5000", is_local=False): self.provider_name = provider_name.lower() self.model = model self.is_local = is_local self.server_ip = server_address + self.server_address = server_address self.available_providers = { "ollama": self.ollama_fn, "server": self.server_fn, @@ -44,7 +44,7 @@ class Provider: self.api_key = self.get_api_key(self.provider_name) elif self.provider_name != "ollama": pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success") - + def get_model_name(self) -> str: return self.model @@ -57,7 +57,7 @@ class Provider: exit(1) return api_key - def respond(self, history, verbose = True): + def respond(self, history, verbose=True): """ Use the choosen provider to generate text. """ @@ -73,7 +73,8 @@ class Provider: except AttributeError as e: raise NotImplementedError(f"{str(e)}\nIs {self.provider_name} implemented ?") except ModuleNotFoundError as e: - raise ModuleNotFoundError(f"{str(e)}\nA import related to provider {self.provider_name} was not found. Is it installed ?") + raise ModuleNotFoundError( + f"{str(e)}\nA import related to provider {self.provider_name} was not found. Is it installed ?") except Exception as e: if "try again later" in str(e).lower(): return f"{self.provider_name} server is overloaded. Please try again later." @@ -106,8 +107,7 @@ class Provider: except (subprocess.TimeoutExpired, subprocess.SubprocessError) as e: return False - - def server_fn(self, history, verbose = False): + def server_fn(self, history, verbose=False): """ Use a remote server with LLM to generate text. """ @@ -141,50 +141,59 @@ class Provider: pretty_print(f"An error occurred: {str(e)}", color="failure") break except KeyError as e: - raise Exception(f"{str(e)}\nError occured with server route. Are you using the correct address for the config.ini provider?") from e + raise Exception( + f"{str(e)}\nError occured with server route. Are you using the correct address for the config.ini provider?") from e except Exception as e: raise e return thought - def ollama_fn(self, history, verbose = False): + def ollama_fn(self, history, verbose=False): """ - Use local ollama server to generate text. + Use local or remote Ollama server to generate text. """ thought = "" + host = "http://localhost:11434" if self.is_local else f"http://{self.server_address}" + client = OllamaClient(host=host) + try: - stream = chat( + stream = client.chat( model=self.model, messages=history, stream=True, ) for chunk in stream: - if verbose: - print(chunk['message']['content'], end='', flush=True) - thought += chunk['message']['content'] + if verbose: + print(chunk["message"]["content"], end="", flush=True) + thought += chunk["message"]["content"] except httpx.ConnectError as e: - raise Exception("\nOllama connection failed. provider should not be set to ollama if server address is not localhost") from e - except ollama.ResponseError as e: - if e.status_code == 404: + raise Exception( + f"\nOllama connection failed at {host}. Check if the server is running." + ) from e + except Exception as e: + if hasattr(e, 'status_code') and e.status_code == 404: animate_thinking(f"Downloading {self.model}...") - ollama.pull(self.model) + client.pull(self.model) self.ollama_fn(history, verbose) if "refused" in str(e).lower(): - raise Exception("Ollama connection failed. is the server running ?") from e + raise Exception( + f"Ollama connection refused at {host}. Is the server running?" + ) from e raise e + return thought - + def huggingface_fn(self, history, verbose=False): """ Use huggingface to generate text. """ from huggingface_hub import InferenceClient client = InferenceClient( - api_key=self.get_api_key("huggingface") + api_key=self.get_api_key("huggingface") ) completion = client.chat.completions.create( - model=self.model, - messages=history, - max_tokens=1024, + model=self.model, + messages=history, + max_tokens=1024, ) thought = completion.choices[0].message return thought.content @@ -212,7 +221,7 @@ class Provider: return thought except Exception as e: raise Exception(f"OpenAI API error: {str(e)}") from e - + def google_fn(self, history, verbose=False): """ Use google gemini to generate text. @@ -278,8 +287,8 @@ class Provider: return thought except Exception as e: raise Exception(f"Deepseek API error: {str(e)}") from e - - def lm_studio_fn(self, history, verbose = False): + + def lm_studio_fn(self, history, verbose=False): """ Use local lm-studio server to generate text. lm studio use endpoint /v1/chat/completions not /chat/completions like openai @@ -304,14 +313,14 @@ class Provider: raise Exception(f"An error occurred: {str(e)}") from e return thought - def dsk_deepseek(self, history, verbose = False): + def dsk_deepseek(self, history, verbose=False): """ Use: xtekky/deepseek4free For free api. Api key should be set to DSK_DEEPSEEK_API_KEY This is an unofficial provider, you'll have to find how to set it up yourself. """ from dsk.api import ( - DeepSeekAPI, + DeepSeekAPI, AuthenticationError, RateLimitError, NetworkError, @@ -340,7 +349,7 @@ class Provider: raise APIError(f"API error occurred: {str(e)}") from e return None - def test_fn(self, history, verbose = True): + def test_fn(self, history, verbose=True): """ This function is used to conduct tests. """ @@ -349,6 +358,7 @@ class Provider: """ return thought + if __name__ == "__main__": provider = Provider("server", "deepseek-r1:32b", " x.x.x.x:8080") res = provider.respond(["user", "Hello, how are you?"])