Skip to content

Commit

Permalink
[CPU] Add support for unpack ukernel preparation (iree-org#17498)
Browse files Browse the repository at this point in the history
The unpack ukernel only works for 4D cases. Similar to what's happening
in pack, this revision adds support to convert 5D unpack to 4D unpack.
  • Loading branch information
pashu123 authored Jun 6, 2024
1 parent 5404ad7 commit 58feff3
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 1 deletion.
118 changes: 118 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,41 @@ static void tileNonPackedDimsFor3DPackOps(RewriterBase &rewriter,
});
}

static void tileNonPackedDimsFor5DPUnpackOps(RewriterBase &rewriter,
FunctionOpInterface funcOp) {
funcOp.walk([&](tensor::UnPackOp unpackOp) {
if (unpackOp.getSourceRank() != 5 || unpackOp.getDestRank() != 3) {
return;
}

OpFoldResult zero = rewriter.getIndexAttr(0);
OpFoldResult one = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> tileSizes(unpackOp.getDestRank(), one);

for (auto dim : unpackOp.getInnerDimsPos()) {
tileSizes[dim] = zero;
}

// Skip the tiling if the size is already 1.
RankedTensorType destType = unpackOp.getDestType();
for (auto [idx, val] : llvm::enumerate(tileSizes)) {
if (val && destType.getDimSize(idx) == 1)
return;
}

auto tilingInterfaceOp = cast<TilingInterface>(unpackOp.getOperation());
auto options = scf::SCFTilingOptions().setTileSizes(tileSizes);
auto outerDimsPerm = unpackOp.getOuterDimsPerm();
if (!outerDimsPerm.empty()) {
options.setInterchange(outerDimsPerm);
}
FailureOr<scf::SCFTilingResult> tilingResult =
scf::tileUsingSCF(rewriter, tilingInterfaceOp, options);
assert(succeeded(tilingResult));
rewriter.replaceOp(unpackOp, tilingResult->replacements);
});
}

/// Returns true if:
/// 1. `genericOp` is element-wise with all identity indexing maps
/// 2. `genericOp` has only one input and one output with the same shape
Expand Down Expand Up @@ -278,6 +313,84 @@ struct Convert3DPackto2DPackPattern : public OpRewritePattern<tensor::PackOp> {
}
};

struct Convert5DUnPackto4DUnPackPattern
: public OpRewritePattern<tensor::UnPackOp> {
using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
if (unpackOp.getSourceRank() != 5 || unpackOp.getDestRank() != 3) {
return failure();
}

llvm::SmallDenseSet<int64_t> s(unpackOp.getInnerDimsPos().begin(),
unpackOp.getInnerDimsPos().end());

SmallVector<int64_t> seqOrOuterDimsPerm =
llvm::to_vector(llvm::seq<int64_t>(0, unpackOp.getDestRank()));
if (!unpackOp.getOuterDimsPerm().empty()) {
applyPermutationToVector(seqOrOuterDimsPerm, unpackOp.getOuterDimsPerm());
}

int64_t srcPos = 0;
int64_t destPos = 0;

for (auto [idx, val] : llvm::enumerate(seqOrOuterDimsPerm)) {
if (s.contains(val))
continue;
srcPos = idx;
destPos = val;
break;
}

if (unpackOp.getSourceType().getDimSize(srcPos) != 1) {
return rewriter.notifyMatchFailure(unpackOp, "srcPos != 1");
}

if (unpackOp.getDestType().getDimSize(destPos) != 1) {
return rewriter.notifyMatchFailure(unpackOp, "destPos != 1");
}

// Calculate the new innerDimsPos and outerDimsPerm after removal of the
// unit non packed/tiled dimension.
SmallVector<int64_t> newInnerDimsPos(unpackOp.getInnerDimsPos());
for (auto &val : newInnerDimsPos) {
assert(val != destPos);
if (val > destPos)
val--;
}

SmallVector<int64_t> newOuterDimsPerm(unpackOp.getOuterDimsPerm());
if (!newOuterDimsPerm.empty()) {
newOuterDimsPerm.erase(newOuterDimsPerm.begin() + srcPos);
for (auto &val : newOuterDimsPerm) {
if (val > destPos)
val--;
}
}

Location loc = unpackOp.getLoc();
auto reducedSrcType =
RankedTensorType::Builder(unpackOp.getSourceType()).dropDim(srcPos);
auto reducedSrc = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, unpackOp.getSource(), reducedSrcType);

