Skip to content

Commit

Permalink
[FRONTEND] Support passing dtype as constexpr for tma load (#4821)
Browse files Browse the repository at this point in the history
Fixing an compile error like below when passing dtype through kernel arg
for `tl._experimental_descriptor_load`:

 AttributeError: 'constexpr' object has no attribute 'to_ir'
  • Loading branch information
htyu authored Sep 28, 2024
1 parent fe47f98 commit 6af74b2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions python/test/unit/hopper/test_experimental_tma.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def kernel(Z, desc, SIZE: tl.constexpr, BYVAL_TMA: tl.constexpr):
@triton.jit
def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
BYVAL_TMA: tl.constexpr):
BYVAL_TMA: tl.constexpr, dtype: tl.constexpr):
if not BYVAL_TMA:
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
Expand All @@ -72,11 +72,11 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
offs_k = 0
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], tl.float16)
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tl.float16)
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype)
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], dtype)
accumulator = tl.dot(a, b, acc=accumulator)
offs_k += BLOCK_SIZE_K
accumulator = accumulator.to(tl.float16)
accumulator = accumulator.to(dtype)
tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])


Expand All @@ -101,7 +101,7 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tm
desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size())
kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1,
1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma,
num_warps=8, num_stages=num_stages)
num_warps=8, num_stages=num_stages, dtype=tl.float16)
ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16)
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
if BLOCK_M >= 64 and BLOCK_N >= 64:
Expand Down
2 changes: 1 addition & 1 deletion python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,7 +1613,7 @@ def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=
This loads a tensor of data based on the descriptor and offsets.
"""
type = block_type(dtype, shape)
type = block_type(_constexpr_to_value(dtype), shape)
return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder)


Expand Down

0 comments on commit 6af74b2

Please sign in to comment.