Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert dora fusion test from pytest to unittest #1622

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 129 additions & 138 deletions test/dora/test_dora_fusion.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,24 @@
import sys

import pytest
import unittest

if sys.version_info < (3, 11):
pytest.skip("requires Python >= 3.11", allow_module_level=True)
raise unittest.SkipTest("requires Python >= 3.11")

triton = pytest.importorskip("triton", reason="requires triton")
try:
import triton
except ImportError:
raise unittest.SkipTest("requires triton")

import itertools

import torch
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import run_tests

from torchao.prototype.dora.kernels.matmul import triton_mm
from torchao.prototype.dora.kernels.smallk import triton_mm_small_k

torch.manual_seed(0)

# Test configs
M = 4096
N = 4096
Ks = [int(2**i) for i in range(4, 7)]

FUSED_DORA_SHAPES = [(M, N, K) for K in Ks[:1]]

DTYPES = [torch.float32, torch.float16, torch.bfloat16]

STORE_ACC = [False]
EPILOGUE_NORM = [True, False]
ADD_SOURCE = [True]
MAGNITUDE_VECTOR = [True]
FUSED_DORA_TEST_CONFIGS = list(
itertools.product(
FUSED_DORA_SHAPES,
STORE_ACC,
EPILOGUE_NORM,
ADD_SOURCE,
MAGNITUDE_VECTOR,
DTYPES,
)
)


def _arg_to_id(arg):
if isinstance(arg, (tuple, list)):
return "x".join([str(x) for x in arg])
Expand All @@ -56,59 +34,73 @@ def check(expected, actual, dtype):
return diff


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
@pytest.mark.parametrize(
"shape, store_acc, epilogue_norm, add_source, magnitude_vector, dtype",
FUSED_DORA_TEST_CONFIGS,
ids=_arg_to_id,
)
def test_dora_column_norm(
shape, store_acc, epilogue_norm, add_source, magnitude_vector, dtype
):
if not (store_acc or epilogue_norm):
pytest.skip("Either store_acc or epilogue_norm must be True")

M, N, K = shape
A = torch.randn(M, K, device="cuda", dtype=dtype)
B = torch.randn(K, N, device="cuda", dtype=dtype)
source = torch.randn(M, N, device="cuda", dtype=dtype)
magnitude = torch.randn(M, device="cuda", dtype=dtype)

c_ref = torch.matmul(A, B)
norm2_ref = 1 / c_ref.norm(2, dim=1)
source_ref = source + c_ref
source_norm2_ref = 1 / (source + c_ref).norm(2, dim=1)
source_norm2_magnitude_ref = magnitude * source_norm2_ref

# First test small K only kernel, no epilogue
# source = None # source # None
# magnitude = None # magnitude # None

tt_out = triton_mm_small_k(
A,
B,
source=source if add_source else None,
magnitude=magnitude if magnitude_vector else None,
epilogue_norm=epilogue_norm,
store_acc=store_acc,
)
M = 4096
N = 4096
Ks = [int(2**i) for i in range(4, 7)]
FUSED_DORA_SHAPES = [(M, N, K) for K in Ks[:1]]
DTYPES = [torch.float32, torch.float16, torch.bfloat16]
STORE_ACC = [False]
EPILOGUE_NORM = [True, False]
ADD_SOURCE = [True]
MAGNITUDE_VECTOR = [True]

if store_acc:
c_test = tt_out[0] if epilogue_norm else tt_out
if add_source:
check(source_ref, c_test, dtype)
else:
check(c_ref, c_test, dtype)

if epilogue_norm:
norm2_test = tt_out[1] if store_acc else tt_out
if add_source:
if magnitude_vector:
check(source_norm2_magnitude_ref, norm2_test, dtype)
class TestDoraColumnNorm(common_utils.TestCase):

