Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Update LLVM to llvm/llvm-project@ac8bb735. C++ changes are related to
change in behavior of TypeConverter changed in

iree-org/llvm-project@3cc311a.
It used to generate UnrealizedConversionCastOp, during
applySignatureConversion in GenericOpTypePropagation of
TypePropagationPass.cpp, however now it's not. This causes
unrealized_conversion_cast to be generated later and hence survive the
pass. To repro above behavior, try undo the C++ change in this PR and
then:

```
wget https://gist.githubusercontent.com/raikonenfnu/dfb3b274007df8c4be87daf9ee67a5f4/raw/e48cc07e5fa558cd2c450b0e3ae46568136e1be6/type_propagate_repro.mlir
iree-opt --pass-pipeline='builtin.module(func.func(iree-codegen-type-propagation))' propagate_test.mlir -o /dev/null

error: failed to legalize unresolved materialization from ('i8') to ('i1') that remained live after conversion
  ^bb0(%in: i1, %in_0: f32, %in_1: f32, %out: f32):
       ^
propagate_test.mlir:5:8: note: see current operation: %10 = "builtin.unrealized_conversion_cast"(%arg0) : (i8) -> i1
propagate_test.mlir:6:11: note: see existing live user here: %10 = arith.select %9, %in_0, %in_1 : f32
```

Additionally, we made API changes in
6ed8924 from:
1. `applyPatternsAndFoldGreedily` -> `applyPatternsGreedily`
2. `applyOpPatternsAndFold` -> `applyOpPatternsGreedily`
To resolve depracated API error in bazel 

This PR also carries the following reverts:

llvm/llvm-project#119461

The main issue with PR 119461 is it breaks e2e riscv test by making it
get stuck on infinite loop.
```
/path/to/iree-build/tools/iree-compile --output-format=vm-bytecode --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=llvm-cpu --iree-input-type=stablehlo --iree-input-demote-f64-to-f32 --iree-llvmcpu-target-cpu=generic /path/to/iree/tests/e2e/stablehlo_ops/three_fry.mlir -o three_fly_exec_target.mlir --iree-llvmcpu-target-triple=riscv64 --iree-llvmcpu-target-abi=lp64d --iree-llvmcpu-target-cpu-features=+m,+a,+d,+zvl512b,+v --mlir-disable-threading
> infinite loop
```

