Feat : add options to use api based providers

This commit is contained in:
martin legrand 2025-03-02 11:54:45 +01:00
parent 8b619fe71f
commit dcb5724e28
2 changed files with 48 additions and 8 deletions

View File

@ -13,6 +13,7 @@ flask==3.1.0
soundfile==0.13.1
protobuf==3.20.3
termcolor==2.3.0
openai==1.13.3
# if use chinese
ordered_set
pypinyin

View File

@ -6,6 +6,9 @@ import requests
import subprocess
import ipaddress
import platform
from dotenv import load_dotenv, set_key
from openai import OpenAI
import os
class Provider:
def __init__(self, provider_name, model, server_address = "127.0.0.1:5000"):
@ -15,12 +18,27 @@ class Provider:
self.available_providers = {
"ollama": self.ollama_fn,
"server": self.server_fn,
"test": self.test_fn,
"openai": self.openai_fn
}
if self.server != "":
self.api_key = None
self.unsafe_providers = ["openai"]
if self.provider_name not in self.available_providers:
raise ValueError(f"Unknown provider: {provider_name}")
if self.provider_name in self.unsafe_providers:
print("Warning: you are using an API provider. You data will be sent to the cloud.")
self.get_api_key(self.provider_name)
elif self.server != "":
print("Provider initialized at ", self.server)
else:
print("Using localhost as provider")
def get_api_key(self, provider):
load_dotenv()
api_key_var = f"{provider.upper()}_API_KEY"
api_key = os.getenv(api_key_var)
if not api_key:
api_key = input(f"Please enter your {provider} API key: ")
set_key(".env", api_key_var, api_key)
load_dotenv()
return api_key
def check_address_format(self, address):
"""
@ -61,7 +79,7 @@ class Provider:
print(f"An error occurred: {e}")
return False
def server_fn(self, history, verbose = True):
def server_fn(self, history, verbose = False):
"""
Use a remote server wit LLM to generate text.
"""
@ -76,12 +94,11 @@ class Provider:
while not is_complete:
response = requests.get(f"http://{self.server}/get_updated_sentence")
thought = response.json()["sentence"]
# TODO add real time streaming to stdout
is_complete = bool(response.json()["is_complete"])
time.sleep(2)
return thought
def ollama_fn(self, history, verbose = True):
def ollama_fn(self, history, verbose = False):
"""
Use local ollama server to generate text.
"""
@ -104,9 +121,27 @@ class Provider:
raise e
return thought
def openai_fn(self, history, verbose=False):
"""
Use openai to generate text.
"""
api_key = self.get_api_key("openai")
client = OpenAI(api_key=api_key)
try:
response = client.chat.completions.create(
model=self.model,
messages=history
)
thought = response.choices[0].message.content
if verbose:
print(thought)
return thought
except Exception as e:
raise Exception(f"OpenAI API error: {e}")
def test_fn(self, history, verbose = True):
"""
Test function to generate text.
This function is used to conduct tests.
"""
thought = """
This is a test response from the test provider.
@ -121,3 +156,7 @@ class Provider:
```
"""
return thought
if __name__ == "__main__":
provider = Provider("openai", "gpt-4o-mini")
print(provider.respond(["user", "Hello, how are you?"]))