Skip to content

Commit

Permalink
[AMD][Navi31] Introduce WMMA layout Attr (#3112)
Browse files Browse the repository at this point in the history
-Add WMMA layout to TritonGPU dialect
-Support required methods for it
Please note, lowering to WMMA instructions is not supported yet.

Signed-off-by: joviliast <[email protected]>
  • Loading branch information
joviliast authored Feb 22, 2024
1 parent b247953 commit f1f73d5
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 0 deletions.
55 changes: 55 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,61 @@ The data will be distributed between threads as follows:
let hasCustomAssemblyFormat = 1;
}

def AMDWmmaEncodingAttr : DistributedEncoding<"AMDWmmaEncoding", "amd_wmma_encoding", [MmaEncodingTrait]> {
let mnemonic = "amd_wmma";

let description = [{
An encoding for tensors that have been produced by WMMA instructions,
available on RDNA 3.
A `warpsPerCTA` parameter characterizes data distribution between waves.
An important limitation of WMMA for layout is a shape for tiles proccessed
by a single wave. It is [16, 16].
This encoding assumes specific access to matrix elements by threads.
Example:
Suppose we have a tensor with shape [32, 48], `warpsPerCTA` set to [2, 3].

wave 0 [16, 16] wave 1 [16, 16] wave 2 [16, 16]
-----------/\---------- -----------/\---------- -----------/\----------
[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15]
[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31]
[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15]
[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31]
... ... ...
[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15]
[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31]

wave 3 [16, 16] wave 4 [16, 16] wave 5 [16, 16]
-----------/\---------- -----------/\---------- -----------/\----------
[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15]
[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31]
[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15]
[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31]
... ... ...
[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15]
[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31]
}];

let parameters = (
ins
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
"CTALayoutAttr":$CTALayout
);

let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = extraDistributedDeclaration # [{
bool supportReduction() const {
return true;
}
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getWMMAElemsPerInstrForOperands() const;
SmallVector<int64_t> getWMMARepForOperands(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth, int opIdx) const;
static SmallVector<unsigned> getMNKDimPerWMMAInstr();
}];
}

def NvidiaMmaEncodingAttr : DistributedEncoding<"NvidiaMmaEncoding", "nvidia_mma_encoding", [MmaEncodingTrait]> {
let mnemonic = "nvidia_mma";
Expand Down
171 changes: 171 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,28 @@ unsigned MfmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,

//

SmallVector<unsigned>
AMDWmmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
size_t rank = shape.size();
assert(rank == 2 && "Unexpected rank of mfma layout");

SmallVector<unsigned> elemsPerThread(rank);
auto nonKDim = getMNKDimPerWMMAInstr()[0];
auto elemsPerThreadPerTile = getSizePerThread();
return {ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]) *
elemsPerThreadPerTile[0],
ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]) *
elemsPerThreadPerTile[1]};
}

unsigned AMDWmmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
return product<unsigned>(getElemsPerThread(shape, eltTy));
}

//

SmallVector<unsigned>
NvidiaMmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
Expand Down Expand Up @@ -1204,6 +1226,63 @@ void MfmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "}>";
}

//===----------------------------------------------------------------------===//
// WMMA encoding
//===----------------------------------------------------------------------===//

Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};

SmallVector<unsigned> warpsPerCTA;
std::optional<SmallVector<unsigned>> CTAsPerCGA;
std::optional<SmallVector<unsigned>> CTASplitNum;
std::optional<SmallVector<unsigned>> CTAOrder;

for (const NamedAttribute &attr : dict) {
if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
return {};
}
if (attr.getName() == "CTAsPerCGA") {
if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA")
.failed())
return {};
}
if (attr.getName() == "CTASplitNum") {
if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum")
.failed())
return {};
}
if (attr.getName() == "CTAOrder") {
if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder")
.failed())
return {};
}
}

std::optional<CTALayoutAttr> CTALayout = getCTALayoutOrError(
parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size());
if (!CTALayout.has_value())
return {};

return parser.getChecked<AMDWmmaEncodingAttr>(parser.getContext(),
warpsPerCTA, *CTALayout);
}

void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]";
maybePrintCTALayout(getContext(), printer, getCTALayout(),
/*rank=*/getWarpsPerCTA().size());
printer << "}>";
}

