Skip to content

Commit

Permalink
Port refprune to NewPassManager
Browse files Browse the repository at this point in the history
Based on the changes introduced in numba#1042 by @modiking
  • Loading branch information
yashssh committed Jul 18, 2024
1 parent f59140f commit 6b21b2a
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 68 deletions.
222 changes: 157 additions & 65 deletions ffi/custom_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,22 @@
using namespace llvm;

namespace llvm {
void initializeRefNormalizePassPass(PassRegistry &Registry);
void initializeRefPrunePassPass(PassRegistry &Registry);
void initializeRefNormalizeLegacyPassPass(PassRegistry &Registry);
void initializeRefPruneLegacyPassPass(PassRegistry &Registry);
} // namespace llvm

namespace llvm {
struct OpaqueModulePassManager;
typedef OpaqueModulePassManager *LLVMModulePassManagerRef;
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ModulePassManager, LLVMModulePassManagerRef)

struct OpaqueFunctionPassManager;
typedef OpaqueFunctionPassManager *LLVMFunctionPassManagerRef;
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(FunctionPassManager,
LLVMFunctionPassManagerRef)
} // namespace llvm

namespace {
/**
* Checks if a call instruction is an incref
*
Expand Down Expand Up @@ -104,13 +116,9 @@ template <class Tstack> struct raiiStack {
* A FunctionPass to reorder incref/decref instructions such that decrefs occur
* logically after increfs. This is a pre-requisite pass to the pruner passes.
*/
struct RefNormalizePass : public FunctionPass {
static char ID;
RefNormalizePass() : FunctionPass(ID) {
initializeRefNormalizePassPass(*PassRegistry::getPassRegistry());
}
struct RefNormalize {

bool runOnFunction(Function &F) override {
bool runOnFunction(Function &F) {
bool mutated = false;
// For each basic block in F
for (BasicBlock &bb : F) {
Expand Down Expand Up @@ -158,7 +166,16 @@ struct RefNormalizePass : public FunctionPass {
}
};

struct RefPrunePass : public FunctionPass {
typedef enum {
None = 0b0000,
PerBasicBlock = 0b0001,
Diamond = 0b0010,
Fanout = 0b0100,
FanoutRaise = 0b1000,
All = PerBasicBlock | Diamond | Fanout | FanoutRaise
} Subpasses;

struct RefPrune {
static char ID;
static size_t stats_per_bb;
static size_t stats_diamond;
Expand All @@ -175,25 +192,21 @@ struct RefPrunePass : public FunctionPass {
/**
* Enum for setting which subpasses to run, there is no interdependence.
*/
enum Subpasses {
None = 0b0000,
PerBasicBlock = 0b0001,
Diamond = 0b0010,
Fanout = 0b0100,
FanoutRaise = 0b1000,
All = PerBasicBlock | Diamond | Fanout | FanoutRaise
} flags;

RefPrunePass(Subpasses flags = Subpasses::All, size_t subgraph_limit = -1)
: FunctionPass(ID), flags(flags), subgraph_limit(subgraph_limit) {
initializeRefPrunePassPass(*PassRegistry::getPassRegistry());
}
Subpasses flags;

DominatorTree &DT;
PostDominatorTree &PDT;

RefPrune(DominatorTree &DT, PostDominatorTree &PDT,
Subpasses flags = Subpasses::All, size_t subgraph_limit = -1)
: DT(DT), PDT(PDT), flags(flags), subgraph_limit(subgraph_limit) {}

bool isSubpassEnabledFor(Subpasses expected) {
return (flags & expected) == expected;
}

bool runOnFunction(Function &F) override {
bool runOnFunction(Function &F) {
// state for LLVM function pass mutated IR
bool mutated = false;

Expand Down Expand Up @@ -361,11 +374,6 @@ struct RefPrunePass : public FunctionPass {
*/
bool runDiamondPrune(Function &F) {
bool mutated = false;
// gets the dominator tree
auto &domtree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
// gets the post-dominator tree
auto &postdomtree =
getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();

// Find all increfs and decrefs in the Function and store them in
// incref_list and decref_list respectively.
Expand Down Expand Up @@ -394,8 +402,8 @@ struct RefPrunePass : public FunctionPass {
continue;

// incref DOM decref && decref POSTDOM incref
if (domtree.dominates(incref, decref) &&
postdomtree.dominates(decref, incref)) {
if (DT.dominates(incref, decref) &&
PDT.dominates(decref, incref)) {
// check that the decref cannot be executed multiple times
SmallBBSet tail_nodes;
tail_nodes.insert(decref->getParent());
Expand Down Expand Up @@ -1028,14 +1036,6 @@ struct RefPrunePass : public FunctionPass {
return NULL;
}

/**
* getAnalysisUsage() LLVM plumbing for the pass
*/
void getAnalysisUsage(AnalysisUsage &Info) const override {
Info.addRequired<DominatorTreeWrapperPass>();
Info.addRequired<PostDominatorTreeWrapperPass>();
}

/**
* Checks if the first argument to the supplied call_inst is NULL and
* returns true if so, false otherwise.
Expand Down Expand Up @@ -1163,34 +1163,128 @@ struct RefPrunePass : public FunctionPass {
}
}
}
}; // end of struct RefPrunePass
}; // end of struct RefPrune

} // namespace

class RefPrunePass : public PassInfoMixin<RefPrunePass> {

public:
Subpasses flags;
size_t subgraph_limit;
RefPrunePass(Subpasses flags = Subpasses::All, size_t subgraph_limit = -1)
: flags(flags), subgraph_limit(subgraph_limit) {}

PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) {
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
if (RefPrune(DT, PDT, flags, subgraph_limit).runOnFunction(F)) {
return PreservedAnalyses::none();
}

return PreservedAnalyses::all();
}
};

char RefNormalizePass::ID = 0;
char RefPrunePass::ID = 0;
class RefNormalizePass : public PassInfoMixin<RefNormalizePass> {

size_t RefPrunePass::stats_per_bb = 0;
size_t RefPrunePass::stats_diamond = 0;
size_t RefPrunePass::stats_fanout = 0;
size_t RefPrunePass::stats_fanout_raise = 0;
public:
RefNormalizePass() = default;

INITIALIZE_PASS(RefNormalizePass, "nrtrefnormalizepass", "Normalize NRT refops",
false, false)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) {
RefNormalize().runOnFunction(F);

INITIALIZE_PASS_BEGIN(RefPrunePass, "nrtrefprunepass", "Prune NRT refops",
false, false)
return PreservedAnalyses::all();
}
};

class RefNormalizeLegacyPass : public FunctionPass {
public:
static char ID;
RefNormalizeLegacyPass() : FunctionPass(ID) {
initializeRefNormalizeLegacyPassPass(*PassRegistry::getPassRegistry());
}

bool runOnFunction(Function &F) override {
return RefNormalize().runOnFunction(F);
};
};

class RefPruneLegacyPass : public FunctionPass {

public:
static char ID; // Pass identification, replacement for typeid
// The maximum number of nodes that the fanout pruners will look at.
size_t subgraph_limit;
Subpasses flags;
RefPruneLegacyPass(Subpasses flags = Subpasses::All,
size_t subgraph_limit = -1)
: FunctionPass(ID), flags(flags), subgraph_limit(subgraph_limit) {
initializeRefPruneLegacyPassPass(*PassRegistry::getPassRegistry());
}

bool runOnFunction(Function &F) override {
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();

auto &PDT =
getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();

return RefPrune(DT, PDT, flags, subgraph_limit).runOnFunction(F);
};

/**
* getAnalysisUsage() LLVM plumbing for the pass
*/
void getAnalysisUsage(AnalysisUsage &Info) const override {
Info.addRequired<DominatorTreeWrapperPass>();
Info.addRequired<PostDominatorTreeWrapperPass>();
}
};

char RefNormalizeLegacyPass::ID = 0;
char RefPruneLegacyPass::ID = 0;

size_t RefPrune::stats_per_bb = 0;
size_t RefPrune::stats_diamond = 0;
size_t RefPrune::stats_fanout = 0;
size_t RefPrune::stats_fanout_raise = 0;

INITIALIZE_PASS(RefNormalizeLegacyPass, "nrtRefNormalize",
"Normalize NRT refops", false, false)

INITIALIZE_PASS_BEGIN(RefPruneLegacyPass, "nrtRefPruneLegacyPass",
"Prune NRT refops", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)

INITIALIZE_PASS_END(RefPrunePass, "refprunepass", "Prune NRT refops", false,
false)
INITIALIZE_PASS_END(RefPruneLegacyPass, "RefPruneLegacyPass",
"Prune NRT refops", false, false)

extern "C" {

API_EXPORT(void)
LLVMPY_AddRefPrunePass(LLVMPassManagerRef PM, int subpasses,
size_t subgraph_limit) {
unwrap(PM)->add(new RefNormalizePass());
LLVMPY_AddLegacyRefPrunePass(LLVMPassManagerRef PM, int subpasses,
size_t subgraph_limit) {
unwrap(PM)->add(new RefNormalizeLegacyPass());
unwrap(PM)->add(
new RefPrunePass((RefPrunePass::Subpasses)subpasses, subgraph_limit));
new RefPruneLegacyPass((Subpasses)subpasses, subgraph_limit));
}

API_EXPORT(void)
LLVMPY_AddRefPrunePass_module(LLVMModulePassManagerRef MPM, int subpasses,
size_t subgraph_limit) {
llvm::unwrap(MPM)->addPass(
createModuleToFunctionPassAdaptor(RefNormalizePass()));
llvm::unwrap(MPM)->addPass(createModuleToFunctionPassAdaptor(
RefPrunePass((Subpasses)subpasses, subgraph_limit)));
}

API_EXPORT(void)
LLVMPY_AddRefPrunePass_function(LLVMFunctionPassManagerRef FPM, int subpasses,
size_t subgraph_limit) {
llvm::unwrap(FPM)->addPass(RefNormalizePass());
llvm::unwrap(FPM)->addPass(
RefPrunePass((Subpasses)subpasses, subgraph_limit));
}

/**
Expand All @@ -1207,24 +1301,22 @@ typedef struct PruneStats {
API_EXPORT(void)
LLVMPY_DumpRefPruneStats(PRUNESTATS *buf, bool do_print) {
/* PRUNESTATS is updated with the statistics about what has been pruned from
* the RefPrunePass static state vars. This isn't threadsafe but neither is
* the RefPrune static state vars. This isn't threadsafe but neither is
* the LLVM pass infrastructure so it's all done under a python thread lock.
*
* do_print if set will print the stats to stderr.
*/
if (do_print) {
errs() << "refprune stats "
<< "per-BB " << RefPrunePass::stats_per_bb << " "
<< "diamond " << RefPrunePass::stats_diamond << " "
<< "fanout " << RefPrunePass::stats_fanout << " "
<< "fanout+raise " << RefPrunePass::stats_fanout_raise << " "
<< "\n";
errs() << "refprune stats " << "per-BB " << RefPrune::stats_per_bb
<< " " << "diamond " << RefPrune::stats_diamond << " "
<< "fanout " << RefPrune::stats_fanout << " " << "fanout+raise "
<< RefPrune::stats_fanout_raise << " " << "\n";
};

buf->basicblock = RefPrunePass::stats_per_bb;
buf->diamond = RefPrunePass::stats_diamond;
buf->fanout = RefPrunePass::stats_fanout;
buf->fanout_raise = RefPrunePass::stats_fanout_raise;
buf->basicblock = RefPrune::stats_per_bb;
buf->diamond = RefPrune::stats_diamond;
buf->fanout = RefPrune::stats_fanout;
buf->fanout_raise = RefPrune::stats_fanout_raise;
}

} // extern "C"
45 changes: 45 additions & 0 deletions llvmlite/binding/newpassmanagers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ctypes import c_bool, c_int
from enum import IntFlag
from llvmlite.binding import ffi


Expand All @@ -18,6 +19,14 @@ def create_pipeline_tuning_options(speed_level=2, size_level=0):
return PipelineTuningOptions(speed_level, size_level)


class RefPruneSubpasses(IntFlag):
PER_BB = 0b0001 # noqa: E221
DIAMOND = 0b0010 # noqa: E221
FANOUT = 0b0100 # noqa: E221
FANOUT_RAISE = 0b1000
ALL = PER_BB | DIAMOND | FANOUT | FANOUT_RAISE


class ModulePassManager(ffi.ObjectRef):

def __init__(self, ptr=None):
Expand Down Expand Up @@ -52,6 +61,24 @@ def add_jump_threading_pass(self, threshold=-1):
def _dispose(self):
ffi.lib.LLVMPY_DisposeNewModulePassManger(self)

# Non-standard LLVM passes
def add_refprune_pass(self, subpasses_flags=RefPruneSubpasses.ALL,
subgraph_limit=1000):
"""Add Numba specific Reference count pruning pass.
Parameters
----------
subpasses_flags : RefPruneSubpasses
A bitmask to control the subpasses to be enabled.
subgraph_limit : int
Limit the fanout pruners to working on a subgraph no bigger than
this number of basic-blocks to avoid spending too much time in very
large graphs. Default is 1000. Subject to change in future
versions.
"""
iflags = RefPruneSubpasses(subpasses_flags)
ffi.lib.LLVMPY_AddRefPrunePass_module(self, iflags, subgraph_limit)


class FunctionPassManager(ffi.ObjectRef):

Expand Down Expand Up @@ -84,6 +111,24 @@ def add_jump_threading_pass(self, threshold=-1):
def _dispose(self):
ffi.lib.LLVMPY_DisposeNewFunctionPassManger(self)

# Non-standard LLVM passes
def add_refprune_pass(self, subpasses_flags=RefPruneSubpasses.ALL,
subgraph_limit=1000):
"""Add Numba specific Reference count pruning pass.
Parameters
----------
subpasses_flags : RefPruneSubpasses
A bitmask to control the subpasses to be enabled.
subgraph_limit : int
Limit the fanout pruners to working on a subgraph no bigger than
this number of basic-blocks to avoid spending too much time in very
large graphs. Default is 1000. Subject to change in future
versions.
"""
iflags = RefPruneSubpasses(subpasses_flags)
ffi.lib.LLVMPY_AddRefPrunePass_function(self, iflags, subgraph_limit)


class PipelineTuningOptions(ffi.ObjectRef):

Expand Down
Loading

0 comments on commit 6b21b2a

Please sign in to comment.