---------

Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu authored Dec 31, 2024
1 parent a43d893 commit f27feff
Show file tree
Hide file tree
Showing 176 changed files with 318 additions and 405 deletions.
3 changes: 1 addition & 2 deletions compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2182,8 +2182,7 @@ struct LegalizeChlo final : impl::LegalizeChloBase<LegalizeChlo> {
mlir::shape::CstrBroadcastableOp::getCanonicalizationPatterns(patterns,
ctx);
mlir::tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ struct LegalizeShapeComputations final

auto func = this->getOperation();
populateLegalizeShapeComputationPatterns(&ctx, &patterns);
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
if (failed(applyPatternsGreedily(func, std::move(patterns)))) {
this->signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1156,8 +1156,7 @@ struct StableHLOCanonicalize final
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateCanonicalizationPatterns(ctx, &patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,7 @@ struct DotGeneralToDot final : impl::DotGeneralToDotBase<DotGeneralToDot> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populatePreprocessingDotGeneralToDotPatterns(&getContext(), &patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ struct EinsumToDotGeneral final
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populatePreprocessingEinsumToDotGeneralPatterns(&getContext(), &patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,7 @@ struct FlattenTuplesInCFG final
patterns.insert<DetupleCallOp, DetupleIndirectCallOp, DetupleConditionOp,
DetupleReturnOp, DetupleBranchOp>(ctx);
populateCanonicalizationPatterns(ctx, &patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,7 @@ struct FlattenTuplesInSCF final
patterns
.add<DetupleYieldOp, DetupleConditionOp, DetupleIfOp, DetupleWhileOp>(
ctx);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ struct GatherToTorchIndexSelect final
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
populatePreprocessingGatherToTorchIndexSelectPatterns(ctx, &patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ struct LowerComplex final : impl::LowerComplexBase<LowerComplex> {
RewritePatternSet patterns(ctx);
populatePreprocessingComplexPatterns(ctx, &patterns);
populateCanonicalizationPatterns(ctx, &patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1968,8 +1968,7 @@ struct StableHLOToStableHLOPreprocessing final
patterns.insert<ReorderConvOpKernelDimensions>(context);
patterns.insert<ReorderConvOpOutputDimensions>(context);
}
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,7 @@ struct UnfuseBatchNorm final : impl::UnfuseBatchNormBase<UnfuseBatchNorm> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populatePreprocessingUnfuseBatchNormPatterns(&getContext(), &patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ struct LegalizeStableHLOCustomCalls final

RewritePatternSet patterns(ctx);
patterns.add<HouseholderReflectorRewriter, ShapeAssertionDrop>(ctx);
if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) {
if (failed(applyPatternsGreedily(f, std::move(patterns)))) {
signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ struct ConvertStableHloToIreeInputDialects final
std::unique_ptr<TypeConverter> typeConverter =
createStableHloToLinalgTypeConverter();
typeConverter->addArgumentMaterialization(scalarToTensor);
typeConverter->addSourceMaterialization(scalarToTensor);
typeConverter->addTargetMaterialization(scalarToTensor);

// Run stablehlo canonicalization patterns with a high benefit to avoid some
// expensive expansions.
Expand Down Expand Up @@ -610,7 +612,7 @@ struct ConvertStableHloToIreeInputDialects final
RewritePatternSet removeUnusedOperandsResultsPatterns(context);
linalg::populateEraseUnusedOperandsAndResultsPatterns(
removeUnusedOperandsResultsPatterns);
if (failed(applyPatternsAndFoldGreedily(
if (failed(applyPatternsGreedily(
getOperation(),
std::move(removeUnusedOperandsResultsPatterns)))) {
return signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ RemoveSignTypeConverter::RemoveSignTypeConverter() {

LinalgTypeConverter::LinalgTypeConverter() : RemoveSignTypeConverter() {
addArgumentMaterialization(scalarToTensor);
addSourceMaterialization(scalarToTensor);
addTargetMaterialization(scalarToTensor);
}

} // namespace mlir::iree_compiler::stablehlo
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ class BitCastQuantTensorPass final
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<BitCastQuantizedMatmul, BitCastViewDtype>(context);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,7 @@ class ConvertTMTensorToLinalgExtPass final
patterns.add<ScatterOpConversion>(context);
patterns.add<AttentionOpConversion>(context);

if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ void BlockDynamicDimensionsPass::runOnOperation() {
memref::populateResolveRankedShapedTypeResultDimsPatterns(
bubbleExpandShapePatterns);
populateRemoveDeadMemAllocPatterns(bubbleExpandShapePatterns);
if (failed(applyPatternsAndFoldGreedily(
operation, std::move(bubbleExpandShapePatterns)))) {
if (failed(applyPatternsGreedily(operation,
std::move(bubbleExpandShapePatterns)))) {
operation->emitOpError(
"failed in application of bubble up expand shape patterns");
return signalPassFailure();
Expand All @@ -380,8 +380,8 @@ void BlockDynamicDimensionsPass::runOnOperation() {
context);
memref::populateResolveRankedShapedTypeResultDimsPatterns(
removeBarrierOpsPatterns);
if (failed(applyPatternsAndFoldGreedily(
operation, std::move(removeBarrierOpsPatterns)))) {
if (failed(applyPatternsGreedily(operation,
std::move(removeBarrierOpsPatterns)))) {
operation->emitOpError("failed in cleanup patterns");
return signalPassFailure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ void BubbleUpOrdinalOpsPass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.insert<BubbleUpAcrossCastOp<arith::IndexCastUIOp>>(context);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,8 +622,7 @@ void CPULowerToUKernelsPass::runOnOperation() {
// These patterns are inherently specific to the VMVX backend.
patterns.insert<LowerToUKernelPattern<IREE::Codegen::QueryTileSizesOp>>(
context, isVMVXBackend);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ void CPUPrepareUkernelsPass::runOnOperation() {
tensor::UnPackOp::getCanonicalizationPatterns(patterns, ctx);
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
tensor::populateFoldTensorEmptyPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ struct CleanupBufferAllocViewPass final
RewritePatternSet patterns(&getContext());
populateReshapeToInterfaceTensorPatterns(patterns);
populateRemoveDeadMemAllocPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ class ConcretizePadResultShapePass final
{
RewritePatternSet patterns(context);
populateConcretizePadResultShapePatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns),
config))) {
if (failed(applyPatternsGreedily(funcOp, std::move(patterns), config))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct ConfigTrackingCanonicalizerPass final
{
config.listener = &listener;
LogicalResult didConverge =
applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
applyPatternsGreedily(getOperation(), *patterns, config);
config.listener = nullptr;
if (this->testConvergence && failed(didConverge)) {
getOperation()->emitError("Canonicalizer failed to converge");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,7 @@ struct ConvertBf16ArithToF32Pass final
cleanupPatterns
.insert<PropagateCastF<arith::TruncFOp>, PropagateCastF<arith::ExtFOp>>(
context);
if (applyPatternsAndFoldGreedily(this->getOperation(),
std::move(cleanupPatterns))
if (applyPatternsGreedily(this->getOperation(), std::move(cleanupPatterns))
.failed()) {
return this->signalPassFailure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ void ConvertToDestinationPassingStylePass::runOnOperation() {
{
RewritePatternSet patterns(context);
patterns.insert<RemoveCstOutsDependency>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
Expand All @@ -632,15 +632,15 @@ void ConvertToDestinationPassingStylePass::runOnOperation() {
{
RewritePatternSet patterns(context);
linalg::populateEraseUnusedOperandsAndResultsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

{
RewritePatternSet patterns(context);
patterns.insert<SwitchStoreOfIfResultValue>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ convertToIGEMMAndSetConfig(FunctionOpInterface funcOp,
if (configFn.has_value()) {
patterns.add<SetIGEMMConfiguration>(context, configFn.value());
}
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
return failure();
}
}
Expand Down Expand Up @@ -150,8 +150,8 @@ convertToIGEMMAndSetConfig(FunctionOpInterface funcOp,
tensor::ExpandShapeOp::getCanonicalizationPatterns(
bubbleCollapseShapePatterns, context);
populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(bubbleCollapseShapePatterns)))) {
if (failed(applyPatternsGreedily(funcOp,
std::move(bubbleCollapseShapePatterns)))) {
return failure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ class DecomposeConvolutionToLowerDimOpsPass final
// 2. Run the patterns. This is the key part of this pass.
RewritePatternSet patterns(context);
linalg::populateDecomposeConvolutionPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class DecomposeLinalgGenericPass final
RewritePatternSet patterns(context);
linalg::populateDecomposeLinalgOpsPattern(patterns);
linalg::GenericOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ static LogicalResult commonRunOnOperation(
RewritePatternSet patterns(ctx);
patterns.add<linalg::DecomposeOuterUnitDimsPackOpPattern,
linalg::DecomposeOuterUnitDimsUnPackOpPattern>(ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
funcOp.emitError(
"failed to apply generalization patterns on pack/unpack ops for "
"outer unit dims cases");
Expand All @@ -123,7 +123,7 @@ static LogicalResult commonRunOnOperation(
if (!tileOuterToOne) {
RewritePatternSet patterns(ctx);
patterns.add<LowerPackPattern, LowerUnPackPattern>(ctx, controlFn);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
funcOp.emitError(
"failed to apply generalization patterns on pack/unpack ops for "
"general cases.");
Expand Down Expand Up @@ -223,7 +223,7 @@ static LogicalResult commonRunOnOperation(
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
ctx->getOrLoadDialect<tensor::TensorDialect>()->getCanonicalizationPatterns(
patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
return failure();
}
}
Expand All @@ -242,7 +242,7 @@ static LogicalResult commonRunOnOperation(
patterns.add<linalg::DecomposeOuterUnitDimsPackOpPattern,
linalg::DecomposeOuterUnitDimsUnPackOpPattern>(ctx);
}
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
return failure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void DropVectorUnitDimsPass::runOnOperation() {
vector::populateDropUnitDimWithShapeCastPatterns(patterns);
vector::InsertOp::getCanonicalizationPatterns(patterns, ctx);
vector::ExtractOp::getCanonicalizationPatterns(patterns, ctx);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
} // namespace
} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,17 @@ struct EmulateNarrowTypePass final

RewritePatternSet sinkBroadcast(ctx);
vector::populateSinkVectorOpsPatterns(sinkBroadcast);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(sinkBroadcast)))) {
if (failed(
applyPatternsGreedily(getOperation(), std::move(sinkBroadcast)))) {
getOperation()->emitOpError("failed in sinking of broadcasts");
return signalPassFailure();
}

// Also do the `bitcast -> extui/extsi` rewrite.
RewritePatternSet foldExtPatterns(ctx);
vector::populateVectorNarrowTypeRewritePatterns(foldExtPatterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(foldExtPatterns)))) {
if (failed(applyPatternsGreedily(getOperation(),
std::move(foldExtPatterns)))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ struct ExtractAddressComputationPass final
void ExtractAddressComputationPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateExtractAddressComputationPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Loading

0 comments on commit f27feff

Please sign in to comment.