knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
Bruteforce.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_BRUTEFORCE_HPP
2#define KNNCOLLE_BRUTEFORCE_HPP
3
4#include "distances.hpp"
5#include "NeighborQueue.hpp"
6#include "Searcher.hpp"
7#include "Builder.hpp"
8#include "Prebuilt.hpp"
9#include "MockMatrix.hpp"
10#include "report_all_neighbors.hpp"
11
12#include <vector>
13#include <type_traits>
14#include <limits>
15
22namespace knncolle {
23
24template<class Distance_, typename Dim_, typename Index_, typename Store_, typename Float_>
26
38template<class Distance_, typename Dim_, typename Index_, typename Store_, typename Float_>
39class BruteforceSearcher : public Searcher<Index_, Float_> {
40public:
49private:
51 internal::NeighborQueue<Index_, Float_> my_nearest;
52 std::vector<std::pair<Float_, Index_> > my_all_neighbors;
53
54private:
55 static void normalize(std::vector<Float_>* output_distances) {
56 if (output_distances) {
57 for (auto& d : *output_distances) {
58 d = Distance_::normalize(d);
59 }
60 }
61 }
62
63public:
64 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
65 my_nearest.reset(k + 1);
66 auto ptr = my_parent->my_data.data() + static_cast<size_t>(i) * my_parent->my_long_ndim; // cast to avoid overflow.
67 my_parent->search(ptr, my_nearest);
68 my_nearest.report(output_indices, output_distances, i);
69 normalize(output_distances);
70 }
71
72 void search(const Float_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
73 if (k == 0) { // protect the NeighborQueue from k = 0.
74 internal::flush_output(output_indices, output_distances, 0);
75 } else {
76 my_nearest.reset(k);
77 my_parent->search(query, my_nearest);
78 my_nearest.report(output_indices, output_distances);
79 normalize(output_distances);
80 }
81 }
82
83 bool can_search_all() const {
84 return true;
85 }
86
87 Index_ search_all(Index_ i, Float_ d, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
88 auto ptr = my_parent->my_data.data() + static_cast<size_t>(i) * my_parent->my_long_ndim; // cast to avoid overflow.
89
90 if (!output_indices && !output_distances) {
91 Index_ count = 0;
92 my_parent->template search_all<true>(ptr, d, count);
93 return internal::safe_remove_self(count);
94
95 } else {
96 my_all_neighbors.clear();
97 my_parent->template search_all<false>(ptr, d, my_all_neighbors);
98 internal::report_all_neighbors(my_all_neighbors, output_indices, output_distances, i);
99 normalize(output_distances);
100 return internal::safe_remove_self(my_all_neighbors.size());
101 }
102 }
103
104 Index_ search_all(const Float_* query, Float_ d, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
106 Index_ count = 0;
107 my_parent->template search_all<true>(query, d, count);
108 return count;
109
110 } else {
111 my_all_neighbors.clear();
112 my_parent->template search_all<false>(query, d, my_all_neighbors);
113 internal::report_all_neighbors(my_all_neighbors, output_indices, output_distances);
114 normalize(output_distances);
115 return my_all_neighbors.size();
116 }
117 }
118};
119
135template<class Distance_, typename Dim_, typename Index_, typename Store_, typename Float_>
136class BruteforcePrebuilt : public Prebuilt<Dim_, Index_, Float_> {
137private:
138 Dim_ my_dim;
139 Index_ my_obs;
140 size_t my_long_ndim;
141 std::vector<Store_> my_data;
142
143public:
147 BruteforcePrebuilt(Dim_ num_dim, Index_ num_obs, std::vector<Store_> data) :
148 my_dim(num_dim), my_obs(num_obs), my_long_ndim(num_dim), my_data(std::move(data)) {}
153public:
154 Dim_ num_dimensions() const {
155 return my_dim;
156 }
157
158 Index_ num_observations() const {
159 return my_obs;
160 }
161
162private:
163 template<typename Query_>
164 void search(const Query_* query, internal::NeighborQueue<Index_, Float_>& nearest) const {
165 auto copy = my_data.data();
166 Float_ threshold_raw = std::numeric_limits<Float_>::infinity();
167 for (Index_ x = 0; x < my_obs; ++x, copy += my_dim) {
168 auto dist_raw = Distance_::template raw_distance<Float_>(query, copy, my_dim);
169 if (dist_raw <= threshold_raw) {
170 nearest.add(x, dist_raw);
171 if (nearest.is_full()) {
172 threshold_raw = nearest.limit();
173 }
174 }
175 }
176 }
177
178 template<bool count_only_, typename Query_, typename Output_>
179 void search_all(const Query_* query, Float_ threshold, Output_& all_neighbors) const {
180 Float_ threshold_raw = Distance_::denormalize(threshold);
181 auto copy = my_data.data();
182 for (Index_ x = 0; x < my_obs; ++x, copy += my_dim) {
183 Float_ raw_distance = Distance_::template raw_distance<Float_>(query, copy, my_dim);
184 if (threshold_raw >= raw_distance) {
185 if constexpr(count_only_) {
186 ++all_neighbors; // expect this to be an integer.
187 } else {
188 all_neighbors.emplace_back(raw_distance, x); // expect this to be a vector of (distance, index) pairs.
189 }
190 }
191 }
192 }
193
194 friend class BruteforceSearcher<Distance_, Dim_, Index_, Store_, Float_>;
195
196public:
200 std::unique_ptr<Searcher<Index_, Float_> > initialize() const {
201 return std::make_unique<BruteforceSearcher<Distance_, Dim_, Index_, Store_, Float_> >(this);
202 }
203};
204
217template<class Distance_ = EuclideanDistance, class Matrix_ = SimpleMatrix<int, int, double>, typename Float_ = double>
218class BruteforceBuilder : public Builder<Matrix_, Float_> {
219public:
224 auto ndim = data.num_dimensions();
225 auto nobs = data.num_observations();
226
227 typedef decltype(ndim) Dim_;
228 typedef decltype(nobs) Index_;
229 typedef typename Matrix_::data_type Store_;
230 std::vector<typename Matrix_::data_type> store(static_cast<size_t>(ndim) * static_cast<size_t>(nobs));
231
232 auto work = data.create_workspace();
233 auto sIt = store.begin();
234 for (decltype(nobs) o = 0; o < nobs; ++o, sIt += ndim) {
235 auto ptr = data.get_observation(work);
236 std::copy(ptr, ptr + ndim, sIt);
237 }
238
239 return new BruteforcePrebuilt<Distance_, Dim_, Index_, Store_, Float_>(ndim, nobs, std::move(store));
240 }
241};
242
243}
244
245#endif
Interface to build nearest-neighbor indices.
Interface for prebuilt nearest-neighbor indices.
Interface for searching nearest-neighbor indices.
Perform a brute-force nearest neighbor search.
Definition Bruteforce.hpp:218
Prebuilt< typename Matrix_::dimension_type, typename Matrix_::index_type, Float_ > * build_raw(const Matrix_ &data) const
Definition Bruteforce.hpp:223
Index for a brute-force nearest neighbor search.
Definition Bruteforce.hpp:136
Index_ num_observations() const
Definition Bruteforce.hpp:158
std::unique_ptr< Searcher< Index_, Float_ > > initialize() const
Definition Bruteforce.hpp:200
Dim_ num_dimensions() const
Definition Bruteforce.hpp:154
Brute-force nearest neighbor searcher.
Definition Bruteforce.hpp:39
Index_ search_all(Index_ i, Float_ d, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)
Definition Bruteforce.hpp:87
void search(Index_ i, Index_ k, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)
Definition Bruteforce.hpp:64
bool can_search_all() const
Definition Bruteforce.hpp:83
Index_ search_all(const Float_ *query, Float_ d, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)
Definition Bruteforce.hpp:104
void search(const Float_ *query, Index_ k, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)
Definition Bruteforce.hpp:72
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:22
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:28
Interface for searching nearest-neighbor search indices.
Definition Searcher.hpp:28
Classes for distance calculations.
Collection of KNN algorithms.
Definition Bruteforce.hpp:22