Skip to content

Commit

Permalink
Include all materialization functions
Browse files Browse the repository at this point in the history
  • Loading branch information
vimarsh6739 committed Dec 11, 2024
1 parent 0808a99 commit c9996d6
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 26 deletions.
21 changes: 14 additions & 7 deletions stablehlo/transforms/ChloDecompositionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,27 @@ namespace stablehlo {

// Utility functions used in the Chlo to stablehlo legalization.

Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc, Value val);
Value materializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args);

Value getConstantLikeInfValue(OpBuilder &b, Location loc, Value val,
bool negative);
Value materializeCoshApproximation(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands);

Value getConstantLikeSmallestNormalizedValue(OpBuilder &b, Location loc,
Value val);
Value materializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands);

Value materializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args);
Value materializeSinhApproximation(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands);

Value materializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args);

Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args);

Value materializePolygamma(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args);

} // namespace stablehlo
} // namespace mlir

Expand Down
43 changes: 24 additions & 19 deletions stablehlo/transforms/ChloLegalizeToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,30 +170,29 @@ static void populateForBroadcastingBinaryOp(MLIRContext *context,
mlir::stablehlo::CompareOp, HloCompareAdaptor>>(
context, args...);
}
} // namespace

Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc, Value val) {
static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc,
Value val) {
auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
return getConstantLike(
b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val);
}

Value getConstantLikeInfValue(OpBuilder &b, Location loc, Value val,
bool negative) {
static Value getConstantLikeInfValue(OpBuilder &b, Location loc, Value val,
bool negative) {
auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
return getConstantLike(
b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
}

Value getConstantLikeSmallestNormalizedValue(OpBuilder &b, Location loc,
Value val) {
static Value getConstantLikeSmallestNormalizedValue(OpBuilder &b, Location loc,
Value val) {
auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
return getConstantLike(
b, loc, llvm::APFloat::getSmallestNormalized(ty.getFloatSemantics()),
val);
}

namespace {
//===----------------------------------------------------------------------===//
// Broadcasting Patterns.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1397,8 +1396,6 @@ Value materializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false), lgamma);
}

namespace {

// Express `cosh` as
// cosh(x) = (e^x + e^-x) / 2
// = e^(x + log(1/2)) + e^(-x + log(1/2))
Expand All @@ -1409,8 +1406,8 @@ namespace {
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
static Value materializeCoshApproximation(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
Value materializeCoshApproximation(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
mlir::chlo::CoshOp::Adaptor transformed(operands);
Value x = transformed.getOperand();

Expand All @@ -1423,6 +1420,8 @@ static Value materializeCoshApproximation(ConversionPatternRewriter &rewriter,
return rewriter.create<mlir::stablehlo::AddOp>(loc, expAdd, expSub);
}

namespace {

struct ConvertCoshOp final : OpConversionPattern<mlir::chlo::CoshOp> {
using OpConversionPattern::OpConversionPattern;

Expand Down Expand Up @@ -1562,8 +1561,9 @@ static Value getConstantLikeSmallestFiniteValue(OpBuilder &b, Location loc,
b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val);
}

static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args) {
} // namespace
Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args) {
// Implementation ported from:
// https://github.com/openxla/xla/blob/7a067a7b88d2ffb15b1dc5e3c06f701a15f0391d/xla/client/lib/math.cc#L1912-L1917
// Reference: Johansson, Fredrik.
Expand Down Expand Up @@ -1713,8 +1713,8 @@ static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
return output;
}

static Value materializePolygamma(ConversionPatternRewriter &rewriter,
Location loc, ValueRange args) {
Value materializePolygamma(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args) {
mlir::chlo::PolygammaOp::Adaptor transformed(args);
Value n = transformed.getN();
Value x = transformed.getX();
Expand Down Expand Up @@ -1757,6 +1757,8 @@ static Value materializePolygamma(ConversionPatternRewriter &rewriter,
result);
}

namespace {

struct ConvertLgammaOp final : OpConversionPattern<mlir::chlo::LgammaOp> {
using OpConversionPattern::OpConversionPattern;

Expand Down Expand Up @@ -1900,6 +1902,7 @@ struct ConvertPolygammaOp final : OpConversionPattern<mlir::chlo::PolygammaOp> {
return success();
}
};
} // namespace

// Sinh(x) = (e^x - e^-x) / 2
// = e^(x + log(1/2)) - e^(-x + log(1/2)).
Expand All @@ -1911,8 +1914,8 @@ struct ConvertPolygammaOp final : OpConversionPattern<mlir::chlo::PolygammaOp> {
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
static Value materializeSinhApproximationForLargeX(
ConversionPatternRewriter &rewriter, Location loc, ValueRange operands) {
Value materializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
mlir::chlo::SinhOp::Adaptor transformed(operands);
Value x = transformed.getOperand();

Expand All @@ -1928,8 +1931,8 @@ static Value materializeSinhApproximationForLargeX(
// Express `sinh` as
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
static Value materializeSinhApproximation(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
Value materializeSinhApproximation(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
Value largeSinhResult =
materializeSinhApproximationForLargeX(rewriter, loc, operands);

Expand Down Expand Up @@ -1961,6 +1964,8 @@ static Value materializeSinhApproximation(ConversionPatternRewriter &rewriter,
loc, absXLtOne, smallSinhResult, largeSinhResult);
}

namespace {

struct ConvertSinhOp final : OpConversionPattern<mlir::chlo::SinhOp> {
using OpConversionPattern::OpConversionPattern;

Expand Down

0 comments on commit c9996d6

Please sign in to comment.