auto reducedDestType =
RankedTensorType::Builder(unpackOp.getDestType()).dropDim(destPos);
auto reducedDest = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, unpackOp.getDest(), reducedDestType);

auto newUnpackOp = rewriter.create<tensor::UnPackOp>(
loc, reducedSrc, reducedDest, newInnerDimsPos, unpackOp.getMixedTiles(),
newOuterDimsPerm);

auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp(
rewriter, loc, newUnpackOp.getResult(), unpackOp.getDest());
rewriter.replaceOp(unpackOp, insertSliceOp);
return success();
}
};

struct CPUPrepareUkernelsPass
: public CPUPrepareUkernelsBase<CPUPrepareUkernelsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
Expand Down Expand Up @@ -305,6 +418,10 @@ void CPUPrepareUkernelsPass::runOnOperation() {
tileNonPackedDimsFor3DPackOps(rewriter, funcOp);
patterns.add<Convert3DPackto2DPackPattern>(ctx);
}
if (hasUkernel(targetAttr, "unpack")) {
tileNonPackedDimsFor5DPUnpackOps(rewriter, funcOp);
patterns.add<Convert5DUnPackto4DUnPackPattern>(ctx);
}

// Canonicalize extract and insert slice ops created during the conversion.
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
Expand All @@ -313,6 +430,7 @@ void CPUPrepareUkernelsPass::runOnOperation() {
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx);
tensor::PackOp::getCanonicalizationPatterns(patterns, ctx);
tensor::UnPackOp::getCanonicalizationPatterns(patterns, ctx);
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
tensor::populateFoldTensorEmptyPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ func.func @batch_mmt4d_with_fill_batch_dim(%arg0: tensor<12x10x32x8x1xf32>, %arg
// CHECK: }
// CHECK: return %[[TILED_RES]] : tensor<12x10x80x8x4xf32>


// -----

func.func @batch_mmt4d_with_lowering_config(%arg0: tensor<12x4x64x8x1xf16>, %arg1: tensor<12x4x64x8x1xf16>, %arg2: tensor<12x4x4x8x8xf16>) -> tensor<12x4x4x8x8xf16> attributes {
Expand Down Expand Up @@ -208,3 +207,70 @@ func.func @do_not_decompose_pack(%arg0: tensor<1x16384x512xbf16>, %arg1: tensor<
}
// CHECK-LABEL: func.func @do_not_decompose_pack
// CHECK: tensor.pack {{.+}} : tensor<1x16384x512xbf16> -> tensor<1x1024x256x16x2xbf16>

// -----

func.func @unpack_without_transpose(%arg0: tensor<1828x8x64x16x16xf32>) -> tensor<1828x128x1024xf32> attributes {
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {ukernels = "unpack", target_triple="x86_64-xyz-xyz", cpu_features=""}>
} {
%6 = tensor.empty() : tensor<1828x128x1024xf32>
%unpack = tensor.unpack %arg0
outer_dims_perm = [0, 1, 2]
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
into %6 : tensor<1828x8x64x16x16xf32> -> tensor<1828x128x1024xf32>
return %unpack : tensor<1828x128x1024xf32>
}
// CHECK-LABEL: func.func @unpack_without_transpose(
// CHECK: %[[SRC:.*]]: tensor<1828x8x64x16x16xf32>) -> tensor<1828x128x1024xf32>
// CHECK: %[[CST_1:.*]] = arith.constant 1 : index
// CHECK: %[[CST_1828:.*]] = arith.constant 1828 : index
// CHECK: %[[CST_0:.*]] = arith.constant 0 : index
// CHECK: %[[DEST:.*]] = tensor.empty() : tensor<1828x128x1024xf32>
// CHECK: %[[RES:.*]] = scf.for %[[ITER:.*]] = %[[CST_0]] to %[[CST_1828]]
// CHECK-SAME: step %[[CST_1]] iter_args(%[[ITER_ARG:.*]] = %[[DEST]]) -> (tensor<1828x128x1024xf32>) {
// CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][%[[ITER]], 0, 0, 0, 0] [1, 8, 64, 16, 16] [1, 1, 1, 1, 1]
// CHECK-SAME: : tensor<1828x8x64x16x16xf32> to tensor<8x64x16x16xf32>
// CHECK: %[[DEST_SLICE:.*]] = tensor.extract_slice %[[ITER_ARG]][%[[ITER]], 0, 0] [1, 128, 1024] [1, 1, 1]
// CHECK-SAME: : tensor<1828x128x1024xf32> to tensor<128x1024xf32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[SRC_SLICE]]
// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16]
// CHECK-SAME: into %[[DEST_SLICE]] : tensor<8x64x16x16xf32> -> tensor<128x1024xf32>
// CHECK: %[[NEW_ITER_ARG:.*]] = tensor.insert_slice %[[UNPACK]] into %[[ITER_ARG]][%[[ITER]], 0, 0] [1, 128, 1024] [1, 1, 1]
// CHECK-SAME: : tensor<128x1024xf32> into tensor<1828x128x1024xf32>
// CHECK: scf.yield %[[NEW_ITER_ARG]] : tensor<1828x128x1024xf32>
// CHECK: }
// CHECK: return %[[RES]] : tensor<1828x128x1024xf32>
// CHECK: }

// -----

func.func @unpack_outer_dim_transpose(%arg0: tensor<4x8x29241x16x16xf32>) -> tensor<29241x128x64xf32> attributes {
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {ukernels = "unpack", target_triple="x86_64-xyz-xyz", cpu_features=""}>
} {
%cst = arith.constant 0.000000e+00 : bf16
%4 = tensor.empty() : tensor<29241x128x64xf32>
%unpack = tensor.unpack %arg0 outer_dims_perm = [2, 1, 0] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %4 : tensor<4x8x29241x16x16xf32> -> tensor<29241x128x64xf32>
return %unpack : tensor<29241x128x64xf32>
}
// CHECK-LABEL: func.func @unpack_outer_dim_transpose(
// CHECK: %[[SRC:.*]]: tensor<4x8x29241x16x16xf32>) -> tensor<29241x128x64xf32>
// CHECK: %[[CST_1:.*]] = arith.constant 1 : index
// CHECK: %[[CST_29K:.*]] = arith.constant 29241 : index
// CHECK: %[[CST_0:.*]] = arith.constant 0 : index
// CHECK: %[[DEST:.*]] = tensor.empty() : tensor<29241x128x64xf32>
// CHECK: %[[RES:.*]] = scf.for %[[ITER:.*]] = %[[CST_0]] to %[[CST_29K]] step %[[CST_1]]
// CHECK-SAME: iter_args(%[[ITER_ARG:.*]] = %[[DEST]]) -> (tensor<29241x128x64xf32>) {
// CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, 0, %[[ITER]], 0, 0] [4, 8, 1, 16, 16] [1, 1, 1, 1, 1]
// CHECK-SAME: : tensor<4x8x29241x16x16xf32> to tensor<4x8x16x16xf32>
// CHECK: %[[DEST_SLICE:.*]] = tensor.extract_slice %[[ITER_ARG]][%[[ITER]], 0, 0] [1, 128, 64] [1, 1, 1]
// CHECK-SAME: : tensor<29241x128x64xf32> to tensor<128x64xf32>
// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[SRC_SLICE]]
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 16]
// CHECK-SAME: into %[[DEST_SLICE]] : tensor<4x8x16x16xf32> -> tensor<128x64xf32>
// CHECK: %[[NEW_ITER_ARG:.*]] = tensor.insert_slice %[[UNPACK]] into %[[ITER_ARG]][%[[ITER]], 0, 0] [1, 128, 64] [1, 1, 1]
// CHECK-SAME: : tensor<128x64xf32> into tensor<29241x128x64xf32>
// CHECK: scf.yield %[[NEW_ITER_ARG]] : tensor<29241x128x64xf32>
// CHECK: }
// CHECK: return %[[RES]] : tensor<29241x128x64xf32>
// CHECK: }

0 comments on commit 58feff3

Please sign in to comment.