knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
find_nearest_neighbors.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_FIND_NEAREST_NEIGHBORS_HPP
2#define KNNCOLLE_FIND_NEAREST_NEIGHBORS_HPP
3
4#include "Prebuilt.hpp"
5
6#include <vector>
7#include <utility>
8#include <type_traits>
9
10#include "sanisizer/sanisizer.hpp"
11#ifndef KNNCOLLE_CUSTOM_PARALLEL
12#include "subpar/subpar.hpp"
13#endif
14
15
22namespace knncolle {
23
36template<typename Task_, class Run_>
37void parallelize(int num_workers, Task_ num_tasks, Run_ run_task_range) {
38#ifndef KNNCOLLE_CUSTOM_PARALLEL
39 // Don't make this nothrow_ = true, as the derived methods could do anything...
40 subpar::parallelize(num_workers, num_tasks, std::move(run_task_range));
41#else
42 KNNCOLLE_CUSTOM_PARALLEL(num_workers, num_tasks, run_task_range);
43#endif
44}
45
55template<typename Index_, typename Distance_>
56using NeighborList = std::vector<std::vector<std::pair<Index_, Distance_> > >;
57
70template<typename Index_>
71int cap_k(int k, Index_ num_observations) {
72 if (sanisizer::is_less_than(sanisizer::attest_gez(k), sanisizer::attest_gez(num_observations))) {
73 return k;
74 } else if (num_observations) {
75 return num_observations - 1;
76 } else {
77 return 0;
78 }
79}
80
91template<typename Index_>
92int cap_k_query(int k, Index_ num_observations) {
93 return sanisizer::min(sanisizer::attest_gez(k), sanisizer::attest_gez(num_observations));
94}
95
115template<typename Index_, typename Data_, typename Distance_>
117 const Index_ nobs = index.num_observations();
118 k = cap_k(k, nobs);
119 auto output = sanisizer::create<NeighborList<Index_, Distance_> >(sanisizer::attest_gez(nobs));
120
121 parallelize(num_threads, nobs, [&](int, Index_ start, Index_ length) -> void {
122 auto sptr = index.initialize_known();
123 std::vector<Index_> indices;
124 std::vector<Distance_> distances;
125 for (Index_ i = start, end = start + length; i < end; ++i) {
126 sptr->search(i, k, &indices, &distances);
127 const auto actual_k = indices.size();
128 output[i].reserve(actual_k);
129 for (I<decltype(actual_k)> j = 0; j < actual_k; ++j) {
130 output[i].emplace_back(indices[j], distances[j]);
131 }
132 }
133 });
134
135 return output;
136}
137
157template<typename Index_, typename Data_, typename Distance_>
158std::vector<std::vector<Index_> > find_nearest_neighbors_index_only(const Prebuilt<Index_, Data_, Distance_>& index, int k, int num_threads = 1) {
159 const Index_ nobs = index.num_observations();
160 k = cap_k(k, nobs);
161 auto output = sanisizer::create<std::vector<std::vector<Index_> > >(sanisizer::attest_gez(nobs));
162
163 parallelize(num_threads, nobs, [&](int, Index_ start, Index_ length) -> void {
164 auto sptr = index.initialize_known();
165 for (Index_ i = start, end = start + length; i < end; ++i) {
166 sptr->search(i, k, &(output[i]), NULL);
167 }
168 });
169
170 return output;
171}
172
173}
174
175#endif
Interface for prebuilt nearest-neighbor indices.
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:29
auto initialize_known() const
Definition Prebuilt.hpp:98
virtual Index_ num_observations() const =0
Collection of KNN algorithms.
Definition Bruteforce.hpp:29
NeighborList< Index_, Distance_ > find_nearest_neighbors(const Prebuilt< Index_, Data_, Distance_ > &index, int k, int num_threads=1)
Definition find_nearest_neighbors.hpp:116
void parallelize(int num_workers, Task_ num_tasks, Run_ run_task_range)
Definition find_nearest_neighbors.hpp:37
std::vector< std::vector< Index_ > > find_nearest_neighbors_index_only(const Prebuilt< Index_, Data_, Distance_ > &index, int k, int num_threads=1)
Definition find_nearest_neighbors.hpp:158
int cap_k_query(int k, Index_ num_observations)
Definition find_nearest_neighbors.hpp:92
int cap_k(int k, Index_ num_observations)
Definition find_nearest_neighbors.hpp:71
std::vector< std::vector< std::pair< Index_, Distance_ > > > NeighborList
Definition find_nearest_neighbors.hpp:56