Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LL::quotient and remove uses of divideRight and sublayoutIsIdentity #4968

Merged
merged 7 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "mlir/Support/LLVM.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/LinearLayout.h"

namespace mlir {

Expand Down Expand Up @@ -189,6 +190,14 @@ bool supportMMA(triton::DotOp op, int version);

bool supportMMA(Value value, int version);

// Conversion from `srcTy` to `dstTy` involving the minimum amount of data
// transfer provided that both types can be converted to LL (if it can't it'll
// return nullopt). The output will be such that layout.getInDimNames() ==
// layout.getOutDimNames() and the conversion will not include kBlock (resp.
// kWarp or kLane) if it can be avoided
std::optional<mlir::triton::LinearLayout>
minimalCvtLayout(RankedTensorType srcTy, RankedTensorType dstTy);

// Conversion from `srcTy` to `dstTy` only involves reordering of registers.
// There is no need for data exchange across threads, warps, or blocks.
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ std::optional<LinearLayout>
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth = std::nullopt);

// Given a linear layout with input dims and output dims containing a "block"
// dimension, determines if the layout moves data across block boundaries.
bool isCrossCTAConversion(const LinearLayout &layout);

// Given a linear layout where the input dimensions contain a "block" dimension,
// this method sets the "block" dimension to 0 and removes the corresponding
// output dimensions.
Expand Down
45 changes: 18 additions & 27 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,29 +575,20 @@ class LinearLayout {
return *this;
}

// divideLeft and divideRight are the inverses of operator*.
//
// Consider `a = c.divideRight(b)`, where `a` is a linear layout with
// `in-dims(a) == in-dims(b)` and `out-dims(a) == out-dims(c)`. We may remove
// some empty dimensions from `a` to form `a'` and still have `a' * b == c`.
// Therefore, there are multiple possible values that we could return for
// `(a * b).divideRight(b)` which would satisfy
// `((a * b).divideRight(b)) * b == a * b`.
//
// In the following example, we have `a * b == a' * b` when "in1" is an empty
// dimension that maps everything to 0:
//
// a = L("in1", "in2") -> ("out1", "out2")
// a' = L("in1") -> ("out1")
// b = L("in2") -> ("out2")
//
// divideLeft and divideRight resolve this ambiguity by always returning the
// "canonical" quotient, namely the one with the fewest possible size-zero
// input and output dimensions.
//
// TODO(jlebar): Implement divideLeft.
// std::optional<LinearLayout> divideLeft(const LinearLayout &divisor);
std::optional<LinearLayout> divideRight(const LinearLayout &divisor) const;
// Returns true if this layout acts trivially (as the identity) on the given
// dimensions. This means that it's the identity on those dimensions, and it
// does not map other dimensions onto those or these onto other dimensions.
bool isTrivialOver(ArrayRef<StringAttr> dimNames) const;

// For an endomorphism on dimNames (linear map that maps dimNames to dimNames)
// checks whether it is the identity map on these dimensions (i.e
// LinearLayouts::isTrivialOver) and if so, returns the sublayout of the
// remaining dimensions.
// nb. The isTrivialOver condition is more restrictive than the usual
// "leaves the subspace invariant" condition in maths.
// We can always relax it if we know how to take advantage of a conversion
// layout being block-diagonal in the future.
std::optional<LinearLayout> quotient(ArrayRef<StringAttr> dimNames) const;
lezcano marked this conversation as resolved.
Show resolved Hide resolved

// Gets a layout with only these in/out dimensions.
//
Expand All @@ -614,10 +605,10 @@ class LinearLayout {
bool sublayoutIsZero(ArrayRef<StringAttr> inDimNames,
ArrayRef<StringAttr> outDimNames) const;

// Is the sublayout restricted to inDimNames + outDimNames and then flattened
// to 1D the identity layout (ignoring out-dim sizes)?
bool sublayoutIsIdentity(ArrayRef<StringAttr> inDimNames,
ArrayRef<StringAttr> outDimNames) const;
// Is the sublayout defined from dimNames to dimNames the identity?
// In particular, is the input and output size in these dimensions
// the same, and are the bases the identity?
bool squareSublayoutIsIdentity(ArrayRef<StringAttr> dimNames) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, are we ignoring the case where the input dimension names are not equal to the output dimension names, but the input and output are identity mappings?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The identity in maths is just defined between a space and itself, so yes, this is the correct mathematical concept. Even if you have registers that map 1-to-1 to a vector of outputs, this would not be an identity map strictly speaking, as it's not mapping elements to themselves, but identifying registers with a matrix.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, makes sense then


// Computes and returns L(x, y, z).
//
Expand Down
77 changes: 38 additions & 39 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -640,57 +640,56 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
return ans;
}

bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) {
// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
// under kBlock, kWarp or kLane (in that order). The idea here is that if we
// have a transformation that's the identity on kBlock, we don't need to use
// distributed shared memory. If it's also the identity on kWarp, we can
// transfer via warp-shuffles, and if it's the identity on kLane just have to
// reorder the registers
std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
MLIRContext *ctx = srcTy.getContext();
std::optional<LinearLayout> srcLayout =
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
std::optional<LinearLayout> dstLayout =
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
if (srcLayout.has_value() && dstLayout.has_value()) {
// comp describes the layout function for converting from src to dst.
LinearLayout comp = srcLayout->invertAndCompose(*dstLayout);
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");
// TODO(jlebar): These checks are overly-restrictive. For example, we can
// transfer by shuffling registers (case 1) if and only if all of the bases
// for `register` have 0s for lane, warp, and block. But the check below is
// stronger than this, checking also that the choice of lane/warp/block does
// not affect the permutation of registers. If we allow different
// lane/warp/blocks to have different permutations, we can generalize this.
if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kLane),
kLane, kLane) *
LinearLayout::identity1D(comp.getInDimSize(kWarp),
kWarp, kWarp) *
LinearLayout::identity1D(comp.getInDimSize(kBlock),
kBlock, kBlock))
.has_value()) {
return true;
if (!(srcLayout.has_value() && dstLayout.has_value()))
return std::nullopt;
// comp describes the layout function to create dst from src.
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
// We try to quotient by the largest subspace first
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
for (auto dim : dims) {
auto quotient = comp.quotient(StringAttr::get(ctx, dim));
if (!quotient.has_value()) {
break;
}
comp = *quotient;
}
return false;
return comp;
}

bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) {
auto layout = minimalCvtLayout(srcTy, dstTy);
MLIRContext *ctx = srcTy.getContext();
if (!layout.has_value()) {
return false;
}
auto kRegister = StringAttr::get(ctx, "register");
auto outDims = llvm::to_vector(layout->getOutDimNames());
return outDims.empty() || ArrayRef(outDims) == ArrayRef({kRegister});
}

bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
auto layout = minimalCvtLayout(srcTy, dstTy);
MLIRContext *ctx = srcTy.getContext();
std::optional<LinearLayout> srcLayout =
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
std::optional<LinearLayout> dstLayout =
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
if (srcLayout.has_value() && dstLayout.has_value()) {
// comp describes the layout function for converting from src to dst.
LinearLayout comp = srcLayout->invertAndCompose(*dstLayout);
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");
if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kWarp),
kWarp, kWarp) *
LinearLayout::identity1D(comp.getInDimSize(kBlock),
kBlock, kBlock))
.has_value()) {
return true;
}
if (!layout.has_value()) {
return false;
}
return false;
auto kRegister = StringAttr::get(ctx, "register");
auto kLane = StringAttr::get(ctx, "lane");
return llvm::to_vector(layout->getOutDimNames()) ==
llvm::SmallVector<StringAttr, 2>{kRegister, kLane};
}

bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
Expand Down
149 changes: 60 additions & 89 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,111 +282,79 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
const auto &shape = op.getType().getShape();
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
std::optional<LinearLayout> srcLayout =
toLinearLayout(shape, srcTy.getEncoding());
std::optional<LinearLayout> dstLayout =
toLinearLayout(shape, dstTy.getEncoding());
if (!srcLayout.has_value() || !dstLayout.has_value()) {
return failure();
}

