Merge pull request #167 from rense/feature/remote-ollama

Allow connecting to a remote Ollama server
This commit is contained in:
Martin 2025-05-06 19:49:18 +02:00 committed by GitHub
commit 3678c091ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,27 +1,27 @@
import os import os
import time
import ollama
from ollama import chat
import requests
import subprocess
import ipaddress
import httpx
import socket
import platform import platform
import socket
import subprocess
import time
from urllib.parse import urlparse from urllib.parse import urlparse
from dotenv import load_dotenv, set_key
import httpx
import requests
from dotenv import load_dotenv
from ollama import Client as OllamaClient
from openai import OpenAI from openai import OpenAI
from typing import List, Tuple, Type, Dict
from sources.utility import pretty_print, animate_thinking
from sources.logger import Logger from sources.logger import Logger
from sources.utility import pretty_print, animate_thinking
class Provider: class Provider:
def __init__(self, provider_name, model, server_address = "127.0.0.1:5000", is_local=False): def __init__(self, provider_name, model, server_address="127.0.0.1:5000", is_local=False):
self.provider_name = provider_name.lower() self.provider_name = provider_name.lower()
self.model = model self.model = model
self.is_local = is_local self.is_local = is_local
self.server_ip = server_address self.server_ip = server_address
self.server_address = server_address
self.available_providers = { self.available_providers = {
"ollama": self.ollama_fn, "ollama": self.ollama_fn,
"server": self.server_fn, "server": self.server_fn,
@ -44,7 +44,7 @@ class Provider:
self.api_key = self.get_api_key(self.provider_name) self.api_key = self.get_api_key(self.provider_name)
elif self.provider_name != "ollama": elif self.provider_name != "ollama":
pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success") pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success")
def get_model_name(self) -> str: def get_model_name(self) -> str:
return self.model return self.model
@ -57,7 +57,7 @@ class Provider:
exit(1) exit(1)
return api_key return api_key
def respond(self, history, verbose = True): def respond(self, history, verbose=True):
""" """
Use the choosen provider to generate text. Use the choosen provider to generate text.
""" """
@ -73,7 +73,8 @@ class Provider:
except AttributeError as e: except AttributeError as e:
raise NotImplementedError(f"{str(e)}\nIs {self.provider_name} implemented ?") raise NotImplementedError(f"{str(e)}\nIs {self.provider_name} implemented ?")
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
raise ModuleNotFoundError(f"{str(e)}\nA import related to provider {self.provider_name} was not found. Is it installed ?") raise ModuleNotFoundError(
f"{str(e)}\nA import related to provider {self.provider_name} was not found. Is it installed ?")
except Exception as e: except Exception as e:
if "try again later" in str(e).lower(): if "try again later" in str(e).lower():
return f"{self.provider_name} server is overloaded. Please try again later." return f"{self.provider_name} server is overloaded. Please try again later."
@ -106,8 +107,7 @@ class Provider:
except (subprocess.TimeoutExpired, subprocess.SubprocessError) as e: except (subprocess.TimeoutExpired, subprocess.SubprocessError) as e:
return False return False
def server_fn(self, history, verbose=False):
def server_fn(self, history, verbose = False):
""" """
Use a remote server with LLM to generate text. Use a remote server with LLM to generate text.
""" """
@ -141,50 +141,59 @@ class Provider:
pretty_print(f"An error occurred: {str(e)}", color="failure") pretty_print(f"An error occurred: {str(e)}", color="failure")
break break
except KeyError as e: except KeyError as e:
raise Exception(f"{str(e)}\nError occured with server route. Are you using the correct address for the config.ini provider?") from e raise Exception(
f"{str(e)}\nError occured with server route. Are you using the correct address for the config.ini provider?") from e
except Exception as e: except Exception as e:
raise e raise e
return thought return thought
def ollama_fn(self, history, verbose = False): def ollama_fn(self, history, verbose=False):
""" """
Use local ollama server to generate text. Use local or remote Ollama server to generate text.
""" """
thought = "" thought = ""
host = "http://localhost:11434" if self.is_local else f"http://{self.server_address}"
client = OllamaClient(host=host)
try: try:
stream = chat( stream = client.chat(
model=self.model, model=self.model,
messages=history, messages=history,
stream=True, stream=True,
) )
for chunk in stream: for chunk in stream:
if verbose: if verbose:
print(chunk['message']['content'], end='', flush=True) print(chunk["message"]["content"], end="", flush=True)
thought += chunk['message']['content'] thought += chunk["message"]["content"]
except httpx.ConnectError as e: except httpx.ConnectError as e:
raise Exception("\nOllama connection failed. provider should not be set to ollama if server address is not localhost") from e raise Exception(
except ollama.ResponseError as e: f"\nOllama connection failed at {host}. Check if the server is running."
if e.status_code == 404: ) from e
except Exception as e:
if hasattr(e, 'status_code') and e.status_code == 404:
animate_thinking(f"Downloading {self.model}...") animate_thinking(f"Downloading {self.model}...")
ollama.pull(self.model) client.pull(self.model)
self.ollama_fn(history, verbose) self.ollama_fn(history, verbose)
if "refused" in str(e).lower(): if "refused" in str(e).lower():
raise Exception("Ollama connection failed. is the server running ?") from e raise Exception(
f"Ollama connection refused at {host}. Is the server running?"
) from e
raise e raise e
return thought return thought
def huggingface_fn(self, history, verbose=False): def huggingface_fn(self, history, verbose=False):
""" """
Use huggingface to generate text. Use huggingface to generate text.
""" """
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
client = InferenceClient( client = InferenceClient(
api_key=self.get_api_key("huggingface") api_key=self.get_api_key("huggingface")
) )
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=self.model, model=self.model,
messages=history, messages=history,
max_tokens=1024, max_tokens=1024,
) )
thought = completion.choices[0].message thought = completion.choices[0].message
return thought.content return thought.content
@ -212,7 +221,7 @@ class Provider:
return thought return thought
except Exception as e: except Exception as e:
raise Exception(f"OpenAI API error: {str(e)}") from e raise Exception(f"OpenAI API error: {str(e)}") from e
def google_fn(self, history, verbose=False): def google_fn(self, history, verbose=False):
""" """
Use google gemini to generate text. Use google gemini to generate text.
@ -278,8 +287,8 @@ class Provider:
return thought return thought
except Exception as e: except Exception as e:
raise Exception(f"Deepseek API error: {str(e)}") from e raise Exception(f"Deepseek API error: {str(e)}") from e
def lm_studio_fn(self, history, verbose = False): def lm_studio_fn(self, history, verbose=False):
""" """
Use local lm-studio server to generate text. Use local lm-studio server to generate text.
lm studio use endpoint /v1/chat/completions not /chat/completions like openai lm studio use endpoint /v1/chat/completions not /chat/completions like openai
@ -304,14 +313,14 @@ class Provider:
raise Exception(f"An error occurred: {str(e)}") from e raise Exception(f"An error occurred: {str(e)}") from e
return thought return thought
def dsk_deepseek(self, history, verbose = False): def dsk_deepseek(self, history, verbose=False):
""" """
Use: xtekky/deepseek4free Use: xtekky/deepseek4free
For free api. Api key should be set to DSK_DEEPSEEK_API_KEY For free api. Api key should be set to DSK_DEEPSEEK_API_KEY
This is an unofficial provider, you'll have to find how to set it up yourself. This is an unofficial provider, you'll have to find how to set it up yourself.
""" """
from dsk.api import ( from dsk.api import (
DeepSeekAPI, DeepSeekAPI,
AuthenticationError, AuthenticationError,
RateLimitError, RateLimitError,
NetworkError, NetworkError,
@ -340,7 +349,7 @@ class Provider:
raise APIError(f"API error occurred: {str(e)}") from e raise APIError(f"API error occurred: {str(e)}") from e
return None return None
def test_fn(self, history, verbose = True): def test_fn(self, history, verbose=True):
""" """
This function is used to conduct tests. This function is used to conduct tests.
""" """
@ -349,6 +358,7 @@ class Provider:
""" """
return thought return thought
if __name__ == "__main__": if __name__ == "__main__":
provider = Provider("server", "deepseek-r1:32b", " x.x.x.x:8080") provider = Provider("server", "deepseek-r1:32b", " x.x.x.x:8080")
res = provider.respond(["user", "Hello, how are you?"]) res = provider.respond(["user", "Hello, how are you?"])