Skip to content

Commit

Permalink
[Frontend] Upgrade split and join to be non-experimental ops. (#3204)
Browse files Browse the repository at this point in the history
[Frontend] Upgrade split and join to be non-experimental ops.

Also add interleave(), a simple helper based on join and reshape.
  • Loading branch information
jlebar authored Feb 27, 2024
1 parent 12bcc7c commit 99b024b
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 81 deletions.
3 changes: 3 additions & 0 deletions docs/python-api/triton.language.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ Shape Manipulation Ops
broadcast
broadcast_to
expand_dims
interleave
join
permute
ravel
reshape
split
trans
view

Expand Down
4 changes: 2 additions & 2 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)";
}

def TT_ExperimentalJoinOp : TT_Op<"experimental_join", [
def TT_JoinOp : TT_Op<"join", [
NoMemoryEffect, SameTypeOperands,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
]> {
Expand All @@ -401,7 +401,7 @@ def TT_ExperimentalJoinOp : TT_Op<"experimental_join", [
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)";
}

def TT_ExperimentalSplitOp : TT_Op<"experimental_split", [
def TT_SplitOp : TT_Op<"split", [
NoMemoryEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"outLHS and outRHS types match",
Expand Down
14 changes: 7 additions & 7 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ struct CatOpConversion : public ConvertOpToLLVMPattern<CatOp> {
return success();
}
};
struct JoinOpConversion : public ConvertOpToLLVMPattern<ExperimentalJoinOp> {
using OpAdaptor = typename ExperimentalJoinOp::Adaptor;
struct JoinOpConversion : public ConvertOpToLLVMPattern<JoinOp> {
using OpAdaptor = typename JoinOp::Adaptor;
explicit JoinOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<ExperimentalJoinOp>(typeConverter, benefit) {}
: ConvertOpToLLVMPattern<JoinOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(ExperimentalJoinOp op, OpAdaptor adaptor,
matchAndRewrite(JoinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We rely on the following invariants of this op (which are checked by its
// verifier):
Expand Down Expand Up @@ -157,11 +157,11 @@ struct JoinOpConversion : public ConvertOpToLLVMPattern<ExperimentalJoinOp> {
return success();
}
};
struct SplitOpConversion : public ConvertOpToLLVMPattern<ExperimentalSplitOp> {
using OpAdaptor = typename ExperimentalSplitOp::Adaptor;
struct SplitOpConversion : public ConvertOpToLLVMPattern<SplitOp> {
using OpAdaptor = typename SplitOp::Adaptor;
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ExperimentalSplitOp op, OpAdaptor adaptor,
matchAndRewrite(SplitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We rely on the following invariants of this op (which are checked by its
// verifier):
Expand Down
17 changes: 7 additions & 10 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,27 +321,25 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
}
};

struct TritonJoinOpPattern
: public OpConversionPattern<triton::ExperimentalJoinOp> {
struct TritonJoinOpPattern : public OpConversionPattern<triton::JoinOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(ExperimentalJoinOp op, OpAdaptor adaptor,
LogicalResult matchAndRewrite(JoinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Simply rely on type inference for this op. (Notably, GenericOpPattern
// does not do this, instead it assigns the default layout to the ins and
// outs.)
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ExperimentalJoinOp>(
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::JoinOp>(
op, adaptor.getLhs(), adaptor.getRhs()),
adaptor.getAttributes());
return success();
}
};

struct TritonSplitOpPattern
: public OpConversionPattern<triton::ExperimentalSplitOp> {
struct TritonSplitOpPattern : public OpConversionPattern<triton::SplitOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(ExperimentalSplitOp op, OpAdaptor adaptor,
LogicalResult matchAndRewrite(SplitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto src = adaptor.getSrc();
auto srcTy = src.getType().cast<RankedTensorType>();
Expand Down Expand Up @@ -393,9 +391,8 @@ struct TritonSplitOpPattern
src = rewriter.create<ConvertLayoutOp>(op.getLoc(), srcTy, src);
}

addNamedAttrs(
rewriter.replaceOpWithNewOp<triton::ExperimentalSplitOp>(op, src),
adaptor.getAttributes());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::SplitOp>(op, src),
adaptor.getAttributes());
return success();
}
};
Expand Down
15 changes: 8 additions & 7 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1026,11 +1026,12 @@ LogicalResult ReturnOp::verify() {
return success();
}

// -- ExperimentalJoinOp --
LogicalResult ExperimentalJoinOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// -- JoinOp --
LogicalResult
JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// These should have been checked by tablegen-generated code.
assert(operands.size() == 2);
assert(operands[0].getType() == operands[1].getType());
Expand Down Expand Up @@ -1058,8 +1059,8 @@ LogicalResult ExperimentalJoinOp::inferReturnTypes(
return success();
}

// -- ExperimentalSplitOp --
LogicalResult ExperimentalSplitOp::inferReturnTypes(
// -- SplitOp --
LogicalResult SplitOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
}
if (user->hasTrait<OpTrait::SameOperandsAndResultEncoding>() ||
user->hasTrait<OpTrait::Elementwise>() ||
isa<ReduceOp, ExpandDimsOp, ReshapeOp, ExperimentalJoinOp,
ExperimentalSplitOp, ConvertLayoutOp>(user)) {
isa<ReduceOp, ExpandDimsOp, ReshapeOp, JoinOp, SplitOp,
ConvertLayoutOp>(user)) {
setEncoding(user->getResults(), info, changed, user);
continue;
}
Expand Down Expand Up @@ -706,8 +706,8 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
}
if (op->hasTrait<OpTrait::SameOperandsAndResultEncoding>() ||
op->hasTrait<OpTrait::Elementwise>() ||
isa<ReduceOp, ExpandDimsOp, ReshapeOp, ExperimentalJoinOp,
ExperimentalSplitOp, ConvertLayoutOp>(op)) {
isa<ReduceOp, ExpandDimsOp, ReshapeOp, JoinOp, SplitOp, ConvertLayoutOp>(
op)) {
Operation *newOp = cloneElementwise(rewriter, op, encoding);
for (auto [oldResult, newResult] :
llvm::zip(op->getResults(), newOp->getResults()))
Expand Down
20 changes: 8 additions & 12 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,7 @@ static std::optional<Attribute> inferDstEncoding(triton::ExpandDimsOp op,
return sliceEncoding.getParent();
}

static std::optional<Attribute> inferDstEncoding(ExperimentalJoinOp op,
Attribute srcEnc) {
static std::optional<Attribute> inferDstEncoding(JoinOp op, Attribute srcEnc) {
Attribute dstEnc;
if (srcEnc.getDialect()
.getRegisteredInterface<DialectInferLayoutInterface>()
Expand All @@ -299,8 +298,7 @@ static std::optional<Attribute> inferDstEncoding(ExperimentalJoinOp op,
return std::nullopt;
}

static std::optional<Attribute> inferDstEncoding(ExperimentalSplitOp op,
Attribute srcEnc) {
static std::optional<Attribute> inferDstEncoding(SplitOp op, Attribute srcEnc) {
Attribute dstEnc;
if (srcEnc.getDialect()
.getRegisteredInterface<DialectInferLayoutInterface>()
Expand Down Expand Up @@ -328,8 +326,7 @@ static std::optional<Attribute> inferSrcEncoding(triton::ExpandDimsOp op,
encoding);
}

static std::optional<Attribute> inferSrcEncoding(ExperimentalJoinOp op,
Attribute dstEnc) {
static std::optional<Attribute> inferSrcEncoding(JoinOp op, Attribute dstEnc) {
// Split is the inverse of join.
Attribute srcEnc;
if (dstEnc.getDialect()
Expand All @@ -341,8 +338,7 @@ static std::optional<Attribute> inferSrcEncoding(ExperimentalJoinOp op,
return std::nullopt;
}

static std::optional<Attribute> inferSrcEncoding(ExperimentalSplitOp op,
Attribute dstEnc) {
static std::optional<Attribute> inferSrcEncoding(SplitOp op, Attribute dstEnc) {
// Join is the inverse of split.
Attribute srcEnc;
if (dstEnc.getDialect()
Expand Down Expand Up @@ -438,9 +434,9 @@ std::optional<Attribute> inferSrcEncoding(Operation *op, Attribute encoding) {
return inferSrcEncoding(reduceOp, encoding);
if (auto expand = dyn_cast<triton::ExpandDimsOp>(op))
return inferSrcEncoding(expand, encoding);
if (auto join = dyn_cast<triton::ExperimentalJoinOp>(op))
if (auto join = dyn_cast<triton::JoinOp>(op))
return inferSrcEncoding(join, encoding);
if (auto split = dyn_cast<triton::ExperimentalSplitOp>(op))
if (auto split = dyn_cast<triton::SplitOp>(op))
return inferSrcEncoding(split, encoding);
if (auto trans = dyn_cast<triton::TransOp>(op))
return inferSrcEncoding(trans, encoding);
Expand All @@ -464,9 +460,9 @@ std::optional<Attribute> inferDstEncoding(Operation *op, Attribute encoding) {
return inferDstEncoding(reduceOp, encoding);
if (auto expand = dyn_cast<triton::ExpandDimsOp>(op))
return inferDstEncoding(expand, encoding);
if (auto join = dyn_cast<triton::ExperimentalJoinOp>(op))
if (auto join = dyn_cast<triton::JoinOp>(op))
return inferDstEncoding(join, encoding);
if (auto split = dyn_cast<triton::ExperimentalSplitOp>(op))
if (auto split = dyn_cast<triton::SplitOp>(op))
return inferDstEncoding(split, encoding);
if (auto trans = dyn_cast<triton::TransOp>(op))
return inferDstEncoding(trans, encoding);
Expand Down
4 changes: 2 additions & 2 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1191,11 +1191,11 @@ void init_triton_ir(py::module &&m) {
})
.def("create_join",
[](TritonOpBuilder &self, Value &a, Value &b) -> Value {
return self.create<ExperimentalJoinOp>(a, b);
return self.create<JoinOp>(a, b);
})
.def("create_split",
[](TritonOpBuilder &self, Value &a) -> std::vector<Value> {
auto op = self.create<ExperimentalSplitOp>(a);
auto op = self.create<SplitOp>(a);
return std::vector<Value>(op->result_begin(), op->result_end());
})
// Implements tl.trans and tl.permute.
Expand Down
44 changes: 39 additions & 5 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,7 +1511,7 @@ def kernel(X, Y, Z, N: tl.constexpr):
offs = tl.arange(0, N)
x = tl.load(X + offs)
y = tl.load(Y + offs)
z = tl._experimental_join(x, y)
z = tl.join(x, y)
tl.store(Z + tl.arange(0, N)[:, None] * 2 + tl.arange(0, 2)[None, :], z)

x = torch.arange(0, 128, device=device).to(torch.int32)
Expand All @@ -1531,7 +1531,7 @@ def test_join_scalars(device):
def kernel(X, Y, Z):
x = tl.load(X)
y = tl.load(Y)
z = tl._experimental_join(x, y)
z = tl.join(x, y)
tl.static_assert(z.shape == [2])
tl.store(Z + tl.arange(0, 2), z)

Expand All @@ -1550,7 +1550,7 @@ def test_join_with_mma(device):
@triton.jit
def kernel(X, Z):
x = tl.load(X + 16 * tl.arange(0, 32)[:, None] + tl.arange(0, 16)[None, :]) # (32,16)
x2 = tl._experimental_join(x, 2 * x) # (32,16,2)
x2 = tl.join(x, 2 * x) # (32,16,2)
x3 = tl.reshape(x2, (32, 32))
z = tl.dot(x3, x3) # (32,32)
tl.store(Z + 32 * tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :], z)
Expand All @@ -1564,6 +1564,40 @@ def kernel(X, Z):
torch.testing.assert_close(z, z_ref)


def test_interleave(device):
if is_hip():
pytest.skip("test_interleaves not supported on HIP")

@triton.jit
def kernel(Z, N: tl.constexpr):
z = tl.interleave(tl.arange(0, N), tl.arange(N, 2 * N))
tl.store(Z + tl.arange(0, 2 * N), z)

x = torch.arange(0, 128, device=device).to(torch.int32)
y = torch.arange(128, 256, device=device).to(torch.int32)
z_ref = torch.stack([x, y], dim=-1).reshape(256)
z = torch.zeros_like(z_ref)
kernel[(1, )](z, N=128)

np.testing.assert_equal(to_numpy(z_ref), to_numpy(z))


def test_interleave_scalars(device):
if is_hip():
pytest.skip("test_interleaves not supported on HIP")

@triton.jit
def kernel(X, Y, Z):
z = tl.interleave(X, Y)
tl.static_assert(z.shape == [tl.constexpr(2)])
tl.store(Z + tl.arange(0, 2), z)

z = torch.zeros(2, device=device)
kernel[(1, )](10, 20, z)

np.testing.assert_equal([10, 20], to_numpy(z))


def test_split(device):
if is_hip():
pytest.skip("test_split not supported on HIP")
Expand All @@ -1573,7 +1607,7 @@ def kernel(X, Z1, Z2, N: tl.constexpr):
offs = tl.arange(0, N)
x = tl.load(X + offs)
x1 = tl.reshape(x, (N // 2, 2))
z1, z2 = tl._experimental_split(x1)
z1, z2 = tl.split(x1)
tl.store(Z1 + tl.arange(0, N // 2), z1)
tl.store(Z2 + tl.arange(0, N // 2), z2)

Expand All @@ -1595,7 +1629,7 @@ def test_split_to_scalar(device):
def kernel(X, Z1, Z2):
offs = tl.arange(0, 2)
x = tl.load(X + offs)
z1, z2 = tl._experimental_split(x)
z1, z2 = tl.split(x)
tl.static_assert(isinstance(z1, tl.tensor))
tl.static_assert(isinstance(z2, tl.tensor))
tl.static_assert(z1.shape == [])
Expand Down
Loading

0 comments on commit 99b024b

Please sign in to comment.