Ponca  7d8ac87a7de01d881c9fde3c42e397b44bffb901
Point Cloud Analysis library
Loading...
Searching...
No Matches
kdTree.hpp
1/*
2 This Source Code Form is subject to the terms of the Mozilla Public
3 License, v. 2.0. If a copy of the MPL was not distributed with this
4 file, You can obtain one at http://mozilla.org/MPL/2.0/.
5*/
6
7// KdTree ----------------------------------------------------------------------
8
9template <typename Traits>
10template <typename PointUserContainer, typename PointConverter>
11PONCA_MULTIARCH_HOST inline void KdTreeBase<Traits>::build(PointUserContainer&& points, PointConverter c)
12{
13 IndexContainer ids(points.size());
14 std::iota(ids.begin(), ids.end(), IndexType(0));
15 this->buildWithSampling(std::forward<PointUserContainer>(points), std::move(ids), std::move(c));
16}
17
18template <typename Traits>
19PONCA_MULTIARCH_HOST [[nodiscard]] inline bool StaticKdTreeBase<Traits>::valid() const
20{
21 if (m_bufs.points_size == 0)
22 return m_bufs.nodes_size == 0 && m_bufs.indices_size == 0;
23
24 if (m_bufs.nodes_size == 0 || m_bufs.indices_size == 0)
25 {
26 std::cerr << "KdTree validation check failed in " << __FILE__ << " (" << __LINE__ << ")" << std::endl;
27 return false;
28 }
29
30 std::vector<bool> b(pointCount(), false);
31 for (unsigned int i = 0; i < sampleCount(); ++i)
32 {
33 const int idx = m_bufs.indices[i];
34 if (idx < 0 || pointCount() <= idx || b[idx])
35 {
36 std::cerr << "KdTree validation check failed in " << __FILE__ << " (" << __LINE__ << ")" << std::endl;
37 return false;
38 }
39 b[idx] = true;
40 }
41
42 for (NodeIndexType n = 0; n < nodeCount(); ++n)
43 {
44 const NodeType& node = m_bufs.nodes[n];
45 if (node.is_leaf())
46 {
47 if (sampleCount() <= node.leaf_start() || node.leaf_start() + node.leaf_size() > sampleCount())
48 {
49 std::cerr << "KdTree validation check failed in " << __FILE__ << " (" << __LINE__ << ")" << std::endl;
50 return false;
51 }
52 }
53 else
54 {
55 if (node.inner_split_dim() < 0 || DataPoint::Dim - 1 < node.inner_split_dim())
56 {
57 std::cerr << "KdTree validation check failed in " << __FILE__ << " (" << __LINE__ << ")" << std::endl;
58 return false;
59 }
60 if (nodeCount() <= node.inner_first_child_id() || nodeCount() <= node.inner_first_child_id() + 1)
61 {
62 std::cerr << "KdTree validation check failed in " << __FILE__ << " (" << __LINE__ << ")" << std::endl;
63 return false;
64 }
65 }
66 }
67
68 return true;
69}
70
71template <typename Traits>
72PONCA_MULTIARCH_HOST inline void StaticKdTreeBase<Traits>::print(std::ostream& os, bool verbose) const
73{
74 os << "KdTree:";
75 os << "\n MaxNodes: " << MAX_NODE_COUNT;
76 os << "\n MaxPoints: " << MAX_POINT_COUNT;
77 os << "\n MaxDepth: " << MAX_DEPTH;
78 os << "\n PointCount: " << pointCount();
79 os << "\n SampleCount: " << sampleCount();
80 os << "\n NodeCount: " << nodeCount();
81
82 if (!verbose)
83 {
84 return;
85 }
86
87 os << "\n Samples: [";
88 static constexpr IndexType SAMPLES_PER_LINE = 10;
89 for (IndexType i = 0; i < sampleCount(); ++i)
90 {
91 os << (i == 0 ? "" : ",");
92 os << (i % SAMPLES_PER_LINE == 0 ? "\n " : " ");
93 os << m_bufs.indices[i];
94 }
95
96 os << "]\n Nodes:";
97 for (NodeIndexType n = 0; n < nodeCount(); ++n)
98 {
99 const NodeType& node = m_bufs.nodes[n];
100 if (node.is_leaf())
101 {
102 os << "\n - Type: Leaf";
103 os << "\n Start: " << node.leaf_start();
104 os << "\n Size: " << node.leaf_size();
105 }
106 else
107 {
108 os << "\n - Type: Inner";
109 os << "\n SplitDim: " << node.inner_split_dim();
110 os << "\n SplitValue: " << node.inner_split_value();
111 os << "\n FirstChild: " << node.inner_first_child_id();
112 }
113 }
114}
115
116template <typename Traits>
117template <typename PointUserContainer, typename IndexUserContainer, typename PointConverter>
118PONCA_MULTIARCH_HOST inline void KdTreeBase<Traits>::buildWithSampling(PointUserContainer&& points,
119 IndexUserContainer&& sampling, PointConverter c)
120{
121 PONCA_DEBUG_ASSERT(static_cast<IndexType>(Base::pointCount()) <= Base::MAX_POINT_COUNT);
122 Base::m_leaf_count = 0;
123
124 // Move, copy or convert input samples
125 c(std::forward<PointUserContainer>(points), Base::m_bufs.points);
126 Base::m_bufs.points_size = points.size();
127
128 Base::m_bufs.indices_size = sampling.size();
129 Base::m_bufs.indices = std::move(sampling);
130
131 Base::m_bufs.nodes.reserve(4 * Base::pointCount() / Base::m_min_cell_size);
132 Base::m_bufs.nodes.emplace_back();
133
134 this->buildRec(0, 0, Base::sampleCount(), 1);
135 Base::m_bufs.nodes_size = Base::m_bufs.nodes.size();
136
137 PONCA_DEBUG_ASSERT(this->valid());
138}
139
140template <typename Traits>
141PONCA_MULTIARCH_HOST inline void KdTreeBase<Traits>::buildRec(NodeIndexType node_id, IndexType start, IndexType end,
142 int level)
143{
144 NodeType& node = Base::m_bufs.nodes[node_id];
145 AabbType aabb;
146 for (IndexType i = start; i < end; ++i)
147 aabb.extend(Base::m_bufs.points[Base::m_bufs.indices[i]].pos());
148
149 node.set_is_leaf(end - start <= Base::m_min_cell_size || level >= Traits::MAX_DEPTH ||
150 // Since we add 2 nodes per inner node we need to stop if we can't add
151 // them both
152 static_cast<NodeIndexType>(Base::m_bufs.nodes.size()) > Base::MAX_NODE_COUNT - 2);
153
154 node.configure_range(start, end - start, aabb);
155 if (node.is_leaf())
156 {
157 ++Base::m_leaf_count;
158 }
159 else
160 {
161 IndexType split_dim = 0;
162 (Scalar(0.5) * aabb.diagonal()).maxCoeff(&split_dim);
163 node.configure_inner(aabb.center()[split_dim], static_cast<IndexType>(Base::m_bufs.nodes.size()), split_dim);
164 Base::m_bufs.nodes.emplace_back();
165 Base::m_bufs.nodes.emplace_back();
166
167 IndexType mid_id = this->partition(start, end, split_dim, node.inner_split_value());
168 buildRec(node.inner_first_child_id(), start, mid_id, level + 1);
169 buildRec(node.inner_first_child_id() + 1, mid_id, end, level + 1);
170 }
171}
172
173template <typename Traits>
174PONCA_MULTIARCH_HOST [[nodiscard]] inline auto KdTreeBase<Traits>::partition(IndexType start, IndexType end, int dim,
175 Scalar value) -> IndexType
176{
177 const auto& points = Base::m_bufs.points;
178
179 auto it = std::partition(std::begin(Base::m_bufs.indices) + start, std::begin(Base::m_bufs.indices) + end,
180 [&](IndexType i) { return points[i].pos()[dim] < value; });
181
182 auto distance = std::distance(std::begin(Base::m_bufs.indices), it);
183
184 return static_cast<IndexType>(distance);
185}