Skip to content

Commit

Permalink
change version to tuple format
Browse files Browse the repository at this point in the history
  • Loading branch information
Lzy17 committed Jan 13, 2025
1 parent d673950 commit 8fa795a
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def supports_igemmlt(device: torch.device) -> bool:
if device == torch.device("cpu"):
return True
if torch.version.hip:
return False if get_compute_capabilities() < 601 else True
return False if get_compute_capabilities() < (6, 1) else True
if get_compute_capabilities() < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
Expand Down
6 changes: 3 additions & 3 deletions bitsandbytes/gpu_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
@dataclasses.dataclass(frozen=True)
class GPUSpecs:
gpu_backend: str
highest_compute_capability: Union[int, Tuple[int, int]]
highest_compute_capability: Tuple[int, int]
backend_version_string: str
backend_version_tuple: Tuple[int, int]

@property
def enable_blaslt(self) -> bool:
if torch.version.hip:
return self.highest_compute_capability >= 601
return self.highest_compute_capability >= (6, 1)
else:
return self.highest_compute_capability >= (7, 5)

Expand All @@ -32,7 +32,7 @@ def get_gpu_backend() -> str:
def get_compute_capabilities() -> Union[int, Tuple[int, int]]:
if torch.version.hip:
hip_major, hip_minor = get_backend_version_tuple()
return hip_major * 100 + hip_minor
return (hip_major, hip_minor)
else:
return sorted(
torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count())
Expand Down
2 changes: 1 addition & 1 deletion tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
assert (idx == 0).sum().item() < n * 0.02


@pytest.mark.skipif(0 < get_compute_capabilities() < 601, reason="this test is supported on ROCm from 6.1")
@pytest.mark.skipif((0, 0) < get_compute_capabilities() < (6, 1), reason="this test is supported on ROCm from 6.1")
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
Expand Down
8 changes: 4 additions & 4 deletions tests/test_cuda_setup_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
@pytest.fixture
def cuda120_spec() -> GPUSpecs:
return GPUSpecs(
cuda_version_string="120",
backend_version_string="120",
highest_compute_capability=(8, 6),
cuda_version_tuple=(12, 0),
backend_version_tuple=(12, 0),
)


@pytest.fixture
def cuda111_noblas_spec() -> GPUSpecs:
return GPUSpecs(
cuda_version_string="111",
backend_version_string="111",
highest_compute_capability=(7, 2),
cuda_version_tuple=(11, 1),
backend_version_tuple=(11, 1),
)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def test_vector_quant(dim1, dim2, dim3):
assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002))


@pytest.mark.skipif(0 < get_compute_capabilities() < 601, reason="this test is supported on ROCm from 6.1")
@pytest.mark.skipif((0, 0) < get_compute_capabilities() < (6, 1), reason="this test is supported on ROCm from 6.1")
@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3"))
Expand Down Expand Up @@ -1818,7 +1818,7 @@ def quant_zp(x):
print(err1, err2, err3, err4, err5, err6)


@pytest.mark.skipif(0 < get_compute_capabilities() < 601, reason="this test is supported on ROCm from 6.1")
@pytest.mark.skipif((0, 0) < get_compute_capabilities() < (6, 1), reason="this test is supported on ROCm from 6.1")
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_extract_outliers(device):
for i in range(k):
Expand Down

0 comments on commit 8fa795a

Please sign in to comment.