From df3d588dde006e05b85d4d37674c774e64d7b12f Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 8 Aug 2024 21:45:03 -0700 Subject: [PATCH] Erase shape_assertion ops (#18167) Just dropping shape_assertion custom call ops. --- .../Conversion/StableHLOCustomCalls.cpp | 17 ++++++++++++++++- tests/e2e/stablehlo_ops/BUILD.bazel | 4 ++++ tests/e2e/stablehlo_ops/CMakeLists.txt | 10 ++++++++++ tests/e2e/stablehlo_ops/shape_assertion.mlir | 9 +++++++++ 4 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 tests/e2e/stablehlo_ops/shape_assertion.mlir diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp index c410f71f384e..a4118689191b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp @@ -221,6 +221,21 @@ struct HouseholderReflectorRewriter final } }; +struct ShapeAssertionDrop final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + using OpAdaptor = mlir::stablehlo::CustomCallOp::Adaptor; + + LogicalResult matchAndRewrite(mlir::stablehlo::CustomCallOp op, + PatternRewriter &rewriter) const final { + if (op.getCallTargetName() != "shape_assertion") { + return rewriter.notifyMatchFailure(op, "not shape_assertion"); + } + rewriter.eraseOp(op); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pass Definition. //===----------------------------------------------------------------------===// @@ -237,7 +252,7 @@ struct LegalizeStableHLOCustomCalls final MLIRContext *ctx = f.getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns.add(ctx); if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { signalPassFailure(); } diff --git a/tests/e2e/stablehlo_ops/BUILD.bazel b/tests/e2e/stablehlo_ops/BUILD.bazel index e10b2171477d..fb9a65d470cf 100644 --- a/tests/e2e/stablehlo_ops/BUILD.bazel +++ b/tests/e2e/stablehlo_ops/BUILD.bazel @@ -65,6 +65,7 @@ ALL_SRCS = enforce_glob( "scatter.mlir", "scatter_dynamic.mlir", "select.mlir", + "shape_assertion.mlir", "sine.mlir", "slice.mlir", "sort.mlir", @@ -169,6 +170,7 @@ iree_check_single_backend_test_suite( "scatter.mlir", "scatter_dynamic.mlir", "select.mlir", + "shape_assertion.mlir", "sine.mlir", "slice.mlir", "sort.mlir", @@ -247,6 +249,7 @@ iree_check_single_backend_test_suite( "scatter.mlir", "scatter_dynamic.mlir", "select.mlir", + "shape_assertion.mlir", "sine.mlir", "slice.mlir", "sort.mlir", @@ -381,6 +384,7 @@ iree_check_single_backend_test_suite( "scatter.mlir", "scatter_dynamic.mlir", "select.mlir", + "shape_assertion.mlir", "sine.mlir", "slice.mlir", "sort.mlir", diff --git a/tests/e2e/stablehlo_ops/CMakeLists.txt b/tests/e2e/stablehlo_ops/CMakeLists.txt index 2050c5286702..4ebc7b5c0a7f 100644 --- a/tests/e2e/stablehlo_ops/CMakeLists.txt +++ b/tests/e2e/stablehlo_ops/CMakeLists.txt @@ -65,6 +65,7 @@ iree_check_single_backend_test_suite( "scatter.mlir" "scatter_dynamic.mlir" "select.mlir" + "shape_assertion.mlir" "sine.mlir" "slice.mlir" "sort.mlir" @@ -142,6 +143,7 @@ iree_check_single_backend_test_suite( "scatter.mlir" "scatter_dynamic.mlir" "select.mlir" + "shape_assertion.mlir" "sine.mlir" "slice.mlir" "sort.mlir" @@ -219,6 +221,7 @@ iree_check_single_backend_test_suite( "scatter.mlir" "scatter_dynamic.mlir" "select.mlir" + "shape_assertion.mlir" "sine.mlir" "slice.mlir" "sort.mlir" @@ -291,6 +294,7 @@ iree_check_single_backend_test_suite( "scatter.mlir" "scatter_dynamic.mlir" "select.mlir" + "shape_assertion.mlir" "sine.mlir" "slice.mlir" "sort.mlir" @@ -366,6 +370,7 @@ iree_check_single_backend_test_suite( "scatter.mlir" "scatter_dynamic.mlir" "select.mlir" + "shape_assertion.mlir" "sine.mlir" "slice.mlir" "sort.mlir" @@ -450,6 +455,7 @@ iree_check_single_backend_test_suite( "scatter.mlir" "scatter_dynamic.mlir" "select.mlir" + "shape_assertion.mlir" "sine.mlir" "slice.mlir" "sort.mlir" @@ -533,6 +539,7 @@ iree_check_single_backend_test_suite( "scatter.mlir" "scatter_dynamic.mlir" "select.mlir" + "shape_assertion.mlir" "sine.mlir" "slice.mlir" "sort.mlir" @@ -607,6 +614,7 @@ iree_check_single_backend_test_suite( "scatter.mlir" "scatter_dynamic.mlir" "select.mlir" + "shape_assertion.mlir" "sine.mlir" "slice.mlir" "sort.mlir" @@ -680,6 +688,7 @@ iree_check_single_backend_test_suite( "scatter.mlir" "scatter_dynamic.mlir" "select.mlir" + "shape_assertion.mlir" "sine.mlir" "slice.mlir" "sort.mlir" @@ -754,6 +763,7 @@ iree_check_single_backend_test_suite( "scatter.mlir" "scatter_dynamic.mlir" "select.mlir" + "shape_assertion.mlir" "sine.mlir" "slice.mlir" "sort.mlir" diff --git a/tests/e2e/stablehlo_ops/shape_assertion.mlir b/tests/e2e/stablehlo_ops/shape_assertion.mlir new file mode 100644 index 000000000000..0d5de866912a --- /dev/null +++ b/tests/e2e/stablehlo_ops/shape_assertion.mlir @@ -0,0 +1,9 @@ +func.func @tensor() { + %0 = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> + %1 = util.unfoldable_constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf32> + %4 = stablehlo.compare EQ, %0, %1, NOTYPE : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + stablehlo.custom_call @shape_assertion(%4) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<4xi1>) -> () + %result = "stablehlo.add"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + check.expect_almost_eq_const(%result, dense<[6.0, 8.0, 10.0, 12.0]> : tensor<4xf32>) : tensor<4xf32> + return +}