Skip to content

Commit

Permalink
Use fixtures for device properties. Follow existing style in fixture …
Browse files Browse the repository at this point in the history
…order. Mock subprocess.run for new tests
  • Loading branch information
TensorTemplar committed Oct 5, 2024
1 parent 03b33db commit bd592e5
Showing 1 changed file with 99 additions and 69 deletions.
168 changes: 99 additions & 69 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,28 @@ def test_file_size_above_limit_on_gpu():
assert size == 4_600_000_000


@pytest.fixture
def mock_cuda_is_available_true(monkeypatch):
"""Fixture to mock torch.cuda.is_available() to return True."""
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)


@pytest.fixture
def mock_nvidia_device_properties(monkeypatch):
"""Fixture to mock torch.cuda.get_device_properties() for NVIDIA GPUs."""
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "NVIDIA GeForce RTX 3090"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)


@pytest.fixture
def mock_amd_device_properties(monkeypatch):
"""Fixture to mock torch.cuda.get_device_properties() for AMD GPUs."""
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "amd instinct mi250x"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)


@pytest.fixture
def all_nvlink_connected_output():
return mock.MagicMock(
Expand All @@ -497,11 +519,13 @@ def all_nvlink_connected_output():


@mock.patch("subprocess.run")
def test_all_nvlink_connected(mock_run, all_nvlink_connected_output):
def test_all_nvlink_connected(
mock_run, all_nvlink_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties
):
mock_run.return_value = all_nvlink_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")


@pytest.fixture
Expand All @@ -522,14 +546,16 @@ def nvlink_partially_connected_output():


@mock.patch("subprocess.run")
def test_nvlink_partially_connected_output(mock_run, nvlink_partially_connected_output):
def test_nvlink_partially_connected_output(
mock_run, nvlink_partially_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties
):
mock_run.return_value = nvlink_partially_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)
mock_print.assert_any_call(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)


@pytest.fixture
Expand All @@ -555,14 +581,16 @@ def nvlink_not_connected_output():


@mock.patch("subprocess.run")
def test_nvlink_not_connected_output(mock_run, nvlink_not_connected_output):
def test_nvlink_not_connected_output(
mock_run, nvlink_not_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties
):
mock_run.return_value = nvlink_not_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)
mock_print.assert_any_call(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)


@pytest.fixture
Expand Down Expand Up @@ -616,6 +644,19 @@ def nvlink_all_gpu_connected_but_other_connected_output():
)


@mock.patch("subprocess.run")
def test_nvlink_all_gpu_connected_but_other_connected_output(
mock_run,
nvlink_all_gpu_connected_but_other_connected_output,
mock_cuda_is_available_true,
mock_nvidia_device_properties,
):
mock_run.return_value = nvlink_all_gpu_connected_but_other_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")


@pytest.fixture
def nvidia_smi_nvlink_output_dual_gpu_no_numa():
return """
Expand All @@ -625,6 +666,22 @@ def nvidia_smi_nvlink_output_dual_gpu_no_numa():
"""


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_fully_connected_when_nvidia_all_nvlink_two_gpus(
monkeypatch, nvidia_smi_nvlink_output_dual_gpu_no_numa
):
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "NVIDIA GeForce RTX 4090"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)

mock_run = mock.MagicMock(return_value=mock.Mock(stdout=nvidia_smi_nvlink_output_dual_gpu_no_numa, returncode=0))
with mock.patch("subprocess.run", mock_run):
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")


@pytest.fixture
def rocm_smi_xgmi_output_multi_gpu():
"""
Expand All @@ -646,14 +703,39 @@ def rocm_smi_xgmi_output_multi_gpu():
"""


@mock.patch("subprocess.run")
def test_nvlink_all_gpu_connected_but_other_connected_output(
mock_run, nvlink_all_gpu_connected_but_other_connected_output
def test_check_nvlink_connectivity_returns_fully_connected_when_amd_all_xgmi_8_gpus(
monkeypatch, rocm_smi_xgmi_output_multi_gpu
):
mock_run.return_value = nvlink_all_gpu_connected_but_other_connected_output
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "amd instinct mi250x" # ROCM 6.0.3
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)

mock_run = mock.MagicMock(return_value=mock.Mock(stdout=rocm_smi_xgmi_output_multi_gpu, returncode=0))
with mock.patch("subprocess.run", mock_run):
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via XGMI.")


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity_returns_no_gpus_when_no_gpus(mock_run, monkeypatch):
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")
mock_print.assert_any_call("No GPUs available")


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity_returns_unrecognized_vendor_when_unrecognized_vendor(mock_run, monkeypatch):
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "Unknown GPU Vendor"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)

with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("Unrecognized GPU vendor: Unknown GPU Vendor")


def test_fix_and_load_json():
Expand Down Expand Up @@ -708,55 +790,3 @@ def test_fix_and_load_json():

result_missing_commas = fix_and_load_json(invalid_json_missing_commas)
assert result_missing_commas == expected_output_missing_commas


def test_check_nvlink_connectivity__returns_fully_connected_when_nvidia_all_nvlink_two_gpus(
monkeypatch, nvidia_smi_nvlink_output_dual_gpu_no_numa
):
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "NVIDIA GeForce RTX 3090"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)

mock_run = mock.MagicMock(return_value=mock.Mock(stdout=nvidia_smi_nvlink_output_dual_gpu_no_numa, returncode=0))
with mock.patch("subprocess.run", mock_run):
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")


def test_check_nvlink_connectivity_returns_fully_connected_when_amd_all_xgmi_8_gpus(
monkeypatch, rocm_smi_xgmi_output_multi_gpu
):
# Mock the GPU device properties to simulate AMD GPUs
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "amd instinct mi250x" # ROCM 6.0.3
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)

mock_run = mock.MagicMock(return_value=mock.Mock(stdout=rocm_smi_xgmi_output_multi_gpu, returncode=0))
with mock.patch("subprocess.run", mock_run):
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via XGMI.")


def test_check_nvlink_connectivity_returns_no_gpus_when_no_gpus(monkeypatch):
# Mock torch.cuda.is_available to return False
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)

with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("No GPUs available")


def test_check_nvlink_connectivity_returns_unrecognized_vendor_when_unrecognized_vendor(monkeypatch):
# Mock the GPU device properties to simulate an unrecognized GPU vendor
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "Unknown GPU Vendor"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)

with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("Unrecognized GPU vendor: Unknown GPU Vendor")

0 comments on commit bd592e5

Please sign in to comment.