//===----------------------------------------------------------------------===//
// Sliced Encoding
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1436,6 +1515,98 @@ MfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
}
}

SmallVector<unsigned>
AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
auto nonKDim = getMNKDimPerWMMAInstr()[0];
return {nonKDim * getWarpsPerCTA()[0], nonKDim * getWarpsPerCTA()[1]};
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAOrder() const {
return SmallVector<unsigned>(getCTALayout().getCTAOrder());
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getCTASplitNum() const {
return SmallVector<unsigned>(getCTALayout().getCTASplitNum());
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpsPerCTA() const {
return SmallVector<unsigned>(getWarpsPerCTA__());
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpOrder() const {
return ::getOrder(*this);
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getThreadsPerWarp() const {
return {getMNKDimPerWMMAInstr()[0] / getSizePerThread()[0],
getMNKDimPerWMMAInstr()[1] / getSizePerThread()[1]};
}

SmallVector<unsigned> AMDWmmaEncodingAttr::getSizePerThread() const {
return {8, 1};
}
SmallVector<unsigned>
AMDWmmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
if (opIdx == 0) {
return {1, 16};
} else if (opIdx == 1) {
return {16, 1};
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
}

SmallVector<unsigned>
AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
int opIdx) const {
auto parentShapePerCTA = getShapePerCTATile(shape);
if (opIdx == 0) {
return {parentShapePerCTA[0], static_cast<unsigned>(shape[1])};
} else if (opIdx == 1) {
return {static_cast<unsigned>(shape[0]), parentShapePerCTA[1]};
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
}

unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands(
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
int warpsPerCTAM = getWarpsPerCTA()[0];
int warpsPerCTAN = getWarpsPerCTA()[1];
auto tileSize = getWMMAElemsPerInstrForOperands();
auto rep = getWMMARepForOperands(shape, eltTy, kWidth, opIdx);
return product(tileSize) * product(rep) * warpsPerCTAN * warpsPerCTAM;
}

SmallVector<int64_t>
AMDWmmaEncodingAttr::getWMMAElemsPerInstrForOperands() const {
return {16, 16};
}

SmallVector<int64_t>
AMDWmmaEncodingAttr::getWMMARepForOperands(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth,
int opIdx) const {
auto operandTileShape = getWMMAElemsPerInstrForOperands();
auto warpsPerCTA = getWarpsPerCTA();
if (opIdx == 0)
return {std::max<int64_t>(1, operandShape[0] /
(operandTileShape[0] * warpsPerCTA[0])),
std::max<int64_t>(1, operandShape[1] / operandTileShape[1])};
else {
assert(opIdx == 1);
return {std::max<int64_t>(1, operandShape[0] / operandTileShape[0]),
std::max<int64_t>(1, operandShape[1] /
(operandTileShape[1] * warpsPerCTA[1]))};
}
}

SmallVector<unsigned> AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr() {
// TODO: move magic numbers out of the code
return {16, 16, 16};
}

//===----------------------------------------------------------------------===//
// Mma encoding
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 20 additions & 0 deletions test/Conversion/triton_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// RUN: triton-opt %s | FileCheck %s

// CHECK: #[[WMMA:.*]] = #triton_gpu.amd_wmma

tt.func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
// scalar -> scalar
// CHECK: i64 -> !tt.ptr<f32, 1>
Expand Down Expand Up @@ -232,3 +234,21 @@ tt.func @histogram(%0: tensor<512xi32>) {
%1 = tt.histogram %0 : tensor<512xi32> -> tensor<16xi32>
tt.return
}

#blocked = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>

module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: wmma_layout
tt.func @wmma_layout(%0: tensor<16x16xf16, #blocked>) {
%1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #triton_gpu.amd_wmma<{warpsPerCTA = [1, 1]}>>
// CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[WMMA]]>
tt.return
}

// CHECK-LABEL: wmma_dot_op_layout
tt.func @wmma_dot_op_layout(%0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) {
%1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.amd_wmma<{warpsPerCTA = [1, 1]}>}>>
// CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA]]}>>
tt.return
}
}

0 comments on commit f1f73d5

Please sign in to comment.