Skip to content

Commit

Permalink
make sublayoutIsIdentity correct, thinking of layouts as linear funct…
Browse files Browse the repository at this point in the history
…ions
  • Loading branch information
lezcano committed Oct 22, 2024
1 parent 5573541 commit b67f53c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 56 deletions.
8 changes: 4 additions & 4 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,10 +604,10 @@ class LinearLayout {
bool sublayoutIsZero(ArrayRef<StringAttr> inDimNames,
ArrayRef<StringAttr> outDimNames) const;

// Is the sublayout restricted to inDimNames + outDimNames and then flattened
// to 1D the identity layout (ignoring out-dim sizes)?
bool sublayoutIsIdentity(ArrayRef<StringAttr> inDimNames,
ArrayRef<StringAttr> outDimNames) const;
// Is the sublayout defined from dimNames to dimNames the identity?
// In particular, is the input and output size in these dimensions
// the same, and are the bases the identity?
bool squareSublayoutIsIdentity(ArrayRef<StringAttr> dimNames) const;

// Computes and returns L(x, y, z).
//
Expand Down
9 changes: 5 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// simply reorder the elements of adaptor.getSrc().
return transferWithinThread(op, *conversion, adaptor, rewriter);
} else {
// Nothing to do. We should remove these ops in removeLayoutConversion.
rewriter.replaceOp(op, op.getSrc());
// The two layouts are equivalent. We should probably remove these in
// RemoveLayoutConversion.
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}
}
Expand All @@ -338,8 +339,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion

auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
SmallVector<Value> outVals;
outVals.resize(conversion.getOutDimSize(kRegister));
for (int i = 0; i < conversion.getOutDimSize(kRegister); i++) {
outVals.resize(conversion.getInDimSize(kRegister));
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second;
outVals[i] = inVals[srcIdx];
}
Expand Down
21 changes: 15 additions & 6 deletions lib/Tools/LinearLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ bool LinearLayout::canQuotient(ArrayRef<StringAttr> dimNames) const {
// We can quotient out dimNames iff they don't affect the remainingInDimNames
// in the result. In other words, we want to check that B is zero, and C is
// zero, and D is the identity
return sublayoutIsIdentity(dimNames, dimNames) &&
return squareSublayoutIsIdentity(dimNames) &&
sublayoutIsZero(remainingInDimNames, dimNames) &&
sublayoutIsZero(dimNames, remainingOutDimNames);
}
Expand Down Expand Up @@ -760,13 +760,22 @@ bool LinearLayout::sublayoutIsZero(ArrayRef<StringAttr> inDimNames,
return true;
}

bool LinearLayout::sublayoutIsIdentity(ArrayRef<StringAttr> inDimNames,
ArrayRef<StringAttr> outDimNames) const {
LinearLayout sl =
sublayout(inDimNames, outDimNames).flattenIns().flattenOuts();
if (sl.getNumInDims() == 0 || sl.getNumOutDims() == 0) {
bool LinearLayout::squareSublayoutIsIdentity(
ArrayRef<StringAttr> dimNames) const {
// The empty layout is the identity
if (dimNames.size() == 0) {
return true;
}
// Check that the input-output sizes are the same
LinearLayout sl = sublayout(dimNames, dimNames);
for (StringAttr dim : dimNames) {
if (getInDimSize(dim) != getOutDimSize(dim)) {
return false;
}
}
// Once the inputs and output dimensions are the same, we can just check
// that the basis for the single remaining dimension is the identity.
sl = sl.flattenIns().flattenOuts();
int b = 0;
const auto &inDimBases = sl.bases.begin()->second;
for (auto basis : inDimBases) {
Expand Down
65 changes: 23 additions & 42 deletions unittest/Tools/LinearLayoutTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,51 +613,32 @@ TEST_F(LinearLayoutTest, SublayoutIsZero) {
}

TEST_F(LinearLayoutTest, SublayoutIsIdentity) {
EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out"))
.sublayoutIsIdentity({S("in")}, {S("out")}));
EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out"))
.sublayoutIsIdentity({}, {S("out")}));
EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out"))
.sublayoutIsIdentity({S("in")}, {}));
EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out"))
.sublayoutIsIdentity({}, {}));
EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("in"))
.squareSublayoutIsIdentity({S("in")}));
EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("in"))
.squareSublayoutIsIdentity({}));

