From d7d88da2d39e0dc26118bc1fd5f33fa1e5f914b8 Mon Sep 17 00:00:00 2001 From: Igor Benav Date: Thu, 31 Oct 2024 23:59:14 -0300 Subject: [PATCH] added tests for ollama manager --- tests/ollama/__init__.py | 185 +++++++++++++++ tests/ollama/test_manager.py | 420 +++++++++++++++++++++++++++++++++++ 2 files changed, 605 insertions(+) create mode 100644 tests/ollama/test_manager.py diff --git a/tests/ollama/__init__.py b/tests/ollama/__init__.py index e69de29..4f22f62 100644 --- a/tests/ollama/__init__.py +++ b/tests/ollama/__init__.py @@ -0,0 +1,185 @@ +import pytest + +from clientai.ollama import OllamaServerConfig + + +def test_default_config(): + """Test default configuration initialization.""" + config = OllamaServerConfig() + assert config.host == "127.0.0.1" + assert config.port == 11434 + assert config.timeout == 30 + assert config.check_interval == 1.0 + assert config.gpu_layers is None + assert config.compute_unit is None + assert config.cpu_threads is None + assert config.memory_limit is None + assert config.gpu_memory_fraction is None + assert config.gpu_devices is None + assert config.env_vars == {} + assert config.extra_args == [] + + +def test_host_validation(): + """Test host address validation.""" + OllamaServerConfig(host="127.0.0.1") + OllamaServerConfig(host="0.0.0.0") + OllamaServerConfig(host="localhost") + OllamaServerConfig(host="example.com") + + with pytest.raises(ValueError, match="Invalid host"): + OllamaServerConfig(host="invalid..host") + with pytest.raises(ValueError, match="Host cannot be empty"): + OllamaServerConfig(host="") + + +def test_port_validation(): + """Test port number validation.""" + OllamaServerConfig(port=1) + OllamaServerConfig(port=8080) + OllamaServerConfig(port=65535) + + with pytest.raises(ValueError, match="Port must be between 1 and 65535"): + OllamaServerConfig(port=0) + with pytest.raises(ValueError, match="Port must be between 1 and 65535"): + OllamaServerConfig(port=65536) + + +def test_timeout_and_interval_validation(): + """Test timeout and check interval validation.""" + OllamaServerConfig(timeout=30, check_interval=1.0) + OllamaServerConfig(timeout=60, check_interval=2.0) + + with pytest.raises(ValueError, match="Timeout must be positive"): + OllamaServerConfig(timeout=0) + with pytest.raises(ValueError, match="Check interval must be positive"): + OllamaServerConfig(check_interval=0) + with pytest.raises( + ValueError, match="Check interval cannot be greater than timeout" + ): + OllamaServerConfig(timeout=10, check_interval=20) + + +def test_gpu_settings_validation(): + """Test GPU-related settings validation.""" + OllamaServerConfig(gpu_layers=35) + OllamaServerConfig(gpu_memory_fraction=0.8) + OllamaServerConfig(gpu_devices=0) + OllamaServerConfig(gpu_devices=[0, 1]) + + with pytest.raises( + ValueError, match="gpu_layers must be a non-negative integer" + ): + OllamaServerConfig(gpu_layers=-1) + + with pytest.raises( + ValueError, match="gpu_memory_fraction must be between 0.0 and 1.0" + ): + OllamaServerConfig(gpu_memory_fraction=1.5) + + with pytest.raises(ValueError, match="GPU device ID must be non-negative"): + OllamaServerConfig(gpu_devices=-1) + with pytest.raises( + ValueError, match="All GPU device IDs must be non-negative integers" + ): + OllamaServerConfig(gpu_devices=[-1, 0]) + with pytest.raises( + ValueError, match="Duplicate GPU device IDs are not allowed" + ): + OllamaServerConfig(gpu_devices=[0, 0]) + with pytest.raises( + ValueError, match="gpu_devices must be an integer or list of integers" + ): + OllamaServerConfig(gpu_devices="0") + + +def test_compute_unit_validation(): + """Test compute unit validation.""" + OllamaServerConfig(compute_unit="cpu") + OllamaServerConfig(compute_unit="gpu") + OllamaServerConfig(compute_unit="auto") + + with pytest.raises(ValueError, match="compute_unit must be one of"): + OllamaServerConfig(compute_unit="invalid") + + +def test_cpu_threads_validation(): + """Test CPU threads validation.""" + OllamaServerConfig(cpu_threads=1) + OllamaServerConfig(cpu_threads=8) + + with pytest.raises( + ValueError, match="cpu_threads must be a positive integer" + ): + OllamaServerConfig(cpu_threads=0) + with pytest.raises( + ValueError, match="cpu_threads must be a positive integer" + ): + OllamaServerConfig(cpu_threads=-1) + + +def test_memory_limit_validation(): + """Test memory limit validation.""" + OllamaServerConfig(memory_limit="8GiB") + OllamaServerConfig(memory_limit="1024MiB") + OllamaServerConfig(memory_limit="1TiB") + OllamaServerConfig(memory_limit="0.5GiB") + + with pytest.raises(ValueError, match="memory_limit must be in format"): + OllamaServerConfig(memory_limit="8GB") + with pytest.raises(ValueError, match="memory_limit must be in format"): + OllamaServerConfig(memory_limit="8G") + with pytest.raises(ValueError, match="memory_limit must be in format"): + OllamaServerConfig(memory_limit="8 GiB") + + with pytest.raises( + ValueError, match="memory_limit in MiB must be at least 100" + ): + OllamaServerConfig(memory_limit="50MiB") + with pytest.raises( + ValueError, match="memory_limit in GiB must be at least 0.1" + ): + OllamaServerConfig(memory_limit="0.05GiB") + with pytest.raises( + ValueError, match="memory_limit in TiB must be at least 0.001" + ): + OllamaServerConfig(memory_limit="0.0005TiB") + + +def test_env_vars_validation(): + """Test environment variables validation.""" + OllamaServerConfig(env_vars={"KEY": "value"}) + OllamaServerConfig(env_vars={"MULTIPLE": "vars", "ARE": "valid"}) + + with pytest.raises( + ValueError, match="All environment variables must be strings" + ): + OllamaServerConfig(env_vars={"KEY": 123}) + with pytest.raises( + ValueError, match="All environment variables must be strings" + ): + OllamaServerConfig(env_vars={123: "value"}) + + +def test_extra_args_validation(): + """Test extra arguments validation.""" + OllamaServerConfig(extra_args=["--verbose"]) + OllamaServerConfig(extra_args=["--arg1", "--arg2"]) + + with pytest.raises( + ValueError, match="All extra arguments must be strings" + ): + OllamaServerConfig(extra_args=[123]) + with pytest.raises( + ValueError, match="All extra arguments must be strings" + ): + OllamaServerConfig(extra_args=["--valid", 123]) + + +def test_base_url_property(): + """Test base_url property.""" + config = OllamaServerConfig(host="localhost", port=8080) + assert config.base_url == "http://localhost:8080" + + config = OllamaServerConfig() + assert config.base_url == "http://127.0.0.1:11434" diff --git a/tests/ollama/test_manager.py b/tests/ollama/test_manager.py new file mode 100644 index 0000000..a43d2b8 --- /dev/null +++ b/tests/ollama/test_manager.py @@ -0,0 +1,420 @@ +import http.client +import subprocess +from typing import Any, cast +from unittest.mock import MagicMock, patch + +import pytest + +from clientai.ollama.manager import OllamaManager, OllamaServerConfig +from clientai.ollama.manager.exceptions import ( + ExecutableNotFoundError, + ServerStartupError, + ServerTimeoutError, +) +from clientai.ollama.manager.platform_info import ( + GPUVendor, + Platform, + PlatformInfo, +) + + +class MockProcess: + """Mock for subprocess.Popen that supports type hints.""" + + def __init__(self, returncode: int = 0) -> None: + self.returncode = returncode + self._poll_value = None + self.terminate = MagicMock() + self.wait = MagicMock() + self.communicate = MagicMock(return_value=("", "")) + + def poll(self) -> int | None: + """Mock poll method.""" + return self._poll_value + + +@pytest.fixture +def mock_subprocess(): + """Mock subprocess.Popen for server process management.""" + with patch( + "clientai.ollama.manager.core.subprocess", autospec=True + ) as mock: + process = MockProcess() + mock.Popen = MagicMock(return_value=process) + mock.CREATE_NO_WINDOW = 0x08000000 + mock.PIPE = subprocess.PIPE + yield mock + + +@pytest.fixture +def mock_http_client(): + """Mock http.client for server health checks.""" + with patch( + "clientai.ollama.manager.core.http.client.HTTPConnection" + ) as mock: + mock_conn = MagicMock() + mock_response = MagicMock() + mock_response.status = 200 + mock_conn.getresponse.return_value = mock_response + mock.return_value = mock_conn + yield mock + + +@pytest.fixture +def mock_platform_info(): + """Mock PlatformInfo for system information.""" + with patch( + "clientai.ollama.manager.core.PlatformInfo", autospec=True + ) as MockPlatformInfo: + platform_info = MagicMock(spec=PlatformInfo) + platform_info.platform = Platform.LINUX + platform_info.gpu_vendor = GPUVendor.NVIDIA + platform_info.cpu_count = 8 + + def get_environment(config: Any) -> dict[str, str]: + """Dynamic environment generation based on config.""" + env = { + "PATH": "/usr/local/bin", + } + if config.gpu_layers is not None: + env["OLLAMA_GPU_LAYERS"] = str(config.gpu_layers) + + if platform_info.gpu_vendor == GPUVendor.NVIDIA: + if config.gpu_devices is not None: + devices = ( + config.gpu_devices + if isinstance(config.gpu_devices, list) + else [config.gpu_devices] + ) + env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, devices)) + if config.gpu_memory_fraction is not None: + env["CUDA_MEM_FRACTION"] = str(config.gpu_memory_fraction) + + elif platform_info.gpu_vendor == GPUVendor.AMD: + if config.gpu_devices is not None: + devices = ( + config.gpu_devices + if isinstance(config.gpu_devices, list) + else [config.gpu_devices] + ) + env["GPU_DEVICE_ORDINAL"] = ",".join(map(str, devices)) + if config.gpu_memory_fraction is not None: + env["GPU_MAX_HEAP_SIZE"] = ( + f"{int(config.gpu_memory_fraction * 100)}%" + ) + + return env + + platform_info.get_environment.side_effect = get_environment + + def get_server_command(config: Any) -> list[str]: + """Get platform-specific server command.""" + base_cmd = ["ollama", "serve"] + if config.host != "127.0.0.1": + base_cmd.extend(["--host", config.host]) + if config.port != 11434: + base_cmd.extend(["--port", str(config.port)]) + return base_cmd + + platform_info.get_server_command.side_effect = get_server_command + + MockPlatformInfo.return_value = platform_info + yield platform_info + + +@pytest.fixture +def manager(mock_platform_info): + """Create a manager instance with test configuration.""" + config = OllamaServerConfig( + host="127.0.0.1", + port=11434, + gpu_layers=35, + gpu_memory_fraction=0.8, + gpu_devices=0, + ) + return OllamaManager(config) + + +def test_init_default_config(): + """Test manager initialization with default config.""" + manager = OllamaManager() + assert isinstance(manager.config, OllamaServerConfig) + assert manager.config.host == "127.0.0.1" + assert manager.config.port == 11434 + + +def test_init_custom_config(): + """Test manager initialization with custom config.""" + config = OllamaServerConfig(host="localhost", port=11435, gpu_layers=35) + manager = OllamaManager(config) + assert manager.config == config + + +def test_start_server_success( + manager, mock_subprocess, mock_http_client, mock_platform_info +): + """Test successful server startup.""" + manager.start() + + mock_subprocess.Popen.assert_called_once() + assert manager._process is not None + mock_platform_info.get_environment.assert_called_once_with(manager.config) + mock_platform_info.get_server_command.assert_called_once_with( + manager.config + ) + + +def test_start_server_executable_not_found(manager, mock_subprocess): + """Test error handling when Ollama executable is not found.""" + mock_subprocess.Popen.side_effect = FileNotFoundError() + + with pytest.raises(ExecutableNotFoundError) as exc_info: + manager.start() + + assert "Ollama executable not found" in str(exc_info.value) + assert manager._process is None + + +def test_start_server_already_running(manager): + """Test error when attempting to start an already running server.""" + manager._process = cast(subprocess.Popen[str], MockProcess()) + + with pytest.raises(ServerStartupError) as exc_info: + manager.start() + + assert "already running" in str(exc_info.value) + + +def test_stop_server_success(manager): + """Test successful server shutdown.""" + process = MockProcess() + manager._process = cast(subprocess.Popen[str], process) + + manager.stop() + + process.terminate.assert_called_once() + process.wait.assert_called_once() + assert manager._process is None + + +def test_stop_server_not_running(manager): + """Test stopping server when it's not running.""" + manager._process = None + manager.stop() + + +def test_context_manager(manager, mock_subprocess, mock_http_client): + """Test using manager as a context manager.""" + with manager as m: + assert m._process is not None + assert isinstance(m, OllamaManager) + assert m._process is None + + +@pytest.mark.parametrize( + "platform_type,expected_cmd", + [ + (Platform.WINDOWS, ["ollama.exe", "serve"]), + (Platform.LINUX, ["ollama", "serve"]), + (Platform.MACOS, ["ollama", "serve"]), + ], +) +def test_platform_specific_commands(platform_type, expected_cmd): + """Test platform-specific command generation.""" + with patch( + "clientai.ollama.manager.core.PlatformInfo", autospec=True + ) as MockPlatformInfo: + platform_info = MagicMock(spec=PlatformInfo) + platform_info.platform = platform_type + + def get_server_command(config): + base_cmd = ( + ["ollama.exe"] + if platform_type == Platform.WINDOWS + else ["ollama"] + ) + return base_cmd + ["serve"] + + platform_info.get_server_command.side_effect = get_server_command + MockPlatformInfo.return_value = platform_info + + manager = OllamaManager() + result = manager._platform_info.get_server_command(manager.config) + assert result == expected_cmd + + +@pytest.mark.parametrize( + "gpu_vendor,config,expected_env", + [ + ( + GPUVendor.NVIDIA, + OllamaServerConfig( + gpu_layers=35, gpu_memory_fraction=0.8, gpu_devices=0 + ), + { + "PATH": "/usr/local/bin", + "OLLAMA_GPU_LAYERS": "35", + "CUDA_VISIBLE_DEVICES": "0", + "CUDA_MEM_FRACTION": "0.8", + }, + ), + ( + GPUVendor.AMD, + OllamaServerConfig( + gpu_layers=35, gpu_memory_fraction=0.8, gpu_devices=0 + ), + { + "PATH": "/usr/local/bin", + "OLLAMA_GPU_LAYERS": "35", + "GPU_DEVICE_ORDINAL": "0", + "GPU_MAX_HEAP_SIZE": "80%", + }, + ), + ( + GPUVendor.APPLE, + OllamaServerConfig(gpu_layers=35), + {"PATH": "/usr/local/bin", "OLLAMA_GPU_LAYERS": "35"}, + ), + ( + GPUVendor.NONE, + OllamaServerConfig(gpu_layers=35), + {"PATH": "/usr/local/bin", "OLLAMA_GPU_LAYERS": "35"}, + ), + ], +) +def test_gpu_specific_environment(gpu_vendor, config, expected_env): + """Test GPU-specific environment variable generation.""" + with patch( + "clientai.ollama.manager.core.PlatformInfo", autospec=True + ) as MockPlatformInfo: + platform_info = MagicMock(spec=PlatformInfo) + platform_info.platform = Platform.LINUX + platform_info.gpu_vendor = gpu_vendor + + def get_environment(cfg): + env = { + "PATH": "/usr/local/bin", + } + + if cfg.gpu_layers is not None: + env["OLLAMA_GPU_LAYERS"] = str(cfg.gpu_layers) + + if gpu_vendor == GPUVendor.NVIDIA: + if cfg.gpu_devices is not None: + devices = ( + cfg.gpu_devices + if isinstance(cfg.gpu_devices, list) + else [cfg.gpu_devices] + ) + env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, devices)) + if cfg.gpu_memory_fraction is not None: + env["CUDA_MEM_FRACTION"] = str(cfg.gpu_memory_fraction) + + elif gpu_vendor == GPUVendor.AMD: + if cfg.gpu_devices is not None: + devices = ( + cfg.gpu_devices + if isinstance(cfg.gpu_devices, list) + else [cfg.gpu_devices] + ) + env["GPU_DEVICE_ORDINAL"] = ",".join(map(str, devices)) + if cfg.gpu_memory_fraction is not None: + env["GPU_MAX_HEAP_SIZE"] = ( + f"{int(cfg.gpu_memory_fraction * 100)}%" + ) + + return env + + platform_info.get_environment.side_effect = get_environment + MockPlatformInfo.return_value = platform_info + + manager = OllamaManager(config) + env = manager._platform_info.get_environment(config) + + for key, value in expected_env.items(): + assert ( + env[key] == value + ), f"Expected {key}={value}, got {env.get(key)}" + + for key in env: + assert ( + key in expected_env + ), f"Unexpected environment variable: {key}" + + +def test_health_check_error_handling( + manager, mock_subprocess, mock_http_client +): + """Test various HTTP health check error scenarios.""" + manager.config.timeout = 0.1 + + process = MockProcess() + process._poll_value = None + mock_subprocess.Popen.return_value = process + + mock_response = MagicMock() + mock_response.status = 500 + mock_http_client.return_value.getresponse.return_value = mock_response + + with pytest.raises(ServerTimeoutError): + manager.start() + + mock_http_client.return_value.request.side_effect = OSError() + + with pytest.raises(ServerTimeoutError): + manager.start() + + mock_http_client.return_value.request.side_effect = ( + http.client.HTTPException() + ) + + with pytest.raises(ServerTimeoutError): + manager.start() + + +def test_cleanup_on_error(manager, mock_subprocess, mock_http_client): + """Test proper cleanup when an error occurs during startup.""" + process = MockProcess() + mock_subprocess.Popen.return_value = process + + mock_http_client.return_value.request.side_effect = ( + http.client.HTTPException("Connection failed") + ) + manager.config.timeout = 0.1 + + with pytest.raises(ServerTimeoutError) as exc_info: + manager.start() + + assert manager._process is None + assert "did not start within" in str(exc_info.value) + mock_http_client.return_value.close.assert_called() + process.terminate.assert_called() + + +def test_resource_error_handling(manager, mock_subprocess): + """Test handling of resource allocation errors.""" + process = MockProcess(returncode=1) + process._poll_value = 1 + process.communicate.return_value = ("", "cannot allocate memory") + mock_subprocess.Popen.return_value = process + + with pytest.raises(ServerStartupError) as exc_info: + manager.start() + + assert "cannot allocate memory" in str(exc_info.value) + + +def test_custom_host_port(mock_subprocess, mock_http_client): + """Test server startup with custom host and port.""" + config = OllamaServerConfig(host="localhost", port=11435) + manager = OllamaManager(config) + + process = MockProcess() + mock_subprocess.Popen.return_value = process + + try: + manager.start() + except ServerTimeoutError: + pass + + mock_http_client.assert_called_with("localhost", 11435, timeout=5)