Skip to content

Commit

Permalink
[Codegen][LLVMGPU][NFC] Cleanup contract distribution pattern for Lay…
Browse files Browse the repository at this point in the history
…outAttr (iree-org#17581)

This patch reuses VectorContractOpInfo to remove most of the contract
info inference.
  • Loading branch information
Groverkss authored Jun 10, 2024
1 parent 363e088 commit 8ab07d2
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,93 +9,39 @@
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "iree/compiler/Codegen/Utils/VectorOpUtils.h"

namespace mlir::iree_compiler {

using namespace mlir::iree_compiler::IREE::VectorExt;
using VectorValue = TypedValue<VectorType>;

enum class ContractMatrixType { A, B, C, D };
enum class ContractType { MM, MMT, MTM, MTMT, UNSUPPORTED };

namespace {

static bool isOperandATransposed(ContractType contractType) {
return (contractType == ContractType::MTM) ||
(contractType == ContractType::MTMT);
}

static bool isOperandBTransposed(ContractType contractType) {
return (contractType == ContractType::MMT) ||
(contractType == ContractType::MTMT);
}
struct DistributeContractions final
: OpDistributionPattern<vector::ContractionOp> {
using OpDistributionPattern::OpDistributionPattern;

// For a MM contraction, we compute C(i, k) += A(i, j) * B(j, k).
// If we have an MMT contraction, we compute C(i, k) += A(i, j) * B(k, j).
// This function returns the appropriate indices for the A and B matrices.
// Given incoming indices (i, j), it either returns the same or swaps them,
// depending on the type of contraction and type of matrix.
SmallVector<int64_t> getIndices(ContractType contractType,
ContractMatrixType matrixType, int i,
int j) const {
SmallVector<int64_t> originalIndices{i, j};
SmallVector<int64_t> swappedIndices{j, i};
if (contractType == ContractType::MTMT)
return swappedIndices;
if ((contractType == ContractType::MTM) &&
(matrixType == ContractMatrixType::A))
return swappedIndices;
if ((contractType == ContractType::MMT) &&
(matrixType == ContractMatrixType::B))
return swappedIndices;
return originalIndices;
}

int64_t getReductionDimensionShape(int64_t rowBatch, int64_t colBatch,
ContractType contractType) const {
if (isOperandATransposed(contractType)) {
return rowBatch;
}
return colBatch;
}

ContractType inferContractType(MLIRContext *ctx,
SmallVector<AffineMap> maps) const {
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [&](MapList m) {
return AffineMap::inferFromExprList(m, ctx);
};
AffineExpr m, n, k;
bindDims(ctx, m, n, k);
if ((maps == infer({{m, k}, {k, n}, {m, n}})) ||
(maps == infer({{n, k}, {k, m}, {n, m}}))) {
return ContractType::MM;
}
if ((maps == infer({{m, k}, {n, k}, {m, n}})) ||
(maps == infer({{n, k}, {m, k}, {n, m}}))) {
return ContractType::MMT;
}
if ((maps == infer({{k, m}, {k, n}, {m, n}})) ||
(maps == infer({{k, n}, {k, m}, {n, m}}))) {
return ContractType::MTM;
}
if ((maps == infer({{k, m}, {n, k}, {m, n}})) ||
(maps == infer({{k, n}, {m, k}, {n, m}}))) {
return ContractType::MTMT;
}
return ContractType::UNSUPPORTED;
}

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
VectorValue result = dyn_cast<VectorValue>(contractOp.getResult());
if (!result) {
return failure();
return rewriter.notifyMatchFailure(contractOp,
"result should be of type vector");
}

LayoutAttr resultLayout = dyn_cast<LayoutAttr>(signature[result]);
if (!resultLayout) {
return rewriter.notifyMatchFailure(
contractOp, "result layout should be of type LayoutAttr");
}

auto mmaAttr =
contractOp->getAttrOfType<IREE::GPU::MMAAttr>("iree.amdgpu.mma");
if (!mmaAttr) {
return rewriter.notifyMatchFailure(
contractOp, "missing iree.amdgpu.mma intrinsic attribute");
}

constexpr int LHS = 0;
Expand All @@ -113,10 +59,6 @@ struct DistributeContractions final
}
}

LayoutAttr resultLayout = dyn_cast<LayoutAttr>(signature[result]);
if (!resultLayout)
return failure();

Type elementType =
llvm::cast<ShapedType>(operands[ACC].getType()).getElementType();
SmallVector<int64_t> vectorShape = resultLayout.getDistributedShape();
Expand All @@ -125,39 +67,36 @@ struct DistributeContractions final
Value vector = rewriter.create<arith::ConstantOp>(
loc, vectorType, rewriter.getZeroAttr(vectorType));

ContractType contractType = inferContractType(
contractOp.getContext(), contractOp.getIndexingMapsArray());
if (contractType == ContractType::UNSUPPORTED)
return failure();

auto mmaAttr =
contractOp->getAttrOfType<IREE::GPU::MMAAttr>("iree.amdgpu.mma");
if (!mmaAttr) {
return rewriter.notifyMatchFailure(
contractOp, "missing iree.amdgpu.mma intrinsic attribute");
}
VectorContractOpInfo opInfo(contractOp);
auto [lhsK, rhsK] = opInfo.getOperandKIndex();

std::optional<int64_t> rowBatch = layouts[LHS].getBatchDim(0);
if (!rowBatch)
return failure();
std::optional<int64_t> colBatch = layouts[LHS].getBatchDim(1);
if (!colBatch)
std::optional<int64_t> kBatch = layouts[LHS].getBatchDim(lhsK);
if (!kBatch) {
return failure();

int K = getReductionDimensionShape(rowBatch.value(), colBatch.value(),
contractType);
}

auto contractFn = [&](const LayoutIterator::State &state) {
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
SmallVector<int64_t> indices = state.computeIteratorProjectedSIMTIndex();
Value dMatrix = rewriter.create<vector::ExtractOp>(
loc, getDistributed(rewriter, operands[ACC], layouts[ACC]), indices);
for (int k = 0; k < K; k++) {
for (int k = 0; k < kBatch; ++k) {
SmallVector<int64_t> lhsIndices(2);
SmallVector<int64_t> rhsIndices(2);
lhsIndices[lhsM] = indices[0];
lhsIndices[lhsK] = k;
rhsIndices[rhsN] = indices[1];
rhsIndices[rhsK] = k;

Value aMatrix = rewriter.create<vector::ExtractOp>(
loc, getDistributed(rewriter, operands[LHS], layouts[LHS]),
getIndices(contractType, ContractMatrixType::A, indices[0], k));
lhsIndices);

Value bMatrix = rewriter.create<vector::ExtractOp>(
loc, getDistributed(rewriter, operands[RHS], layouts[RHS]),
getIndices(contractType, ContractMatrixType::B, k, indices[1]));
rhsIndices);

dMatrix = mmaAttr
.buildMmaOperation(rewriter, loc, dMatrix.getType(),
aMatrix, bMatrix, dMatrix)
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
"//llvm-external-projects/iree-dialects:IREEVectorExtDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AMDGPUDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ iree_cc_library(
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
iree::compiler::Codegen::Utils::VectorOpUtils
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::Utils
PUBLIC
)

Expand Down

0 comments on commit 8ab07d2

Please sign in to comment.