diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 8762942c311c..8ee166866974 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -41,36 +41,60 @@ SmallVector reorderValues(const SmallVector &values, Type inType, if (inBitWidth == ouBitWidth) return values; if (inBitWidth == 16 && ouBitWidth == 32) { + // Register layout conversion: + // + // [0, 1], [4, 5] ⟶ [0], [1], [4], [5] + // [2, 3], [6, 7] [2], [3], [6], [7] + // + // Original access order: + // + // [0, 1], [2, 3], [4, 5], [6, 7] + // + // Transformed access order: + // + // [0], [2], [1], [3], [4], [6], [5], [7] SmallVector ret; for (unsigned i = 0; i < values.size(); i += 8) { ret.push_back(values[i]); - ret.push_back(values[i + 1]); - ret.push_back(values[i + 4]); - ret.push_back(values[i + 5]); ret.push_back(values[i + 2]); + ret.push_back(values[i + 1]); ret.push_back(values[i + 3]); + ret.push_back(values[i + 4]); ret.push_back(values[i + 6]); + ret.push_back(values[i + 5]); ret.push_back(values[i + 7]); } return ret; } if (inBitWidth == 8 && ouBitWidth == 16) { + // Register layout conversion: + // + // [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11] + // [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15] + // + // Original access order: + // + // [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15] + // + // Transformed access order: + // + // [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15] SmallVector ret; for (unsigned i = 0; i < values.size(); i += 16) { - ret.push_back(values[i + 0]); + ret.push_back(values[i]); ret.push_back(values[i + 1]); - ret.push_back(values[i + 2]); - ret.push_back(values[i + 3]); - ret.push_back(values[i + 8]); - ret.push_back(values[i + 9]); - ret.push_back(values[i + 10]); - ret.push_back(values[i + 11]); ret.push_back(values[i + 4]); ret.push_back(values[i + 5]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); ret.push_back(values[i + 6]); ret.push_back(values[i + 7]); + ret.push_back(values[i + 8]); + ret.push_back(values[i + 9]); ret.push_back(values[i + 12]); ret.push_back(values[i + 13]); + ret.push_back(values[i + 10]); + ret.push_back(values[i + 11]); ret.push_back(values[i + 14]); ret.push_back(values[i + 15]); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 54371d063fb1..71fd3c0cd4e7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -641,7 +641,6 @@ struct ConvertLayoutOpConversion // for the destination type, we need to pack values together // so they can be consumed by tensor core operations SmallVector vecVals; - SmallVector types; // For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer // instructions to pack & unpack sub-word integers. A workaround is to // store the results of ldmatrix in i32 @@ -655,37 +654,20 @@ struct ConvertLayoutOpConversion shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j)); val = or_(i32_ty, val, ext); } - vecVals.push_back(val); + vecVals.push_back(bitcast(val, i32_ty)); } - elems = elems / (32 / elemSize); - types = SmallVector(elems, i32_ty); } else { unsigned vecSize = std::max(32 / elemSize, 1); Type vecTy = vec_ty(elemTy, vecSize); - types = SmallVector(elems / vecSize, vecTy); for (unsigned i = 0; i < elems; i += vecSize) { Value packed = rewriter.create(loc, vecTy); for (unsigned j = 0; j < vecSize; j++) packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); - vecVals.push_back(packed); + vecVals.push_back(bitcast(packed, i32_ty)); } } - - // This needs to be ordered the same way that - // ldmatrix.x4 would order it - // TODO: this needs to be refactor so we don't - // implicitly depends on how emitOffsetsForMMAV2 - // is implemented - SmallVector reorderedVals; - for (unsigned i = 0; i < vecVals.size(); i += 4) { - reorderedVals.push_back(bitcast(vecVals[i], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty)); - } - - Value view = packLLElements(loc, getTypeConverter(), reorderedVals, - rewriter, dstTy); + Value view = + packLLElements(loc, getTypeConverter(), vecVals, rewriter, dstTy); rewriter.replaceOp(op, view); return success(); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index 73c21cae6de2..21c2bee584a6 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -513,8 +513,8 @@ Value composeValuesToDotOperandLayoutStruct( for (int m = 0; m < n0; ++m) for (int k = 0; k < n1; ++k) { elems.push_back(vals.at({b, 2 * m, 2 * k})); - elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); + elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); } assert(!elems.empty()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 79ccb57206ae..c2940a04386f 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -75,9 +75,39 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( // For kWidth = 8, split the mma into 4 mmas with "stride 4" along K if (dot.getOpIdx() == 0) { - si = llvm::SmallVector{0, 8, 4, 12, 1, 9, 5, 13, - 2, 10, 6, 14, 3, 11, 7, 15}; + // Original register layout: + // + // [0, 1, 2, 3], [8, 9, 10, 11] + // [4, 5, 6, 7], [12, 13, 14, 15] + // + // Each element in the layout consists of two bf16 values. + // For example, the row [0, 1, 2, 3] expands to: + // + // [[0/0, 0/1], [1/0, 1/1], [2/0, 2/1], [3/0, 3/1]] + // + // Here, 0/0 refers to the first half of element 0, and 0/1 refers to the + // second half, matching kWidth = 8. + // + // To derive four independent MMA operations, a stride of 4 is applied to + // the original register layout: + // + // 1st MMA: [0, 4, 8, 12] + // 2nd MMA: [1, 5, 9, 13] + // 3rd MMA: [2, 6, 10, 14] + // 4th MMA: [3, 7, 11, 15] + si = llvm::SmallVector{0, 4, 8, 12, 1, 5, 9, 13, + 2, 6, 10, 14, 3, 7, 11, 15}; } else { + // Original register layout: + // + // [0, 1, 2, 3]^T, [4, 5, 6, 7]^T + // + // A stride of 4 is applied to derive four independent MMA operations: + // + // 1st MMA: [0, 4] + // 2nd MMA: [1, 5] + // 3rd MMA: [2, 6] + // 4th MMA: [3, 7] si = llvm::SmallVector{0, 4, 1, 5, 2, 6, 3, 7}; } @@ -112,8 +142,8 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( for (auto i = 0; i < n0; ++i) { for (auto j = 0; j < n1; j++) { vals[{b, 2 * i, 2 * j}] = elems[offset++]; - vals[{b, 2 * i, 2 * j + 1}] = elems[offset++]; vals[{b, 2 * i + 1, 2 * j}] = elems[offset++]; + vals[{b, 2 * i, 2 * j + 1}] = elems[offset++]; vals[{b, 2 * i + 1, 2 * j + 1}] = elems[offset++]; } } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index 9404bb4474d0..722bf56cd015 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -80,19 +80,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { ret.push_back(v); } } - // FIXME [Dot LL] - // The DotOperandEncodingAttr without LLs encodes the - // layout as - // e0 e1 - // e2 e3 - // rather than transposed that, as the PTX docs say - // We transpose every block of 4 elements (kWidth = 8 -> 4 bf16x2) - assert(ret.size() % 16 == 0); - for (int i = 0; i < ret.size() / 16; ++i) { - for (int j = 0; j < 4; ++j) { - std::swap(ret[16 * i + j + 4], ret[16 * i + j + 8]); - } - } return ret; }