knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
Vptree.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_VPTREE_HPP
2#define KNNCOLLE_VPTREE_HPP
3
4#include "distances.hpp"
5#include "NeighborQueue.hpp"
6#include "Prebuilt.hpp"
7#include "Builder.hpp"
8#include "Matrix.hpp"
10
11#include <vector>
12#include <random>
13#include <limits>
14#include <tuple>
15#include <memory>
16#include <cstddef>
17
24namespace knncolle {
25
29template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
30class VptreePrebuilt;
46template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
47class VptreeSearcher final : public Searcher<Index_, Data_, Distance_> {
48public:
57private:
60 std::vector<std::pair<Distance_, Index_> > my_all_neighbors;
61
62public:
63 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
64 my_nearest.reset(k + 1);
65 auto iptr = my_parent.my_data.data() + static_cast<std::size_t>(my_parent.my_new_locations[i]) * my_parent.my_dim; // cast to avoid overflow.
66 Distance_ max_dist = std::numeric_limits<Distance_>::max();
67 my_parent.search_nn(0, iptr, max_dist, my_nearest);
68 my_nearest.report(output_indices, output_distances, i);
69 }
70
71 void search(const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
72 // Protect the NeighborQueue from k = 0. This also protects search_nn()
73 // when there are no observations (and no node 0 to start recursion).
74 if (k == 0 || my_parent.my_data.empty()) {
75 if (output_indices) {
76 output_indices->clear();
77 }
78 if (output_distances) {
79 output_distances->clear();
80 }
81
82 } else {
83 my_nearest.reset(k);
84 Distance_ max_dist = std::numeric_limits<Distance_>::max();
85 my_parent.search_nn(0, query, max_dist, my_nearest);
86 my_nearest.report(output_indices, output_distances);
87 }
88 }
89
90 bool can_search_all() const {
91 return true;
92 }
93
94 Index_ search_all(Index_ i, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
95 auto iptr = my_parent.my_data.data() + static_cast<std::size_t>(my_parent.my_new_locations[i]) * my_parent.my_dim; // cast to avoid overflow.
96
97 if (!output_indices && !output_distances) {
98 Index_ count = 0;
99 my_parent.template search_all<true>(0, iptr, d, count);
101
102 } else {
103 my_all_neighbors.clear();
104 my_parent.template search_all<false>(0, iptr, d, my_all_neighbors);
105 report_all_neighbors(my_all_neighbors, output_indices, output_distances, i);
106 return count_all_neighbors_without_self(my_all_neighbors.size());
107 }
108 }
109
110 Index_ search_all(const Data_* query, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
111 if (my_parent.my_data.empty()) { // protect the search_all() method when there is not even a node 0 to start the recursion.
112 my_all_neighbors.clear();
113 report_all_neighbors(my_all_neighbors, output_indices, output_distances);
114 return 0;
115 }
116
117 if (!output_indices && !output_distances) {
118 Index_ count = 0;
119 my_parent.template search_all<true>(0, query, d, count);
120 return count;
121
122 } else {
123 my_all_neighbors.clear();
124 my_parent.template search_all<false>(0, query, d, my_all_neighbors);
125 report_all_neighbors(my_all_neighbors, output_indices, output_distances);
126 return my_all_neighbors.size();
127 }
128 }
129};
130
142template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
143class VptreePrebuilt final : public Prebuilt<Index_, Data_, Distance_> {
144private:
145 std::size_t my_dim;
146 Index_ my_obs;
147 std::vector<Data_> my_data;
148 std::shared_ptr<const DistanceMetric_> my_metric;
149
150public:
151 Index_ num_observations() const {
152 return my_obs;
153 }
154
155 std::size_t num_dimensions() const {
156 return my_dim;
157 }
158
159private:
160 /* Adapted from http://stevehanov.ca/blog/index.php?id=130 */
161 static const Index_ LEAF = 0;
162
163 // Single node of a VP tree.
164 struct Node {
165 Distance_ radius = 0;
166
167 // Original index of current vantage point, defining the center of the node.
168 Index_ index = 0;
169
170 // Node index of the next vantage point for all children closer than 'threshold' from the current vantage point.
171 // This must be > 0, as the first node in 'nodes' is the root and cannot be referenced from other nodes.
172 Index_ left = LEAF;
173
174 // Node index of the next vantage point for all children further than 'threshold' from the current vantage point.
175 // This must be > 0, as the first node in 'nodes' is the root and cannot be referenced from other nodes.
176 Index_ right = LEAF;
177 };
178
179 std::vector<Node> my_nodes;
180
181 typedef std::pair<Distance_, Index_> DataPoint;
182
183 template<class Rng_>
184 Index_ build(Index_ lower, Index_ upper, const Data_* coords, std::vector<DataPoint>& items, Rng_& rng) {
185 /*
186 * We're assuming that lower < upper at each point within this
187 * recursion. This requires some protection at the call site
188 * when nobs = 0, see the constructor.
189 */
190
191 Index_ pos = my_nodes.size();
192 my_nodes.emplace_back();
193 Node& node = my_nodes.back(); // this is safe during recursion because we reserved 'nodes' already to the number of observations, see the constructor.
194
195 Index_ gap = upper - lower;
196 if (gap > 1) { // not yet at a leaft.
197
198 /* Choose an arbitrary point and move it to the start of the [lower, upper)
199 * interval in 'items'; this is our new vantage point.
200 *
201 * Yes, I know that the modulo method does not provide strictly
202 * uniform values but statistical correctness doesn't really matter
203 * here, and I don't want std::uniform_int_distribution's
204 * implementation-specific behavior.
205 */
206 Index_ i = (rng() % gap + lower);
207 std::swap(items[lower], items[i]);
208 const auto& vantage = items[lower];
209 node.index = vantage.second;
210 const Data_* vantage_ptr = coords + static_cast<std::size_t>(vantage.second) * my_dim; // cast to avoid overflow.
211
212 // Compute distances to the new vantage point.
213 for (Index_ i = lower + 1; i < upper; ++i) {
214 const Data_* loc = coords + static_cast<std::size_t>(items[i].second) * my_dim; // cast to avoid overflow.
215 items[i].first = my_metric->raw(my_dim, vantage_ptr, loc);
216 }
217
218 // Partition around the median distance from the vantage point.
219 Index_ median = lower + gap/2;
220 Index_ lower_p1 = lower + 1; // excluding the vantage point itself, obviously.
221 std::nth_element(items.begin() + lower_p1, items.begin() + median, items.begin() + upper);
222
223 // Radius of the new node will be the distance to the median.
224 node.radius = my_metric->normalize(items[median].first);
225
226 // Recursively build tree.
227 if (lower_p1 < median) {
228 node.left = build(lower_p1, median, coords, items, rng);
229 }
230 if (median < upper) {
231 node.right = build(median, upper, coords, items, rng);
232 }
233
234 } else {
235 const auto& leaf = items[lower];
236 node.index = leaf.second;
237 }
238
239 return pos;
240 }
241
242private:
243 std::vector<Index_> my_new_locations;
244
245public:
249 VptreePrebuilt(std::size_t num_dim, Index_ num_obs, std::vector<Data_> data, std::shared_ptr<const DistanceMetric_> metric) :
250 my_dim(num_dim),
251 my_obs(num_obs),
252 my_data(std::move(data)),
253 my_metric(std::move(metric))
254 {
255 if (num_obs) {
256 std::vector<DataPoint> items;
257 items.reserve(my_obs);
258 for (Index_ i = 0; i < my_obs; ++i) {
259 items.emplace_back(0, i);
260 }
261
262 my_nodes.reserve(my_obs);
263
264 // Statistical correctness doesn't matter (aside from tie breaking)
265 // so we'll just use a deterministically 'random' number to ensure
266 // we get the same ties for any given dataset but a different stream
267 // of numbers between datasets. Casting to get well-defined overflow.
268 uint64_t base = 1234567890, m1 = my_obs, m2 = my_dim;
269 std::mt19937_64 rand(base * m1 + m2);
270
271 build(0, my_obs, my_data.data(), items, rand);
272
273 // Resorting data in place to match order of occurrence within
274 // 'nodes', for better cache locality.
275 std::vector<uint8_t> used(my_obs);
276 std::vector<Data_> buffer(my_dim);
277 my_new_locations.resize(my_obs);
278 auto host = my_data.data();
279
280 for (Index_ o = 0; o < num_obs; ++o) {
281 if (used[o]) {
282 continue;
283 }
284
285 auto& current = my_nodes[o];
286 my_new_locations[current.index] = o;
287 if (current.index == o) {
288 continue;
289 }
290
291 auto optr = host + static_cast<std::size_t>(o) * my_dim;
292 std::copy_n(optr, my_dim, buffer.begin());
293 Index_ replacement = current.index;
294
295 do {
296 auto rptr = host + static_cast<std::size_t>(replacement) * my_dim;
297 std::copy_n(rptr, my_dim, optr);
298 used[replacement] = 1;
299
300 const auto& next = my_nodes[replacement];
301 my_new_locations[next.index] = replacement;
302
303 optr = rptr;
304 replacement = next.index;
305 } while (replacement != o);
306
307 std::copy(buffer.begin(), buffer.end(), optr);
308 }
309 }
310 }
315private:
316 void search_nn(Index_ curnode_index, const Data_* target, Distance_& max_dist, NeighborQueue<Index_, Distance_>& nearest) const {
317 auto nptr = my_data.data() + static_cast<std::size_t>(curnode_index) * my_dim; // cast to avoid overflow.
318 Distance_ dist = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
319
320 // If current node is within the maximum distance:
321 const auto& curnode = my_nodes[curnode_index];
322 if (dist <= max_dist) {
323 nearest.add(curnode.index, dist);
324 if (nearest.is_full()) {
325 max_dist = nearest.limit(); // update value of max_dist (farthest point in result list)
326 }
327 }
328
329 if (dist < curnode.radius) { // If the target lies within the radius of ball:
330 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child first
331 search_nn(curnode.left, target, max_dist, nearest);
332 }
333
334 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child
335 search_nn(curnode.right, target, max_dist, nearest);
336 }
337
338 } else { // If the target lies outsize the radius of the ball:
339 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child first
340 search_nn(curnode.right, target, max_dist, nearest);
341 }
342
343 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child
344 search_nn(curnode.left, target, max_dist, nearest);
345 }
346 }
347 }
348
349 template<bool count_only_, typename Output_>
350 void search_all(Index_ curnode_index, const Data_* target, Distance_ threshold, Output_& all_neighbors) const {
351 auto nptr = my_data.data() + static_cast<std::size_t>(curnode_index) * my_dim; // cast to avoid overflow.
352 Distance_ dist = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
353
354 // If current node is within the maximum distance:
355 const auto& curnode = my_nodes[curnode_index];
356 if (dist <= threshold) {
357 if constexpr(count_only_) {
358 ++all_neighbors;
359 } else {
360 all_neighbors.emplace_back(dist, curnode.index);
361 }
362 }
363
364 if (dist < curnode.radius) { // If the target lies within the radius of ball:
365 if (curnode.left != LEAF && dist - threshold <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child first
366 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
367 }
368
369 if (curnode.right != LEAF && dist + threshold >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child
370 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
371 }
372
373 } else { // If the target lies outsize the radius of the ball:
374 if (curnode.right != LEAF && dist + threshold >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child first
375 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
376 }
377
378 if (curnode.left != LEAF && dist - threshold <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child
379 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
380 }
381 }
382 }
383
384 friend class VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_>;
385
386public:
390 std::unique_ptr<Searcher<Index_, Data_, Distance_> > initialize() const {
391 return std::make_unique<VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_> >(*this);
392 }
393};
394
427template<
428 typename Index_,
429 typename Data_,
430 typename Distance_,
431 class Matrix_ = Matrix<Index_, Data_>,
432 class DistanceMetric_ = DistanceMetric<Data_, Distance_>
433>
434class VptreeBuilder final : public Builder<Index_, Data_, Distance_, Matrix_> {
435public:
439 VptreeBuilder(std::shared_ptr<const DistanceMetric_> metric) : my_metric(std::move(metric)) {}
440
441private:
442 std::shared_ptr<const DistanceMetric_> my_metric;
443
444public:
448 Prebuilt<Index_, Data_, Distance_>* build_raw(const Matrix_& data) const {
449 std::size_t ndim = data.num_dimensions();
450 Index_ nobs = data.num_observations();
451 auto work = data.new_extractor();
452
453 std::vector<Data_> store(ndim * static_cast<std::size_t>(nobs)); // cast to avoid overflow.
454 for (Index_ o = 0; o < nobs; ++o) {
455 std::copy_n(work->next(), ndim, store.begin() + static_cast<std::size_t>(o) * ndim); // cast to avoid overflow.
456 }
457
458 return new VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>(ndim, nobs, std::move(store), my_metric);
459 }
460};
461
462};
463
464#endif
Interface to build nearest-neighbor indices.
Interface for the input matrix.
Helper class to track nearest neighbors.
Interface for prebuilt nearest-neighbor indices.
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:28
Helper class to track nearest neighbors.
Definition NeighborQueue.hpp:30
void report(std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition NeighborQueue.hpp:109
void reset(Index_ k)
Definition NeighborQueue.hpp:46
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:28
Interface for searching nearest-neighbor search indices.
Definition Searcher.hpp:28
Perform a nearest neighbor search based on a vantage point (VP) tree.
Definition Vptree.hpp:434
VptreeBuilder(std::shared_ptr< const DistanceMetric_ > metric)
Definition Vptree.hpp:439
Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const
Definition Vptree.hpp:448
Index for a VP-tree search.
Definition Vptree.hpp:143
std::size_t num_dimensions() const
Definition Vptree.hpp:155
Index_ num_observations() const
Definition Vptree.hpp:151
std::unique_ptr< Searcher< Index_, Data_, Distance_ > > initialize() const
Definition Vptree.hpp:390
VP-tree searcher.
Definition Vptree.hpp:47
Index_ search_all(Index_ i, Distance_ d, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Vptree.hpp:94
void search(Index_ i, Index_ k, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Vptree.hpp:63
void search(const Data_ *query, Index_ k, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Vptree.hpp:71
Index_ search_all(const Data_ *query, Distance_ d, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Vptree.hpp:110
bool can_search_all() const
Definition Vptree.hpp:90
Classes for distance calculations.
Collection of KNN algorithms.
Definition Bruteforce.hpp:24
Index_ count_all_neighbors_without_self(Index_ count)
Definition report_all_neighbors.hpp:23
void report_all_neighbors(std::vector< std::pair< Distance_, Index_ > > &all_neighbors, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition report_all_neighbors.hpp:106
Format the output for Searcher::search_all().