knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
Vptree.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_VPTREE_HPP
2#define KNNCOLLE_VPTREE_HPP
3
4#include "distances.hpp"
5#include "NeighborQueue.hpp"
6#include "Prebuilt.hpp"
7#include "Builder.hpp"
8#include "Matrix.hpp"
10#include "utils.hpp"
11
12#include <vector>
13#include <random>
14#include <limits>
15#include <tuple>
16#include <memory>
17#include <cstddef>
18#include <string>
19#include <cstring>
20#include <filesystem>
21#include <cassert>
22#include <optional>
23
24#include "sanisizer/sanisizer.hpp"
25
32namespace knncolle {
33
37inline static constexpr const char* vptree_prebuilt_save_name = "knncolle::Vptree";
38
47 std::optional<typename std::mt19937_64::result_type> seed;
48};
49
53template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
54class VptreePrebuilt;
55
56template<typename Index_>
57struct VptreeSearchHistory {
58 VptreeSearchHistory(bool right, Index_ node) : node(node), right(right) {}
59 Index_ node;
60 bool right;
61};
62
63template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
64class VptreeSearcher final : public Searcher<Index_, Data_, Distance_> {
65public:
66 VptreeSearcher(const VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>& parent) : my_parent(parent) {}
67
68private:
69 const VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>& my_parent;
70 NeighborQueue<Index_, Distance_> my_nearest;
71 std::vector<VptreeSearchHistory<Index_> > my_history;
72 std::vector<std::pair<Distance_, Index_> > my_all_neighbors;
73
74public:
75 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
76 my_nearest.reset(k + 1); // +1 is safe as k < num_obs.
77 auto iptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(my_parent.my_new_locations[i], my_parent.my_dim);
78 my_parent.search_nn(iptr, my_nearest, my_history);
79 my_nearest.report(output_indices, output_distances, i);
80 }
81
82 void search(const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
83 // Protect the NeighborQueue from k = 0. This also protects search_nn()
84 // when there are no observations (and no node 0 to start recursion).
85 if (k == 0 || my_parent.my_nodes.empty()) {
86 if (output_indices) {
87 output_indices->clear();
88 }
89 if (output_distances) {
90 output_distances->clear();
91 }
92
93 } else {
94 my_nearest.reset(k);
95 my_parent.search_nn(query, my_nearest, my_history);
96 my_nearest.report(output_indices, output_distances);
97 }
98 }
99
100 bool can_search_all() const {
101 return true;
102 }
103
104 Index_ search_all(Index_ i, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
105 auto iptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(my_parent.my_new_locations[i], my_parent.my_dim);
106
107 if (!output_indices && !output_distances) {
108 Index_ count = 0;
109 my_parent.template search_all<true>(iptr, d, count, my_history);
111
112 } else {
113 my_all_neighbors.clear();
114 my_parent.template search_all<false>(iptr, d, my_all_neighbors, my_history);
115 report_all_neighbors(my_all_neighbors, output_indices, output_distances, i);
116 return count_all_neighbors_without_self(my_all_neighbors.size());
117 }
118 }
119
120 Index_ search_all(const Data_* query, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
121 if (my_parent.my_nodes.empty()) { // protect the search_all() method when there is not even a node 0 to start the recursion.
122 my_all_neighbors.clear();
123 report_all_neighbors(my_all_neighbors, output_indices, output_distances);
124 return 0;
125 }
126
127 if (!output_indices && !output_distances) {
128 Index_ count = 0;
129 my_parent.template search_all<true>(query, d, count, my_history);
130 return count;
131
132 } else {
133 my_all_neighbors.clear();
134 my_parent.template search_all<false>(query, d, my_all_neighbors, my_history);
135 report_all_neighbors(my_all_neighbors, output_indices, output_distances);
136 return my_all_neighbors.size();
137 }
138 }
139};
140
141template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
142class VptreePrebuilt final : public Prebuilt<Index_, Data_, Distance_> {
143private:
144 std::size_t my_dim;
145 Index_ my_obs;
146 std::vector<Data_> my_data;
147 std::shared_ptr<const DistanceMetric_> my_metric;
148
149public:
150 Index_ num_observations() const {
151 return my_obs;
152 }
153
154 std::size_t num_dimensions() const {
155 return my_dim;
156 }
157
158private:
159 /* Adapted from http://stevehanov.ca/blog/index.php?id=130 */
160
161 // Normally, 'left' or 'right' must be > 0, as the first node in 'nodes' is
162 // the root and cannot be referenced from other nodes. This means that we
163 // can use 0 as a sentinel to indicate that no child exists here.
164 static constexpr Index_ TERMINAL = 0;
165
166 // Single node of a VP tree.
167 struct Node {
168 Distance_ radius = 0;
169
170 // Original index of current vantage point, defining the center of the node.
171 Index_ index = 0;
172
173 // Node index of the next vantage point for all children no more than 'threshold' from the current vantage point.
174 Index_ left = TERMINAL;
175
176 // Node index of the next vantage point for all children no less than 'threshold' from the current vantage point.
177 Index_ right = TERMINAL;
178 };
179
180 std::vector<Node> my_nodes;
181
182 void build(const VptreeOptions& options) {
183 typedef std::pair<Distance_, Index_> DataPoint;
184 std::vector<DataPoint> items;
185 items.reserve(my_obs);
186 for (Index_ i = 0; i < my_obs; ++i) {
187 items.emplace_back(0, i);
188 }
189
190 std::mt19937_64 rng([&]() {
191 if (options.seed.has_value()) {
192 return *(options.seed);
193 }
194
195 // Statistical correctness doesn't matter (aside from tie breaking)
196 // so we'll just use a deterministically 'random' number to ensure
197 // we get the same ties for any given dataset but a different stream
198 // of numbers between datasets. Casting to get well-defined overflow.
199 typedef typename std::mt19937_64::result_type SeedType;
200 const SeedType base = 1234567890, m1 = my_obs, m2 = my_dim;
201 return static_cast<SeedType>(base * m1 + m2);
202 }());
203
204 // We're assuming that lower < upper at each loop. This requires some protection at the call site when nobs = 0, see the constructor.
205 Index_ lower = 0, upper = my_obs;
206
207 // Reserving everything so there there won't be a reallocation, which ensures that pointers to various members will remain valid.
208 my_nodes.reserve(my_obs);
209 const auto coords = my_data.data();
210
211 struct BuildHistory {
212 BuildHistory(Index_ lower, Index_ upper, Index_* right) : right(right), lower(lower), upper(upper) {}
213 Index_* right; // This is a pointer to the 'Node::right' of the parent of the node-to-be-added.
214 Index_ lower, upper; // Lower and upper ranges of the items in the node-to-be-added.
215 };
216 std::vector<BuildHistory> history;
217
218 while (1) {
219 my_nodes.emplace_back();
220 Node& node = my_nodes.back();
221
222 const Index_ gap = upper - lower;
223 assert(gap > 0);
224 if (gap == 1) { // i.e., we're at a leaf.
225 const auto& leaf = items[lower];
226 node.index = leaf.second;
227
228 // If we're at a leaf, we've finished this particular branch of the tree, so we can start rolling back through history.
229 if (history.empty()) {
230 return;
231 }
232 *(history.back().right) = my_nodes.size();
233 lower = history.back().lower;
234 upper = history.back().upper;
235 history.pop_back();
236 continue;
237 }
238
239 // Choose an arbitrary point and move it to the start of the [lower, upper) interval in 'items'; this is our new vantage point.
240 // Yes, I know that the modulo method does not provide strictly uniform values but statistical correctness doesn't really matter here,
241 // and I don't want std::uniform_int_distribution's implementation-specific behavior.
242 const Index_ vp = (rng() % gap + lower);
243 std::swap(items[lower], items[vp]);
244 const auto& vantage = items[lower];
245 node.index = vantage.second;
246 const Data_* vantage_ptr = coords + sanisizer::product_unsafe<std::size_t>(vantage.second, my_dim);
247
248 // Compute distances to the new vantage point.
249 // We +1 to exclude the vantage point itself, obviously.
250 const Index_ lower_p1 = lower + 1;
251 for (Index_ i = lower_p1 ; i < upper; ++i) {
252 const Data_* loc = coords + sanisizer::product_unsafe<std::size_t>(items[i].second, my_dim);
253 items[i].first = my_metric->raw(my_dim, vantage_ptr, loc);
254 }
255
256 if (gap > 2) {
257 // Partition around the median distance from the vantage point.
258 const Index_ median = lower_p1 + (gap - 1)/2;
259 std::nth_element(items.begin() + lower_p1, items.begin() + median, items.begin() + upper);
260
261 // Radius of the new node will be the distance to the median.
262 node.radius = my_metric->normalize(items[median].first);
263
264 // The next iteration will process the left node (i.e., inside the ball).
265 // We store the boundaries of the yet-to-be-added right node to the history for later processing.
266 history.emplace_back(median, upper, &(node.right));
267 node.left = my_nodes.size();
268 lower = lower_p1;
269 upper = median;
270
271 } else {
272 // Here we only have one child, as this node has two observations and one of them was already used as the vantage point.
273 // So the other observation is used directly as the right node.
274 const Index_ median = lower_p1;
275 node.radius = my_metric->normalize(items[median].first);
276 node.right = my_nodes.size();
277 lower = median;
278
279 // Several points worth mentioning here:
280 // - No need to set upper, as we'd end up just doing upper = upper and clang complains.
281 // - This code allows us to get a node where left = TERMINAL and right != TERMINAL, but the opposite is impossible.
282 // This fact is exploited in search_nn() for some minor optimizations.
283 }
284 }
285 }
286
287private:
288 std::vector<Index_> my_new_locations;
289
290public:
291 VptreePrebuilt(std::size_t num_dim, Index_ num_obs, std::vector<Data_> data, std::shared_ptr<const DistanceMetric_> metric, const VptreeOptions& options) :
292 my_dim(num_dim),
293 my_obs(num_obs),
294 my_data(std::move(data)),
295 my_metric(std::move(metric))
296 {
297 if (num_obs) {
298 build(options);
299
300 // Resorting data in place to match order of occurrence within 'nodes', for better cache locality.
301 auto used = sanisizer::create<std::vector<char> >(sanisizer::attest_gez(my_obs));
302 auto buffer = sanisizer::create<std::vector<Data_> >(sanisizer::attest_gez(my_dim));
303 sanisizer::resize(my_new_locations, sanisizer::attest_gez(my_obs));
304 auto host = my_data.data();
305
306 for (Index_ o = 0; o < num_obs; ++o) {
307 if (used[o]) {
308 continue;
309 }
310
311 auto& current = my_nodes[o];
312 my_new_locations[current.index] = o;
313 if (current.index == o) {
314 continue;
315 }
316
317 auto optr = host + sanisizer::product_unsafe<std::size_t>(o, my_dim);
318 std::copy_n(optr, my_dim, buffer.begin());
319 Index_ replacement = current.index;
320
321 do {
322 auto rptr = host + sanisizer::product_unsafe<std::size_t>(replacement, my_dim);
323 std::copy_n(rptr, my_dim, optr);
324 used[replacement] = 1;
325
326 const auto& next = my_nodes[replacement];
327 my_new_locations[next.index] = replacement;
328
329 optr = rptr;
330 replacement = next.index;
331 } while (replacement != o);
332
333 std::copy(buffer.begin(), buffer.end(), optr);
334 }
335 }
336 }
337
338private:
339 static bool can_progress_left(const Node& node, const Distance_ dist_to_vp, const Distance_ threshold) {
340 return node.left != TERMINAL && dist_to_vp - threshold <= node.radius;
341 }
342
343 static bool can_progress_right(const Node& node, const Distance_ dist_to_vp, const Distance_ threshold) {
344 // Using >= in the triangle inequality as there are some points that lie on the surface of the ball but are considered 'outside' the ball,
345 // e.g., the median point itself as well as anything with a tied distance.
346 return node.right != TERMINAL && dist_to_vp + threshold >= node.radius;
347 }
348
349 void search_nn(const Data_* target, NeighborQueue<Index_, Distance_>& nearest, std::vector<VptreeSearchHistory<Index_> >& history) const {
350 history.clear();
351 Index_ curnode_offset = 0;
352 Distance_ max_dist = std::numeric_limits<Distance_>::max();
353
354 while (1) {
355 auto nptr = my_data.data() + sanisizer::product_unsafe<std::size_t>(curnode_offset, my_dim);
356 const Distance_ dist_to_vp = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
357
358 const auto& curnode = my_nodes[curnode_offset];
359 if (dist_to_vp <= max_dist) {
360 nearest.add(curnode.index, dist_to_vp);
361 if (nearest.is_full()) {
362 max_dist = nearest.limit(); // update value of max_dist (farthest point in result list)
363 }
364 }
365
366 if (dist_to_vp < curnode.radius) {
367 // If the target lies within the radius of ball, chances are that its neighbors also lie inside the ball.
368 // So we check the points inside the ball first (i.e., left node) to try to shrink max_dist as fast as possible.
369
370 // A quirk here is that, if dist_to_vp < curnode.radius, then can_progress_left must be true if curnode.left != TERMINAL.
371 // So we don't bother to compute the full function.
372 const bool can_left = curnode.left != TERMINAL;
373 const bool can_right = can_progress_right(curnode, dist_to_vp, max_dist);
374
375 if (can_left) {
376 if (can_right) {
377 history.emplace_back(false, curnode_offset);
378 }
379 curnode_offset = curnode.left;
380 continue;
381 } else if (can_right) {
382 curnode_offset = curnode.right;
383 continue;
384 }
385
386 } else {
387 // Otherwise, if the target lies at or outside the radius of the ball, chances are its neighbors also lie outside the ball.
388 // So we check the points outside the ball first (i.e., right node) to try to shrink max_dist as fast as possible.
389
390 // A quirk here is that, if dist_to_vp >= curnode.radius, then can_progress_right must be true if curnode.right != TERMINAL.
391 // So we don't bother to compute the full function.
392 const bool can_right = curnode.right != TERMINAL;
393 const bool can_left = can_progress_left(curnode, dist_to_vp, max_dist);
394
395 if (can_right) {
396 if (can_left) {
397 history.emplace_back(true, curnode_offset);
398 }
399 curnode_offset = curnode.right;
400 continue;
401 } else {
402 // The manner of construction of the VP tree prevents the existence of a node where right == TERMINAL but left != TERMINAL.
403 // As such, there's no need to consider the 'else if (can_left) {' condition that we would otherwise expect for symmetry with the inside-ball code.
404 assert(!can_left);
405 }
406 }
407
408 // We don't have anything else to do here, so we move back to the last branching node in our history.
409 if (history.empty()) {
410 return;
411 }
412
413 auto& histinfo = history.back();
414 if (!histinfo.right) {
415 curnode_offset = my_nodes[histinfo.node].right;
416 } else {
417 curnode_offset = my_nodes[histinfo.node].left;
418 }
419 history.pop_back();
420 }
421 }
422
423 template<bool count_only_, typename Output_>
424 void search_all(const Data_* target, const Distance_ threshold, Output_& all_neighbors, std::vector<VptreeSearchHistory<Index_> >& history) const {
425 history.clear();
426 Index_ curnode_offset = 0;
427
428 while (1) {
429 auto nptr = my_data.data() + sanisizer::product_unsafe<std::size_t>(curnode_offset, my_dim);
430 const Distance_ dist_to_vp = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
431
432 const auto& curnode = my_nodes[curnode_offset];
433 if (dist_to_vp <= threshold) {
434 if constexpr(count_only_) {
435 ++all_neighbors;
436 } else {
437 all_neighbors.emplace_back(dist_to_vp, curnode.index);
438 }
439 }
440
441 const bool can_left = can_progress_left(curnode, dist_to_vp, threshold);
442 const bool can_right = can_progress_right(curnode, dist_to_vp, threshold);
443
444 // Unlike in search_nn(), we don't bother with different priorities for left/right.
445 // The threshold isn't going to change and we'd have to search both children anyway.
446 if (can_left) {
447 if (can_right) {
448 history.emplace_back(false, curnode_offset); // false is just a dummy value and is ignored in this rest of this function.
449 }
450 curnode_offset = curnode.left;
451 continue;
452 } else if (can_right) {
453 curnode_offset = curnode.right;
454 continue;
455 }
456
457 // We don't have anything else to do here, so we move back to the last branching node in our history.
458 if (history.empty()) {
459 return;
460 }
461
462 auto& histinfo = history.back();
463 curnode_offset = my_nodes[histinfo.node].right;
464 history.pop_back();
465 }
466 }
467
468 friend class VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_>;
469
470public:
471 std::unique_ptr<Searcher<Index_, Data_, Distance_> > initialize() const {
472 return initialize_known();
473 }
474
475 auto initialize_known() const {
476 return std::make_unique<VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_> >(*this);
477 }
478
479public:
480 void save(const std::filesystem::path& dir) const {
481 quick_save(dir / "ALGORITHM", vptree_prebuilt_save_name, std::strlen(vptree_prebuilt_save_name));
482 quick_save(dir / "DATA", my_data.data(), my_data.size());
483 quick_save(dir / "NUM_OBS", &my_obs, 1);
484 quick_save(dir / "NUM_DIM", &my_dim, 1);
485 quick_save(dir / "NODES", my_nodes.data(), my_nodes.size());
486 quick_save(dir / "NEW_LOCATIONS", my_new_locations.data(), my_new_locations.size());
487
488 const auto distdir = dir / "DISTANCE";
489 std::filesystem::create_directory(distdir);
490 my_metric->save(distdir);
491 }
492
493 VptreePrebuilt(const std::filesystem::path& dir) {
494 quick_load(dir / "NUM_OBS", &my_obs, 1);
495 quick_load(dir / "NUM_DIM", &my_dim, 1);
496
497 my_data.resize(sanisizer::product<I<decltype(my_data.size())> >(sanisizer::attest_gez(my_obs), my_dim));
498 quick_load(dir / "DATA", my_data.data(), my_data.size());
499
500 sanisizer::resize(my_nodes, sanisizer::attest_gez(my_obs));
501 quick_load(dir / "NODES", my_nodes.data(), my_nodes.size());
502
503 sanisizer::resize(my_new_locations, sanisizer::attest_gez(my_obs));
504 quick_load(dir / "NEW_LOCATIONS", my_new_locations.data(), my_new_locations.size());
505
506 auto dptr = load_distance_metric_raw<Data_, Distance_>(dir / "DISTANCE");
507 auto xptr = dynamic_cast<DistanceMetric_*>(dptr);
508 if (xptr == NULL) {
509 throw std::runtime_error("cannot cast the loaded distance metric to a DistanceMetric_");
510 }
511 my_metric.reset(xptr);
512 }
513};
559template<
560 typename Index_,
561 typename Data_,
562 typename Distance_,
563 class Matrix_ = Matrix<Index_, Data_>,
564 class DistanceMetric_ = DistanceMetric<Data_, Distance_>
565>
566class VptreeBuilder final : public Builder<Index_, Data_, Distance_, Matrix_> {
567public:
572 VptreeBuilder(std::shared_ptr<const DistanceMetric_> metric, VptreeOptions options) : my_metric(std::move(metric)), my_options(std::move(options)) {}
573
579 VptreeBuilder(std::shared_ptr<const DistanceMetric_> metric) : VptreeBuilder(std::move(metric), {}) {}
580
586 return my_options;
587 }
588
589private:
590 std::shared_ptr<const DistanceMetric_> my_metric;
591 VptreeOptions my_options;
592
593public:
597 Prebuilt<Index_, Data_, Distance_>* build_raw(const Matrix_& data) const {
598 return build_known_raw(data);
599 }
604public:
608 auto build_known_raw(const Matrix_& data) const {
609 std::size_t ndim = data.num_dimensions();
610 Index_ nobs = data.num_observations();
611 auto work = data.new_known_extractor();
612
613 // We assume that that vector::size_type <= size_t, otherwise data() wouldn't be a contiguous array.
614 std::vector<Data_> store(sanisizer::product<typename std::vector<Data_>::size_type>(ndim, sanisizer::attest_gez(nobs)));
615 for (Index_ o = 0; o < nobs; ++o) {
616 std::copy_n(work->next(), ndim, store.data() + sanisizer::product_unsafe<std::size_t>(o, ndim));
617 }
618
619 return new VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>(ndim, nobs, std::move(store), my_metric, my_options);
620 }
621
625 auto build_known_unique(const Matrix_& data) const {
626 return std::unique_ptr<I<decltype(*build_known_raw(data))> >(build_known_raw(data));
627 }
628
632 auto build_known_shared(const Matrix_& data) const {
633 return std::shared_ptr<I<decltype(*build_known_raw(data))> >(build_known_raw(data));
634 }
635};
636
637}
638
639#endif
Interface to build nearest-neighbor indices.
Interface for the input matrix.
Helper class to track nearest neighbors.
Interface for prebuilt nearest-neighbor indices.
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:28
virtual Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const =0
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:29
Perform a nearest neighbor search based on a vantage point (VP) tree.
Definition Vptree.hpp:566
VptreeBuilder(std::shared_ptr< const DistanceMetric_ > metric)
Definition Vptree.hpp:579
auto build_known_unique(const Matrix_ &data) const
Definition Vptree.hpp:625
VptreeOptions & get_options()
Definition Vptree.hpp:585
VptreeBuilder(std::shared_ptr< const DistanceMetric_ > metric, VptreeOptions options)
Definition Vptree.hpp:572
auto build_known_raw(const Matrix_ &data) const
Definition Vptree.hpp:608
auto build_known_shared(const Matrix_ &data) const
Definition Vptree.hpp:632
Classes for distance calculations.
Collection of KNN algorithms.
Definition Bruteforce.hpp:29
void quick_load(const std::filesystem::path &path, Input_ *const contents, const Length_ length)
Definition utils.hpp:57
Index_ count_all_neighbors_without_self(Index_ count)
Definition report_all_neighbors.hpp:23
void quick_save(const std::filesystem::path &path, const Input_ *const contents, const Length_ length)
Definition utils.hpp:33
void report_all_neighbors(std::vector< std::pair< Distance_, Index_ > > &all_neighbors, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition report_all_neighbors.hpp:106
Format the output for Searcher::search_all().
Options for VptreeBuilder construction.
Definition Vptree.hpp:42
std::optional< typename std::mt19937_64::result_type > seed
Definition Vptree.hpp:47
Miscellaneous utilities for knncolle