mirror of
https://github.com/tcsenpai/agenticSeek.git
synced 2025-06-07 19:45:27 +00:00
feat : lm-studio integration
This commit is contained in:
parent
bd951e19d3
commit
582462a73f
@ -19,11 +19,12 @@ class Provider:
|
|||||||
self.provider_name = provider_name.lower()
|
self.provider_name = provider_name.lower()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.is_local = is_local
|
self.is_local = is_local
|
||||||
self.server = self.check_address_format(server_address)
|
self.server_ip = self.check_address_format(server_address)
|
||||||
self.available_providers = {
|
self.available_providers = {
|
||||||
"ollama": self.ollama_fn,
|
"ollama": self.ollama_fn,
|
||||||
"server": self.server_fn,
|
"server": self.server_fn,
|
||||||
"openai": self.openai_fn,
|
"openai": self.openai_fn,
|
||||||
|
"lm-studio": self.lm_studio_fn,
|
||||||
"huggingface": self.huggingface_fn,
|
"huggingface": self.huggingface_fn,
|
||||||
"deepseek-api": self.deepseek_fn
|
"deepseek-api": self.deepseek_fn
|
||||||
}
|
}
|
||||||
@ -34,11 +35,11 @@ class Provider:
|
|||||||
if self.provider_name in self.unsafe_providers:
|
if self.provider_name in self.unsafe_providers:
|
||||||
pretty_print("Warning: you are using an API provider. You data will be sent to the cloud.", color="warning")
|
pretty_print("Warning: you are using an API provider. You data will be sent to the cloud.", color="warning")
|
||||||
self.api_key = self.get_api_key(self.provider_name)
|
self.api_key = self.get_api_key(self.provider_name)
|
||||||
elif self.server != "ollama":
|
elif self.provider_name != "ollama":
|
||||||
pretty_print(f"Provider: {provider_name} initialized at {self.server}", color="success")
|
pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success")
|
||||||
self.check_address_format(self.server)
|
self.check_address_format(self.server_ip)
|
||||||
if not self.is_ip_online(self.server.split(':')[0]):
|
if not self.is_ip_online(self.server_ip.split(':')[0]):
|
||||||
raise Exception(f"Server at {self.server} is offline.")
|
raise Exception(f"Server at {self.server_ip} is offline.")
|
||||||
|
|
||||||
def get_api_key(self, provider):
|
def get_api_key(self, provider):
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@ -73,7 +74,7 @@ class Provider:
|
|||||||
try:
|
try:
|
||||||
thought = llm(history, verbose)
|
thought = llm(history, verbose)
|
||||||
except ConnectionError as e:
|
except ConnectionError as e:
|
||||||
raise ConnectionError(f"{str(e)}\nConnection to {self.server} failed.")
|
raise ConnectionError(f"{str(e)}\nConnection to {self.server_ip} failed.")
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise NotImplementedError(f"{str(e)}\nIs {self.provider_name} implemented ?")
|
raise NotImplementedError(f"{str(e)}\nIs {self.provider_name} implemented ?")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -105,16 +106,16 @@ class Provider:
|
|||||||
Use a remote server with LLM to generate text.
|
Use a remote server with LLM to generate text.
|
||||||
"""
|
"""
|
||||||
thought = ""
|
thought = ""
|
||||||
route_start = f"http://{self.server}/generate"
|
route_start = f"http://{self.server_ip}/generate"
|
||||||
|
|
||||||
if not self.is_ip_online(self.server.split(":")[0]):
|
if not self.is_ip_online(self.server_ip.split(":")[0]):
|
||||||
raise Exception(f"Server is offline at {self.server}")
|
raise Exception(f"Server is offline at {self.server_ip}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
requests.post(route_start, json={"messages": history})
|
requests.post(route_start, json={"messages": history})
|
||||||
is_complete = False
|
is_complete = False
|
||||||
while not is_complete:
|
while not is_complete:
|
||||||
response = requests.get(f"http://{self.server}/get_updated_sentence")
|
response = requests.get(f"http://{self.server_ip}/get_updated_sentence")
|
||||||
thought = response.json()["sentence"]
|
thought = response.json()["sentence"]
|
||||||
is_complete = bool(response.json()["is_complete"])
|
is_complete = bool(response.json()["is_complete"])
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
@ -124,6 +125,7 @@ class Provider:
|
|||||||
raise e
|
raise e
|
||||||
return thought
|
return thought
|
||||||
|
|
||||||
|
|
||||||
def ollama_fn(self, history, verbose = False):
|
def ollama_fn(self, history, verbose = False):
|
||||||
"""
|
"""
|
||||||
Use local ollama server to generate text.
|
Use local ollama server to generate text.
|
||||||
@ -170,8 +172,9 @@ class Provider:
|
|||||||
"""
|
"""
|
||||||
Use openai to generate text.
|
Use openai to generate text.
|
||||||
"""
|
"""
|
||||||
|
base_url = self.server_ip
|
||||||
if self.is_local:
|
if self.is_local:
|
||||||
client = OpenAI(api_key=self.api_key, base_url=base_url)
|
client = OpenAI(api_key=self.api_key, base_url=f"http://{base_url}")
|
||||||
else:
|
else:
|
||||||
client = OpenAI(api_key=self.api_key)
|
client = OpenAI(api_key=self.api_key)
|
||||||
|
|
||||||
@ -180,6 +183,8 @@ class Provider:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
messages=history,
|
messages=history,
|
||||||
)
|
)
|
||||||
|
if response is None:
|
||||||
|
raise Exception("OpenAI response is empty.")
|
||||||
thought = response.choices[0].message.content
|
thought = response.choices[0].message.content
|
||||||
if verbose:
|
if verbose:
|
||||||
print(thought)
|
print(thought)
|
||||||
@ -204,6 +209,33 @@ class Provider:
|
|||||||
return thought
|
return thought
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Deepseek API error: {str(e)}") from e
|
raise Exception(f"Deepseek API error: {str(e)}") from e
|
||||||
|
|
||||||
|
def lm_studio_fn(self, history, verbose = False):
|
||||||
|
"""
|
||||||
|
Use local lm-studio server to generate text.
|
||||||
|
lm studio use endpoint /v1/chat/completions not /chat/completions like openai
|
||||||
|
"""
|
||||||
|
thought = ""
|
||||||
|
route_start = f"http://{self.server_ip}/v1/chat/completions"
|
||||||
|
payload = {
|
||||||
|
"messages": history,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"model": self.model
|
||||||
|
}
|
||||||
|
if not self.is_ip_online(self.server_ip.split(":")[0]):
|
||||||
|
raise Exception(f"Server is offline at {self.server_ip}")
|
||||||
|
try:
|
||||||
|
response = requests.post(route_start, json=payload)
|
||||||
|
result = response.json()
|
||||||
|
if verbose:
|
||||||
|
print("Response from LM Studio:", result)
|
||||||
|
return result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
raise Exception(f"HTTP request failed: {str(e)}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"An error occurred: {str(e)}") from e
|
||||||
|
return thought
|
||||||
|
|
||||||
def test_fn(self, history, verbose = True):
|
def test_fn(self, history, verbose = True):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user