knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
Vptree.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_VPTREE_HPP
2#define KNNCOLLE_VPTREE_HPP
3
4#include "distances.hpp"
5#include "NeighborQueue.hpp"
6#include "Prebuilt.hpp"
7#include "Builder.hpp"
8#include "MockMatrix.hpp"
9#include "report_all_neighbors.hpp"
10
11#include <vector>
12#include <random>
13#include <limits>
14#include <tuple>
15
22namespace knncolle {
23
24template<class Distance_, typename Dim_, typename Index_, typename Store_, typename Float_>
25class VptreePrebuilt;
26
38template<class Distance_, typename Dim_, typename Index_, typename Store_, typename Float_>
39class VptreeSearcher : public Searcher<Index_, Float_> {
40public:
49private:
51 internal::NeighborQueue<Index_, Float_> my_nearest;
52 std::vector<std::pair<Float_, Index_> > my_all_neighbors;
53
54public:
55 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
56 my_nearest.reset(k + 1);
57 auto iptr = my_parent->my_data.data() + static_cast<size_t>(my_parent->my_new_locations[i]) * my_parent->my_long_ndim; // cast to avoid overflow.
58 Float_ max_dist = std::numeric_limits<Float_>::max();
59 my_parent->search_nn(0, iptr, max_dist, my_nearest);
60 my_nearest.report(output_indices, output_distances, i);
61 }
62
63 void search(const Float_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
64 if (k == 0) { // protect the NeighborQueue from k = 0.
65 internal::flush_output(output_indices, output_distances, 0);
66 } else {
67 my_nearest.reset(k);
68 Float_ max_dist = std::numeric_limits<Float_>::max();
69 my_parent->search_nn(0, query, max_dist, my_nearest);
70 my_nearest.report(output_indices, output_distances);
71 }
72 }
73
74 bool can_search_all() const {
75 return true;
76 }
77
78 Index_ search_all(Index_ i, Float_ d, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
79 auto iptr = my_parent->my_data.data() + static_cast<size_t>(my_parent->my_new_locations[i]) * my_parent->my_long_ndim; // cast to avoid overflow.
80
82 Index_ count = 0;
83 my_parent->template search_all<true>(0, iptr, d, count);
84 return internal::safe_remove_self(count);
85
86 } else {
87 my_all_neighbors.clear();
88 my_parent->template search_all<false>(0, iptr, d, my_all_neighbors);
89 internal::report_all_neighbors(my_all_neighbors, output_indices, output_distances, i);
90 return internal::safe_remove_self(my_all_neighbors.size());
91 }
92 }
93
94 Index_ search_all(const Float_* query, Float_ d, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
96 Index_ count = 0;
97 my_parent->template search_all<true>(0, query, d, count);
98 return count;
99
100 } else {
101 my_all_neighbors.clear();
102 my_parent->template search_all<false>(0, query, d, my_all_neighbors);
103 internal::report_all_neighbors(my_all_neighbors, output_indices, output_distances);
104 return my_all_neighbors.size();
105 }
106 }
107};
108
124template<class Distance_, typename Dim_, typename Index_, typename Store_, typename Float_>
125class VptreePrebuilt : public Prebuilt<Dim_, Index_, Float_> {
126private:
127 Dim_ my_dim;
128 Index_ my_obs;
129 size_t my_long_ndim;
130 std::vector<Store_> my_data;
131
132public:
133 Index_ num_observations() const {
134 return my_obs;
135 }
136
137 Dim_ num_dimensions() const {
138 return my_dim;
139 }
140
141private:
142 /* Adapted from http://stevehanov.ca/blog/index.php?id=130 */
143 static const Index_ LEAF = 0;
144
145 // Single node of a VP tree.
146 struct Node {
147 Float_ radius = 0;
148
149 // Original index of current vantage point, defining the center of the node.
150 Index_ index = 0;
151
152 // Node index of the next vantage point for all children closer than 'threshold' from the current vantage point.
153 // This must be > 0, as the first node in 'nodes' is the root and cannot be referenced from other nodes.
154 Index_ left = LEAF;
155
156 // Node index of the next vantage point for all children further than 'threshold' from the current vantage point.
157 // This must be > 0, as the first node in 'nodes' is the root and cannot be referenced from other nodes.
158 Index_ right = LEAF;
159 };
160
161 std::vector<Node> my_nodes;
162
163 typedef std::pair<Float_, Index_> DataPoint;
164
165 template<class Rng_>
166 Index_ build(Index_ lower, Index_ upper, const Store_* coords, std::vector<DataPoint>& items, Rng_& rng) {
167 /*
168 * We're assuming that lower < upper at each point within this
169 * recursion. This requires some protection at the call site
170 * when nobs = 0, see the constructor.
171 */
172
173 Index_ pos = my_nodes.size();
174 my_nodes.emplace_back();
175 Node& node = my_nodes.back(); // this is safe during recursion because we reserved 'nodes' already to the number of observations, see the constructor.
176
177 Index_ gap = upper - lower;
178 if (gap > 1) { // not yet at a leaft.
179
180 /* Choose an arbitrary point and move it to the start of the [lower, upper)
181 * interval in 'items'; this is our new vantage point.
182 *
183 * Yes, I know that the modulo method does not provide strictly
184 * uniform values but statistical correctness doesn't really matter
185 * here, and I don't want std::uniform_int_distribution's
186 * implementation-specific behavior.
187 */
188 Index_ i = (rng() % gap + lower);
189 std::swap(items[lower], items[i]);
190 const auto& vantage = items[lower];
191 node.index = vantage.second;
192 const Store_* vantage_ptr = coords + static_cast<size_t>(vantage.second) * my_long_ndim; // cast to avoid overflow.
193
194 // Compute distances to the new vantage point.
195 for (Index_ i = lower + 1; i < upper; ++i) {
196 const Store_* loc = coords + static_cast<size_t>(items[i].second) * my_long_ndim; // cast to avoid overflow.
197 items[i].first = Distance_::template raw_distance<Float_>(vantage_ptr, loc, my_dim);
198 }
199
200 // Partition around the median distance from the vantage point.
201 Index_ median = lower + gap/2;
202 Index_ lower_p1 = lower + 1; // excluding the vantage point itself, obviously.
203 std::nth_element(items.begin() + lower_p1, items.begin() + median, items.begin() + upper);
204
205 // Radius of the new node will be the distance to the median.
206 node.radius = Distance_::normalize(items[median].first);
207
208 // Recursively build tree.
209 if (lower_p1 < median) {
210 node.left = build(lower_p1, median, coords, items, rng);
211 }
212 if (median < upper) {
213 node.right = build(median, upper, coords, items, rng);
214 }
215
216 } else {
217 const auto& leaf = items[lower];
218 node.index = leaf.second;
219 }
220
221 return pos;
222 }
223
224private:
225 std::vector<Index_> my_new_locations;
226
227public:
233 VptreePrebuilt(Dim_ num_dim, Index_ num_obs, std::vector<Store_> data) :
234 my_dim(num_dim),
235 my_obs(num_obs),
236 my_long_ndim(my_dim),
237 my_data(std::move(data))
238 {
239 if (num_obs) {
240 std::vector<DataPoint> items;
241 items.reserve(my_obs);
242 for (Index_ i = 0; i < my_obs; ++i) {
243 items.emplace_back(0, i);
244 }
245
246 my_nodes.reserve(my_obs);
247
248 // Statistical correctness doesn't matter (aside from tie breaking)
249 // so we'll just use a deterministically 'random' number to ensure
250 // we get the same ties for any given dataset but a different stream
251 // of numbers between datasets. Casting to get well-defined overflow.
252 uint64_t base = 1234567890, m1 = my_obs, m2 = my_dim;
253 std::mt19937_64 rand(base * m1 + m2);
254
255 build(0, my_obs, my_data.data(), items, rand);
256
257 // Resorting data in place to match order of occurrence within
258 // 'nodes', for better cache locality.
259 std::vector<uint8_t> used(my_obs);
260 std::vector<Store_> buffer(my_dim);
261 my_new_locations.resize(my_obs);
262 auto host = my_data.data();
263
264 for (Index_ o = 0; o < num_obs; ++o) {
265 if (used[o]) {
266 continue;
267 }
268
269 auto& current = my_nodes[o];
270 my_new_locations[current.index] = o;
271 if (current.index == o) {
272 continue;
273 }
274
275 auto optr = host + static_cast<size_t>(o) * my_long_ndim;
276 std::copy_n(optr, my_dim, buffer.begin());
277 Index_ replacement = current.index;
278
279 do {
280 auto rptr = host + static_cast<size_t>(replacement) * my_long_ndim;
281 std::copy_n(rptr, my_dim, optr);
282 used[replacement] = 1;
283
284 const auto& next = my_nodes[replacement];
285 my_new_locations[next.index] = replacement;
286
287 optr = rptr;
288 replacement = next.index;
289 } while (replacement != o);
290
291 std::copy(buffer.begin(), buffer.end(), optr);
292 }
293 }
294 }
295
296private:
297 template<typename Query_>
298 void search_nn(Index_ curnode_index, const Query_* target, Float_& max_dist, internal::NeighborQueue<Index_, Float_>& nearest) const {
299 auto nptr = my_data.data() + static_cast<size_t>(curnode_index) * my_long_ndim; // cast to avoid overflow.
300 Float_ dist = Distance_::normalize(Distance_::template raw_distance<Float_>(nptr, target, my_dim));
301
302 // If current node is within the maximum distance:
303 const auto& curnode = my_nodes[curnode_index];
304 if (dist <= max_dist) {
305 nearest.add(curnode.index, dist);
306 if (nearest.is_full()) {
307 max_dist = nearest.limit(); // update value of max_dist (farthest point in result list)
308 }
309 }
310
311 if (dist < curnode.radius) { // If the target lies within the radius of ball:
312 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child first
313 search_nn(curnode.left, target, max_dist, nearest);
314 }
315
316 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child
317 search_nn(curnode.right, target, max_dist, nearest);
318 }
319
320 } else { // If the target lies outsize the radius of the ball:
321 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child first
322 search_nn(curnode.right, target, max_dist, nearest);
323 }
324
325 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child
326 search_nn(curnode.left, target, max_dist, nearest);
327 }
328 }
329 }
330
331 template<bool count_only_, typename Query_, typename Output_>
332 void search_all(Index_ curnode_index, const Query_* target, Float_ threshold, Output_& all_neighbors) const {
333 auto nptr = my_data.data() + static_cast<size_t>(curnode_index) * my_long_ndim; // cast to avoid overflow.
334 Float_ dist = Distance_::normalize(Distance_::template raw_distance<Float_>(nptr, target, my_dim));
335
336 // If current node is within the maximum distance:
337 const auto& curnode = my_nodes[curnode_index];
338 if (dist <= threshold) {
339 if constexpr(count_only_) {
340 ++all_neighbors;
341 } else {
342 all_neighbors.emplace_back(dist, curnode.index);
343 }
344 }
345
346 if (dist < curnode.radius) { // If the target lies within the radius of ball:
347 if (curnode.left != LEAF && dist - threshold <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child first
348 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
349 }
350
351 if (curnode.right != LEAF && dist + threshold >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child
352 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
353 }
354
355 } else { // If the target lies outsize the radius of the ball:
356 if (curnode.right != LEAF && dist + threshold >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child first
357 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
358 }
359
360 if (curnode.left != LEAF && dist - threshold <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child
361 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
362 }
363 }
364 }
365
366 friend class VptreeSearcher<Distance_, Dim_, Index_, Store_, Float_>;
367
368public:
372 std::unique_ptr<Searcher<Index_, Float_> > initialize() const {
373 return std::make_unique<VptreeSearcher<Distance_, Dim_, Index_, Store_, Float_> >(this);
374 }
375};
376
405template<class Distance_ = EuclideanDistance, class Matrix_ = SimpleMatrix<int, int, double>, typename Float_ = double>
406class VptreeBuilder : public Builder<Matrix_, Float_> {
407public:
412 auto ndim = data.num_dimensions();
413 auto nobs = data.num_observations();
414
415 typedef typename Matrix_::data_type Store_;
416 std::vector<typename Matrix_::data_type> store(static_cast<size_t>(ndim) * static_cast<size_t>(nobs));
417
418 auto work = data.create_workspace();
419 auto sIt = store.begin();
420 for (decltype(nobs) o = 0; o < nobs; ++o, sIt += ndim) {
421 auto ptr = data.get_observation(work);
422 std::copy_n(ptr, ndim, sIt);
423 }
424
426 }
427};
428
429};
430
431#endif
Interface to build nearest-neighbor indices.
Interface for prebuilt nearest-neighbor indices.
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:22
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:28
Interface for searching nearest-neighbor search indices.
Definition Searcher.hpp:28
Perform a nearest neighbor search based on a vantage point (VP) tree.
Definition Vptree.hpp:406
Prebuilt< typename Matrix_::dimension_type, typename Matrix_::index_type, Float_ > * build_raw(const Matrix_ &data) const
Definition Vptree.hpp:411
Index for a VP-tree search.
Definition Vptree.hpp:125
Index_ num_observations() const
Definition Vptree.hpp:133
VptreePrebuilt(Dim_ num_dim, Index_ num_obs, std::vector< Store_ > data)
Definition Vptree.hpp:233
std::unique_ptr< Searcher< Index_, Float_ > > initialize() const
Definition Vptree.hpp:372
Dim_ num_dimensions() const
Definition Vptree.hpp:137
VP-tree searcher.
Definition Vptree.hpp:39
void search(Index_ i, Index_ k, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)
Definition Vptree.hpp:55
Index_ search_all(const Float_ *query, Float_ d, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)
Definition Vptree.hpp:94
Index_ search_all(Index_ i, Float_ d, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)
Definition Vptree.hpp:78
bool can_search_all() const
Definition Vptree.hpp:74
void search(const Float_ *query, Index_ k, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)
Definition Vptree.hpp:63
Classes for distance calculations.
Collection of KNN algorithms.
Definition Bruteforce.hpp:22