LinearLayout l1(
{{S("in1"), {{1, 1}, {2, 2}, {4, 4}}}, {S("in2"), {{2, 1}, {1, 2}}}},
{{S("out1"), 8}, {S("out2"), 8}}, /*requireSurjective=*/false);
EXPECT_FALSE(l1.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1")}));
EXPECT_FALSE(l1.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out2")}));
EXPECT_FALSE(l1.sublayoutIsIdentity({S("in1")}, {S("out1"), S("out2")}));
EXPECT_FALSE(l1.sublayoutIsIdentity({S("in1")}, {S("out2"), S("out1")}));
EXPECT_TRUE(l1.sublayoutIsIdentity({S("in1")}, {S("out1")}));
EXPECT_TRUE(l1.sublayoutIsIdentity({S("in1")}, {S("out2")}));
EXPECT_FALSE(l1.sublayoutIsIdentity({S("in2")}, {S("out1")}));
EXPECT_TRUE(l1.sublayoutIsIdentity({S("in2")}, {S("out2")}));

LinearLayout l2 =
LinearLayout::identity1D(4, S("in1"), S("out1")) *
LinearLayout::identity1D(8, S("in2"), S("out2")) *
LinearLayout({{S("in3"), {{1, 1, 1}}}},
{{S("out1"), 2}, {S("out2"), 2}, {S("out3"), 2}},
/*requireSurjective=*/false);
EXPECT_TRUE(l2.sublayoutIsIdentity({S("in1")}, {S("out1")}));
EXPECT_TRUE(l2.sublayoutIsIdentity({S("in2")}, {S("out2")}));
EXPECT_TRUE(l2.sublayoutIsIdentity({S("in3")}, {S("out3")}));
EXPECT_FALSE(
l2.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1"), S("out2")}));
EXPECT_FALSE(l2.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1")}));
EXPECT_TRUE(l2.sublayoutIsIdentity({S("in1"), S("in3")}, {S("out1")}));

LinearLayout l3 = LinearLayout::identity1D(4, S("in1"), S("out1")) *
LinearLayout::identity1D(8, S("in2"), S("out2"));
EXPECT_TRUE(l3.sublayoutIsIdentity({S("in1")}, {S("out1")}));
EXPECT_TRUE(l3.sublayoutIsIdentity({S("in2")}, {S("out2")}));
EXPECT_FALSE(l3.sublayoutIsIdentity({S("in1")}, {S("out2")}));
EXPECT_FALSE(l3.sublayoutIsIdentity({S("in2")}, {S("out1")}));
EXPECT_FALSE(l3.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1")}));
EXPECT_FALSE(l3.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out2")}));
EXPECT_TRUE(
l3.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1"), S("out2")}));
{{S("in1"), 8}, {S("in2"), 8}}, /*requireSurjective=*/false);
EXPECT_TRUE(l1.squareSublayoutIsIdentity({S("in1")}));
EXPECT_FALSE(l1.squareSublayoutIsIdentity({S("in2")}));

LinearLayout l2 = LinearLayout::identity1D(4, S("in1"), S("in1")) *
LinearLayout::identity1D(8, S("in2"), S("in2")) *
LinearLayout({{S("in3"), {{1, 1, 1}}}},
{{S("in1"), 2}, {S("in2"), 2}, {S("in3"), 2}},
/*requireSurjective=*/false);
EXPECT_FALSE(l2.squareSublayoutIsIdentity({S("in1")}));
EXPECT_FALSE(l2.squareSublayoutIsIdentity({S("in2")}));
EXPECT_TRUE(l2.squareSublayoutIsIdentity({S("in3")}));
EXPECT_FALSE(l2.squareSublayoutIsIdentity({S("in1"), S("in2")}));

LinearLayout l3 = LinearLayout::identity1D(4, S("in1"), S("in1")) *
LinearLayout::identity1D(8, S("in2"), S("in2"));
EXPECT_TRUE(l3.squareSublayoutIsIdentity({S("in1")}));
EXPECT_TRUE(l3.squareSublayoutIsIdentity({S("in2")}));
EXPECT_TRUE(l3.squareSublayoutIsIdentity({S("in1"), S("in2")}));
}

TEST_F(LinearLayoutTest, FreeVariableMasks) {
Expand Down

0 comments on commit b67f53c

Please sign in to comment.