From 8de98566951a5a306edd78e9422d769c53e36209 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Tue, 1 Oct 2024 16:47:29 -0700 Subject: [PATCH] [LinalgExt] Add Interfaces for implementing fusion support for `iree_linalg_ext.custom_op`. (#18647) These methods allow the dispatch region formation to automatically pick up fusion of `custom_op` with producers/consumers similar to what is supported with `LinalgOp`s. Signed-off-by: MaheshRavishankar --- .../Dialect/LinalgExt/IR/LinalgExtOps.cpp | 29 +++++ .../Dialect/LinalgExt/IR/LinalgExtOps.td | 7 +- .../test/form_dispatch_regions.mlir | 113 ++++++++++++++++++ 3 files changed, 148 insertions(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 85c6da0f783d..d4d73442a5be 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -1766,6 +1766,35 @@ LogicalResult CustomOp::verify() { return success(); } +SmallVector CustomOp::getIndexingMapsForOperands() { + return llvm::map_to_vector( + getIndexingMaps().getValue().take_front(getNumDpsInputs()), + [](Attribute attr) { return cast(attr).getValue(); }); +} + +SmallVector CustomOp::getIndexingMapsForResults() { + return llvm::map_to_vector( + getIndexingMaps().getValue().take_back(getNumDpsInits()), + [](Attribute attr) { return cast(attr).getValue(); }); +} + +SmallVector CustomOp::getLoopIteratorTypes() { + return llvm::map_to_vector(getIteratorTypes(), [](Attribute attr) { + return cast(attr).getValue(); + }); +} + +LogicalResult +CustomOp::reifyResultShapes(OpBuilder &builder, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + for (auto init : getOutputs()) { + SmallVector sizes = + tensor::getMixedSizes(builder, getLoc(), init); + reifiedReturnShapes.emplace_back(std::move(sizes)); + } + return success(); +} + #define DEFINE_OP_GET_EFFECTS(OP_NAME) \ void OP_NAME::getEffects( \ SmallVectorImpl> \ diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index 9de0ae58f54c..eb66a38836d5 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -1574,7 +1574,12 @@ def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_ // Custom tilable op //===---------------------------------------------------------------------===// -def IREELinalgExt_CustomOp : IREELinalgExt_Op<"custom_op"> { +def IREELinalgExt_CustomOp : IREELinalgExt_Op<"custom_op", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { let summary = "Custom operation for compiling with IREE"; let description = [{ This operation is meant to allow computation sequences that are fused at diff --git a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir index 62bca48be35a..3f3c91b6d5e5 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir @@ -795,3 +795,116 @@ util.func public @no_batch_mmt4d_fusion(%arg0: tensor<1x1x64x1x1xf32>, // CHECK-SAME: outs(%[[INIT0]] : tensor<1x1x32x1x4xf32>) // CHECK: flow.return %[[GEN]] : tensor<1x1x32x1x4xf32> // CHECK: util.return %[[DISP1]] : tensor<1x1x32x1x4xf32> + +// ----- + +util.func @custom_op_consumer_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = iree_linalg_ext.custom_op { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = [#iree_linalg_ext.iterator_type, #iree_linalg_ext.iterator_type]} + ins(%arg0 : tensor) outs(%arg1 : tensor) { + ^bb0(%b0 : tensor, %b1 : tensor): + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%b0 : tensor) outs(%b1 : tensor) { + ^bb1(%bb0 : f32, %bb1 : f32) : + %2 = arith.addf %bb0, %bb1 : f32 + linalg.yield %2 : f32 + } -> tensor + iree_linalg_ext.yield %1 : tensor + } -> tensor + %3 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%0 : tensor) outs(%arg1 : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %4 = arith.mulf %b0, %b0 : f32 + linalg.yield %4 :f32 + } -> tensor + util.return %3 : tensor +} +// CHECK-LABEL: func public @custom_op_consumer_fusion +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[CUSTOM_OP:.+]] = iree_linalg_ext.custom_op +// CHECK: linalg.generic +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[CUSTOM_OP]] : +// CHECK: flow.return %[[GENERIC]] +// CHECK: util.return %[[DISPATCH]] + +// ----- + +util.func @custom_op_producer_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) outs(%arg0 : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %1 = arith.mulf %b0, %b0 : f32 + linalg.yield %1 :f32 + } -> tensor + %2 = iree_linalg_ext.custom_op { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = [#iree_linalg_ext.iterator_type, #iree_linalg_ext.iterator_type]} + ins(%0 : tensor) outs(%arg1 : tensor) { + ^bb0(%b0 : tensor, %b1 : tensor): + %3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%b0 : tensor) outs(%b1 : tensor) { + ^bb1(%bb0 : f32, %bb1 : f32) : + %4 = arith.addf %bb0, %bb1 : f32 + linalg.yield %4 : f32 + } -> tensor + iree_linalg_ext.yield %3 : tensor + } -> tensor + util.return %2 : tensor +} +// CHECK-LABEL: func public @custom_op_producer_fusion +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: %[[CUSTOM_OP:.+]] = iree_linalg_ext.custom_op +// CHECK-SAME: ins(%[[GENERIC]] : +// CHECK: flow.return %[[CUSTOM_OP]] +// CHECK: util.return %[[DISPATCH]] + +// ----- + +util.func @custom_op_no_producer_fusion(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor, %arg3 : tensor, %arg4 : tensor) -> tensor { + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) outs(%arg0 : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %1 = arith.mulf %b0, %b0 : f32 + linalg.yield %1 :f32 + } -> tensor + %2 = iree_linalg_ext.custom_op { + indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>, + affine_map<(d0, d1)[s0, s1] -> (s0, s1)>, + affine_map<(d0, d1)[s0, s1] -> (d0, s1)>, + affine_map<(d0, d1)[s0, s1] -> (s1, d1)>, + affine_map<(d0, d1)[s0, s1] -> (d0, d1)>], + iterator_types = [#iree_linalg_ext.iterator_type, #iree_linalg_ext.iterator_type]} + ins(%0, %arg1, %arg2, %arg3 : tensor, tensor, tensor, tensor) + outs(%arg4 : tensor) { + ^bb0(%b0 : tensor, %b1 : tensor, %b2 : tensor, %b3 : tensor, %b4 : tensor): + %3 = linalg.matmul ins(%b0, %b1 : tensor, tensor) + outs(%b2 : tensor) -> tensor + %4 = linalg.matmul ins(%3, %b3 : tensor, tensor) + outs(%b4 : tensor) -> tensor + iree_linalg_ext.yield %4 : tensor + } -> tensor + util.return %2 : tensor +} +// CHECK-LABEL: func public @custom_op_no_producer_fusion +// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: flow.return %[[GENERIC]] +// CHECK: %[[DISPATCH2:.+]] = flow.dispatch.region +// CHECK: %[[CUSTOM_OP:.+]] = iree_linalg_ext.custom_op +// CHECK-SAME: ins(%[[DISPATCH1]], +// CHECK: flow.return %[[CUSTOM_OP]] +// CHECK: util.return %[[DISPATCH2]]