Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Reland] ROCm CI (Infra + Skips) #1581

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,19 @@ jobs:
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""
- name: ROCM Nightly
runs-on: linux.rocm.gpu.torchao
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3'
gpu-arch-type: "rocm"
gpu-arch-version: "6.3"

permissions:
id-token: write
contents: read
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
timeout: 120
no-sudo: ${{ matrix.gpu-arch-type == 'rocm' }}
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
Expand Down Expand Up @@ -74,7 +80,6 @@ jobs:
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"

- name: CPU 2.3
runs-on: linux.4xlarge
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'
Expand Down Expand Up @@ -102,8 +107,6 @@ jobs:
conda create -n venv python=3.9 -y
conda activate venv
echo "::group::Install newer objcopy that supports --set-section-alignment"
yum install -y devtoolset-10-binutils
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
pip install -r dev-requirements.txt
Expand Down
4 changes: 4 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_89,
skip_if_rocm,
)


Expand Down Expand Up @@ -93,6 +94,7 @@ def test_tensor_core_layout_transpose(self):
@common_utils.parametrize(
"apply_quant", get_quantization_functions(True, True, "cuda", True)
)
@skip_if_rocm("ROCm enablement in progress")
def test_weights_only(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
Expand Down Expand Up @@ -170,6 +172,7 @@ def apply_uint6_weight_only_quant(linear):

@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@skip_if_rocm("ROCm enablement in progress")
def test_print_quantized_module(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
Expand All @@ -182,6 +185,7 @@ class TestAffineQuantizedBasic(TestCase):

@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
@skip_if_rocm("ROCm enablement in progress")
def test_flatten_unflatten(self, device, dtype):
apply_quant_list = get_quantization_functions(False, True, device)
for apply_quant in apply_quant_list:
Expand Down
4 changes: 4 additions & 0 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

import pytest
import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.testing._internal import common_utils
Expand All @@ -26,6 +27,9 @@
except ModuleNotFoundError:
has_gemlite = False

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)


class TestAffineQuantizedTensorParallel(DTensorTestBase):
"""Basic test case for tensor subclasses"""
Expand Down
3 changes: 2 additions & 1 deletion test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
fpx_weight_only,
quantize_,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode, skip_if_rocm

_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_Floatx_DTYPES = [(3, 2), (2, 2)]
Expand Down Expand Up @@ -109,6 +109,7 @@ def test_to_copy_device(self, ebits, mbits):
@parametrize("bias", [False, True])
@parametrize("dtype", [torch.half, torch.bfloat16])
@unittest.skipIf(is_fbcode(), reason="broken in fbcode")
@skip_if_rocm("ROCm enablement in progress")
def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
N, OC, IC = 4, 256, 64
device = "cuda"
Expand Down
3 changes: 3 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
nf4_weight_only,
to_nf4,
)
from torchao.utils import skip_if_rocm

bnb_available = False

Expand Down Expand Up @@ -111,6 +112,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype):

@unittest.skipIf(not bnb_available, "Need bnb availble")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@skip_if_rocm("ROCm enablement in progress")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):
# From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47
Expand All @@ -133,6 +135,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):

