From 582462a73f79dce06007c219a5c2c9436211acb8 Mon Sep 17 00:00:00 2001 From: martin legrand Date: Thu, 27 Mar 2025 18:35:07 +0100 Subject: [PATCH] feat : lm-studio integration --- sources/llm_provider.py | 56 ++++++++++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/sources/llm_provider.py b/sources/llm_provider.py index 66addbf..4eaccc4 100644 --- a/sources/llm_provider.py +++ b/sources/llm_provider.py @@ -19,11 +19,12 @@ class Provider: self.provider_name = provider_name.lower() self.model = model self.is_local = is_local - self.server = self.check_address_format(server_address) + self.server_ip = self.check_address_format(server_address) self.available_providers = { "ollama": self.ollama_fn, "server": self.server_fn, "openai": self.openai_fn, + "lm-studio": self.lm_studio_fn, "huggingface": self.huggingface_fn, "deepseek-api": self.deepseek_fn } @@ -34,11 +35,11 @@ class Provider: if self.provider_name in self.unsafe_providers: pretty_print("Warning: you are using an API provider. You data will be sent to the cloud.", color="warning") self.api_key = self.get_api_key(self.provider_name) - elif self.server != "ollama": - pretty_print(f"Provider: {provider_name} initialized at {self.server}", color="success") - self.check_address_format(self.server) - if not self.is_ip_online(self.server.split(':')[0]): - raise Exception(f"Server at {self.server} is offline.") + elif self.provider_name != "ollama": + pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success") + self.check_address_format(self.server_ip) + if not self.is_ip_online(self.server_ip.split(':')[0]): + raise Exception(f"Server at {self.server_ip} is offline.") def get_api_key(self, provider): load_dotenv() @@ -73,7 +74,7 @@ class Provider: try: thought = llm(history, verbose) except ConnectionError as e: - raise ConnectionError(f"{str(e)}\nConnection to {self.server} failed.") + raise ConnectionError(f"{str(e)}\nConnection to {self.server_ip} failed.") except AttributeError as e: raise NotImplementedError(f"{str(e)}\nIs {self.provider_name} implemented ?") except Exception as e: @@ -105,16 +106,16 @@ class Provider: Use a remote server with LLM to generate text. """ thought = "" - route_start = f"http://{self.server}/generate" + route_start = f"http://{self.server_ip}/generate" - if not self.is_ip_online(self.server.split(":")[0]): - raise Exception(f"Server is offline at {self.server}") + if not self.is_ip_online(self.server_ip.split(":")[0]): + raise Exception(f"Server is offline at {self.server_ip}") try: requests.post(route_start, json={"messages": history}) is_complete = False while not is_complete: - response = requests.get(f"http://{self.server}/get_updated_sentence") + response = requests.get(f"http://{self.server_ip}/get_updated_sentence") thought = response.json()["sentence"] is_complete = bool(response.json()["is_complete"]) time.sleep(2) @@ -124,6 +125,7 @@ class Provider: raise e return thought + def ollama_fn(self, history, verbose = False): """ Use local ollama server to generate text. @@ -170,8 +172,9 @@ class Provider: """ Use openai to generate text. """ + base_url = self.server_ip if self.is_local: - client = OpenAI(api_key=self.api_key, base_url=base_url) + client = OpenAI(api_key=self.api_key, base_url=f"http://{base_url}") else: client = OpenAI(api_key=self.api_key) @@ -180,6 +183,8 @@ class Provider: model=self.model, messages=history, ) + if response is None: + raise Exception("OpenAI response is empty.") thought = response.choices[0].message.content if verbose: print(thought) @@ -204,6 +209,33 @@ class Provider: return thought except Exception as e: raise Exception(f"Deepseek API error: {str(e)}") from e + + def lm_studio_fn(self, history, verbose = False): + """ + Use local lm-studio server to generate text. + lm studio use endpoint /v1/chat/completions not /chat/completions like openai + """ + thought = "" + route_start = f"http://{self.server_ip}/v1/chat/completions" + payload = { + "messages": history, + "temperature": 0.7, + "max_tokens": 4096, + "model": self.model + } + if not self.is_ip_online(self.server_ip.split(":")[0]): + raise Exception(f"Server is offline at {self.server_ip}") + try: + response = requests.post(route_start, json=payload) + result = response.json() + if verbose: + print("Response from LM Studio:", result) + return result.get("choices", [{}])[0].get("message", {}).get("content", "") + except requests.exceptions.RequestException as e: + raise Exception(f"HTTP request failed: {str(e)}") from e + except Exception as e: + raise Exception(f"An error occurred: {str(e)}") from e + return thought def test_fn(self, history, verbose = True): """