Skip to content

Commit

Permalink
[BACKEND] Fix the register accessing order of dot operands of mmav2 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Oct 24, 2024
1 parent 3c13f09 commit 3613bf4
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 49 deletions.
44 changes: 34 additions & 10 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,36 +41,60 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
if (inBitWidth == ouBitWidth)
return values;
if (inBitWidth == 16 && ouBitWidth == 32) {
// Register layout conversion:
//
// [0, 1], [4, 5] ⟶ [0], [1], [4], [5]
// [2, 3], [6, 7] [2], [3], [6], [7]
//
// Original access order:
//
// [0, 1], [2, 3], [4, 5], [6, 7]
//
// Transformed access order:
//
// [0], [2], [1], [3], [4], [6], [5], [7]
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 8) {
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 7]);
}
return ret;
}
if (inBitWidth == 8 && ouBitWidth == 16) {
// Register layout conversion:
//
// [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11]
// [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15]
//
// Original access order:
//
// [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]
//
// Transformed access order:
//
// [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15]
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 16) {
ret.push_back(values[i + 0]);
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 8]);
ret.push_back(values[i + 9]);
ret.push_back(values[i + 10]);
ret.push_back(values[i + 11]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
ret.push_back(values[i + 8]);
ret.push_back(values[i + 9]);
ret.push_back(values[i + 12]);
ret.push_back(values[i + 13]);
ret.push_back(values[i + 10]);
ret.push_back(values[i + 11]);
ret.push_back(values[i + 14]);
ret.push_back(values[i + 15]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,6 @@ struct ConvertLayoutOpConversion
// for the destination type, we need to pack values together
// so they can be consumed by tensor core operations
SmallVector<Value> vecVals;
SmallVector<Type> types;
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
// instructions to pack & unpack sub-word integers. A workaround is to
// store the results of ldmatrix in i32
Expand All @@ -655,37 +654,20 @@ struct ConvertLayoutOpConversion
shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j));
val = or_(i32_ty, val, ext);
}
vecVals.push_back(val);
vecVals.push_back(bitcast(val, i32_ty));
}
elems = elems / (32 / elemSize);
types = SmallVector<Type>(elems, i32_ty);
} else {
unsigned vecSize = std::max<unsigned>(32 / elemSize, 1);
Type vecTy = vec_ty(elemTy, vecSize);
types = SmallVector<Type>(elems / vecSize, vecTy);
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
vecVals.push_back(bitcast(packed, i32_ty));
}
}

// This needs to be ordered the same way that
// ldmatrix.x4 would order it
// TODO: this needs to be refactor so we don't
// implicitly depends on how emitOffsetsForMMAV2
// is implemented
SmallVector<Value> reorderedVals;
for (unsigned i = 0; i < vecVals.size(); i += 4) {
reorderedVals.push_back(bitcast(vecVals[i], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty));
}

Value view = packLLElements(loc, getTypeConverter(), reorderedVals,
rewriter, dstTy);
Value view =
packLLElements(loc, getTypeConverter(), vecVals, rewriter, dstTy);
rewriter.replaceOp(op, view);
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,8 @@ Value composeValuesToDotOperandLayoutStruct(
for (int m = 0; m < n0; ++m)
for (int k = 0; k < n1; ++k) {
elems.push_back(vals.at({b, 2 * m, 2 * k}));
elems.push_back(vals.at({b, 2 * m, 2 * k + 1}));
elems.push_back(vals.at({b, 2 * m + 1, 2 * k}));
elems.push_back(vals.at({b, 2 * m, 2 * k + 1}));
elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1}));
}
assert(!elems.empty());
Expand Down
36 changes: 33 additions & 3 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,39 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(

// For kWidth = 8, split the mma into 4 mmas with "stride 4" along K
if (dot.getOpIdx() == 0) {
si = llvm::SmallVector<unsigned>{0, 8, 4, 12, 1, 9, 5, 13,
2, 10, 6, 14, 3, 11, 7, 15};
// Original register layout:
//
// [0, 1, 2, 3], [8, 9, 10, 11]
// [4, 5, 6, 7], [12, 13, 14, 15]
//
// Each element in the layout consists of two bf16 values.
// For example, the row [0, 1, 2, 3] expands to:
//
// [[0/0, 0/1], [1/0, 1/1], [2/0, 2/1], [3/0, 3/1]]
//
// Here, 0/0 refers to the first half of element 0, and 0/1 refers to the
// second half, matching kWidth = 8.
//
// To derive four independent MMA operations, a stride of 4 is applied to
// the original register layout:
//
// 1st MMA: [0, 4, 8, 12]
// 2nd MMA: [1, 5, 9, 13]
// 3rd MMA: [2, 6, 10, 14]
// 4th MMA: [3, 7, 11, 15]
si = llvm::SmallVector<unsigned>{0, 4, 8, 12, 1, 5, 9, 13,
2, 6, 10, 14, 3, 7, 11, 15};
} else {
// Original register layout:
//
// [0, 1, 2, 3]^T, [4, 5, 6, 7]^T
//
// A stride of 4 is applied to derive four independent MMA operations:
//
// 1st MMA: [0, 4]
// 2nd MMA: [1, 5]
// 3rd MMA: [2, 6]
// 4th MMA: [3, 7]
si = llvm::SmallVector<unsigned>{0, 4, 1, 5, 2, 6, 3, 7};
}

Expand Down Expand Up @@ -112,8 +142,8 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
for (auto i = 0; i < n0; ++i) {
for (auto j = 0; j < n1; j++) {
vals[{b, 2 * i, 2 * j}] = elems[offset++];
vals[{b, 2 * i, 2 * j + 1}] = elems[offset++];
vals[{b, 2 * i + 1, 2 * j}] = elems[offset++];
vals[{b, 2 * i, 2 * j + 1}] = elems[offset++];
vals[{b, 2 * i + 1, 2 * j + 1}] = elems[offset++];
}
}
Expand Down
13 changes: 0 additions & 13 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
ret.push_back(v);
}
}
// FIXME [Dot LL]
// The DotOperandEncodingAttr without LLs encodes the
// layout as
// e0 e1
// e2 e3
// rather than transposed that, as the PTX docs say
// We transpose every block of 4 elements (kWidth = 8 -> 4 bf16x2)
assert(ret.size() % 16 == 0);
for (int i = 0; i < ret.size() / 16; ++i) {
for (int j = 0; j < 4; ++j) {
std::swap(ret[16 * i + j + 4], ret[16 * i + j + 8]);
}
}

return ret;
}
Expand Down

0 comments on commit 3613bf4

Please sign in to comment.