Skip to content

Commit

Permalink
[LinalgExt] Add Interfaces for implementing fusion support for `iree_…
Browse files Browse the repository at this point in the history
…linalg_ext.custom_op`. (iree-org#18647)

These methods allow the dispatch region formation to automatically pick
up fusion of `custom_op` with producers/consumers similar to what is
supported with `LinalgOp`s.

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar authored Oct 1, 2024
1 parent 451ef71 commit 8de9856
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 1 deletion.
29 changes: 29 additions & 0 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,35 @@ LogicalResult CustomOp::verify() {
return success();
}

SmallVector<AffineMap> CustomOp::getIndexingMapsForOperands() {
return llvm::map_to_vector(
getIndexingMaps().getValue().take_front(getNumDpsInputs()),
[](Attribute attr) { return cast<AffineMapAttr>(attr).getValue(); });
}

SmallVector<AffineMap> CustomOp::getIndexingMapsForResults() {
return llvm::map_to_vector(
getIndexingMaps().getValue().take_back(getNumDpsInits()),
[](Attribute attr) { return cast<AffineMapAttr>(attr).getValue(); });
}

SmallVector<utils::IteratorType> CustomOp::getLoopIteratorTypes() {
return llvm::map_to_vector(getIteratorTypes(), [](Attribute attr) {
return cast<IREE::LinalgExt::IteratorTypeAttr>(attr).getValue();
});
}

LogicalResult
CustomOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
for (auto init : getOutputs()) {
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(builder, getLoc(), init);
reifiedReturnShapes.emplace_back(std::move(sizes));
}
return success();
}

#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,12 @@ def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_
// Custom tilable op
//===---------------------------------------------------------------------===//

def IREELinalgExt_CustomOp : IREELinalgExt_Op<"custom_op"> {
def IREELinalgExt_CustomOp : IREELinalgExt_Op<"custom_op", [
DeclareOpInterfaceMethods<LinalgFusionInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["getLoopIteratorTypes"]>
]> {
let summary = "Custom operation for compiling with IREE";
let description = [{
This operation is meant to allow computation sequences that are fused at
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -795,3 +795,116 @@ util.func public @no_batch_mmt4d_fusion(%arg0: tensor<1x1x64x1x1xf32>,
// CHECK-SAME: outs(%[[INIT0]] : tensor<1x1x32x1x4xf32>)
// CHECK: flow.return %[[GEN]] : tensor<1x1x32x1x4xf32>
// CHECK: util.return %[[DISP1]] : tensor<1x1x32x1x4xf32>

// -----

util.func @custom_op_consumer_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?xf32> {
%0 = iree_linalg_ext.custom_op {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = [#iree_linalg_ext.iterator_type<parallel>, #iree_linalg_ext.iterator_type<reduction>]}
ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
^bb0(%b0 : tensor<?x?xf32>, %b1 : tensor<?xf32>):
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%b0 : tensor<?x?xf32>) outs(%b1 : tensor<?xf32>) {
^bb1(%bb0 : f32, %bb1 : f32) :
%2 = arith.addf %bb0, %bb1 : f32
linalg.yield %2 : f32
} -> tensor<?xf32>
iree_linalg_ext.yield %1 : tensor<?xf32>
} -> tensor<?xf32>
%3 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%0 : tensor<?xf32>) outs(%arg1 : tensor<?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%4 = arith.mulf %b0, %b0 : f32
linalg.yield %4 :f32
} -> tensor<?xf32>
util.return %3 : tensor<?xf32>
}
// CHECK-LABEL: func public @custom_op_consumer_fusion
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
// CHECK: %[[CUSTOM_OP:.+]] = iree_linalg_ext.custom_op
// CHECK: linalg.generic
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[CUSTOM_OP]] :
// CHECK: flow.return %[[GENERIC]]
// CHECK: util.return %[[DISPATCH]]

// -----

util.func @custom_op_producer_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?xf32> {
%0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?x?xf32>) outs(%arg0 : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%1 = arith.mulf %b0, %b0 : f32
linalg.yield %1 :f32
} -> tensor<?x?xf32>
%2 = iree_linalg_ext.custom_op {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = [#iree_linalg_ext.iterator_type<parallel>, #iree_linalg_ext.iterator_type<reduction>]}
ins(%0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
^bb0(%b0 : tensor<?x?xf32>, %b1 : tensor<?xf32>):
%3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%b0 : tensor<?x?xf32>) outs(%b1 : tensor<?xf32>) {
^bb1(%bb0 : f32, %bb1 : f32) :
%4 = arith.addf %bb0, %bb1 : f32
linalg.yield %4 : f32
} -> tensor<?xf32>
iree_linalg_ext.yield %3 : tensor<?xf32>
} -> tensor<?xf32>
util.return %2 : tensor<?xf32>
}
// CHECK-LABEL: func public @custom_op_producer_fusion
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: %[[CUSTOM_OP:.+]] = iree_linalg_ext.custom_op
// CHECK-SAME: ins(%[[GENERIC]] :
// CHECK: flow.return %[[CUSTOM_OP]]
// CHECK: util.return %[[DISPATCH]]

// -----

util.func @custom_op_no_producer_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>, %arg3 : tensor<?x?xf32>, %arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?x?xf32>) outs(%arg0 : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%1 = arith.mulf %b0, %b0 : f32
linalg.yield %1 :f32
} -> tensor<?x?xf32>
%2 = iree_linalg_ext.custom_op {
indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>,
affine_map<(d0, d1)[s0, s1] -> (s0, s1)>,
affine_map<(d0, d1)[s0, s1] -> (d0, s1)>,
affine_map<(d0, d1)[s0, s1] -> (s1, d1)>,
affine_map<(d0, d1)[s0, s1] -> (d0, d1)>],
iterator_types = [#iree_linalg_ext.iterator_type<parallel>, #iree_linalg_ext.iterator_type<parallel>]}
ins(%0, %arg1, %arg2, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg4 : tensor<?x?xf32>) {
^bb0(%b0 : tensor<?x?xf32>, %b1 : tensor<?x?xf32>, %b2 : tensor<?x?xf32>, %b3 : tensor<?x?xf32>, %b4 : tensor<?x?xf32>):
%3 = linalg.matmul ins(%b0, %b1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%b2 : tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = linalg.matmul ins(%3, %b3 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%b4 : tensor<?x?xf32>) -> tensor<?x?xf32>
iree_linalg_ext.yield %4 : tensor<?x?xf32>
} -> tensor<?x?xf32>
util.return %2 : tensor<?x?xf32>
}
// CHECK-LABEL: func public @custom_op_no_producer_fusion
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: flow.return %[[GENERIC]]
// CHECK: %[[DISPATCH2:.+]] = flow.dispatch.region
// CHECK: %[[CUSTOM_OP:.+]] = iree_linalg_ext.custom_op
// CHECK-SAME: ins(%[[DISPATCH1]],
// CHECK: flow.return %[[CUSTOM_OP]]
// CHECK: util.return %[[DISPATCH2]]

0 comments on commit 8de9856

Please sign in to comment.