Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 authored Dec 16, 2024
1 parent 38fe0f4 commit 54dafb1
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 9 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "0876c11ceeb093904decc4d89bef213d483a5656"
LLVM_COMMIT = "af20aff35ec37ead88903bc3e44f6a81c5c9ca4e"

LLVM_SHA256 = "8379577a71645bbba89dea08beba32b3e56b833da7340ba5be7efa3986c8f8ed"
LLVM_SHA256 = "6e31682011d8c483c6a41adf5389eb09ad7db84331ca985d33a5d59efd0388f6"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0876c11ceeb093904decc4d89bef213d483a5656
af20aff35ec37ead88903bc3e44f6a81c5c9ca4e
32 changes: 27 additions & 5 deletions stablehlo/transforms/StablehloRefineShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,10 @@ class RefinementKey {
// Which correlates to <func, sym_int_values, arg_types>
class RefineShapeState {
public:
RefineShapeState(
std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn)
: additionalPatternsFn(additionalPatternsFn) {}

enum class RefinementState {
NOT_ALREADY_REFINED,
ALREADY_REFINED,
Expand Down Expand Up @@ -431,7 +435,14 @@ class RefineShapeState {
});
}

void addAdditionalPatterns(RewritePatternSet& patterns) {
if (additionalPatternsFn.has_value())
additionalPatternsFn.value()(&patterns);
}

private:
std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn;

// Maps refined functions to the refinement context: the values of dimension
// arguments and the types of non-global-constant arguments. A function is
// added here when we start refining it.
Expand Down Expand Up @@ -1001,7 +1012,7 @@ struct UpdateRegionTypePattern : public OpRewritePattern<ReturnOp> {
LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
RefineShapeState& state) {
MLIRContext* context = func.getContext();
RewritePatternSet patterns(context);
RewritePatternSet patterns(func->getContext());
GreedyRewriteConfig config;

// The algorithm behind this pass consists of a single traversal of the
Expand All @@ -1019,6 +1030,9 @@ LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
populateStablehloRefineShapesPatterns(&patterns, context);
patterns.add<RefineCallOpPattern>(context, state);

// Populate additional patterns for StableHLO extensions.
state.addAdditionalPatterns(patterns);

// The folding patterns implement partial evaluation of shape computations
// which is a critical part of implementing type refinement for ops like
// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape
Expand Down Expand Up @@ -1103,15 +1117,23 @@ struct StablehloRefineShapesPass

// Start with empty state, and no dim args / token args.
MLIRContext* context = func.getContext();
RefineShapeState state;
RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes()));
if (failed(refineFunction(*context, state, key)))
return signalPassFailure();
if (failed(refineEntryFunction(*context, func))) return signalPassFailure();
}
};

} // namespace

LogicalResult refineEntryFunction(
MLIRContext& context, func::FuncOp func,
std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn) {
// Start with empty state, and no dim args / token args.
RefineShapeState state(additionalPatternsFn);
RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes()));
if (failed(refineFunction(context, state, key)))
return func.emitError("Failed to refine entry function");
return success();
}

func::FuncOp getStablehloRefineShapesTarget(ModuleOp module) {
// Only one function per module is supported at the moment to avoid the need
// to think about iterative type inference algorithms.
Expand Down
13 changes: 12 additions & 1 deletion stablehlo/transforms/StablehloRefineShapes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#ifndef STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H
#define STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H

#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
Expand Down Expand Up @@ -101,6 +100,18 @@ LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op,
return refineReturnShape(rewriter, op, shape);
}

// Entrypoint for any pass adding extensibility to the StableHLO shape
// refinement pass. If program is inlined before shape refinement,
// populateShapeRefinementPatterns can be safely used, but if shape refinement
// needs to operate on programs with functions and calls, then
// additionalPatterns will need to be populated and passed in.
using AdditionalShapeRefinementPatternsFn =
std::function<void(RewritePatternSet*)>;
LogicalResult refineEntryFunction(
MLIRContext& context, func::FuncOp func,
std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn =
std::nullopt);

// Custom call used to buffer operands for shape refinement
// This is a temporary artifact that is introduced by StablehloRefineArguments
// and is washed away during StablehloRefineShapes.
Expand Down

0 comments on commit 54dafb1

Please sign in to comment.