Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Updated VHLO test files, there was an upstream change that modified the
in-memory use-list in a fully compatible way, likely caused by something
about the order of transforms applied, or the inclusion of new
unreazlied_casts as intermediate values.
  • Loading branch information
GleasonK authored Jan 13, 2025
1 parent 51028d9 commit b2d36c5
Show file tree
Hide file tree
Showing 31 changed files with 45 additions and 29 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "e86910337f98e57f5b9253f7d80d5b916eb1d97e"
LLVM_COMMIT = "faa3f752896903c2d09d389970d3d0ebf50a1073"

LLVM_SHA256 = "4ca0eff0ca86ed6f2fdb7682354fdf4c85151d90ac9fb6e55a868e4191359e9f"
LLVM_SHA256 = "2c8b76b370dca2a70dac1036244598d357867071217074c5cdf15c43295b0042"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
e86910337f98e57f5b9253f7d80d5b916eb1d97e
faa3f752896903c2d09d389970d3d0ebf50a1073
14 changes: 11 additions & 3 deletions build_tools/math/generate_ChloDecompositionPatternsMath.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,15 @@ def main(kind="CHLO"):

output_file = os.path.relpath(
os.path.normpath(
os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo",
"transforms", output_filename)),
os.path.join(
os.path.dirname(__file__),
"..",
"..",
"stablehlo",
"transforms",
output_filename,
)
),
os.getcwd(),
)

Expand Down Expand Up @@ -113,7 +120,8 @@ def main(kind="CHLO"):
func = getattr(fa.algorithms, fname, None)
if func is None:
warnings.warn(
f"{fa.algorithms.__name__} does not define {fname}. Skipping.")
f"{fa.algorithms.__name__} does not define {fname}. Skipping."
)
continue
ctx = fa.Context(paths=[fa.algorithms],
parameters=dict(rewrite_keep_integer_literals=True))
Expand Down
15 changes: 10 additions & 5 deletions build_tools/math/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,16 @@ def main():
params = fa.utils.function_validation_parameters(opname, dtype)
max_ulp_difference = op.get(
"max_ulp_difference",
params.get("max_valid_ulp_count", default_max_ulp_difference))
params.get("max_valid_ulp_count", default_max_ulp_difference),
)

nmp = fa.utils.numpy_with_mpmath(
extra_prec_multiplier=op.get(
"extra_prec_multiplier",
params.get("extra_prec_multiplier",
default_extra_prec_multiplier)),
params.get(
"extra_prec_multiplier", default_extra_prec_multiplier
),
),
flush_subnormals=flush_subnormals,
)

Expand Down Expand Up @@ -224,8 +227,10 @@ def main():
continue

f = open(fname, "w")
f.write(f"// RUN: stablehlo-opt {passes} %s |"
" stablehlo-translate --interpret\n")
f.write(
f"// RUN: stablehlo-opt {passes} %s |"
" stablehlo-translate --interpret\n"
)
f.write(
"// This file is generated, see build_tools/math/README.md for more"
" information.\n")
Expand Down
2 changes: 2 additions & 0 deletions stablehlo/conversions/linalg/transforms/TypeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ RemoveSignTypeConverter::RemoveSignTypeConverter() {

LinalgTypeConverter::LinalgTypeConverter() : RemoveSignTypeConverter() {
addArgumentMaterialization(scalarToTensor);
addSourceMaterialization(scalarToTensor);
addTargetMaterialization(scalarToTensor);
}

} // namespace mlir::stablehlo
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) {
}

