From c96abe63005d58a57761b5dd3afdd14cdefb34f5 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 1 Oct 2024 20:31:00 +0000 Subject: [PATCH 1/3] Preserve blocked pointers used by tt.load operation with DPAS layout Signed-off-by: Tiotto, Ettore --- .../RewriteTensorPointer.cpp | 60 +++++++++++-------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp index dc205496b..1c45599dc 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp @@ -30,15 +30,16 @@ namespace { /// Check if the tensor pointer should be removed. The tensor pointer should be /// removed if: -/// - the tensor pointer does not have DotEncoding with DpasEncoding parent -/// and does not have DpasEncoding -/// - the tensor pointer pitch is not divisible by Qword bitwidth -/// - the tensor pointer is not contiguous on memory -bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) { +/// - it does not have Dpas layout or Dot layout (with Dpas layout as parent) +/// - its pitch is not divisible by Qword bitwidth +/// - it is not contiguous in memory +bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) { LDBG("Considering removal of: " << op); if (!op->getParentOfType()->hasAttr( - ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName())) + ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName())) { + LDBG("Marked for removal: 2D block operation not supported"); return true; + } auto ptrType = cast(op.getType()); LDBG("Op ptr type: " << ptrType); @@ -46,8 +47,11 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) { LDBG("Op tensor type: " << tensorType); if (!ttgi::hasDotDpasEncoding(tensorType) && - !(isUsedByStoreOp && ttgi::hasDpasEncoding(tensorType))) + !(isUsedByLoadOrStoreOp && ttgi::hasDpasEncoding(tensorType))) { + LDBG("Marked for removal: tensor doesn't have DPAS layout and is not used " + "by load or store op with DPAS layout"); return true; + } TypedValue base = op.getBase(); Operation::operand_range shape = op.getShape(); @@ -60,7 +64,7 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) { int fastChangeDim = -1; for (size_t i = 0; i < strides.size(); ++i) { - if (mlir::triton::gpu::intel::isConstant(strides[i], 1)) { + if (ttgi::isConstant(strides[i], 1)) { fastChangeDim = i; break; } @@ -68,6 +72,7 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) { LDBG("fastChangeDim: " << fastChangeDim); if (fastChangeDim < 0) { + LDBG("Marked for removal: fast changing dimension not found"); return true; } @@ -75,6 +80,7 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) { << tensorType.getElementTypeBitWidth()); if (fastChangeDim == rank - 2 && tensorType.getElementTypeBitWidth() == 8) { // TODO: column major layout w/ fp8 has performance regression + LDBG("Marked for removal: column major layout with fp8 element type"); return true; } @@ -85,11 +91,15 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) { // Across Intel platforms, the strictest pitch restriction is to be a // multiple of OWord(128 bits). if (!ttgi::isDivisible(pitch, 128 / tensorType.getElementTypeBitWidth())) { + LDBG("Marked for removal: cannot use block read/write instructions"); return true; } return false; } + + LDBG("Marked for removal: fall-trough"); + return true; } @@ -705,28 +715,28 @@ class TritonIntelGPURewriteTensorPointerPass void runOnOperation() override { ModuleOp mod = getOperation(); - auto usedByStoreOp = [](Value val) { + auto usedByLoadOrStoreOp = [](Value val) { return llvm::any_of(val.getUsers(), [](Operation *user) { - return llvm::isa(user); + return isa(user); }); }; - auto markTensorPointerForRemoval = [this](Value val, - bool isUsedByStoreOp = false) { - if (tt::isTensorPointerType(val.getType())) { - tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val); - if (shouldRemove(makeTensorPtrOp, isUsedByStoreOp)) - valueToRemove.insert(val); - } - }; + auto markTensorPointerForRemoval = + [this](Value val, bool isUsedByLoadOrStoreOp = false) { + if (tt::isTensorPointerType(val.getType())) { + tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val); + if (shouldRemove(makeTensorPtrOp, isUsedByLoadOrStoreOp)) + valueToRemove.insert(val); + } + }; mod.walk([&](Operation *op) { - if (llvm::isa(op)) { + if (isa(op)) { Value result = op->getResult(0); - markTensorPointerForRemoval(result, usedByStoreOp(result)); - } else if (llvm::isa(op)) { + markTensorPointerForRemoval(result, usedByLoadOrStoreOp(result)); + } else if (isa(op)) { markTensorPointerForRemoval(op->getOperand(0), - llvm::isa(op)); + isa(op)); } else if (auto forOp = dyn_cast(op)) { for (auto arg : forOp.getInitArgs()) markTensorPointerForRemoval(arg); @@ -738,11 +748,11 @@ class TritonIntelGPURewriteTensorPointerPass LLVM_DEBUG({ if (valueToRemove.empty()) - llvm::dbgs() << "No tensor pointer to remove\n"; + LDBG("No tensor pointer to remove"); else { - llvm::dbgs() << "Values to remove: \n"; + LDBG("Values to remove: "); for (auto val : valueToRemove) - llvm::dbgs() << val << "\n"; + LDBG(val); } }); From 7a2c62a5c70ccf03cdcfa44736f5e84c6043e641 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 2 Oct 2024 16:57:37 +0000 Subject: [PATCH 2/3] Add lit test Signed-off-by: Tiotto, Ettore --- .../rewrite-tensor-pointer.mlir | 43 +++++++++++-------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/test/TritonIntelGPU/rewrite-tensor-pointer.mlir b/test/TritonIntelGPU/rewrite-tensor-pointer.mlir index 1ee5d5f87..761c82717 100644 --- a/test/TritonIntelGPU/rewrite-tensor-pointer.mlir +++ b/test/TritonIntelGPU/rewrite-tensor-pointer.mlir @@ -10,8 +10,7 @@ #dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}> #dot1 = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} { - tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { - // CHECK: @matmul_kernel_with_block_pointers + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) { %c4_i32 = arith.constant 4 : i32 %c256_i32 = arith.constant 256 : i32 %c1_i64 = arith.constant 1 : i64 @@ -20,9 +19,9 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa %c255_i32 = arith.constant 255 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas> %0 = tt.get_program_id x : i32 - %1 = arith.addi %arg3, %c255_i32 : i32 + %1 = arith.addi %arg4, %c255_i32 : i32 %2 = arith.divsi %1, %c256_i32 : i32 - %3 = arith.addi %arg4, %c255_i32 : i32 + %3 = arith.addi %arg5, %c255_i32 : i32 %4 = arith.divsi %3, %c256_i32 : i32 %5 = arith.muli %4, %c4_i32 : i32 %6 = arith.divsi %0, %5 : i32 @@ -34,35 +33,41 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa %12 = arith.remsi %0, %5 : i32 %13 = arith.divsi %12, %9 : i32 %14 = arith.muli %11, %c256_i32 : i32 - %15 = arith.extsi %arg3 : i32 to i64 - %16 = arith.extsi %arg5 : i32 to i64 - %17 = arith.extsi %arg6 : i32 to i64 + %15 = arith.extsi %arg4 : i32 to i64 + %16 = arith.extsi %arg6 : i32 to i64 + %17 = arith.extsi %arg7 : i32 to i64 // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> %18 = tt.make_tensor_ptr %arg0, [%15, %16], [%17, %c1_i64], [%14, %c0_i32] {order = array} : > %19 = arith.muli %13, %c256_i32 : i32 - %20 = arith.extsi %arg4 : i32 to i64 - %21 = arith.extsi %arg7 : i32 to i64 + %20 = arith.extsi %arg5 : i32 to i64 + %21 = arith.extsi %arg8 : i32 to i64 // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > - %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { + %23:3 = scf.for %arg10 = %c0_i32 to %arg6 step %c32_i32 iter_args(%arg11 = %cst, %arg12 = %18, %arg13 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr>> // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr>> - %28 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr> - %29 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr> + %28 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr> + %29 = tt.load %arg13 {boundaryCheck = array} : !tt.ptr> // CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]> // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> - %30 = tt.dot %28, %29, %arg10, inputPrecision = tf32 : tensor<256x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<256x256xf32, #dpas> - %31 = tt.advance %arg11, [%c0_i32, %c32_i32] : > - %32 = tt.advance %arg12, [%c32_i32, %c0_i32] : > + %30 = tt.dot %28, %29, %arg11, inputPrecision = tf32 : tensor<256x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<256x256xf32, #dpas> + %31 = tt.advance %arg12, [%c0_i32, %c32_i32] : > + %32 = tt.advance %arg13, [%c32_i32, %c0_i32] : > scf.yield %30, %31, %32 : tensor<256x256xf32, #dpas>, !tt.ptr>, !tt.ptr> } - %24 = arith.truncf %23#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas> - %26 = arith.extsi %arg8 : i32 to i64 + %25 = arith.extsi %arg9 : i32 to i64 + // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : > + %26 = tt.make_tensor_ptr %arg3, [%15, %20], [%25, %c1_i64], [%14, %19] {order = array} : > + // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr> + %27 = tt.load %26 {boundaryCheck = array} : !tt.ptr> + %28 = arith.addf %23#0, %27 : tensor<256x256xf32, #dpas> + %29 = arith.truncf %28 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas> + // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : > - %27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array} : > + %30 = tt.make_tensor_ptr %arg2, [%15, %20], [%25, %c1_i64], [%14, %19] {order = array} : > // CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array} : !tt.ptr> - tt.store %27, %24 {boundaryCheck = array} : !tt.ptr> + tt.store %30, %29 {boundaryCheck = array} : !tt.ptr> tt.return } } From bd2daa449aea3dd46ca78ff98760d26c2778300c Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 2 Oct 2024 17:47:03 +0000 Subject: [PATCH 3/3] Address code review comments Signed-off-by: Tiotto, Ettore --- .../lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp index 1c45599dc..801982320 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp @@ -748,11 +748,11 @@ class TritonIntelGPURewriteTensorPointerPass LLVM_DEBUG({ if (valueToRemove.empty()) - LDBG("No tensor pointer to remove"); + DBGS() << "No tensor pointer to remove"; else { - LDBG("Values to remove: "); + DBGS() << "Values to remove: "; for (auto val : valueToRemove) - LDBG(val); + DBGS() << val; } });