knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
find_nearest_neighbors.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_FIND_NEAREST_NEIGHBORS_HPP
2#define KNNCOLLE_FIND_NEAREST_NEIGHBORS_HPP
3
4#include <vector>
5#include <utility>
6#include <type_traits>
7
8#include "Prebuilt.hpp"
9
16#ifndef KNNCOLLE_CUSTOM_PARALLEL
17#include "subpar/subpar.hpp"
18#endif
19
20namespace knncolle {
21
34template<typename Task_, class Run_>
35void parallelize(int num_workers, Task_ num_tasks, Run_ run_task_range) {
36#ifndef KNNCOLLE_CUSTOM_PARALLEL
37 // Don't make this nothrow_ = true, as the derived methods could do anything...
38 subpar::parallelize(num_workers, num_tasks, std::move(run_task_range));
39#else
40 KNNCOLLE_CUSTOM_PARALLEL(num_workers, num_tasks, run_task_range);
41#endif
42}
43
53template<typename Index_, typename Distance_>
54using NeighborList = std::vector<std::vector<std::pair<Index_, Distance_> > >;
55
68template<typename Index_>
69int cap_k(int k, Index_ num_observations) {
70 if constexpr(std::is_signed<Index_>::value) {
71 if (k < num_observations) {
72 return k;
73 }
74 } else {
75 if (static_cast<typename std::make_unsigned<Index_>::type>(k) < num_observations) {
76 return k;
77 }
78 }
79 if (num_observations) {
80 return num_observations - 1;
81 }
82 return 0;
83}
84
95template<typename Index_>
96int cap_k_query(int k, Index_ num_observations) {
97 if constexpr(std::is_signed<Index_>::value) {
98 if (k <= num_observations) {
99 return k;
100 }
101 } else {
102 if (static_cast<typename std::make_unsigned<Index_>::type>(k) <= num_observations) {
103 return k;
104 }
105 }
106 return num_observations;
107}
108
128template<typename Index_, typename Data_, typename Distance_>
130 Index_ nobs = index.num_observations();
131 k = cap_k(k, nobs);
133
134 parallelize(num_threads, nobs, [&](int, Index_ start, Index_ length) -> void {
135 auto sptr = index.initialize();
136 std::vector<Index_> indices;
137 std::vector<Distance_> distances;
138 for (Index_ i = start, end = start + length; i < end; ++i) {
139 sptr->search(i, k, &indices, &distances);
140 int actual_k = indices.size();
141 output[i].reserve(actual_k);
142 for (int j = 0; j < actual_k; ++j) {
143 output[i].emplace_back(indices[j], distances[j]);
144 }
145 }
146 });
147
148 return output;
149}
150
170template<typename Index_, typename Data_, typename Distance_>
171std::vector<std::vector<Index_> > find_nearest_neighbors_index_only(const Prebuilt<Index_, Data_, Distance_>& index, int k, int num_threads = 1) {
172 Index_ nobs = index.num_observations();
173 k = cap_k(k, nobs);
174 std::vector<std::vector<Index_> > output(nobs);
175
176 parallelize(num_threads, nobs, [&](int, Index_ start, Index_ length) -> void {
177 auto sptr = index.initialize();
178 for (Index_ i = start, end = start + length; i < end; ++i) {
179 sptr->search(i, k, &(output[i]), NULL);
180 }
181 });
182
183 return output;
184}
185
186}
187
188#endif
Interface for prebuilt nearest-neighbor indices.
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:28
virtual Index_ num_observations() const =0
virtual std::unique_ptr< Searcher< Index_, Data_, Distance_ > > initialize() const =0
Collection of KNN algorithms.
Definition Bruteforce.hpp:24
NeighborList< Index_, Distance_ > find_nearest_neighbors(const Prebuilt< Index_, Data_, Distance_ > &index, int k, int num_threads=1)
Definition find_nearest_neighbors.hpp:129
void parallelize(int num_workers, Task_ num_tasks, Run_ run_task_range)
Definition find_nearest_neighbors.hpp:35
std::vector< std::vector< Index_ > > find_nearest_neighbors_index_only(const Prebuilt< Index_, Data_, Distance_ > &index, int k, int num_threads=1)
Definition find_nearest_neighbors.hpp:171
int cap_k_query(int k, Index_ num_observations)
Definition find_nearest_neighbors.hpp:96
int cap_k(int k, Index_ num_observations)
Definition find_nearest_neighbors.hpp:69
std::vector< std::vector< std::pair< Index_, Distance_ > > > NeighborList
Definition find_nearest_neighbors.hpp:54