diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp index 3f60799bdf8a..ac51f211234f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp @@ -78,6 +78,41 @@ static void tileNonPackedDimsFor3DPackOps(RewriterBase &rewriter, }); } +static void tileNonPackedDimsFor5DPUnpackOps(RewriterBase &rewriter, + FunctionOpInterface funcOp) { + funcOp.walk([&](tensor::UnPackOp unpackOp) { + if (unpackOp.getSourceRank() != 5 || unpackOp.getDestRank() != 3) { + return; + } + + OpFoldResult zero = rewriter.getIndexAttr(0); + OpFoldResult one = rewriter.getIndexAttr(1); + SmallVector tileSizes(unpackOp.getDestRank(), one); + + for (auto dim : unpackOp.getInnerDimsPos()) { + tileSizes[dim] = zero; + } + + // Skip the tiling if the size is already 1. + RankedTensorType destType = unpackOp.getDestType(); + for (auto [idx, val] : llvm::enumerate(tileSizes)) { + if (val && destType.getDimSize(idx) == 1) + return; + } + + auto tilingInterfaceOp = cast(unpackOp.getOperation()); + auto options = scf::SCFTilingOptions().setTileSizes(tileSizes); + auto outerDimsPerm = unpackOp.getOuterDimsPerm(); + if (!outerDimsPerm.empty()) { + options.setInterchange(outerDimsPerm); + } + FailureOr tilingResult = + scf::tileUsingSCF(rewriter, tilingInterfaceOp, options); + assert(succeeded(tilingResult)); + rewriter.replaceOp(unpackOp, tilingResult->replacements); + }); +} + /// Returns true if: /// 1. `genericOp` is element-wise with all identity indexing maps /// 2. `genericOp` has only one input and one output with the same shape @@ -278,6 +313,84 @@ struct Convert3DPackto2DPackPattern : public OpRewritePattern { } }; +struct Convert5DUnPackto4DUnPackPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, + PatternRewriter &rewriter) const override { + if (unpackOp.getSourceRank() != 5 || unpackOp.getDestRank() != 3) { + return failure(); + } + + llvm::SmallDenseSet s(unpackOp.getInnerDimsPos().begin(), + unpackOp.getInnerDimsPos().end()); + + SmallVector seqOrOuterDimsPerm = + llvm::to_vector(llvm::seq(0, unpackOp.getDestRank())); + if (!unpackOp.getOuterDimsPerm().empty()) { + applyPermutationToVector(seqOrOuterDimsPerm, unpackOp.getOuterDimsPerm()); + } + + int64_t srcPos = 0; + int64_t destPos = 0; + + for (auto [idx, val] : llvm::enumerate(seqOrOuterDimsPerm)) { + if (s.contains(val)) + continue; + srcPos = idx; + destPos = val; + break; + } + + if (unpackOp.getSourceType().getDimSize(srcPos) != 1) { + return rewriter.notifyMatchFailure(unpackOp, "srcPos != 1"); + } + + if (unpackOp.getDestType().getDimSize(destPos) != 1) { + return rewriter.notifyMatchFailure(unpackOp, "destPos != 1"); + } + + // Calculate the new innerDimsPos and outerDimsPerm after removal of the + // unit non packed/tiled dimension. + SmallVector newInnerDimsPos(unpackOp.getInnerDimsPos()); + for (auto &val : newInnerDimsPos) { + assert(val != destPos); + if (val > destPos) + val--; + } + + SmallVector newOuterDimsPerm(unpackOp.getOuterDimsPerm()); + if (!newOuterDimsPerm.empty()) { + newOuterDimsPerm.erase(newOuterDimsPerm.begin() + srcPos); + for (auto &val : newOuterDimsPerm) { + if (val > destPos) + val--; + } + } + + Location loc = unpackOp.getLoc(); + auto reducedSrcType = + RankedTensorType::Builder(unpackOp.getSourceType()).dropDim(srcPos); + auto reducedSrc = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, unpackOp.getSource(), reducedSrcType); + + auto reducedDestType = + RankedTensorType::Builder(unpackOp.getDestType()).dropDim(destPos); + auto reducedDest = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, unpackOp.getDest(), reducedDestType); + + auto newUnpackOp = rewriter.create( + loc, reducedSrc, reducedDest, newInnerDimsPos, unpackOp.getMixedTiles(), + newOuterDimsPerm); + + auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, newUnpackOp.getResult(), unpackOp.getDest()); + rewriter.replaceOp(unpackOp, insertSliceOp); + return success(); + } +}; + struct CPUPrepareUkernelsPass : public CPUPrepareUkernelsBase { void getDependentDialects(DialectRegistry ®istry) const override { @@ -305,6 +418,10 @@ void CPUPrepareUkernelsPass::runOnOperation() { tileNonPackedDimsFor3DPackOps(rewriter, funcOp); patterns.add(ctx); } + if (hasUkernel(targetAttr, "unpack")) { + tileNonPackedDimsFor5DPUnpackOps(rewriter, funcOp); + patterns.add(ctx); + } // Canonicalize extract and insert slice ops created during the conversion. tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); @@ -313,6 +430,7 @@ void CPUPrepareUkernelsPass::runOnOperation() { tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx); tensor::PackOp::getCanonicalizationPatterns(patterns, ctx); + tensor::UnPackOp::getCanonicalizationPatterns(patterns, ctx); tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); tensor::populateFoldTensorEmptyPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/prepare_ukernels.mlir b/compiler/src/iree/compiler/Codegen/Common/CPU/test/prepare_ukernels.mlir index ba294796cd37..199596d99833 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/prepare_ukernels.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/prepare_ukernels.mlir @@ -136,7 +136,6 @@ func.func @batch_mmt4d_with_fill_batch_dim(%arg0: tensor<12x10x32x8x1xf32>, %arg // CHECK: } // CHECK: return %[[TILED_RES]] : tensor<12x10x80x8x4xf32> - // ----- func.func @batch_mmt4d_with_lowering_config(%arg0: tensor<12x4x64x8x1xf16>, %arg1: tensor<12x4x64x8x1xf16>, %arg2: tensor<12x4x4x8x8xf16>) -> tensor<12x4x4x8x8xf16> attributes { @@ -208,3 +207,70 @@ func.func @do_not_decompose_pack(%arg0: tensor<1x16384x512xbf16>, %arg1: tensor< } // CHECK-LABEL: func.func @do_not_decompose_pack // CHECK: tensor.pack {{.+}} : tensor<1x16384x512xbf16> -> tensor<1x1024x256x16x2xbf16> + +// ----- + +func.func @unpack_without_transpose(%arg0: tensor<1828x8x64x16x16xf32>) -> tensor<1828x128x1024xf32> attributes { + hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {ukernels = "unpack", target_triple="x86_64-xyz-xyz", cpu_features=""}> +} { + %6 = tensor.empty() : tensor<1828x128x1024xf32> + %unpack = tensor.unpack %arg0 + outer_dims_perm = [0, 1, 2] + inner_dims_pos = [1, 2] + inner_tiles = [16, 16] + into %6 : tensor<1828x8x64x16x16xf32> -> tensor<1828x128x1024xf32> + return %unpack : tensor<1828x128x1024xf32> +} +// CHECK-LABEL: func.func @unpack_without_transpose( +// CHECK: %[[SRC:.*]]: tensor<1828x8x64x16x16xf32>) -> tensor<1828x128x1024xf32> +// CHECK: %[[CST_1:.*]] = arith.constant 1 : index +// CHECK: %[[CST_1828:.*]] = arith.constant 1828 : index +// CHECK: %[[CST_0:.*]] = arith.constant 0 : index +// CHECK: %[[DEST:.*]] = tensor.empty() : tensor<1828x128x1024xf32> +// CHECK: %[[RES:.*]] = scf.for %[[ITER:.*]] = %[[CST_0]] to %[[CST_1828]] +// CHECK-SAME: step %[[CST_1]] iter_args(%[[ITER_ARG:.*]] = %[[DEST]]) -> (tensor<1828x128x1024xf32>) { +// CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][%[[ITER]], 0, 0, 0, 0] [1, 8, 64, 16, 16] [1, 1, 1, 1, 1] +// CHECK-SAME: : tensor<1828x8x64x16x16xf32> to tensor<8x64x16x16xf32> +// CHECK: %[[DEST_SLICE:.*]] = tensor.extract_slice %[[ITER_ARG]][%[[ITER]], 0, 0] [1, 128, 1024] [1, 1, 1] +// CHECK-SAME: : tensor<1828x128x1024xf32> to tensor<128x1024xf32> +// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[SRC_SLICE]] +// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16] +// CHECK-SAME: into %[[DEST_SLICE]] : tensor<8x64x16x16xf32> -> tensor<128x1024xf32> +// CHECK: %[[NEW_ITER_ARG:.*]] = tensor.insert_slice %[[UNPACK]] into %[[ITER_ARG]][%[[ITER]], 0, 0] [1, 128, 1024] [1, 1, 1] +// CHECK-SAME: : tensor<128x1024xf32> into tensor<1828x128x1024xf32> +// CHECK: scf.yield %[[NEW_ITER_ARG]] : tensor<1828x128x1024xf32> +// CHECK: } +// CHECK: return %[[RES]] : tensor<1828x128x1024xf32> +// CHECK: } + +// ----- + +func.func @unpack_outer_dim_transpose(%arg0: tensor<4x8x29241x16x16xf32>) -> tensor<29241x128x64xf32> attributes { + hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {ukernels = "unpack", target_triple="x86_64-xyz-xyz", cpu_features=""}> +} { + %cst = arith.constant 0.000000e+00 : bf16 + %4 = tensor.empty() : tensor<29241x128x64xf32> + %unpack = tensor.unpack %arg0 outer_dims_perm = [2, 1, 0] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %4 : tensor<4x8x29241x16x16xf32> -> tensor<29241x128x64xf32> + return %unpack : tensor<29241x128x64xf32> +} +// CHECK-LABEL: func.func @unpack_outer_dim_transpose( +// CHECK: %[[SRC:.*]]: tensor<4x8x29241x16x16xf32>) -> tensor<29241x128x64xf32> +// CHECK: %[[CST_1:.*]] = arith.constant 1 : index +// CHECK: %[[CST_29K:.*]] = arith.constant 29241 : index +// CHECK: %[[CST_0:.*]] = arith.constant 0 : index +// CHECK: %[[DEST:.*]] = tensor.empty() : tensor<29241x128x64xf32> +// CHECK: %[[RES:.*]] = scf.for %[[ITER:.*]] = %[[CST_0]] to %[[CST_29K]] step %[[CST_1]] +// CHECK-SAME: iter_args(%[[ITER_ARG:.*]] = %[[DEST]]) -> (tensor<29241x128x64xf32>) { +// CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, 0, %[[ITER]], 0, 0] [4, 8, 1, 16, 16] [1, 1, 1, 1, 1] +// CHECK-SAME: : tensor<4x8x29241x16x16xf32> to tensor<4x8x16x16xf32> +// CHECK: %[[DEST_SLICE:.*]] = tensor.extract_slice %[[ITER_ARG]][%[[ITER]], 0, 0] [1, 128, 64] [1, 1, 1] +// CHECK-SAME: : tensor<29241x128x64xf32> to tensor<128x64xf32> +// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[SRC_SLICE]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 16] +// CHECK-SAME: into %[[DEST_SLICE]] : tensor<4x8x16x16xf32> -> tensor<128x64xf32> +// CHECK: %[[NEW_ITER_ARG:.*]] = tensor.insert_slice %[[UNPACK]] into %[[ITER_ARG]][%[[ITER]], 0, 0] [1, 128, 64] [1, 1, 1] +// CHECK-SAME: : tensor<128x64xf32> into tensor<29241x128x64xf32> +// CHECK: scf.yield %[[NEW_ITER_ARG]] : tensor<29241x128x64xf32> +// CHECK: } +// CHECK: return %[[RES]] : tensor<29241x128x64xf32> +// CHECK: }