knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
Bruteforce.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_BRUTEFORCE_HPP
2#define KNNCOLLE_BRUTEFORCE_HPP
3
4#include "distances.hpp"
5#include "NeighborQueue.hpp"
6#include "Searcher.hpp"
7#include "Builder.hpp"
8#include "Prebuilt.hpp"
9#include "Matrix.hpp"
11#include "utils.hpp"
12
13#include <vector>
14#include <limits>
15#include <memory>
16#include <cstddef>
17#include <string>
18#include <cstring>
19#include <filesystem>
20
21#include "sanisizer/sanisizer.hpp"
22
29namespace knncolle {
30
34inline static constexpr const char* bruteforce_prebuilt_save_name = "knncolle::Bruteforce";
35
39template<typename Index_, typename Data_, typename Distance_, typename DistanceMetric_>
40class BruteforcePrebuilt;
41
42template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
43class BruteforceSearcher final : public Searcher<Index_, Data_, Distance_> {
44public:
45 BruteforceSearcher(const BruteforcePrebuilt<Index_, Data_, Distance_, DistanceMetric_>& parent) : my_parent(parent) {}
46
47private:
48 const BruteforcePrebuilt<Index_, Data_, Distance_, DistanceMetric_>& my_parent;
50 std::vector<std::pair<Distance_, Index_> > my_all_neighbors;
51
52private:
53 void normalize(std::vector<Distance_>* output_distances) const {
54 if (output_distances) {
55 for (auto& d : *output_distances) {
56 d = my_parent.my_metric->normalize(d);
57 }
58 }
59 }
60
61public:
62 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
63 my_nearest.reset(k + 1); // +1 is safe as k < num_obs.
64 auto ptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(i, my_parent.my_dim);
65 my_parent.search(ptr, my_nearest);
66 my_nearest.report(output_indices, output_distances, i);
67 normalize(output_distances);
68 }
69
70 void search(const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
71 if (k == 0) { // protect the NeighborQueue from k = 0.
72 if (output_indices) {
73 output_indices->clear();
74 }
75 if (output_distances) {
76 output_distances->clear();
77 }
78 } else {
79 my_nearest.reset(k);
80 my_parent.search(query, my_nearest);
81 my_nearest.report(output_indices, output_distances);
82 normalize(output_distances);
83 }
84 }
85
86 bool can_search_all() const {
87 return true;
88 }
89
90 Index_ search_all(Index_ i, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
91 auto ptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(i, my_parent.my_dim);
92
93 if (!output_indices && !output_distances) {
94 Index_ count = 0;
95 my_parent.template search_all<true>(ptr, d, count);
97
98 } else {
99 my_all_neighbors.clear();
100 my_parent.template search_all<false>(ptr, d, my_all_neighbors);
101 report_all_neighbors(my_all_neighbors, output_indices, output_distances, i);
102 normalize(output_distances);
103 return count_all_neighbors_without_self(my_all_neighbors.size());
104 }
105 }
106
107 Index_ search_all(const Data_* query, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
108 if (!output_indices && !output_distances) {
109 Index_ count = 0;
110 my_parent.template search_all<true>(query, d, count);
111 return count;
112
113 } else {
114 my_all_neighbors.clear();
115 my_parent.template search_all<false>(query, d, my_all_neighbors);
116 report_all_neighbors(my_all_neighbors, output_indices, output_distances);
117 normalize(output_distances);
118 return my_all_neighbors.size();
119 }
120 }
121};
122
123template<typename Index_, typename Data_, typename Distance_, class DistanceMetric_>
124class BruteforcePrebuilt final : public Prebuilt<Index_, Data_, Distance_> {
125private:
126 std::size_t my_dim;
127 Index_ my_obs;
128 std::vector<Data_> my_data;
129 std::shared_ptr<const DistanceMetric_> my_metric;
130
131public:
132 BruteforcePrebuilt(std::size_t num_dim, Index_ num_obs, std::vector<Data_> data, std::shared_ptr<const DistanceMetric_> metric) :
133 my_dim(num_dim), my_obs(num_obs), my_data(std::move(data)), my_metric(std::move(metric)) {}
134
135public:
136 std::size_t num_dimensions() const {
137 return my_dim;
138 }
139
140 Index_ num_observations() const {
141 return my_obs;
142 }
143
144private:
145 void search(const Data_* query, NeighborQueue<Index_, Distance_>& nearest) const {
146 Distance_ threshold_raw = std::numeric_limits<Distance_>::infinity();
147 for (Index_ x = 0; x < my_obs; ++x) {
148 auto dist_raw = my_metric->raw(my_dim, query, my_data.data() + sanisizer::product_unsafe<std::size_t>(x, my_dim));
149 if (dist_raw <= threshold_raw) {
150 nearest.add(x, dist_raw);
151 if (nearest.is_full()) {
152 threshold_raw = nearest.limit();
153 }
154 }
155 }
156 }
157
158 template<bool count_only_, typename Output_>
159 void search_all(const Data_* query, Distance_ threshold, Output_& all_neighbors) const {
160 Distance_ threshold_raw = my_metric->denormalize(threshold);
161 for (Index_ x = 0; x < my_obs; ++x) {
162 Distance_ raw = my_metric->raw(my_dim, query, my_data.data() + sanisizer::product_unsafe<std::size_t>(x, my_dim));
163 if (threshold_raw >= raw) {
164 if constexpr(count_only_) {
165 ++all_neighbors; // expect this to be an integer.
166 } else {
167 all_neighbors.emplace_back(raw, x); // expect this to be a vector of (distance, index) pairs.
168 }
169 }
170 }
171 }
172
173 friend class BruteforceSearcher<Index_, Data_, Distance_, DistanceMetric_>;
174
175public:
176 std::unique_ptr<Searcher<Index_, Data_, Distance_> > initialize() const {
177 return initialize_known();
178 }
179
180 auto initialize_known() const {
181 return std::make_unique<BruteforceSearcher<Index_, Data_, Distance_, DistanceMetric_> >(*this);
182 }
183
184public:
185 void save(const std::filesystem::path& dir) const {
186 quick_save(dir / "ALGORITHM", bruteforce_prebuilt_save_name, std::strlen(bruteforce_prebuilt_save_name));
187 quick_save(dir / "DATA", my_data.data(), my_data.size());
188 quick_save(dir / "NUM_OBS", &my_obs, 1);
189 quick_save(dir / "NUM_DIM", &my_dim, 1);
190
191 const auto distdir = dir / "DISTANCE";
192 std::filesystem::create_directory(distdir);
193 my_metric->save(distdir);
194 }
195
196 BruteforcePrebuilt(const std::filesystem::path& dir) {
197 quick_load(dir / "NUM_OBS", &my_obs, 1);
198 quick_load(dir / "NUM_DIM", &my_dim, 1);
199
200 my_data.resize(sanisizer::product<I<decltype(my_data.size())> >(sanisizer::attest_gez(my_obs), my_dim));
201 quick_load(dir / "DATA", my_data.data(), my_data.size());
202
203 auto dptr = load_distance_metric_raw<Data_, Distance_>(dir / "DISTANCE");
204 auto xptr = dynamic_cast<DistanceMetric_*>(dptr);
205 if (xptr == NULL) {
206 throw std::runtime_error("cannot cast the loaded distance metric to a DistanceMetric_");
207 }
208 my_metric.reset(xptr);
209 }
210};
231template<
232 typename Index_,
233 typename Data_,
234 typename Distance_,
235 class Matrix_ = Matrix<Index_, Data_>,
236 class DistanceMetric_ = DistanceMetric<Data_, Distance_>
237>
238class BruteforceBuilder final : public Builder<Index_, Data_, Distance_, Matrix_> {
239public:
243 BruteforceBuilder(std::shared_ptr<const DistanceMetric_> metric) : my_metric(std::move(metric)) {}
244
245private:
246 std::shared_ptr<const DistanceMetric_> my_metric;
247
248public:
252 Prebuilt<Index_, Data_, Distance_>* build_raw(const Matrix_& data) const {
253 return build_known_raw(data);
254 }
259public:
263 auto build_known_raw(const Matrix_& data) const {
264 std::size_t ndim = data.num_dimensions();
265 const Index_ nobs = data.num_observations();
266 auto work = data.new_known_extractor();
267
268 // We assume that that vector::size_type <= size_t, otherwise data() wouldn't be a contiguous array.
269 std::vector<Data_> store(sanisizer::product<typename std::vector<Data_>::size_type>(ndim, sanisizer::attest_gez(nobs)));
270 for (Index_ o = 0; o < nobs; ++o) {
271 std::copy_n(work->next(), ndim, store.data() + sanisizer::product_unsafe<std::size_t>(o, ndim));
272 }
273
274 return new BruteforcePrebuilt<Index_, Data_, Distance_, DistanceMetric_>(ndim, nobs, std::move(store), my_metric);
275 }
276
280 auto build_known_unique(const Matrix_& data) const {
281 return std::unique_ptr<I<decltype(*build_known_raw(data))> >(build_known_raw(data));
282 }
283
287 auto build_known_shared(const Matrix_& data) const {
288 return std::shared_ptr<I<decltype(*build_known_raw(data))> >(build_known_raw(data));
289 }
290};
291
292}
293
294#endif
Interface to build nearest-neighbor indices.
Interface for the input matrix.
Helper class to track nearest neighbors.
Interface for prebuilt nearest-neighbor indices.
Interface for searching nearest-neighbor indices.
Perform a brute-force nearest neighbor search.
Definition Bruteforce.hpp:238
BruteforceBuilder(std::shared_ptr< const DistanceMetric_ > metric)
Definition Bruteforce.hpp:243
auto build_known_shared(const Matrix_ &data) const
Definition Bruteforce.hpp:287
auto build_known_unique(const Matrix_ &data) const
Definition Bruteforce.hpp:280
auto build_known_raw(const Matrix_ &data) const
Definition Bruteforce.hpp:263
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:28
virtual Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const =0
Interface for a distance metric.
Definition distances.hpp:30
Interface for matrix data.
Definition Matrix.hpp:59
Helper class to track nearest neighbors.
Definition NeighborQueue.hpp:35
void report(std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition NeighborQueue.hpp:168
void add(Index_ i, Distance_ d)
Definition NeighborQueue.hpp:88
Distance_ limit() const
Definition NeighborQueue.hpp:76
bool is_full() const
Definition NeighborQueue.hpp:68
void reset(Index_ k)
Definition NeighborQueue.hpp:51
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:29
Interface for searching nearest-neighbor search indices.
Definition Searcher.hpp:28
Classes for distance calculations.
Collection of KNN algorithms.
Definition Bruteforce.hpp:29
void quick_load(const std::filesystem::path &path, Input_ *const contents, const Length_ length)
Definition utils.hpp:57
Index_ count_all_neighbors_without_self(Index_ count)
Definition report_all_neighbors.hpp:23
void quick_save(const std::filesystem::path &path, const Input_ *const contents, const Length_ length)
Definition utils.hpp:33
void report_all_neighbors(std::vector< std::pair< Distance_, Index_ > > &all_neighbors, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition report_all_neighbors.hpp:106
Format the output for Searcher::search_all().
Miscellaneous utilities for knncolle