Skip to content

Commit

Permalink
[BACKEND] Support Hopper MMA to MMA convert_layout ops (#4492)
Browse files Browse the repository at this point in the history
This PR enables mma to mma conversion on the hopper architecture.
We also replace the previous `isMmaToMmaShortcut` check with
`cvtReordersRegisters` in several places.
Note that mma to mma conversion using shared memory still goes through
the legacy `ConvertLayoutOpConversion` function; we will deprecate it
soon in the next PR.
  • Loading branch information
Jokeren authored Aug 12, 2024
1 parent 6a9a0a6 commit 7d89248
Show file tree
Hide file tree
Showing 11 changed files with 214 additions and 250 deletions.
16 changes: 11 additions & 5 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,24 @@ bool supportMMA(triton::DotOp op, int version);

bool supportMMA(Value value, int version);

// Conversion from `srcTy` to `dstTy` only involves reordering of registers.
// There is no need for data exchange across threads, warps, or blocks.
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy);

// Conversion from `srcTy` to `dstTy` involves data exchange across threads
// within a warp. No data exchange across warps or blocks is needed.
bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy);

// Conversion from `srcTy` to `dstTy` involves data exchange across threads,
// warps, and possibly blocks.
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);

bool atomicNeedsSharedMemory(Value result);

bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// TODO(jlebar): Remove this function; it's subsumed by the linear-layout case
// in cvtNeedsSharedMemory.
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// Return true if the src and dst layout match.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy);
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void decomposeSplatOpToSharedLayoutConversion(ModuleOp module);
/// Replaces `mma/mfma -> dot_op` with `mma/mfma -> blocked -> dot_op` in the
/// given |module| op, but bypass the decomposition if |shortcutFn| returns
/// true.
using ShortcutFn = std::function<bool(RankedTensorType &, RankedTensorType &)>;
using ShortcutFn = std::function<bool(RankedTensorType, RankedTensorType)>;
template <typename TensorCoreEncodingAttr>
void decomposeTensorCoreToDotLayoutConversion(ModuleOp module,
ShortcutFn shortcutFn);
Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
auto dstMmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(dstLayout);
auto dstDotLayout = mlir::dyn_cast<DotOperandEncodingAttr>(dstLayout);

assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) &&
assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere() &&
!srcMmaLayout.isHopper()) &&
"mma -> mma layout conversion is only supported on Ampere");

// mma or dot layout does not have an order, so the order depends on the
Expand Down
77 changes: 46 additions & 31 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ bool supportMMA(Value value, int version) {
(elemTy.isInteger(8) && version >= 2);
}

bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
Expand All @@ -543,21 +543,6 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}

static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
auto src = dyn_cast<NvidiaMmaEncodingAttr>(srcEncoding);
auto dst = dyn_cast<NvidiaMmaEncodingAttr>(dstEncoding);
if (!src || !dst)
return false;
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
return src && dst && src.getVersionMajor() == 3 &&
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
dst.getWarpsPerCTA()[1] == 1;
}

bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding());
}

// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
Expand All @@ -567,14 +552,16 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
return false;
}
int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth();
auto ans =
mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 &&
isMmaToMmaShortcut(dotOperandLayout.getParent(), srcTy.getEncoding()) &&
(elementTypeSize == 16 || elementTypeSize == 8);
auto parentTy = RankedTensorType::get(
srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent());
auto ans = mmaLayout.getVersionMajor() == 3 &&
dotOperandLayout.getOpIdx() == 0 &&
!cvtNeedsSharedMemory(parentTy, srcTy) &&
(elementTypeSize == 16 || elementTypeSize == 8);
return ans;
}

bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) {
MLIRContext *ctx = srcTy.getContext();
std::optional<LinearLayout> srcLayout =
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
Expand All @@ -586,26 +573,54 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");
// In principle, there's no need for shared memory if there's no
// communication between warps. However, right now we only have implemented
// the shortcut case where there's no communication between *threads*.
//
// TODO(jlebar): Remove the kLane layout once we add support for
// shuffle-based layout conversions in ConvertLayoutToLLVM.
// TODO(jlebar): These checks are overly-restrictive. For example, we can
// transfer by shuffling registers (case 1) if and only if all of the bases
// for `register` have 0s for lane, warp, and block. But the check below is
// stronger than this, checking also that the choice of lane/warp/block does
// not affect the permutation of registers. If we allow different
// lane/warp/blocks to have different permutations, we can generalize this.
if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kLane),
kLane, kLane) *
LinearLayout::identity1D(comp.getInDimSize(kWarp),
kWarp, kWarp) *
LinearLayout::identity1D(comp.getInDimSize(kBlock),
kBlock, kBlock))
.has_value()) {
return false;
return true;
}
}
return false;
}

