Skip to content

Commit

Permalink
implement GPUSpecs
Browse files Browse the repository at this point in the history
  • Loading branch information
Lzy17 committed Nov 4, 2024
1 parent 4aad810 commit 98be798
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 69 deletions.
6 changes: 3 additions & 3 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch

from bitsandbytes.cextension import BNB_HIP_VERSION
from bitsandbytes.gpu_specs import get_compute_capabilities
import bitsandbytes.functional as F


Expand Down Expand Up @@ -224,8 +224,8 @@ def supports_igemmlt(device: torch.device) -> bool:
if device == torch.device("cpu"):
return True
if torch.version.hip:
return False if BNB_HIP_VERSION < 601 else True
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False if get_compute_capabilities() < 601 else True
if get_compute_capabilities() < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from .base import Backend

if lib and lib.compiled_with_cuda:
if lib and lib.compiled_with_gpu:
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {
"adam": (
Expand Down
51 changes: 23 additions & 28 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,28 @@
import torch

from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_rocm_gpu_arch
from bitsandbytes.gpu_specs import GPUSpecs, get_gpu_specs, get_rocm_gpu_arch

logger = logging.getLogger(__name__)


def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
def get_gpu_bnb_library_path(gpu_specs: GPUSpecs) -> Path:
"""
Get the disk path to the CUDA BNB native library specified by the
given CUDA specs, taking into account the `BNB_CUDA_VERSION` override environment variable.
Get the disk path to the GPU BNB native library specified by the
given GPU specs, taking into account the `BNB_GPU_VERSION` override environment variable.
The library is not guaranteed to exist at the returned path.
"""
if torch.version.hip:
if BNB_HIP_VERSION < 601:
return PACKAGE_DIR / f"libbitsandbytes_rocm{BNB_HIP_VERSION_SHORT}_nohipblaslt{DYNAMIC_LIBRARY_SUFFIX}"
else:
return PACKAGE_DIR / f"libbitsandbytes_rocm{BNB_HIP_VERSION_SHORT}{DYNAMIC_LIBRARY_SUFFIX}"
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}"
if not cuda_specs.has_cublaslt:
library_name = f"libbitsandbytes_{gpu_specs.gpu_backend}{gpu_specs.backend_version_string}"
if not gpu_specs.has_blaslt:
# if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt
library_name += "_nocublaslt"
if gpu_specs.gpu_backend == "rocm":
library_name += "_nohipblaslt"
else:
library_name += "_nocublaslt"
library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}"

# Do I need to change it to BNB_GPU_VERSION here? IGNORE FOR NOW!
override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value:
library_name_stem, _, library_name_ext = library_name.rpartition(".")
Expand All @@ -69,7 +68,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:

class BNBNativeLibrary:
_lib: ct.CDLL
compiled_with_cuda = False
compiled_with_gpu = False

def __init__(self, lib: ct.CDLL):
self._lib = lib
Expand All @@ -78,8 +77,8 @@ def __getattr__(self, item):
return getattr(self._lib, item)


class CudaBNBNativeLibrary(BNBNativeLibrary):
compiled_with_cuda = True
class GpuBNBNativeLibrary(BNBNativeLibrary):
compiled_with_gpu = True

def __init__(self, lib: ct.CDLL):
super().__init__(lib)
Expand All @@ -93,18 +92,18 @@ def __init__(self, lib: ct.CDLL):

def get_native_library() -> BNBNativeLibrary:
binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}"
cuda_specs = get_cuda_specs()
if cuda_specs:
cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)
if cuda_binary_path.exists():
binary_path = cuda_binary_path
gpu_specs = get_gpu_specs()
if gpu_specs:
gpu_binary_path = get_gpu_bnb_library_path(gpu_specs)
if gpu_binary_path.exists():
binary_path = gpu_binary_path
else:
logger.warning("Could not find the bitsandbytes %s binary at %r", BNB_BACKEND, cuda_binary_path)
logger.warning("Could not find the bitsandbytes %s binary at %r", gpu_specs.gpu_backend, gpu_binary_path)
logger.debug(f"Loading bitsandbytes native library from: {binary_path}")
dll = ct.cdll.LoadLibrary(str(binary_path))

if hasattr(dll, "get_context"): # only a CUDA-built library exposes this
return CudaBNBNativeLibrary(dll)
return GpuBNBNativeLibrary(dll)

return BNBNativeLibrary(dll)

Expand All @@ -113,15 +112,11 @@ def get_native_library() -> BNBNativeLibrary:

try:
if torch.version.hip:
hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2])
HIP_ENVIRONMENT, BNB_HIP_VERSION = True, hip_major * 100 + hip_minor
BNB_HIP_VERSION_SHORT = f"{hip_major}{hip_minor}"
BNB_BACKEND = "ROCm"
HIP_ENVIRONMENT = True
else:
HIP_ENVIRONMENT, BNB_HIP_VERSION = False, 0
BNB_HIP_VERSION_SHORT = ""
BNB_BACKEND = "CUDA"