@unittest.skipIf(not bnb_available, "Need bnb availble")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@skip_if_rocm("ROCm enablement in progress")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_nf4_bnb_linear(self, dtype: torch.dtype):
"""
Expand Down
4 changes: 3 additions & 1 deletion test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm


def _apply_weight_only_uint4_quant(model):
Expand Down Expand Up @@ -92,6 +92,7 @@ def test_basic_tensor_ops(self):
# only test locally
# print("x:", x[0])

@skip_if_rocm("ROCm enablement in progress")
def test_gpu_quant(self):
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
x = torch.randn(*x_shape)
Expand All @@ -104,6 +105,7 @@ def test_gpu_quant(self):
# make sure it runs
opt(x)

@skip_if_rocm("ROCm enablement in progress")
def test_pt2e_quant(self):
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
QuantizationConfig,
Expand Down
2 changes: 2 additions & 0 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_90,
skip_if_rocm,
)

if not TORCH_VERSION_AT_LEAST_2_5:
Expand Down Expand Up @@ -423,6 +424,7 @@ def test_linear_from_config_params(
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_bias", [True, False])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skip_if_rocm("ROCm enablement in progress")
def test_linear_from_recipe(
self,
recipe_name,
Expand Down
3 changes: 3 additions & 0 deletions test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
if not is_sm_at_least_89():
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)

if torch.version.hip is not None:
pytest.skip("ROCm enablement in progress", allow_module_level=True)


class TestFloat8Common:
def broadcast_module(self, module: nn.Module) -> None:
Expand Down
2 changes: 2 additions & 0 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
skip_if_rocm,
)

cuda_available = torch.cuda.is_available()
Expand Down Expand Up @@ -110,6 +111,7 @@ def test_hqq_plain_5bit(self):
ref_dot_product_error=0.000704,
)

@skip_if_rocm("ROCm enablement in progress")
def test_hqq_plain_4bit(self):
self._test_hqq(
dtype=torch.uint4,
Expand Down
7 changes: 7 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
benchmark_model,
is_fbcode,
is_sm_at_least_90,
skip_if_rocm,
unwrap_tensor_subclass,
)

Expand Down Expand Up @@ -569,6 +570,7 @@ def test_per_token_linear_cpu(self):
self._test_per_token_linear_impl("cpu", dtype)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@skip_if_rocm("ROCm enablement in progress")
def test_per_token_linear_cuda(self):
for dtype in (torch.float32, torch.float16, torch.bfloat16):
self._test_per_token_linear_impl("cuda", dtype)
Expand Down Expand Up @@ -687,6 +689,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm enablement in progress")
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand All @@ -706,6 +709,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm enablement in progress")
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand Down Expand Up @@ -899,6 +903,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm enablement in progress")
def test_int4_weight_only_quant_subclass(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand All @@ -918,6 +923,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm enablement in progress")
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand Down Expand Up @@ -1071,6 +1077,7 @@ def test_gemlite_layout(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm enablement in progress")
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand Down
3 changes: 3 additions & 0 deletions test/kernel/test_fused_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch
from galore_test_utils import get_kernel, make_copy, make_data

from torchao.utils import skip_if_rocm

torch.manual_seed(0)
MAX_DIFF_no_tf32 = 1e-5
MAX_DIFF_tf32 = 1e-3
Expand Down Expand Up @@ -104,6 +106,7 @@ def run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32):

@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
@pytest.mark.parametrize("kernel, dtype, M, N, rank, allow_tf32", TEST_CONFIGS)
@skip_if_rocm("ROCm enablement in progress")
def test_galore_fused_kernels(kernel, dtype, M, N, rank, allow_tf32):
torch.backends.cuda.matmul.allow_tf32 = allow_tf32

Expand Down
2 changes: 2 additions & 0 deletions test/kernel/test_galore_downproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk
from torchao.prototype.galore.kernels.matmul import triton_mm_launcher
from torchao.utils import skip_if_rocm

torch.manual_seed(0)

Expand All @@ -29,6 +30,7 @@

@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
@pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS)
@skip_if_rocm("ROCm enablement in progress")
def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype):
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32
Expand Down
7 changes: 6 additions & 1 deletion test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import torch

from torchao.quantization import quantize_
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_5,
skip_if_rocm,
)

if TORCH_VERSION_AT_LEAST_2_3:
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
Expand Down Expand Up @@ -113,6 +117,7 @@ def test_awq_loading(device, qdtype):

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@skip_if_rocm("ROCm enablement in progress")
def test_save_weights_only():
dataset_size = 100
l1, l2, l3 = 512, 256, 128
Expand Down
4 changes: 4 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
get_available_devices,
skip_if_rocm,
)

try:
Expand Down Expand Up @@ -112,6 +113,7 @@ class TestOptim(TestCase):
)
@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("device", _DEVICES)
@skip_if_rocm("ROCm enablement in progress")
def test_optim_smoke(self, optim_name, dtype, device):
if optim_name.endswith("Fp8") and device == "cuda":
if not TORCH_VERSION_AT_LEAST_2_4:
Expand Down Expand Up @@ -185,6 +187,7 @@ def test_subclass_slice(self, subclass, shape, device):
not torch.cuda.is_available(),
reason="bitsandbytes 8-bit Adam only works for CUDA",
)
@skip_if_rocm("ROCm enablement in progress")
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
def test_optim_8bit_correctness(self, optim_name):
device = "cuda"
Expand Down Expand Up @@ -384,6 +387,7 @@ def world_size(self) -> int:
not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required."
)
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
@skip_if_rocm("ROCm enablement in progress")
def test_fsdp2(self):
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
if torch.cuda.get_device_capability() >= (8, 9):
Expand Down
3 changes: 3 additions & 0 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
TORCH_VERSION_AT_LEAST_2_5,
)

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)


class ToyLinearModel(torch.nn.Module):
def __init__(self, m=512, n=256, k=128):
Expand Down
4 changes: 3 additions & 1 deletion test/prototype/test_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
except ImportError:
triton_available = False

from torchao.utils import skip_if_compute_capability_less_than

from torchao.utils import skip_if_compute_capability_less_than, skip_if_rocm


@unittest.skipIf(not triton_available, "Triton is required but not available")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
class TestFP8Gemm(TestCase):
@skip_if_compute_capability_less_than(9.0)
@skip_if_rocm("ROCm enablement in progress")
def test_gemm_split_k(self):
dtype = torch.float16
qdtype = torch.float8_e4m3fn
Expand Down
2 changes: 2 additions & 0 deletions test/quantization/test_galore_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
triton_dequant_blockwise,
triton_quantize_blockwise,
)
from torchao.utils import skip_if_rocm

SEED = 0
torch.manual_seed(SEED)
Expand Down Expand Up @@ -82,6 +83,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
"dim1,dim2,dtype,signed,blocksize",
TEST_CONFIGS,
)
@skip_if_rocm("ROCm enablement in progress")
def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01

Expand Down
3 changes: 2 additions & 1 deletion test/quantization/test_marlin_qqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
MappingType,
choose_qparams_and_quantize_affine_qqq,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode, skip_if_rocm


@unittest.skipIf(
is_fbcode(),
"Skipping the test in fbcode since we don't have TARGET file for kernels",
)
@skip_if_rocm("ROCm enablement in progress")
class TestMarlinQQQ(TestCase):
def setUp(self):
super().setUp()
Expand Down
Loading
Loading