Skip to content

Commit

Permalink
[AMD] NFC: Refactor AccelerateAMDMatmul patterns
Browse files Browse the repository at this point in the history
This commit refactors the AccelerateAMDMatmul patterns in prep
for mxfp support so that upcoming changes can be easier to review.
  • Loading branch information
antiagainst committed Oct 23, 2024
1 parent a20ce64 commit e785eef
Showing 1 changed file with 109 additions and 102 deletions.
211 changes: 109 additions & 102 deletions third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <memory>

Expand Down Expand Up @@ -36,16 +38,15 @@ int getWmmaVersion(StringRef archGen) {
return 0;
}

SmallVector<unsigned, 2> warpsPerTile(tt::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps,
SmallVector<int64_t, 2> shapePerWarp) {
SmallVector<unsigned, 3>
warpsPerTile(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps,
std::pair<int64_t, int64_t> shapePerWarp) {
auto rank = shape.size();
// Early exit for batched matmul
if (rank == 3)
return {(unsigned)numWarps, 1, 1};

auto filter = [&dotOp](Operation *op) {
auto filter = [dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
ForwardSliceOptions fwdOpt;
Expand All @@ -55,17 +56,17 @@ SmallVector<unsigned, 2> warpsPerTile(tt::DotOp dotOp,
bwdOpt.filter = filter;
auto slices = getSlice(dotOp, bwdOpt, fwdOpt);
for (Operation *op : slices)
if (isa<tt::DotOp>(op) && (op != dotOp))
if (op->hasTrait<OpTrait::DotLike>() && (op != dotOp))
return {(unsigned)numWarps, 1};

SmallVector<int64_t, 2> tensorShape = {shape[0], shape[1]};
SmallVector<unsigned, 2> ret = {1, 1};
do {
if (ret[0] * ret[1] >= numWarps)
break;
if (tensorShape[0] / (shapePerWarp[0] * 2) / ret[0] >=
tensorShape[1] / shapePerWarp[1] / ret[1]) {
if (ret[0] < tensorShape[0] / shapePerWarp[0]) {
if (tensorShape[0] / (shapePerWarp.first * 2) / ret[0] >=
tensorShape[1] / shapePerWarp.second / ret[1]) {
if (ret[0] < tensorShape[0] / shapePerWarp.first) {
ret[0] *= 2;
} else
ret[1] *= 2;
Expand All @@ -74,24 +75,89 @@ SmallVector<unsigned, 2> warpsPerTile(tt::DotOp dotOp,
}
} while (true);

if (ret[1] * shapePerWarp[1] > tensorShape[1]) {
if (ret[1] * shapePerWarp.first > tensorShape[1]) {
return {ret[1], ret[0]};
}

return ret;
}

SmallVector<unsigned, 2>
warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
SmallVector<int64_t, 2> shapePerWarp) {
SmallVector<unsigned, 3>
warpsPerTileMFMA(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps,
std::pair<int64_t, int64_t> shapePerWarp) {
return warpsPerTile(dotOp, shape, numWarps, shapePerWarp);
}

SmallVector<unsigned, 2>
warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
return warpsPerTile(dotOp, shape, numWarps,
{ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr()[0],
ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr()[1]});
SmallVector<unsigned, 3>
warpsPerTileWMMA(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps) {
auto mnk = ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr();
return warpsPerTile(dotOp, shape, numWarps, {mnk[0], mnk[1]});
}

// Chooses a proper MFMA instruction that can used to compute the given dot op.
// If enforcedNonKDim is not zero, it will be used to overwrite the default
// logic to chose a MFMA with matching M/N dim.
FailureOr<MfmaInsn> chooseMfmaInstruction(RankedTensorType cType,
Type aElemType, Type bElemType,
int inputKSize, int mfmaVersion,
int enforcedNonKDim) {
// number of matrix elements along k dim per one MFMA intruction
unsigned kDim = 0;

auto resShape = cType.getShape();
auto rank = resShape.size();
auto M = resShape[rank - 2];
auto N = resShape[rank - 1];

unsigned mDim = 0;
unsigned nDim = 0;
if (enforcedNonKDim != 0) {
mDim = nDim = enforcedNonKDim;
} else {
int minSize = std::min(M, N);
if (minSize >= 32) {
mDim = 32;
nDim = 32;
}
if (minSize >= 16 && minSize < 32) {
mDim = 16;
nDim = 16;
}
if (minSize < 16) {
if (M < 16 && N >= 64) {
mDim = 4;
nDim = 64;
} else if (M >= 64 && N < 16) {
mDim = 64;
nDim = 4;
} else {
assert(inputKSize >= 64 &&
"k should be at least 64 to use this layout");
mDim = 4;
nDim = 4;
}
}
}
assert(mDim != 0 && nDim != 0);

auto maybeMfmaInsn =
MfmaInsn::selectMfma(mDim, nDim, aElemType, bElemType, mfmaVersion);
if (failed(maybeMfmaInsn))
llvm::report_fatal_error("No match found in MFMA database\n");

kDim = maybeMfmaInsn->getKDim();
assert(kDim != 0);
assert(M % mDim == 0 && N % nDim == 0);
assert(inputKSize % kDim == 0);
return maybeMfmaInsn;
}

FailureOr<MfmaInsn> chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion,
int nonKDim) {
RankedTensorType aType = dot.getA().getType();
return chooseMfmaInstruction(dot.getC().getType(), aType.getElementType(),
dot.getB().getType().getElementType(),
aType.getShape().back(), mfmaVersion, nonKDim);
}

using OperandTypesVector = SmallVector<Type, 4>;
Expand Down Expand Up @@ -259,15 +325,16 @@ Value convertAndCastTensor(PatternRewriter &rewriter, Value value,
return castedTensor;
}

class BlockedToMFMA : public RewritePattern {
class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
int mfmaVersion;
int enforcedNonKDim;
int nonKDim;
int kPack;

public:
BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack)
: RewritePattern(tt::DotOp::getOperationName(), 2, context),
mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {}
BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion),
nonKDim(nonKDim), kPack(kPack) {}

bool isSecondDot(tt::DotOp &dotOp) const {
auto filter = [&dotOp](Operation *op) {
Expand All @@ -285,75 +352,15 @@ class BlockedToMFMA : public RewritePattern {
return false;
}

/// @brief Choose MFMA instruction parameters
/// @param dot target dot operation
/// @return MfmaInsn or failure
FailureOr<MfmaInsn> chooseMfmaInstruction(tt::DotOp dot) const {
// number of matrix elements along k dim per one MFMA intruction
unsigned kDim = 0;
auto opType = cast<RankedTensorType>(dot.getA().getType());
auto dataTypeA = opType.getElementType();
auto dataTypeB =
cast<RankedTensorType>(dot.getB().getType()).getElementType();

auto resType = cast<RankedTensorType>(dot.getD().getType());
auto resShape = resType.getShape();
auto rank = resShape.size();
auto M = resShape[rank - 2];
auto N = resShape[rank - 1];

unsigned mDim = 0;
unsigned nDim = 0;
if (enforcedNonKDim != 0) {
mDim = enforcedNonKDim;
nDim = enforcedNonKDim;
} else {
int minSize = std::min(M, N);
if (minSize >= 32) {
mDim = 32;
nDim = 32;
}
if (minSize >= 16 && minSize < 32) {
mDim = 16;
nDim = 16;
}
if (minSize < 16) {
if (M < 16 && N >= 64) {
mDim = 4;
nDim = 64;
} else if (M >= 64 && N < 16) {
mDim = 64;
nDim = 4;
} else {
assert(opType.getShape()[rank - 1] >= 64 &&
"k should be at least 64 to use this layout");
mDim = 4;
nDim = 4;
}
}
}
assert(mDim != 0 && nDim != 0);

auto maybeMfmaInsn =
MfmaInsn::selectMfma(mDim, nDim, dataTypeA, dataTypeB, mfmaVersion);
if (failed(maybeMfmaInsn))
llvm::report_fatal_error("No match found in MFMA database\n");

kDim = maybeMfmaInsn->getKDim();
assert(kDim != 0);
assert(M % mDim == 0 && N % nDim == 0);
assert(opType.getShape()[rank - 1] % kDim == 0);
return maybeMfmaInsn;
}

LogicalResult matchAndRewrite(Operation *op,
LogicalResult matchAndRewrite(tt::DotOp dotOp,
PatternRewriter &rewriter) const override {
auto dotOp = cast<tt::DotOp>(op);

RankedTensorType oldRetType = dotOp.getType();
if (!oldRetType.getEncoding() ||
!isa<ttg::BlockedEncodingAttr>(oldRetType.getEncoding()))
return failure();
if (!isa_and_nonnull<BlockedEncodingAttr>(dotOp.getType().getEncoding()))
return rewriter.notifyMatchFailure(
dotOp, "expected blocked encoding result tensor");

if (!supportMFMA(dotOp))
return failure();
Expand All @@ -362,7 +369,7 @@ class BlockedToMFMA : public RewritePattern {

// get MFMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<ModuleOp>();
auto mod = dotOp->getParentOfType<ModuleOp>();
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);

// operands
Expand All @@ -374,7 +381,7 @@ class BlockedToMFMA : public RewritePattern {

ttg::AMDMfmaEncodingAttr mfmaEnc;

auto mfmaInstr = chooseMfmaInstruction(dotOp);
auto mfmaInstr = chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim);
auto mDim = mfmaInstr.value().getMDim();
auto nDim = mfmaInstr.value().getNDim();
auto kDim = mfmaInstr.value().getKDim();
Expand All @@ -397,7 +404,7 @@ class BlockedToMFMA : public RewritePattern {
mfmaAccType = rewriter.getF32Type();

// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto oldAcc = dotOp.getC();
auto newAcc = convertAndCastTensor(rewriter, oldAcc, mfmaEnc, mfmaAccType);

// Here is a brief explanation of kWidth, kBase, and kDim
Expand Down Expand Up @@ -456,11 +463,12 @@ class BlockedToMFMA : public RewritePattern {
convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(),
oldRetType.getElementType());

rewriter.replaceOp(op, dotOutput);
rewriter.replaceOp(dotOp, dotOutput);

return success();
}
};

static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
Type promotedType) {
Type tensorPromotedType = cast<RankedTensorType>(operand.getType())
Expand Down Expand Up @@ -566,18 +574,17 @@ static void decomposeMixedModeDotOp(ModuleOp mod) {
});
}

class BlockedToWMMA : public RewritePattern {
class BlockedToWMMA : public OpRewritePattern<tt::DotOp> {
int wmmaVersion;

public:
BlockedToWMMA(MLIRContext *context, int wmmaVersion)
: RewritePattern(tt::DotOp::getOperationName(), 2, context),
wmmaVersion(wmmaVersion) {}
BlockedToWMMA(MLIRContext *context, int wmmaVersion,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), wmmaVersion(wmmaVersion) {}

LogicalResult matchAndRewrite(Operation *op,
LogicalResult matchAndRewrite(tt::DotOp dotOp,
PatternRewriter &rewriter) const override {
auto ctx = op->getContext();
auto dotOp = cast<tt::DotOp>(op);
auto ctx = dotOp->getContext();

Value a = dotOp.getA();
Value b = dotOp.getB();
Expand All @@ -603,7 +610,7 @@ class BlockedToWMMA : public RewritePattern {

if (wmmaVersion == 2 && llvm::isa<FloatType>(oldAType) &&
oldAType.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(op, "not supported yet");
return rewriter.notifyMatchFailure(dotOp, "not supported yet");
}

// get operand types
Expand All @@ -612,7 +619,7 @@ class BlockedToWMMA : public RewritePattern {
return failure();

// get WMMA encoding for the given number of warps
auto mod = op->getParentOfType<ModuleOp>();
auto mod = dotOp->getParentOfType<ModuleOp>();
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);

ttg::AMDWmmaEncodingAttr wmmaEnc;
Expand All @@ -626,7 +633,7 @@ class BlockedToWMMA : public RewritePattern {
auto newRetType = RankedTensorType::get(retShape, operandTypes[3], wmmaEnc);

// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto oldAcc = dotOp.getC();
auto newAcc =
convertAndCastTensor(rewriter, oldAcc, wmmaEnc, operandTypes[2]);

Expand All @@ -653,7 +660,7 @@ class BlockedToWMMA : public RewritePattern {

Value dotOutput = convertAndCastTensor(rewriter, newDot, oldRetEncoding,
oldRetType.getElementType());
rewriter.replaceOp(op, dotOutput);
rewriter.replaceOp(dotOp, dotOutput);
return success();
}
};
Expand Down

0 comments on commit e785eef

Please sign in to comment.