knncolle_hnsw
knncolle bindings for HNSW
Loading...
Searching...
No Matches
knncolle_hnsw.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_HNSW_HPP
2#define KNNCOLLE_HNSW_HPP
3
4#include <vector>
5#include <type_traits>
6#include <queue>
7#include <algorithm>
8#include <memory>
9
10#include "knncolle/knncolle.hpp"
11#include "hnswlib/hnswalg.h"
12
13#include "distances.hpp"
14
25namespace knncolle_hnsw {
26
37template<typename Dim_ = int, typename InternalData_ = float>
65
66template<typename Dim_, typename Index_, typename Float_, typename InternalData_>
67class HnswPrebuilt;
68
79template<typename Dim_, typename Index_, typename Float_, typename InternalData_>
80class HnswSearcher : public knncolle::Searcher<Index_, Float_> {
81private:
83
84 std::priority_queue<std::pair<InternalData_, hnswlib::labeltype> > my_queue;
85 std::vector<InternalData_> my_buffer;
86
87 static constexpr bool same_internal = std::is_same<Float_, InternalData_>::value;
88
89public:
93 HnswSearcher(const HnswPrebuilt<Dim_, Index_, Float_, InternalData_>* parent) : my_parent(parent) {
94 if constexpr(!same_internal) {
95 my_buffer.resize(my_parent->my_dim);
96 }
97 }
102public:
103 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
104 my_buffer = my_parent->my_index.template getDataByLabel<InternalData_>(i);
105 Index_ kp1 = k + 1;
106 my_queue = my_parent->my_index.searchKnn(my_buffer.data(), kp1); // +1, as it forgets to discard 'self'.
107
108 if (output_indices) {
109 output_indices->clear();
110 output_indices->reserve(kp1);
111 }
112 if (output_distances) {
113 output_distances->clear();
114 output_distances->reserve(kp1);
115 }
116
117 bool self_found = false;
118 hnswlib::labeltype icopy = i;
119 while (!my_queue.empty()) {
120 const auto& top = my_queue.top();
121 if (!self_found && top.second == icopy) {
122 self_found = true;
123 } else {
124 if (output_indices) {
125 output_indices->push_back(top.second);
126 }
127 if (output_distances) {
128 output_distances->push_back(top.first);
129 }
130 }
131 my_queue.pop();
132 }
133
134 if (output_indices) {
135 std::reverse(output_indices->begin(), output_indices->end());
136 }
137 if (output_distances) {
138 std::reverse(output_distances->begin(), output_distances->end());
139 }
140
141 // Just in case we're full of ties at duplicate points, such that 'c'
142 // is not in the set. Note that, if self_found=false, we must have at
143 // least 'K+2' points for 'c' to not be detected as its own neighbor.
144 // Thus there is no need to worry whether we are popping off a non-'c'
145 // element and then returning fewer elements than expected.
146 if (!self_found) {
147 if (output_indices) {
148 output_indices->pop_back();
149 }
150 if (output_distances) {
151 output_distances->pop_back();
152 }
153 }
154
155 if (output_distances && my_parent->my_normalize) {
156 for (auto& d : *output_distances) {
157 d = my_parent->my_normalize(d);
158 }
159 }
160 }
161
162private:
163 void search_raw(const InternalData_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
164 k = std::min(k, my_parent->my_obs);
165 my_queue = my_parent->my_index.searchKnn(query, k);
166
167 if (output_indices) {
168 output_indices->resize(k);
169 }
170 if (output_distances) {
171 output_distances->resize(k);
172 }
173
174 size_t position = k;
175 while (!my_queue.empty()) {
176 const auto& top = my_queue.top();
177 --position;
178 if (output_indices) {
179 (*output_indices)[position] = top.second;
180 }
181 if (output_distances) {
182 (*output_distances)[position] = top.first;
183 }
184 my_queue.pop();
185 }
186
187 if (output_distances && my_parent->my_normalize) {
188 for (auto& d : *output_distances) {
189 d = my_parent->my_normalize(d);
190 }
191 }
192 }
193
194public:
195 void search(const Float_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
196 if constexpr(same_internal) {
197 my_queue = my_parent->my_index.searchKnn(query, k);
198 search_raw(query, k, output_indices, output_distances);
199 } else {
200 std::copy_n(query, my_parent->my_dim, my_buffer.begin());
201 search_raw(my_buffer.data(), k, output_indices, output_distances);
202 }
203 }
204};
205
219template<typename Dim_, typename Index_, typename Float_, typename InternalData_>
220class HnswPrebuilt : public knncolle::Prebuilt<Dim_, Index_, Float_> {
221public:
225 template<class Matrix_>
226 HnswPrebuilt(const Matrix_& data, const HnswOptions<Dim_, InternalData_>& options) :
227 my_dim(data.num_dimensions()),
228 my_obs(data.num_observations()),
229 my_space([&]() {
230 if (options.distance_options.create) {
231 return options.distance_options.create(my_dim);
232 } else if constexpr(std::is_same<InternalData_, float>::value) {
233 return static_cast<hnswlib::SpaceInterface<InternalData_>*>(new hnswlib::L2Space(my_dim));
234 } else {
235 return static_cast<hnswlib::SpaceInterface<InternalData_>*>(new SquaredEuclideanDistance<InternalData_>(my_dim));
236 }
237 }()),
238 my_normalize([&]() {
239 if (options.distance_options.normalize) {
240 return options.distance_options.normalize;
241 } else if (options.distance_options.create) {
242 return std::function<InternalData_(InternalData_)>();
243 } else {
244 return std::function<InternalData_(InternalData_)>([](InternalData_ x) -> InternalData_ { return std::sqrt(x); });
245 }
246 }()),
247 my_index(my_space.get(), my_obs, options.num_links, options.ef_construction)
248 {
249 typedef typename Matrix_::data_type Data_;
250 auto work = data.create_workspace();
251 if constexpr(std::is_same<Data_, InternalData_>::value) {
252 for (Index_ i = 0; i < my_obs; ++i) {
253 auto ptr = data.get_observation(work);
254 my_index.addPoint(ptr, i);
255 }
256 } else {
257 std::vector<InternalData_> incoming(my_dim);
258 for (Index_ i = 0; i < my_obs; ++i) {
259 auto ptr = data.get_observation(work);
260 std::copy_n(ptr, my_dim, incoming.begin());
261 my_index.addPoint(incoming.data(), i);
262 }
263 }
264
265 my_index.setEf(options.ef_search);
266 return;
267 }
272private:
273 Dim_ my_dim;
274 Index_ my_obs;
275
276 // The following must be a pointer for polymorphism, but also so that
277 // references to the object in my_index are still valid after copying.
278 std::shared_ptr<hnswlib::SpaceInterface<InternalData_> > my_space;
279
280 std::function<InternalData_(InternalData_)> my_normalize;
281 hnswlib::HierarchicalNSW<InternalData_> my_index;
282
283 friend class HnswSearcher<Dim_, Index_, Float_, InternalData_>;
284
285public:
286 Dim_ num_dimensions() const {
287 return my_dim;
288 }
289
290 Index_ num_observations() const {
291 return my_obs;
292 }
293
294 std::unique_ptr<knncolle::Searcher<Index_, Float_> > initialize() const {
295 return std::make_unique<HnswSearcher<Dim_, Index_, Float_, InternalData_> >(this);
296 }
297};
298
322template<
324 typename Float_ = double,
325 typename InternalData_ = float>
326class HnswBuilder : public knncolle::Builder<Matrix_, Float_> {
327public:
332
333private:
334 Options my_options;
335
336public:
340 HnswBuilder(Options options) : my_options(std::move(options)) {}
341
345 HnswBuilder() = default;
346
351 return my_options;
352 }
353
354public:
357 }
358};
359
360}
361
362#endif
Perform an approximate nearest neighbor search with HNSW.
Definition knncolle_hnsw.hpp:326
HnswOptions< typename Matrix_::dimension_type, InternalData_ > Options
Definition knncolle_hnsw.hpp:331
Options & get_options()
Definition knncolle_hnsw.hpp:350
HnswBuilder(Options options)
Definition knncolle_hnsw.hpp:340
Prebuilt index for an Hnsw search.
Definition knncolle_hnsw.hpp:220
Searcher on an Hnsw index.
Definition knncolle_hnsw.hpp:80
knncolle bindings for HNSW search.
Definition distances.hpp:12
Distance options for the HNSW index.
Definition distances.hpp:21
Options for HnswBuilder and HnswPrebuilt.
Definition knncolle_hnsw.hpp:38
DistanceOptions< Dim_, InternalData_ > distance_options
Definition knncolle_hnsw.hpp:63
int ef_construction
Definition knncolle_hnsw.hpp:51
int ef_search
Definition knncolle_hnsw.hpp:58
int num_links
Definition knncolle_hnsw.hpp:44