knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
Kmknn.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_KMKNN_HPP
2#define KNNCOLLE_KMKNN_HPP
3
4#include "distances.hpp"
5#include "NeighborQueue.hpp"
6#include "Prebuilt.hpp"
7#include "Builder.hpp"
8#include "MockMatrix.hpp"
9#include "report_all_neighbors.hpp"
10
11#include "kmeans/kmeans.hpp"
12
13#include <algorithm>
14#include <vector>
15#include <memory>
16#include <limits>
17#include <cmath>
18
25namespace knncolle {
26
40template<typename Dim_ = int, typename Index_ = int, typename Store_ = double>
46 double power = 0.5;
47
48 // Note that we use Store_ as the k-means output type, as we'll
49 // be storing the cluster centers as Store_'s, not Float_'s.
50
55 std::shared_ptr<kmeans::Initialize<kmeans::SimpleMatrix<Store_, Index_, Dim_>, Index_, Store_> > initialize_algorithm;
56
61 std::shared_ptr<kmeans::Refine<kmeans::SimpleMatrix<Store_, Index_, Dim_>, Index_, Store_> > refine_algorithm;
62};
63
64
65template<class Distance_, typename Dim_, typename Index_, typename Store_, typename Float_>
66class KmknnPrebuilt;
67
79template<class Distance_, typename Dim_, typename Index_, typename Store_, typename Float_>
80class KmknnSearcher : public Searcher<Index_, Float_> {
81public:
86 center_order.reserve(my_parent->my_sizes.size());
87 }
92private:
94 internal::NeighborQueue<Index_, Float_> my_nearest;
95 std::vector<std::pair<Float_, Index_> > my_all_neighbors;
96 std::vector<std::pair<Float_, Index_> > center_order;
97
98public:
99 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
100 my_nearest.reset(k + 1);
101 auto new_i = my_parent->my_new_location[i];
102 auto iptr = my_parent->my_data.data() + static_cast<size_t>(new_i) * my_parent->my_long_ndim; // cast to avoid overflow.
103 my_parent->search_nn(iptr, my_nearest, center_order);
104 my_nearest.report(output_indices, output_distances, new_i);
105 my_parent->normalize(output_indices, output_distances);
106 }
107
108 void search(const Float_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
109 if (k == 0) { // protect the NeighborQueue from k = 0.
110 internal::flush_output(output_indices, output_distances, 0);
111 } else {
112 my_nearest.reset(k);
113 my_parent->search_nn(query, my_nearest, center_order);
114 my_nearest.report(output_indices, output_distances);
115 my_parent->normalize(output_indices, output_distances);
116 }
117 }
118
119 bool can_search_all() const {
120 return true;
121 }
122
123 Index_ search_all(Index_ i, Float_ d, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
124 auto new_i = my_parent->my_new_location[i];
125 auto iptr = my_parent->my_data.data() + static_cast<size_t>(new_i) * my_parent->my_long_ndim; // cast to avoid overflow.
126
128 Index_ count = 0;
129 my_parent->template search_all<true>(iptr, d, count);
130 return internal::safe_remove_self(count);
131
132 } else {
133 my_all_neighbors.clear();
134 my_parent->template search_all<false>(iptr, d, my_all_neighbors);
135 internal::report_all_neighbors(my_all_neighbors, output_indices, output_distances, new_i);
136 my_parent->normalize(output_indices, output_distances);
137 return internal::safe_remove_self(my_all_neighbors.size());
138 }
139 }
140
141 Index_ search_all(const Float_* query, Float_ d, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
143 Index_ count = 0;
144 my_parent->template search_all<true>(query, d, count);
145 return count;
146
147 } else {
148 my_all_neighbors.clear();
149 my_parent->template search_all<false>(query, d, my_all_neighbors);
150 internal::report_all_neighbors(my_all_neighbors, output_indices, output_distances);
151 my_parent->normalize(output_indices, output_distances);
152 return my_all_neighbors.size();
153 }
154 }
155};
156
172template<class Distance_, typename Dim_, typename Index_, typename Store_, typename Float_>
173class KmknnPrebuilt : public Prebuilt<Dim_, Index_, Float_> {
174private:
175 Dim_ my_dim;
176 Index_ my_obs;
177 size_t my_long_ndim;
178
179public:
180 Index_ num_observations() const {
181 return my_obs;
182 }
183
184 Dim_ num_dimensions() const {
185 return my_dim;
186 }
187
188private:
189 std::vector<Store_> my_data;
190
191 std::vector<Index_> my_sizes;
192 std::vector<Index_> my_offsets;
193 std::vector<Store_> my_centers;
194
195 std::vector<Index_> my_observation_id, my_new_location;
196 std::vector<Float_> my_dist_to_centroid;
197
198public:
205 KmknnPrebuilt(Dim_ num_dim, Index_ num_obs, std::vector<Store_> data, const KmknnOptions<Dim_, Index_, Store_>& options) :
206 my_dim(num_dim),
207 my_obs(num_obs),
208 my_long_ndim(my_dim),
209 my_data(std::move(data))
210 {
211 auto init = options.initialize_algorithm;
212 if (init == nullptr) {
214 }
215 auto refine = options.refine_algorithm;
216 if (refine == nullptr) {
218 }
219
220 Index_ ncenters = std::ceil(std::pow(my_obs, options.power));
221 my_centers.resize(static_cast<size_t>(ncenters) * my_long_ndim); // cast to avoid overflow problems.
222
223 kmeans::SimpleMatrix mat(my_dim, my_obs, my_data.data());
224 std::vector<Index_> clusters(my_obs);
225 auto output = kmeans::compute(mat, *init, *refine, ncenters, my_centers.data(), clusters.data());
226
227 // Removing empty clusters, e.g., due to duplicate points.
228 {
229 my_sizes.resize(ncenters);
230 std::vector<Index_> remap(ncenters);
231 Index_ survivors = 0;
232 for (Index_ c = 0; c < ncenters; ++c) {
233 if (output.sizes[c]) {
234 if (c > survivors) {
235 auto src = my_centers.begin() + static_cast<size_t>(c) * my_long_ndim; // cast to avoid overflow.
236 auto dest = my_centers.begin() + static_cast<size_t>(survivors) * my_long_ndim;
237 std::copy_n(src, my_dim, dest);
238 }
239 remap[c] = survivors;
240 my_sizes[survivors] = output.sizes[c];
241 ++survivors;
242 }
243 }
244
245 if (survivors < ncenters) {
246 for (auto& c : clusters) {
247 c = remap[c];
248 }
249 ncenters = survivors;
250 my_centers.resize(static_cast<size_t>(ncenters) * my_long_ndim);
251 my_sizes.resize(ncenters);
252 }
253 }
254
255 my_offsets.resize(ncenters);
256 for (Index_ i = 1; i < ncenters; ++i) {
257 my_offsets[i] = my_offsets[i - 1] + my_sizes[i - 1];
258 }
259
260 // Organize points correctly; firstly, sorting by distance from the assigned center.
261 std::vector<std::pair<Float_, Index_> > by_distance(my_obs);
262 {
263 auto sofar = my_offsets;
264 auto host = my_data.data();
265 for (Index_ o = 0; o < my_obs; ++o) {
266 auto optr = host + static_cast<size_t>(o) * my_long_ndim;
267 auto clustid = clusters[o];
268 auto cptr = my_centers.data() + static_cast<size_t>(clustid) * my_long_ndim;
269
270 auto& counter = sofar[clustid];
271 auto& current = by_distance[counter];
272 current.first = Distance_::normalize(Distance_::template raw_distance<Float_>(optr, cptr, my_dim));
273 current.second = o;
274
275 ++counter;
276 }
277
278 for (Index_ c = 0; c < ncenters; ++c) {
279 auto begin = by_distance.begin() + my_offsets[c];
280 std::sort(begin, begin + my_sizes[c]);
281 }
282 }
283
284 // Permuting in-place to mirror the reordered distances, so that the search is more cache-friendly.
285 {
286 auto host = my_data.data();
287 std::vector<uint8_t> used(my_obs);
288 std::vector<Store_> buffer(my_dim);
289 my_observation_id.resize(my_obs);
290 my_dist_to_centroid.resize(my_obs);
291 my_new_location.resize(my_obs);
292
293 for (Index_ o = 0; o < my_obs; ++o) {
294 if (used[o]) {
295 continue;
296 }
297
298 const auto& current = by_distance[o];
299 my_observation_id[o] = current.second;
300 my_dist_to_centroid[o] = current.first;
301 my_new_location[current.second] = o;
302 if (current.second == o) {
303 continue;
304 }
305
306 // We recursively perform a "thread" of replacements until we
307 // are able to find the home of the originally replaced 'o'.
308 auto optr = host + static_cast<size_t>(o) * my_long_ndim;
309 std::copy_n(optr, my_dim, buffer.begin());
310 Index_ replacement = current.second;
311 do {
312 auto rptr = host + static_cast<size_t>(replacement) * my_long_ndim;
313 std::copy_n(rptr, my_dim, optr);
314 used[replacement] = 1;
315
316 const auto& next = by_distance[replacement];
317 my_observation_id[replacement] = next.second;
318 my_dist_to_centroid[replacement] = next.first;
319 my_new_location[next.second] = replacement;
320
321 optr = rptr;
322 replacement = next.second;
323 } while (replacement != o);
324
325 std::copy(buffer.begin(), buffer.end(), optr);
326 }
327 }
328
329 return;
330 }
331
332private:
333 template<typename Query_>
334 void search_nn(const Query_* target, internal::NeighborQueue<Index_, Float_>& nearest, std::vector<std::pair<Float_, Index_> >& center_order) const {
335 /* Computing distances to all centers and sorting them. The aim is to
336 * go through the nearest centers first, to try to get the shortest
337 * threshold (i.e., 'nearest.limit()') possible at the start;
338 * this allows us to skip searches of the later clusters.
339 */
340 center_order.clear();
341 size_t ncenters = my_sizes.size();
342 center_order.reserve(ncenters);
343 auto clust_ptr = my_centers.data();
344 for (size_t c = 0; c < ncenters; ++c, clust_ptr += my_dim) {
345 center_order.emplace_back(Distance_::template raw_distance<Float_>(target, clust_ptr, my_dim), c);
346 }
347 std::sort(center_order.begin(), center_order.end());
348
349 // Computing the distance to each center, and deciding whether to proceed for each cluster.
350 Float_ threshold_raw = std::numeric_limits<Float_>::infinity();
351 for (const auto& curcent : center_order) {
352 const Index_ center = curcent.second;
353 const Float_ dist2center = Distance_::normalize(curcent.first);
354
355 const auto cur_nobs = my_sizes[center];
356 const Float_* dIt = my_dist_to_centroid.data() + my_offsets[center];
357 const Float_ maxdist = *(dIt + cur_nobs - 1);
358
359 Index_ firstcell = 0;
360#if KNNCOLLE_KMKNN_USE_UPPER
361 Float_ upper_bd = std::numeric_limits<Float_>::max();
362#endif
363
364 if (!std::isinf(threshold_raw)) {
365 const Float_ threshold = Distance_::normalize(threshold_raw);
366
367 /* The conditional expression below exploits the triangle inequality; it is equivalent to asking whether:
368 * threshold + maxdist < dist2center
369 * All points (if any) within this cluster with distances above 'lower_bd' are potentially countable.
370 */
371 const Float_ lower_bd = dist2center - threshold;
372 if (maxdist < lower_bd) {
373 continue;
374 }
375
376 firstcell = std::lower_bound(dIt, dIt + cur_nobs, lower_bd) - dIt;
377
378#if KNNCOLLE_KMKNN_USE_UPPER
379 /* This exploits the reverse triangle inequality, to ignore points where:
380 * threshold + dist2center < point-to-center distance
381 */
382 upper_bd = threshold + dist2center;
383#endif
384 }
385
386 const auto cur_start = my_offsets[center];
387 const auto* other_cell = my_data.data() + my_long_ndim * static_cast<size_t>(cur_start + firstcell); // cast to avoid overflow issues.
388 for (auto celldex = firstcell; celldex < cur_nobs; ++celldex, other_cell += my_dim) {
389#if KNNCOLLE_KMKNN_USE_UPPER
390 if (*(dIt + celldex) > upper_bd) {
391 break;
392 }
393#endif
394
395 auto dist2cell_raw = Distance_::template raw_distance<Float_>(target, other_cell, my_dim);
396 if (dist2cell_raw <= threshold_raw) {
397 nearest.add(cur_start + celldex, dist2cell_raw);
398 if (nearest.is_full()) {
399 threshold_raw = nearest.limit(); // Shrinking the threshold, if an earlier NN has been found.
400#if KNNCOLLE_KMKNN_USE_UPPER
401 upper_bd = Distance_::normalize(threshold_raw) + dist2center;
402#endif
403 }
404 }
405 }
406 }
407 }
408
409 template<bool count_only_, typename Query_, typename Output_>
410 void search_all(const Query_* target, Float_ threshold, Output_& all_neighbors) const {
411 Float_ threshold_raw = Distance_::denormalize(threshold);
412
413 /* Computing distances to all centers. We don't sort them here
414 * because the threshold is constant so there's no point.
415 */
416 Index_ ncenters = my_sizes.size();
417 auto center_ptr = my_centers.data();
418 for (Index_ center = 0; center < ncenters; ++center, center_ptr += my_dim) {
419 const Float_ dist2center = Distance_::normalize(Distance_::template raw_distance<Float_>(target, center_ptr, my_dim));
420
421 auto cur_nobs = my_sizes[center];
422 const Float_* dIt = my_dist_to_centroid.data() + my_offsets[center];
423 const Float_ maxdist = *(dIt + cur_nobs - 1);
424
425 /* The conditional expression below exploits the triangle inequality; it is equivalent to asking whether:
426 * threshold + maxdist < dist2center
427 * All points (if any) within this cluster with distances above 'lower_bd' are potentially countable.
428 */
429 const Float_ lower_bd = dist2center - threshold;
430 if (maxdist < lower_bd) {
431 continue;
432 }
433
434 Index_ firstcell = std::lower_bound(dIt, dIt + cur_nobs, lower_bd) - dIt;
435#if KNNCOLLE_KMKNN_USE_UPPER
436 /* This exploits the reverse triangle inequality, to ignore points where:
437 * threshold + dist2center < point-to-center distance
438 */
439 Float_ upper_bd = threshold + dist2center;
440#endif
441
442 const auto cur_start = my_offsets[center];
443 auto other_ptr = my_data.data() + my_long_ndim * static_cast<size_t>(cur_start + firstcell); // cast to avoid overflow issues.
444 for (auto celldex = firstcell; celldex < cur_nobs; ++celldex, other_ptr += my_dim) {
445#if KNNCOLLE_KMKNN_USE_UPPER
446 if (*(dIt + celldex) > upper_bd) {
447 break;
448 }
449#endif
450
451 auto dist2cell_raw = Distance_::template raw_distance<Float_>(target, other_ptr, my_dim);
452 if (dist2cell_raw <= threshold_raw) {
453 if constexpr(count_only_) {
454 ++all_neighbors;
455 } else {
456 all_neighbors.emplace_back(dist2cell_raw, cur_start + celldex);
457 }
458 }
459 }
460 }
461 }
462
463 void normalize(std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) const {
464 if (output_indices) {
465 for (auto& s : *output_indices) {
466 s = my_observation_id[s];
467 }
468 }
469 if (output_distances) {
470 for (auto& d : *output_distances) {
471 d = Distance_::normalize(d);
472 }
473 }
474 }
475
476 friend class KmknnSearcher<Distance_, Dim_, Index_, Store_, Float_>;
477
478public:
482 std::unique_ptr<Searcher<Index_, Float_> > initialize() const {
483 return std::make_unique<KmknnSearcher<Distance_, Dim_, Index_, Store_, Float_> >(this);
484 }
485};
486
506template<class Distance_ = EuclideanDistance, class Matrix_ = SimpleMatrix<int, int, double>, typename Float_ = double>
507class KmknnBuilder : public Builder<Matrix_, Float_> {
508public:
513
514private:
515 Options my_options;
516
517public:
521 KmknnBuilder(Options options) : my_options(std::move(options)) {}
522
526 KmknnBuilder() = default;
527
533 return my_options;
534 }
535
536public:
541 auto ndim = data.num_dimensions();
542 auto nobs = data.num_observations();
543
544 typedef typename Matrix_::data_type Store_;
545 std::vector<Store_> store(static_cast<size_t>(ndim) * static_cast<size_t>(nobs));
546
547 auto work = data.create_workspace();
548 auto sIt = store.begin();
549 for (decltype(nobs) o = 0; o < nobs; ++o, sIt += ndim) {
550 auto ptr = data.get_observation(work);
551 std::copy_n(ptr, ndim, sIt);
552 }
553
554 return new KmknnPrebuilt<Distance_, decltype(ndim), decltype(nobs), Store_, Float_>(ndim, nobs, std::move(store), my_options);
555 }
556};
557
558}
559
560#endif
Interface to build nearest-neighbor indices.
Interface for prebuilt nearest-neighbor indices.
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:22
Perform a nearest neighbor search based on k-means clustering.
Definition Kmknn.hpp:507
Prebuilt< typename Matrix_::dimension_type, typename Matrix_::index_type, Float_ > * build_raw(const Matrix_ &data) const
Definition Kmknn.hpp:540
KmknnOptions< typename Matrix_::dimension_type, typename Matrix_::index_type, typename Matrix_::data_type > Options
Definition Kmknn.hpp:512
KmknnBuilder(Options options)
Definition Kmknn.hpp:521
Options & get_options()
Definition Kmknn.hpp:532
Index for a KMKNN search.
Definition Kmknn.hpp:173
std::unique_ptr< Searcher< Index_, Float_ > > initialize() const
Definition Kmknn.hpp:482
KmknnPrebuilt(Dim_ num_dim, Index_ num_obs, std::vector< Store_ > data, const KmknnOptions< Dim_, Index_, Store_ > &options)
Definition Kmknn.hpp:205
Index_ num_observations() const
Definition Kmknn.hpp:180
Dim_ num_dimensions() const
Definition Kmknn.hpp:184
KMKNN searcher.
Definition Kmknn.hpp:80
Index_ search_all(const Float_ *query, Float_ d, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)
Definition Kmknn.hpp:141
void search(Index_ i, Index_ k, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)
Definition Kmknn.hpp:99
bool can_search_all() const
Definition Kmknn.hpp:119
void search(const Float_ *query, Index_ k, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)
Definition Kmknn.hpp:108
Index_ search_all(Index_ i, Float_ d, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)
Definition Kmknn.hpp:123
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:28
Interface for searching nearest-neighbor search indices.
Definition Searcher.hpp:28
Classes for distance calculations.
Details< typename Matrix_::index_type > compute(const Matrix_ &data, const Initialize< Matrix_, Cluster_, Float_ > &initialize, const Refine< Matrix_, Cluster_, Float_ > &refine, Cluster_ num_centers, Float_ *centers, Cluster_ *clusters)
Collection of KNN algorithms.
Definition Bruteforce.hpp:22
Options for KmknnBuilder and KmknnPrebuilt construction.
Definition Kmknn.hpp:41
std::shared_ptr< kmeans::Initialize< kmeans::SimpleMatrix< Store_, Index_, Dim_ >, Index_, Store_ > > initialize_algorithm
Definition Kmknn.hpp:55
double power
Definition Kmknn.hpp:46
std::shared_ptr< kmeans::Refine< kmeans::SimpleMatrix< Store_, Index_, Dim_ >, Index_, Store_ > > refine_algorithm
Definition Kmknn.hpp:61