@common_utils.parametrize("shape", FUSED_DORA_SHAPES)
@common_utils.parametrize("store_acc", STORE_ACC)
@common_utils.parametrize("epilogue_norm", EPILOGUE_NORM)
@common_utils.parametrize("add_source", ADD_SOURCE)
@common_utils.parametrize("magnitude_vector", MAGNITUDE_VECTOR)
@common_utils.parametrize("dtype", DTYPES)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_dora_column_norm(
self, shape, store_acc, epilogue_norm, add_source, magnitude_vector, dtype
):
if not (store_acc or epilogue_norm):
# pytest.skip("Either store_acc or epilogue_norm must be True")
raise unittest.SkipTest("Either store_acc or epilogue_norm must be True")

M, N, K = shape
A = torch.randn(M, K, device="cuda", dtype=dtype)
B = torch.randn(K, N, device="cuda", dtype=dtype)
source = torch.randn(M, N, device="cuda", dtype=dtype)
magnitude = torch.randn(M, device="cuda", dtype=dtype)

c_ref = torch.matmul(A, B)
norm2_ref = 1 / c_ref.norm(2, dim=1)
source_ref = source + c_ref
source_norm2_ref = 1 / (source + c_ref).norm(2, dim=1)
source_norm2_magnitude_ref = magnitude * source_norm2_ref

# First test small K only kernel, no epilogue
# source = None # source # None
# magnitude = None # magnitude # None

tt_out = triton_mm_small_k(
A,
B,
source=source if add_source else None,
magnitude=magnitude if magnitude_vector else None,
epilogue_norm=epilogue_norm,
store_acc=store_acc,
)

if store_acc:
c_test = tt_out[0] if epilogue_norm else tt_out
if add_source:
check(source_ref, c_test, dtype)
else:
check(c_ref, c_test, dtype)

if epilogue_norm:
norm2_test = tt_out[1] if store_acc else tt_out
if add_source:
if magnitude_vector:
check(source_norm2_magnitude_ref, norm2_test, dtype)
else:
check(source_norm2_ref, norm2_test, dtype)
else:
check(source_norm2_ref, norm2_test, dtype)
else:
check(norm2_ref, norm2_test, dtype)
check(norm2_ref, norm2_test, dtype)


