Skip to content

Commit

Permalink
Merge pull request #13 from constantinpape/carving
Browse files Browse the repository at this point in the history
Carving
  • Loading branch information
constantinpape authored Feb 11, 2019
2 parents 0a19bc3 + 10a1a4c commit d3397c7
Show file tree
Hide file tree
Showing 9 changed files with 407 additions and 34 deletions.
239 changes: 239 additions & 0 deletions include/nifty/carving/carving.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
#pragma once

#include <boost/pending/disjoint_sets.hpp>
#include "nifty/xtensor/xtensor.hxx"
#include "nifty/graph/undirected_list_graph.hxx"


namespace nifty {
namespace carving {


// we provide implementations with kruskal and prim
template<class GRAPH>
class CarvingSegmenter {
public:

// TODO for now I made sorting edges optional, we could also just always
// do this, but I still want to check how much time this costs
template<class EDGES>
CarvingSegmenter(const GRAPH & graph,
const EDGES & edgeWeights,
const bool sortEdges) : graph_(graph),
nNodes_(graph.numberOfNodes()){
// check that the number of edges and len of edges agree
NIFTY_CHECK_OP(edgeWeights.size(), ==, graph_.numberOfEdges(), "Number of edges does not agree");

// copy the edge weights
edgeWeights_.resize(edgeWeights.size());
std::copy(edgeWeights.begin(), edgeWeights.end(), edgeWeights_.begin());

if(sortEdges) {
sortEdgeIndices();
}
}

template<class NODES>
inline void operator()(NODES & seeds,
const double bias,
const double noBiasBelow) const {
// check that the number of nodes agree
NIFTY_CHECK_OP(seeds.size(), ==, nNodes_, "Number of nodes does not agree");

// check if we can use kruskal: we don't have a bias and edges were pre-sorted
const bool useKruskal = (bias == 1.) && (edgesSorted_.size() == graph_.numberOfEdges());
if(useKruskal) {
// std::cout << "run kruskal" << std::endl;
runKruskal(seeds);
}
// otherwise we need to run prim
else {
// std::cout << "run prim" << std::endl;
runPrim(seeds, bias, noBiasBelow);
}

}

inline std::size_t nNodes() const {
return graph_.numberOfNodes();
}

inline const std::vector<std::size_t> & edgesSorted() const {
return edgesSorted_;
}

private:
// argsort the edges
inline void sortEdgeIndices() {
// we sort edge indices in ascending order
edgesSorted_.resize(graph_.numberOfEdges());
std::iota(edgesSorted_.begin(), edgesSorted_.end(), 0);
std::sort(edgesSorted_.begin(), edgesSorted_.end(), [&](const std::size_t a,
const std::size_t b){
return edgeWeights_[a] < edgeWeights_[b];}
);
}

template<class NODES>
inline void runPrim(NODES & seeds,
const double bias,
const double noBiasBelow) const {
typedef typename NODES::value_type NodeType;
typedef float WeightType;
const NodeType backgroundSeedLabel = 1;

// initialize the priority queue
typedef std::pair<std::size_t, WeightType> PQElement; // PQElement contains the edge-id and the weight
auto pqCompare = [](PQElement left, PQElement right) {return left.second < right.second;};
typedef std::priority_queue<PQElement, std::vector<PQElement>, decltype(pqCompare)> PriorityQueue;
PriorityQueue pq(pqCompare);

// put edges from seed nodes on the pq
for(std::size_t nodeId = 0; nodeId < nNodes_; ++nodeId) {
const NodeType seedId = seeds[nodeId];
if(seedId != 0) {

// check if this is a background seed and we use bias
const bool needBias = seedId == backgroundSeedLabel;

// iterate over the edges going from this node
// and put them on the pq
for(auto adjIt = graph_.adjacencyBegin(nodeId); adjIt != graph_.adjacencyEnd(nodeId); ++adjIt) {
const std::size_t edgeId = adjIt->edge();

// don't put on pq if the connected node has a seed
const std::size_t node = adjIt->node();
if(seeds[node] != 0) {
continue;
}

WeightType weight = edgeWeights_[edgeId];
if(needBias && (weight > noBiasBelow)) {
weight *= bias;
}

pq.push(std::make_pair(edgeId, weight));
}
}
}

// run prim
while(!pq.empty()) {
// extract next element from the queue
const PQElement elem = pq.top();
pq.pop();

const std::size_t edgeId = elem.first;

const auto u = graph_.u(edgeId);
const auto v = graph_.v(edgeId);
const NodeType lu = seeds[u];
const NodeType lv = seeds[v];

// check for seeds
if(lu == 0 && lv == 0){
throw std::runtime_error("both have no labels");
}
else if(lu != 0 && lv != 0){
continue;
}

const auto unlabeledNode = lu == 0 ? u : v;
const NodeType seedId = lu == 0 ? lv : lu;

// assign seedId to unlabeled node
seeds[unlabeledNode] = seedId;

// check if this is a background seed and thus we use bias
const bool needBias = seedId == backgroundSeedLabel;

// put outgoing edges on the pq
for(auto adjIt = graph_.adjacencyBegin(unlabeledNode); adjIt != graph_.adjacencyEnd(unlabeledNode); ++adjIt) {
const std::size_t nextEdge = adjIt->edge();

// check that this is not the ingoing edge
if(nextEdge == edgeId) {
continue;
}
// check that the node is not labeled
const auto nextNode = adjIt->node();
if(seeds[nextNode] != 0) {
continue;
}

WeightType weight = edgeWeights_[nextEdge];
if(needBias && weight > noBiasBelow) {
weight *= bias;
}
pq.push(std::make_pair(nextEdge, weight));
}

}
}

template<class NODES>
inline void runKruskal(NODES & seeds) const {
typedef typename NODES::value_type NodeType;
// make union find and map seeds to reperesentatives
std::vector<uint64_t> ranks(nNodes_);
std::vector<uint64_t> parents(nNodes_);
boost::disjoint_sets<uint64_t*, uint64_t*> ufd(&ranks[0], &parents[0]);

for(std::size_t node = 0; node < nNodes_; ++node) {
ufd.make_set(node);
}

// run kruskal
for(const std::size_t edgeId : edgesSorted_) {
// get the nodes connected by this edge
// and the representatives
const uint64_t u = graph_.u(edgeId);
const uint64_t v = graph_.v(edgeId);
const uint64_t ru = ufd.find_set(u);
const uint64_t rv = ufd.find_set(v);

// if the representatives are the same, continue
if(ru == rv) {
continue;
}

// get the seeds for our reperesentatives
const NodeType lu = seeds[ru];
const NodeType lv = seeds[rv];

// if we have two seeded regions (both values different from 0) continue
if(lu !=0 && lv != 0) {
continue;
}

// otherwise link the two representatives
ufd.link(ru, rv);

// if we have a seed, propagate it
if(lu != 0) {
seeds[rv] = lu;
}
if(lv != 0) {
seeds[ru] = lv;
}
}

// write all seeds
for(std::size_t node = 0; node < nNodes_; ++node) {
auto & seed = seeds[node];
if(seed == 0) {
seed = seeds[ufd.find_set(node)];
}
}
}

private:
const GRAPH & graph_;
std::size_t nNodes_;
std::vector<float> edgeWeights_;
std::vector<std::size_t> edgesSorted_;
};


}
}
1 change: 1 addition & 0 deletions include/nifty/graph/rag/grid_rag_features.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
namespace nifty{
namespace graph{

// TODO parallelize
template<size_t DIM, class GRAPH_LABELS, class LABELS, class NODE_MAP>
void gridRagAccumulateLabels(const GridRag<DIM, GRAPH_LABELS> & graph,
const xt::xexpression<LABELS> & dataExp,
Expand Down
1 change: 1 addition & 0 deletions src/python/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,4 @@ add_subdirectory(filters)
add_subdirectory(ground_truth)
add_subdirectory(distributed)
add_subdirectory(skeletons)
add_subdirectory(carving)
9 changes: 9 additions & 0 deletions src/python/lib/carving/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
addPythonModule(
NESTED_NAME
nifty/carving
SOURCES
carving.cxx
LIBRRARIES
${Boost_FILESYSTEM_LIBRARY}
${Boost_SYSTEM_LIBRARY}
)
74 changes: 74 additions & 0 deletions src/python/lib/carving/carving.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>

#define FORCE_IMPORT_ARRAY
#include "xtensor-python/pytensor.hpp"
#include "nifty/carving/carving.hxx"
#include "nifty/graph/rag/grid_rag.hxx"

namespace py = pybind11;

namespace nifty{
namespace carving{


template<class GRAPH>
void exportCarvingT(py::module & module,
const std::string & graphName) {

typedef xt::pytensor<float, 1> WeightsType;
typedef GRAPH GraphType;
typedef CarvingSegmenter<GraphType> CarvingType;
const auto clsName = std::string("CarvingSegmenter") + graphName;
py::class_<CarvingType>(module, clsName.c_str())
.def(py::init<const GraphType &, const WeightsType &, bool>(),
py::arg("graph"),
py::arg("edgeWeights"),
py::arg("sortEdges")=true)

// TODO for some reason pure call by reference does not work
// and we still need to return the seeds to see a change
.def("__call__", [](const CarvingType & self,
xt::pytensor<uint8_t, 1> & seeds,
const double bias,
const double noBiasBelow){
{
py::gil_scoped_release allowThreads;
self(seeds, bias, noBiasBelow);
}
return seeds;
}, py::arg("seeds"),
py::arg("bias"),
py::arg("noBiasBelow"))
;

}


void exportCarving(py::module & module) {

typedef xt::pytensor<uint32_t, 2> ExplicitLabels2D;
typedef graph::GridRag<2, ExplicitLabels2D> Rag2D;
exportCarvingT<Rag2D>(module, "Rag2D");

typedef xt::pytensor<uint32_t, 3> ExplicitLabels3D;
typedef graph::GridRag<3, ExplicitLabels3D> Rag3D;
exportCarvingT<Rag3D>(module, "Rag3D");
}

}
}


PYBIND11_MODULE(_carving, module) {

xt::import_numpy();

py::options options;
options.disable_function_signatures();

module.doc() = "carving submodule of nifty";

using namespace nifty::carving;
exportCarving(module);
}
3 changes: 0 additions & 3 deletions src/python/lib/graph/connected_components.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ namespace graph{
" numpy.ndarray : connected components labels"
);





typedef GRAPH GraphType;
typedef ComponentsUfd<GraphType> ComponentsType;
Expand Down
Loading

0 comments on commit d3397c7

Please sign in to comment.