knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
Vptree.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_VPTREE_HPP
2#define KNNCOLLE_VPTREE_HPP
3
4#include "distances.hpp"
5#include "NeighborQueue.hpp"
6#include "Prebuilt.hpp"
7#include "Builder.hpp"
8#include "Matrix.hpp"
10
11#include <vector>
12#include <random>
13#include <limits>
14#include <tuple>
15#include <memory>
16
23namespace knncolle {
24
28template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
29class VptreePrebuilt;
45template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
46class VptreeSearcher final : public Searcher<Index_, Data_, Distance_> {
47public:
56private:
59 std::vector<std::pair<Distance_, Index_> > my_all_neighbors;
60
61public:
62 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
63 my_nearest.reset(k + 1);
64 auto iptr = my_parent.my_data.data() + static_cast<size_t>(my_parent.my_new_locations[i]) * my_parent.my_dim; // cast to avoid overflow.
65 Distance_ max_dist = std::numeric_limits<Distance_>::max();
66 my_parent.search_nn(0, iptr, max_dist, my_nearest);
67 my_nearest.report(output_indices, output_distances, i);
68 }
69
70 void search(const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
71 // Protect the NeighborQueue from k = 0. This also protects search_nn()
72 // when there are no observations (and no node 0 to start recursion).
73 if (k == 0) {
74 if (output_indices) {
75 output_indices->clear();
76 }
77 if (output_distances) {
78 output_distances->clear();
79 }
80
81 } else {
82 my_nearest.reset(k);
83 Distance_ max_dist = std::numeric_limits<Distance_>::max();
84 my_parent.search_nn(0, query, max_dist, my_nearest);
85 my_nearest.report(output_indices, output_distances);
86 }
87 }
88
89 bool can_search_all() const {
90 return true;
91 }
92
93 Index_ search_all(Index_ i, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
94 auto iptr = my_parent.my_data.data() + static_cast<size_t>(my_parent.my_new_locations[i]) * my_parent.my_dim; // cast to avoid overflow.
95
96 if (!output_indices && !output_distances) {
97 Index_ count = 0;
98 my_parent.template search_all<true>(0, iptr, d, count);
100
101 } else {
102 my_all_neighbors.clear();
103 my_parent.template search_all<false>(0, iptr, d, my_all_neighbors);
104 report_all_neighbors(my_all_neighbors, output_indices, output_distances, i);
105 return count_all_neighbors_without_self(my_all_neighbors.size());
106 }
107 }
108
109 Index_ search_all(const Data_* query, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
110 if (my_parent.my_data.empty()) { // protect the search_all() method when there is not even a node 0 to start the recursion.
111 my_all_neighbors.clear();
112 report_all_neighbors(my_all_neighbors, output_indices, output_distances);
113 return 0;
114 }
115
116 if (!output_indices && !output_distances) {
117 Index_ count = 0;
118 my_parent.template search_all<true>(0, query, d, count);
119 return count;
120
121 } else {
122 my_all_neighbors.clear();
123 my_parent.template search_all<false>(0, query, d, my_all_neighbors);
124 report_all_neighbors(my_all_neighbors, output_indices, output_distances);
125 return my_all_neighbors.size();
126 }
127 }
128};
129
141template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
142class VptreePrebuilt final : public Prebuilt<Index_, Data_, Distance_> {
143private:
144 size_t my_dim;
145 Index_ my_obs;
146 std::vector<Data_> my_data;
147 std::shared_ptr<const DistanceMetric_> my_metric;
148
149public:
150 Index_ num_observations() const {
151 return my_obs;
152 }
153
154 size_t num_dimensions() const {
155 return my_dim;
156 }
157
158private:
159 /* Adapted from http://stevehanov.ca/blog/index.php?id=130 */
160 static const Index_ LEAF = 0;
161
162 // Single node of a VP tree.
163 struct Node {
164 Distance_ radius = 0;
165
166 // Original index of current vantage point, defining the center of the node.
167 Index_ index = 0;
168
169 // Node index of the next vantage point for all children closer than 'threshold' from the current vantage point.
170 // This must be > 0, as the first node in 'nodes' is the root and cannot be referenced from other nodes.
171 Index_ left = LEAF;
172
173 // Node index of the next vantage point for all children further than 'threshold' from the current vantage point.
174 // This must be > 0, as the first node in 'nodes' is the root and cannot be referenced from other nodes.
175 Index_ right = LEAF;
176 };
177
178 std::vector<Node> my_nodes;
179
180 typedef std::pair<Distance_, Index_> DataPoint;
181
182 template<class Rng_>
183 Index_ build(Index_ lower, Index_ upper, const Data_* coords, std::vector<DataPoint>& items, Rng_& rng) {
184 /*
185 * We're assuming that lower < upper at each point within this
186 * recursion. This requires some protection at the call site
187 * when nobs = 0, see the constructor.
188 */
189
190 Index_ pos = my_nodes.size();
191 my_nodes.emplace_back();
192 Node& node = my_nodes.back(); // this is safe during recursion because we reserved 'nodes' already to the number of observations, see the constructor.
193
194 Index_ gap = upper - lower;
195 if (gap > 1) { // not yet at a leaft.
196
197 /* Choose an arbitrary point and move it to the start of the [lower, upper)
198 * interval in 'items'; this is our new vantage point.
199 *
200 * Yes, I know that the modulo method does not provide strictly
201 * uniform values but statistical correctness doesn't really matter
202 * here, and I don't want std::uniform_int_distribution's
203 * implementation-specific behavior.
204 */
205 Index_ i = (rng() % gap + lower);
206 std::swap(items[lower], items[i]);
207 const auto& vantage = items[lower];
208 node.index = vantage.second;
209 const Data_* vantage_ptr = coords + static_cast<size_t>(vantage.second) * my_dim; // cast to avoid overflow.
210
211 // Compute distances to the new vantage point.
212 for (Index_ i = lower + 1; i < upper; ++i) {
213 const Data_* loc = coords + static_cast<size_t>(items[i].second) * my_dim; // cast to avoid overflow.
214 items[i].first = my_metric->raw(my_dim, vantage_ptr, loc);
215 }
216
217 // Partition around the median distance from the vantage point.
218 Index_ median = lower + gap/2;
219 Index_ lower_p1 = lower + 1; // excluding the vantage point itself, obviously.
220 std::nth_element(items.begin() + lower_p1, items.begin() + median, items.begin() + upper);
221
222 // Radius of the new node will be the distance to the median.
223 node.radius = my_metric->normalize(items[median].first);
224
225 // Recursively build tree.
226 if (lower_p1 < median) {
227 node.left = build(lower_p1, median, coords, items, rng);
228 }
229 if (median < upper) {
230 node.right = build(median, upper, coords, items, rng);
231 }
232
233 } else {
234 const auto& leaf = items[lower];
235 node.index = leaf.second;
236 }
237
238 return pos;
239 }
240
241private:
242 std::vector<Index_> my_new_locations;
243
244public:
248 VptreePrebuilt(size_t num_dim, Index_ num_obs, std::vector<Data_> data, std::shared_ptr<const DistanceMetric_> metric) :
249 my_dim(num_dim),
250 my_obs(num_obs),
251 my_data(std::move(data)),
252 my_metric(std::move(metric))
253 {
254 if (num_obs) {
255 std::vector<DataPoint> items;
256 items.reserve(my_obs);
257 for (Index_ i = 0; i < my_obs; ++i) {
258 items.emplace_back(0, i);
259 }
260
261 my_nodes.reserve(my_obs);
262
263 // Statistical correctness doesn't matter (aside from tie breaking)
264 // so we'll just use a deterministically 'random' number to ensure
265 // we get the same ties for any given dataset but a different stream
266 // of numbers between datasets. Casting to get well-defined overflow.
267 uint64_t base = 1234567890, m1 = my_obs, m2 = my_dim;
268 std::mt19937_64 rand(base * m1 + m2);
269
270 build(0, my_obs, my_data.data(), items, rand);
271
272 // Resorting data in place to match order of occurrence within
273 // 'nodes', for better cache locality.
274 std::vector<uint8_t> used(my_obs);
275 std::vector<Data_> buffer(my_dim);
276 my_new_locations.resize(my_obs);
277 auto host = my_data.data();
278
279 for (Index_ o = 0; o < num_obs; ++o) {
280 if (used[o]) {
281 continue;
282 }
283
284 auto& current = my_nodes[o];
285 my_new_locations[current.index] = o;
286 if (current.index == o) {
287 continue;
288 }
289
290 auto optr = host + static_cast<size_t>(o) * my_dim;
291 std::copy_n(optr, my_dim, buffer.begin());
292 Index_ replacement = current.index;
293
294 do {
295 auto rptr = host + static_cast<size_t>(replacement) * my_dim;
296 std::copy_n(rptr, my_dim, optr);
297 used[replacement] = 1;
298
299 const auto& next = my_nodes[replacement];
300 my_new_locations[next.index] = replacement;
301
302 optr = rptr;
303 replacement = next.index;
304 } while (replacement != o);
305
306 std::copy(buffer.begin(), buffer.end(), optr);
307 }
308 }
309 }
314private:
315 void search_nn(Index_ curnode_index, const Data_* target, Distance_& max_dist, NeighborQueue<Index_, Distance_>& nearest) const {
316 auto nptr = my_data.data() + static_cast<size_t>(curnode_index) * my_dim; // cast to avoid overflow.
317 Distance_ dist = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
318
319 // If current node is within the maximum distance:
320 const auto& curnode = my_nodes[curnode_index];
321 if (dist <= max_dist) {
322 nearest.add(curnode.index, dist);
323 if (nearest.is_full()) {
324 max_dist = nearest.limit(); // update value of max_dist (farthest point in result list)
325 }
326 }
327
328 if (dist < curnode.radius) { // If the target lies within the radius of ball:
329 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child first
330 search_nn(curnode.left, target, max_dist, nearest);
331 }
332
333 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child
334 search_nn(curnode.right, target, max_dist, nearest);
335 }
336
337 } else { // If the target lies outsize the radius of the ball:
338 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child first
339 search_nn(curnode.right, target, max_dist, nearest);
340 }
341
342 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child
343 search_nn(curnode.left, target, max_dist, nearest);
344 }
345 }
346 }
347
348 template<bool count_only_, typename Output_>
349 void search_all(Index_ curnode_index, const Data_* target, Distance_ threshold, Output_& all_neighbors) const {
350 auto nptr = my_data.data() + static_cast<size_t>(curnode_index) * my_dim; // cast to avoid overflow.
351 Distance_ dist = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
352
353 // If current node is within the maximum distance:
354 const auto& curnode = my_nodes[curnode_index];
355 if (dist <= threshold) {
356 if constexpr(count_only_) {
357 ++all_neighbors;
358 } else {
359 all_neighbors.emplace_back(dist, curnode.index);
360 }
361 }
362
363 if (dist < curnode.radius) { // If the target lies within the radius of ball:
364 if (curnode.left != LEAF && dist - threshold <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child first
365 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
366 }
367
368 if (curnode.right != LEAF && dist + threshold >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child
369 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
370 }
371
372 } else { // If the target lies outsize the radius of the ball:
373 if (curnode.right != LEAF && dist + threshold >= curnode.radius) { // if there can still be neighbors outside the ball, recursively search right child first
374 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
375 }
376
377 if (curnode.left != LEAF && dist - threshold <= curnode.radius) { // if there can still be neighbors inside the ball, recursively search left child
378 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
379 }
380 }
381 }
382
383 friend class VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_>;
384
385public:
389 std::unique_ptr<Searcher<Index_, Data_, Distance_> > initialize() const {
390 return std::make_unique<VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_> >(*this);
391 }
392};
393
426template<
427 typename Index_,
428 typename Data_,
429 typename Distance_,
430 class Matrix_ = Matrix<Index_, Data_>,
431 class DistanceMetric_ = DistanceMetric<Data_, Distance_>
432>
433class VptreeBuilder final : public Builder<Index_, Data_, Distance_, Matrix_> {
434public:
438 VptreeBuilder(std::shared_ptr<const DistanceMetric_> metric) : my_metric(std::move(metric)) {}
439
443 VptreeBuilder(const DistanceMetric_* metric) : VptreeBuilder(std::shared_ptr<const DistanceMetric_>(metric)) {}
444
445private:
446 std::shared_ptr<const DistanceMetric_> my_metric;
447
448public:
452 Prebuilt<Index_, Data_, Distance_>* build_raw(const Matrix_& data) const {
453 size_t ndim = data.num_dimensions();
454 size_t nobs = data.num_observations();
455 auto work = data.new_extractor();
456
457 std::vector<Data_> store(ndim * nobs);
458 for (size_t o = 0; o < nobs; ++o) {
459 std::copy_n(work->next(), ndim, store.begin() + o * ndim);
460 }
461
462 return new VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>(ndim, nobs, std::move(store), my_metric);
463 }
464};
465
466};
467
468#endif
Interface to build nearest-neighbor indices.
Interface for the input matrix.
Helper class to track nearest neighbors.
Interface for prebuilt nearest-neighbor indices.
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:28
Helper class to track nearest neighbors.
Definition NeighborQueue.hpp:30
void report(std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition NeighborQueue.hpp:109
void reset(Index_ k)
Definition NeighborQueue.hpp:46
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:26
Interface for searching nearest-neighbor search indices.
Definition Searcher.hpp:28
Perform a nearest neighbor search based on a vantage point (VP) tree.
Definition Vptree.hpp:433
VptreeBuilder(std::shared_ptr< const DistanceMetric_ > metric)
Definition Vptree.hpp:438
VptreeBuilder(const DistanceMetric_ *metric)
Definition Vptree.hpp:443
Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const
Definition Vptree.hpp:452
Index for a VP-tree search.
Definition Vptree.hpp:142
size_t num_dimensions() const
Definition Vptree.hpp:154
Index_ num_observations() const
Definition Vptree.hpp:150
std::unique_ptr< Searcher< Index_, Data_, Distance_ > > initialize() const
Definition Vptree.hpp:389
VP-tree searcher.
Definition Vptree.hpp:46
Index_ search_all(Index_ i, Distance_ d, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Vptree.hpp:93
void search(Index_ i, Index_ k, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Vptree.hpp:62
void search(const Data_ *query, Index_ k, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Vptree.hpp:70
Index_ search_all(const Data_ *query, Distance_ d, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Vptree.hpp:109
bool can_search_all() const
Definition Vptree.hpp:89
Classes for distance calculations.
Collection of KNN algorithms.
Definition Bruteforce.hpp:23
Index_ count_all_neighbors_without_self(Index_ count)
Definition report_all_neighbors.hpp:23
void report_all_neighbors(std::vector< std::pair< Distance_, Index_ > > &all_neighbors, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition report_all_neighbors.hpp:106
Format the output for Searcher::search_all().