Replies: 3 comments 2 replies
-
The error message should be better :) Not on my workstation right now but can you try:
Normal assignment removes the constexpr qualifier, which is necessary for triton tensor shapes |
Beta Was this translation helpful? Give feedback.
-
Okay I've fixed a few more issues regarding my kernel. #!/usr/bin/env python
import numpy as np
import mxnet as mx
import triton
import triton.language as tl
# (x,y,z) -> (x,z)
@triton.jit
def kernRowReduceAxis1(
d0, s0,
x_stride:tl.constexpr,
y_stride:tl.constexpr,
BLKSZ_Y: tl.constexpr,
BLKSZ_Z: tl.constexpr,
):
idx_x = tl.program_id(0)
idx_yz = tl.arange(0, BLKSZ_Y * BLKSZ_Z)
idx_z = idx_yz % BLKSZ_Z
idx_y = idx_yz // BLKSZ_Z
rmask_z = (idx_z < y_stride)
rmask_y = (idx_y < (x_stride // y_stride))
rmask = rmask_z * rmask_y
x = tl.load(s0 + idx_x*x_stride + idx_y*y_stride + idx_z, mask=rmask)
x = tl.reshape(x, (BLKSZ_Y, BLKSZ_Z))
xsum = tl.sum(x, axis=0)
wmask = (idx_y == 0) * rmask_z
tl.store(d0 + idx_x*y_stride + idx_z, xsum, mask=wmask)
def AS_BLOCK_SIZE(i:int):
return 1<<((i-1).bit_length())
GPU0 = mx.gpu()
N,W,C = 256, 384, 256
v_x = mx.nd.random_uniform(0., 1., shape=(N,W,C), dtype=np.float32, ctx=GPU0)
v_y = mx.nd.empty(shape=(N,C,), dtype=np.float32, ctx=GPU0)
# make sure async init op finishes
hash(str(v_x))
hash(str(v_y))
print('[D] begin kernel launch')
kernRowReduceAxis1[(N,)](
v_y, v_x,
W*C, C,
AS_BLOCK_SIZE(W),
AS_BLOCK_SIZE(C))
print('[D] end kernel launch') Running the above script gives segfault.
The vecadd kernel works on MXNet so far. To separate debugging scope, I'll work on master branch after I've get my torch env working. |
Beta Was this translation helpful? Give feedback.
-
Hello! Sorry for the delay. I think you are seeing a segfault because the shapes in the
the program runs. Definitely a bug that you're not getting a better error message. All the frontend-level compiler error messages inside of Triton should be double-checked, because clearly it's not user-friendly enough at the moment. As for the practical uses of reshape, there are used implicitly in the broadcast (e.g., PS: sorry for all the trouble, Triton should have better error messages. |
Beta Was this translation helpful? Give feedback.
-
I encountered this bug while working with my fork & Apache MXNet. Using
triton.language.reshape
would throw compilation error. I can't tell where I broke the package, or perhaps it's bugged on master branch.Running the above script gives:
Switching
tl.reshape(x, [sz_y, sz_z])
intotl.reshape(x, (sz_y, sz_z))
gives a different error:triton.language.reshape
has no unit test. Also I can't find usage oftriton.language.reshape
in a github global search, so I can't tell whether this part is currently broken on master. Any help on how to fix this bug would be greatly appreciated.Beta Was this translation helpful? Give feedback.
All reactions