Ponca  93eea5457c48839cb5d16642765afa89fc7cfe66
Point Cloud Analysis library
Loading...
Searching...
No Matches
kdTree.hpp
1/*
2 This Source Code Form is subject to the terms of the Mozilla Public
3 License, v. 2.0. If a copy of the MPL was not distributed with this
4 file, You can obtain one at http://mozilla.org/MPL/2.0/.
5*/
6
7// KdTree ----------------------------------------------------------------------
8
9template<typename Traits>
10template<typename PointUserContainer, typename PointConverter>
11PONCA_MULTIARCH_HOST inline void KdTreeBase<Traits>::build(PointUserContainer&& points, PointConverter c)
12{
13 IndexContainer ids (points.size());
14 std::iota(ids.begin(), ids.end(), IndexType(0));
15 this->buildWithSampling(std::forward<PointUserContainer>(points), std::move(ids), std::move(c));
16}
17
18template<typename Traits>
19PONCA_MULTIARCH_HOST [[nodiscard]] inline bool StaticKdTreeBase<Traits>::valid() const
20{
21 if (m_bufs.points_size == 0)
22 return m_bufs.nodes_size == 0 && m_bufs.indices_size == 0;
23
24 if(m_bufs.nodes_size == 0 || m_bufs.indices_size == 0)
25 {
26 std::cerr << "KdTree validation check failed in " << __FILE__ << " (" << __LINE__ << ")" << std::endl;
27 return false;
28 }
29
30 std::vector<bool> b(pointCount(), false);
31 for(unsigned int i = 0; i < sampleCount(); ++i)
32 {
33 const int idx = m_bufs.indices[i];
34 if(idx < 0 || pointCount() <= idx || b[idx])
35 {
36 std::cerr << "KdTree validation check failed in " << __FILE__ << " (" << __LINE__ << ")" << std::endl;
37 return false;
38 }
39 b[idx] = true;
40 }
41
42 for(NodeIndexType n=0;n<nodeCount();++n)
43 {
44 const NodeType& node = m_bufs.nodes[n];
45 if(node.is_leaf())
46 {
47 if(sampleCount() <= node.leaf_start() || node.leaf_start()+node.leaf_size() > sampleCount())
48 {
49 std::cerr << "KdTree validation check failed in " << __FILE__ << " (" << __LINE__ << ")" << std::endl;
50 return false;
51 }
52 }
53 else
54 {
55 if(node.inner_split_dim() < 0 || DataPoint::Dim-1 < node.inner_split_dim())
56 {
57 std::cerr << "KdTree validation check failed in " << __FILE__ << " (" << __LINE__ << ")" << std::endl;
58 return false;
59 }
60 if(nodeCount() <= node.inner_first_child_id() || nodeCount() <= node.inner_first_child_id()+1)
61 {
62 std::cerr << "KdTree validation check failed in " << __FILE__ << " (" << __LINE__ << ")" << std::endl;
63 return false;
64 }
65 }
66 }
67
68 return true;
69}
70
71template<typename Traits>
72PONCA_MULTIARCH_HOST inline void StaticKdTreeBase<Traits>::print(std::ostream& os, bool verbose) const
73{
74 os << "KdTree:";
75 os << "\n MaxNodes: " << MAX_NODE_COUNT;
76 os << "\n MaxPoints: " << MAX_POINT_COUNT;
77 os << "\n MaxDepth: " << MAX_DEPTH;
78 os << "\n PointCount: " << pointCount();
79 os << "\n SampleCount: " << sampleCount();
80 os << "\n NodeCount: " << nodeCount();
81
82 if (!verbose)
83 {
84 return;
85 }
86
87 os << "\n Samples: [";
88 static constexpr IndexType SAMPLES_PER_LINE = 10;
89 for (IndexType i = 0; i < sampleCount(); ++i)
90 {
91 os << (i == 0 ? "" : ",");
92 os << (i % SAMPLES_PER_LINE == 0 ? "\n " : " ");
93 os << m_bufs.indices[i];
94 }
95
96 os << "]\n Nodes:";
97 for (NodeIndexType n = 0; n < nodeCount(); ++n)
98 {
99 const NodeType& node = m_bufs.nodes[n];
100 if (node.is_leaf())
101 {
102 os << "\n - Type: Leaf";
103 os << "\n Start: " << node.leaf_start();
104 os << "\n Size: " << node.leaf_size();
105 }
106 else
107 {
108 os << "\n - Type: Inner";
109 os << "\n SplitDim: " << node.inner_split_dim();
110 os << "\n SplitValue: " << node.inner_split_value();
111 os << "\n FirstChild: " << node.inner_first_child_id();
112 }
113 }
114}
115
116template<typename Traits>
117template<typename PointUserContainer, typename IndexUserContainer, typename PointConverter>
118PONCA_MULTIARCH_HOST inline void KdTreeBase<Traits>::buildWithSampling(
119 PointUserContainer&& points, IndexUserContainer&& sampling, PointConverter c
120) {
121 PONCA_DEBUG_ASSERT(static_cast<IndexType>(Base::pointCount()) <= Base::MAX_POINT_COUNT);
122 Base::m_leaf_count = 0;
123
124 // Move, copy or convert input samples
125 c(std::forward<PointUserContainer>(points), Base::m_bufs.points);
126 Base::m_bufs.points_size = points.size();
127
128 Base::m_bufs.indices_size = sampling.size();
129 Base::m_bufs.indices = std::move(sampling);
130
131 Base::m_bufs.nodes.reserve(4 * Base::pointCount() / Base::m_min_cell_size);
132 Base::m_bufs.nodes.emplace_back();
133
134 this->buildRec(0, 0, Base::sampleCount(), 1);
135 Base::m_bufs.nodes_size = Base::m_bufs.nodes.size();
136
137 PONCA_DEBUG_ASSERT(this->valid());
138}
139
140template<typename Traits>
141PONCA_MULTIARCH_HOST inline void KdTreeBase<Traits>::buildRec(NodeIndexType node_id, IndexType start, IndexType end, int level)
142{
143 NodeType& node = Base::m_bufs.nodes[node_id];
144 AabbType aabb;
145 for(IndexType i=start; i<end; ++i)
146 aabb.extend(Base::m_bufs.points[Base::m_bufs.indices[i]].pos());
147
148 node.set_is_leaf(
149 end-start <= Base::m_min_cell_size ||
150 level >= Traits::MAX_DEPTH ||
151 // Since we add 2 nodes per inner node we need to stop if we can't add
152 // them both
153 static_cast<NodeIndexType>(Base::m_bufs.nodes.size()) > Base::MAX_NODE_COUNT - 2);
154
155 node.configure_range(start, end-start, aabb);
156 if (node.is_leaf())
157 {
158 ++Base::m_leaf_count;
159 }
160 else
161 {
162 IndexType split_dim = 0;
163 (Scalar(0.5) * aabb.diagonal()).maxCoeff(&split_dim);
164 node.configure_inner(aabb.center()[split_dim], static_cast<IndexType>(Base::m_bufs.nodes.size()), split_dim);
165 Base::m_bufs.nodes.emplace_back();
166 Base::m_bufs.nodes.emplace_back();
167
168 IndexType mid_id = this->partition(start, end, split_dim, node.inner_split_value());
169 buildRec(node.inner_first_child_id(), start, mid_id, level+1);
170 buildRec(node.inner_first_child_id()+1, mid_id, end, level+1);
171 }
172}
173
174template<typename Traits>
175PONCA_MULTIARCH_HOST [[nodiscard]] inline auto KdTreeBase<Traits>::partition(IndexType start, IndexType end, int dim, Scalar value)
176 -> IndexType
177{
178 const auto& points = Base::m_bufs.points;
179
180 auto it = std::partition(std::begin(Base::m_bufs.indices)+start, std::begin(Base::m_bufs.indices)+end, [&](IndexType i)
181 {
182 return points[i].pos()[dim] < value;
183 });
184
185 auto distance = std::distance(std::begin(Base::m_bufs.indices), it);
186
187 return static_cast<IndexType>(distance);
188}