knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
Bruteforce.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_BRUTEFORCE_HPP
2#define KNNCOLLE_BRUTEFORCE_HPP
3
4#include "distances.hpp"
5#include "NeighborQueue.hpp"
6#include "Searcher.hpp"
7#include "Builder.hpp"
8#include "Prebuilt.hpp"
9#include "Matrix.hpp"
11
12#include <vector>
13#include <type_traits>
14#include <limits>
15#include <memory>
16#include <cstddef>
17
24namespace knncolle {
25
26template<typename Index_, typename Data_, typename Distance_, typename DistanceMetric_>
28
40template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
41class BruteforceSearcher final : public Searcher<Index_, Data_, Distance_> {
42public:
51private:
54 std::vector<std::pair<Distance_, Index_> > my_all_neighbors;
55
56private:
57 void normalize(std::vector<Distance_>* output_distances) const {
58 if (output_distances) {
59 for (auto& d : *output_distances) {
60 d = my_parent.my_metric->normalize(d);
61 }
62 }
63 }
64
65public:
66 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
67 my_nearest.reset(k + 1);
68 auto ptr = my_parent.my_data.data() + static_cast<std::size_t>(i) * my_parent.my_dim; // cast to avoid overflow.
69 my_parent.search(ptr, my_nearest);
70 my_nearest.report(output_indices, output_distances, i);
71 normalize(output_distances);
72 }
73
74 void search(const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
75 if (k == 0) { // protect the NeighborQueue from k = 0.
76 if (output_indices) {
77 output_indices->clear();
78 }
79 if (output_distances) {
80 output_distances->clear();
81 }
82 } else {
83 my_nearest.reset(k);
84 my_parent.search(query, my_nearest);
85 my_nearest.report(output_indices, output_distances);
86 normalize(output_distances);
87 }
88 }
89
90 bool can_search_all() const {
91 return true;
92 }
93
94 Index_ search_all(Index_ i, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
95 auto ptr = my_parent.my_data.data() + static_cast<std::size_t>(i) * my_parent.my_dim; // cast to avoid overflow.
96
97 if (!output_indices && !output_distances) {
98 Index_ count = 0;
99 my_parent.template search_all<true>(ptr, d, count);
101
102 } else {
103 my_all_neighbors.clear();
104 my_parent.template search_all<false>(ptr, d, my_all_neighbors);
105 report_all_neighbors(my_all_neighbors, output_indices, output_distances, i);
106 normalize(output_distances);
107 return count_all_neighbors_without_self(my_all_neighbors.size());
108 }
109 }
110
111 Index_ search_all(const Data_* query, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
112 if (!output_indices && !output_distances) {
113 Index_ count = 0;
114 my_parent.template search_all<true>(query, d, count);
115 return count;
116
117 } else {
118 my_all_neighbors.clear();
119 my_parent.template search_all<false>(query, d, my_all_neighbors);
120 report_all_neighbors(my_all_neighbors, output_indices, output_distances);
121 normalize(output_distances);
122 return my_all_neighbors.size();
123 }
124 }
125};
126
138template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
139class BruteforcePrebuilt final : public Prebuilt<Index_, Data_, Distance_> {
140private:
141 std::size_t my_dim;
142 Index_ my_obs;
143 std::vector<Data_> my_data;
144 std::shared_ptr<const DistanceMetric_> my_metric;
145
146public:
150 BruteforcePrebuilt(std::size_t num_dim, Index_ num_obs, std::vector<Data_> data, std::shared_ptr<const DistanceMetric_> metric) :
151 my_dim(num_dim), my_obs(num_obs), my_data(std::move(data)), my_metric(std::move(metric)) {}
156public:
157 std::size_t num_dimensions() const {
158 return my_dim;
159 }
160
161 Index_ num_observations() const {
162 return my_obs;
163 }
164
165private:
166 void search(const Data_* query, NeighborQueue<Index_, Distance_>& nearest) const {
167 auto copy = my_data.data();
168 Distance_ threshold_raw = std::numeric_limits<Distance_>::infinity();
169 for (Index_ x = 0; x < my_obs; ++x, copy += my_dim) {
170 auto dist_raw = my_metric->raw(my_dim, query, copy);
171 if (dist_raw <= threshold_raw) {
172 nearest.add(x, dist_raw);
173 if (nearest.is_full()) {
174 threshold_raw = nearest.limit();
175 }
176 }
177 }
178 }
179
180 template<bool count_only_, typename Output_>
181 void search_all(const Data_* query, Distance_ threshold, Output_& all_neighbors) const {
182 Distance_ threshold_raw = my_metric->denormalize(threshold);
183 auto copy = my_data.data();
184 for (Index_ x = 0; x < my_obs; ++x, copy += my_dim) {
185 Distance_ raw = my_metric->raw(my_dim, query, copy);
186 if (threshold_raw >= raw) {
187 if constexpr(count_only_) {
188 ++all_neighbors; // expect this to be an integer.
189 } else {
190 all_neighbors.emplace_back(raw, x); // expect this to be a vector of (distance, index) pairs.
191 }
192 }
193 }
194 }
195
196 friend class BruteforceSearcher<Index_, Data_, Distance_, DistanceMetric_>;
197
198public:
202 std::unique_ptr<Searcher<Index_, Data_, Distance_> > initialize() const {
203 return std::make_unique<BruteforceSearcher<Index_, Data_, Distance_, DistanceMetric_> >(*this);
204 }
205};
206
223template<
224 typename Index_,
225 typename Data_,
226 typename Distance_,
227 class Matrix_ = Matrix<Index_, Data_>,
228 class DistanceMetric_ = DistanceMetric<Data_, Distance_>
229>
230class BruteforceBuilder final : public Builder<Index_, Data_, Distance_, Matrix_> {
231public:
235 BruteforceBuilder(std::shared_ptr<const DistanceMetric_> metric) : my_metric(std::move(metric)) {}
236
237private:
238 std::shared_ptr<const DistanceMetric_> my_metric;
239
240public:
244 Prebuilt<Index_, Data_, Distance_>* build_raw(const Matrix_& data) const {
245 std::size_t ndim = data.num_dimensions();
246 Index_ nobs = data.num_observations();
247 auto work = data.new_extractor();
248
249 std::vector<Data_> store(ndim * static_cast<std::size_t>(nobs)); // cast to avoid overflow.
250 for (Index_ o = 0; o < nobs; ++o) {
251 std::copy_n(work->next(), ndim, store.begin() + static_cast<std::size_t>(o) * ndim); // cast to size_t to avoid overflow.
252 }
253
254 return new BruteforcePrebuilt<Index_, Data_, Distance_, DistanceMetric_>(ndim, nobs, std::move(store), my_metric);
255 }
256};
257
258}
259
260#endif
Interface to build nearest-neighbor indices.
Interface for the input matrix.
Helper class to track nearest neighbors.
Interface for prebuilt nearest-neighbor indices.
Interface for searching nearest-neighbor indices.
Perform a brute-force nearest neighbor search.
Definition Bruteforce.hpp:230
BruteforceBuilder(std::shared_ptr< const DistanceMetric_ > metric)
Definition Bruteforce.hpp:235
Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const
Definition Bruteforce.hpp:244
Index for a brute-force nearest neighbor search.
Definition Bruteforce.hpp:139
Index_ num_observations() const
Definition Bruteforce.hpp:161
std::unique_ptr< Searcher< Index_, Data_, Distance_ > > initialize() const
Definition Bruteforce.hpp:202
std::size_t num_dimensions() const
Definition Bruteforce.hpp:157
Brute-force nearest neighbor searcher.
Definition Bruteforce.hpp:41
void search(Index_ i, Index_ k, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Bruteforce.hpp:66
Index_ search_all(Index_ i, Distance_ d, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Bruteforce.hpp:94
bool can_search_all() const
Definition Bruteforce.hpp:90
Index_ search_all(const Data_ *query, Distance_ d, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Bruteforce.hpp:111
void search(const Data_ *query, Index_ k, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Bruteforce.hpp:74
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:28
Helper class to track nearest neighbors.
Definition NeighborQueue.hpp:30
void report(std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition NeighborQueue.hpp:109
void add(Index_ i, Distance_ d)
Definition NeighborQueue.hpp:83
Distance_ limit() const
Definition NeighborQueue.hpp:71
bool is_full() const
Definition NeighborQueue.hpp:63
void reset(Index_ k)
Definition NeighborQueue.hpp:46
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:28
Interface for searching nearest-neighbor search indices.
Definition Searcher.hpp:28
Classes for distance calculations.
Collection of KNN algorithms.
Definition Bruteforce.hpp:24
Index_ count_all_neighbors_without_self(Index_ count)
Definition report_all_neighbors.hpp:23
void report_all_neighbors(std::vector< std::pair< Distance_, Index_ > > &all_neighbors, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition report_all_neighbors.hpp:106
Format the output for Searcher::search_all().