Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve blocked pointers used by tt.load operation with DPAS layout #2400

Merged
merged 3 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions test/TritonIntelGPU/rewrite-tensor-pointer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
#dot1 = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
// CHECK: @matmul_kernel_with_block_pointers
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) {
%c4_i32 = arith.constant 4 : i32
%c256_i32 = arith.constant 256 : i32
%c1_i64 = arith.constant 1 : i64
Expand All @@ -20,9 +19,9 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
%c255_i32 = arith.constant 255 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas>
%0 = tt.get_program_id x : i32
%1 = arith.addi %arg3, %c255_i32 : i32
%1 = arith.addi %arg4, %c255_i32 : i32
%2 = arith.divsi %1, %c256_i32 : i32
%3 = arith.addi %arg4, %c255_i32 : i32
%3 = arith.addi %arg5, %c255_i32 : i32
%4 = arith.divsi %3, %c256_i32 : i32
%5 = arith.muli %4, %c4_i32 : i32
%6 = arith.divsi %0, %5 : i32
Expand All @@ -34,35 +33,41 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
%12 = arith.remsi %0, %5 : i32
%13 = arith.divsi %12, %9 : i32
%14 = arith.muli %11, %c256_i32 : i32
%15 = arith.extsi %arg3 : i32 to i64
%16 = arith.extsi %arg5 : i32 to i64
%17 = arith.extsi %arg6 : i32 to i64
%15 = arith.extsi %arg4 : i32 to i64
%16 = arith.extsi %arg6 : i32 to i64
%17 = arith.extsi %arg7 : i32 to i64
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
%18 = tt.make_tensor_ptr %arg0, [%15, %16], [%17, %c1_i64], [%14, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #dot0>>
%19 = arith.muli %13, %c256_i32 : i32
%20 = arith.extsi %arg4 : i32 to i64
%21 = arith.extsi %arg7 : i32 to i64
%20 = arith.extsi %arg5 : i32 to i64
%21 = arith.extsi %arg8 : i32 to i64
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #dot1>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #dot0>>, !tt.ptr<tensor<32x256xf16, #dot1>>) : i32 {
%23:3 = scf.for %arg10 = %c0_i32 to %arg6 step %c32_i32 iter_args(%arg11 = %cst, %arg12 = %18, %arg13 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #dot0>>, !tt.ptr<tensor<32x256xf16, #dot1>>) : i32 {
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #dot0>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #dot1>>
%28 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #dot0>>
%29 = tt.load %arg13 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #dot1>>
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%30 = tt.dot %28, %29, %arg10, inputPrecision = tf32 : tensor<256x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<256x256xf32, #dpas>
%31 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #dot0>>
%32 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #dot1>>
%30 = tt.dot %28, %29, %arg11, inputPrecision = tf32 : tensor<256x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<256x256xf32, #dpas>
%31 = tt.advance %arg12, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #dot0>>
%32 = tt.advance %arg13, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #dot1>>
scf.yield %30, %31, %32 : tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #dot0>>, !tt.ptr<tensor<32x256xf16, #dot1>>
}
%24 = arith.truncf %23#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>
%26 = arith.extsi %arg8 : i32 to i64
%25 = arith.extsi %arg9 : i32 to i64
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #[[DPAS]]>>
%26 = tt.make_tensor_ptr %arg3, [%15, %20], [%25, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #dpas>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #[[DPAS]]>>
%27 = tt.load %26 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #dpas>>
%28 = arith.addf %23#0, %27 : tensor<256x256xf32, #dpas>
%29 = arith.truncf %28 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>

// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #[[DPAS]]>>
%27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #dpas>>
%30 = tt.make_tensor_ptr %arg2, [%15, %20], [%25, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #dpas>>
// CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #[[DPAS]]>>
tt.store %27, %24 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #dpas>>
tt.store %30, %29 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #dpas>>
tt.return
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,28 @@ namespace {

/// Check if the tensor pointer should be removed. The tensor pointer should be
/// removed if:
/// - the tensor pointer does not have DotEncoding with DpasEncoding parent
/// and does not have DpasEncoding
/// - the tensor pointer pitch is not divisible by Qword bitwidth
/// - the tensor pointer is not contiguous on memory
bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) {
/// - it does not have Dpas layout or Dot layout (with Dpas layout as parent)
/// - its pitch is not divisible by Qword bitwidth
/// - it is not contiguous in memory
bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) {
LDBG("Considering removal of: " << op);
if (!op->getParentOfType<ModuleOp>()->hasAttr(
ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName()))
ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName())) {
LDBG("Marked for removal: 2D block operation not supported");
return true;
}

auto ptrType = cast<tt::PointerType>(op.getType());
LDBG("Op ptr type: " << ptrType);
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
LDBG("Op tensor type: " << tensorType);

if (!ttgi::hasDotDpasEncoding(tensorType) &&
!(isUsedByStoreOp && ttgi::hasDpasEncoding(tensorType)))
!(isUsedByLoadOrStoreOp && ttgi::hasDpasEncoding(tensorType))) {
LDBG("Marked for removal: tensor doesn't have DPAS layout and is not used "
"by load or store op with DPAS layout");
return true;
}

TypedValue<triton::PointerType> base = op.getBase();
Operation::operand_range shape = op.getShape();
Expand All @@ -60,21 +64,23 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) {

int fastChangeDim = -1;
for (size_t i = 0; i < strides.size(); ++i) {
if (mlir::triton::gpu::intel::isConstant(strides[i], 1)) {
if (ttgi::isConstant(strides[i], 1)) {
fastChangeDim = i;
break;
}
}

LDBG("fastChangeDim: " << fastChangeDim);
if (fastChangeDim < 0) {
LDBG("Marked for removal: fast changing dimension not found");
return true;
}

LDBG("Tensor type element type bit width: "
<< tensorType.getElementTypeBitWidth());
if (fastChangeDim == rank - 2 && tensorType.getElementTypeBitWidth() == 8) {
// TODO: column major layout w/ fp8 has performance regression
LDBG("Marked for removal: column major layout with fp8 element type");
return true;
}

Expand All @@ -85,11 +91,15 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) {
// Across Intel platforms, the strictest pitch restriction is to be a
// multiple of OWord(128 bits).
if (!ttgi::isDivisible(pitch, 128 / tensorType.getElementTypeBitWidth())) {
LDBG("Marked for removal: cannot use block read/write instructions");
return true;
}

return false;
}

LDBG("Marked for removal: fall-trough");

return true;
}

Expand Down Expand Up @@ -705,28 +715,28 @@ class TritonIntelGPURewriteTensorPointerPass
void runOnOperation() override {
ModuleOp mod = getOperation();

auto usedByStoreOp = [](Value val) {
auto usedByLoadOrStoreOp = [](Value val) {
return llvm::any_of(val.getUsers(), [](Operation *user) {
return llvm::isa<tt::StoreOp>(user);
return isa<tt::LoadOp, tt::StoreOp>(user);
});
};

auto markTensorPointerForRemoval = [this](Value val,
bool isUsedByStoreOp = false) {
if (tt::isTensorPointerType(val.getType())) {
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
if (shouldRemove(makeTensorPtrOp, isUsedByStoreOp))
valueToRemove.insert(val);
}
};
auto markTensorPointerForRemoval =
[this](Value val, bool isUsedByLoadOrStoreOp = false) {
if (tt::isTensorPointerType(val.getType())) {
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
if (shouldRemove(makeTensorPtrOp, isUsedByLoadOrStoreOp))
valueToRemove.insert(val);
}
};

mod.walk([&](Operation *op) {
if (llvm::isa<tt::MakeTensorPtrOp>(op)) {
if (isa<tt::MakeTensorPtrOp>(op)) {
Value result = op->getResult(0);
markTensorPointerForRemoval(result, usedByStoreOp(result));
} else if (llvm::isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
markTensorPointerForRemoval(result, usedByLoadOrStoreOp(result));
} else if (isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
markTensorPointerForRemoval(op->getOperand(0),
llvm::isa<tt::StoreOp>(op));
isa<tt::LoadOp, tt::StoreOp>(op));
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
for (auto arg : forOp.getInitArgs())
markTensorPointerForRemoval(arg);
Expand All @@ -738,11 +748,11 @@ class TritonIntelGPURewriteTensorPointerPass

LLVM_DEBUG({
if (valueToRemove.empty())
llvm::dbgs() << "No tensor pointer to remove\n";
DBGS() << "No tensor pointer to remove";
else {
llvm::dbgs() << "Values to remove: \n";
DBGS() << "Values to remove: ";
for (auto val : valueToRemove)
llvm::dbgs() << val << "\n";
DBGS() << val;
}
});

Expand Down