diff --git a/scripts/windows_install.bat b/scripts/windows_install.bat index fc7a86f..55d456d 100755 --- a/scripts/windows_install.bat +++ b/scripts/windows_install.bat @@ -2,6 +2,7 @@ echo Starting installation for Windows... REM Install Python dependencies from requirements.txt +pip install pyreadline3 pip install -r requirements.txt REM Install Selenium diff --git a/sources/llm_provider.py b/sources/llm_provider.py index a0ecb3e..a04d620 100644 --- a/sources/llm_provider.py +++ b/sources/llm_provider.py @@ -45,8 +45,6 @@ class Provider: self.api_key = self.get_api_key(self.provider_name) elif self.provider_name != "ollama": pretty_print(f"Provider: {provider_name} initialized at {self.server_ip}", color="success") - if not self.is_ip_online(self.server_ip.split(':')[0]): - raise Exception(f"Server at {self.server_ip} is offline.") def get_api_key(self, provider): load_dotenv() @@ -86,9 +84,11 @@ class Provider: """ if not address: return False - if address.lower() in ["127.0.0.1", "localhost", "0.0.0.0"]: + parsed = urlparse(address if address.startswith(('http://', 'https://')) else f'http://{address}') + + hostname = parsed.hostname or address + if "127.0.0.1" in address or "localhost" in address: return True - hostname = urlparse(f'http://{address}' if not address.startswith(('http://', 'https://')) else address).hostname or address try: ip_address = socket.gethostbyname(hostname) except socket.gaierror: @@ -102,15 +102,16 @@ class Provider: except (subprocess.TimeoutExpired, subprocess.SubprocessError) as e: return False + def server_fn(self, history, verbose = False): """ Use a remote server with LLM to generate text. """ thought = "" - route_setup = f"http://{self.server_ip}/setup" - route_gen = f"http://{self.server_ip}/generate" + route_setup = f"{self.server_ip}/setup" + route_gen = f"{self.server_ip}/generate" - if not self.is_ip_online(self.server_ip.split(":")[0]): + if not self.is_ip_online(self.server_ip): pretty_print(f"Server is offline at {self.server_ip}", color="failure") try: @@ -119,7 +120,7 @@ class Provider: is_complete = False while not is_complete: try: - response = requests.get(f"http://{self.server_ip}/get_updated_sentence") + response = requests.get(f"{self.server_ip}/get_updated_sentence") if "error" in response.json(): pretty_print(response.json()["error"], color="failure") break @@ -275,15 +276,13 @@ class Provider: lm studio use endpoint /v1/chat/completions not /chat/completions like openai """ thought = "" - route_start = f"http://{self.server_ip}/v1/chat/completions" + route_start = f"{self.server_ip}/v1/chat/completions" payload = { "messages": history, "temperature": 0.7, "max_tokens": 4096, "model": self.model } - if not self.is_ip_online(self.server_ip.split(":")[0]): - raise Exception(f"Server is offline at {self.server_ip}") try: response = requests.post(route_start, json=payload) result = response.json() diff --git a/tests/test_provider.py b/tests/test_provider.py new file mode 100644 index 0000000..b1cebb2 --- /dev/null +++ b/tests/test_provider.py @@ -0,0 +1,70 @@ +import unittest +from unittest.mock import patch, MagicMock +import os, sys +import socket +import subprocess +from urllib.parse import urlparse +import platform + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) # Add project root to Python path + +from sources.llm_provider import Provider + +class TestIsIpOnline(unittest.TestCase): + def setUp(self): + self.checker = Provider("ollama", "deepseek-r1:32b") + + def test_empty_address(self): + """Test with empty address""" + result = self.checker.is_ip_online("") + self.assertFalse(result) + + def test_localhost(self): + """Test with localhost""" + test_cases = [ + "localhost", + "http://localhost", + "https://localhost", + "127.0.0.1", + "http://127.0.0.1", + "https://127.0.0.1:8080" + ] + for address in test_cases: + with self.subTest(address=address): + result = self.checker.is_ip_online(address) + self.assertTrue(result) + + def test_google_ips(self): + """Test with known Google IPs""" + google_ips = [ + "74.125.197.100", + "74.125.197.139", + "74.125.197.101", + "74.125.197.113", + "74.125.197.102", + "74.125.197.138" + ] + for ip in google_ips: + with self.subTest(ip=ip), \ + patch('socket.gethostbyname', return_value=ip), \ + patch('subprocess.run', return_value=MagicMock(returncode=0)): + result = self.checker.is_ip_online(ip) + self.assertTrue(result) + + def test_unresolvable_hostname(self): + """Test with unresolvable hostname""" + address = "nonexistent.example.com" + with patch('socket.gethostbyname', side_effect=socket.gaierror): + result = self.checker.is_ip_online(address) + self.assertFalse(result) + + def test_valid_domain(self): + """Test with valid domain name""" + address = "google.com" + with patch('socket.gethostbyname', return_value="142.250.190.78"), \ + patch('subprocess.run', return_value=MagicMock(returncode=0)): + result = self.checker.is_ip_online(address) + self.assertTrue(result) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file