knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
L2Normalized.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_L2_NORMALIZED_HP
2#define KNNCOLLE_L2_NORMALIZED_HP
3
4#include <vector>
5#include <cmath>
6#include <memory>
7
8#include "Searcher.hpp"
9#include "Prebuilt.hpp"
10#include "Builder.hpp"
11#include "MockMatrix.hpp"
12
18namespace knncolle {
19
23namespace internal {
24
25template<typename Float_>
26const Float_* l2norm(const Float_* ptr, size_t ndim, Float_* buffer) {
27 Float_ l2 = 0;
28 for (size_t d = 0; d < ndim; ++d) {
29 auto val = ptr[d];
30 l2 += val * val;
31 }
32
33 if (l2 == 0) {
34 return ptr;
35 }
36
37 l2 = std::sqrt(l2);
38 for (size_t d = 0; d < ndim; ++d) {
39 buffer[d] = ptr[d] / l2;
40 }
41 return buffer;
42}
43
44}
59template<typename Index_, typename Float_>
60class L2NormalizedSearcher : public Searcher<Index_, Float_> {
61public:
66 L2NormalizedSearcher(std::unique_ptr<Searcher<Index_, Float_> > searcher, size_t num_dimensions) :
67 my_searcher(std::move(searcher)),
68 buffer(num_dimensions)
69 {}
70
71private:
72 std::unique_ptr<Searcher<Index_, Float_> > my_searcher;
73 std::vector<Float_> buffer;
74
78public:
79 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
80 my_searcher->search(i, k, output_indices, output_distances);
81 }
82
83 void search(const Float_* ptr, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
84 auto normalized = internal::l2norm(ptr, buffer.size(), buffer.data());
85 my_searcher->search(normalized, k, output_indices, output_distances);
86 }
87
88public:
89 bool can_search_all() const {
90 return my_searcher->can_search_all();
91 }
92
93 Index_ search_all(Index_ i, Float_ threshold, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
94 return my_searcher->search_all(i, threshold, output_indices, output_distances);
95 }
96
97 Index_ search_all(const Float_* ptr, Float_ threshold, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
98 auto normalized = internal::l2norm(ptr, buffer.size(), buffer.data());
99 return my_searcher->search_all(normalized, threshold, output_indices, output_distances);
100 }
104};
105
116template<typename Dim_, typename Index_, typename Float_>
117class L2NormalizedPrebuilt : public Prebuilt<Dim_, Index_, Float_> {
118public:
122 L2NormalizedPrebuilt(std::unique_ptr<Prebuilt<Dim_, Index_, Float_> > prebuilt) : my_prebuilt(std::move(prebuilt)) {}
123
124private:
125 std::unique_ptr<Prebuilt<Dim_, Index_, Float_> > my_prebuilt;
126
127public:
131 Index_ num_observations() const {
132 return my_prebuilt->num_observations();
133 }
134
135 Dim_ num_dimensions() const {
136 return my_prebuilt->num_dimensions();
137 }
145 std::unique_ptr<Searcher<Index_, Float_> > initialize() const {
146 return std::make_unique<L2NormalizedSearcher<Index_, Float_> >(my_prebuilt->initialize(), my_prebuilt->num_dimensions());
147 }
148};
149
159template<class Matrix_ = SimpleMatrix<int, int, double> >
164public:
165 L2NormalizedMatrix(const Matrix_& matrix) : my_matrix(matrix) {}
166
167private:
168 const Matrix_& my_matrix;
169
170public:
171 typedef typename Matrix_::data_type data_type;
172 typedef typename Matrix_::index_type index_type;
173 typedef typename Matrix_::dimension_type dimension_type;
174
175 dimension_type num_dimensions() const {
176 return my_matrix.num_dimensions();
177 }
178
179 index_type num_observations() const {
180 return my_matrix.num_observations();
181 }
182
183 struct Workspace {
184 Workspace(size_t n) : normalized(n) {}
185 typename Matrix_::Workspace inner;
186 std::vector<data_type> normalized;
187 };
188
189 Workspace create_workspace() const {
190 return Workspace(my_matrix.num_dimensions());
191 }
192
193 const data_type* get_observation(Workspace& workspace) const {
194 auto ptr = my_matrix.get_observation(workspace.inner);
195 size_t ndim = workspace.normalized.size();
196 return internal::l2norm(ptr, ndim, workspace.normalized.data());
197 }
201};
202
211template<class Matrix_ = SimpleMatrix<int, int, double>, typename Float_ = double>
212class L2NormalizedBuilder : public Builder<Matrix_, Float_> {
213public:
218 L2NormalizedBuilder(std::unique_ptr<Builder<L2NormalizedMatrix<Matrix_>, Float_> > builder) : my_builder(std::move(builder)) {}
219
224 L2NormalizedBuilder(Builder<L2NormalizedMatrix<Matrix_>, Float_>* builder) : my_builder(builder) {}
225
226private:
227 std::unique_ptr<Builder<L2NormalizedMatrix<Matrix_>, Float_> > my_builder;
228
229public:
236};
237
238}
239
240#endif
Interface to build nearest-neighbor indices.
Interface for prebuilt nearest-neighbor indices.
Interface for searching nearest-neighbor indices.
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:22
Wrapper around a builder with L2 normalization.
Definition L2Normalized.hpp:212
L2NormalizedBuilder(Builder< L2NormalizedMatrix< Matrix_ >, Float_ > *builder)
Definition L2Normalized.hpp:224
L2NormalizedBuilder(std::unique_ptr< Builder< L2NormalizedMatrix< Matrix_ >, Float_ > > builder)
Definition L2Normalized.hpp:218
Prebuilt< typename Matrix_::dimension_type, typename Matrix_::index_type, Float_ > * build_raw(const Matrix_ &data) const
Definition L2Normalized.hpp:233
Wrapper around a matrix with L2 normalization.
Definition L2Normalized.hpp:160
Wrapper around a prebuilt index with L2 normalization.
Definition L2Normalized.hpp:117
std::unique_ptr< Searcher< Index_, Float_ > > initialize() const
Definition L2Normalized.hpp:145
L2NormalizedPrebuilt(std::unique_ptr< Prebuilt< Dim_, Index_, Float_ > > prebuilt)
Definition L2Normalized.hpp:122
Wrapper around a search interface with L2 normalization.
Definition L2Normalized.hpp:60
L2NormalizedSearcher(std::unique_ptr< Searcher< Index_, Float_ > > searcher, size_t num_dimensions)
Definition L2Normalized.hpp:66
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:28
virtual Index_ num_observations() const =0
virtual Dim_ num_dimensions() const =0
Interface for searching nearest-neighbor search indices.
Definition Searcher.hpp:28
virtual void search(Index_ i, Index_ k, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)=0
virtual bool can_search_all() const
Definition Searcher.hpp:80
virtual Index_ search_all(Index_ i, Float_ distance, std::vector< Index_ > *output_indices, std::vector< Float_ > *output_distances)
Definition Searcher.hpp:101
Collection of KNN algorithms.
Definition Bruteforce.hpp:22