-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from constantinpape/carving
Carving
- Loading branch information
Showing
9 changed files
with
407 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
#pragma once | ||
|
||
#include <boost/pending/disjoint_sets.hpp> | ||
#include "nifty/xtensor/xtensor.hxx" | ||
#include "nifty/graph/undirected_list_graph.hxx" | ||
|
||
|
||
namespace nifty { | ||
namespace carving { | ||
|
||
|
||
// we provide implementations with kruskal and prim | ||
template<class GRAPH> | ||
class CarvingSegmenter { | ||
public: | ||
|
||
// TODO for now I made sorting edges optional, we could also just always | ||
// do this, but I still want to check how much time this costs | ||
template<class EDGES> | ||
CarvingSegmenter(const GRAPH & graph, | ||
const EDGES & edgeWeights, | ||
const bool sortEdges) : graph_(graph), | ||
nNodes_(graph.numberOfNodes()){ | ||
// check that the number of edges and len of edges agree | ||
NIFTY_CHECK_OP(edgeWeights.size(), ==, graph_.numberOfEdges(), "Number of edges does not agree"); | ||
|
||
// copy the edge weights | ||
edgeWeights_.resize(edgeWeights.size()); | ||
std::copy(edgeWeights.begin(), edgeWeights.end(), edgeWeights_.begin()); | ||
|
||
if(sortEdges) { | ||
sortEdgeIndices(); | ||
} | ||
} | ||
|
||
template<class NODES> | ||
inline void operator()(NODES & seeds, | ||
const double bias, | ||
const double noBiasBelow) const { | ||
// check that the number of nodes agree | ||
NIFTY_CHECK_OP(seeds.size(), ==, nNodes_, "Number of nodes does not agree"); | ||
|
||
// check if we can use kruskal: we don't have a bias and edges were pre-sorted | ||
const bool useKruskal = (bias == 1.) && (edgesSorted_.size() == graph_.numberOfEdges()); | ||
if(useKruskal) { | ||
// std::cout << "run kruskal" << std::endl; | ||
runKruskal(seeds); | ||
} | ||
// otherwise we need to run prim | ||
else { | ||
// std::cout << "run prim" << std::endl; | ||
runPrim(seeds, bias, noBiasBelow); | ||
} | ||
|
||
} | ||
|
||
inline std::size_t nNodes() const { | ||
return graph_.numberOfNodes(); | ||
} | ||
|
||
inline const std::vector<std::size_t> & edgesSorted() const { | ||
return edgesSorted_; | ||
} | ||
|
||
private: | ||
// argsort the edges | ||
inline void sortEdgeIndices() { | ||
// we sort edge indices in ascending order | ||
edgesSorted_.resize(graph_.numberOfEdges()); | ||
std::iota(edgesSorted_.begin(), edgesSorted_.end(), 0); | ||
std::sort(edgesSorted_.begin(), edgesSorted_.end(), [&](const std::size_t a, | ||
const std::size_t b){ | ||
return edgeWeights_[a] < edgeWeights_[b];} | ||
); | ||
} | ||
|
||
template<class NODES> | ||
inline void runPrim(NODES & seeds, | ||
const double bias, | ||
const double noBiasBelow) const { | ||
typedef typename NODES::value_type NodeType; | ||
typedef float WeightType; | ||
const NodeType backgroundSeedLabel = 1; | ||
|
||
// initialize the priority queue | ||
typedef std::pair<std::size_t, WeightType> PQElement; // PQElement contains the edge-id and the weight | ||
auto pqCompare = [](PQElement left, PQElement right) {return left.second < right.second;}; | ||
typedef std::priority_queue<PQElement, std::vector<PQElement>, decltype(pqCompare)> PriorityQueue; | ||
PriorityQueue pq(pqCompare); | ||
|
||
// put edges from seed nodes on the pq | ||
for(std::size_t nodeId = 0; nodeId < nNodes_; ++nodeId) { | ||
const NodeType seedId = seeds[nodeId]; | ||
if(seedId != 0) { | ||
|
||
// check if this is a background seed and we use bias | ||
const bool needBias = seedId == backgroundSeedLabel; | ||
|
||
// iterate over the edges going from this node | ||
// and put them on the pq | ||
for(auto adjIt = graph_.adjacencyBegin(nodeId); adjIt != graph_.adjacencyEnd(nodeId); ++adjIt) { | ||
const std::size_t edgeId = adjIt->edge(); | ||
|
||
// don't put on pq if the connected node has a seed | ||
const std::size_t node = adjIt->node(); | ||
if(seeds[node] != 0) { | ||
continue; | ||
} | ||
|
||
WeightType weight = edgeWeights_[edgeId]; | ||
if(needBias && (weight > noBiasBelow)) { | ||
weight *= bias; | ||
} | ||
|
||
pq.push(std::make_pair(edgeId, weight)); | ||
} | ||
} | ||
} | ||
|
||
// run prim | ||
while(!pq.empty()) { | ||
// extract next element from the queue | ||
const PQElement elem = pq.top(); | ||
pq.pop(); | ||
|
||
const std::size_t edgeId = elem.first; | ||
|
||
const auto u = graph_.u(edgeId); | ||
const auto v = graph_.v(edgeId); | ||
const NodeType lu = seeds[u]; | ||
const NodeType lv = seeds[v]; | ||
|
||
// check for seeds | ||
if(lu == 0 && lv == 0){ | ||
throw std::runtime_error("both have no labels"); | ||
} | ||
else if(lu != 0 && lv != 0){ | ||
continue; | ||
} | ||
|
||
const auto unlabeledNode = lu == 0 ? u : v; | ||
const NodeType seedId = lu == 0 ? lv : lu; | ||
|
||
// assign seedId to unlabeled node | ||
seeds[unlabeledNode] = seedId; | ||
|
||
// check if this is a background seed and thus we use bias | ||
const bool needBias = seedId == backgroundSeedLabel; | ||
|
||
// put outgoing edges on the pq | ||
for(auto adjIt = graph_.adjacencyBegin(unlabeledNode); adjIt != graph_.adjacencyEnd(unlabeledNode); ++adjIt) { | ||
const std::size_t nextEdge = adjIt->edge(); | ||
|
||
// check that this is not the ingoing edge | ||
if(nextEdge == edgeId) { | ||
continue; | ||
} | ||
// check that the node is not labeled | ||
const auto nextNode = adjIt->node(); | ||
if(seeds[nextNode] != 0) { | ||
continue; | ||
} | ||
|
||
WeightType weight = edgeWeights_[nextEdge]; | ||
if(needBias && weight > noBiasBelow) { | ||
weight *= bias; | ||
} | ||
pq.push(std::make_pair(nextEdge, weight)); | ||
} | ||
|
||
} | ||
} | ||
|
||
template<class NODES> | ||
inline void runKruskal(NODES & seeds) const { | ||
typedef typename NODES::value_type NodeType; | ||
// make union find and map seeds to reperesentatives | ||
std::vector<uint64_t> ranks(nNodes_); | ||
std::vector<uint64_t> parents(nNodes_); | ||
boost::disjoint_sets<uint64_t*, uint64_t*> ufd(&ranks[0], &parents[0]); | ||
|
||
for(std::size_t node = 0; node < nNodes_; ++node) { | ||
ufd.make_set(node); | ||
} | ||
|
||
// run kruskal | ||
for(const std::size_t edgeId : edgesSorted_) { | ||
// get the nodes connected by this edge | ||
// and the representatives | ||
const uint64_t u = graph_.u(edgeId); | ||
const uint64_t v = graph_.v(edgeId); | ||
const uint64_t ru = ufd.find_set(u); | ||
const uint64_t rv = ufd.find_set(v); | ||
|
||
// if the representatives are the same, continue | ||
if(ru == rv) { | ||
continue; | ||
} | ||
|
||
// get the seeds for our reperesentatives | ||
const NodeType lu = seeds[ru]; | ||
const NodeType lv = seeds[rv]; | ||
|
||
// if we have two seeded regions (both values different from 0) continue | ||
if(lu !=0 && lv != 0) { | ||
continue; | ||
} | ||
|
||
// otherwise link the two representatives | ||
ufd.link(ru, rv); | ||
|
||
// if we have a seed, propagate it | ||
if(lu != 0) { | ||
seeds[rv] = lu; | ||
} | ||
if(lv != 0) { | ||
seeds[ru] = lv; | ||
} | ||
} | ||
|
||
// write all seeds | ||
for(std::size_t node = 0; node < nNodes_; ++node) { | ||
auto & seed = seeds[node]; | ||
if(seed == 0) { | ||
seed = seeds[ufd.find_set(node)]; | ||
} | ||
} | ||
} | ||
|
||
private: | ||
const GRAPH & graph_; | ||
std::size_t nNodes_; | ||
std::vector<float> edgeWeights_; | ||
std::vector<std::size_t> edgesSorted_; | ||
}; | ||
|
||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
addPythonModule( | ||
NESTED_NAME | ||
nifty/carving | ||
SOURCES | ||
carving.cxx | ||
LIBRRARIES | ||
${Boost_FILESYSTEM_LIBRARY} | ||
${Boost_SYSTEM_LIBRARY} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/numpy.h> | ||
|
||
#define FORCE_IMPORT_ARRAY | ||
#include "xtensor-python/pytensor.hpp" | ||
#include "nifty/carving/carving.hxx" | ||
#include "nifty/graph/rag/grid_rag.hxx" | ||
|
||
namespace py = pybind11; | ||
|
||
namespace nifty{ | ||
namespace carving{ | ||
|
||
|
||
template<class GRAPH> | ||
void exportCarvingT(py::module & module, | ||
const std::string & graphName) { | ||
|
||
typedef xt::pytensor<float, 1> WeightsType; | ||
typedef GRAPH GraphType; | ||
typedef CarvingSegmenter<GraphType> CarvingType; | ||
const auto clsName = std::string("CarvingSegmenter") + graphName; | ||
py::class_<CarvingType>(module, clsName.c_str()) | ||
.def(py::init<const GraphType &, const WeightsType &, bool>(), | ||
py::arg("graph"), | ||
py::arg("edgeWeights"), | ||
py::arg("sortEdges")=true) | ||
|
||
// TODO for some reason pure call by reference does not work | ||
// and we still need to return the seeds to see a change | ||
.def("__call__", [](const CarvingType & self, | ||
xt::pytensor<uint8_t, 1> & seeds, | ||
const double bias, | ||
const double noBiasBelow){ | ||
{ | ||
py::gil_scoped_release allowThreads; | ||
self(seeds, bias, noBiasBelow); | ||
} | ||
return seeds; | ||
}, py::arg("seeds"), | ||
py::arg("bias"), | ||
py::arg("noBiasBelow")) | ||
; | ||
|
||
} | ||
|
||
|
||
void exportCarving(py::module & module) { | ||
|
||
typedef xt::pytensor<uint32_t, 2> ExplicitLabels2D; | ||
typedef graph::GridRag<2, ExplicitLabels2D> Rag2D; | ||
exportCarvingT<Rag2D>(module, "Rag2D"); | ||
|
||
typedef xt::pytensor<uint32_t, 3> ExplicitLabels3D; | ||
typedef graph::GridRag<3, ExplicitLabels3D> Rag3D; | ||
exportCarvingT<Rag3D>(module, "Rag3D"); | ||
} | ||
|
||
} | ||
} | ||
|
||
|
||
PYBIND11_MODULE(_carving, module) { | ||
|
||
xt::import_numpy(); | ||
|
||
py::options options; | ||
options.disable_function_signatures(); | ||
|
||
module.doc() = "carving submodule of nifty"; | ||
|
||
using namespace nifty::carving; | ||
exportCarving(module); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.