knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
L2Normalized.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_L2_NORMALIZED_HP
2#define KNNCOLLE_L2_NORMALIZED_HP
3
4#include <vector>
5#include <cmath>
6#include <memory>
7#include <limits>
8#include <cstddef>
9
10#include "Searcher.hpp"
11#include "Prebuilt.hpp"
12#include "Builder.hpp"
13#include "Matrix.hpp"
14
20namespace knncolle {
21
25namespace internal {
26
27template<typename Data_, typename Normalized_>
28void l2norm(const Data_* ptr, std::size_t ndim, Normalized_* buffer) {
29 Normalized_ l2 = 0;
30 for (std::size_t d = 0; d < ndim; ++d) {
31 Normalized_ val = ptr[d]; // cast to Normalized_ to avoid issues with integer overflow.
32 buffer[d] = val;
33 l2 += val * val;
34 }
35
36 if (l2 > 0) {
37 l2 = std::sqrt(l2);
38 for (std::size_t d = 0; d < ndim; ++d) {
39 buffer[d] /= l2;
40 }
41 }
42}
43
44}
60template<typename Index_, typename Data_, typename Distance_, typename Normalized_>
61class L2NormalizedSearcher final : public Searcher<Index_, Data_, Distance_> {
62public:
67 L2NormalizedSearcher(std::unique_ptr<Searcher<Index_, Normalized_, Distance_> > searcher, std::size_t num_dimensions) :
68 my_searcher(std::move(searcher)),
69 buffer(num_dimensions)
70 {}
71
72private:
73 // No way around this; the L2-normalized values must be floating-point,
74 // so the internal searcher must accept floats.
75 static_assert(std::is_floating_point<Normalized_>::value);
76
77 std::unique_ptr<Searcher<Index_, Normalized_, Distance_> > my_searcher;
78 std::vector<Normalized_> buffer;
83public:
84 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
85 my_searcher->search(i, k, output_indices, output_distances);
86 }
87
88 void search(const Data_* ptr, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
89 auto normalized = buffer.data();
90 internal::l2norm(ptr, buffer.size(), normalized);
91 my_searcher->search(normalized, k, output_indices, output_distances);
92 }
93
94public:
95 bool can_search_all() const {
96 return my_searcher->can_search_all();
97 }
98
99 Index_ search_all(Index_ i, Distance_ threshold, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
100 return my_searcher->search_all(i, threshold, output_indices, output_distances);
101 }
102
103 Index_ search_all(const Data_* ptr, Distance_ threshold, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
104 auto normalized = buffer.data();
105 internal::l2norm(ptr, buffer.size(), normalized);
106 return my_searcher->search_all(normalized, threshold, output_indices, output_distances);
107 }
111};
112
124template<typename Index_, typename Data_, typename Distance_, typename Normalized_>
125class L2NormalizedPrebuilt final : public Prebuilt<Index_, Data_, Distance_> {
126public:
130 L2NormalizedPrebuilt(std::unique_ptr<Prebuilt<Index_, Normalized_, Distance_> > prebuilt) : my_prebuilt(std::move(prebuilt)) {}
131
132private:
133 std::unique_ptr<Prebuilt<Index_, Normalized_, Distance_> > my_prebuilt;
134
135public:
139 Index_ num_observations() const {
140 return my_prebuilt->num_observations();
141 }
142
143 std::size_t num_dimensions() const {
144 return my_prebuilt->num_dimensions();
145 }
153 std::unique_ptr<Searcher<Index_, Data_, Distance_> > initialize() const {
154 return std::make_unique<L2NormalizedSearcher<Index_, Data_, Distance_, Normalized_> >(my_prebuilt->initialize(), my_prebuilt->num_dimensions());
155 }
156};
157
161template<typename Index_, typename Data_, typename Normalized_, typename Matrix_>
162class L2NormalizedMatrix;
174template<typename Index_, typename Data_, typename Normalized_>
175class L2NormalizedMatrixExtractor final : public MatrixExtractor<Normalized_> {
176public:
180 L2NormalizedMatrixExtractor(std::unique_ptr<MatrixExtractor<Data_> > extractor, std::size_t dim) :
181 my_extractor(std::move(extractor)), buffer(dim) {}
182
183private:
184 std::unique_ptr<MatrixExtractor<Data_> > my_extractor;
185 std::vector<Normalized_> buffer;
186
187public:
188 const Normalized_* next() {
189 auto raw = my_extractor->next();
190 auto normalized = buffer.data();
191 internal::l2norm(raw, buffer.size(), normalized);
192 return normalized;
193 }
197};
198
212template<typename Index_, typename Data_, typename Normalized_, typename Matrix_ = Matrix<Index_, Data_> >
213class L2NormalizedMatrix final : public Matrix<Index_, Normalized_> {
217public:
218 L2NormalizedMatrix(const Matrix_& matrix) : my_matrix(matrix) {}
219
220private:
221 static_assert(std::is_same<decltype(std::declval<Matrix_>().num_observations()), Index_>::value);
222 static_assert(std::is_same<typename std::remove_pointer<decltype(std::declval<Matrix_>().new_extractor()->next())>::type, const Data_>::value);
223
224 const Matrix_& my_matrix;
225
226public:
227 std::size_t num_dimensions() const {
228 return my_matrix.num_dimensions();
229 }
230
231 Index_ num_observations() const {
232 return my_matrix.num_observations();
233 }
234
235 std::unique_ptr<MatrixExtractor<Normalized_> > new_extractor() const {
236 return std::make_unique<L2NormalizedMatrixExtractor<Index_, Data_, Normalized_> >(my_matrix.new_extractor(), num_dimensions());
237 }
241};
242
258template<typename Index_, typename Data_, typename Distance_, typename Normalized_, class Matrix_ = Matrix<Index_, Data_> >
259class L2NormalizedBuilder final : public Builder<Index_, Data_, Distance_, Matrix_> {
260public:
265
277 typedef typename std::conditional<
278 std::is_base_of<Matrix_, NormalizedMatrix>::value,
279 Matrix_,
282
283public:
287 L2NormalizedBuilder(std::shared_ptr<const Builder<Index_, Normalized_, Distance_, BuilderMatrix> > builder) : my_builder(std::move(builder)) {}
288
289private:
290 std::shared_ptr<const Builder<Index_, Normalized_, Distance_, BuilderMatrix> > my_builder;
291
292public:
296 Prebuilt<Index_, Data_, Distance_>* build_raw(const Matrix_& data) const {
297 NormalizedMatrix normalized(data);
298 return new L2NormalizedPrebuilt<Index_, Data_, Distance_, Normalized_>(my_builder->build_unique(normalized));
299 }
300};
301
302}
303
304#endif
Interface to build nearest-neighbor indices.
Interface for the input matrix.
Interface for prebuilt nearest-neighbor indices.
Interface for searching nearest-neighbor indices.
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:28
Wrapper around a builder with L2 normalization.
Definition L2Normalized.hpp:259
L2NormalizedMatrix< Index_, Data_, Normalized_, Matrix_ > NormalizedMatrix
Definition L2Normalized.hpp:264
L2NormalizedBuilder(std::shared_ptr< const Builder< Index_, Normalized_, Distance_, BuilderMatrix > > builder)
Definition L2Normalized.hpp:287
Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const
Definition L2Normalized.hpp:296
std::conditional< std::is_base_of< Matrix_, NormalizedMatrix >::value, Matrix_, NormalizedMatrix >::type BuilderMatrix
Definition L2Normalized.hpp:281
Extractor for the L2NormalizedMatrix.
Definition L2Normalized.hpp:175
Wrapper around a matrix with L2 normalization.
Definition L2Normalized.hpp:213
Wrapper around a prebuilt index with L2 normalization.
Definition L2Normalized.hpp:125
L2NormalizedPrebuilt(std::unique_ptr< Prebuilt< Index_, Normalized_, Distance_ > > prebuilt)
Definition L2Normalized.hpp:130
std::unique_ptr< Searcher< Index_, Data_, Distance_ > > initialize() const
Definition L2Normalized.hpp:153
Wrapper around a search interface with L2 normalization.
Definition L2Normalized.hpp:61
L2NormalizedSearcher(std::unique_ptr< Searcher< Index_, Normalized_, Distance_ > > searcher, std::size_t num_dimensions)
Definition L2Normalized.hpp:67
Extractor interface for matrix data.
Definition Matrix.hpp:20
virtual const Normalized_ * next()=0
Interface for matrix data.
Definition Matrix.hpp:56
virtual std::size_t num_dimensions() const=0
virtual std::unique_ptr< MatrixExtractor< Normalized_ > > new_extractor() const=0
virtual Index_ num_observations() const=0
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:28
virtual Index_ num_observations() const =0
virtual std::size_t num_dimensions() const =0
Interface for searching nearest-neighbor search indices.
Definition Searcher.hpp:28
virtual bool can_search_all() const
Definition Searcher.hpp:85
virtual Index_ search_all(Index_ i, Distance_ distance, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Searcher.hpp:106
virtual void search(Index_ i, Index_ k, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)=0
Collection of KNN algorithms.
Definition Bruteforce.hpp:24