void StablehloLegalizeToTosaPass::runOnOperation() {
(void)applyPatternsAndFoldGreedily(getOperation(), patterns);
(void)applyPatternsGreedily(getOperation(), patterns);
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void StablehloPrepareForTosaPass::runOnOperation() {
// TODO: Enable post upstreaming decision.
// stablehlo::DotGeneralOp::getCanonicalizationPatterns(patterns, ctx);
// stablehlo::populateGeneralDotOpLoweringPatterns(&patterns, ctx);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ struct StablehloQuantLegalizeToTosaRescalePass
}
void runOnOperation() final {
auto func = getOperation();
if (failed(applyPatternsAndFoldGreedily(func, patterns))) {
if (failed(applyPatternsGreedily(func, patterns))) {
func.emitError(
"Failed to apply StablehloQuantLegalizeToTosaRescale pass ");
signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ struct TosaRescaleLegalizeToStablehloPass
}
void runOnOperation() final {
auto func = getOperation();
if (failed(applyPatternsAndFoldGreedily(func, patterns))) {
if (failed(applyPatternsGreedily(func, patterns))) {
func.emitError("Failed to apply TosaRescaleLegalizeToStablehlo pass ");
signalPassFailure();
}
Expand Down
5 changes: 2 additions & 3 deletions stablehlo/tests/TestUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ struct HloTestInferPass : public impl::HloTestInferPassBase<HloTestInferPass> {
}

void runOnOperation() override {
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}

Expand All @@ -176,7 +175,7 @@ struct HloTestSpeculatabilityPass
config.maxIterations = 1;
config.useTopDownTraversal = true;
config.enableRegionSimplification = GreedySimplifyRegionLevel::Disabled;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}

private:
Expand Down
6 changes: 4 additions & 2 deletions stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,6 @@ func.func @asinh_f32(%arg : tensor<f32>) -> tensor<f32> {

// -----


// CHECK-LABEL: func.func @asinh_f64(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<f64>) -> tensor<f64> {
// CHECK: %[[VAL_1:.*]] = stablehlo.sign %[[VAL_0]] : tensor<f64>
Expand Down Expand Up @@ -2788,7 +2787,6 @@ func.func @polygamma_f16(%lhs : tensor<f16>, %rhs : tensor<f16>) -> tensor<f16>

// -----


// CHECK-LABEL: @sinh_f32
// CHECK-SAME: (%[[X:.*]]: tensor<f32>)
func.func @sinh_f32(%x : tensor<f32>) -> tensor<f32> {
Expand Down Expand Up @@ -3891,6 +3889,8 @@ func.func @erf_inv_wide(%arg0 : tensor<16x16xf64>) {
return
}

// -----

// CHECK-LABEL: @square_complex_f32(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<complex<f32>>) -> tensor<complex<f32>> {
// CHECK: %[[VAL_1:.*]] = stablehlo.real %[[VAL_0]] : (tensor<complex<f32>>) -> tensor<f32>
Expand All @@ -3916,6 +3916,8 @@ func.func @square_complex_f32(%arg : tensor<complex<f32>>) -> tensor<complex<f32
func.return %result : tensor<complex<f32>>
}

// -----

// CHECK-LABEL: @square_f32(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: %[[VAL_1:.*]] = stablehlo.multiply %[[VAL_0]], %[[VAL_0]] : tensor<f32>
Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc
Binary file not shown.
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir.bc
Binary file not shown.
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_20_0.mlir.bc
Binary file not shown.
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_0_0.mlir.bc
Binary file not shown.
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_1_0.mlir.bc
Binary file not shown.
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_2_0.mlir.bc
Binary file not shown.
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_3_0.mlir.bc
Binary file not shown.
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_4_0.mlir.bc
Binary file not shown.
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_5_0.mlir.bc
Binary file not shown.
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_6_0.mlir.bc
Binary file not shown.
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_7_0.mlir.bc
Binary file not shown.
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_8_0.mlir.bc
Binary file not shown.
2 changes: 1 addition & 1 deletion stablehlo/transforms/StablehloAggressiveFolder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ struct StablehloAggressiveFolderPass
}

void runOnOperation() override {
if (failed(applyPatternsAndFoldGreedily(getOperation(), patterns)))
if (failed(applyPatternsGreedily(getOperation(), patterns)))
signalPassFailure();
}

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/transforms/StablehloAggressiveSimplification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,7 @@ struct StablehloAggressiveSimplificationPass final
}

void runOnOperation() override {
if (failed(applyPatternsAndFoldGreedily(getOperation(), patterns)))
if (failed(applyPatternsGreedily(getOperation(), patterns)))
signalPassFailure();
}

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/transforms/StablehloCanonicalizeDynamism.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ struct StablehloCanonicalizeDynamismPass

void runOnOperation() override {
auto func = getOperation();
if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
if (failed(applyPatternsGreedily(func, patterns, config))) {
func.emitError("Failed to converge StablehloCanonicalizeDynamism in ")
<< config.maxIterations << " iterations";
}
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/transforms/StablehloCompatibilityExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ struct StablehloCompatibilityExpanderPass

void runOnOperation() override {
auto func = getOperation();
if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
if (failed(applyPatternsGreedily(func, patterns, config))) {
func.emitError(
"Failed to converge StableHLOCompatibilityExpanderPass in ")
<< config.maxIterations << " iterations";
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/transforms/StablehloComplexMathExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct StablehloComplexMathExpanderPass

void runOnOperation() override {
auto func = getOperation();
if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
if (failed(applyPatternsGreedily(func, patterns, config))) {
func.emitError("Failed to converge StableHLOComplexMathExpanderPass in ")
<< config.maxIterations << " iterations";
signalPassFailure();
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class StablehloLegalizeQDQToQuantizedOpPass

void runOnOperation() override {
auto func = getOperation();
if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
if (failed(applyPatternsGreedily(func, patterns, config))) {
func.emitError(
"Failed to converge StablehloLegalizeQDQToQuantizedOpPass in ")
<< config.maxIterations << " iterations";
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class StablehloLegalizeQuantizedOpToQDQPass

void runOnOperation() override {
auto func = getOperation();
if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
if (failed(applyPatternsGreedily(func, patterns, config))) {
func.emitError("Failed to converge StablehloLegalizeQuantizedOpToQDQ in ")
<< config.maxIterations << " iterations";
signalPassFailure();
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/transforms/StablehloRefineShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
// function. This is sufficient because we only support one function per
// program at the moment.
// TODO(#1048): Find out why .maxIterations = 1 no longer works.
// There have been recent refactors to applyPatternsAndFoldGreedily
// There have been recent refactors to applyPatternsGreedily
// upstream, and that might be the reason.
config.useTopDownTraversal = true;
config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive;
Expand All @@ -1039,7 +1039,7 @@ LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
// depends on the value of their shape operands.
populateStablehloShapeFolderPatterns(&patterns, context);

if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns), config)))
if (failed(applyPatternsGreedily(func, std::move(patterns), config)))
func.emitError("Failed to converge StablehloRefineShapes in ")
<< config.maxIterations << " iterations";

Expand Down

0 comments on commit b2d36c5

Please sign in to comment.