From dcb5724e28db4866920471c66ae959ff4b4f3bef Mon Sep 17 00:00:00 2001 From: martin legrand Date: Sun, 2 Mar 2025 11:54:45 +0100 Subject: [PATCH] Feat : add options to use api based providers --- requirements.txt | 1 + sources/llm_provider.py | 55 +++++++++++++++++++++++++++++++++++------ 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3a4967e..ddbc9f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ flask==3.1.0 soundfile==0.13.1 protobuf==3.20.3 termcolor==2.3.0 +openai==1.13.3 # if use chinese ordered_set pypinyin diff --git a/sources/llm_provider.py b/sources/llm_provider.py index f41d182..46f4c76 100644 --- a/sources/llm_provider.py +++ b/sources/llm_provider.py @@ -6,6 +6,9 @@ import requests import subprocess import ipaddress import platform +from dotenv import load_dotenv, set_key +from openai import OpenAI +import os class Provider: def __init__(self, provider_name, model, server_address = "127.0.0.1:5000"): @@ -15,12 +18,27 @@ class Provider: self.available_providers = { "ollama": self.ollama_fn, "server": self.server_fn, - "test": self.test_fn, + "openai": self.openai_fn } - if self.server != "": + self.api_key = None + self.unsafe_providers = ["openai"] + if self.provider_name not in self.available_providers: + raise ValueError(f"Unknown provider: {provider_name}") + if self.provider_name in self.unsafe_providers: + print("Warning: you are using an API provider. You data will be sent to the cloud.") + self.get_api_key(self.provider_name) + elif self.server != "": print("Provider initialized at ", self.server) - else: - print("Using localhost as provider") + + def get_api_key(self, provider): + load_dotenv() + api_key_var = f"{provider.upper()}_API_KEY" + api_key = os.getenv(api_key_var) + if not api_key: + api_key = input(f"Please enter your {provider} API key: ") + set_key(".env", api_key_var, api_key) + load_dotenv() + return api_key def check_address_format(self, address): """ @@ -61,7 +79,7 @@ class Provider: print(f"An error occurred: {e}") return False - def server_fn(self, history, verbose = True): + def server_fn(self, history, verbose = False): """ Use a remote server wit LLM to generate text. """ @@ -76,12 +94,11 @@ class Provider: while not is_complete: response = requests.get(f"http://{self.server}/get_updated_sentence") thought = response.json()["sentence"] - # TODO add real time streaming to stdout is_complete = bool(response.json()["is_complete"]) time.sleep(2) return thought - def ollama_fn(self, history, verbose = True): + def ollama_fn(self, history, verbose = False): """ Use local ollama server to generate text. """ @@ -104,9 +121,27 @@ class Provider: raise e return thought + def openai_fn(self, history, verbose=False): + """ + Use openai to generate text. + """ + api_key = self.get_api_key("openai") + client = OpenAI(api_key=api_key) + try: + response = client.chat.completions.create( + model=self.model, + messages=history + ) + thought = response.choices[0].message.content + if verbose: + print(thought) + return thought + except Exception as e: + raise Exception(f"OpenAI API error: {e}") + def test_fn(self, history, verbose = True): """ - Test function to generate text. + This function is used to conduct tests. """ thought = """ This is a test response from the test provider. @@ -121,3 +156,7 @@ class Provider: ``` """ return thought + +if __name__ == "__main__": + provider = Provider("openai", "gpt-4o-mini") + print(provider.respond(["user", "Hello, how are you?"]))