Skip to content

Commit

Permalink
[Codegen][GPU] Disable consumer fusion for multi use cases (iree-org#…
Browse files Browse the repository at this point in the history
…18723)

The upstream patterns for doing consumer fusion currently don't support
cases where multiple operands of the consumer come from the producer
loop. This disables fusion of such cases and sends it down the fallback
path.
  • Loading branch information
qedawkins authored Oct 8, 2024
1 parent 0f28d44 commit 0e16a89
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ struct FuseTilableForallConsumers final
}

tensor::ParallelInsertSliceOp producerSlice;
scf::ForallOp sliceOwner;
Value fusionOperand;
for (auto operand : dpsOp.getDpsInputs()) {
auto forallProducer = operand.getDefiningOp<scf::ForallOp>();
if (!forallProducer) {
Expand All @@ -288,6 +290,8 @@ struct FuseTilableForallConsumers final
auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(user);
if (sliceOp && sliceOp.getDest() == iterArg) {
producerSlice = sliceOp;
sliceOwner = forallProducer;
fusionOperand = operand;
break;
}
}
Expand All @@ -297,7 +301,16 @@ struct FuseTilableForallConsumers final
}

if (!producerSlice) {
return failure();
return rewriter.notifyMatchFailure(tilableOp,
"no scf.forall producer to fuse into");
}

for (auto operand : tilableOp->getOperands()) {
if (operand != fusionOperand && operand.getDefiningOp() == sliceOwner) {
return rewriter.notifyMatchFailure(tilableOp,
"unimplemented: Cannot fuse op with "
"multiple uses of producer loop");
}
}

FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,35 @@ func.func @forall_hoist_unit_loop_with_fill(%3: tensor<1x128xf16>, %4: tensor<12
// CHECK: scf.forall.in_parallel
// CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]] into %[[ITER]]
// CHECK: return %[[OUTER_PARALLEL]]

// -----

func.func @no_fuse_multi_use(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) -> tensor<128x128xf16> {
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%empty = tensor.empty() : tensor<128x128xf16>
%10:2 = scf.forall (%arg5, %arg6) in (32, 32) shared_outs(%arg7 = %empty, %arg8 = %empty) -> (tensor<128x128xf16>, tensor<128x128xf16>) {
%extracted_slice_1 = tensor.extract_slice %2[%arg5, %arg6] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16>
%extracted_slice_2 = tensor.extract_slice %arg7[%arg5, %arg6] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16>
%extracted_slice_3 = tensor.extract_slice %arg8[%arg6, %arg5] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16>
%16 = linalg.copy ins(%extracted_slice_1 : tensor<2x2xf16>) outs(%extracted_slice_2 : tensor<2x2xf16>) -> tensor<2x2xf16>
%17 = linalg.transpose ins(%extracted_slice_1 : tensor<2x2xf16>) outs(%extracted_slice_3 : tensor<2x2xf16>) permutation = [1, 0]
scf.forall.in_parallel {
tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [2, 2] [1, 1] : tensor<2x2xf16> into tensor<128x128xf16>
tensor.parallel_insert_slice %17 into %arg8[%arg6, %arg5] [2, 2] [1, 1] : tensor<2x2xf16> into tensor<128x128xf16>
}
} {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
%add = linalg.add
ins(%10#0, %10#1 : tensor<128x128xf16>, tensor<128x128xf16>)
outs(%empty: tensor<128x128xf16>) -> tensor<128x128xf16>
return %add : tensor<128x128xf16>
}

// CHECK-LABEL: func @no_fuse_multi_use
// CHECK: scf.forall
// CHECK: linalg.copy
// CHECK: linalg.transpose
// CHECK: scf.forall.in_parallel
// CHECK: linalg.add
// CHECK: return

0 comments on commit 0e16a89

Please sign in to comment.