Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate LLVM at llvm/llvm-project@af20aff35ec3 #2670

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "0876c11ceeb093904decc4d89bef213d483a5656"
LLVM_COMMIT = "af20aff35ec37ead88903bc3e44f6a81c5c9ca4e"

LLVM_SHA256 = "8379577a71645bbba89dea08beba32b3e56b833da7340ba5be7efa3986c8f8ed"
LLVM_SHA256 = "6e31682011d8c483c6a41adf5389eb09ad7db84331ca985d33a5d59efd0388f6"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0876c11ceeb093904decc4d89bef213d483a5656
af20aff35ec37ead88903bc3e44f6a81c5c9ca4e
32 changes: 27 additions & 5 deletions stablehlo/transforms/StablehloRefineShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,10 @@ class RefinementKey {
// Which correlates to <func, sym_int_values, arg_types>
class RefineShapeState {
public:
RefineShapeState(
std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn)
: additionalPatternsFn(additionalPatternsFn) {}

enum class RefinementState {
NOT_ALREADY_REFINED,
ALREADY_REFINED,
Expand Down Expand Up @@ -431,7 +435,14 @@ class RefineShapeState {
});
}

void addAdditionalPatterns(RewritePatternSet& patterns) {
if (additionalPatternsFn.has_value())
additionalPatternsFn.value()(&patterns);
}

private:
std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn;

// Maps refined functions to the refinement context: the values of dimension
// arguments and the types of non-global-constant arguments. A function is
// added here when we start refining it.
Expand Down Expand Up @@ -1001,7 +1012,7 @@ struct UpdateRegionTypePattern : public OpRewritePattern<ReturnOp> {
LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
RefineShapeState& state) {
MLIRContext* context = func.getContext();
RewritePatternSet patterns(context);
RewritePatternSet patterns(func->getContext());
GreedyRewriteConfig config;

// The algorithm behind this pass consists of a single traversal of the
Expand All @@ -1019,6 +1030,9 @@ LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
populateStablehloRefineShapesPatterns(&patterns, context);
patterns.add<RefineCallOpPattern>(context, state);

// Populate additional patterns for StableHLO extensions.
state.addAdditionalPatterns(patterns);

// The folding patterns implement partial evaluation of shape computations
// which is a critical part of implementing type refinement for ops like
// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape
Expand Down Expand Up @@ -1103,15 +1117,23 @@ struct StablehloRefineShapesPass

// Start with empty state, and no dim args / token args.
MLIRContext* context = func.getContext();
RefineShapeState state;
RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes()));
if (failed(refineFunction(*context, state, key)))
return signalPassFailure();
if (failed(refineEntryFunction(*context, func))) return signalPassFailure();
}
};

} // namespace

LogicalResult refineEntryFunction(
MLIRContext& context, func::FuncOp func,
std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn) {
// Start with empty state, and no dim args / token args.
RefineShapeState state(additionalPatternsFn);
RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes()));
if (failed(refineFunction(context, state, key)))
return func.emitError("Failed to refine entry function");
return success();
}

func::FuncOp getStablehloRefineShapesTarget(ModuleOp module) {
// Only one function per module is supported at the moment to avoid the need
// to think about iterative type inference algorithms.
Expand Down
13 changes: 12 additions & 1 deletion stablehlo/transforms/StablehloRefineShapes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#ifndef STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H
#define STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H

#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
Expand Down Expand Up @@ -101,6 +100,18 @@ LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op,
return refineReturnShape(rewriter, op, shape);
}

// Entrypoint for any pass adding extensibility to the StableHLO shape
// refinement pass. If program is inlined before shape refinement,
// populateShapeRefinementPatterns can be safely used, but if shape refinement
// needs to operate on programs with functions and calls, then
// additionalPatterns will need to be populated and passed in.
using AdditionalShapeRefinementPatternsFn =
std::function<void(RewritePatternSet*)>;
LogicalResult refineEntryFunction(
MLIRContext& context, func::FuncOp func,
std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn =
std::nullopt);

// Custom call used to buffer operands for shape refinement
// This is a temporary artifact that is introduced by StablehloRefineArguments
// and is washed away during StablehloRefineShapes.
Expand Down