// TODO(jlebar): Remove these special cases once they're fully subsumed by the
// linear-layout check above.
return !isMmaToMmaShortcut(srcTy, dstTy) &&
bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
MLIRContext *ctx = srcTy.getContext();
std::optional<LinearLayout> srcLayout =
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
std::optional<LinearLayout> dstLayout =
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
if (srcLayout.has_value() && dstLayout.has_value()) {
// comp describes the layout function for converting from src to dst.
LinearLayout comp = srcLayout->invertAndCompose(*dstLayout);
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");
if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kWarp),
kWarp, kWarp) *
LinearLayout::identity1D(comp.getInDimSize(kBlock),
kBlock, kBlock))
.has_value()) {
return true;
}
}
return false;
}

bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut` and
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
// checks.
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
// supported yet in Triton's backend.
return !cvtReordersRegisters(srcTy, dstTy) &&
!isMmaToDotShortcut(srcTy, dstTy) &&
!isMfmaToDotShortcut(srcTy, dstTy);
}
Expand Down
105 changes: 55 additions & 50 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
MLIRContext *ctx = op.getContext();

const auto &shape = op.getType().getShape();
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
std::optional<LinearLayout> srcLayout =
toLinearLayout(shape, op.getSrc().getType().getEncoding());
toLinearLayout(shape, srcTy.getEncoding());
std::optional<LinearLayout> dstLayout =
toLinearLayout(shape, op.getType().getEncoding());
toLinearLayout(shape, dstTy.getEncoding());
if (!srcLayout.has_value() || !dstLayout.has_value()) {
return failure();
}
Expand All @@ -270,93 +272,94 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// 4. Transfer between values in different CTAs, in which case we move
// values through distributed shared memory.
//
// We can tell which case we're in by examining `conversion`. If e.g. the
// block -> block mapping is {1, 2, 4, ...} then there's no movement between
// data in different CTAs and we know we're not in case 4.
LinearLayout conversion = srcLayout->invertAndCompose(*dstLayout);

int numLanes = conversion.getInDimSize(str_attr("lane"));
int numWarps = conversion.getInDimSize(str_attr("warp"));
int numBlocks = conversion.getInDimSize(str_attr("block"));

StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");
StringAttr kBlock = str_attr("block");

// TODO(jlebar): These checks are overly-restrictive. For example, we can
// transfer by shuffling registers (case 1) if and only if all of the bases
// for `register` have 0s for lane, warp, and block. But the check below is
// stronger than this, checking also that the choice of lane/warp/block does
// not affect the permutation of registers. If we allow different
// lane/warp/blocks to have different permutations, we can generalize this.
if (std::optional<LinearLayout> c = conversion.divideRight(
LinearLayout::identity1D(numLanes, kLane, kLane) *
LinearLayout::identity1D(numWarps, kWarp, kWarp) *
LinearLayout::identity1D(numBlocks, kBlock, kBlock));
c.has_value()) {
return transferWithinThread(*c, op, adaptor, rewriter);
// We can tell which case we're in by examining `conversion`.
// For example, if the block -> block mapping is an identity layout: {1, 2,
// 4, ...}, then there's no movement between data in different CTAs, and we
// know we're not in case 4.
if (cvtReordersRegisters(srcTy, dstTy)) { // Case 1.
return transferWithinThread(op, *srcLayout, *dstLayout, adaptor,
rewriter);
}

if (std::optional<LinearLayout> c = conversion.divideRight(
LinearLayout::identity1D(numWarps, kWarp, kWarp) *
LinearLayout::identity1D(numBlocks, kBlock, kBlock));
c.has_value()) {
return transferWithinLane(*c, op, adaptor, rewriter);
if (cvtNeedsWarpShuffle(srcTy, dstTy)) { // Case 2.
return transferWithinLane(op, *srcLayout, *dstLayout, adaptor, rewriter);
}

return transferWithinBlockOrGroup(conversion, op, *srcLayout, *dstLayout,
adaptor, rewriter);
return transferWithinBlockOrGroup(op, *srcLayout, *dstLayout, adaptor,
rewriter); // Case 3 and 4
}

LogicalResult
transferWithinThread(const LinearLayout &conversion, ConvertLayoutOp op,
OpAdaptor adaptor,
transferWithinThread(ConvertLayoutOp op, const LinearLayout &srcLayout,
const LinearLayout &dstLayout, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();
StringAttr kRegister = str_attr("register");
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");
StringAttr kBlock = str_attr("block");

// There are three possible cases:
//
// 1. `srcLayout` has the same number of registers as `dstLayout`.
// 2. `srcLayout` has fewer registers than `dstLayout`.
// 3. `srcLayout` has more registers than `dstLayout`.
//
// In the second case `srcLayout . dstLayout^-1` is not surjective
// because not all destination registers are covered.
// Since the goal is to cover all of the destination
// registers, we can instead use `dstLayout . srcLayout^-1`.
LinearLayout conversion = dstLayout.invertAndCompose(srcLayout);
auto dstToSrc = conversion.divideRight(
LinearLayout::identity1D(conversion.getInDimSize(kLane), kLane, kLane) *
LinearLayout::identity1D(conversion.getInDimSize(kWarp), kWarp, kWarp) *
LinearLayout::identity1D(conversion.getInDimSize(kBlock), kBlock,
kBlock));

assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
assert(ArrayRef(to_vector(conversion.getInDimNames())) ==
assert(ArrayRef(to_vector(dstToSrc->getInDimNames())) ==
ArrayRef{kRegister});
assert(ArrayRef(to_vector(conversion.getOutDimNames())) ==
assert(ArrayRef(to_vector(dstToSrc->getOutDimNames())) ==
ArrayRef{kRegister});

auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
SmallVector<Value> outVals(conversion.getOutDimSize(kRegister));
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
auto dstIdx = conversion.apply({{kRegister, i}});
outVals[dstIdx.begin()->second] = inVals[i];
SmallVector<Value> outVals;
outVals.resize(dstToSrc->getInDimSize(kRegister));
for (int i = 0; i < dstToSrc->getInDimSize(kRegister); i++) {
auto srcIdx = dstToSrc->apply({{kRegister, i}});
outVals[i] = inVals[srcIdx.begin()->second];
}
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
rewriter.replaceOp(op, result);
return success();
}

LogicalResult transferWithinLane(const LinearLayout &conversion,
ConvertLayoutOp op, OpAdaptor adaptor,
LogicalResult transferWithinLane(ConvertLayoutOp op,
const LinearLayout &srcLayout,
const LinearLayout &dstLayout,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// TODO(jlebar): Implement me.
return failure();
}

LogicalResult
transferWithinBlockOrGroup(const LinearLayout &conversion, ConvertLayoutOp op,
const LinearLayout &srcLayout,
transferWithinBlockOrGroup(ConvertLayoutOp op, const LinearLayout &srcLayout,
const LinearLayout &dstLayout, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
LinearLayout conversion = srcLayout.invertAndCompose(dstLayout);

// TODO(Keren): LLs support cross-CTA conversions, this function does not
if (isCrossCTAConversion(conversion))
return failure();

MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();

assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));

// TODO(jlebar): For now we handle only blocked/slice -> blocked/slice
// conversions. Once we have ldmatrix support in
// TODO(jlebar): For now we handle only blocked/slice ->
// blocked/slice conversions. Once we have ldmatrix support in
// load/storeDistributedToShared, we can remove this constraint.
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
if (isa<BlockedEncodingAttr>(layout)) {
Expand All @@ -372,6 +375,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return failure();
}

assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));

SmallVector<Value> inVals =
unpackLLElements(loc, adaptor.getSrc(), rewriter);
assert(!inVals.empty());
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ LogicalResult MakeRangeOp::verify() {

//-- ReduceOp --
static LogicalResult
inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy,
int axis, SmallVectorImpl<Type> &inferredReturnTypes) {
inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto retShape = argTy.getShape().vec();
retShape.erase(retShape.begin() + axis);
if (retShape.empty()) {
Expand Down
Loading

0 comments on commit 7d89248

Please sign in to comment.