BATCH_SIZES = [int(2**i) for i in range(6)]
Expand All @@ -124,64 +116,63 @@ def test_dora_column_norm(
EPILOGUE_ELEMENTWISE_ADD = [True]
EPILOGUE_BROADCAST_SCALE = [True]

FUSED_MATMUL_TEST_CONFIGS = list(
itertools.product(
FUSED_MATMUL_SHAPES[:1],
DTYPES,
EPILOGUE_ELEMENTWISE_ADD,
EPILOGUE_BROADCAST_SCALE,
)
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
@pytest.mark.parametrize(
"shape, dtype, epilogue_add, epilogue_scale",
FUSED_MATMUL_TEST_CONFIGS,
ids=_arg_to_id,
)
def test_dora_matmul(shape, dtype, epilogue_add, epilogue_scale):
M, K, N = shape
A = torch.randn(M, K, device="cuda", dtype=dtype)
B = torch.randn(K, N, device="cuda", dtype=dtype)
C = torch.randn(M, N, device="cuda", dtype=dtype) if epilogue_add else None
scale = torch.randn(N, device="cuda", dtype=dtype) if epilogue_scale else None

D_ref = torch.matmul(A, B)
if epilogue_add:
D_ref += C
if epilogue_scale:
D_ref *= scale.unsqueeze(0)

D_test = triton_mm(A, B, epilogue_source=C, epilogue_scale=scale)
check(D_ref, D_test, dtype)
class TestDoraMatmul(common_utils.TestCase):
@common_utils.parametrize("shape", FUSED_MATMUL_SHAPES[:1])
@common_utils.parametrize("dtype", DTYPES)
@common_utils.parametrize("epilogue_add", EPILOGUE_ELEMENTWISE_ADD)
@common_utils.parametrize("epilogue_scale", EPILOGUE_BROADCAST_SCALE)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_dora_matmul(shape, dtype, epilogue_add, epilogue_scale):
M, K, N = shape
A = torch.randn(M, K, device="cuda", dtype=dtype)
B = torch.randn(K, N, device="cuda", dtype=dtype)
C = torch.randn(M, N, device="cuda", dtype=dtype) if epilogue_add else None
scale = torch.randn(N, device="cuda", dtype=dtype) if epilogue_scale else None

D_ref = torch.matmul(A, B)
if epilogue_add:
D_ref += C
if epilogue_scale:
D_ref *= scale.unsqueeze(0)

D_test = triton_mm(A, B, epilogue_source=C, epilogue_scale=scale)
check(D_ref, D_test, dtype)


MODES = ["default"]


@pytest.mark.skip("TODO: torch.compile does not work with custom kernel")
@pytest.mark.parametrize(
"shape, dtype, epilogue_add, epilogue_scale, mode",
[[*cfg, mode] for cfg in FUSED_MATMUL_TEST_CONFIGS for mode in MODES][:1],
ids=_arg_to_id,
)
def test_dora_matmul_compile(shape, dtype, epilogue_add, epilogue_scale, mode):
M, K, N = shape
A = torch.randn(M, K, device="cuda", dtype=dtype)
B = torch.randn(K, N, device="cuda", dtype=dtype)
C = torch.randn(M, N, device="cuda", dtype=dtype) if epilogue_add else None
scale = torch.randn(N, device="cuda", dtype=dtype) if epilogue_scale else None

D_ref = torch.matmul(A, B)
if epilogue_add:
D_ref += C
if epilogue_scale:
D_ref *= scale.unsqueeze(0)

D_test = triton_mm(A, B, epilogue_source=C, epilogue_scale=scale)
check(D_ref, D_test, dtype)

triton_compiled = torch.compile(triton_mm, mode=mode)
D_compiled = triton_compiled(A, B, epilogue_source=C, epilogue_scale=scale)
check(D_ref, D_compiled, dtype)
class TestDoraMatmulCompile(common_utils.TestCase):
@common_utils.parametrize("shape", FUSED_MATMUL_SHAPES[:1])
@common_utils.parametrize("dtype", DTYPES)
@common_utils.parametrize("epilogue_add", EPILOGUE_ELEMENTWISE_ADD)
@common_utils.parametrize("epilogue_scale", EPILOGUE_BROADCAST_SCALE)
@common_utils.parametrize("mode", MODES)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("TODO: torch.compile does not work with custom kernel")
def test_dora_matmul_compile(shape, dtype, epilogue_add, epilogue_scale, mode):
M, K, N = shape
A = torch.randn(M, K, device="cuda", dtype=dtype)
B = torch.randn(K, N, device="cuda", dtype=dtype)
C = torch.randn(M, N, device="cuda", dtype=dtype) if epilogue_add else None
scale = torch.randn(N, device="cuda", dtype=dtype) if epilogue_scale else None

D_ref = torch.matmul(A, B)
if epilogue_add:
D_ref += C
if epilogue_scale:
D_ref *= scale.unsqueeze(0)

D_test = triton_mm(A, B, epilogue_source=C, epilogue_scale=scale)
check(D_ref, D_test, dtype)

triton_compiled = torch.compile(triton_mm, mode=mode)
D_compiled = triton_compiled(A, B, epilogue_source=C, epilogue_scale=scale)
check(D_ref, D_compiled, dtype)

osbm marked this conversation as resolved.
Show resolved Hide resolved

common_utils.instantiate_parametrized_tests(TestDoraColumnNorm)
common_utils.instantiate_parametrized_tests(TestDoraMatmul)
common_utils.instantiate_parametrized_tests(TestDoraMatmulCompile)

if __name__ == "__main__":
run_tests()
Loading