// There are four cases to handle.
//
// 1. Transfer between values in the same thread, in which case we simply
// reorder the elements of adaptor.getSrc().
// 2. Transfer between values in the same warp, in which case we try to
// move values using warp shuffles, though if the pattern is complicated
// enough we may fall back to using shared memory (case 3).
// 3. Transfer between values in the same CTA, in which case we move values
// through shared memory.
// 4. Transfer between values in different CTAs, in which case we move
// values through distributed shared memory.
//
// We can tell which case we're in by examining `conversion`.
// For example, if the block -> block mapping is an identity layout: {1, 2,
// 4, ...}, then there's no movement between data in different CTAs, and we
// know we're not in case 4.
if (cvtReordersRegisters(srcTy, dstTy)) { // Case 1.
return transferWithinThread(op, *srcLayout, *dstLayout, adaptor,
rewriter);
auto conversion = minimalCvtLayout(srcTy, dstTy);
if (!conversion.has_value()) {
return rewriter.notifyMatchFailure(
op, "NYI. srcTy and/or dstTy don't implement LLs yet");
}

if (cvtNeedsWarpShuffle(srcTy, dstTy)) { // Case 2.
return transferWithinLane(op, *srcLayout, *dstLayout, adaptor, rewriter);
assert(to_vector(conversion->getInDimNames()) ==
to_vector(conversion->getOutDimNames()));
auto dims = conversion->getInDimNames();
if (llvm::is_contained(dims, str_attr("block"))) {
// Case 1: Transfer between values in different CTAs.
// This requires moving values through distributed shared memory.
return rewriter.notifyMatchFailure(
op, "NYI: Transfer between different CTAs");
} else if (llvm::is_contained(dims, str_attr("warp"))) {
// Case 2: Transfer between values in the same CTA, in which case we move
// values through shared memory.
LinearLayout srcLayout =
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
LinearLayout dstLayout =
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
} else if (llvm::is_contained(dims, str_attr("lane"))) {
// Case 3. Transfer between values in the same warp, in which case we try
// to move values using warp shuffles, though if the pattern is
// complicated enough we may fall back to using shared memory
// TODO(Keren): implement warp shuffle instead of using the general
// approach that uses shared memory
LinearLayout srcLayout =
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
LinearLayout dstLayout =
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
} else if (llvm::is_contained(dims, str_attr("register"))) {
// Case 4. Transfer between values in the same thread, in which case we
// simply reorder the elements of adaptor.getSrc().
return transferWithinThread(op, *conversion, adaptor, rewriter);
} else {
// The two layouts are equivalent. We should probably remove these in
// RemoveLayoutConversion.
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}

return transferWithinBlockOrGroup(op, *srcLayout, *dstLayout, adaptor,
rewriter); // Case 3 and 4
}

LogicalResult
transferWithinThread(ConvertLayoutOp op, const LinearLayout &srcLayout,
const LinearLayout &dstLayout, OpAdaptor adaptor,
transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();
StringAttr kRegister = str_attr("register");
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");
StringAttr kBlock = str_attr("block");

// There are three possible cases:
//
// 1. `srcLayout` has the same number of registers as `dstLayout`.
// 2. `srcLayout` has fewer registers than `dstLayout`.
// 3. `srcLayout` has more registers than `dstLayout`.
//
// In the second case `srcLayout . dstLayout^-1` is not surjective
// because not all destination registers are covered.
// Since the goal is to cover all of the destination
// registers, we can instead use `dstLayout . srcLayout^-1`.
LinearLayout conversion = dstLayout.invertAndCompose(srcLayout);
auto dstToSrc = conversion.divideRight(
LinearLayout::identity1D(conversion.getInDimSize(kLane), kLane, kLane) *
LinearLayout::identity1D(conversion.getInDimSize(kWarp), kWarp, kWarp) *
LinearLayout::identity1D(conversion.getInDimSize(kBlock), kBlock,
kBlock));

assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
assert(ArrayRef(to_vector(dstToSrc->getInDimNames())) ==
ArrayRef{kRegister});
assert(ArrayRef(to_vector(dstToSrc->getOutDimNames())) ==
ArrayRef{kRegister});

auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
SmallVector<Value> outVals;
outVals.resize(dstToSrc->getInDimSize(kRegister));
for (int i = 0; i < dstToSrc->getInDimSize(kRegister); i++) {
auto srcIdx = dstToSrc->apply({{kRegister, i}});
outVals[i] = inVals[srcIdx.begin()->second];
outVals.resize(conversion.getInDimSize(kRegister));
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second;
outVals[i] = inVals[srcIdx];
}
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
rewriter.replaceOp(op, result);
return success();
}

LogicalResult transferWithinLane(ConvertLayoutOp op,
const LinearLayout &srcLayout,
const LinearLayout &dstLayout,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// TODO(Keren): implement warp shuffle instead of using the general approach
// that uses shared memory
return transferWithinBlockOrGroup(op, srcLayout, dstLayout, adaptor,
rewriter);
}

LogicalResult
transferWithinBlockOrGroup(ConvertLayoutOp op, const LinearLayout &srcLayout,
const LinearLayout &dstLayout, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
LinearLayout conversion = srcLayout.invertAndCompose(dstLayout);

// TODO(Keren): LLs support cross-CTA conversions, this function does not
if (isCrossCTAConversion(conversion))
return failure();

LogicalResult transferWithinBlock(ConvertLayoutOp op,
const LinearLayout &srcLayout,
const LinearLayout &dstLayout,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();
auto srcTy = op.getSrc().getType();
Expand Down Expand Up @@ -461,11 +429,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}
}

// Pretty sure this is the identity function ATM
// It'd be better to simply call `quotient({kBlock})` and
// remove kBlock from transferWithinBlockImpl
auto srcLayoutWithinBlock = getLayoutWithinBlock(srcLayout);
auto dstLayoutWithinBlock = getLayoutWithinBlock(dstLayout);
SmallVector<Value> outVals =
transferWithinBlock(inVals, op, srcLayoutWithinBlock,
dstLayoutWithinBlock, adaptor, rewriter);
transferWithinBlockImpl(inVals, op, srcLayoutWithinBlock,
dstLayoutWithinBlock, adaptor, rewriter);

// Unmunge output values
for (const auto &it : llvm::enumerate(outVals)) {
Expand Down Expand Up @@ -499,10 +470,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}

SmallVector<Value>
transferWithinBlock(ArrayRef<Value> inVals, ConvertLayoutOp op,
const LinearLayout &srcLayout,
const LinearLayout &dstLayout, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
transferWithinBlockImpl(ArrayRef<Value> inVals, ConvertLayoutOp op,
const LinearLayout &srcLayout,
const LinearLayout &dstLayout, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();

Expand Down
Loading
Loading