knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
Vptree.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_VPTREE_HPP
2#define KNNCOLLE_VPTREE_HPP
3
4#include "distances.hpp"
5#include "NeighborQueue.hpp"
6#include "Prebuilt.hpp"
7#include "Builder.hpp"
8#include "Matrix.hpp"
10#include "utils.hpp"
11
12#include <vector>
13#include <random>
14#include <limits>
15#include <tuple>
16#include <memory>
17#include <cstddef>
18#include <string>
19#include <cstring>
20#include <filesystem>
21
22#include "sanisizer/sanisizer.hpp"
23
30namespace knncolle {
31
35inline static constexpr const char* vptree_prebuilt_save_name = "knncolle::Vptree";
36
40template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
41class VptreePrebuilt;
42
43template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
44class VptreeSearcher final : public Searcher<Index_, Data_, Distance_> {
45public:
46 VptreeSearcher(const VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>& parent) : my_parent(parent) {}
47
48private:
49 const VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>& my_parent;
50 NeighborQueue<Index_, Distance_> my_nearest;
51 std::vector<std::pair<Distance_, Index_> > my_all_neighbors;
52
53public:
54 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
55 my_nearest.reset(k + 1); // +1 is safe as k < num_obs.
56 auto iptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(my_parent.my_new_locations[i], my_parent.my_dim);
57 Distance_ max_dist = std::numeric_limits<Distance_>::max();
58 my_parent.search_nn(0, iptr, max_dist, my_nearest);
59 my_nearest.report(output_indices, output_distances, i);
60 }
61
62 void search(const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
63 // Protect the NeighborQueue from k = 0. This also protects search_nn()
64 // when there are no observations (and no node 0 to start recursion).
65 if (k == 0 || my_parent.my_nodes.empty()) {
66 if (output_indices) {
67 output_indices->clear();
68 }
69 if (output_distances) {
70 output_distances->clear();
71 }
72
73 } else {
74 my_nearest.reset(k);
75 Distance_ max_dist = std::numeric_limits<Distance_>::max();
76 my_parent.search_nn(0, query, max_dist, my_nearest);
77 my_nearest.report(output_indices, output_distances);
78 }
79 }
80
81 bool can_search_all() const {
82 return true;
83 }
84
85 Index_ search_all(Index_ i, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
86 auto iptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(my_parent.my_new_locations[i], my_parent.my_dim);
87
88 if (!output_indices && !output_distances) {
89 Index_ count = 0;
90 my_parent.template search_all<true>(0, iptr, d, count);
92
93 } else {
94 my_all_neighbors.clear();
95 my_parent.template search_all<false>(0, iptr, d, my_all_neighbors);
96 report_all_neighbors(my_all_neighbors, output_indices, output_distances, i);
97 return count_all_neighbors_without_self(my_all_neighbors.size());
98 }
99 }
100
101 Index_ search_all(const Data_* query, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
102 if (my_parent.my_nodes.empty()) { // protect the search_all() method when there is not even a node 0 to start the recursion.
103 my_all_neighbors.clear();
104 report_all_neighbors(my_all_neighbors, output_indices, output_distances);
105 return 0;
106 }
107
108 if (!output_indices && !output_distances) {
109 Index_ count = 0;
110 my_parent.template search_all<true>(0, query, d, count);
111 return count;
112
113 } else {
114 my_all_neighbors.clear();
115 my_parent.template search_all<false>(0, query, d, my_all_neighbors);
116 report_all_neighbors(my_all_neighbors, output_indices, output_distances);
117 return my_all_neighbors.size();
118 }
119 }
120};
121
122template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
123class VptreePrebuilt final : public Prebuilt<Index_, Data_, Distance_> {
124private:
125 std::size_t my_dim;
126 Index_ my_obs;
127 std::vector<Data_> my_data;
128 std::shared_ptr<const DistanceMetric_> my_metric;
129
130public:
131 Index_ num_observations() const {
132 return my_obs;
133 }
134
135 std::size_t num_dimensions() const {
136 return my_dim;
137 }
138
139private:
140 /* Adapted from http://stevehanov.ca/blog/index.php?id=130 */
141 static const Index_ LEAF = 0;
142
143 // Single node of a VP tree.
144 struct Node {
145 Distance_ radius = 0;
146
147 // Original index of current vantage point, defining the center of the node.
148 Index_ index = 0;
149
150 // Node index of the next vantage point for all children closer than 'threshold' from the current vantage point.
151 // This must be > 0, as the first node in 'nodes' is the root and cannot be referenced from other nodes.
152 Index_ left = LEAF;
153
154 // Node index of the next vantage point for all children further than 'threshold' from the current vantage point.
155 // This must be > 0, as the first node in 'nodes' is the root and cannot be referenced from other nodes.
156 Index_ right = LEAF;
157 };
158
159 std::vector<Node> my_nodes;
160
161 typedef std::pair<Distance_, Index_> DataPoint;
162
163 template<class Rng_>
164 Index_ build(Index_ lower, Index_ upper, const Data_* coords, std::vector<DataPoint>& items, Rng_& rng) {
165 /*
166 * We're assuming that lower < upper at each point within this
167 * recursion. This requires some protection at the call site
168 * when nobs = 0, see the constructor.
169 */
170
171 Index_ pos = my_nodes.size();
172 my_nodes.emplace_back();
173 Node& node = my_nodes.back(); // this is safe during recursion because we reserved 'nodes' already to the number of observations, see the constructor.
174
175 Index_ gap = upper - lower;
176 if (gap > 1) { // not yet at a leaf.
177
178 /* Choose an arbitrary point and move it to the start of the [lower, upper)
179 * interval in 'items'; this is our new vantage point.
180 *
181 * Yes, I know that the modulo method does not provide strictly
182 * uniform values but statistical correctness doesn't really matter
183 * here, and I don't want std::uniform_int_distribution's
184 * implementation-specific behavior.
185 */
186 Index_ i = (rng() % gap + lower);
187 std::swap(items[lower], items[i]);
188 const auto& vantage = items[lower];
189 node.index = vantage.second;
190 const Data_* vantage_ptr = coords + sanisizer::product_unsafe<std::size_t>(vantage.second, my_dim);
191
192 // Compute distances to the new vantage point.
193 for (Index_ i = lower + 1; i < upper; ++i) {
194 const Data_* loc = coords + sanisizer::product_unsafe<std::size_t>(items[i].second, my_dim);
195 items[i].first = my_metric->raw(my_dim, vantage_ptr, loc);
196 }
197
198 // Partition around the median distance from the vantage point.
199 Index_ median = lower + gap/2;
200 Index_ lower_p1 = lower + 1; // excluding the vantage point itself, obviously.
201 std::nth_element(items.begin() + lower_p1, items.begin() + median, items.begin() + upper);
202
203 // Radius of the new node will be the distance to the median.
204 node.radius = my_metric->normalize(items[median].first);
205
206 // Recursively build tree.
207 if (lower_p1 < median) {
208 node.left = build(lower_p1, median, coords, items, rng);
209 }
210 if (median < upper) {
211 node.right = build(median, upper, coords, items, rng);
212 }
213
214 } else {
215 const auto& leaf = items[lower];
216 node.index = leaf.second;
217 }
218
219 return pos;
220 }
221
222private:
223 std::vector<Index_> my_new_locations;
224
225public:
226 VptreePrebuilt(std::size_t num_dim, Index_ num_obs, std::vector<Data_> data, std::shared_ptr<const DistanceMetric_> metric) :
227 my_dim(num_dim),
228 my_obs(num_obs),
229 my_data(std::move(data)),
230 my_metric(std::move(metric))
231 {
232 if (num_obs) {
233 std::vector<DataPoint> items;
234 items.reserve(my_obs);
235 for (Index_ i = 0; i < my_obs; ++i) {
236 items.emplace_back(0, i);
237 }
238
239 my_nodes.reserve(my_obs);
240
241 // Statistical correctness doesn't matter (aside from tie breaking)
242 // so we'll just use a deterministically 'random' number to ensure
243 // we get the same ties for any given dataset but a different stream
244 // of numbers between datasets. Casting to get well-defined overflow.
245 const std::mt19937_64::result_type base = 1234567890, m1 = my_obs, m2 = my_dim;
246 std::mt19937_64 rand(base * m1 + m2);
247
248 build(0, my_obs, my_data.data(), items, rand);
249
250 // Resorting data in place to match order of occurrence within
251 // 'nodes', for better cache locality.
252 auto used = sanisizer::create<std::vector<char> >(sanisizer::attest_gez(my_obs));
253 auto buffer = sanisizer::create<std::vector<Data_> >(sanisizer::attest_gez(my_dim));
254 sanisizer::resize(my_new_locations, sanisizer::attest_gez(my_obs));
255 auto host = my_data.data();
256
257 for (Index_ o = 0; o < num_obs; ++o) {
258 if (used[o]) {
259 continue;
260 }
261
262 auto& current = my_nodes[o];
263 my_new_locations[current.index] = o;
264 if (current.index == o) {
265 continue;
266 }
267
268 auto optr = host + sanisizer::product_unsafe<std::size_t>(o, my_dim);
269 std::copy_n(optr, my_dim, buffer.begin());
270 Index_ replacement = current.index;
271
272 do {
273 auto rptr = host + sanisizer::product_unsafe<std::size_t>(replacement, my_dim);
274 std::copy_n(rptr, my_dim, optr);
275 used[replacement] = 1;
276
277 const auto& next = my_nodes[replacement];
278 my_new_locations[next.index] = replacement;
279
280 optr = rptr;
281 replacement = next.index;
282 } while (replacement != o);
283
284 std::copy(buffer.begin(), buffer.end(), optr);
285 }
286 }
287 }
288
289private:
290 void search_nn(Index_ curnode_index, const Data_* target, Distance_& max_dist, NeighborQueue<Index_, Distance_>& nearest) const {
291 auto nptr = my_data.data() + sanisizer::product_unsafe<std::size_t>(curnode_index, my_dim);
292 Distance_ dist = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
293
294 // If current node is within the maximum distance:
295 const auto& curnode = my_nodes[curnode_index];
296 if (dist <= max_dist) {
297 nearest.add(curnode.index, dist);
298 if (nearest.is_full()) {
299 max_dist = nearest.limit(); // update value of max_dist (farthest point in result list)
300 }
301 }
302
303 if (dist < curnode.radius) { // If the target lies within the radius of ball:
304 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child first
305 search_nn(curnode.left, target, max_dist, nearest);
306 }
307
308 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child
309 search_nn(curnode.right, target, max_dist, nearest);
310 }
311
312 } else { // If the target lies outsize the radius of the ball:
313 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child first
314 search_nn(curnode.right, target, max_dist, nearest);
315 }
316
317 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child
318 search_nn(curnode.left, target, max_dist, nearest);
319 }
320 }
321 }
322
323 template<bool count_only_, typename Output_>
324 void search_all(Index_ curnode_index, const Data_* target, Distance_ threshold, Output_& all_neighbors) const {
325 auto nptr = my_data.data() + sanisizer::product_unsafe<std::size_t>(curnode_index, my_dim);
326 Distance_ dist = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
327
328 // If current node is within the maximum distance:
329 const auto& curnode = my_nodes[curnode_index];
330 if (dist <= threshold) {
331 if constexpr(count_only_) {
332 ++all_neighbors;
333 } else {
334 all_neighbors.emplace_back(dist, curnode.index);
335 }
336 }
337
338 if (dist < curnode.radius) { // If the target lies within the radius of ball:
339 if (curnode.left != LEAF && dist - threshold <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child first
340 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
341 }
342
343 if (curnode.right != LEAF && dist + threshold >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child
344 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
345 }
346
347 } else { // If the target lies outsize the radius of the ball:
348 if (curnode.right != LEAF && dist + threshold >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child first
349 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
350 }
351
352 if (curnode.left != LEAF && dist - threshold <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child
353 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
354 }
355 }
356 }
357
358 friend class VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_>;
359
360public:
361 std::unique_ptr<Searcher<Index_, Data_, Distance_> > initialize() const {
362 return initialize_known();
363 }
364
365 auto initialize_known() const {
366 return std::make_unique<VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_> >(*this);
367 }
368
369public:
370 void save(const std::filesystem::path& dir) const {
371 quick_save(dir / "ALGORITHM", vptree_prebuilt_save_name, std::strlen(vptree_prebuilt_save_name));
372 quick_save(dir / "DATA", my_data.data(), my_data.size());
373 quick_save(dir / "NUM_OBS", &my_obs, 1);
374 quick_save(dir / "NUM_DIM", &my_dim, 1);
375 quick_save(dir / "NODES", my_nodes.data(), my_nodes.size());
376 quick_save(dir / "NEW_LOCATIONS", my_new_locations.data(), my_new_locations.size());
377
378 const auto distdir = dir / "DISTANCE";
379 std::filesystem::create_directory(distdir);
380 my_metric->save(distdir);
381 }
382
383 VptreePrebuilt(const std::filesystem::path& dir) {
384 quick_load(dir / "NUM_OBS", &my_obs, 1);
385 quick_load(dir / "NUM_DIM", &my_dim, 1);
386
387 my_data.resize(sanisizer::product<I<decltype(my_data.size())> >(sanisizer::attest_gez(my_obs), my_dim));
388 quick_load(dir / "DATA", my_data.data(), my_data.size());
389
390 sanisizer::resize(my_nodes, sanisizer::attest_gez(my_obs));
391 quick_load(dir / "NODES", my_nodes.data(), my_nodes.size());
392
393 sanisizer::resize(my_new_locations, sanisizer::attest_gez(my_obs));
394 quick_load(dir / "NEW_LOCATIONS", my_new_locations.data(), my_new_locations.size());
395
396 auto dptr = load_distance_metric_raw<Data_, Distance_>(dir / "DISTANCE");
397 auto xptr = dynamic_cast<DistanceMetric_*>(dptr);
398 if (xptr == NULL) {
399 throw std::runtime_error("cannot cast the loaded distance metric to a DistanceMetric_");
400 }
401 my_metric.reset(xptr);
402 }
403};
449template<
450 typename Index_,
451 typename Data_,
452 typename Distance_,
453 class Matrix_ = Matrix<Index_, Data_>,
454 class DistanceMetric_ = DistanceMetric<Data_, Distance_>
455>
456class VptreeBuilder final : public Builder<Index_, Data_, Distance_, Matrix_> {
457public:
461 VptreeBuilder(std::shared_ptr<const DistanceMetric_> metric) : my_metric(std::move(metric)) {}
462
463private:
464 std::shared_ptr<const DistanceMetric_> my_metric;
465
466public:
470 Prebuilt<Index_, Data_, Distance_>* build_raw(const Matrix_& data) const {
471 return build_known_raw(data);
472 }
477public:
481 auto build_known_raw(const Matrix_& data) const {
482 std::size_t ndim = data.num_dimensions();
483 Index_ nobs = data.num_observations();
484 auto work = data.new_known_extractor();
485
486 // We assume that that vector::size_type <= size_t, otherwise data() wouldn't be a contiguous array.
487 std::vector<Data_> store(sanisizer::product<typename std::vector<Data_>::size_type>(ndim, sanisizer::attest_gez(nobs)));
488 for (Index_ o = 0; o < nobs; ++o) {
489 std::copy_n(work->next(), ndim, store.data() + sanisizer::product_unsafe<std::size_t>(o, ndim));
490 }
491
492 return new VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>(ndim, nobs, std::move(store), my_metric);
493 }
494
498 auto build_known_unique(const Matrix_& data) const {
499 return std::unique_ptr<I<decltype(*build_known_raw(data))> >(build_known_raw(data));
500 }
501
505 auto build_known_shared(const Matrix_& data) const {
506 return std::shared_ptr<I<decltype(*build_known_raw(data))> >(build_known_raw(data));
507 }
508};
509
510}
511
512#endif
Interface to build nearest-neighbor indices.
Interface for the input matrix.
Helper class to track nearest neighbors.
Interface for prebuilt nearest-neighbor indices.
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:28
virtual Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const =0
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:29
Perform a nearest neighbor search based on a vantage point (VP) tree.
Definition Vptree.hpp:456
VptreeBuilder(std::shared_ptr< const DistanceMetric_ > metric)
Definition Vptree.hpp:461
auto build_known_unique(const Matrix_ &data) const
Definition Vptree.hpp:498
auto build_known_raw(const Matrix_ &data) const
Definition Vptree.hpp:481
auto build_known_shared(const Matrix_ &data) const
Definition Vptree.hpp:505
Classes for distance calculations.
Collection of KNN algorithms.
Definition Bruteforce.hpp:29
void quick_load(const std::filesystem::path &path, Input_ *const contents, const Length_ length)
Definition utils.hpp:57
Index_ count_all_neighbors_without_self(Index_ count)
Definition report_all_neighbors.hpp:23
void quick_save(const std::filesystem::path &path, const Input_ *const contents, const Length_ length)
Definition utils.hpp:33
void report_all_neighbors(std::vector< std::pair< Distance_, Index_ > > &all_neighbors, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition report_all_neighbors.hpp:106
Format the output for Searcher::search_all().
Miscellaneous utilities for knncolle