mirror of
https://github.com/tcsenpai/agenticSeek.git
synced 2025-06-05 02:25:27 +00:00
Feat : add options to use api based providers
This commit is contained in:
parent
8b619fe71f
commit
dcb5724e28
@ -13,6 +13,7 @@ flask==3.1.0
|
||||
soundfile==0.13.1
|
||||
protobuf==3.20.3
|
||||
termcolor==2.3.0
|
||||
openai==1.13.3
|
||||
# if use chinese
|
||||
ordered_set
|
||||
pypinyin
|
||||
|
@ -6,6 +6,9 @@ import requests
|
||||
import subprocess
|
||||
import ipaddress
|
||||
import platform
|
||||
from dotenv import load_dotenv, set_key
|
||||
from openai import OpenAI
|
||||
import os
|
||||
|
||||
class Provider:
|
||||
def __init__(self, provider_name, model, server_address = "127.0.0.1:5000"):
|
||||
@ -15,12 +18,27 @@ class Provider:
|
||||
self.available_providers = {
|
||||
"ollama": self.ollama_fn,
|
||||
"server": self.server_fn,
|
||||
"test": self.test_fn,
|
||||
"openai": self.openai_fn
|
||||
}
|
||||
if self.server != "":
|
||||
self.api_key = None
|
||||
self.unsafe_providers = ["openai"]
|
||||
if self.provider_name not in self.available_providers:
|
||||
raise ValueError(f"Unknown provider: {provider_name}")
|
||||
if self.provider_name in self.unsafe_providers:
|
||||
print("Warning: you are using an API provider. You data will be sent to the cloud.")
|
||||
self.get_api_key(self.provider_name)
|
||||
elif self.server != "":
|
||||
print("Provider initialized at ", self.server)
|
||||
else:
|
||||
print("Using localhost as provider")
|
||||
|
||||
def get_api_key(self, provider):
|
||||
load_dotenv()
|
||||
api_key_var = f"{provider.upper()}_API_KEY"
|
||||
api_key = os.getenv(api_key_var)
|
||||
if not api_key:
|
||||
api_key = input(f"Please enter your {provider} API key: ")
|
||||
set_key(".env", api_key_var, api_key)
|
||||
load_dotenv()
|
||||
return api_key
|
||||
|
||||
def check_address_format(self, address):
|
||||
"""
|
||||
@ -61,7 +79,7 @@ class Provider:
|
||||
print(f"An error occurred: {e}")
|
||||
return False
|
||||
|
||||
def server_fn(self, history, verbose = True):
|
||||
def server_fn(self, history, verbose = False):
|
||||
"""
|
||||
Use a remote server wit LLM to generate text.
|
||||
"""
|
||||
@ -76,12 +94,11 @@ class Provider:
|
||||
while not is_complete:
|
||||
response = requests.get(f"http://{self.server}/get_updated_sentence")
|
||||
thought = response.json()["sentence"]
|
||||
# TODO add real time streaming to stdout
|
||||
is_complete = bool(response.json()["is_complete"])
|
||||
time.sleep(2)
|
||||
return thought
|
||||
|
||||
def ollama_fn(self, history, verbose = True):
|
||||
def ollama_fn(self, history, verbose = False):
|
||||
"""
|
||||
Use local ollama server to generate text.
|
||||
"""
|
||||
@ -104,9 +121,27 @@ class Provider:
|
||||
raise e
|
||||
return thought
|
||||
|
||||
def openai_fn(self, history, verbose=False):
|
||||
"""
|
||||
Use openai to generate text.
|
||||
"""
|
||||
api_key = self.get_api_key("openai")
|
||||
client = OpenAI(api_key=api_key)
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=history
|
||||
)
|
||||
thought = response.choices[0].message.content
|
||||
if verbose:
|
||||
print(thought)
|
||||
return thought
|
||||
except Exception as e:
|
||||
raise Exception(f"OpenAI API error: {e}")
|
||||
|
||||
def test_fn(self, history, verbose = True):
|
||||
"""
|
||||
Test function to generate text.
|
||||
This function is used to conduct tests.
|
||||
"""
|
||||
thought = """
|
||||
This is a test response from the test provider.
|
||||
@ -121,3 +156,7 @@ class Provider:
|
||||
```
|
||||
"""
|
||||
return thought
|
||||
|
||||
if __name__ == "__main__":
|
||||
provider = Provider("openai", "gpt-4o-mini")
|
||||
print(provider.respond(["user", "Hello, how are you?"]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user