From 0655e1d0cf713c44e304593b159e95b37b4cec26 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Thu, 21 Mar 2024 10:21:18 -0700 Subject: [PATCH] Integrate CHLO->StableHLO lowerings in XLA PiperOrigin-RevId: 617886759 --- BUILD | 26 +- mhlo/transforms/CMakeLists.txt | 1 - .../chlo_legalize_to_hlo.cc | 1976 ----------------- .../chlo_legalize_to_hlo_pass.cc | 111 +- .../chlo_legalize_to_hlo_patterns.td | 337 +-- mhlo/transforms/mhlo_passes.td | 37 +- mhlo/transforms/passes.h | 8 - mhlo/transforms/rewriters.h | 25 +- .../transforms/ChloRecomposeOps.cpp | 10 + .../experimental/transforms/Passes.h | 2 + .../chlo/chlo_legalize_to_stablehlo.mlir | 132 +- .../transforms/ChloLegalizeToStablehlo.cpp | 5 +- .../chlo/chlo_legalize_to_hlo_broadcasts.mlir | 345 --- .../chlo_legalize_to_hlo_no_broadcasts.mlir | 11 - tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir | 23 +- .../chlo/chlo_legalize_to_mhlo_basis_ops.mlir | 276 --- tests/Dialect/mhlo/lower-complex.mlir | 2 +- 17 files changed, 184 insertions(+), 3143 deletions(-) delete mode 100644 mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc delete mode 100644 tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir delete mode 100644 tests/Dialect/chlo/chlo_legalize_to_hlo_no_broadcasts.mlir delete mode 100644 tests/Dialect/chlo/chlo_legalize_to_mhlo_basis_ops.mlir diff --git a/BUILD b/BUILD index fa2bc4bd4..f1c30e4e0 100644 --- a/BUILD +++ b/BUILD @@ -554,7 +554,7 @@ cc_library( ], strip_include_prefix = ".", deps = [ - ":chlo_legalize_to_hlo", + ":chlo_legalize_to_hlo_inc_gen", ":hlo_legalize_to_stablehlo", ":legalize_to_linalg_utils", ":legalize_to_standard_inc_gen", @@ -574,6 +574,7 @@ cc_library( "//stablehlo:base", "//stablehlo:chlo_ops", "//stablehlo:stablehlo_ops", + "//stablehlo:stablehlo_passes", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", @@ -860,28 +861,6 @@ cc_library( ], ) -cc_library( - name = "chlo_legalize_to_hlo", - srcs = ["mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc"], - hdrs = ["mhlo/transforms/rewriters.h"], - strip_include_prefix = ".", - deps = [ - ":chlo_legalize_to_hlo_inc_gen", - ":map_chlo_to_hlo_op", - ":mlir_hlo", - "//stablehlo:broadcast_utils", - "//stablehlo:chlo_ops", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ComplexDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:Transforms", - ], -) - gentbl_cc_library( name = "chlo_legalize_to_hlo_inc_gen", strip_include_prefix = "mhlo/transforms", @@ -950,7 +929,6 @@ cc_library( ], strip_include_prefix = ".", deps = [ - ":chlo_legalize_to_hlo", ":deallocation_passes", ":deallocation_passes_inc_gen", ":lhlo", diff --git a/mhlo/transforms/CMakeLists.txt b/mhlo/transforms/CMakeLists.txt index ba8ab2e65..b7d447620 100644 --- a/mhlo/transforms/CMakeLists.txt +++ b/mhlo/transforms/CMakeLists.txt @@ -192,7 +192,6 @@ add_mlir_library(MhloToStandard ) add_mlir_library(ChloPasses - chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc DEPENDS diff --git a/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc b/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc deleted file mode 100644 index 958d137a4..000000000 --- a/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc +++ /dev/null @@ -1,1976 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Enable the use of M_* math constants. -// NOTE: this must be first in the file to ensure that if cmath is transitively -// included by any other header it has the define set on first processing. -// https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants -#define _USE_MATH_DEFINES -#include -#include -#include -#include -#include -#include - -#include "llvm/ADT/SmallVector.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/map_chlo_to_hlo_op.h" -#include "mhlo/transforms/rewriters.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/BroadcastUtils.h" -#include "stablehlo/dialect/ChloOps.h" -#include "utils/hlo_utils.h" - -namespace mlir { -namespace chlo { -namespace { - -struct ConvertConstantLikeOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ConstantLikeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resultTy = op.getType().cast(); - - // Unranked uses are not supported. - if (!resultTy.hasRank()) return failure(); - - // Lower to MHLO constant if statically shaped. - if (resultTy.hasStaticShape()) { - auto complexAttr = op.getValue().dyn_cast(); - auto attr = complexAttr - ? DenseElementsAttr::get(resultTy, complexAttr.getValue()) - : DenseElementsAttr::get(resultTy, op.getValue()); - rewriter.replaceOpWithNewOp(op, attr); - return success(); - } - - // Lower to broadcasted constant. - auto loc = op.getLoc(); - Value constant = rewriter.create(loc, op.getValue()); - Value shape = rewriter.create(loc, adaptor.getOperand()); - rewriter.replaceOpWithNewOp( - op, resultTy, constant, shape, rewriter.getI64TensorAttr({})); - return success(); - } -}; - -template -Value materializeChebyshevPolynomialApproximation( - ConversionPatternRewriter &rewriter, Location loc, Value x, - ArrayRef coefficients) { - Value b0 = chlo::getConstantLike(rewriter, loc, 0.0, x); - Value b1 = chlo::getConstantLike(rewriter, loc, 0.0, x); - Value b2 = chlo::getConstantLike(rewriter, loc, 0.0, x); - for (FTy c : coefficients) { - b2 = b1; - b1 = b0; - b0 = rewriter.create(loc, x.getType(), x, b1); - b0 = rewriter.create(loc, x.getType(), b0, b2); - b0 = rewriter.create( - loc, x.getType(), b0, chlo::getConstantLike(rewriter, loc, c, x)); - } - Value result = rewriter.create(loc, x.getType(), b0, b2); - result = rewriter.create( - loc, x.getType(), result, chlo::getConstantLike(rewriter, loc, 0.5, x)); - return result; -} - -template -Value materializeBesselI1eApproximation(ConversionPatternRewriter &rewriter, - Location loc, Value x, - ArrayRef kI1eCoeffsA, - ArrayRef kI1eCoeffsB) { - Value z = rewriter.create(loc, x); - Value half = chlo::getConstantLike(rewriter, loc, 0.5, x); - Value two = chlo::getConstantLike(rewriter, loc, 2.0, x); - Value thirtyTwo = chlo::getConstantLike(rewriter, loc, 32.0, x); - Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x); - - Value tmp = rewriter.create(loc, half, z); - tmp = rewriter.create(loc, tmp, two); - - Value xLe8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp, - kI1eCoeffsA); - xLe8 = rewriter.create(loc, z, xLe8); - - tmp = rewriter.create(loc, thirtyTwo, z); - tmp = rewriter.create(loc, tmp, two); - Value xGt8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp, - kI1eCoeffsB); - xGt8 = rewriter.create(loc, xGt8, - rewriter.create(loc, z)); - - Value isLe8 = rewriter.create(loc, z, eight, - mhlo::ComparisonDirection::LE); - - Value select = rewriter.create(loc, isLe8, xLe8, xGt8); - return rewriter.create( - loc, rewriter.create(loc, x), select); -} - -Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF32() && - "expect f32 element type"); - const float kI1eCoeffsA[] = { - 9.38153738649577178388E-9f, -4.44505912879632808065E-8f, - 2.00329475355213526229E-7f, -8.56872026469545474066E-7f, - 3.47025130813767847674E-6f, -1.32731636560394358279E-5f, - 4.78156510755005422638E-5f, -1.61760815825896745588E-4f, - 5.12285956168575772895E-4f, -1.51357245063125314899E-3f, - 4.15642294431288815669E-3f, -1.05640848946261981558E-2f, - 2.47264490306265168283E-2f, -5.29459812080949914269E-2f, - 1.02643658689847095384E-1f, -1.76416518357834055153E-1f, - 2.52587186443633654823E-1f}; - - const float kI1eCoeffsB[] = { - -3.83538038596423702205E-9f, -2.63146884688951950684E-8f, - -2.51223623787020892529E-7f, -3.88256480887769039346E-6f, - -1.10588938762623716291E-4f, -9.76109749136146840777E-3f, - 7.78576235018280120474E-1f}; - - return materializeBesselI1eApproximation(rewriter, loc, x, kI1eCoeffsA, - kI1eCoeffsB); -} - -Value materializeBesselI1eApproximationF64(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF64() && - "expect f64 element type"); - - const double kI1eCoeffsA[] = { - 2.77791411276104639959E-18, -2.11142121435816608115E-17, - 1.55363195773620046921E-16, -1.10559694773538630805E-15, - 7.60068429473540693410E-15, -5.04218550472791168711E-14, - 3.22379336594557470981E-13, -1.98397439776494371520E-12, - 1.17361862988909016308E-11, -6.66348972350202774223E-11, - 3.62559028155211703701E-10, -1.88724975172282928790E-9, - 9.38153738649577178388E-9, -4.44505912879632808065E-8, - 2.00329475355213526229E-7, -8.56872026469545474066E-7, - 3.47025130813767847674E-6, -1.32731636560394358279E-5, - 4.78156510755005422638E-5, -1.61760815825896745588E-4, - 5.12285956168575772895E-4, -1.51357245063125314899E-3, - 4.15642294431288815669E-3, -1.05640848946261981558E-2, - 2.47264490306265168283E-2, -5.29459812080949914269E-2, - 1.02643658689847095384E-1, -1.76416518357834055153E-1, - 2.52587186443633654823E-1}; - - const double kI1eCoeffsB[] = { - 7.51729631084210481353E-18, 4.41434832307170791151E-18, - -4.65030536848935832153E-17, -3.20952592199342395980E-17, - 2.96262899764595013876E-16, 3.30820231092092828324E-16, - -1.88035477551078244854E-15, -3.81440307243700780478E-15, - 1.04202769841288027642E-14, 4.27244001671195135429E-14, - -2.10154184277266431302E-14, -4.08355111109219731823E-13, - -7.19855177624590851209E-13, 2.03562854414708950722E-12, - 1.41258074366137813316E-11, 3.25260358301548823856E-11, - -1.89749581235054123450E-11, -5.58974346219658380687E-10, - -3.83538038596423702205E-9, -2.63146884688951950684E-8, - -2.51223623787020892529E-7, -3.88256480887769039346E-6, - -1.10588938762623716291E-4, -9.76109749136146840777E-3, - 7.78576235018280120474E-1}; - - return materializeBesselI1eApproximation(rewriter, loc, x, - kI1eCoeffsA, kI1eCoeffsB); -} - -Value materializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc, - ValueRange args, FloatType minPrecisionTy, - Value callback(ConversionPatternRewriter &, - Location, ValueRange)) { - auto originalTy = getElementTypeOrSelf(args.front().getType()); - auto floatOriginalTy = originalTy.dyn_cast(); - bool needsUpcast = - floatOriginalTy && floatOriginalTy.getWidth() < minPrecisionTy.getWidth(); - - // Upcast arguments if necessary. - llvm::SmallVector castedArgs; - if (needsUpcast) { - for (Value a : args) { - castedArgs.push_back( - rewriter.create(loc, a, minPrecisionTy)); - } - args = castedArgs; - } - - Value result = callback(rewriter, loc, args); - - // Cast back if necessary. - if (needsUpcast) { - result = rewriter.create(loc, result, originalTy); - } - - return result; -} - -struct ConvertBesselI1eOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - BesselI1eOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value x = adaptor.getOperand(); - Type ty = x.getType().cast().getElementType(); - - // For now, we support only f64, f32, f16 and bf16. - // See https://www.tensorflow.org/api_docs/python/tf/math/bessel_i1e - if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) - return failure(); - - if (ty.isF64()) { - rewriter.replaceOp( - op, materializeBesselI1eApproximationF64(rewriter, loc, x)); - return success(); - } - - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), - rewriter.getF32Type(), - &materializeBesselI1eApproximationF32)); - return success(); - } -}; - -template -Value materializePolynomialApproximation(ConversionPatternRewriter &rewriter, - Location loc, Value x, - ArrayRef coefficients) { - if (coefficients.empty()) return chlo::getConstantLike(rewriter, loc, 0.0, x); - - Value poly = chlo::getConstantLike(rewriter, loc, coefficients[0], x); - for (size_t i = 1; i < coefficients.size(); ++i) { - poly = rewriter.create(loc, x.getType(), poly, x); - poly = rewriter.create( - loc, x.getType(), poly, - chlo::getConstantLike(rewriter, loc, coefficients[i], x)); - } - return poly; -} - -// Precondition is |x| >= 1. Use erf approximation, otherwise. -// -// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an -// argument and derive the final approximation for all |x| >= 1. -// This implementation is based on Cephes. -Value materializeErfcApproximationF64ForMagnituteGeOne( - ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF64() && - "expect f64 element type"); - const double kMaxlog = 7.09782712893383996843E2; - const double kErfcPCoefficients[] = { - 2.46196981473530512524E-10, 5.64189564831068821977E-1, - 7.46321056442269912687E0, 4.86371970985681366614E1, - 1.96520832956077098242E2, 5.26445194995477358631E2, - 9.34528527171957607540E2, 1.02755188689515710272E3, - 5.57535335369399327526E2}; - const double kErfcQCoefficients[] = { - 1.00000000000000000000E0, 1.32281951154744992508E1, - 8.67072140885989742329E1, 3.54937778887819891062E2, - 9.75708501743205489753E2, 1.82390916687909736289E3, - 2.24633760818710981792E3, 1.65666309194161350182E3, - 5.57535340817727675546E2}; - const double kErfcRCoefficients[] = { - 5.64189583547755073984E-1, 1.27536670759978104416E0, - 5.01905042251180477414E0, 6.16021097993053585195E0, - 7.40974269950448939160E0, 2.97886665372100240670E0}; - const double kErfcSCoefficients[] = { - 1.00000000000000000000E0, 2.26052863220117276590E0, - 9.39603524938001434673E0, 1.20489539808096656605E1, - 1.70814450747565897222E1, 9.60896809063285878198E0, - 3.36907645100081516050E0}; - - // Let z = -x^2. - Value xSq = rewriter.create(loc, x, x); - Value z = rewriter.create(loc, xSq); - - // Materialize polynomial approximation for x in [1, 8) as - // erfc(x) = exp(z) P(|x|) / Q(|x|). - Value expZ = rewriter.create(loc, z); - Value absX = rewriter.create(loc, x); - Value polP = materializePolynomialApproximation( - rewriter, loc, absX, llvm::ArrayRef(kErfcPCoefficients)); - Value expZMulPolyP = rewriter.create(loc, expZ, polP); - Value polQ = materializePolynomialApproximation( - rewriter, loc, absX, llvm::ArrayRef(kErfcQCoefficients)); - Value erfcApprox18 = rewriter.create(loc, expZMulPolyP, polQ); - - // Materialize polynomial approximation for x in >= 8 as - // erfc(x) exp(z) R(|x|) / S(|x|). - Value polR = materializePolynomialApproximation( - rewriter, loc, absX, llvm::ArrayRef(kErfcRCoefficients)); - Value expZMulPolyR = rewriter.create(loc, expZ, polR); - Value polS = materializePolynomialApproximation( - rewriter, loc, absX, llvm::ArrayRef(kErfcSCoefficients)); - Value erfcApprox8Inf = rewriter.create(loc, expZMulPolyR, polS); - - // Combine polynomial approximations for x >= 1. - Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x); - Value absXLt8 = rewriter.create( - loc, absX, eight, mhlo::ComparisonDirection::LT); - Value erfcApprox = rewriter.create(loc, absXLt8, erfcApprox18, - erfcApprox8Inf); - - // Clamp to prevent overflow and materialize approximation for large x as - // erfc(x) = 0. - Value zLtNegMaxlog = rewriter.create( - loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), - mhlo::ComparisonDirection::LT); - Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x); - Value erfcApproxClamped = - rewriter.create(loc, zLtNegMaxlog, zero, erfcApprox); - - // Derive approximation for x <= -1 as - // erfc(x) = 2 - erfc(-x). - // Reuse previously materialized approximations all of which take |x| as their - // argument. - Value xLtZero = rewriter.create( - loc, x, zero, mhlo::ComparisonDirection::LT); - Value two = chlo::getConstantLike(rewriter, loc, 2.0, x); - Value twoSubErfcApproxClamped = - rewriter.create(loc, two, erfcApproxClamped); - return rewriter.create(loc, xLtZero, twoSubErfcApproxClamped, - erfcApproxClamped); -} - -// Precondition is |x| <= 1. Use erfc approximation, otherwise. -// This implementation is based on Cephes. -Value materializeErfApproximationF64ForMagnituteLeOne( - ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF64() && - "expect f64 element type"); - const double kErfTCoefficients[] = { - 9.60497373987051638749E0, 9.00260197203842689217E1, - 2.23200534594684319226E3, 7.00332514112805075473E3, - 5.55923013010394962768E4}; - const double kErfUCoefficients[] = { - 1.00000000000000000000E0, 3.35617141647503099647E1, - 5.21357949780152679795E2, 4.59432382970980127987E3, - 2.26290000613890934246E4, 4.92673942608635921086E4}; - - // Materialize polynomial approximation for |x| <= 1 as - // erf(x) = x T(x^2) / U(x^2). - Value xSq = rewriter.create(loc, x, x); - Value polyT = materializePolynomialApproximation( - rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients)); - Value xMulPolyT = rewriter.create(loc, x, polyT); - Value polyU = materializePolynomialApproximation( - rewriter, loc, xSq, llvm::ArrayRef(kErfUCoefficients)); - return rewriter.create(loc, xMulPolyT, polyU); -} - -// This implementation is based on Cephes. -Value materializeErfApproximationF64(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF64() && - "expect f64 element type"); - - // Rely on erf approximation for |x| < 1 - // erf(x) = erf_approx(x) - Value erfApprox = - materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x); - - // Rely on erfc approximation for |x| >= 1 and materialize erf as - // erf(x) = 1 - erfc_approx(x) - Value one = chlo::getConstantLike(rewriter, loc, 1.0, x); - Value erfcApprox = - materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x); - Value erfcBasedApprox = - rewriter.create(loc, one, erfcApprox); - - // Materialize approximation selection based on argument. - Value absX = rewriter.create(loc, x); - Value absXLtOne = rewriter.create( - loc, absX, one, mhlo::ComparisonDirection::LT); - return rewriter.create(loc, absXLtOne, erfApprox, - erfcBasedApprox); -} - -Value materializeErfcApproximationF64(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF64() && - "expect f64 element type"); - - // Rely on erfc approximation for |x| >= 1 - // erfc(x) = erfc_approx(x) - Value erfcApprox = - materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x); - - // Rely on erf approximation for |x| < 1 and materialize erfc as - // erfc(x) = 1 - erf_approx(x) - Value one = chlo::getConstantLike(rewriter, loc, 1.0, x); - Value erfApprox = - materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x); - Value erfBasedApprox = rewriter.create(loc, one, erfApprox); - - // Materialize approximation selection based on argument. - Value absX = rewriter.create(loc, x); - Value absXLtOne = rewriter.create( - loc, absX, one, mhlo::ComparisonDirection::LT); - return rewriter.create(loc, absXLtOne, erfBasedApprox, - erfcApprox); -} - -// Precondition is |x| >= 1. Use erf approximation, otherwise. -// -// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an -// argument and derive the final approximation for all |x| >= 1. -// This implementation is based on Cephes. -Value materializeErfcApproximationF32ForMagnitudeGeOne( - ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF32() && - "expect f32 element type"); - const double kMaxlog = 88.72283905206835; - const float kErfcPCoefficients[] = { - +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1, - -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1, - +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1, - }; - const float kErfcRCoefficients[] = { - -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0, - +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1, - -2.820767439740514E-1, +5.641895067754075E-1, - }; - - // Let z = -x^2. - Value xSq = rewriter.create(loc, x, x); - Value z = rewriter.create(loc, xSq); - - // Materialize polynomial approximation for x >= 1 as - // erfc(x) = exp(z) 1/x P(1/x^2) if x in [1, 2) - // erfc(x) = exp(z) 1/x R(1/x^2) if x >= 2 - Value absX = rewriter.create(loc, x); - Value one = chlo::getConstantLike(rewriter, loc, 1.0, x); - Value reciprocalXSq = rewriter.create(loc, one, xSq); - Value expZ = rewriter.create(loc, z); - Value oneDivAbsX = rewriter.create(loc, one, absX); - Value expZMulOneDivAbsX = rewriter.create(loc, expZ, oneDivAbsX); - Value two = chlo::getConstantLike(rewriter, loc, 2.0, x); - Value absXLtTwo = rewriter.create( - loc, absX, two, mhlo::ComparisonDirection::LT); - Value polP = materializePolynomialApproximation( - rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcPCoefficients)); - Value polR = materializePolynomialApproximation( - rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcRCoefficients)); - Value poly = rewriter.create(loc, absXLtTwo, polP, polR); - Value erfcApprox = rewriter.create(loc, expZMulOneDivAbsX, poly); - - // Clamp to prevent overflow and materialize approximation for large x as - // erfc(x) = 0. - Value zLtNeqMaxlog = rewriter.create( - loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), - mhlo::ComparisonDirection::LT); - Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x); - Value erfcApproxClamped = - rewriter.create(loc, zLtNeqMaxlog, zero, erfcApprox); - - // Derive approximation for x <= -1 as - // erfc(x) = 2 - erfc(-x). - // Reuse previously materialized approximations all of which take |x| as their - // argument. - Value xLtZero = rewriter.create( - loc, x, zero, mhlo::ComparisonDirection::LT); - Value twoSubErfcApprox = - rewriter.create(loc, two, erfcApproxClamped); - return rewriter.create(loc, xLtZero, twoSubErfcApprox, - erfcApproxClamped); -} - -// Precondition is |x| <= 1. Use erfc approximation, otherwise. -// This implementation is based on Cephes. -Value materializeErfApproximationF32ForMagnitudeLeOne( - ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF32() && - "expect f32 element type"); - const float kErfTCoefficients[] = { - +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3, - -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1, - +1.128379165726710E+0, - }; - - // Materialize polynomial approximation for |x| <= 1 as - // erf(x) = x T(x^2). - Value xSq = rewriter.create(loc, x, x); - Value polyT = materializePolynomialApproximation( - rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients)); - return rewriter.create(loc, x, polyT); -} - -// This is the same approximation as used in Eigen. -Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF32() && - "expect f32 element type"); - const float kAlpha[] = { - -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f, - -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f, - -1.60960333262415e-02f, - }; - const float kBeta[] = { - -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f, - -7.37332916720468e-03f, -1.42647390514189e-02f, - }; - - // Clamp argument between -4 and 4. - Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x); - Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x); - x = rewriter.create(loc, x.getType(), lb, x, ub); - Value xSq = rewriter.create(loc, x, x); - - // Materialize polynomial approximation for x in [-4, 4] as - // erf(x) = x * Alpha(x^2) / Beta(x^2). - Value alphaPoly = materializePolynomialApproximation(rewriter, loc, xSq, - llvm::ArrayRef(kAlpha)); - Value betaPoly = materializePolynomialApproximation(rewriter, loc, xSq, - llvm::ArrayRef(kBeta)); - Value xMulAlphaPoly = rewriter.create(loc, x, alphaPoly); - Value erf = rewriter.create(loc, xMulAlphaPoly, betaPoly); - Value lbErf = chlo::getConstantLike(rewriter, loc, -1.0, x); - Value ubErf = chlo::getConstantLike(rewriter, loc, 1.0, x); - return rewriter.create(loc, erf.getType(), lbErf, erf, ubErf); -} - -Value materializeErfcApproximationF32(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { - Value x = args.front(); - assert(x.getType().cast().getElementType().isF32() && - "expect f32 element type"); - - // Rely on erfc approximation for |x| >= 1 - // erfc(x) = erfc_approx(x) - Value erfcApprox = - materializeErfcApproximationF32ForMagnitudeGeOne(rewriter, loc, x); - - // Rely on erf approximation for |x| < 1 and materialize erfc as - // erfc(x) = 1 - erf_approx(x) - Value one = chlo::getConstantLike(rewriter, loc, 1.0, x); - Value erfApprox = - materializeErfApproximationF32ForMagnitudeLeOne(rewriter, loc, x); - Value erfBasedApprox = rewriter.create(loc, one, erfApprox); - - // Materialize approximation selection based on argument. - Value absX = rewriter.create(loc, x); - Value absXLtOne = rewriter.create( - loc, absX, one, mhlo::ComparisonDirection::LT); - return rewriter.create(loc, absXLtOne, erfBasedApprox, - erfcApprox); -} - -struct BasisConvertErfOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ErfOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value x = adaptor.getOperand(); - Type ty = x.getType().cast().getElementType(); - - // For now, we support only f64, f32, f16 and bf16. - if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) - return failure(); - - if (ty.isF64()) { - rewriter.replaceOp(op, materializeErfApproximationF64(rewriter, loc, x)); - return success(); - } - - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), - rewriter.getF32Type(), - &materializeErfApproximationF32)); - return success(); - } -}; - -struct ConvertErfcOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ErfcOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value x = adaptor.getOperand(); - Type ty = x.getType().cast().getElementType(); - - // For now, we support only f64, f32, f16 and bf16. - if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) - return failure(); - - if (ty.isF64()) { - rewriter.replaceOp(op, materializeErfcApproximationF64(rewriter, loc, x)); - return success(); - } - - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), - rewriter.getF32Type(), - &materializeErfcApproximationF32)); - return success(); - } -}; - -Value erfInv32(ConversionPatternRewriter &b, Location loc, ValueRange args) { - constexpr int kDegree = 9; - constexpr std::array wLessThan5Constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array wGreaterThan5Constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - - Value x = args[0]; - // Compute logarithm of (1+arg) using log1p(arg) which is more precise than - // log(1+arg) when arg is close to zero. For more details, see - // https://en.cppreference.com/w/cpp/numeric/math/log1p - Value minusXSquared = - b.create(loc, x, b.create(loc, x)); - Value w = - b.create(loc, b.create(loc, minusXSquared)); - - Value lt = b.create(loc, w, getConstantLike(b, loc, 5.0, x), - mhlo::ComparisonDirection::LT); - auto coefficient = [&](int i) { - return b.create( - loc, lt, getConstantLike(b, loc, wLessThan5Constants[i], x), - getConstantLike(b, loc, wGreaterThan5Constants[i], x)); - }; - w = b.create( - loc, lt, - b.create(loc, w, getConstantLike(b, loc, 2.5, x)), - b.create(loc, b.create(loc, w), - getConstantLike(b, loc, 3.0, x))); - Value p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = b.create(loc, coefficient(i), - b.create(loc, p, w)); - } - - // Result modulo edge cases. - Value result = b.create(loc, p, x); - - // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is - // indeterminate, and can give nan or -/+inf.) - return b.create( - loc, - b.create(loc, b.create(loc, x), - getConstantLike(b, loc, 1, x), - mhlo::ComparisonDirection::EQ), - b.create(loc, x, getConstantLikeInfValue(b, loc, x, false)), - result); -} - -Value erfInv64(ConversionPatternRewriter &b, Location loc, ValueRange args) { - constexpr std::array wLessThan625Constants = { - -3.6444120640178196996e-21, -1.685059138182016589e-19, - 1.2858480715256400167e-18, 1.115787767802518096e-17, - -1.333171662854620906e-16, 2.0972767875968561637e-17, - 6.6376381343583238325e-15, -4.0545662729752068639e-14, - -8.1519341976054721522e-14, 2.6335093153082322977e-12, - -1.2975133253453532498e-11, -5.4154120542946279317e-11, - 1.051212273321532285e-09, -4.1126339803469836976e-09, - -2.9070369957882005086e-08, 4.2347877827932403518e-07, - -1.3654692000834678645e-06, -1.3882523362786468719e-05, - 0.0001867342080340571352, -0.00074070253416626697512, - -0.0060336708714301490533, 0.24015818242558961693, - 1.6536545626831027356}; - constexpr std::array wLessThan16Constants = { - 2.2137376921775787049e-09, 9.0756561938885390979e-08, - -2.7517406297064545428e-07, 1.8239629214389227755e-08, - 1.5027403968909827627e-06, -4.013867526981545969e-06, - 2.9234449089955446044e-06, 1.2475304481671778723e-05, - -4.7318229009055733981e-05, 6.8284851459573175448e-05, - 2.4031110387097893999e-05, -0.0003550375203628474796, - 0.00095328937973738049703, -0.0016882755560235047313, - 0.0024914420961078508066, -0.0037512085075692412107, - 0.005370914553590063617, 1.0052589676941592334, - 3.0838856104922207635, - }; - constexpr std::array wGreaterThan16Constants = { - -2.7109920616438573243e-11, -2.5556418169965252055e-10, - 1.5076572693500548083e-09, -3.7894654401267369937e-09, - 7.6157012080783393804e-09, -1.4960026627149240478e-08, - 2.9147953450901080826e-08, -6.7711997758452339498e-08, - 2.2900482228026654717e-07, -9.9298272942317002539e-07, - 4.5260625972231537039e-06, -1.9681778105531670567e-05, - 7.5995277030017761139e-05, -0.00021503011930044477347, - -0.00013871931833623122026, 1.0103004648645343977, - 4.8499064014085844221, - }; - - Value x = args[0]; - // Compute logarithm of (1+arg) using log1p(arg) which is more precise than - // log(1+arg) when arg is close to zero. For more details, see - // https://en.cppreference.com/w/cpp/numeric/math/log1p - Value minusXSquared = - b.create(loc, x, b.create(loc, x)); - Value w = - b.create(loc, b.create(loc, minusXSquared)); - - Value lt625 = b.create( - loc, w, getConstantLike(b, loc, 6.25, x), mhlo::ComparisonDirection::LT); - Value lt16 = b.create(loc, w, getConstantLike(b, loc, 16, x), - mhlo::ComparisonDirection::LT); - - auto coefficient = [&](int i) { - Value c = getConstantLike(b, loc, wLessThan625Constants[i], x); - if (i < 19) { - c = b.create( - loc, lt625, c, getConstantLike(b, loc, wLessThan16Constants[i], x)); - } - if (i < 17) { - c = b.create( - loc, lt16, c, getConstantLike(b, loc, wGreaterThan16Constants[i], x)); - } - return c; - }; - - Value sqrtW = b.create(loc, w); - Value wMinus3125 = - b.create(loc, w, getConstantLike(b, loc, 3.125, x)); - Value select2 = - b.create(loc, lt16, getConstantLike(b, loc, 3.25, w), - getConstantLike(b, loc, 5.0, w)); - Value select2Result = b.create(loc, sqrtW, select2); - w = b.create(loc, lt625, wMinus3125, select2Result); - - Value p = coefficient(0); - for (int i = 1; i < 17; ++i) { - p = b.create(loc, coefficient(i), - b.create(loc, p, w)); - } - for (int i = 17; i < 19; ++i) { - p = b.create( - loc, lt16, - b.create(loc, coefficient(i), - b.create(loc, p, w)), - p); - } - for (int i = 19; i < 23; ++i) { - p = b.create( - loc, lt625, - b.create(loc, coefficient(i), - b.create(loc, p, w)), - p); - } - - // Result modulo edge cases. - Value result = b.create(loc, p, x); - - // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is - // indeterminate, and can give nan or -/+inf.) - return b.create( - loc, - b.create(loc, b.create(loc, x), - getConstantLike(b, loc, 1, x), - mhlo::ComparisonDirection::EQ), - b.create(loc, x, getConstantLikeInfValue(b, loc, x, false)), - result); -} - -struct ConvertErfInvOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ErfInvOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - if (op.getResult().getType().getElementType().isF64()) { - rewriter.replaceOp(op, erfInv64(rewriter, loc, adaptor.getOperands())); - return success(); - } - FloatType minPrecisionTy = rewriter.getF32Type(); - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), - minPrecisionTy, &erfInv32)); - return success(); - } -}; - -// Coefficients for the Lanczos approximation of the gamma function. The -// coefficients are uniquely determined by the choice of g and n (kLanczosGamma -// and kLanczosCoefficients.size() + 1). The coefficients below correspond to -// [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and -// [7, 9] seemed to be the least sensitive to the quality of the log function. -// In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5 -// for a particularly inaccurate log function. -constexpr double kLanczosGamma = 7; // aka g -constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478; -constexpr std::array kLanczosCoefficients = { - 676.520368121885098567009190444019, -1259.13921672240287047156078755283, - 771.3234287776530788486528258894, -176.61502916214059906584551354, - 12.507343278686904814458936853, -0.13857109526572011689554707, - 9.984369578019570859563e-6, 1.50563273514931155834e-7}; - -// Compute the Lgamma function using Lanczos' approximation from "A Precision -// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis -// series B. Vol. 1: -// lgamma(z + 1) = (log(2) + log(pi)) / 2 -// + (z + 1/2) * log(t(z)) -// - t(z) + log(a(z)) -// with t(z) = z + kLanczosGamma + 1/2 -// a(z) = kBaseLanczosCoeff -// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) -Value materializeLgamma(ConversionPatternRewriter &rewriter, Location loc, - ValueRange args) { - // If the input is less than 0.5 use Euler's reflection formula. - // gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) - // Let z be - // z = -x if x < 1/2 - // z = x - 1 otheriwse - Value x = args.front(); - Value half = getConstantLike(rewriter, loc, 0.5, x); - Value needToReflect = rewriter.create( - loc, x, half, mhlo::ComparisonDirection::LT); - Value negX = rewriter.create(loc, x); - Value one = getConstantLike(rewriter, loc, 1, x); - Value xSubOne = rewriter.create(loc, x, one); - Value z = rewriter.create(loc, needToReflect, negX, xSubOne); - - // Materialize - // a(z) = kBaseLanczosCoeff - // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) - Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x); - for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { - Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x); - Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x); - Value quotient = rewriter.create( - loc, coeff, rewriter.create(loc, z, oneBasedIndex)); - a = rewriter.create(loc, a, quotient); - } - - // To improve accuracy on platforms with less-precise log implementations, - // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the - // device. - // Materialize as - // log(t) = log(kLanczosGamma + 1/2 + z) - // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)). - Value lanczosPlusHalf = - getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x); - Value t = rewriter.create(loc, lanczosPlusHalf, z); - Value logTerm = - getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x); - Value log1pTerm = rewriter.create( - loc, rewriter.create(loc, z, lanczosPlusHalf)); - Value logT = rewriter.create(loc, logTerm, log1pTerm); - - // Note that t(z) may be large and we need to be careful not to overflow to - // infinity in the relevant term - // r = (z + 1/2) * log(t(z)) - t(z). - // Therefore, we compute this as - // r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)). - Value tDivLogT = rewriter.create(loc, t, logT); - Value sum = rewriter.create( - loc, rewriter.create(loc, z, half), tDivLogT); - Value r = rewriter.create(loc, sum, logT); - - // Compute the final result (modulo reflection) as - // lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)). - Value logA = rewriter.create(loc, a); - Value lgamma = rewriter.create( - loc, - rewriter.create( - loc, - getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x), - r), - logA); - - // Compute the reflected value for x < 0.5 as - // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))). - // - // The abs is needed because lgamma is the log of the absolute value of the - // gamma function. - // - // We have to be careful when computing the final term above. gamma(x) goes - // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x) - // term. The slope is large, so precision is particularly important. - // - // Because abs(sin(pi * x)) has period of 1 we can equivalently use - // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is - // more numerically accurate: It doesn't overflow to inf like pi * x would and - // if x is an integer it evaluates to exactly 0 which is important because we - // then take the log of this value, and log(0) is inf. - // - // We don't have a frac(x) primitive in HLO and computing it is tricky, but - // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our - // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)). - // - // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close - // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain - // [0, 1] is symmetric across the line Y=0.5. - // - - // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of - // pi * abs_frac for values of abs_frac close to 1. - Value abs = rewriter.create(loc, x); - Value absFrac = rewriter.create( - loc, abs, rewriter.create(loc, abs)); - Value reduceAbsFrac = rewriter.create( - loc, half, absFrac, mhlo::ComparisonDirection::LT); - absFrac = rewriter.create( - loc, reduceAbsFrac, rewriter.create(loc, one, absFrac), - absFrac); - - // Materialize reflection. - Value reflectionDenom = rewriter.create( - loc, - rewriter.create( - loc, rewriter.create( - loc, getConstantLike(rewriter, loc, M_PI, x), absFrac))); - Value lgammaReflection = rewriter.create( - loc, - rewriter.create( - loc, getConstantLike(rewriter, loc, std::log(M_PI), x), - reflectionDenom), - lgamma); - - // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf, - // then it "wins" and the result is +/-inf. - Value finiteReflectionDenom = - rewriter.create(loc, reflectionDenom); - Value negReflectionDenom = rewriter.create(loc, reflectionDenom); - lgammaReflection = rewriter.create( - loc, finiteReflectionDenom, lgammaReflection, negReflectionDenom); - - // Select whether or not to rely on the reflection. - lgamma = rewriter.create(loc, needToReflect, lgammaReflection, - lgamma); - - // Materialize +/-inf behavior as - // lgamma(+/-inf) = +inf. - Value xIsInf = rewriter.create(loc, x); - return rewriter.create( - loc, xIsInf, - chlo::getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false), - lgamma); -} - -// Express `cosh` as -// cosh(x) = (e^x + e^-x) / 2 -// = e^(x + log(1/2)) + e^(-x + log(1/2)) -// -// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not. -// -// This incorrectly overflows to inf for two f32 input values, namely -// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The -// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so -// we deem this acceptable. -Value materializeCoshApproximation(ConversionPatternRewriter &rewriter, - Location loc, ValueRange operands) { - CoshOp::Adaptor transformed(operands); - Value x = transformed.getOperand(); - - Value logOneHalf = - rewriter.create(loc, getConstantLike(rewriter, loc, 0.5, x)); - Value expAdd = rewriter.create( - loc, rewriter.create(loc, x, logOneHalf)); - Value expSub = rewriter.create( - loc, rewriter.create(loc, logOneHalf, x)); - return rewriter.create(loc, expAdd, expSub); -} - -struct ConvertCoshOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - CoshOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(), - rewriter.getF32Type(), - &materializeCoshApproximation)); - return success(); - } -}; - -// Compute the Digamma function using Lanczos' approximation from "A Precision -// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis -// series B. Vol. 1: -// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z) -// with t(z) = z + kLanczosGamma + 1/2 -// a(z) = kBaseLanczosCoeff -// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) -// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) -Value materializeDigamma(ConversionPatternRewriter &rewriter, Location loc, - ValueRange args) { - // If the input is less than 0.5 use Euler's reflection formula. - // digamma(x) = digamma(1 - x) - pi * cot(pi * x) - // Let z be - // z = -x if x < 1/2 - // z = x - 1 otheriwse - Value x = args.front(); - Value half = getConstantLike(rewriter, loc, 0.5, x); - Value needToReflect = rewriter.create( - loc, x, half, mhlo::ComparisonDirection::LT); - Value negX = rewriter.create(loc, x); - Value one = getConstantLike(rewriter, loc, 1, x); - Value xSubOne = rewriter.create(loc, x, one); - Value z = rewriter.create(loc, needToReflect, negX, xSubOne); - - // Materialize - // a(z) = kBaseLanczosCoeff - // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) - // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) - Value zero = getConstantLike(rewriter, loc, 0.0, x); - Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x); - Value aPrime = zero; - for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { - Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x); - Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x); - Value zTerm = rewriter.create(loc, z, oneBasedIndex); - aPrime = rewriter.create( - loc, aPrime, - rewriter.create( - loc, coeff, rewriter.create(loc, zTerm, zTerm))); - a = rewriter.create( - loc, a, rewriter.create(loc, coeff, zTerm)); - } - - // To improve accuracy on platforms with less-precise log implementations, - // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the - // device. - // Materialize as - // log(t) = log(kLanczosGamma + 1/2 + z) - // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)). - Value lanczosPlusHalf = - getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x); - Value t = rewriter.create(loc, lanczosPlusHalf, z); - Value logTerm = - getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x); - Value log1pTerm = rewriter.create( - loc, rewriter.create(loc, z, lanczosPlusHalf)); - Value logT = rewriter.create(loc, logTerm, log1pTerm); - - // Materialize the final result (modulo reflection) as - // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z). - Value aPrimeDivA = rewriter.create(loc, aPrime, a); - Value lanczosGammaDivT = rewriter.create( - loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t); - Value digamma = rewriter.create( - loc, rewriter.create(loc, logT, aPrimeDivA), - lanczosGammaDivT); - - // We need to be careful how we compute cot(pi * input) below: For - // near-integral arguments, pi * input can lose precision. - // - // Input is already known to be less than 0.5 (otherwise we don't have to - // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to - // increase precision of pi * x and the resulting cotangent. - Value reducedX = rewriter.create( - loc, x, - rewriter.create( - loc, rewriter.create( - loc, rewriter.create( - loc, x, getConstantLike(rewriter, loc, 0.5, x))))); - - // Materialize reflection for inputs less than 0.5 as - // digamma(x) = digamma(1 - x) - pi * cot(pi * x) - // = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x) - Value pi = getConstantLike(rewriter, loc, M_PI, x); - Value piMulReducedX = rewriter.create(loc, pi, reducedX); - Value cos = rewriter.create(loc, piMulReducedX); - Value sin = rewriter.create(loc, piMulReducedX); - Value reflection = rewriter.create( - loc, digamma, - rewriter.create( - loc, rewriter.create(loc, pi, cos), sin)); - - // Select whether or not to rely on the reflection. - digamma = - rewriter.create(loc, needToReflect, reflection, digamma); - - // Digamma has poles at negative integers and zero; return nan for those. - Value isLeZero = rewriter.create( - loc, x, zero, mhlo::ComparisonDirection::LE); - Value isInt = rewriter.create( - loc, x, rewriter.create(loc, x), - mhlo::ComparisonDirection::EQ); - Value isPole = rewriter.create(loc, isLeZero, isInt); - return rewriter.create( - loc, isPole, - getConstantLike(rewriter, loc, std::numeric_limits::quiet_NaN(), - x), - digamma); -} - -Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc, - ValueRange args) { - // Implementation ported from: - // https://github.com/openxla/xla/blob/7a067a7b88d2ffb15b1dc5e3c06f701a15f0391d/xla/client/lib/math.cc#L1912-L1917 - // Reference: Johansson, Fredrik. - // "Rigorous high-precision computation of the Hurwitz zeta function and its - // derivatives." Numerical Algorithms 69.2 (2015): 253-270. - // https://arxiv.org/abs/1309.2877 - formula (5) - // Notation is more or less kept as a reference to the whitepaper. - assert(args.size() == 2); - Value x = args[0]; - Value q = args[1]; - static const std::array kZetaCoeffs{ - -7.1661652561756670113e18, - 1.8152105401943546773e17, - -4.5979787224074726105e15, - 1.1646782814350067249e14, - -2.950130727918164224e12, - 7.47242496e10, - -1.8924375803183791606e9, - 47900160.0, - -1209600.0, - 30240.0, - -720.0, - 12.0, - }; - - // For speed we'll always use 9 iterations for the initial series estimate, - // and a 12 term expansion for the Euler-Maclaurin formula. - Value a = q; - Value zero = chlo::getConstantLike(rewriter, loc, 0.0, a); - Value negPower = zero; - Value negX = rewriter.create(loc, x); - Value initialSum = rewriter.create(loc, q, negX); - Value one = chlo::getConstantLike(rewriter, loc, 1.0, a); - for (int i = 0; i < 9; ++i) { - a = rewriter.create(loc, a, one); - negPower = rewriter.create(loc, a, negX); - initialSum = rewriter.create(loc, initialSum, negPower); - } - a = rewriter.create(loc, a, one); - negPower = rewriter.create(loc, a, negX); - Value oneLikeX = chlo::getConstantLike(rewriter, loc, 1.0, x); - Value xMinusOne = rewriter.create(loc, x, oneLikeX); - Value negPowerMulA = rewriter.create(loc, negPower, a); - Value negPowerMulADivXMinusOne = - rewriter.create(loc, negPowerMulA, xMinusOne); - Value s = - rewriter.create(loc, initialSum, negPowerMulADivXMinusOne); - Value aInverseSquare = rewriter.create( - loc, one, rewriter.create(loc, a, a)); - - Value hornerSum = zero; - Value factor = one; - // Use Horner's rule for this. - // Note this differs from Cephes which does a 'naive' polynomial evaluation. - // Using Horner's rule allows to avoid some NaN's and Infs from happening, - // resulting in more numerically stable code. - for (int i = 0; i < 11; ++i) { - Value factorLhs = rewriter.create( - loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x)); - Value factorRhs = rewriter.create( - loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x)); - factor = rewriter.create(loc, factorLhs, factorRhs); - hornerSum = rewriter.create( - loc, factor, - rewriter.create( - loc, aInverseSquare, - rewriter.create( - loc, hornerSum, - chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a)))); - } - Value zeroPointFiveLikeNegPower = - chlo::getConstantLike(rewriter, loc, .5, negPower); - Value xDivA = rewriter.create(loc, x, a); - s = rewriter.create( - loc, s, - rewriter.create( - loc, negPower, - rewriter.create( - loc, zeroPointFiveLikeNegPower, - rewriter.create( - loc, xDivA, - rewriter.create( - loc, - chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11], - a), - hornerSum))))); - - // Use the initial zeta sum without the correction term coming - // from Euler-Maclaurin if it is accurate enough. - Value absNegPower = rewriter.create(loc, negPower); - Value absInitialSum = rewriter.create(loc, initialSum); - Value output = rewriter.create( - loc, - rewriter.create( - loc, absNegPower, - rewriter.create( - loc, absInitialSum, - chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)), - mhlo::ComparisonDirection::LT), - initialSum, s); - - // Function is not defined for x < 1. - Value nan = chlo::getConstantLike( - rewriter, loc, std::numeric_limits::quiet_NaN(), x); - output = rewriter.create( - loc, - rewriter.create(loc, x, oneLikeX, - mhlo::ComparisonDirection::LT), - nan, output); - - // For q <= 0, x must be an integer. - Value qLeZero = rewriter.create( - loc, q, zero, mhlo::ComparisonDirection::LE); - Value xNotInt = rewriter.create( - loc, x, rewriter.create(loc, x), - mhlo::ComparisonDirection::NE); - Value xDomainError = rewriter.create(loc, qLeZero, xNotInt); - output = rewriter.create(loc, xDomainError, nan, output); - - // For all integer q <= 0, zeta has a pole. The limit is only defined as - // +inf if x is and even integer. - Value inf = chlo::getConstantLike(rewriter, loc, - std::numeric_limits::infinity(), x); - Value qIsInt = rewriter.create( - loc, q, rewriter.create(loc, q), - mhlo::ComparisonDirection::EQ); - Value atPole = rewriter.create(loc, qLeZero, qIsInt); - Value two = chlo::getConstantLike(rewriter, loc, 2.0, x); - Value xIsInt = rewriter.create( - loc, x, rewriter.create(loc, x), - mhlo::ComparisonDirection::EQ); - Value xIsEven = rewriter.create( - loc, rewriter.create(loc, x, two), zero, - mhlo::ComparisonDirection::EQ); - Value xIsEvenInt = rewriter.create(loc, xIsInt, xIsEven); - output = rewriter.create( - loc, atPole, rewriter.create(loc, xIsEvenInt, inf, nan), - output); - - // For x = 1, this is the harmonic series and diverges. - output = rewriter.create( - loc, - rewriter.create(loc, x, one, - mhlo::ComparisonDirection::EQ), - inf, output); - - return output; -} - -Value materializePolygamma(ConversionPatternRewriter &rewriter, Location loc, - ValueRange args) { - PolygammaOp::Adaptor transformed(args); - Value n = transformed.getN(); - Value x = transformed.getX(); - - // Handle integer n > 0. - Value one = getConstantLike(rewriter, loc, 1.0, x); - Value two = getConstantLike(rewriter, loc, 2.0, x); - Value sign = rewriter.create( - loc, - rewriter.create(loc, two, - rewriter.create(loc, n, two)), - one); - Value nPlusOne = rewriter.create(loc, n, one); - Value expLgammaNp1 = rewriter.create( - loc, rewriter.create(loc, nPlusOne)); - Value zeta = rewriter.create(loc, nPlusOne, x); - Value result = rewriter.create( - loc, rewriter.create(loc, sign, expLgammaNp1), zeta); - - // Handle n = 0. - Value zero = getConstantLike(rewriter, loc, 0.0, x); - Value nEqZero = rewriter.create( - loc, n, zero, mhlo::ComparisonDirection::EQ); - result = rewriter.create( - loc, nEqZero, rewriter.create(loc, x), result); - - // Check that n is a natural number. Return nan, otherwise. - Value nonInt = rewriter.create( - loc, n, rewriter.create(loc, n), - mhlo::ComparisonDirection::NE); - Value negative = rewriter.create( - loc, n, zero, mhlo::ComparisonDirection::LT); - Value nonNatural = rewriter.create(loc, nonInt, negative); - return rewriter.create( - loc, nonNatural, - getConstantLike(rewriter, loc, std::numeric_limits::quiet_NaN(), - x), - result); -} - -struct ConvertLgammaOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - LgammaOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FloatType minPrecisionTy = rewriter.getF32Type(); - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(), - minPrecisionTy, &materializeLgamma)); - return success(); - } -}; - -struct ConvertDigammaOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - DigammaOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FloatType minPrecisionTy = rewriter.getF32Type(); - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(), - minPrecisionTy, &materializeDigamma)); - return success(); - } -}; - -Value materializeNextAfter(ConversionPatternRewriter &rewriter, Location loc, - ValueRange operands) { - NextAfterOp::Adaptor transformed(operands); - Value x = transformed.getX(); - Value y = transformed.getY(); - auto resultTy = x.getType().cast(); - auto bitwidth = resultTy.getElementType().getIntOrFloatBitWidth(); - ImplicitLocOpBuilder b(loc, rewriter); - auto intTy = resultTy.clone(b.getIntegerType(bitwidth)); - auto xAsInt = b.create(intTy, x); - auto yAsInt = b.create(intTy, y); - - // The result is NaN if either "x" or "y" are NaN. - auto xIsNan = b.create(x, x, mhlo::ComparisonDirection::NE); - auto yIsNan = b.create(y, y, mhlo::ComparisonDirection::NE); - auto nanInput = b.create(xIsNan, yIsNan); - auto resultForNan = getConstantLike( - rewriter, loc, std::numeric_limits::quiet_NaN(), x); - auto resultForNanAsInt = - b.create(intTy, resultForNan); - - // The sign bit is the MSB. - const int64_t signBit = int64_t{1} << (bitwidth - 1); - // Discard the sign bit to make the result non-negative. - auto signMask = getConstantLike(rewriter, loc, signBit, xAsInt); - auto negatedSignMask = getConstantLike(rewriter, loc, ~signBit, xAsInt); - auto xAbs = b.create(xAsInt, negatedSignMask); - auto yAbs = b.create(yAsInt, negatedSignMask); - - // When both "x" and "y" are equal, the result is "y". - auto xAndYAreEqual = - b.create(x, y, mhlo::ComparisonDirection::EQ); - auto resultForEqual = yAsInt; - - // When both "x" and "y" are 0, the result is "y". This is a separate case - // from above because "x" and "y" might have a different sign. - auto zero = getConstantLike(rewriter, loc, 0, xAsInt); - auto xIsZero = - b.create(xAbs, zero, mhlo::ComparisonDirection::EQ); - auto yIsZero = - b.create(yAbs, zero, mhlo::ComparisonDirection::EQ); - auto resultForBothZero = yAsInt; - - auto xSign = b.create(xAsInt, signMask); - auto ySign = b.create(yAsInt, signMask); - - // If from == 0 && to != 0, we need to return the smallest subnormal number - // signed like "to". - auto one = getConstantLike(rewriter, loc, 1, xAsInt); - auto resultForXZeroYNonZero = b.create(ySign, one); - - // If the sign of "x" and "y" disagree: - // - we need to make the magnitude of "from" smaller so that it is closer to - // zero. - // - // Otherwise the signs agree: - // - "x" with a magnitude larger than "y" means we need to make the magnitude - // smaller. - // - "x" with a magnitude smaller than "y" means we need to make the magnitude - // larger. - auto signsDisagree = - b.create(xSign, ySign, mhlo::ComparisonDirection::NE); - auto xMagnitudeLargerThanY = - b.create(xAbs, yAbs, mhlo::ComparisonDirection::GT); - auto resultHasSmallerMagnitude = - b.create(xMagnitudeLargerThanY, signsDisagree); - auto minusOne = getConstantLike(rewriter, loc, -1, xAsInt); - auto magnitudeAdjustment = - b.create(resultHasSmallerMagnitude, minusOne, one); - Value result = b.create(xAsInt, magnitudeAdjustment); - // Handle from == +-0. - result = b.create( - xIsZero, - b.create(yIsZero, resultForBothZero, - resultForXZeroYNonZero), - result); - // Handle from == to. - result = b.create(xAndYAreEqual, resultForEqual, result); - // Handle isnan(x) || isnan(y). - result = b.create(nanInput, resultForNanAsInt, result); - - // Cast back to the original type. - return b.create(resultTy, result); -} - -struct ConvertNextAfterOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - NextAfterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOp( - op, materializeNextAfter(rewriter, op.getLoc(), adaptor.getOperands())); - return success(); - } -}; - -struct ConvertPolygammaOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - PolygammaOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - FloatType minPrecisionTy = rewriter.getF32Type(); - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), - minPrecisionTy, &materializePolygamma)); - return success(); - } -}; - -// Sinh(x) = (e^x - e^-x) / 2 -// = e^(x + log(1/2)) - e^(-x + log(1/2)). -// -// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not -// inf. -// -// This incorrectly overflows to +/-inf for two f32 input values, namely -// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The -// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so -// we deem this acceptable. -Value materializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter, - Location loc, ValueRange operands) { - SinhOp::Adaptor transformed(operands); - Value x = transformed.getOperand(); - - Value logOneHalf = - rewriter.create(loc, getConstantLike(rewriter, loc, 0.5, x)); - Value expAdd = rewriter.create( - loc, rewriter.create(loc, x, logOneHalf)); - Value expSub = rewriter.create( - loc, rewriter.create(loc, logOneHalf, x)); - return rewriter.create(loc, expAdd, expSub); -} - -// Express `sinh` as -// sinh(x) = (e^x - e^-x) / 2 if |x| < 1 -// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. -Value materializeSinhApproximation(ConversionPatternRewriter &rewriter, - Location loc, ValueRange operands) { - Value largeSinhResult = - materializeSinhApproximationForLargeX(rewriter, loc, operands); - - SinhOp::Adaptor transformed(operands); - Value x = transformed.getOperand(); - - // For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in - // 0. - // Rewrite this to avoid that. We use expm1(x) because that preserves the - // first order term of the taylor series of e^x. - // (e^(x) - e^(-x)) / 2. = - // (e^(x) - 1 + 1 - e^(-x)) / 2. - // (expm1(x) + (e^(x) - 1) / e^x) / 2. - // (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2. - Value expm1 = rewriter.create(loc, x); - Value one = getConstantLike(rewriter, loc, 1.0, x); - Value oneHalf = getConstantLike(rewriter, loc, 0.5, x); - Value expm1PlusOne = rewriter.create(loc, expm1, one); - Value ratio = rewriter.create(loc, expm1, expm1PlusOne); - Value sum = rewriter.create(loc, expm1, ratio); - Value smallSinhResult = rewriter.create(loc, oneHalf, sum); - - Value absX = rewriter.create(loc, x); - Value absXLtOne = rewriter.create( - loc, absX, one, mhlo::ComparisonDirection::LT); - return rewriter.create(loc, absXLtOne, smallSinhResult, - largeSinhResult); -} - -struct ConvertSinhOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - SinhOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value x = adaptor.getOperand(); - if (x.getType().cast().getElementType().isa()) { - rewriter.replaceOp(op, materializeSinhApproximationForLargeX( - rewriter, op.getLoc(), adaptor.getOperands())); - return success(); - } - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(), - rewriter.getF32Type(), - &materializeSinhApproximation)); - return success(); - } -}; - -// Converts chlo.top_k to MHLO iota, sort, and slice ops. -// -// chlo.top_k sorts along last dimension of the input tensor and then returns -// the top K components' values and indices. This is translated into a few -// ops in MHLO: first generating an integer sequence for the indices, -// then sort both the original input tensor and the indices togheter, and -// at last slice out the top K components. -// -// For example, for the following IR: -// -// %0:2 = "chlo.top_k"(%input, k=8): tensor<16x16xf32> -> -// (tensor<16x8xf32>, tensor<16x8xi32>) -// -// We will get: -// -// %1 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32> -// %2 = "mhlo.sort"(%input, %1) ({ -// ^bb0(%arg1: tensor, %arg2: tensor, -// %arg3: tensor, %arg4: tensor): -// %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ... -// "mhlo.return"(%7) : (tensor) -> () -// }) {dimension = 1 : i64, is_stable = true} : ... -// %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : ... -// %4 = "mhlo.get_tuple_element"(%2) {index = 1 : i32} : ... -// %5 = "mhlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>, -// start_indices dense<0> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : -// (tensor<16x16xf32>) -> tensor<16x8xf32> -// %6 = "mhlo.slice"(%4) ... -// -struct BasisConvertTopKOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - TopKOp op, OpAdaptor /*adaptor*/, - ConversionPatternRewriter &rewriter) const override { - auto operandType = op.getOperand().getType().dyn_cast(); - if (!operandType) return failure(); - int64_t operandRank = operandType.getRank(); - int64_t lastDimIndex = operandRank - 1; - int64_t lastDimSize = operandType.getDimSize(lastDimIndex); - int64_t lastDimResultSize = - hlo::isDynamicDimSize(lastDimSize) - ? static_cast(op.getK()) - : std::min(static_cast(op.getK()), lastDimSize); - int64_t isDynamic = !operandType.hasStaticShape(); - auto i32Type = rewriter.getIntegerType(32); - Value opShapeValue, resultShapeValue; - if (isDynamic) { - SmallVector sizesI32x1; - for (auto i = 0; i < operandType.getRank(); ++i) { - auto sizeI32 = rewriter.create( - op.getLoc(), op.getOperand(), i); - auto sizeI32x1 = rewriter.create( - op.getLoc(), RankedTensorType::get({1}, i32Type), sizeI32); - sizesI32x1.push_back(sizeI32x1); - } - opShapeValue = - rewriter.create(op.getLoc(), sizesI32x1, - /*dimension=*/0); - auto lastDimI32 = rewriter.create( - op.getLoc(), - rewriter.getI32IntegerAttr(static_cast(lastDimResultSize))); - auto lastDimI32x1 = rewriter.create( - op.getLoc(), RankedTensorType::get({1}, i32Type), lastDimI32); - sizesI32x1.back() = lastDimI32x1; - resultShapeValue = - rewriter.create(op.getLoc(), sizesI32x1, - /*dimension=*/0); - } - - // Create an Iota op for indices. - Type iotaType = RankedTensorType::get(operandType.getShape(), i32Type); - Value iotaOp; - if (isDynamic) { - iotaOp = rewriter.create( - op.getLoc(), iotaType, opShapeValue, - rewriter.getI64IntegerAttr(lastDimIndex)); - } else { - iotaOp = rewriter.create( - op.getLoc(), iotaType, rewriter.getI64IntegerAttr(lastDimIndex)); - } - - // Create the sort op. It takes two inputs, one for the original input, the - // other for the indices. Use TOTALORDER comparison type instead of the - // default comparison if the element type is of type float. - Type elementType = operandType.getElementType(); - auto sortOp = - createSortOp(&rewriter, op.getLoc(), {op.getOperand(), iotaOp}, - {elementType, i32Type}, lastDimIndex, - /*isStable=*/true, - /*direction=*/mhlo::ComparisonDirection::GT); - - // Get the sorted input and index tuple element. - auto tupleFirstElement = sortOp.getResult(0); - auto tupleSecondElement = sortOp.getResult(1); - - SmallVector beginIndices(operandRank, 0); - auto endIndices = llvm::to_vector<4>(operandType.getShape()); - endIndices.back() = lastDimResultSize; - SmallVector strides(operandRank, 1); - - // Get the slice for the top K elements. - auto indicesTy = RankedTensorType::get(operandRank, rewriter.getI64Type()); - Value values, indices; - if (isDynamic) { - Value startIndices = rewriter.create( - op.getLoc(), DenseIntElementsAttr::get(indicesTy, beginIndices)); - Value lastIndices = rewriter.create( - op.getLoc(), resultShapeValue, rewriter.getI64Type()); - Value stridesOp = rewriter.create( - op.getLoc(), DenseIntElementsAttr::get(indicesTy, strides)); - - SmallVector resultShape = - llvm::to_vector<4>(operandType.getShape()); - resultShape.back() = lastDimResultSize; - RankedTensorType resultType = RankedTensorType::get( - resultShape, elementType, operandType.getEncoding()); - RankedTensorType indexResultType = - RankedTensorType::get(resultShape, i32Type); - - values = rewriter.create( - op.getLoc(), resultType, tupleFirstElement, startIndices, lastIndices, - stridesOp); - indices = rewriter.create( - op.getLoc(), indexResultType, tupleSecondElement, startIndices, - lastIndices, stridesOp); - } else { - values = rewriter.create( - op.getLoc(), tupleFirstElement, - DenseIntElementsAttr::get(indicesTy, beginIndices), - DenseIntElementsAttr::get(indicesTy, endIndices), - DenseIntElementsAttr::get(indicesTy, strides)); - indices = rewriter.create( - op.getLoc(), tupleSecondElement, - DenseIntElementsAttr::get(indicesTy, beginIndices), - DenseIntElementsAttr::get(indicesTy, endIndices), - DenseIntElementsAttr::get(indicesTy, strides)); - } - - rewriter.replaceOp(op, {values, indices}); - return success(); - } -}; - -struct ConvertZetaOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ZetaOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - FloatType minPrecisionTy = rewriter.getF32Type(); - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), - minPrecisionTy, &materializeZeta)); - return success(); - } -}; - -struct ConvertSelectOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - BroadcastSelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Only support ranked operands. - Value pred = adaptor.getPred(); - Value onTrue = adaptor.getOnTrue(); - Value onFalse = adaptor.getOnFalse(); - auto predType = pred.getType().dyn_cast(); - auto onTrueType = onTrue.getType().dyn_cast(); - auto onFalseType = onFalse.getType().dyn_cast(); - auto resultType = op.getResult().getType().dyn_cast(); - if (!predType || !onTrueType || !onFalseType || !resultType) { - return failure(); - } - - auto loc = op.getLoc(); - - Value predShape = rewriter.createOrFold(loc, pred); - Value onTrueShape = rewriter.createOrFold(loc, onTrue); - Value onFalseShape = rewriter.createOrFold(loc, onFalse); - int64_t resultRank = std::max( - {predType.getRank(), onTrueType.getRank(), onFalseType.getRank()}); - - Value broadcastableCstr = rewriter.createOrFold( - loc, ValueRange{predShape, onTrueShape, onFalseShape}); - auto assumingOp = rewriter.create( - loc, ArrayRef{resultType}, broadcastableCstr); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.createBlock(&assumingOp.getDoRegion()); - - Value resultExtents = rewriter.createOrFold( - loc, shape::getExtentTensorType(op.getContext()), - ValueRange{predShape, onTrueShape, onFalseShape}, - /*error=*/nullptr); - auto shapeType = - RankedTensorType::get({resultRank}, rewriter.getIndexType()); - resultExtents = - rewriter.createOrFold(loc, shapeType, resultExtents); - - Value broadcastedPred = pred; - // Pred has an implicit broadcast for scalars, so use that when convenient. - if (predType.getRank() > 0) { - auto predBroadcastDimensions = llvm::to_vector<4>( - llvm::seq(resultRank - predType.getRank(), resultRank)); - broadcastedPred = rewriter.create( - loc, - RankedTensorType::get(resultType.getShape(), - predType.getElementType()), - pred, resultExtents, - rewriter.getI64TensorAttr(predBroadcastDimensions)); - } - auto onTrueBroadcastDimensions = llvm::to_vector<4>( - llvm::seq(resultRank - onTrueType.getRank(), resultRank)); - Value broadcastedOnTrue = rewriter.create( - loc, - RankedTensorType::get(resultType.getShape(), - onTrueType.getElementType()), - onTrue, resultExtents, - rewriter.getI64TensorAttr(onTrueBroadcastDimensions)); - auto onFalseBroadcastDimensions = llvm::to_vector<4>( - llvm::seq(resultRank - onFalseType.getRank(), resultRank)); - Value broadcastedOnFalse = rewriter.create( - loc, - RankedTensorType::get(resultType.getShape(), - onFalseType.getElementType()), - onFalse, resultExtents, - rewriter.getI64TensorAttr(onFalseBroadcastDimensions)); - - // And generate the final non-broadcasted ternary op. - Value finalResult = - rewriter.create(loc, resultType, broadcastedPred, - broadcastedOnTrue, broadcastedOnFalse); - rewriter.create(loc, finalResult); - rewriter.replaceOp(op, {assumingOp.getResult(0)}); - return success(); - } -}; - -// Converts binary ops that statically are determined to not broadcast directly -// to the corresponding mhlo non-broadcasting op. -template -struct ConvertTrivialNonBroadcastBinaryOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ChloOpTy op, typename ChloOpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Only rewrite for statically determinable non-broadcasting cases. - auto lhsType = - adaptor.getLhs().getType().template dyn_cast(); - auto rhsType = - adaptor.getRhs().getType().template dyn_cast(); - if (!lhsType || !rhsType) return failure(); - - // Requires rank broadcast. - if (lhsType.getRank() != rhsType.getRank()) return failure(); - // Any dynamic dimension may require broadcasting and requires more - // analysis. - if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) - return failure(); - - for (auto extents : llvm::zip(lhsType.getShape(), rhsType.getShape())) { - auto lhsExtent = std::get<0>(extents); - auto rhsExtent = std::get<1>(extents); - if (lhsExtent != rhsExtent) { - return failure(); - } - } - - rewriter.replaceOp(op, Adaptor::createOp(op, op.getResult().getType(), - adaptor.getOperands(), rewriter)); - return success(); - } -}; - -// Converts a binary op with ranked broadcasting operands to explicitly -// broadcast and invoke the corresponding mhlo non-broadcasting op. -// Note that dynamic broadcasting supported by this pattern is only valid for -// "numpy" broadcasting semantics as defined here: -// https://docs.scipy.org/doc/numpy/reference/ufuncs.html -// Specifically, this includes the following cases: -// - Same rank broadcast (operands have the same static rank). -// - Different-rank broadcast, either without a broadcast_dims attribte or -// with the broadcast_dims attribute set to map to a prefix padding. -// - Legal combinations of degenerate (1-dim) implicit broadcasting. -// The restriction on broadcast_dims derives from the definition of the -// `shape.broadcast` op, which only supports prefix-padding. -template -struct ConvertRankedDynamicBroadcastBinaryOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ChloOpTy op, typename ChloOpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Only support ranked operands. - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - auto lhsType = lhs.getType().dyn_cast(); - auto rhsType = rhs.getType().dyn_cast(); - auto resultType = - op.getResult().getType().template dyn_cast(); - if (!lhsType || !rhsType || !resultType) return failure(); - - // Check for "numpy"-style rank broadcast. - auto broadcastDimensions = op.getBroadcastDimensions(); - if (broadcastDimensions && - !hlo::isLegalNumpyRankedBroadcast(lhs, rhs, *broadcastDimensions)) { - // Note: It is unclear whether the general specification of explicit - // broadcast_dimensions on binary ops is a feature we want to carry - // forward. While it can technically be implemented for ranked-dynamic, - // it is incompatible with unranked inputs. If this warning is emitted - // in real programs, it is an indication that the feature should be - // implemented versus just falling back on the more standard definition - // of numpy-like prefix-padding. - op.emitWarning() << "unsupported non prefix-padded dynamic rank " - << "broadcast_dimensions = " << *broadcastDimensions; - return failure(); - } - - // Compute result shape. - auto loc = op.getLoc(); - - // Insert a constraint on the shapes being broadcastable and insert all - // future code into an assuming block reliant on the constraint. - Value lhsShape = rewriter.create(loc, lhs); - Value rhsShape = rewriter.create(loc, rhs); - auto broadcastableCstr = - rewriter.create(loc, lhsShape, rhsShape); - auto assumingOp = rewriter.create( - loc, ArrayRef{resultType}, broadcastableCstr.getResult()); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.createBlock(&assumingOp.getDoRegion()); - - int64_t resultRank = std::max(lhsType.getRank(), rhsType.getRank()); - Value resultExtents = - hlo::computeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, - rewriter); - - // Note that we unconditionally emit DynamicBroadcastInDim ops and let - // downstream canonicalizations fold them away if possible. This is - // because, in the dynamic case, there are many corner cases regarding - // when it is safe to omit, and some of them require analysis to prove - // properly. - auto lhsBroadcastDimensions = llvm::to_vector<4>( - llvm::seq(resultRank - lhsType.getRank(), resultRank)); - Value broadcastedLhs = rewriter.create( - loc, - RankedTensorType::get(resultType.getShape(), lhsType.getElementType()), - lhs, resultExtents, rewriter.getI64TensorAttr(lhsBroadcastDimensions)); - auto rhsBroadcastDimensions = llvm::to_vector<4>( - llvm::seq(resultRank - rhsType.getRank(), resultRank)); - Value broadcastedRhs = rewriter.create( - loc, - RankedTensorType::get(resultType.getShape(), rhsType.getElementType()), - rhs, resultExtents, rewriter.getI64TensorAttr(rhsBroadcastDimensions)); - - // And generate the final non-broadcasted binary op. - Value finalResult = Adaptor::createOp( - op, resultType, {broadcastedLhs, broadcastedRhs}, rewriter); - rewriter.create(loc, finalResult); - rewriter.replaceOp(op, {assumingOp.getResult(0)}); - return success(); - } -}; - -class ConvertDynamicReshapeOp - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(chlo::DynamicReshapeOp op, - PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto tensor = op.getOperand(); - auto shape = op.getOutputShape(); - - auto shapeTy = shape.getType().cast(); - auto resultTy = op.getType().cast(); - - Value inputShape = rewriter.create(loc, tensor); - Value numEls = rewriter.create(loc, inputShape); - Value cstr = rewriter.create(loc, numEls, shape); - rewriter.replaceOpWithNewOp( - op, cstr, [&](OpBuilder &b, Location l) { - Value computedShape = - b.create(l, shapeTy, numEls, shape); - SmallVector result; - result.push_back(b.create(l, resultTy, tensor, - computedShape)); - return result; - }); - - return success(); - } -}; - -#include "chlo_legalize_to_hlo/generated_chlo_legalize_to_hlo.inc" -} // namespace - -void populateChloBroadcastingPatterns(MLIRContext *context, - RewritePatternSet *patterns) { - // Instantiate conversion templates for conforming binary elementwise ops - // that do not have different dtypes between operands and results and do - // not have special attributes that need to be preserved. - populateForBroadcastingBinaryOp( - context, patterns, 10); - populateForBroadcastingBinaryOp( - context, patterns, 5); - patterns - ->add( - context); -} - -void populateChloLegalizeToHloBasisOpsPatterns(MLIRContext *context, - RewritePatternSet *patterns) { - // Patterns that decompose to a basis set of HLOs - // These are guaranteed to be convertible to StableHLO, but discard some - // higher level information that is useful to XLA compilation. - patterns->add(context); -} - -void populateDecomposeChloPatterns(MLIRContext *context, - RewritePatternSet *patterns) { - populateWithGenerated(*patterns); - - // Other patterns. - // clang-format off - patterns->add(context); - // clang-format on -} - -} // namespace chlo -} // namespace mlir diff --git a/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc b/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc index 058e6db42..d03c05880 100644 --- a/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc +++ b/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc @@ -19,59 +19,46 @@ limitations under the License. #include "mhlo/IR/hlo_ops.h" #include "mhlo/transforms/passes.h" #include "mhlo/transforms/rewriters.h" +#include "mhlo/utils/type_conversion.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Pass/Pass.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" namespace mlir { namespace mhlo { #define GEN_PASS_DEF_CHLOLEGALIZETOHLOPASS -#define GEN_PASS_DEF_CHLOLEGALIZETOHLOBASISOPSPASS +#define GEN_PASS_DEF_CHLOLEGALIZETOHIGHLEVELMHLOPASS #include "mhlo/transforms/mhlo_passes.h.inc" namespace { -struct ChloLegalizeToHloPass - : public impl::ChloLegalizeToHloPassBase { - explicit ChloLegalizeToHloPass(bool legalizeBroadcasts, - bool expandCompositions) - : ChloLegalizeToHloPassBase< - ChloLegalizeToHloPass>::ChloLegalizeToHloPassBase() { - this->legalize_broadcasts_ = legalizeBroadcasts; - this->expand_compositions_ = expandCompositions; - } +struct ChloLegalizeToHighLevelMhloPass + : public impl::ChloLegalizeToHighLevelMhloPassBase< + ChloLegalizeToHighLevelMhloPass> { + using ChloLegalizeToHighLevelMhloPassBase:: + ChloLegalizeToHighLevelMhloPassBase; void runOnOperation() override { - ConversionTarget conversionTarget(getContext()); - RewritePatternSet conversionPatterns(&getContext()); - conversionTarget.addIllegalDialect(); + MLIRContext &context = getContext(); + ConversionTarget conversionTarget(context); + RewritePatternSet conversionPatterns(&context); + + chlo::populateChloToHighLevelMhloOpPatterns(&context, &conversionPatterns); // Consider the mhlo dialect legal for tests. Also add helper dialects // that are needed by the patterns. - conversionTarget - .addLegalDialect(); - conversionTarget.addLegalOp(); - - if (legalize_broadcasts_) { - chlo::populateChloBroadcastingPatterns(&getContext(), - &conversionPatterns); - } - - if (expand_compositions_) { - chlo::populateDecomposeChloPatterns(&getContext(), &conversionPatterns); - } else { - conversionTarget - .addLegalOp(); - } + conversionTarget.addLegalDialect(); + conversionTarget.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), conversionTarget, std::move(conversionPatterns)))) { @@ -80,29 +67,27 @@ struct ChloLegalizeToHloPass } }; -struct ChloLegalizeToHloBasisOpsPass - : public impl::ChloLegalizeToHloBasisOpsPassBase< - ChloLegalizeToHloBasisOpsPass> { - using ChloLegalizeToHloBasisOpsPassBase::ChloLegalizeToHloBasisOpsPassBase; +struct ChloLegalizeToHloPass + : public impl::ChloLegalizeToHloPassBase { + using ChloLegalizeToHloPassBase::ChloLegalizeToHloPassBase; void runOnOperation() override { - ConversionTarget conversionTarget(getContext()); - RewritePatternSet conversionPatterns(&getContext()); + MLIRContext &context = getContext(); + ConversionTarget conversionTarget(context); + RewritePatternSet conversionPatterns(&context); - // Patterns will only be applied to these ops - conversionTarget.addIllegalOp(); + stablehlo::StablehloToHloTypeConverter typeConverter; + chlo::populateChloToHloPatterns(&context, &typeConverter, + &conversionPatterns); - // Programs with MHLO equivalents to the StableHLO ops are likely bugs - // for users of this expander pass, so best to disallow. - conversionTarget.addIllegalOp(); // TODO: Add ErfOp - - // Given that the resulting patterns should be convertible to StableHLO - // Only MHLO should be legal. + // Consider the mhlo dialect legal for tests. Also add helper dialects + // that are needed by the patterns. conversionTarget - .addLegalDialect(); - - chlo::populateChloLegalizeToHloBasisOpsPatterns(&getContext(), - &conversionPatterns); + .addIllegalDialect(); + conversionTarget.addLegalDialect< + MhloDialect, mlir::arith::ArithDialect, mlir::func::FuncDialect, + mlir::tensor::TensorDialect, mlir::shape::ShapeDialect>(); + conversionTarget.addLegalOp(); if (failed(applyPartialConversion(getOperation(), conversionTarget, std::move(conversionPatterns)))) { @@ -113,16 +98,26 @@ struct ChloLegalizeToHloBasisOpsPass } // namespace -std::unique_ptr> createChloLegalizeToHloPass( - bool legalizeBroadcasts, bool expandCompositions) { - return std::make_unique(legalizeBroadcasts, - expandCompositions); +} // namespace mhlo + +namespace chlo { +namespace { +#include "chlo_legalize_to_hlo/generated_chlo_legalize_to_hlo.inc" + +} // namespace + +void populateChloToHighLevelMhloOpPatterns(MLIRContext *, + RewritePatternSet *patterns) { + populateWithGenerated(*patterns); } -std::unique_ptr> -createChloLegalizeToHloBasisOpsPass() { - return std::make_unique(); +void populateChloToHloPatterns(MLIRContext *context, + TypeConverter *typeConverter, + RewritePatternSet *patterns) { + chlo::populateChloToHighLevelMhloOpPatterns(context, patterns); + stablehlo::populateChloToStablehloPatterns(context, patterns); + stablehlo::populateStablehloToHloPatterns(patterns, typeConverter, context); } -} // namespace mhlo +} // namespace chlo } // namespace mlir diff --git a/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td b/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td index 937739f28..497686bf2 100644 --- a/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td +++ b/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td @@ -19,345 +19,22 @@ limitations under the License. // ambiguous/different for various backends. Avoid patterns that are actually // lowering to non-canonical forms. -include "mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir/IR/OpBase.td" include "mhlo/IR/hlo_ops.td" include "stablehlo/dialect/ChloOps.td" -class MHLO_ComparisonDirectionValue : - ConstantAttr; - //===----------------------------------------------------------------------===// -// Unary op patterns. +// Direct CHLO->MHLO conversions //===----------------------------------------------------------------------===// -// Expand acos for non-complex arguments to MHLO dialect as follows: -// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1 -// = pi if x == -1 -// -// TODO(b/237376133): Support operands with complex element types separately -// using the following formula. -// acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x)))) -def : Pat<(CHLO_AcosOp NonComplexElementType:$input), - (MHLO_SelectOp - (MHLO_CompareOp - $input, - (MHLO_ConstantLike<"-1"> $input), - MHLO_ComparisonDirectionValue<"NE">, - (MHLO_DEFAULT_COMPARISON_TYPE) - ), - (MHLO_MulOp - (MHLO_ConstantLike<"2"> $input), - (MHLO_Atan2Op - (MHLO_SqrtOp - (MHLO_SubtractOp - (MHLO_ConstantLike<"1"> $input), - (MHLO_MulOp $input, $input) - ) - ), - (MHLO_AddOp - (MHLO_ConstantLike<"1"> $input), - $input - ) - ) - ), - (MHLO_ConstantLike<"M_PI"> $input) - )>; - -// Expand acosh to MHLO dialect as follows: -// acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1 -// = log(x + sqrt((x+1)*(x-1))) -// acosh(x) = nan if x < -1 -// -// If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as -// log(2*x) = log(2) + log(x). (Note this works because negative x never -// overflows; x < -1 simply yields nan. -def : Pat<(CHLO_AcoshOp NonComplexElementType:$input), - (MHLO_SelectOp - (MHLO_CompareOp - $input, - (MHLO_ConstantLike<"-1"> $input), - MHLO_ComparisonDirectionValue<"LT">, - (MHLO_DEFAULT_COMPARISON_TYPE) - ), - (MHLO_ConstantLike<"NAN"> $input), - (MHLO_SelectOp - (MHLO_CompareOp - $input, - (MHLO_SqrtOp - (MHLO_ConstantLikeMaxFiniteValue $input) - ), - MHLO_ComparisonDirectionValue<"GE">, - (MHLO_DEFAULT_COMPARISON_TYPE) - ), - (MHLO_AddOp - (MHLO_LogOp $input), - (MHLO_LogOp - (MHLO_ConstantLike<"2"> $input) - ) - ), - (MHLO_LogOp - (MHLO_AddOp - $input, - (MHLO_SqrtOp - (MHLO_MulOp - (MHLO_AddOp - (MHLO_ConstantLike<"1"> $input), - $input - ), - (MHLO_AddOp - (MHLO_ConstantLike<"-1"> $input), - $input - ) - ) - ) - ) - ) - ) - )>; - -// Expand acosh for complex arguments to MHLO dialect as -// acosh(x) = log(x + sqrt((x+1)*(x-1))) -// -// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing: -// "For now, we ignore the question of overflow if x is a -// complex type, because we don't yet have exhaustive tests for complex trig -// functions". -def : Pat<(CHLO_AcoshOp ComplexElementType:$input), - (MHLO_LogOp - (MHLO_AddOp - $input, - (MHLO_SqrtOp - (MHLO_MulOp - (MHLO_AddOp - $input, - (MHLO_ConstantLike<"1"> $input) - ), - (MHLO_SubtractOp - $input, - (MHLO_ConstantLike<"1"> $input) - ) - ) - ) - ) - )>; - - -// Expand asin to MHLO dialect as follows: -// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) -def : Pat<(CHLO_AsinOp $input), - (MHLO_MulOp - (MHLO_ConstantLike<"2"> $input), - (MHLO_Atan2Op - $input, - (MHLO_AddOp - (MHLO_ConstantLike<"1"> $input), - (MHLO_SqrtOp - (MHLO_SubtractOp - (MHLO_ConstantLike<"1"> $input), - (MHLO_MulOp $input, $input) - ) - ) - ) - ) - )>; - -// Expand asinh for non-complex arguments to MHLO dialect as -// asinh(x) = log(x + sqrt(x^2 + 1)) -// -// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1) -// as 2*x and return log(2) + log(x). -// -// For small x, sqrt(x^2 + 1) will evaluate to 1 due to floating point -// arithmetic. However, we would like to retain the low order term of this, -// which is around 0.5 * x^2 using a binomial expansion. -// Let z = sqrt(a^2 + 1) -// The following rewrite retains the lower order term. -// log(a + sqrt(a^2 + 1)) -// = log((a + sqrt(a^2 + 1)) * (1 + sqrt(a^2 + 1)) / (1 + sqrt(a^2 + 1))) -// = log((a + a^2 + 1 + a * z + z) / (1 + z)) -// = log(1 + a + a^2 / (1 + z)) -// = log(1 + a + a^2 / (1 + sqrt(a^2 + 1))) -// -// If x is negative, the above would give us some trouble; we can't approximate -// the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) = -// -asinh(x). -def : Pat<(CHLO_AsinhOp NonComplexElementType:$input), - (MHLO_MulOp - (MHLO_SignOp $input), - (MHLO_SelectOp - (MHLO_CompareOp - (MHLO_AbsOp $input), - (MHLO_SqrtOp - (MHLO_ConstantLikeMaxFiniteValue $input) - ), - MHLO_ComparisonDirectionValue<"GE">, - (MHLO_DEFAULT_COMPARISON_TYPE) - ), - (MHLO_AddOp - (MHLO_LogOp - (MHLO_AbsOp $input) - ), - (MHLO_LogOp - (MHLO_ConstantLike<"2"> $input) - ) - ), - (MHLO_SelectOp - (MHLO_CompareOp - (MHLO_AbsOp $input), - (MHLO_ConstantLike<"1"> $input), - MHLO_ComparisonDirectionValue<"LE">, - (MHLO_DEFAULT_COMPARISON_TYPE) - ), - (MHLO_Log1pOp - (MHLO_AddOp - (MHLO_AbsOp $input), - (MHLO_MulOp - (MHLO_AbsOp $input), - (MHLO_DivOp - (MHLO_AbsOp $input), - (MHLO_AddOp - (MHLO_ConstantLike<"1"> $input), - (MHLO_SqrtOp - (MHLO_AddOp - (MHLO_MulOp - (MHLO_AbsOp $input), - (MHLO_AbsOp $input) - ), - (MHLO_ConstantLike<"1"> $input) - ) - ) - ) - ) - ) - ) - ), - (MHLO_LogOp - (MHLO_AddOp - (MHLO_AbsOp $input), - (MHLO_SqrtOp - (MHLO_AddOp - (MHLO_MulOp - (MHLO_AbsOp $input), - (MHLO_AbsOp $input) - ), - (MHLO_ConstantLike<"1"> $input) - ) - ) - ) - ) - ) - ) - )>; - -// Expand asinh for complex arguments to MHLO dialect as -// asinh(x) = log(x + sqrt(x^2 + 1)) -// -// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing: -// "For now, we ignore the question of overflow if x is a -// complex type, because we don't yet have exhaustive tests for complex trig -// functions". -def : Pat<(CHLO_AsinhOp ComplexElementType:$input), - (MHLO_LogOp - (MHLO_AddOp - $input, - (MHLO_SqrtOp - (MHLO_AddOp - (MHLO_MulOp $input, $input), - (MHLO_ConstantLike<"1"> $input) - ) - ) - ) - )>; - -// Express `atan` as -// atan(x) = atan2(x, 1) -def : Pat<(CHLO_AtanOp $input), - (MHLO_Atan2Op - $input, - (MHLO_ConstantLike<"1"> $input) - )>; - -// Express `atanh` for non-complex arguments as follows: -// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1 -// atanh(x) = nan otherwise -def : Pat<(CHLO_AtanhOp NonComplexElementType:$input), - (MHLO_SelectOp - (MHLO_CompareOp - (MHLO_AbsOp $input), - (MHLO_ConstantLike<"1"> $input), - MHLO_ComparisonDirectionValue<"GT">, - (MHLO_DEFAULT_COMPARISON_TYPE) - ), - (MHLO_ConstantLike<"NAN"> $input), - (MHLO_MulOp - (MHLO_SubtractOp - (MHLO_Log1pOp $input), - (MHLO_Log1pOp - (MHLO_NegOp $input) - ) - ), - (MHLO_ConstantLike<"0.5"> $input) - ) - )>; - -// Express `atanh` for complex arguments as follows: -// atanh(x) = (log(1 + x) - log(1 + (-x))) * 0.5 -// -// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing: -// "For now, we ignore the nan edge case for complex inputs, -// because we don't yet have exhaustive tests for complex trig functions". -def : Pat<(CHLO_AtanhOp ComplexElementType:$input), - (MHLO_MulOp - (MHLO_SubtractOp - (MHLO_Log1pOp $input), - (MHLO_Log1pOp - (MHLO_NegOp $input) - ) - ), - (MHLO_ConstantLike<"0.5"> $input) - )>; - -// Express `conj` as -// conj(x) = (re(x), -im(x)). -def : Pat<(CHLO_ConjOp $v), - (MHLO_ComplexOp (MHLO_RealOp $v), (MHLO_NegOp (MHLO_ImagOp $v)))>; - -// Express `is_inf` as -// is_inf(x) = is_pos_inf(|x|) -def : Pat<(CHLO_IsInfOp NonComplexElementType:$input), - (CHLO_IsPosInfOp - (MHLO_AbsOp $input) - )>; - -// Express `is_pos_inf` as -// is_pos_inf(x) = (x == +inf) -def : Pat<(CHLO_IsPosInfOp NonComplexElementType:$input), - (MHLO_CompareOp - $input, - (MHLO_ConstantLikePosInfValue $input), - MHLO_ComparisonDirectionValue<"EQ">, - (MHLO_DEFAULT_COMPARISON_TYPE) - )>; - -// Express `is_neg_inf` as -// is_neg_inf(x) = (x == -inf) -def : Pat<(CHLO_IsNegInfOp NonComplexElementType:$input), - (MHLO_CompareOp - $input, - (MHLO_ConstantLikeNegInfValue $input), - MHLO_ComparisonDirectionValue<"EQ">, - (MHLO_DEFAULT_COMPARISON_TYPE) - )>; - -def : Pat<(CHLO_ConstantOp $v), - (MHLO_ConstantOp $v)>; - def : Pat<(CHLO_TanOp $v), - (MHLO_TanOp $v)>; + (MHLO_TanOp $v), + [], [], (addBenefit 10)>; def : Pat<(CHLO_ErfOp $v), - (MHLO_ErfOp $v)>; + (MHLO_ErfOp $v), + [], [], (addBenefit 10)>; def : Pat<(CHLO_TopKOp AnyRankedTensor:$v, $k), - (MHLO_TopKOp $v, $k, ConstBoolAttrTrue)>; + (MHLO_TopKOp $v, $k, ConstBoolAttrTrue), + [], [], (addBenefit 10)>; diff --git a/mhlo/transforms/mhlo_passes.td b/mhlo/transforms/mhlo_passes.td index 5531d568a..62868358f 100644 --- a/mhlo/transforms/mhlo_passes.td +++ b/mhlo/transforms/mhlo_passes.td @@ -15,29 +15,28 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "func::FuncOp"> { - let summary = "Legalize CHLO to HLO."; - let constructor = "createChloLegalizeToHloPass()"; - let dependentDialects = ["mhlo::MhloDialect", "chlo::ChloDialect", - "shape::ShapeDialect", "scf::SCFDialect"]; - let options = [ - Option<"legalize_broadcasts_", "legalize-broadcasts", "bool", - /*default=*/"true", "Legalize implicit broadcasts to explicit HLO broadcasting forms">, - Option<"expand_compositions_", "expand-compositions", "bool", - /*default=*/"true", "Expands client-centric compositions to HLO primitives">, - ]; +def ChloLegalizeToHighLevelMhloPass : Pass<"chlo-legalize-to-high-level-mhlo", "func::FuncOp"> { + let summary = "Legalize CHLO's with XLA counterparts, like TopK and Erf."; + let description = [{ + Performs direct legalization of CHLO->MHLO only for high-level (non-basis) + ops with XLA support. These are MHLO ops that directly model the CHLO op, + such as TopK and Erf. + }]; + let dependentDialects = ["mhlo::MhloDialect"]; } -def ChloLegalizeToHloBasisOpsPass : Pass<"chlo-legalize-to-hlo-basis-ops", "func::FuncOp"> { - let summary = "Legalize specific CHLO ops (e.g. ErfOf and TopKOp) to basis MHLO ops."; +def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "func::FuncOp"> { + let summary = "Legalize CHLO to MHLO with XLA-supported ops."; let description = [{ - XLA has specialization for certain CHLO ops (ErfOp, TopKOp), and other - backends still require decomposition of these ops into the basis set which - can be converted safely to StableHLO. This pass is needed until we have - direct CHLO to StableHLO lowerings. + Performs legalization of CHLO->StableHLO->MHLO, while also preserving MHLO + high level operations when possible (see ChloLegalizeToHighLevelMhloPass). }]; - let constructor = "createChloLegalizeToHloBasisOpsPass()"; - let dependentDialects = ["mhlo::MhloDialect", "chlo::ChloDialect"]; + let dependentDialects = [ + "mhlo::MhloDialect", + "mlir::shape::ShapeDialect", + "mlir::stablehlo::StablehloDialect", + "mlir::tensor::TensorDialect" + ]; } def HloCanonicalizeScatterPass : Pass<"hlo-canonicalize-scatter", "func::FuncOp"> { diff --git a/mhlo/transforms/passes.h b/mhlo/transforms/passes.h index a52c5b4ee..b9a025cdb 100644 --- a/mhlo/transforms/passes.h +++ b/mhlo/transforms/passes.h @@ -49,14 +49,6 @@ std::unique_ptr> createLegalizeSortPass(); /// Lowers from HLO dialect to Standard dialect. std::unique_ptr> createLegalizeToStdPass(); -/// Lowers from the CHLO dialect to the HLO dialect. -std::unique_ptr> createChloLegalizeToHloPass( - bool legalizeBroadcasts = true, bool expandCompositions = true); - -/// Lowers specific ops from the CHLO dialect to an HLO basis opset -std::unique_ptr> -createChloLegalizeToHloBasisOpsPass(); - // Lowers from sparse ops in CHLO dialect to Linalg dialect. std::unique_ptr> createLegalizeSparseOperationsPass( bool legalizeToCustomCalls = true); diff --git a/mhlo/transforms/rewriters.h b/mhlo/transforms/rewriters.h index 4de69f27d..14e3add6f 100644 --- a/mhlo/transforms/rewriters.h +++ b/mhlo/transforms/rewriters.h @@ -191,23 +191,16 @@ void populateLegalizeSparseOpsToCustomCallPatterns(MLIRContext *context, namespace chlo { -// Populates a collection of conversion patterns for legalizing broadcasting -// client-HLO to their non-broadcasting counterparts. -void populateChloBroadcastingPatterns(MLIRContext *context, - RewritePatternSet *patterns); +// Populates direct translations between CHLO and MHLO ops for higher level +// MHLO ops like TopK and Erf. +void populateChloToHighLevelMhloOpPatterns(MLIRContext *context, + RewritePatternSet *patterns); -// Populates a collection of conversion patterns for legalizing client-HLO to -// HLO by decomposing client-operations to corresponding sequences of more -// primitive operations. This does not include the -// PopulateChloBroadcastingPatterns above. -void populateDecomposeChloPatterns(MLIRContext *context, - RewritePatternSet *patterns); - -// Adds pattern to decompose specific CHLO ops like ErfOp and TopKOp to their -// basis set of operations. These ops have 1:1 corresponding MHLO ops, but for -// certain backends, they need to be expanded. -void populateChloLegalizeToHloBasisOpsPatterns(MLIRContext *context, - RewritePatternSet *patterns); +// Populates direct translations between CHLO->MHLO high level ops +// and CHLO->StableHLO->MHLO patterns. +void populateChloToHloPatterns(MLIRContext *context, + TypeConverter *typeConverter, + RewritePatternSet *patterns); } // namespace chlo diff --git a/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp b/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp index 1ef587e4d..c0df37115 100644 --- a/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp +++ b/stablehlo/stablehlo/experimental/transforms/ChloRecomposeOps.cpp @@ -20,12 +20,14 @@ limitations under the License. #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/experimental/dialect/StablehloOps.h" #include "stablehlo/experimental/transforms/Passes.h" +#include "stablehlo/transforms/Passes.h" namespace mlir { namespace stablehlo { @@ -163,6 +165,14 @@ struct ChloRecomposeOpsPass } }; +void createChloLegalizeToStablehloPipeline(OpPassManager& pm) { + pm.addPass(mlir::stablehlo::experimental::createChloRecomposeOpsPass()); + pm.addNestedPass( + mlir::stablehlo::createChloLegalizeToStablehloPass()); + pm.addNestedPass( + mlir::stablehlo::createShapeLegalizeToStablehloPass()); +} + } // namespace experimental } // namespace stablehlo } // namespace mlir diff --git a/stablehlo/stablehlo/experimental/transforms/Passes.h b/stablehlo/stablehlo/experimental/transforms/Passes.h index c4c9dcded..0a75849c6 100644 --- a/stablehlo/stablehlo/experimental/transforms/Passes.h +++ b/stablehlo/stablehlo/experimental/transforms/Passes.h @@ -29,6 +29,8 @@ namespace experimental { #define GEN_PASS_REGISTRATION #include "stablehlo/experimental/transforms/Passes.h.inc" +void createChloLegalizeToStablehloPipeline(OpPassManager &pm); + } // namespace experimental } // namespace stablehlo } // namespace mlir diff --git a/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir index 663e38b64..53c635d0f 100644 --- a/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +++ b/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir @@ -1324,99 +1324,99 @@ func.func @zeta_f16(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[TMP_40:.*]] = stablehlo.multiply %[[TMP_33]], %[[TMP_33]] // CHECK: %[[TMP_41:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_40]] // CHECK: %[[TMP_42:.*]] = stablehlo.constant dense<2.200000e+01> - // CHECK: %[[TMP_43:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_42]] + // CHECK: %[[TMP_43:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_42]] // CHECK: %[[TMP_44:.*]] = stablehlo.constant dense<2.100000e+01> - // CHECK: %[[TMP_45:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_44]] + // CHECK: %[[TMP_45:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_44]] // CHECK: %[[TMP_46:.*]] = stablehlo.multiply %[[TMP_43]], %[[TMP_45]] // CHECK: %[[TMP_47:.*]] = stablehlo.constant dense<-1.39544646E-19> // CHECK: %[[TMP_48:.*]] = stablehlo.add %[[TMP_2]], %[[TMP_47]] // CHECK: %[[TMP_49:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_48]] // CHECK: %[[TMP_50:.*]] = stablehlo.multiply %[[TMP_46]], %[[TMP_49]] // CHECK: %[[TMP_51:.*]] = stablehlo.constant dense<2.000000e+01> - // CHECK: %[[TMP_52:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_51]] + // CHECK: %[[TMP_52:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_51]] // CHECK: %[[TMP_53:.*]] = stablehlo.constant dense<1.900000e+01> - // CHECK: %[[TMP_54:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_53]] + // CHECK: %[[TMP_54:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_53]] // CHECK: %[[TMP_55:.*]] = stablehlo.multiply %[[TMP_52]], %[[TMP_54]] // CHECK: %[[TMP_56:.*]] = stablehlo.constant dense<5.50900303E-18> // CHECK: %[[TMP_57:.*]] = stablehlo.add %[[TMP_50]], %[[TMP_56]] // CHECK: %[[TMP_58:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_57]] // CHECK: %[[TMP_59:.*]] = stablehlo.multiply %[[TMP_55]], %[[TMP_58]] // CHECK: %[[TMP_60:.*]] = stablehlo.constant dense<1.800000e+01> - // CHECK: %[[TMP_61:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_60]] + // CHECK: %[[TMP_61:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_60]] // CHECK: %[[TMP_62:.*]] = stablehlo.constant dense<1.700000e+01> - // CHECK: %[[TMP_63:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_62]] + // CHECK: %[[TMP_63:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_62]] // CHECK: %[[TMP_64:.*]] = stablehlo.multiply %[[TMP_61]], %[[TMP_63]] // CHECK: %[[TMP_65:.*]] = stablehlo.constant dense<-2.17486866E-16> // CHECK: %[[TMP_66:.*]] = stablehlo.add %[[TMP_59]], %[[TMP_65]] // CHECK: %[[TMP_67:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_66]] // CHECK: %[[TMP_68:.*]] = stablehlo.multiply %[[TMP_64]], %[[TMP_67]] // CHECK: %[[TMP_69:.*]] = stablehlo.constant dense<1.600000e+01> - // CHECK: %[[TMP_70:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_69]] + // CHECK: %[[TMP_70:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_69]] // CHECK: %[[TMP_71:.*]] = stablehlo.constant dense<1.500000e+01> - // CHECK: %[[TMP_72:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_71]] + // CHECK: %[[TMP_72:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_71]] // CHECK: %[[TMP_73:.*]] = stablehlo.multiply %[[TMP_70]], %[[TMP_72]] // CHECK: %[[TMP_74:.*]] = stablehlo.constant dense<8.58606213E-15> // CHECK: %[[TMP_75:.*]] = stablehlo.add %[[TMP_68]], %[[TMP_74]] // CHECK: %[[TMP_76:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_75]] // CHECK: %[[TMP_77:.*]] = stablehlo.multiply %[[TMP_73]], %[[TMP_76]] // CHECK: %[[TMP_78:.*]] = stablehlo.constant dense<1.400000e+01> - // CHECK: %[[TMP_79:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_78]] + // CHECK: %[[TMP_79:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_78]] // CHECK: %[[TMP_80:.*]] = stablehlo.constant dense<1.300000e+01> - // CHECK: %[[TMP_81:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_80]] + // CHECK: %[[TMP_81:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_80]] // CHECK: %[[TMP_82:.*]] = stablehlo.multiply %[[TMP_79]], %[[TMP_81]] // CHECK: %[[TMP_83:.*]] = stablehlo.constant dense<-3.3896803E-13> // CHECK: %[[TMP_84:.*]] = stablehlo.add %[[TMP_77]], %[[TMP_83]] // CHECK: %[[TMP_85:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_84]] // CHECK: %[[TMP_86:.*]] = stablehlo.multiply %[[TMP_82]], %[[TMP_85]] // CHECK: %[[TMP_87:.*]] = stablehlo.constant dense<1.200000e+01> - // CHECK: %[[TMP_88:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_87]] + // CHECK: %[[TMP_88:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_87]] // CHECK: %[[TMP_89:.*]] = stablehlo.constant dense<1.100000e+01> - // CHECK: %[[TMP_90:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_89]] + // CHECK: %[[TMP_90:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_89]] // CHECK: %[[TMP_91:.*]] = stablehlo.multiply %[[TMP_88]], %[[TMP_90]] // CHECK: %[[TMP_92:.*]] = stablehlo.constant dense<1.33825364E-11> // CHECK: %[[TMP_93:.*]] = stablehlo.add %[[TMP_86]], %[[TMP_92]] // CHECK: %[[TMP_94:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_93]] // CHECK: %[[TMP_95:.*]] = stablehlo.multiply %[[TMP_91]], %[[TMP_94]] // CHECK: %[[TMP_96:.*]] = stablehlo.constant dense<1.000000e+01> - // CHECK: %[[TMP_97:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_96]] + // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_96]] // CHECK: %[[TMP_98:.*]] = stablehlo.constant dense<9.000000e+00> - // CHECK: %[[TMP_99:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_98]] + // CHECK: %[[TMP_99:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_98]] // CHECK: %[[TMP_100:.*]] = stablehlo.multiply %[[TMP_97]], %[[TMP_99]] // CHECK: %[[TMP_101:.*]] = stablehlo.constant dense<-5.28419031E-10> // CHECK: %[[TMP_102:.*]] = stablehlo.add %[[TMP_95]], %[[TMP_101]] // CHECK: %[[TMP_103:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_102]] // CHECK: %[[TMP_104:.*]] = stablehlo.multiply %[[TMP_100]], %[[TMP_103]] // CHECK: %[[TMP_105:.*]] = stablehlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_106:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_105]] + // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_105]] // CHECK: %[[TMP_107:.*]] = stablehlo.constant dense<7.000000e+00> - // CHECK: %[[TMP_108:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_107]] + // CHECK: %[[TMP_108:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_107]] // CHECK: %[[TMP_109:.*]] = stablehlo.multiply %[[TMP_106]], %[[TMP_108]] // CHECK: %[[TMP_110:.*]] = stablehlo.constant dense<2.08767563E-8> // CHECK: %[[TMP_111:.*]] = stablehlo.add %[[TMP_104]], %[[TMP_110]] // CHECK: %[[TMP_112:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_111]] // CHECK: %[[TMP_113:.*]] = stablehlo.multiply %[[TMP_109]], %[[TMP_112]] // CHECK: %[[TMP_114:.*]] = stablehlo.constant dense<6.000000e+00> - // CHECK: %[[TMP_115:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_114]] + // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_114]] // CHECK: %[[TMP_116:.*]] = stablehlo.constant dense<5.000000e+00> - // CHECK: %[[TMP_117:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_116]] + // CHECK: %[[TMP_117:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_116]] // CHECK: %[[TMP_118:.*]] = stablehlo.multiply %[[TMP_115]], %[[TMP_117]] // CHECK: %[[TMP_119:.*]] = stablehlo.constant dense<-8.26719599E-7> // CHECK: %[[TMP_120:.*]] = stablehlo.add %[[TMP_113]], %[[TMP_119]] // CHECK: %[[TMP_121:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_120]] // CHECK: %[[TMP_122:.*]] = stablehlo.multiply %[[TMP_118]], %[[TMP_121]] // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_124:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_123]] + // CHECK: %[[TMP_124:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_123]] // CHECK: %[[TMP_125:.*]] = stablehlo.constant dense<3.000000e+00> - // CHECK: %[[TMP_126:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_125]] + // CHECK: %[[TMP_126:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_125]] // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_124]], %[[TMP_126]] // CHECK: %[[TMP_128:.*]] = stablehlo.constant dense<3.30687835E-5> // CHECK: %[[TMP_129:.*]] = stablehlo.add %[[TMP_122]], %[[TMP_128]] // CHECK: %[[TMP_130:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_129]] // CHECK: %[[TMP_131:.*]] = stablehlo.multiply %[[TMP_127]], %[[TMP_130]] // CHECK: %[[TMP_132:.*]] = stablehlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_133:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_132]] + // CHECK: %[[TMP_133:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_132]] // CHECK: %[[TMP_134:.*]] = stablehlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_135:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_134]] + // CHECK: %[[TMP_135:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_134]] // CHECK: %[[TMP_136:.*]] = stablehlo.multiply %[[TMP_133]], %[[TMP_135]] // CHECK: %[[TMP_137:.*]] = stablehlo.constant dense<-0.00138888892> // CHECK: %[[TMP_138:.*]] = stablehlo.add %[[TMP_131]], %[[TMP_137]] @@ -1600,99 +1600,99 @@ func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_128:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] // CHECK: %[[TMP_129:.*]] = stablehlo.divide %[[TMP_93]], %[[TMP_128]] // CHECK: %[[TMP_130:.*]] = stablehlo.constant dense<2.200000e+01> - // CHECK: %[[TMP_131:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_130]] + // CHECK: %[[TMP_131:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_130]] // CHECK: %[[TMP_132:.*]] = stablehlo.constant dense<2.100000e+01> - // CHECK: %[[TMP_133:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_132]] + // CHECK: %[[TMP_133:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_132]] // CHECK: %[[TMP_134:.*]] = stablehlo.multiply %[[TMP_131]], %[[TMP_133]] // CHECK: %[[TMP_135:.*]] = stablehlo.constant dense<-1.39544646E-19> // CHECK: %[[TMP_136:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_135]] // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_136]] // CHECK: %[[TMP_138:.*]] = stablehlo.multiply %[[TMP_134]], %[[TMP_137]] // CHECK: %[[TMP_139:.*]] = stablehlo.constant dense<2.000000e+01> - // CHECK: %[[TMP_140:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_139]] + // CHECK: %[[TMP_140:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_139]] // CHECK: %[[TMP_141:.*]] = stablehlo.constant dense<1.900000e+01> - // CHECK: %[[TMP_142:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_141]] + // CHECK: %[[TMP_142:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_141]] // CHECK: %[[TMP_143:.*]] = stablehlo.multiply %[[TMP_140]], %[[TMP_142]] // CHECK: %[[TMP_144:.*]] = stablehlo.constant dense<5.50900303E-18> // CHECK: %[[TMP_145:.*]] = stablehlo.add %[[TMP_138]], %[[TMP_144]] // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_145]] // CHECK: %[[TMP_147:.*]] = stablehlo.multiply %[[TMP_143]], %[[TMP_146]] // CHECK: %[[TMP_148:.*]] = stablehlo.constant dense<1.800000e+01> - // CHECK: %[[TMP_149:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_148]] + // CHECK: %[[TMP_149:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_148]] // CHECK: %[[TMP_150:.*]] = stablehlo.constant dense<1.700000e+01> - // CHECK: %[[TMP_151:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_150]] + // CHECK: %[[TMP_151:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_150]] // CHECK: %[[TMP_152:.*]] = stablehlo.multiply %[[TMP_149]], %[[TMP_151]] // CHECK: %[[TMP_153:.*]] = stablehlo.constant dense<-2.17486866E-16> // CHECK: %[[TMP_154:.*]] = stablehlo.add %[[TMP_147]], %[[TMP_153]] // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_154]] // CHECK: %[[TMP_156:.*]] = stablehlo.multiply %[[TMP_152]], %[[TMP_155]] // CHECK: %[[TMP_157:.*]] = stablehlo.constant dense<1.600000e+01> - // CHECK: %[[TMP_158:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_157]] + // CHECK: %[[TMP_158:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_157]] // CHECK: %[[TMP_159:.*]] = stablehlo.constant dense<1.500000e+01> - // CHECK: %[[TMP_160:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_159]] + // CHECK: %[[TMP_160:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_159]] // CHECK: %[[TMP_161:.*]] = stablehlo.multiply %[[TMP_158]], %[[TMP_160]] // CHECK: %[[TMP_162:.*]] = stablehlo.constant dense<8.58606213E-15> // CHECK: %[[TMP_163:.*]] = stablehlo.add %[[TMP_156]], %[[TMP_162]] // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_163]] // CHECK: %[[TMP_165:.*]] = stablehlo.multiply %[[TMP_161]], %[[TMP_164]] // CHECK: %[[TMP_166:.*]] = stablehlo.constant dense<1.400000e+01> - // CHECK: %[[TMP_167:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_166]] + // CHECK: %[[TMP_167:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_166]] // CHECK: %[[TMP_168:.*]] = stablehlo.constant dense<1.300000e+01> - // CHECK: %[[TMP_169:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_168]] + // CHECK: %[[TMP_169:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_168]] // CHECK: %[[TMP_170:.*]] = stablehlo.multiply %[[TMP_167]], %[[TMP_169]] // CHECK: %[[TMP_171:.*]] = stablehlo.constant dense<-3.3896803E-13> // CHECK: %[[TMP_172:.*]] = stablehlo.add %[[TMP_165]], %[[TMP_171]] // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_172]] // CHECK: %[[TMP_174:.*]] = stablehlo.multiply %[[TMP_170]], %[[TMP_173]] // CHECK: %[[TMP_175:.*]] = stablehlo.constant dense<1.200000e+01> - // CHECK: %[[TMP_176:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_175]] + // CHECK: %[[TMP_176:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_175]] // CHECK: %[[TMP_177:.*]] = stablehlo.constant dense<1.100000e+01> - // CHECK: %[[TMP_178:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_177]] + // CHECK: %[[TMP_178:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_177]] // CHECK: %[[TMP_179:.*]] = stablehlo.multiply %[[TMP_176]], %[[TMP_178]] // CHECK: %[[TMP_180:.*]] = stablehlo.constant dense<1.33825364E-11> // CHECK: %[[TMP_181:.*]] = stablehlo.add %[[TMP_174]], %[[TMP_180]] // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_181]] // CHECK: %[[TMP_183:.*]] = stablehlo.multiply %[[TMP_179]], %[[TMP_182]] // CHECK: %[[TMP_184:.*]] = stablehlo.constant dense<1.000000e+01> - // CHECK: %[[TMP_185:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_184]] + // CHECK: %[[TMP_185:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_184]] // CHECK: %[[TMP_186:.*]] = stablehlo.constant dense<9.000000e+00> - // CHECK: %[[TMP_187:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_186]] + // CHECK: %[[TMP_187:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_186]] // CHECK: %[[TMP_188:.*]] = stablehlo.multiply %[[TMP_185]], %[[TMP_187]] // CHECK: %[[TMP_189:.*]] = stablehlo.constant dense<-5.28419031E-10> // CHECK: %[[TMP_190:.*]] = stablehlo.add %[[TMP_183]], %[[TMP_189]] // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_190]] // CHECK: %[[TMP_192:.*]] = stablehlo.multiply %[[TMP_188]], %[[TMP_191]] // CHECK: %[[TMP_193:.*]] = stablehlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_194:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_193]] + // CHECK: %[[TMP_194:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_193]] // CHECK: %[[TMP_195:.*]] = stablehlo.constant dense<7.000000e+00> - // CHECK: %[[TMP_196:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_195]] + // CHECK: %[[TMP_196:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_195]] // CHECK: %[[TMP_197:.*]] = stablehlo.multiply %[[TMP_194]], %[[TMP_196]] // CHECK: %[[TMP_198:.*]] = stablehlo.constant dense<2.08767563E-8> // CHECK: %[[TMP_199:.*]] = stablehlo.add %[[TMP_192]], %[[TMP_198]] // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_199]] // CHECK: %[[TMP_201:.*]] = stablehlo.multiply %[[TMP_197]], %[[TMP_200]] // CHECK: %[[TMP_202:.*]] = stablehlo.constant dense<6.000000e+00> - // CHECK: %[[TMP_203:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_202]] + // CHECK: %[[TMP_203:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_202]] // CHECK: %[[TMP_204:.*]] = stablehlo.constant dense<5.000000e+00> - // CHECK: %[[TMP_205:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_204]] + // CHECK: %[[TMP_205:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_204]] // CHECK: %[[TMP_206:.*]] = stablehlo.multiply %[[TMP_203]], %[[TMP_205]] // CHECK: %[[TMP_207:.*]] = stablehlo.constant dense<-8.26719599E-7> // CHECK: %[[TMP_208:.*]] = stablehlo.add %[[TMP_201]], %[[TMP_207]] // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_208]] // CHECK: %[[TMP_210:.*]] = stablehlo.multiply %[[TMP_206]], %[[TMP_209]] // CHECK: %[[TMP_211:.*]] = stablehlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_212:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_211]] + // CHECK: %[[TMP_212:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_211]] // CHECK: %[[TMP_213:.*]] = stablehlo.constant dense<3.000000e+00> - // CHECK: %[[TMP_214:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_213]] + // CHECK: %[[TMP_214:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_213]] // CHECK: %[[TMP_215:.*]] = stablehlo.multiply %[[TMP_212]], %[[TMP_214]] // CHECK: %[[TMP_216:.*]] = stablehlo.constant dense<3.30687835E-5> // CHECK: %[[TMP_217:.*]] = stablehlo.add %[[TMP_210]], %[[TMP_216]] // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_217]] // CHECK: %[[TMP_219:.*]] = stablehlo.multiply %[[TMP_215]], %[[TMP_218]] // CHECK: %[[TMP_220:.*]] = stablehlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_221:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_220]] + // CHECK: %[[TMP_221:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_220]] // CHECK: %[[TMP_222:.*]] = stablehlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_223:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_222]] + // CHECK: %[[TMP_223:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_222]] // CHECK: %[[TMP_224:.*]] = stablehlo.multiply %[[TMP_221]], %[[TMP_223]] // CHECK: %[[TMP_225:.*]] = stablehlo.constant dense<-0.00138888892> // CHECK: %[[TMP_226:.*]] = stablehlo.add %[[TMP_219]], %[[TMP_225]] @@ -1988,99 +1988,99 @@ func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_128:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] // CHECK: %[[TMP_129:.*]] = stablehlo.divide %[[TMP_93]], %[[TMP_128]] // CHECK: %[[TMP_130:.*]] = stablehlo.constant dense<2.200000e+01> - // CHECK: %[[TMP_131:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_130]] + // CHECK: %[[TMP_131:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_130]] // CHECK: %[[TMP_132:.*]] = stablehlo.constant dense<2.100000e+01> - // CHECK: %[[TMP_133:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_132]] + // CHECK: %[[TMP_133:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_132]] // CHECK: %[[TMP_134:.*]] = stablehlo.multiply %[[TMP_131]], %[[TMP_133]] // CHECK: %[[TMP_135:.*]] = stablehlo.constant dense<-1.3954464685812522E-19> // CHECK: %[[TMP_136:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_135]] // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_136]] // CHECK: %[[TMP_138:.*]] = stablehlo.multiply %[[TMP_134]], %[[TMP_137]] // CHECK: %[[TMP_139:.*]] = stablehlo.constant dense<2.000000e+01> - // CHECK: %[[TMP_140:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_139]] + // CHECK: %[[TMP_140:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_139]] // CHECK: %[[TMP_141:.*]] = stablehlo.constant dense<1.900000e+01> - // CHECK: %[[TMP_142:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_141]] + // CHECK: %[[TMP_142:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_141]] // CHECK: %[[TMP_143:.*]] = stablehlo.multiply %[[TMP_140]], %[[TMP_142]] // CHECK: %[[TMP_144:.*]] = stablehlo.constant dense<5.5090028283602295E-18> // CHECK: %[[TMP_145:.*]] = stablehlo.add %[[TMP_138]], %[[TMP_144]] // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_145]] // CHECK: %[[TMP_147:.*]] = stablehlo.multiply %[[TMP_143]], %[[TMP_146]] // CHECK: %[[TMP_148:.*]] = stablehlo.constant dense<1.800000e+01> - // CHECK: %[[TMP_149:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_148]] + // CHECK: %[[TMP_149:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_148]] // CHECK: %[[TMP_150:.*]] = stablehlo.constant dense<1.700000e+01> - // CHECK: %[[TMP_151:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_150]] + // CHECK: %[[TMP_151:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_150]] // CHECK: %[[TMP_152:.*]] = stablehlo.multiply %[[TMP_149]], %[[TMP_151]] // CHECK: %[[TMP_153:.*]] = stablehlo.constant dense<-2.1748686985580617E-16> // CHECK: %[[TMP_154:.*]] = stablehlo.add %[[TMP_147]], %[[TMP_153]] // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_154]] // CHECK: %[[TMP_156:.*]] = stablehlo.multiply %[[TMP_152]], %[[TMP_155]] // CHECK: %[[TMP_157:.*]] = stablehlo.constant dense<1.600000e+01> - // CHECK: %[[TMP_158:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_157]] + // CHECK: %[[TMP_158:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_157]] // CHECK: %[[TMP_159:.*]] = stablehlo.constant dense<1.500000e+01> - // CHECK: %[[TMP_160:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_159]] + // CHECK: %[[TMP_160:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_159]] // CHECK: %[[TMP_161:.*]] = stablehlo.multiply %[[TMP_158]], %[[TMP_160]] // CHECK: %[[TMP_162:.*]] = stablehlo.constant dense<8.5860620562778452E-15> // CHECK: %[[TMP_163:.*]] = stablehlo.add %[[TMP_156]], %[[TMP_162]] // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_163]] // CHECK: %[[TMP_165:.*]] = stablehlo.multiply %[[TMP_161]], %[[TMP_164]] // CHECK: %[[TMP_166:.*]] = stablehlo.constant dense<1.400000e+01> - // CHECK: %[[TMP_167:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_166]] + // CHECK: %[[TMP_167:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_166]] // CHECK: %[[TMP_168:.*]] = stablehlo.constant dense<1.300000e+01> - // CHECK: %[[TMP_169:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_168]] + // CHECK: %[[TMP_169:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_168]] // CHECK: %[[TMP_170:.*]] = stablehlo.multiply %[[TMP_167]], %[[TMP_169]] // CHECK: %[[TMP_171:.*]] = stablehlo.constant dense<-3.3896802963225832E-13> // CHECK: %[[TMP_172:.*]] = stablehlo.add %[[TMP_165]], %[[TMP_171]] // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_172]] // CHECK: %[[TMP_174:.*]] = stablehlo.multiply %[[TMP_170]], %[[TMP_173]] // CHECK: %[[TMP_175:.*]] = stablehlo.constant dense<1.200000e+01> - // CHECK: %[[TMP_176:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_175]] + // CHECK: %[[TMP_176:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_175]] // CHECK: %[[TMP_177:.*]] = stablehlo.constant dense<1.100000e+01> - // CHECK: %[[TMP_178:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_177]] + // CHECK: %[[TMP_178:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_177]] // CHECK: %[[TMP_179:.*]] = stablehlo.multiply %[[TMP_176]], %[[TMP_178]] // CHECK: %[[TMP_180:.*]] = stablehlo.constant dense<1.3382536530684679E-11> // CHECK: %[[TMP_181:.*]] = stablehlo.add %[[TMP_174]], %[[TMP_180]] // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_181]] // CHECK: %[[TMP_183:.*]] = stablehlo.multiply %[[TMP_179]], %[[TMP_182]] // CHECK: %[[TMP_184:.*]] = stablehlo.constant dense<1.000000e+01> - // CHECK: %[[TMP_185:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_184]] + // CHECK: %[[TMP_185:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_184]] // CHECK: %[[TMP_186:.*]] = stablehlo.constant dense<9.000000e+00> - // CHECK: %[[TMP_187:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_186]] + // CHECK: %[[TMP_187:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_186]] // CHECK: %[[TMP_188:.*]] = stablehlo.multiply %[[TMP_185]], %[[TMP_187]] // CHECK: %[[TMP_189:.*]] = stablehlo.constant dense<-5.2841901386874932E-10> // CHECK: %[[TMP_190:.*]] = stablehlo.add %[[TMP_183]], %[[TMP_189]] // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_190]] // CHECK: %[[TMP_192:.*]] = stablehlo.multiply %[[TMP_188]], %[[TMP_191]] // CHECK: %[[TMP_193:.*]] = stablehlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_194:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_193]] + // CHECK: %[[TMP_194:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_193]] // CHECK: %[[TMP_195:.*]] = stablehlo.constant dense<7.000000e+00> - // CHECK: %[[TMP_196:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_195]] + // CHECK: %[[TMP_196:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_195]] // CHECK: %[[TMP_197:.*]] = stablehlo.multiply %[[TMP_194]], %[[TMP_196]] // CHECK: %[[TMP_198:.*]] = stablehlo.constant dense<2.08767569878681E-8> // CHECK: %[[TMP_199:.*]] = stablehlo.add %[[TMP_192]], %[[TMP_198]] // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_199]] // CHECK: %[[TMP_201:.*]] = stablehlo.multiply %[[TMP_197]], %[[TMP_200]] // CHECK: %[[TMP_202:.*]] = stablehlo.constant dense<6.000000e+00> - // CHECK: %[[TMP_203:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_202]] + // CHECK: %[[TMP_203:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_202]] // CHECK: %[[TMP_204:.*]] = stablehlo.constant dense<5.000000e+00> - // CHECK: %[[TMP_205:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_204]] + // CHECK: %[[TMP_205:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_204]] // CHECK: %[[TMP_206:.*]] = stablehlo.multiply %[[TMP_203]], %[[TMP_205]] // CHECK: %[[TMP_207:.*]] = stablehlo.constant dense<-8.2671957671957675E-7> // CHECK: %[[TMP_208:.*]] = stablehlo.add %[[TMP_201]], %[[TMP_207]] // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_208]] // CHECK: %[[TMP_210:.*]] = stablehlo.multiply %[[TMP_206]], %[[TMP_209]] // CHECK: %[[TMP_211:.*]] = stablehlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_212:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_211]] + // CHECK: %[[TMP_212:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_211]] // CHECK: %[[TMP_213:.*]] = stablehlo.constant dense<3.000000e+00> - // CHECK: %[[TMP_214:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_213]] + // CHECK: %[[TMP_214:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_213]] // CHECK: %[[TMP_215:.*]] = stablehlo.multiply %[[TMP_212]], %[[TMP_214]] // CHECK: %[[TMP_216:.*]] = stablehlo.constant dense<3.3068783068783071E-5> // CHECK: %[[TMP_217:.*]] = stablehlo.add %[[TMP_210]], %[[TMP_216]] // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_217]] // CHECK: %[[TMP_219:.*]] = stablehlo.multiply %[[TMP_215]], %[[TMP_218]] // CHECK: %[[TMP_220:.*]] = stablehlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_221:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_220]] + // CHECK: %[[TMP_221:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_220]] // CHECK: %[[TMP_222:.*]] = stablehlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_223:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_222]] + // CHECK: %[[TMP_223:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_222]] // CHECK: %[[TMP_224:.*]] = stablehlo.multiply %[[TMP_221]], %[[TMP_223]] // CHECK: %[[TMP_225:.*]] = stablehlo.constant dense<-0.0013888888888888889> // CHECK: %[[TMP_226:.*]] = stablehlo.add %[[TMP_219]], %[[TMP_225]] diff --git a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp index 2cad98b5e..018382bbd 100644 --- a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +++ b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp @@ -1575,6 +1575,7 @@ static Value getConstantLikeSmallestFiniteValue(OpBuilder &b, Location loc, static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { + // Code should match XLA's materializeZeta from chlo_legalize_to_hlo.cc assert(args.size() == 2); Value x = args[0]; Value q = args[1]; @@ -1629,9 +1630,9 @@ static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc, // Using Horner's rule allows to avoid some NaN's and Infs from happening, // resulting in more numerically stable code. for (int i = 0; i < 11; ++i) { - Value factorLhs = rewriter.create( + Value factorLhs = rewriter.create( loc, x, getConstantLike(rewriter, loc, 22 - 2 * i, x)); - Value factorRhs = rewriter.create( + Value factorRhs = rewriter.create( loc, x, getConstantLike(rewriter, loc, 21 - 2 * i, x)); factor = rewriter.create(loc, factorLhs, factorRhs); hornerSum = rewriter.create( diff --git a/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir b/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir deleted file mode 100644 index 512e1cea6..000000000 --- a/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir +++ /dev/null @@ -1,345 +0,0 @@ -// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="legalize-broadcasts=true expand-compositions=false" -cse -canonicalize -split-input-file -verify-diagnostics %s -o - | FileCheck %s - -// Check the non-broadcast case for each registered op, then just check a -// representative op for detailed broadcast semantics. -// CHECK-LABEL: @addWithoutBroadcast -func.func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.add %arg0, %arg1 - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @dynamicBroadcast -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-SAME: %[[ARG1:.+]]: tensor -func.func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] - // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] - // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] - // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]] - // CHECK-NEXT: shape.assuming_yield %[[RESULT]] - // CHECK-NEXT: } - // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- -// CHECK-LABEL: @dynamicBroadcastComplex -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-SAME: %[[ARG1:.+]]: tensor -func.func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> tensor> { - // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] - // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] - // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] - // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] - // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-NEXT: %[[RESULT:.+]] = mhlo.complex %[[ARG0_B]], %[[ARG1_B]] : tensor> - // CHECK-NEXT: shape.assuming_yield %[[RESULT]] - // CHECK-NEXT: } - // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor> - %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> - func.return %0 : tensor> -} - -// ----- -// CHECK-LABEL: @dynamicBroadcastCompare -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-SAME: %[[ARG1:.+]]: tensor -func.func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] - // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] - // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] - // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] - // CHECK: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - // CHECK: %[[RESULT:.+]] = mhlo.compare EQ, %[[ARG0_B]], %[[ARG1_B]] : (tensor, tensor) -> tensor - // CHECK: shape.assuming_yield %[[RESULT]] - // CHECK-NEXT: } - // CHECK: return %[[FINAL_RESULT]] : tensor - %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @selectv2 -func.func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: mhlo.select %arg0, %arg1, %arg2 - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - func.return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @selectv2_pred_scalar -func.func @selectv2_pred_scalar(%arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: mhlo.select %arg0, %arg1, %arg2 - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - func.return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @selectv2_broadcast_then -func.func @selectv2_broadcast_then(%arg0: tensor, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { - // CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32> - // CHECK-NEXT: mhlo.select %arg0, %[[BROADCAST]], %arg2 - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> - func.return %0: tensor<2x8x8xi32> -} - -// CHECK-LABEL: func @selectv2_broadcast_else -func.func @selectv2_broadcast_else(%arg0: tensor, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> { - // CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32> - // CHECK-NEXT: mhlo.select %arg0, %arg1, %[[BROADCAST]] - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32> - func.return %0: tensor<2x8x8xi32> -} - -// CHECK-LABEL: func @selectv2_broadcast_pred -func.func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { - // CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>) -> tensor<2x8x8xi1> - // CHECK-NEXT: mhlo.select %[[BROADCAST]], %arg1, %arg2 - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> - func.return %0: tensor<2x8x8xi32> -} - -// CHECK-LABEL: func @selectv2_broadcast_tensor_pred -func.func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { - // CHECK-NEXT: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1> - // CHECK-NEXT: mhlo.select %[[BROADCAST]], %arg1, %arg2 - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> - func.return %0: tensor<2x3xf16> -} - -// CHECK-LABEL: func @selectv2_broadcast_all -func.func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { - // CHECK-DAG: %[[BROADCAST_0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1> - // CHECK-DAG: %[[BROADCAST_1:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32> - // CHECK-DAG: %[[BROADCAST_2:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32> - // CHECK: mhlo.select %[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]] - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32> - func.return %0: tensor<8x8x8xi32> -} - -// CHECK-LABEL: func @selectv2_dynamic_ranked -func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> { - // CHECK-DAG: %[[SHAPE0:.*]] = shape.const_shape [1] : tensor<1xindex> - // CHECK-DAG: %[[SHAPE2:.*]] = shape.const_shape [2, 8, 8] : tensor<3xindex> - // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<2x?x8xi32> -> tensor<3xindex> - // CHECK-NEXT: %[[CSTR:.*]] = shape.cstr_broadcastable %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor<3xindex>, tensor<1xindex>, tensor<3xindex> - // CHECK-NEXT: %[[ASSUME:.*]] = shape.assuming %[[CSTR]] -> (tensor<2x?x8xi32>) { - // CHECK-NEXT: %[[BCST:.*]] = shape.broadcast %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> - // CHECK-NEXT: %[[BCST0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[BCST]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>, tensor<3xindex>) -> tensor<2x?x8xi1> - // CHECK-NEXT: %[[BCST1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[BCST]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x?x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32> - // CHECK-NEXT: %[[BCST2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, %[[BCST]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x8x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32> - // CHECK-NEXT: %[[SELECT:.*]] = mhlo.select %[[BCST0]], %[[BCST1]], %[[BCST2]] : tensor<2x?x8xi1>, tensor<2x?x8xi32> - // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<2x?x8xi32> - // CHECK-NEXT: } - // CHECK-NEXT: return %[[ASSUME]] : tensor<2x?x8xi32> - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32> - func.return %0: tensor<2x?x8xi32> -} - -// ----- -// Verifies that broadcast_dimensions validity checks are valid. -// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions -func.func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK: mhlo.add - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - func.return %0 : tensor<1x4xf32> -} - -// ----- -// Verifies that broadcast_dimensions validity checks are valid. -// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions -func.func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor) -> tensor<1x4xf32> { - // CHECK: mhlo.add - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array} : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> - func.return %0 : tensor<1x4xf32> -} - -// ----- -// Verifies that invalid broadcast dimensions are rejected. -func.func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} - // expected-error @+1 {{failed to legalize operation}} - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - func.return %0 : tensor<1x4xf32> -} - -// ----- -// Verifies that invalid broadcast dimensions are rejected. -func.func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} - // expected-error @+1 {{failed to legalize operation}} - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - func.return %0 : tensor<1x4xf32> -} - -// ----- -// Note that broadcast_add is used as a proxy for all of the template -// expansions. Tests below merely verify that the op has an expansion. -// CHECK-LABEL: @andWithoutBroadcast -func.func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK: mhlo.and %arg0, %arg1 - %0 = chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> - func.return %0 : tensor<4xi1> -} - -// ----- -// CHECK-LABEL: @atan2WithoutBroadcast -func.func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.atan2 %arg0, %arg1 - %0 = chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @compareWithoutBroadcast -func.func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> { - // CHECK: mhlo.compare EQ, %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - func.return %0 : tensor<4xi1> -} - -// ----- -// CHECK-LABEL: @complexWithoutBroadcast -func.func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex> { - // CHECK: mhlo.complex %arg0, %arg1 : tensor<4xcomplex> - %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> - func.return %0 : tensor<4xcomplex> -} - -// ----- -// CHECK-LABEL: @divideWithoutBroadcast -func.func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.divide %arg0, %arg1 - %0 = chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @maximumWithoutBroadcast -func.func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.maximum %arg0, %arg1 - %0 = chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @minimumWithoutBroadcast -func.func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.minimum %arg0, %arg1 - %0 = chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @multiplyWithoutBroadcast -func.func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.multiply %arg0, %arg1 - %0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @orWithoutBroadcast -func.func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK: mhlo.or %arg0, %arg1 - %0 = chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> - func.return %0 : tensor<4xi1> -} - -// ----- -// CHECK-LABEL: @powerWithoutBroadcast -func.func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.power %arg0, %arg1 - %0 = chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @remainderWithoutBroadcast -func.func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.remainder %arg0, %arg1 - %0 = chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @shift_leftWithoutBroadcast -func.func @shift_leftWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: mhlo.shift_left %arg0, %arg1 - %0 = chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} - -// ----- -// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast -func.func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: mhlo.shift_right_arithmetic %arg0, %arg1 - %0 = chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} - -// ----- -// CHECK-LABEL: @shift_right_logicalWithoutBroadcast -func.func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: mhlo.shift_right_logical %arg0, %arg1 - %0 = chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} - -// ----- -// CHECK-LABEL: @subWithoutBroadcast -func.func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: mhlo.subtract %arg0, %arg1 - %0 = chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @xorWithoutBroadcast -func.func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK: mhlo.xor %arg0, %arg1 - %0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> - func.return %0 : tensor<4xi1> -} - -// ----- -// CHECK-LABEL: @NextAfterWithoutBroadcast -// CHECK-SAME: (%[[LHS:.*]]: tensor<4xf32>, %[[RHS:.*]]: tensor<4xf32>) -func.func @NextAfterWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) - -> tensor<4xf32> { - // CHECK: chlo.next_after %[[LHS]], %[[RHS]] - %0 = chlo.broadcast_next_after %arg0, %arg1 - : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @PolygammaWithoutBroadcast -// CHECK-SAME: (%[[LHS:.*]]: tensor<4xf32>, %[[RHS:.*]]: tensor<4xf32>) -func.func @PolygammaWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) - -> tensor<4xf32> { - // CHECK: chlo.polygamma %[[LHS]], %[[RHS]] - %0 = chlo.broadcast_polygamma %arg0, %arg1 - : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @ZetaWithoutBroadcast -func.func @ZetaWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) - -> tensor<4xf32> { - // CHECK: chlo.zeta %arg0, %arg1 - %0 = chlo.broadcast_zeta %arg0, %arg1 - : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} diff --git a/tests/Dialect/chlo/chlo_legalize_to_hlo_no_broadcasts.mlir b/tests/Dialect/chlo/chlo_legalize_to_hlo_no_broadcasts.mlir deleted file mode 100644 index 2a22006f2..000000000 --- a/tests/Dialect/chlo/chlo_legalize_to_hlo_no_broadcasts.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: mlir-hlo-opt -chlo-legalize-to-hlo="legalize-broadcasts=false" %s | FileCheck %s - -// CHECK-LABEL: atan_static -// CHECK-SAME: %[[ARG:.*]]: tensor<2x3x4xf32> -func.func @atan_static(%arg0: tensor<2x3x4xf32>) -> tuple> { - // CHECK: %[[CST:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x3x4xf32> - // CHECK: mhlo.atan2 %[[ARG]], %[[CST]] : tensor<2x3x4xf32> - %0 = chlo.atan %arg0 : tensor<2x3x4xf32> -> tensor<2x3x4xf32> - %1 = "mhlo.tuple"(%0) : (tensor<2x3x4xf32>) -> tuple> - func.return %1 : tuple> -} diff --git a/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir index 44b10a5eb..4d9835dc8 100644 --- a/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir +++ b/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir @@ -1,4 +1,5 @@ // RUN: mlir-hlo-opt --chlo-legalize-to-hlo --split-input-file -verify-diagnostics %s | FileCheck %s --dump-input-context=20 +// RUN: mlir-hlo-opt --chlo-legalize-to-high-level-mhlo --split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-HIGH-LEVEL // CHECK-LABEL: func.func @asin_bf16( // CHECK-SAME: %[[TMP_arg0:.*]]: tensor @@ -262,6 +263,7 @@ func.func @conj(%arg0: tensor<3xcomplex>) -> tensor<3xcomplex> { // CHECK-LABEL: @erf_f64 // CHECK-SAME: %[[ARG:.*]]: tensor func.func @erf_f64(%arg : tensor) -> tensor { + // CHECK-HIGH-LEVEL: mhlo.erf // CHECK: %[[RESULT:.*]] = mhlo.erf %[[ARG]] // CHECK: return %[[RESULT]] %1 = "chlo.erf"(%arg) : (tensor) -> tensor @@ -273,6 +275,7 @@ func.func @erf_f64(%arg : tensor) -> tensor { // CHECK-LABEL: @erf_f32 // CHECK-SAME: %[[ARG:.*]]: tensor func.func @erf_f32(%arg : tensor) -> tensor { + // CHECK-HIGH-LEVEL: mhlo.erf // CHECK: %[[RESULT:.*]] = mhlo.erf %[[ARG]] // CHECK: return %[[RESULT]] %1 = "chlo.erf"(%arg) : (tensor) -> tensor @@ -284,6 +287,7 @@ func.func @erf_f32(%arg : tensor) -> tensor { // CHECK-LABEL: @erf_f16 // CHECK-SAME: %[[ARG:.*]]: tensor func.func @erf_f16(%arg : tensor) -> tensor { + // CHECK-HIGH-LEVEL: mhlo.erf // CHECK: %[[RESULT:.*]] = mhlo.erf %[[ARG]] // CHECK: return %[[RESULT]] %1 = "chlo.erf"(%arg) : (tensor) -> tensor @@ -295,6 +299,7 @@ func.func @erf_f16(%arg : tensor) -> tensor { // CHECK-LABEL: @erf_bf16 // CHECK-SAME: %[[ARG:.*]]: tensor func.func @erf_bf16(%arg : tensor) -> tensor { + // CHECK-HIGH-LEVEL: mhlo.erf // CHECK: %[[RESULT:.*]] = mhlo.erf %[[ARG]] // CHECK: return %[[RESULT]] %1 = "chlo.erf"(%arg) : (tensor) -> tensor @@ -2256,12 +2261,9 @@ func.func @next_after_f32(%x: tensor<2xf32>, %y: tensor<2xf32>) -> tensor<2xf32> // CHECK-LABEL: @tan_f16 // CHECK-SAME: (%[[ARG:.*]]: tensor) func.func @tan_f16(%arg : tensor) -> tensor { - // %[[TMP_0:.*]] = mhlo.convert [[ARG]] : (tensor) -> tensor - // %[[TMP_1:.*]] = mhlo.sine %[[TMP_0]] - // %[[TMP_2:.*]] = mhlo.cosine %[[TMP_0]] - // %[[TMP_3:.*]] = mhlo.divide %[[TMP_1]], %[[TMP_2]] - // %[[TMP_4:.*]] = mhlo.convert %[[TMP_3]] : (tensor) -> tensor - // return %[[TMP_4]] : tensor + // CHECK-HIGH-LEVEL: mhlo.tan + // CHECK: %[[RESULT:.*]] = mhlo.tan %[[ARG]] : tensor + // CHECK: return %[[RESULT]] %1 = chlo.tan %arg : tensor -> tensor func.return %1 : tensor } @@ -2271,10 +2273,9 @@ func.func @tan_f16(%arg : tensor) -> tensor { // CHECK-LABEL: @tan_f32 // CHECK-SAME: (%[[ARG:.*]]: tensor) func.func @tan_f32(%arg : tensor) -> tensor { - // %[[TMP_0:.*]] = mhlo.sine %[[ARG]] - // %[[TMP_1:.*]] = mhlo.cosine %[[ARG]] - // %[[TMP_2:.*]] = mhlo.divide %[[TMP_0]], %[[TMP_1]] - // return %[[TMP_2]] : tensor + // CHECK-HIGH-LEVEL: mhlo.tan + // CHECK: %[[RESULT:.*]] = mhlo.tan %[[ARG]] : tensor + // CHECK: return %[[RESULT]] %1 = chlo.tan %arg : tensor -> tensor func.return %1 : tensor } @@ -2284,6 +2285,7 @@ func.func @tan_f32(%arg : tensor) -> tensor { // CHECK-LABEL: @top_k // CHECK-SAME: (%[[ARG:.*]]: tensor<16x16xf32>) func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { + // CHECK-HIGH-LEVEL: mhlo.topk // CHECK: %values, %indices = mhlo.topk(%arg0, k = 8, largest = true) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) %1:2 = chlo.top_k(%arg, k=8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) func.return %1#0, %1#1 : tensor<16x8xf32>, tensor<16x8xi32> @@ -2295,6 +2297,7 @@ func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32 // CHECK-SAME: ([[ARG:%.*]]: tensor // CHECK-SAME: -> (tensor, tensor) func.func @dyn_top_k(%arg0: tensor) -> (tensor, tensor) { + // CHECK-HIGH-LEVEL: mhlo.topk // CHECK: %values, %indices = mhlo.topk(%arg0, k = 2, largest = true) : tensor -> (tensor, tensor) %values, %indices = chlo.top_k(%arg0, k = 2) : tensor -> (tensor, tensor) return %values, %indices : tensor, tensor diff --git a/tests/Dialect/chlo/chlo_legalize_to_mhlo_basis_ops.mlir b/tests/Dialect/chlo/chlo_legalize_to_mhlo_basis_ops.mlir deleted file mode 100644 index 0a9b7eef4..000000000 --- a/tests/Dialect/chlo/chlo_legalize_to_mhlo_basis_ops.mlir +++ /dev/null @@ -1,276 +0,0 @@ -// RUN: mlir-hlo-opt --chlo-legalize-to-hlo-basis-ops --chlo-legalize-to-hlo --split-input-file -verify-diagnostics %s | FileCheck %s - -// ----- - -// CHECK-LABEL: @erf_f64 -// CHECK-SAME: %[[ARG:.*]]: tensor -func.func @erf_f64(%arg : tensor) -> tensor { - // CHECK: %[[TMP_0:.*]] = mhlo.multiply %[[ARG]], %[[ARG]] - // CHECK: %[[TMP_3:.*]] = mhlo.constant dense<9.6049737398705161> - // CHECK: %[[TMP_5:.*]] = mhlo.multiply %[[TMP_3]], %[[TMP_0]] - // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<90.026019720384269> - // CHECK: %[[TMP_7:.*]] = mhlo.add %[[TMP_5]], %[[TMP_6]] - // CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_7]], %[[TMP_0]] - // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2232.0053459468431> - // CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]] - // CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_0]] - // CHECK: %[[TMP_12:.*]] = mhlo.constant dense<7003.3251411280507> - // CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]] - // CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_0]] - // CHECK: %[[TMP_15:.*]] = mhlo.constant dense<55592.301301039493> - // CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]] - // CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[ARG]], %[[TMP_16]] - // CHECK: %[[TMP_20:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_22:.*]] = mhlo.multiply %[[TMP_20]], %[[TMP_0]] - // CHECK: %[[TMP_23:.*]] = mhlo.constant dense<33.561714164750313> - // CHECK: %[[TMP_24:.*]] = mhlo.add %[[TMP_22]], %[[TMP_23]] - // CHECK: %[[TMP_25:.*]] = mhlo.multiply %[[TMP_24]], %[[TMP_0]] - // CHECK: %[[TMP_26:.*]] = mhlo.constant dense<521.35794978015269> - // CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_25]], %[[TMP_26]] - // CHECK: %[[TMP_28:.*]] = mhlo.multiply %[[TMP_27]], %[[TMP_0]] - // CHECK: %[[TMP_29:.*]] = mhlo.constant dense<4594.3238297098014> - // CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_28]], %[[TMP_29]] - // CHECK: %[[TMP_31:.*]] = mhlo.multiply %[[TMP_30]], %[[TMP_0]] - // CHECK: %[[TMP_32:.*]] = mhlo.constant dense<22629.000061389095> - // CHECK: %[[TMP_33:.*]] = mhlo.add %[[TMP_31]], %[[TMP_32]] - // CHECK: %[[TMP_34:.*]] = mhlo.multiply %[[TMP_33]], %[[TMP_0]] - // CHECK: %[[TMP_35:.*]] = mhlo.constant dense<49267.394260863592> - // CHECK: %[[TMP_36:.*]] = mhlo.add %[[TMP_34]], %[[TMP_35]] - // CHECK: %[[TMP_37:.*]] = mhlo.divide %[[TMP_17]], %[[TMP_36]] - // CHECK: %[[TMP_38:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[ARG]], %[[ARG]] - // CHECK: %[[TMP_40:.*]] = mhlo.negate %[[TMP_39]] - // CHECK: %[[TMP_41:.*]] = mhlo.exponential %[[TMP_40]] - // CHECK: %[[TMP_42:.*]] = mhlo.abs %[[ARG]] - // CHECK: %[[TMP_45:.*]] = mhlo.constant dense<2.4619698147353052E-10> - // CHECK: %[[TMP_47:.*]] = mhlo.multiply %[[TMP_45]], %[[TMP_42]] - // CHECK: %[[TMP_48:.*]] = mhlo.constant dense<0.56418956483106886> - // CHECK: %[[TMP_49:.*]] = mhlo.add %[[TMP_47]], %[[TMP_48]] - // CHECK: %[[TMP_50:.*]] = mhlo.multiply %[[TMP_49]], %[[TMP_42]] - // CHECK: %[[TMP_51:.*]] = mhlo.constant dense<7.4632105644226989> - // CHECK: %[[TMP_52:.*]] = mhlo.add %[[TMP_50]], %[[TMP_51]] - // CHECK: %[[TMP_53:.*]] = mhlo.multiply %[[TMP_52]], %[[TMP_42]] - // CHECK: %[[TMP_54:.*]] = mhlo.constant dense<48.637197098568137> - // CHECK: %[[TMP_55:.*]] = mhlo.add %[[TMP_53]], %[[TMP_54]] - // CHECK: %[[TMP_56:.*]] = mhlo.multiply %[[TMP_55]], %[[TMP_42]] - // CHECK: %[[TMP_57:.*]] = mhlo.constant dense<196.5208329560771> - // CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_56]], %[[TMP_57]] - // CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_58]], %[[TMP_42]] - // CHECK: %[[TMP_60:.*]] = mhlo.constant dense<526.44519499547732> - // CHECK: %[[TMP_61:.*]] = mhlo.add %[[TMP_59]], %[[TMP_60]] - // CHECK: %[[TMP_62:.*]] = mhlo.multiply %[[TMP_61]], %[[TMP_42]] - // CHECK: %[[TMP_63:.*]] = mhlo.constant dense<934.52852717195765> - // CHECK: %[[TMP_64:.*]] = mhlo.add %[[TMP_62]], %[[TMP_63]] - // CHECK: %[[TMP_65:.*]] = mhlo.multiply %[[TMP_64]], %[[TMP_42]] - // CHECK: %[[TMP_66:.*]] = mhlo.constant dense<1027.5518868951572> - // CHECK: %[[TMP_67:.*]] = mhlo.add %[[TMP_65]], %[[TMP_66]] - // CHECK: %[[TMP_68:.*]] = mhlo.multiply %[[TMP_67]], %[[TMP_42]] - // CHECK: %[[TMP_69:.*]] = mhlo.constant dense<557.53533536939938> - // CHECK: %[[TMP_70:.*]] = mhlo.add %[[TMP_68]], %[[TMP_69]] - // CHECK: %[[TMP_71:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_70]] - // CHECK: %[[TMP_74:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_76:.*]] = mhlo.multiply %[[TMP_74]], %[[TMP_42]] - // CHECK: %[[TMP_77:.*]] = mhlo.constant dense<13.228195115474499> - // CHECK: %[[TMP_78:.*]] = mhlo.add %[[TMP_76]], %[[TMP_77]] - // CHECK: %[[TMP_79:.*]] = mhlo.multiply %[[TMP_78]], %[[TMP_42]] - // CHECK: %[[TMP_80:.*]] = mhlo.constant dense<86.707214088598973> - // CHECK: %[[TMP_81:.*]] = mhlo.add %[[TMP_79]], %[[TMP_80]] - // CHECK: %[[TMP_82:.*]] = mhlo.multiply %[[TMP_81]], %[[TMP_42]] - // CHECK: %[[TMP_83:.*]] = mhlo.constant dense<354.93777888781989> - // CHECK: %[[TMP_84:.*]] = mhlo.add %[[TMP_82]], %[[TMP_83]] - // CHECK: %[[TMP_85:.*]] = mhlo.multiply %[[TMP_84]], %[[TMP_42]] - // CHECK: %[[TMP_86:.*]] = mhlo.constant dense<975.70850174320549> - // CHECK: %[[TMP_87:.*]] = mhlo.add %[[TMP_85]], %[[TMP_86]] - // CHECK: %[[TMP_88:.*]] = mhlo.multiply %[[TMP_87]], %[[TMP_42]] - // CHECK: %[[TMP_89:.*]] = mhlo.constant dense<1823.9091668790973> - // CHECK: %[[TMP_90:.*]] = mhlo.add %[[TMP_88]], %[[TMP_89]] - // CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_90]], %[[TMP_42]] - // CHECK: %[[TMP_92:.*]] = mhlo.constant dense<2246.3376081871097> - // CHECK: %[[TMP_93:.*]] = mhlo.add %[[TMP_91]], %[[TMP_92]] - // CHECK: %[[TMP_94:.*]] = mhlo.multiply %[[TMP_93]], %[[TMP_42]] - // CHECK: %[[TMP_95:.*]] = mhlo.constant dense<1656.6630919416134> - // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_94]], %[[TMP_95]] - // CHECK: %[[TMP_97:.*]] = mhlo.multiply %[[TMP_96]], %[[TMP_42]] - // CHECK: %[[TMP_98:.*]] = mhlo.constant dense<557.53534081772773> - // CHECK: %[[TMP_99:.*]] = mhlo.add %[[TMP_97]], %[[TMP_98]] - // CHECK: %[[TMP_100:.*]] = mhlo.divide %[[TMP_71]], %[[TMP_99]] - // CHECK: %[[TMP_103:.*]] = mhlo.constant dense<0.56418958354775506> - // CHECK: %[[TMP_105:.*]] = mhlo.multiply %[[TMP_103]], %[[TMP_42]] - // CHECK: %[[TMP_106:.*]] = mhlo.constant dense<1.275366707599781> - // CHECK: %[[TMP_107:.*]] = mhlo.add %[[TMP_105]], %[[TMP_106]] - // CHECK: %[[TMP_108:.*]] = mhlo.multiply %[[TMP_107]], %[[TMP_42]] - // CHECK: %[[TMP_109:.*]] = mhlo.constant dense<5.0190504225118051> - // CHECK: %[[TMP_110:.*]] = mhlo.add %[[TMP_108]], %[[TMP_109]] - // CHECK: %[[TMP_111:.*]] = mhlo.multiply %[[TMP_110]], %[[TMP_42]] - // CHECK: %[[TMP_112:.*]] = mhlo.constant dense<6.160210979930536> - // CHECK: %[[TMP_113:.*]] = mhlo.add %[[TMP_111]], %[[TMP_112]] - // CHECK: %[[TMP_114:.*]] = mhlo.multiply %[[TMP_113]], %[[TMP_42]] - // CHECK: %[[TMP_115:.*]] = mhlo.constant dense<7.4097426995044895> - // CHECK: %[[TMP_116:.*]] = mhlo.add %[[TMP_114]], %[[TMP_115]] - // CHECK: %[[TMP_117:.*]] = mhlo.multiply %[[TMP_116]], %[[TMP_42]] - // CHECK: %[[TMP_118:.*]] = mhlo.constant dense<2.9788666537210022> - // CHECK: %[[TMP_119:.*]] = mhlo.add %[[TMP_117]], %[[TMP_118]] - // CHECK: %[[TMP_120:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_119]] - // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_125:.*]] = mhlo.multiply %[[TMP_123]], %[[TMP_42]] - // CHECK: %[[TMP_126:.*]] = mhlo.constant dense<2.2605286322011726> - // CHECK: %[[TMP_127:.*]] = mhlo.add %[[TMP_125]], %[[TMP_126]] - // CHECK: %[[TMP_128:.*]] = mhlo.multiply %[[TMP_127]], %[[TMP_42]] - // CHECK: %[[TMP_129:.*]] = mhlo.constant dense<9.3960352493800147> - // CHECK: %[[TMP_130:.*]] = mhlo.add %[[TMP_128]], %[[TMP_129]] - // CHECK: %[[TMP_131:.*]] = mhlo.multiply %[[TMP_130]], %[[TMP_42]] - // CHECK: %[[TMP_132:.*]] = mhlo.constant dense<12.048953980809666> - // CHECK: %[[TMP_133:.*]] = mhlo.add %[[TMP_131]], %[[TMP_132]] - // CHECK: %[[TMP_134:.*]] = mhlo.multiply %[[TMP_133]], %[[TMP_42]] - // CHECK: %[[TMP_135:.*]] = mhlo.constant dense<17.081445074756591> - // CHECK: %[[TMP_136:.*]] = mhlo.add %[[TMP_134]], %[[TMP_135]] - // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_136]], %[[TMP_42]] - // CHECK: %[[TMP_138:.*]] = mhlo.constant dense<9.6089680906328585> - // CHECK: %[[TMP_139:.*]] = mhlo.add %[[TMP_137]], %[[TMP_138]] - // CHECK: %[[TMP_140:.*]] = mhlo.multiply %[[TMP_139]], %[[TMP_42]] - // CHECK: %[[TMP_141:.*]] = mhlo.constant dense<3.3690764510008151> - // CHECK: %[[TMP_142:.*]] = mhlo.add %[[TMP_140]], %[[TMP_141]] - // CHECK: %[[TMP_143:.*]] = mhlo.divide %[[TMP_120]], %[[TMP_142]] - // CHECK: %[[TMP_144:.*]] = mhlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_145:.*]] = mhlo.compare LT, %[[TMP_42]], %[[TMP_144]], NOTYPE - // CHECK: %[[TMP_146:.*]] = mhlo.select %[[TMP_145]], %[[TMP_100]], %[[TMP_143]] - // CHECK: %[[TMP_147:.*]] = mhlo.constant dense<-709.78271289338397> - // CHECK: %[[TMP_148:.*]] = mhlo.compare LT, %[[TMP_40]], %[[TMP_147]], NOTYPE - // CHECK: %[[TMP_149:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[TMP_150:.*]] = mhlo.select %[[TMP_148]], %[[TMP_149]], %[[TMP_146]] - // CHECK: %[[TMP_152:.*]] = mhlo.compare LT, %[[ARG]], %[[TMP_149]], NOTYPE - // CHECK: %[[TMP_153:.*]] = mhlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_154:.*]] = mhlo.subtract %[[TMP_153]], %[[TMP_150]] - // CHECK: %[[TMP_155:.*]] = mhlo.select %[[TMP_152]], %[[TMP_154]], %[[TMP_150]] - // CHECK: %[[TMP_156:.*]] = mhlo.subtract %[[TMP_38]], %[[TMP_155]] - // CHECK: %[[TMP_157:.*]] = mhlo.abs %[[ARG]] - // CHECK: %[[TMP_159:.*]] = mhlo.compare LT, %[[TMP_157]], %[[TMP_38]], NOTYPE - // CHECK: %[[RESULT:.*]] = mhlo.select %[[TMP_159]], %[[TMP_37]], %[[TMP_156]] - // CHECK: return %[[RESULT]] - %1 = "chlo.erf"(%arg) : (tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @erf_f32 -// CHECK-SAME: %[[ARG:.*]]: tensor -func.func @erf_f32(%arg : tensor) -> tensor { - // CHECK-DAG: %[[TMP_0:.*]] = mhlo.constant dense<-4.000000e+00> - // CHECK-DAG: %[[TMP_1:.*]] = mhlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_2:.*]] = mhlo.clamp %[[TMP_0]], %[[ARG]], %[[TMP_1]] - // CHECK: %[[TMP_3:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_2]] - // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<-2.72614237E-10> - // CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_6]], %[[TMP_3]] - // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2.77068146E-8> - // CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_8]], %[[TMP_9]] - // CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_3]] - // CHECK: %[[TMP_12:.*]] = mhlo.constant dense<-2.10102394E-6> - // CHECK: %[[TMP_13:.*]] = mhlo.add %[[TMP_11]], %[[TMP_12]] - // CHECK: %[[TMP_14:.*]] = mhlo.multiply %[[TMP_13]], %[[TMP_3]] - // CHECK: %[[TMP_15:.*]] = mhlo.constant dense<-5.69250624E-5> - // CHECK: %[[TMP_16:.*]] = mhlo.add %[[TMP_14]], %[[TMP_15]] - // CHECK: %[[TMP_17:.*]] = mhlo.multiply %[[TMP_16]], %[[TMP_3]] - // CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-7.34990637E-4> - // CHECK: %[[TMP_19:.*]] = mhlo.add %[[TMP_17]], %[[TMP_18]] - // CHECK: %[[TMP_20:.*]] = mhlo.multiply %[[TMP_19]], %[[TMP_3]] - // CHECK: %[[TMP_21:.*]] = mhlo.constant dense<-2.954600e-03> - // CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_20]], %[[TMP_21]] - // CHECK: %[[TMP_23:.*]] = mhlo.multiply %[[TMP_22]], %[[TMP_3]] - // CHECK: %[[TMP_24:.*]] = mhlo.constant dense<-0.0160960332> - // CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_23]], %[[TMP_24]] - // CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-1.45660715E-5> - // CHECK: %[[TMP_30:.*]] = mhlo.multiply %[[TMP_28]], %[[TMP_3]] - // CHECK: %[[TMP_31:.*]] = mhlo.constant dense<-2.13374049E-4> - // CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_30]], %[[TMP_31]] - // CHECK: %[[TMP_33:.*]] = mhlo.multiply %[[TMP_32]], %[[TMP_3]] - // CHECK: %[[TMP_34:.*]] = mhlo.constant dense<-0.00168282702> - // CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_33]], %[[TMP_34]] - // CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_35]], %[[TMP_3]] - // CHECK: %[[TMP_37:.*]] = mhlo.constant dense<-0.00737332925> - // CHECK: %[[TMP_38:.*]] = mhlo.add %[[TMP_36]], %[[TMP_37]] - // CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_38]], %[[TMP_3]] - // CHECK: %[[TMP_40:.*]] = mhlo.constant dense<-0.0142647391> - // CHECK: %[[TMP_41:.*]] = mhlo.add %[[TMP_39]], %[[TMP_40]] - // CHECK: %[[TMP_42:.*]] = mhlo.multiply %[[TMP_2]], %[[TMP_25]] - // CHECK: %[[TMP_43:.*]] = mhlo.divide %[[TMP_42]], %[[TMP_41]] - // CHECK-DAG: %[[TMP_44:.*]] = mhlo.constant dense<-1.000000e+00> - // CHECK-DAG: %[[TMP_45:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[RESULT:.*]] = mhlo.clamp %[[TMP_44]], %[[TMP_43]], %[[TMP_45]] - // CHECK: return %[[RESULT]] - %1 = "chlo.erf"(%arg) : (tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @erf_f16 -// CHECK-SAME: %[[ARG:.*]]: tensor -func.func @erf_f16(%arg : tensor) -> tensor { - // CHECK: mhlo.convert %[[ARG]] : (tensor) -> tensor - // CHECK: %[[RESULT:.*]] = mhlo.convert %{{.*}} : (tensor) -> tensor - // CHECK: return %[[RESULT]] - %1 = "chlo.erf"(%arg) : (tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @erf_bf16 -// CHECK-SAME: %[[ARG:.*]]: tensor -func.func @erf_bf16(%arg : tensor) -> tensor { - // CHECK: mhlo.convert %[[ARG]] : (tensor) -> tensor - // CHECK: %[[RESULT:.*]] = mhlo.convert %{{.*}} : (tensor) -> tensor - // CHECK: return %[[RESULT]] - %1 = "chlo.erf"(%arg) : (tensor) -> tensor - func.return %1 : tensor -} - - -// CHECK-LABEL: @top_k -// CHECK-SAME: (%[[ARG:.*]]: tensor<16x16xf32>) -func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { - // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} - // CHECK-NEXT: %[[SORT:.*]]:2 = "mhlo.sort"(%[[ARG]], %[[IOTA]]) ({ - // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor, %{{.*}}: tensor, %{{.*}}: tensor): - // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[LHS]], %[[RHS]], TOTALORDER - // CHECK-NEXT: mhlo.return %[[CMP]] - // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) - // CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[SORT]]#0) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - // CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[SORT]]#1) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - // CHECK-NEXT: return %[[VAL]], %[[IDX]] - %1:2 = chlo.top_k(%arg, k=8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) - func.return %1#0, %1#1 : tensor<16x8xf32>, tensor<16x8xi32> -} - -// ----- - -// CHECK-LABEL: @dyn_top_k -// CHECK-SAME: ([[ARG:%.*]]: tensor -// CHECK-SAME: -> (tensor, tensor) -func.func @dyn_top_k(%arg0: tensor) -> (tensor, tensor) { - // CHECK-NEXT: [[DIM_0_I32:%.*]] = "mhlo.get_dimension_size"([[ARG]]) {dimension = 0 : i64} : (tensor) -> tensor - // CHECK-NEXT: [[DIM_0_I32x1:%.*]] = mhlo.reshape [[DIM_0_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[DIM_1_I32:%.*]] = "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i64} : (tensor) -> tensor - // CHECK-NEXT: [[DIM_1_I32x1:%.*]] = mhlo.reshape [[DIM_1_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[DIM_2_I32:%.*]] = "mhlo.get_dimension_size"([[ARG]]) {dimension = 2 : i64} : (tensor) -> tensor - // CHECK-NEXT: [[DIM_2_I32x1:%.*]] = mhlo.reshape [[DIM_2_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[IOTA_SHAPE:%.*]] = "mhlo.concatenate"([[DIM_0_I32x1]], [[DIM_1_I32x1]], [[DIM_2_I32x1]]) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> - // CHECK-NEXT: [[K_I32:%.*]] = mhlo.constant dense<2> : tensor - // CHECK-NEXT: [[K_I32x1:%.*]] = mhlo.reshape [[K_I32]] : (tensor) -> tensor<1xi32> - // CHECK-NEXT: [[RESULT_SHAPE:%.*]] = "mhlo.concatenate"([[DIM_0_I32x1]], [[DIM_1_I32x1]], [[K_I32x1]]) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> - // CHECK-NEXT: [[IOTA:%.*]] = "mhlo.dynamic_iota"([[IOTA_SHAPE]]) {iota_dimension = 2 : i64} : (tensor<3xi32>) -> tensor - // CHECK-NEXT: [[SORT:%.*]]:2 = "mhlo.sort"([[ARG]], [[IOTA]]) ({ - // CHECK-NEXT: ^bb0([[ARG_1:%.*]]: tensor, [[ARG_2:%.*]]: tensor, [[ARG_3:%.*]]: tensor, [[ARG_4:%.*]]: tensor): - // CHECK-NEXT: [[CMP:%.*]] = mhlo.compare GT, [[ARG_1]], [[ARG_2]], NOTYPE : (tensor, tensor) -> tensor - // CHECK-NEXT: mhlo.return [[CMP]] : tensor - // CHECK-NEXT: }) {dimension = 2 : i64, is_stable = true} : (tensor, tensor) -> (tensor, tensor) - // CHECK-NEXT: [[STARTS:%.*]] = mhlo.constant dense<0> : tensor<3xi64> - // CHECK-NEXT: [[LIMITS:%.*]] = mhlo.convert [[RESULT_SHAPE]] : (tensor<3xi32>) -> tensor<3xi64> - // CHECK-NEXT: [[STRIDES:%.*]] = mhlo.constant dense<1> : tensor<3xi64> - // CHECK-NEXT: [[VAL:%.*]] = mhlo.real_dynamic_slice [[SORT]]#0, [[STARTS]], [[LIMITS]], [[STRIDES]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor - // CHECK-NEXT: [[IDX:%.*]] = mhlo.real_dynamic_slice [[SORT]]#1, [[STARTS]], [[LIMITS]], [[STRIDES]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor - // CHECK-NEXT: return [[VAL]], [[IDX]] : tensor, tensor - %values, %indices = chlo.top_k(%arg0, k = 2) : tensor -> (tensor, tensor) - return %values, %indices : tensor, tensor -} diff --git a/tests/Dialect/mhlo/lower-complex.mlir b/tests/Dialect/mhlo/lower-complex.mlir index 8f59051c9..8c2e615ba 100644 --- a/tests/Dialect/mhlo/lower-complex.mlir +++ b/tests/Dialect/mhlo/lower-complex.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s +// RUN: mlir-hlo-opt %s --mhlo-test-lower-complex | FileCheck %s // CHECK-LABEL: @add func.func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {