diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index 94c3b47a7357..80f6f3d09e83 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -40,9 +40,12 @@ Shape Manipulation Ops broadcast broadcast_to expand_dims + interleave + join permute ravel reshape + split trans view diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 293333e211ea..67babce875e4 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -383,7 +383,7 @@ def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; } -def TT_ExperimentalJoinOp : TT_Op<"experimental_join", [ +def TT_JoinOp : TT_Op<"join", [ NoMemoryEffect, SameTypeOperands, DeclareOpInterfaceMethods, ]> { @@ -401,7 +401,7 @@ def TT_ExperimentalJoinOp : TT_Op<"experimental_join", [ let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; } -def TT_ExperimentalSplitOp : TT_Op<"experimental_split", [ +def TT_SplitOp : TT_Op<"split", [ NoMemoryEffect, DeclareOpInterfaceMethods, TypesMatchWith<"outLHS and outRHS types match", diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 6d456321fedf..2ce7ea36cdb2 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -119,13 +119,13 @@ struct CatOpConversion : public ConvertOpToLLVMPattern { return success(); } }; -struct JoinOpConversion : public ConvertOpToLLVMPattern { - using OpAdaptor = typename ExperimentalJoinOp::Adaptor; +struct JoinOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename JoinOp::Adaptor; explicit JoinOpConversion(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) - : ConvertOpToLLVMPattern(typeConverter, benefit) {} + : ConvertOpToLLVMPattern(typeConverter, benefit) {} LogicalResult - matchAndRewrite(ExperimentalJoinOp op, OpAdaptor adaptor, + matchAndRewrite(JoinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We rely on the following invariants of this op (which are checked by its // verifier): @@ -157,11 +157,11 @@ struct JoinOpConversion : public ConvertOpToLLVMPattern { return success(); } }; -struct SplitOpConversion : public ConvertOpToLLVMPattern { - using OpAdaptor = typename ExperimentalSplitOp::Adaptor; +struct SplitOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename SplitOp::Adaptor; using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(ExperimentalSplitOp op, OpAdaptor adaptor, + matchAndRewrite(SplitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We rely on the following invariants of this op (which are checked by its // verifier): diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index befae59fb497..27d99ff820ad 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -321,27 +321,25 @@ struct TritonCatPattern : public OpConversionPattern { } }; -struct TritonJoinOpPattern - : public OpConversionPattern { +struct TritonJoinOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(ExperimentalJoinOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(JoinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // Simply rely on type inference for this op. (Notably, GenericOpPattern // does not do this, instead it assigns the default layout to the ins and // outs.) - addNamedAttrs(rewriter.replaceOpWithNewOp( + addNamedAttrs(rewriter.replaceOpWithNewOp( op, adaptor.getLhs(), adaptor.getRhs()), adaptor.getAttributes()); return success(); } }; -struct TritonSplitOpPattern - : public OpConversionPattern { +struct TritonSplitOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(ExperimentalSplitOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(SplitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto src = adaptor.getSrc(); auto srcTy = src.getType().cast(); @@ -393,9 +391,8 @@ struct TritonSplitOpPattern src = rewriter.create(op.getLoc(), srcTy, src); } - addNamedAttrs( - rewriter.replaceOpWithNewOp(op, src), - adaptor.getAttributes()); + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src), + adaptor.getAttributes()); return success(); } }; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 39ae5b14c59e..b4d28fb08502 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1026,11 +1026,12 @@ LogicalResult ReturnOp::verify() { return success(); } -// -- ExperimentalJoinOp -- -LogicalResult ExperimentalJoinOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +// -- JoinOp -- +LogicalResult +JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { // These should have been checked by tablegen-generated code. assert(operands.size() == 2); assert(operands[0].getType() == operands[1].getType()); @@ -1058,8 +1059,8 @@ LogicalResult ExperimentalJoinOp::inferReturnTypes( return success(); } -// -- ExperimentalSplitOp -- -LogicalResult ExperimentalSplitOp::inferReturnTypes( +// -- SplitOp -- +LogicalResult SplitOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 0e7130de2a03..4625f803189c 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -301,8 +301,8 @@ SmallVector LayoutPropagation::propagateToUsers(Value value, } if (user->hasTrait() || user->hasTrait() || - isa(user)) { + isa(user)) { setEncoding(user->getResults(), info, changed, user); continue; } @@ -706,8 +706,8 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) { } if (op->hasTrait() || op->hasTrait() || - isa(op)) { + isa( + op)) { Operation *newOp = cloneElementwise(rewriter, op, encoding); for (auto [oldResult, newResult] : llvm::zip(op->getResults(), newOp->getResults())) diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 3b4d9362a078..8dd6d43a116a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -286,8 +286,7 @@ static std::optional inferDstEncoding(triton::ExpandDimsOp op, return sliceEncoding.getParent(); } -static std::optional inferDstEncoding(ExperimentalJoinOp op, - Attribute srcEnc) { +static std::optional inferDstEncoding(JoinOp op, Attribute srcEnc) { Attribute dstEnc; if (srcEnc.getDialect() .getRegisteredInterface() @@ -299,8 +298,7 @@ static std::optional inferDstEncoding(ExperimentalJoinOp op, return std::nullopt; } -static std::optional inferDstEncoding(ExperimentalSplitOp op, - Attribute srcEnc) { +static std::optional inferDstEncoding(SplitOp op, Attribute srcEnc) { Attribute dstEnc; if (srcEnc.getDialect() .getRegisteredInterface() @@ -328,8 +326,7 @@ static std::optional inferSrcEncoding(triton::ExpandDimsOp op, encoding); } -static std::optional inferSrcEncoding(ExperimentalJoinOp op, - Attribute dstEnc) { +static std::optional inferSrcEncoding(JoinOp op, Attribute dstEnc) { // Split is the inverse of join. Attribute srcEnc; if (dstEnc.getDialect() @@ -341,8 +338,7 @@ static std::optional inferSrcEncoding(ExperimentalJoinOp op, return std::nullopt; } -static std::optional inferSrcEncoding(ExperimentalSplitOp op, - Attribute dstEnc) { +static std::optional inferSrcEncoding(SplitOp op, Attribute dstEnc) { // Join is the inverse of split. Attribute srcEnc; if (dstEnc.getDialect() @@ -438,9 +434,9 @@ std::optional inferSrcEncoding(Operation *op, Attribute encoding) { return inferSrcEncoding(reduceOp, encoding); if (auto expand = dyn_cast(op)) return inferSrcEncoding(expand, encoding); - if (auto join = dyn_cast(op)) + if (auto join = dyn_cast(op)) return inferSrcEncoding(join, encoding); - if (auto split = dyn_cast(op)) + if (auto split = dyn_cast(op)) return inferSrcEncoding(split, encoding); if (auto trans = dyn_cast(op)) return inferSrcEncoding(trans, encoding); @@ -464,9 +460,9 @@ std::optional inferDstEncoding(Operation *op, Attribute encoding) { return inferDstEncoding(reduceOp, encoding); if (auto expand = dyn_cast(op)) return inferDstEncoding(expand, encoding); - if (auto join = dyn_cast(op)) + if (auto join = dyn_cast(op)) return inferDstEncoding(join, encoding); - if (auto split = dyn_cast(op)) + if (auto split = dyn_cast(op)) return inferDstEncoding(split, encoding); if (auto trans = dyn_cast(op)) return inferDstEncoding(trans, encoding); diff --git a/python/src/ir.cc b/python/src/ir.cc index 1bf12f725f98..31c35a29b574 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1191,11 +1191,11 @@ void init_triton_ir(py::module &&m) { }) .def("create_join", [](TritonOpBuilder &self, Value &a, Value &b) -> Value { - return self.create(a, b); + return self.create(a, b); }) .def("create_split", [](TritonOpBuilder &self, Value &a) -> std::vector { - auto op = self.create(a); + auto op = self.create(a); return std::vector(op->result_begin(), op->result_end()); }) // Implements tl.trans and tl.permute. diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 361be81d3d4b..ad58679cd571 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1511,7 +1511,7 @@ def kernel(X, Y, Z, N: tl.constexpr): offs = tl.arange(0, N) x = tl.load(X + offs) y = tl.load(Y + offs) - z = tl._experimental_join(x, y) + z = tl.join(x, y) tl.store(Z + tl.arange(0, N)[:, None] * 2 + tl.arange(0, 2)[None, :], z) x = torch.arange(0, 128, device=device).to(torch.int32) @@ -1531,7 +1531,7 @@ def test_join_scalars(device): def kernel(X, Y, Z): x = tl.load(X) y = tl.load(Y) - z = tl._experimental_join(x, y) + z = tl.join(x, y) tl.static_assert(z.shape == [2]) tl.store(Z + tl.arange(0, 2), z) @@ -1550,7 +1550,7 @@ def test_join_with_mma(device): @triton.jit def kernel(X, Z): x = tl.load(X + 16 * tl.arange(0, 32)[:, None] + tl.arange(0, 16)[None, :]) # (32,16) - x2 = tl._experimental_join(x, 2 * x) # (32,16,2) + x2 = tl.join(x, 2 * x) # (32,16,2) x3 = tl.reshape(x2, (32, 32)) z = tl.dot(x3, x3) # (32,32) tl.store(Z + 32 * tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :], z) @@ -1564,6 +1564,40 @@ def kernel(X, Z): torch.testing.assert_close(z, z_ref) +def test_interleave(device): + if is_hip(): + pytest.skip("test_interleaves not supported on HIP") + + @triton.jit + def kernel(Z, N: tl.constexpr): + z = tl.interleave(tl.arange(0, N), tl.arange(N, 2 * N)) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(128, 256, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1).reshape(256) + z = torch.zeros_like(z_ref) + kernel[(1, )](z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +def test_interleave_scalars(device): + if is_hip(): + pytest.skip("test_interleaves not supported on HIP") + + @triton.jit + def kernel(X, Y, Z): + z = tl.interleave(X, Y) + tl.static_assert(z.shape == [tl.constexpr(2)]) + tl.store(Z + tl.arange(0, 2), z) + + z = torch.zeros(2, device=device) + kernel[(1, )](10, 20, z) + + np.testing.assert_equal([10, 20], to_numpy(z)) + + def test_split(device): if is_hip(): pytest.skip("test_split not supported on HIP") @@ -1573,7 +1607,7 @@ def kernel(X, Z1, Z2, N: tl.constexpr): offs = tl.arange(0, N) x = tl.load(X + offs) x1 = tl.reshape(x, (N // 2, 2)) - z1, z2 = tl._experimental_split(x1) + z1, z2 = tl.split(x1) tl.store(Z1 + tl.arange(0, N // 2), z1) tl.store(Z2 + tl.arange(0, N // 2), z2) @@ -1595,7 +1629,7 @@ def test_split_to_scalar(device): def kernel(X, Z1, Z2): offs = tl.arange(0, 2) x = tl.load(X + offs) - z1, z2 = tl._experimental_split(x) + z1, z2 = tl.split(x) tl.static_assert(isinstance(z1, tl.tensor)) tl.static_assert(isinstance(z2, tl.tensor)) tl.static_assert(z1.shape == []) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 4bf719c5bdf7..9429db23b02e 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -10,13 +10,14 @@ cumprod, cumsum, flip, + interleave, max, min, + ravel, sigmoid, softmax, sort, sum, - ravel, swizzle2d, xor_sum, zeros, @@ -25,6 +26,8 @@ from .core import ( PropagateNan, TRITON_MAX_TENSOR_NUMEL, + _experimental_join, + _experimental_split, abs, advance, arange, @@ -42,9 +45,9 @@ broadcast, broadcast_to, cat, + clamp, constexpr, cos, - clamp, debug_barrier, device_assert, device_print, @@ -52,7 +55,6 @@ dtype, exp, expand_dims, - full, fdiv, float16, float32, @@ -61,6 +63,7 @@ float8e4b15x4, float8e4nv, float8e5, + full, function_type, histogram, inline_asm_elementwise, @@ -69,6 +72,7 @@ int32, int64, int8, + join, load, log, make_block_ptr, @@ -86,14 +90,14 @@ reduce, reshape, sin, + split, sqrt, static_assert, static_print, - store, static_range, + store, tensor, trans, - # triton, uint16, uint32, uint64, @@ -102,8 +106,6 @@ view, void, where, - _experimental_join, - _experimental_split, ) from .random import ( pair_uniform_to_normal, @@ -121,11 +123,13 @@ __all__ = [ "PropagateNan", "TRITON_MAX_TENSOR_NUMEL", + "_experimental_join", + "_experimental_split", "abs", "advance", "arange", - "argmin", "argmax", + "argmin", "associative_scan", "atomic_add", "atomic_and", @@ -168,16 +172,18 @@ "function_type", "histogram", "inline_asm_elementwise", + "interleave", "int1", "int16", "int32", "int64", "int8", "ir", - "math", + "join", "load", "log", "make_block_ptr", + "math", "max", "max_constancy", "max_contiguous", @@ -207,10 +213,11 @@ "sin", "softmax", "sort", + "split", "sqrt", - "static_range", "static_assert", "static_print", + "static_range", "store", "sum", "swizzle2d", @@ -219,9 +226,9 @@ "triton", "uint16", "uint32", - "uint_to_uniform_float", "uint64", "uint8", + "uint_to_uniform_float", "umulhi", "view", "void", @@ -229,8 +236,6 @@ "xor_sum", "zeros", "zeros_like", - _experimental_join, - _experimental_split, ] diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 2665b6e88002..a11182e1e53e 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -922,7 +922,7 @@ def trans(self, *dims) -> tensor: def permute(self, *dims) -> tensor: ... - def _experimental_split(self) -> tuple[tensor, tensor]: + def split(self) -> tuple[tensor, tensor]: ... def view(self, *shape) -> tensor: @@ -1234,6 +1234,12 @@ def cat(input, other, can_reorder=False, _builder=None): @builtin def _experimental_join(a, b, _builder=None): + """Forwards to core.join for temporary backwards compat.""" + return join(a, b, _builder) + + +@builtin +def join(a, b, _builder=None): """ Join the given tensors in a new, minor dimension. @@ -1256,6 +1262,13 @@ def _experimental_join(a, b, _builder=None): return semantic.join(a, b, _builder) +@_tensor_member_fn +@builtin +def _experimental_split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]: + """Forwards to core.split for temporary backwards compat.""" + return split(a, _builder, _generator) + + @jit def _take_first(a, b): return a @@ -1263,7 +1276,7 @@ def _take_first(a, b): @_tensor_member_fn @builtin -def _experimental_split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]: +def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]: """ Split a tensor in two along its last dim, which must have size 2. diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index f1700bd5a8ae..7833d39fe97b 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -416,3 +416,25 @@ def flip(x, dim=None): y = sum(y * flip2, i + 1, keep_dims=True) x = core.reshape(y, x.shape) return x + + +@jit +def interleave(a, b): + """ + Interleaves the values of two tensors along their last dimension. + + The two tensors must have the same shape. + + Equivalent to `tl.join(a, b).reshape(a.shape[-1:] + [2 * a.shape[-1]])` + """ + c = core.join(a, b) + + assert isinstance(c.shape, list) + if len(c.shape) == 1: + # We must have interleaved two scalars. + return c + else: + # This `else` is necessary because Triton's AST parser doesn't + # understand that if we take the `if` above we definitely don't run this + # `else`. + return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]]) diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index 0d31887f8d28..9200350b239b 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -29,7 +29,7 @@ tt.func public @reshape_different_num_elements(%arg0: tensor<32x128xf16>) { // expected-note @+1 {{prior use}} tt.func public @fn(%arg0: tensor<32xf32>, %arg1: tensor<33xf32>) { // expected-error @+1 {{expects different type}} - %a = tt.experimental_join %arg0, %arg1 : tensor<32xf32> -> tensor<32x2xf32> + %a = tt.join %arg0, %arg1 : tensor<32xf32> -> tensor<32x2xf32> tt.return } @@ -38,7 +38,7 @@ tt.func public @fn(%arg0: tensor<32xf32>, %arg1: tensor<33xf32>) { // expected-note @+1 {{prior use}} tt.func public @fn(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf16>) { // expected-error @+1 {{expects different type}} - %a = tt.experimental_join %arg0, %arg1 : tensor<32x32xf32> -> tensor<32x32x2xf32> + %a = tt.join %arg0, %arg1 : tensor<32x32xf32> -> tensor<32x32x2xf32> tt.return } @@ -47,7 +47,7 @@ tt.func public @fn(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf16>) { tt.func public @fn(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>) { // expected-error @+2 {{op failed to infer returned types}} // expected-error @+1 {{incompatible with return type}} - %a = tt.experimental_join %arg0, %arg1 : tensor<32xf32> -> tensor<64xf32> + %a = tt.join %arg0, %arg1 : tensor<32xf32> -> tensor<64xf32> tt.return } @@ -56,7 +56,7 @@ tt.func public @fn(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>) { tt.func public @fn(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) { // expected-error @+2 {{op failed to infer returned types}} // expected-error @+1 {{incompatible with return type}} - %a = tt.experimental_join %arg0, %arg1 : tensor<32x32xf32> -> tensor<32x64xf32> + %a = tt.join %arg0, %arg1 : tensor<32x32xf32> -> tensor<32x64xf32> tt.return } @@ -64,7 +64,7 @@ tt.func public @fn(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) { // This one is OK tt.func public @fn(%arg0: tensor, %arg1: tensor) { - %a = tt.experimental_join %arg0, %arg1 : tensor -> tensor<2xf32> + %a = tt.join %arg0, %arg1 : tensor -> tensor<2xf32> tt.return } @@ -72,7 +72,7 @@ tt.func public @fn(%arg0: tensor, %arg1: tensor) { tt.func public @fn(%arg0: f32, %arg1: f32) { // expected-error @+1 {{kind of type}} - %a = tt.experimental_join %arg0, %arg1 : f32 -> tensor<2xf32> + %a = tt.join %arg0, %arg1 : f32 -> tensor<2xf32> tt.return } @@ -131,7 +131,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c tt.func public @fn(%arg0: tensor<32xf32, #blocked>) { // expected-error @+2 {{op failed to infer returned types}} // expected-error @+1 {{incompatible with return type}} - %a = tt.experimental_join %arg0, %arg0 : tensor<32xf32, #blocked> -> tensor<32x2xf32> + %a = tt.join %arg0, %arg0 : tensor<32xf32, #blocked> -> tensor<32x2xf32> tt.return } } // end module @@ -143,7 +143,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c tt.func public @fn(%arg0: tensor<32xf32>) { // expected-error @+2 {{op failed to infer returned types}} // expected-error @+1 {{incompatible with return type}} - %a = tt.experimental_join %arg0, %arg0 : tensor<32xf32> -> tensor<32x2xf32, #shared> + %a = tt.join %arg0, %arg0 : tensor<32xf32> -> tensor<32x2xf32, #shared> tt.return } } // end module @@ -156,7 +156,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c tt.func public @fn(%arg0: tensor<32xf32, #shared>) { // expected-error @+2 {{can only operate on BlockedEncoding}} // expected-error @+1 {{op failed to infer returned types}} - %a = tt.experimental_join %arg0, %arg0 : tensor<32xf32, #shared> -> tensor<32x2xf32, #blocked> + %a = tt.join %arg0, %arg0 : tensor<32xf32, #shared> -> tensor<32x2xf32, #blocked> tt.return } } // end module @@ -170,7 +170,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c tt.func public @fn(%arg0: tensor<32xf32, #blocked>) { // expected-error @+2 {{order}} // expected-error @+1 {{op failed to infer returned types}} - %a = tt.experimental_join %arg0, %arg0 : tensor<32xf32, #blocked> -> tensor<32x2xf32, #blocked1> + %a = tt.join %arg0, %arg0 : tensor<32xf32, #blocked> -> tensor<32x2xf32, #blocked1> tt.return } } // end module @@ -180,7 +180,7 @@ tt.func public @fn(%arg0: tensor<32xf32, #blocked>) { tt.func public @fn(%arg0: tensor<32xf32>) { // expected-error @+2 {{last dimension}} // expected-error @+1 {{op failed to infer returned types}} - %a, %b = tt.experimental_split %arg0 : tensor<32xf32> -> tensor<16xf32> + %a, %b = tt.split %arg0 : tensor<32xf32> -> tensor<16xf32> tt.return } @@ -189,7 +189,7 @@ tt.func public @fn(%arg0: tensor<32xf32>) { tt.func public @fn(%arg0: tensor<32x2xf32>) { // expected-error @+2 {{op inferred type}} // expected-error @+1 {{op failed to infer returned types}} - %a, %b = tt.experimental_split %arg0 : tensor<32x2xf32> -> tensor<32xf16> + %a, %b = tt.split %arg0 : tensor<32x2xf32> -> tensor<32xf16> tt.return } @@ -197,13 +197,13 @@ tt.func public @fn(%arg0: tensor<32x2xf32>) { tt.func public @fn(%arg0: f32) { // expected-error @+1 {{invalid kind of type}} - %a, %b = tt.experimental_split %arg0 : f32 -> f16 + %a, %b = tt.split %arg0 : f32 -> f16 tt.return } // ----- tt.func public @fn(%arg0: tensor<2xf32>) { - %a, %b = tt.experimental_split %arg0 : tensor<2xf32> -> tensor // OK + %a, %b = tt.split %arg0 : tensor<2xf32> -> tensor // OK tt.return } @@ -217,7 +217,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { // expected-error @+2 {{last dimension}} // expected-error @+1 {{op failed to infer returned types}} - %a, %b = tt.experimental_split %arg0 : tensor<2x2x2xf32, #blocked> -> tensor<2x2xf32, #blocked1> + %a, %b = tt.split %arg0 : tensor<2x2x2xf32, #blocked> -> tensor<2x2xf32, #blocked1> tt.return } } // end module @@ -232,7 +232,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { // expected-error @+2 {{op inferred type}} // expected-error @+1 {{op failed to infer returned types}} - %a, %b = tt.experimental_split %arg0 : tensor<2x2x2xf32, #blocked> -> tensor<2x2xf32, #blocked1> + %a, %b = tt.split %arg0 : tensor<2x2x2xf32, #blocked> -> tensor<2x2xf32, #blocked1> tt.return } } // end module @@ -247,7 +247,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { // expected-error @+2 {{op inferred type}} // expected-error @+1 {{op failed to infer returned types}} - %a, %b = tt.experimental_split %arg0 : tensor<2x2x2xf32, #blocked> -> tensor<2x2xf32, #blocked1> + %a, %b = tt.split %arg0 : tensor<2x2x2xf32, #blocked> -> tensor<2x2xf32, #blocked1> tt.return } } // end module