HIP_ENVIRONMENT = False
lib = get_native_library()
except Exception as e:
lib = None
Expand Down
25 changes: 1 addition & 24 deletions bitsandbytes/cuda_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,4 @@ def get_cuda_specs() -> Optional[CUDASpecs]:
highest_compute_capability=(get_compute_capabilities()[-1]),
cuda_version_string=(get_cuda_version_string()),
cuda_version_tuple=get_cuda_version_tuple(),
)


def get_rocm_gpu_arch() -> str:
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
if match:
return "gfx" + match.group(1)
else:
return "unknown"
else:
return "unknown"
except Exception as e:
logger.error(f"Could not detect ROCm GPU architecture: {e}")
if torch.cuda.is_available():
logger.warning(
"""
ROCm GPU architecture detection failed despite ROCm being available.
""",
)
return "unknown"
)
2 changes: 1 addition & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def prod(iterable):

name2qmap = {}

if lib and lib.compiled_with_cuda:
if lib and lib.compiled_with_gpu:
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {
"adam": (
Expand Down
86 changes: 86 additions & 0 deletions bitsandbytes/gpu_specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import dataclasses
import logging
import re
import subprocess
from typing import List, Optional, Tuple, Union

import torch


@dataclasses.dataclass(frozen=True)
class GPUSpecs:
gpu_backend: str
highest_compute_capability: Union[int, Tuple[int, int]]
backend_version_string: str
backend_version_tuple: Tuple[int, int]

@property
def has_blaslt(self) -> bool:
if torch.version.hip:
return self.highest_compute_capability >= 601
else:
return self.highest_compute_capability >= (7, 5)


def get_gpu_backend() -> str:
if torch.version.hip:
return "rocm"
else:
return "cuda"


def get_compute_capabilities() -> Union[int, Tuple[int, int]]:
if torch.version.hip:
hip_major, hip_minor = get_backend_version_tuple()
return hip_major * 100 + hip_minor
else:
return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count()))[-1]


def get_backend_version_tuple() -> Tuple[int, int]:
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
if torch.version.cuda:
major, minor = map(int, torch.version.cuda.split("."))
elif torch.version.hip:
major, minor = map(int, torch.version.hip.split(".")[0:2])
return major, minor


def get_backend_version_string() -> str:
major, minor = get_backend_version_tuple()
return f"{major}{minor}"


def get_gpu_specs() -> Optional[GPUSpecs]:
if not torch.cuda.is_available():
return None

return GPUSpecs(
gpu_backend=get_gpu_backend(),
highest_compute_capability=(get_compute_capabilities()),
backend_version_string=(get_backend_version_string()),
backend_version_tuple=get_backend_version_tuple(),
)


def get_rocm_gpu_arch() -> str:
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
if match:
return "gfx" + match.group(1)
else:
return "unknown"
else:
return "unknown"
except Exception as e:
logger.error(f"Could not detect ROCm GPU architecture: {e}")
if torch.cuda.is_available():
logger.warning(
"""
ROCm GPU architecture detection failed despite ROCm being available.
""",
)
return "unknown"
4 changes: 2 additions & 2 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

import bitsandbytes as bnb
from bitsandbytes.cextension import BNB_HIP_VERSION
from bitsandbytes.gpu_specs import get_compute_capabilities
from tests.helpers import (
BOOLEAN_TRIPLES,
BOOLEAN_TUPLES,
Expand Down Expand Up @@ -199,7 +199,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool
assert (idx == 0).sum().item() < n * 0.02


@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1")
@pytest.mark.skipif(0 < get_compute_capabilities() < 601, reason="this test is supported on ROCm from 6.1")
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
Expand Down
14 changes: 7 additions & 7 deletions tests/test_cuda_setup_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path
from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path
from bitsandbytes.cuda_specs import CUDASpecs


Expand All @@ -23,19 +23,19 @@ def cuda111_noblas_spec() -> CUDASpecs:


@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
def test_get_gpu_bnb_library_path(monkeypatch, cuda120_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120"
assert get_gpu_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120"


@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
def test_get_gpu_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
assert get_gpu_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?


@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
def test_get_gpu_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"
assert get_gpu_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"
7 changes: 4 additions & 3 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.cextension import BNB_HIP_VERSION, HIP_ENVIRONMENT, ROCM_GPU_ARCH
from bitsandbytes.gpu_specs import get_compute_capabilities
from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH
from tests.helpers import (
BOOLEAN_TUPLES,
TRUE_FALSE,
Expand Down Expand Up @@ -512,7 +513,7 @@ def test_vector_quant(dim1, dim2, dim3):
assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002))


@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1")
@pytest.mark.skipif(0 < get_compute_capabilities() < 601, reason="this test is supported on ROCm from 6.1")
@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3"))
Expand Down Expand Up @@ -1817,7 +1818,7 @@ def quant_zp(x):
print(err1, err2, err3, err4, err5, err6)


@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1")
@pytest.mark.skipif(0 < get_compute_capabilities() < 601, reason="this test is supported on ROCm from 6.1")
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_extract_outliers(device):
for i in range(k):
Expand Down

0 comments on commit 98be798

Please sign in to comment.