knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
L2Normalized.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_L2_NORMALIZED_HP
2#define KNNCOLLE_L2_NORMALIZED_HP
3
4#include <vector>
5#include <cmath>
6#include <memory>
7#include <limits>
8
9#include "Searcher.hpp"
10#include "Prebuilt.hpp"
11#include "Builder.hpp"
12#include "Matrix.hpp"
13
19namespace knncolle {
20
24namespace internal {
25
26template<typename Data_, typename Normalized_>
27void l2norm(const Data_* ptr, size_t ndim, Normalized_* buffer) {
28 Normalized_ l2 = 0;
29 for (size_t d = 0; d < ndim; ++d) {
30 Normalized_ val = ptr[d]; // cast to Normalized_ to avoid issues with integer overflow.
31 buffer[d] = val;
32 l2 += val * val;
33 }
34
35 if (l2 > 0) {
36 l2 = std::sqrt(l2);
37 for (size_t d = 0; d < ndim; ++d) {
38 buffer[d] /= l2;
39 }
40 }
41}
42
43}
59template<typename Index_, typename Data_, typename Distance_, typename Normalized_>
60class L2NormalizedSearcher final : public Searcher<Index_, Data_, Distance_> {
61public:
66 L2NormalizedSearcher(std::unique_ptr<Searcher<Index_, Normalized_, Distance_> > searcher, size_t num_dimensions) :
67 my_searcher(std::move(searcher)),
68 buffer(num_dimensions)
69 {}
70
71private:
72 // No way around this; the L2-normalized values must be floating-point,
73 // so the internal searcher must accept floats.
74 static_assert(std::is_floating_point<Normalized_>::value);
75
76 std::unique_ptr<Searcher<Index_, Normalized_, Distance_> > my_searcher;
77 std::vector<Normalized_> buffer;
82public:
83 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
84 my_searcher->search(i, k, output_indices, output_distances);
85 }
86
87 void search(const Data_* ptr, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
88 auto normalized = buffer.data();
89 internal::l2norm(ptr, buffer.size(), normalized);
90 my_searcher->search(normalized, k, output_indices, output_distances);
91 }
92
93public:
94 bool can_search_all() const {
95 return my_searcher->can_search_all();
96 }
97
98 Index_ search_all(Index_ i, Distance_ threshold, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
99 return my_searcher->search_all(i, threshold, output_indices, output_distances);
100 }
101
102 Index_ search_all(const Data_* ptr, Distance_ threshold, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
103 auto normalized = buffer.data();
104 internal::l2norm(ptr, buffer.size(), normalized);
105 return my_searcher->search_all(normalized, threshold, output_indices, output_distances);
106 }
110};
111
123template<typename Index_, typename Data_, typename Distance_, typename Normalized_>
124class L2NormalizedPrebuilt final : public Prebuilt<Index_, Data_, Distance_> {
125public:
129 L2NormalizedPrebuilt(std::unique_ptr<Prebuilt<Index_, Normalized_, Distance_> > prebuilt) : my_prebuilt(std::move(prebuilt)) {}
130
131private:
132 std::unique_ptr<Prebuilt<Index_, Normalized_, Distance_> > my_prebuilt;
133
134public:
138 Index_ num_observations() const {
139 return my_prebuilt->num_observations();
140 }
141
142 size_t num_dimensions() const {
143 return my_prebuilt->num_dimensions();
144 }
152 std::unique_ptr<Searcher<Index_, Data_, Distance_> > initialize() const {
153 return std::make_unique<L2NormalizedSearcher<Index_, Data_, Distance_, Normalized_> >(my_prebuilt->initialize(), my_prebuilt->num_dimensions());
154 }
155};
156
160template<typename Index_, typename Data_, typename Normalized_, typename Matrix_>
161class L2NormalizedMatrix;
173template<typename Index_, typename Data_, typename Normalized_>
174class L2NormalizedMatrixExtractor final : public MatrixExtractor<Normalized_> {
175public:
179 L2NormalizedMatrixExtractor(std::unique_ptr<MatrixExtractor<Data_> > extractor, size_t dim) :
180 my_extractor(std::move(extractor)), buffer(dim) {}
181
182private:
183 std::unique_ptr<MatrixExtractor<Data_> > my_extractor;
184 std::vector<Normalized_> buffer;
185
186public:
187 const Normalized_* next() {
188 auto raw = my_extractor->next();
189 auto normalized = buffer.data();
190 internal::l2norm(raw, buffer.size(), normalized);
191 return normalized;
192 }
196};
197
211template<typename Index_, typename Data_, typename Normalized_, typename Matrix_ = Matrix<Index_, Data_> >
212class L2NormalizedMatrix final : public Matrix<Index_, Normalized_> {
216public:
217 L2NormalizedMatrix(const Matrix_& matrix) : my_matrix(matrix) {}
218
219private:
220 static_assert(std::is_same<decltype(std::declval<Matrix_>().num_observations()), Index_>::value);
221 static_assert(std::is_same<typename std::remove_pointer<decltype(std::declval<Matrix_>().new_extractor()->next())>::type, const Data_>::value);
222
223 const Matrix_& my_matrix;
224
225public:
226 size_t num_dimensions() const {
227 return my_matrix.num_dimensions();
228 }
229
230 Index_ num_observations() const {
231 return my_matrix.num_observations();
232 }
233
234 std::unique_ptr<MatrixExtractor<Normalized_> > new_extractor() const {
235 return std::make_unique<L2NormalizedMatrixExtractor<Index_, Data_, Normalized_> >(my_matrix.new_extractor(), num_dimensions());
236 }
240};
241
257template<typename Index_, typename Data_, typename Distance_, typename Normalized_, class Matrix_ = Matrix<Index_, Data_> >
258class L2NormalizedBuilder final : public Builder<Index_, Data_, Distance_, Matrix_> {
259public:
264
276 typedef typename std::conditional<
277 std::is_base_of<Matrix_, NormalizedMatrix>::value,
278 Matrix_,
281
282public:
286 L2NormalizedBuilder(std::shared_ptr<const Builder<Index_, Normalized_, Distance_, BuilderMatrix> > builder) : my_builder(std::move(builder)) {}
287
288private:
289 std::shared_ptr<const Builder<Index_, Normalized_, Distance_, BuilderMatrix> > my_builder;
290
291public:
295 Prebuilt<Index_, Data_, Distance_>* build_raw(const Matrix_& data) const {
296 NormalizedMatrix normalized(data);
297 return new L2NormalizedPrebuilt<Index_, Data_, Distance_, Normalized_>(my_builder->build_unique(normalized));
298 }
299};
300
301}
302
303#endif
Interface to build nearest-neighbor indices.
Interface for the input matrix.
Interface for prebuilt nearest-neighbor indices.
Interface for searching nearest-neighbor indices.
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:28
Wrapper around a builder with L2 normalization.
Definition L2Normalized.hpp:258
L2NormalizedMatrix< Index_, Data_, Normalized_, Matrix_ > NormalizedMatrix
Definition L2Normalized.hpp:263
L2NormalizedBuilder(std::shared_ptr< const Builder< Index_, Normalized_, Distance_, BuilderMatrix > > builder)
Definition L2Normalized.hpp:286
Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const
Definition L2Normalized.hpp:295
std::conditional< std::is_base_of< Matrix_, NormalizedMatrix >::value, Matrix_, NormalizedMatrix >::type BuilderMatrix
Definition L2Normalized.hpp:280
Extractor for the L2NormalizedMatrix.
Definition L2Normalized.hpp:174
Wrapper around a matrix with L2 normalization.
Definition L2Normalized.hpp:212
Wrapper around a prebuilt index with L2 normalization.
Definition L2Normalized.hpp:124
L2NormalizedPrebuilt(std::unique_ptr< Prebuilt< Index_, Normalized_, Distance_ > > prebuilt)
Definition L2Normalized.hpp:129
std::unique_ptr< Searcher< Index_, Data_, Distance_ > > initialize() const
Definition L2Normalized.hpp:152
Wrapper around a search interface with L2 normalization.
Definition L2Normalized.hpp:60
L2NormalizedSearcher(std::unique_ptr< Searcher< Index_, Normalized_, Distance_ > > searcher, size_t num_dimensions)
Definition L2Normalized.hpp:66
Extractor interface for matrix data.
Definition Matrix.hpp:17
virtual const Normalized_ * next()=0
Interface for matrix data.
Definition Matrix.hpp:53
virtual std::unique_ptr< MatrixExtractor< Normalized_ > > new_extractor() const=0
virtual Index_ num_observations() const=0
virtual size_t num_dimensions() const=0
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:26
virtual Index_ num_observations() const =0
virtual size_t num_dimensions() const =0
Interface for searching nearest-neighbor search indices.
Definition Searcher.hpp:28
virtual bool can_search_all() const
Definition Searcher.hpp:85
virtual Index_ search_all(Index_ i, Distance_ distance, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Searcher.hpp:106
virtual void search(Index_ i, Index_ k, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)=0
Collection of KNN algorithms.
Definition Bruteforce.hpp:23