knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
Bruteforce.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_BRUTEFORCE_HPP
2#define KNNCOLLE_BRUTEFORCE_HPP
3
4#include "distances.hpp"
5#include "NeighborQueue.hpp"
6#include "Searcher.hpp"
7#include "Builder.hpp"
8#include "Prebuilt.hpp"
9#include "Matrix.hpp"
10#include "report_all_neighbors.hpp"
11
12#include <vector>
13#include <type_traits>
14#include <limits>
15#include <memory>
16
23namespace knncolle {
24
25template<typename Index_, typename Data_, typename Distance_, typename DistanceMetric_>
27
39template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
40class BruteforceSearcher final : public Searcher<Index_, Data_, Distance_> {
41public:
50private:
52 internal::NeighborQueue<Index_, Distance_> my_nearest;
53 std::vector<std::pair<Distance_, Index_> > my_all_neighbors;
54
55private:
56 void normalize(std::vector<Distance_>* output_distances) const {
57 if (output_distances) {
58 for (auto& d : *output_distances) {
59 d = my_parent.my_metric->normalize(d);
60 }
61 }
62 }
63
64public:
65 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
66 my_nearest.reset(k + 1);
67 auto ptr = my_parent.my_data.data() + static_cast<size_t>(i) * my_parent.my_dim; // cast to avoid overflow.
68 my_parent.search(ptr, my_nearest);
69 my_nearest.report(output_indices, output_distances, i);
70 normalize(output_distances);
71 }
72
73 void search(const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
74 if (k == 0) { // protect the NeighborQueue from k = 0.
75 internal::flush_output(output_indices, output_distances, 0);
76 } else {
77 my_nearest.reset(k);
78 my_parent.search(query, my_nearest);
79 my_nearest.report(output_indices, output_distances);
80 normalize(output_distances);
81 }
82 }
83
84 bool can_search_all() const {
85 return true;
86 }
87
88 Index_ search_all(Index_ i, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
89 auto ptr = my_parent.my_data.data() + static_cast<size_t>(i) * my_parent.my_dim; // cast to avoid overflow.
90
91 if (!output_indices && !output_distances) {
92 Index_ count = 0;
93 my_parent.template search_all<true>(ptr, d, count);
94 return internal::safe_remove_self(count);
95
96 } else {
97 my_all_neighbors.clear();
98 my_parent.template search_all<false>(ptr, d, my_all_neighbors);
99 internal::report_all_neighbors(my_all_neighbors, output_indices, output_distances, i);
100 normalize(output_distances);
101 return internal::safe_remove_self(my_all_neighbors.size());
102 }
103 }
104
105 Index_ search_all(const Data_* query, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
106 if (!output_indices && !output_distances) {
107 Index_ count = 0;
108 my_parent.template search_all<true>(query, d, count);
109 return count;
110
111 } else {
112 my_all_neighbors.clear();
113 my_parent.template search_all<false>(query, d, my_all_neighbors);
114 internal::report_all_neighbors(my_all_neighbors, output_indices, output_distances);
115 normalize(output_distances);
116 return my_all_neighbors.size();
117 }
118 }
119};
120
132template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
133class BruteforcePrebuilt final : public Prebuilt<Index_, Data_, Distance_> {
134private:
135 size_t my_dim;
136 Index_ my_obs;
137 std::vector<Data_> my_data;
138 std::shared_ptr<const DistanceMetric_> my_metric;
139
140public:
144 BruteforcePrebuilt(size_t num_dim, Index_ num_obs, std::vector<Data_> data, std::shared_ptr<const DistanceMetric_> metric) :
145 my_dim(num_dim), my_obs(num_obs), my_data(std::move(data)), my_metric(std::move(metric)) {}
150public:
151 size_t num_dimensions() const {
152 return my_dim;
153 }
154
155 Index_ num_observations() const {
156 return my_obs;
157 }
158
159private:
160 void search(const Data_* query, internal::NeighborQueue<Index_, Distance_>& nearest) const {
161 auto copy = my_data.data();
162 Distance_ threshold_raw = std::numeric_limits<Distance_>::infinity();
163 for (Index_ x = 0; x < my_obs; ++x, copy += my_dim) {
164 auto dist_raw = my_metric->raw(my_dim, query, copy);
165 if (dist_raw <= threshold_raw) {
166 nearest.add(x, dist_raw);
167 if (nearest.is_full()) {
168 threshold_raw = nearest.limit();
169 }
170 }
171 }
172 }
173
174 template<bool count_only_, typename Output_>
175 void search_all(const Data_* query, Distance_ threshold, Output_& all_neighbors) const {
176 Distance_ threshold_raw = my_metric->denormalize(threshold);
177 auto copy = my_data.data();
178 for (Index_ x = 0; x < my_obs; ++x, copy += my_dim) {
179 Distance_ raw = my_metric->raw(my_dim, query, copy);
180 if (threshold_raw >= raw) {
181 if constexpr(count_only_) {
182 ++all_neighbors; // expect this to be an integer.
183 } else {
184 all_neighbors.emplace_back(raw, x); // expect this to be a vector of (distance, index) pairs.
185 }
186 }
187 }
188 }
189
190 friend class BruteforceSearcher<Index_, Data_, Distance_, DistanceMetric_>;
191
192public:
196 std::unique_ptr<Searcher<Index_, Data_, Distance_> > initialize() const {
197 return std::make_unique<BruteforceSearcher<Index_, Data_, Distance_, DistanceMetric_> >(*this);
198 }
199};
200
217template<
218 typename Index_,
219 typename Data_,
220 typename Distance_,
221 class Matrix_ = Matrix<Index_, Data_>,
222 class DistanceMetric_ = DistanceMetric<Data_, Distance_>
223>
224class BruteforceBuilder final : public Builder<Index_, Data_, Distance_, Matrix_> {
225public:
229 BruteforceBuilder(std::shared_ptr<const DistanceMetric_> metric) : my_metric(std::move(metric)) {}
230
234 BruteforceBuilder(const DistanceMetric_* metric) : BruteforceBuilder(std::shared_ptr<const DistanceMetric_>(metric)) {}
235
236private:
237 std::shared_ptr<const DistanceMetric_> my_metric;
238
239public:
243 Prebuilt<Index_, Data_, Distance_>* build_raw(const Matrix_& data) const {
244 size_t ndim = data.num_dimensions();
245 size_t nobs = data.num_observations();
246 auto work = data.new_extractor();
247
248 std::vector<Data_> store(ndim * nobs);
249 for (size_t o = 0; o < nobs; ++o) {
250 std::copy_n(work->next(), ndim, store.begin() + o * ndim);
251 }
252
253 return new BruteforcePrebuilt<Index_, Data_, Distance_, DistanceMetric_>(ndim, nobs, std::move(store), my_metric);
254 }
255};
256
257}
258
259#endif
Interface to build nearest-neighbor indices.
Interface for the input matrix.
Interface for prebuilt nearest-neighbor indices.
Interface for searching nearest-neighbor indices.
Perform a brute-force nearest neighbor search.
Definition Bruteforce.hpp:224
BruteforceBuilder(std::shared_ptr< const DistanceMetric_ > metric)
Definition Bruteforce.hpp:229
Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const
Definition Bruteforce.hpp:243
BruteforceBuilder(const DistanceMetric_ *metric)
Definition Bruteforce.hpp:234
Index for a brute-force nearest neighbor search.
Definition Bruteforce.hpp:133
Index_ num_observations() const
Definition Bruteforce.hpp:155
std::unique_ptr< Searcher< Index_, Data_, Distance_ > > initialize() const
Definition Bruteforce.hpp:196
size_t num_dimensions() const
Definition Bruteforce.hpp:151
Brute-force nearest neighbor searcher.
Definition Bruteforce.hpp:40
void search(Index_ i, Index_ k, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Bruteforce.hpp:65
Index_ search_all(Index_ i, Distance_ d, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Bruteforce.hpp:88
bool can_search_all() const
Definition Bruteforce.hpp:84
Index_ search_all(const Data_ *query, Distance_ d, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Bruteforce.hpp:105
void search(const Data_ *query, Index_ k, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Bruteforce.hpp:73
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:28
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:26
Interface for searching nearest-neighbor search indices.
Definition Searcher.hpp:28
Classes for distance calculations.
Collection of KNN algorithms.
Definition Bruteforce.hpp:23