feat : lm-studio integration

This commit is contained in:
martin legrand 2025-03-27 18:35:07 +01:00
parent bd951e19d3
commit 582462a73f

View File

@ -19,11 +19,12 @@ class Provider:
self.provider_name = provider_name.lower() self.provider_name = provider_name.lower()
self.model = model self.model = model
self.is_local = is_local self.is_local = is_local
self.server = self.check_address_format(server_address) self.server_ip = self.check_address_format(server_address)
self.available_providers = { self.available_providers = {
"ollama": self.ollama_fn, "ollama": self.ollama_fn,
"server": self.server_fn, "server": self.server_fn,
"openai": self.openai_fn, "openai": self.openai_fn,
"lm-studio": self.lm_studio_fn,
"huggingface": self.huggingface_fn, "huggingface": self.huggingface_fn,
"deepseek-api": self.deepseek_fn "deepseek-api": self.deepseek_fn
} }
@ -34,11 +35,11 @@ class Provider:
if self.provider_name in self.unsafe_providers: if self.provider_name in self.unsafe_providers:
pretty_print("Warning: you are using an API provider. You data will be sent to the cloud.", color="warning") pretty_print("Warning: you are using an API provider. You data will be sent to the cloud.", color="warning")
self.api_key = self.get_api_key(self.provider_name) self.api_key = self.get_api_key(self.provider_name)
elif self.server != "ollama": elif self.provider_name != "ollama":
pretty_print(f"Provider: {provider_name} initialized at {self.server}", color="success") pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success")
self.check_address_format(self.server) self.check_address_format(self.server_ip)
if not self.is_ip_online(self.server.split(':')[0]): if not self.is_ip_online(self.server_ip.split(':')[0]):
raise Exception(f"Server at {self.server} is offline.") raise Exception(f"Server at {self.server_ip} is offline.")
def get_api_key(self, provider): def get_api_key(self, provider):
load_dotenv() load_dotenv()
@ -73,7 +74,7 @@ class Provider:
try: try:
thought = llm(history, verbose) thought = llm(history, verbose)
except ConnectionError as e: except ConnectionError as e:
raise ConnectionError(f"{str(e)}\nConnection to {self.server} failed.") raise ConnectionError(f"{str(e)}\nConnection to {self.server_ip} failed.")
except AttributeError as e: except AttributeError as e:
raise NotImplementedError(f"{str(e)}\nIs {self.provider_name} implemented ?") raise NotImplementedError(f"{str(e)}\nIs {self.provider_name} implemented ?")
except Exception as e: except Exception as e:
@ -105,16 +106,16 @@ class Provider:
Use a remote server with LLM to generate text. Use a remote server with LLM to generate text.
""" """
thought = "" thought = ""
route_start = f"http://{self.server}/generate" route_start = f"http://{self.server_ip}/generate"
if not self.is_ip_online(self.server.split(":")[0]): if not self.is_ip_online(self.server_ip.split(":")[0]):
raise Exception(f"Server is offline at {self.server}") raise Exception(f"Server is offline at {self.server_ip}")
try: try:
requests.post(route_start, json={"messages": history}) requests.post(route_start, json={"messages": history})
is_complete = False is_complete = False
while not is_complete: while not is_complete:
response = requests.get(f"http://{self.server}/get_updated_sentence") response = requests.get(f"http://{self.server_ip}/get_updated_sentence")
thought = response.json()["sentence"] thought = response.json()["sentence"]
is_complete = bool(response.json()["is_complete"]) is_complete = bool(response.json()["is_complete"])
time.sleep(2) time.sleep(2)
@ -124,6 +125,7 @@ class Provider:
raise e raise e
return thought return thought
def ollama_fn(self, history, verbose = False): def ollama_fn(self, history, verbose = False):
""" """
Use local ollama server to generate text. Use local ollama server to generate text.
@ -170,8 +172,9 @@ class Provider:
""" """
Use openai to generate text. Use openai to generate text.
""" """
base_url = self.server_ip
if self.is_local: if self.is_local:
client = OpenAI(api_key=self.api_key, base_url=base_url) client = OpenAI(api_key=self.api_key, base_url=f"http://{base_url}")
else: else:
client = OpenAI(api_key=self.api_key) client = OpenAI(api_key=self.api_key)
@ -180,6 +183,8 @@ class Provider:
model=self.model, model=self.model,
messages=history, messages=history,
) )
if response is None:
raise Exception("OpenAI response is empty.")
thought = response.choices[0].message.content thought = response.choices[0].message.content
if verbose: if verbose:
print(thought) print(thought)
@ -204,6 +209,33 @@ class Provider:
return thought return thought
except Exception as e: except Exception as e:
raise Exception(f"Deepseek API error: {str(e)}") from e raise Exception(f"Deepseek API error: {str(e)}") from e
def lm_studio_fn(self, history, verbose = False):
"""
Use local lm-studio server to generate text.
lm studio use endpoint /v1/chat/completions not /chat/completions like openai
"""
thought = ""
route_start = f"http://{self.server_ip}/v1/chat/completions"
payload = {
"messages": history,
"temperature": 0.7,
"max_tokens": 4096,
"model": self.model
}
if not self.is_ip_online(self.server_ip.split(":")[0]):
raise Exception(f"Server is offline at {self.server_ip}")
try:
response = requests.post(route_start, json=payload)
result = response.json()
if verbose:
print("Response from LM Studio:", result)
return result.get("choices", [{}])[0].get("message", {}).get("content", "")
except requests.exceptions.RequestException as e:
raise Exception(f"HTTP request failed: {str(e)}") from e
except Exception as e:
raise Exception(f"An error occurred: {str(e)}") from e
return thought
def test_fn(self, history, verbose = True): def test_fn(self, history, verbose = True):
""" """