knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
find_nearest_neighbors.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_FIND_NEAREST_NEIGHBORS_HPP
2#define KNNCOLLE_FIND_NEAREST_NEIGHBORS_HPP
3
4#include <vector>
5#include <utility>
6#include <type_traits>
7
8#include "Prebuilt.hpp"
9
16#ifndef KNNCOLLE_CUSTOM_PARALLEL
17#include "subpar/subpar.hpp"
18#endif
19
20namespace knncolle {
21
34template<typename Task_, class Run_>
35void parallelize(int num_workers, Task_ num_tasks, Run_ run_task_range) {
36#ifndef KNNCOLLE_CUSTOM_PARALLEL
37 // Don't make this nothrow_ = true, as the derived methods could do anything...
38 subpar::parallelize(num_workers, num_tasks, std::move(run_task_range));
39#else
40 KNNCOLLE_CUSTOM_PARALLEL(num_workers, num_tasks, run_task_range);
41#endif
42}
43
53template<typename Index_ = int, typename Float_ = double>
54using NeighborList = std::vector<std::vector<std::pair<Index_, Float_> > >;
55
68template<typename Index_>
69int cap_k(int k, Index_ num_observations) {
70 if constexpr(std::is_signed<Index_>::value) {
71 if (k < num_observations) {
72 return k;
73 }
74 } else {
75 if (static_cast<typename std::make_unsigned<Index_>::type>(k) < num_observations) {
76 return k;
77 }
78 }
79 if (num_observations) {
80 return num_observations - 1;
81 }
82 return 0;
83}
84
104template<typename Dim_, typename Index_, typename Float_>
106 Index_ nobs = index.num_observations();
107 k = cap_k(k, nobs);
108 NeighborList<Index_, Float_> output(nobs);
109
110 parallelize(num_threads, nobs, [&](int, Index_ start, Index_ length) -> void {
111 auto sptr = index.initialize();
112 std::vector<Index_> indices;
113 std::vector<Float_> distances;
114 for (Index_ i = start, end = start + length; i < end; ++i) {
115 sptr->search(i, k, &indices, &distances);
116 int actual_k = indices.size();
117 output[i].reserve(actual_k);
118 for (int j = 0; j < actual_k; ++j) {
119 output[i].emplace_back(indices[j], distances[j]);
120 }
121 }
122 });
123
124 return output;
125}
126
146template<typename Dim_, typename Index_, typename Float_>
147std::vector<std::vector<Index_> > find_nearest_neighbors_index_only(const Prebuilt<Dim_, Index_, Float_>& index, int k, int num_threads = 1) {
148 Index_ nobs = index.num_observations();
149 k = cap_k(k, nobs);
150 std::vector<std::vector<Index_> > output(nobs);
151
152 parallelize(num_threads, nobs, [&](int, Index_ start, Index_ length) -> void {
153 auto sptr = index.initialize();
154 for (Index_ i = start, end = start + length; i < end; ++i) {
155 sptr->search(i, k, &(output[i]), NULL);
156 }
157 });
158
159 return output;
160}
161
162}
163
164#endif
Interface for prebuilt nearest-neighbor indices.
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:28
virtual Index_ num_observations() const =0
virtual std::unique_ptr< Searcher< Index_, Float_ > > initialize() const =0
Collection of KNN algorithms.
Definition Bruteforce.hpp:22
std::vector< std::vector< Index_ > > find_nearest_neighbors_index_only(const Prebuilt< Dim_, Index_, Float_ > &index, int k, int num_threads=1)
Definition find_nearest_neighbors.hpp:147
NeighborList< Index_, Float_ > find_nearest_neighbors(const Prebuilt< Dim_, Index_, Float_ > &index, int k, int num_threads=1)
Definition find_nearest_neighbors.hpp:105
std::vector< std::vector< std::pair< Index_, Float_ > > > NeighborList
Definition find_nearest_neighbors.hpp:54
void parallelize(int num_workers, Task_ num_tasks, Run_ run_task_range)
Definition find_nearest_neighbors.hpp:35
int cap_k(int k, Index_ num_observations)
Definition find_nearest_neighbors.hpp:69