diff --git a/BUILD.bazel b/BUILD.bazel index 117a706646..b755374dc7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1133,6 +1133,7 @@ cc_library( "stablehlo/transforms/VhloToVersion.cpp", ], hdrs = [ + "stablehlo/transforms/ChloDecompositionUtils.h", "stablehlo/transforms/MapStablehloToVhlo.h", "stablehlo/transforms/PassUtils.h", "stablehlo/transforms/Passes.h", diff --git a/stablehlo/transforms/ChloDecompositionUtils.h b/stablehlo/transforms/ChloDecompositionUtils.h new file mode 100644 index 0000000000..6bf3a4ffe6 --- /dev/null +++ b/stablehlo/transforms/ChloDecompositionUtils.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The StableHLO Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_TRANSFORMS_CHLO_DECOMP_UTILS_H_ +#define STABLEHLO_TRANSFORMS_CHLO_DECOMP_UTILS_H_ + +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace stablehlo { + +// Utility functions used in the Chlo to stablehlo legalization. + +Value materializeLgamma(OpBuilder &rewriter, Location loc, ValueRange args); + +Value materializeDigamma(OpBuilder &rewriter, Location loc, ValueRange args); + +Value materializeZeta(OpBuilder &rewriter, Location loc, ValueRange args); + +Value materializePolygamma(OpBuilder &rewriter, Location loc, ValueRange args); + +} // namespace stablehlo +} // namespace mlir + +#endif // STABLEHLO_TRANSFORMS_CHLO_DECOMP_UTILS_H_ diff --git a/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/transforms/ChloLegalizeToStablehlo.cpp index ef3a60756a..5c0cb155ea 100644 --- a/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/ChloLegalizeToStablehlo.cpp @@ -46,6 +46,7 @@ #include "stablehlo/dialect/BroadcastUtils.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/ChloDecompositionUtils.h" #include "stablehlo/transforms/PassUtils.h" #include "stablehlo/transforms/Passes.h" @@ -462,8 +463,7 @@ struct ConvertConstantOp final : OpConversionPattern { template static Value materializeChebyshevPolynomialApproximation( - ConversionPatternRewriter &rewriter, Location loc, Value x, - ArrayRef coefficients) { + OpBuilder &rewriter, Location loc, Value x, ArrayRef coefficients) { Value b0 = getConstantLike(rewriter, loc, 0.0, x); Value b1 = getConstantLike(rewriter, loc, 0.0, x); Value b2 = getConstantLike(rewriter, loc, 0.0, x); @@ -483,9 +483,10 @@ static Value materializeChebyshevPolynomialApproximation( } template -static Value materializeBesselI1eApproximation( - ConversionPatternRewriter &rewriter, Location loc, Value x, - ArrayRef kI1eCoeffsA, ArrayRef kI1eCoeffsB) { +static Value materializeBesselI1eApproximation(OpBuilder &rewriter, + Location loc, Value x, + ArrayRef kI1eCoeffsA, + ArrayRef kI1eCoeffsB) { Value z = rewriter.create(loc, x); Value half = getConstantLike(rewriter, loc, 0.5, x); Value two = getConstantLike(rewriter, loc, 2.0, x); @@ -515,8 +516,8 @@ static Value materializeBesselI1eApproximation( loc, rewriter.create(loc, x), select); } -Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { +Value materializeBesselI1eApproximationF32(OpBuilder &rewriter, Location loc, + ValueRange args) { Value x = args.front(); assert(cast(x.getType()).getElementType().isF32() && "expect f32 element type"); @@ -541,8 +542,9 @@ Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter, kI1eCoeffsB); } -static Value materializeBesselI1eApproximationF64( - ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { +static Value materializeBesselI1eApproximationF64(OpBuilder &rewriter, + Location loc, + ValueRange args) { Value x = args.front(); assert(cast(x.getType()).getElementType().isF64() && "expect f64 element type"); @@ -586,8 +588,8 @@ static Value materializeBesselI1eApproximationF64( static Value materializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc, ValueRange args, FloatType minPrecisionTy, - Value callback(ConversionPatternRewriter &, - Location, ValueRange)) { + Value callback(OpBuilder &, Location, + ValueRange)) { Type originalTy = getElementTypeOrSelf(args.front().getType()); auto floatOriginalTy = dyn_cast(originalTy); bool needsUpcast = @@ -645,9 +647,9 @@ struct ConvertBesselI1eOp final : OpConversionPattern { }; template -static Value materializePolynomialApproximation( - ConversionPatternRewriter &rewriter, Location loc, Value x, - ArrayRef coefficients) { +static Value materializePolynomialApproximation(OpBuilder &rewriter, + Location loc, Value x, + ArrayRef coefficients) { if (coefficients.empty()) return getConstantLike(rewriter, loc, 0.0, x); Value poly = getConstantLike(rewriter, loc, coefficients[0], x); @@ -836,7 +838,7 @@ static Value materializeErfcApproximationF64( // argument and derive the final approximation for all |x| >= 1. // This implementation is based on Cephes. static Value materializeErfcApproximationF32ForMagnitudeGeOne( - ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { + OpBuilder &rewriter, Location loc, ValueRange args) { Value x = args.front(); assert(cast(x.getType()).getElementType().isF32() && "expect f32 element type"); @@ -902,7 +904,7 @@ static Value materializeErfcApproximationF32ForMagnitudeGeOne( // Precondition is |x| <= 1. Use erfc approximation, otherwise. // This implementation is based on Cephes. static Value materializeErfApproximationF32ForMagnitudeLeOne( - ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { + OpBuilder &rewriter, Location loc, ValueRange args) { Value x = args.front(); assert(cast(x.getType()).getElementType().isF32() && "expect f32 element type"); @@ -921,8 +923,8 @@ static Value materializeErfApproximationF32ForMagnitudeLeOne( } // This is the same approximation as used in Eigen. -static Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { +static Value materializeErfApproximationF32(OpBuilder &rewriter, Location loc, + ValueRange args) { Value x = args.front(); assert(cast(x.getType()).getElementType().isF32() && "expect f32 element type"); @@ -958,8 +960,8 @@ static Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter, erf, ubErf); } -static Value materializeErfcApproximationF32( - ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { +static Value materializeErfcApproximationF32(OpBuilder &rewriter, Location loc, + ValueRange args) { Value x = args.front(); assert(cast(x.getType()).getElementType().isF32() && "expect f32 element type"); @@ -1041,8 +1043,7 @@ struct ConvertErfcOp final : OpConversionPattern { } }; -static Value erfInv32(ConversionPatternRewriter &b, Location loc, - ValueRange args) { +static Value erfInv32(OpBuilder &b, Location loc, ValueRange args) { constexpr int kDegree = 9; constexpr std::array wLessThan5Constants = { 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, @@ -1248,6 +1249,8 @@ constexpr std::array kLanczosCoefficients = { 12.507343278686904814458936853, -0.13857109526572011689554707, 9.984369578019570859563e-6, 1.50563273514931155834e-7}; +} // namespace + // Compute the Lgamma function using Lanczos' approximation from "A Precision // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis // series B. Vol. 1: @@ -1257,8 +1260,7 @@ constexpr std::array kLanczosCoefficients = { // with t(z) = z + kLanczosGamma + 1/2 // a(z) = kBaseLanczosCoeff // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) -static Value materializeLgamma(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { +Value materializeLgamma(OpBuilder &rewriter, Location loc, ValueRange args) { // If the input is less than 0.5 use Euler's reflection formula. // gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) // Let z be @@ -1393,6 +1395,8 @@ static Value materializeLgamma(ConversionPatternRewriter &rewriter, getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false), lgamma); } +namespace { + // Express `cosh` as // cosh(x) = (e^x + e^-x) / 2 // = e^(x + log(1/2)) + e^(-x + log(1/2)) @@ -1403,8 +1407,8 @@ static Value materializeLgamma(ConversionPatternRewriter &rewriter, // +/-89.4159851, due to rounding error when computing x +/- log(1/2). The // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so // we deem this acceptable. -static Value materializeCoshApproximation(ConversionPatternRewriter &rewriter, - Location loc, ValueRange operands) { +static Value materializeCoshApproximation(OpBuilder &rewriter, Location loc, + ValueRange operands) { mlir::chlo::CoshOp::Adaptor transformed(operands); Value x = transformed.getOperand(); @@ -1431,6 +1435,8 @@ struct ConvertCoshOp final : OpConversionPattern { } }; +} // namespace + // Compute the Digamma function using Lanczos' approximation from "A Precision // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis // series B. Vol. 1: @@ -1439,8 +1445,7 @@ struct ConvertCoshOp final : OpConversionPattern { // a(z) = kBaseLanczosCoeff // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) -static Value materializeDigamma(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { +Value materializeDigamma(OpBuilder &rewriter, Location loc, ValueRange args) { // If the input is less than 0.5 use Euler's reflection formula. // digamma(x) = digamma(1 - x) - pi * cot(pi * x) // Let z be @@ -1545,6 +1550,8 @@ static Value materializeDigamma(ConversionPatternRewriter &rewriter, digamma); } +namespace { + static Value getConstantLikeSmallestFiniteValue(OpBuilder &b, Location loc, Value val) { auto ty = cast(getElementTypeOrSelf(val.getType())); @@ -1552,7 +1559,7 @@ static Value getConstantLikeSmallestFiniteValue(OpBuilder &b, Location loc, b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val); } -static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc, +static Value materializeZeta(OpBuilder &rewriter, Location loc, ValueRange args) { // Implementation ported from: // https://github.com/openxla/xla/blob/7a067a7b88d2ffb15b1dc5e3c06f701a15f0391d/xla/client/lib/math.cc#L1912-L1917 @@ -1703,8 +1710,9 @@ static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc, return output; } -static Value materializePolygamma(ConversionPatternRewriter &rewriter, - Location loc, ValueRange args) { +} // namespace + +Value materializePolygamma(OpBuilder &rewriter, Location loc, ValueRange args) { mlir::chlo::PolygammaOp::Adaptor transformed(args); Value n = transformed.getN(); Value x = transformed.getX(); @@ -1747,6 +1755,8 @@ static Value materializePolygamma(ConversionPatternRewriter &rewriter, result); } +namespace { + struct ConvertLgammaOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1901,8 +1911,9 @@ struct ConvertPolygammaOp final : OpConversionPattern { // +/-89.4159851, due to rounding error when computing x +/- log(1/2). The // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so // we deem this acceptable. -static Value materializeSinhApproximationForLargeX( - ConversionPatternRewriter &rewriter, Location loc, ValueRange operands) { +static Value materializeSinhApproximationForLargeX(OpBuilder &rewriter, + Location loc, + ValueRange operands) { mlir::chlo::SinhOp::Adaptor transformed(operands); Value x = transformed.getOperand(); @@ -1918,8 +1929,8 @@ static Value materializeSinhApproximationForLargeX( // Express `sinh` as // sinh(x) = (e^x - e^-x) / 2 if |x| < 1 // = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. -static Value materializeSinhApproximation(ConversionPatternRewriter &rewriter, - Location loc, ValueRange operands) { +static Value materializeSinhApproximation(OpBuilder &rewriter, Location loc, + ValueRange operands) { Value largeSinhResult = materializeSinhApproximationForLargeX(rewriter, loc, operands);