knncolle_hnsw
knncolle bindings for HNSW
Loading...
Searching...
No Matches
Hnsw.hpp
1#ifndef KNNCOLLE_HNSW_HPP
2#define KNNCOLLE_HNSW_HPP
3
4#include <vector>
5#include <type_traits>
6#include <queue>
7#include <algorithm>
8#include <memory>
9#include <cstddef>
10#include <cstring>
11#include <cmath>
12
13#include "knncolle/knncolle.hpp"
14#include "sanisizer/sanisizer.hpp"
15#include "hnswlib/hnswalg.h"
16
17#include "distances.hpp"
18#include "utils.hpp"
19
26namespace knncolle_hnsw {
27
40 int num_links = 16;
41
47 int ef_construction = 200;
48
54 int ef_search = 10;
55};
56
60template<typename Index_, typename Data_, typename Distance_, typename HnswData_>
61class HnswPrebuilt;
62
63template<typename Index_, typename Data_, typename Distance_, typename HnswData_>
64class HnswSearcher final : public knncolle::Searcher<Index_, Data_, Distance_> {
65private:
66 const HnswPrebuilt<Index_, Data_, Distance_, HnswData_>& my_parent;
67
68 std::priority_queue<std::pair<HnswData_, hnswlib::labeltype> > my_queue;
69
70 static constexpr bool same_internal_data = std::is_same<Data_, HnswData_>::value;
71 std::vector<HnswData_> my_buffer;
72
73public:
74 HnswSearcher(const HnswPrebuilt<Index_, Data_, Distance_, HnswData_>& parent) : my_parent(parent) {
75 if constexpr(!same_internal_data) {
76 sanisizer::resize(my_buffer, my_parent.my_dim);
77 }
78 }
79
80private:
81 void normalize_distances(std::vector<Distance_>& output_distances) const {
82 switch(my_parent.my_normalize_method) {
83 case DistanceNormalizeMethod::SQRT:
84 for (auto& d : output_distances) {
85 d = std::sqrt(d);
86 }
87 break;
88 case DistanceNormalizeMethod::CUSTOM:
89 for (auto& d : output_distances) {
90 d = my_parent.my_custom_normalize(d);
91 }
92 break;
93 case DistanceNormalizeMethod::NONE:
94 break;
95 }
96 }
97
98public:
99 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
100 my_buffer = my_parent.my_index.template getDataByLabel<HnswData_>(i);
101 Index_ kp1 = k + 1;
102 my_queue = my_parent.my_index.searchKnn(my_buffer.data(), kp1); // +1, as it forgets to discard 'self'.
103
104 if (output_indices) {
105 output_indices->clear();
106 output_indices->reserve(kp1);
107 }
108 if (output_distances) {
109 output_distances->clear();
110 output_distances->reserve(kp1);
111 }
112
113 bool self_found = false;
114 hnswlib::labeltype icopy = i;
115 while (!my_queue.empty()) {
116 const auto& top = my_queue.top();
117 if (!self_found && top.second == icopy) {
118 self_found = true;
119 } else {
120 if (output_indices) {
121 output_indices->push_back(top.second);
122 }
123 if (output_distances) {
124 output_distances->push_back(top.first);
125 }
126 }
127 my_queue.pop();
128 }
129
130 if (output_indices) {
131 std::reverse(output_indices->begin(), output_indices->end());
132 }
133 if (output_distances) {
134 std::reverse(output_distances->begin(), output_distances->end());
135 }
136
137 // Just in case we're full of ties at duplicate points, such that 'c'
138 // is not in the set. Note that, if self_found=false, we must have at
139 // least 'K+2' points for 'c' to not be detected as its own neighbor.
140 // Thus there is no need to worry whether we are popping off a non-'c'
141 // element and then returning fewer elements than expected.
142 if (!self_found) {
143 if (output_indices) {
144 output_indices->pop_back();
145 }
146 if (output_distances) {
147 output_distances->pop_back();
148 }
149 }
150
151 if (output_distances) {
152 normalize_distances(*output_distances);
153 }
154 }
155
156private:
157 void search_raw(const HnswData_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
158 k = std::min(k, my_parent.my_obs);
159 my_queue = my_parent.my_index.searchKnn(query, k);
160
161 if (output_indices) {
162 output_indices->resize(k);
163 }
164 if (output_distances) {
165 output_distances->resize(k);
166 }
167
168 auto position = k;
169 while (!my_queue.empty()) {
170 const auto& top = my_queue.top();
171 --position;
172 if (output_indices) {
173 (*output_indices)[position] = top.second;
174 }
175 if (output_distances) {
176 (*output_distances)[position] = top.first;
177 }
178 my_queue.pop();
179 }
180
181 if (output_distances) {
182 normalize_distances(*output_distances);
183 }
184 }
185
186public:
187 void search(const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
188 if constexpr(same_internal_data) {
189 my_queue = my_parent.my_index.searchKnn(query, k);
190 search_raw(query, k, output_indices, output_distances);
191 } else {
192 std::copy_n(query, my_parent.my_dim, my_buffer.begin());
193 search_raw(my_buffer.data(), k, output_indices, output_distances);
194 }
195 }
196};
197
198template<typename Index_, typename Data_, typename Distance_, typename HnswData_>
199class HnswPrebuilt : public knncolle::Prebuilt<Index_, Data_, Distance_> {
200public:
201 template<class Matrix_>
202 HnswPrebuilt(const Matrix_& data, const DistanceConfig<Distance_, HnswData_>& distance_config, const HnswOptions& options) :
203 my_dim(data.num_dimensions()),
204 my_obs(data.num_observations()),
205 my_space(distance_config.create(my_dim)),
206 my_normalize_method(distance_config.normalize_method),
207 my_custom_normalize(distance_config.custom_normalize),
208 my_index(my_space.get(), my_obs, options.num_links, options.ef_construction)
209 {
210 auto work = data.new_known_extractor();
211 if constexpr(std::is_same<Data_, HnswData_>::value) {
212 for (Index_ i = 0; i < my_obs; ++i) {
213 auto ptr = work->next();
214 my_index.addPoint(ptr, i);
215 }
216 } else {
217 auto incoming = sanisizer::create<std::vector<HnswData_> >(my_dim);
218 for (Index_ i = 0; i < my_obs; ++i) {
219 auto ptr = work->next();
220 std::copy_n(ptr, my_dim, incoming.begin());
221 my_index.addPoint(incoming.data(), i);
222 }
223 }
224
225 my_index.setEf(options.ef_search);
226 return;
227 }
228
229private:
230 std::size_t my_dim;
231 Index_ my_obs;
232
233 // The following must be a pointer for polymorphism, but also so that
234 // references to the object in my_index are still valid after copying.
235 std::shared_ptr<hnswlib::SpaceInterface<HnswData_> > my_space;
236
237 DistanceNormalizeMethod my_normalize_method;
238 std::function<Distance_(Distance_)> my_custom_normalize;
239
240 hnswlib::HierarchicalNSW<HnswData_> my_index;
241
242 friend class HnswSearcher<Index_, Data_, Distance_, HnswData_>;
243
244public:
245 std::size_t num_dimensions() const {
246 return my_dim;
247 }
248
249 Index_ num_observations() const {
250 return my_obs;
251 }
252
253public:
254 std::unique_ptr<knncolle::Searcher<Index_, Data_, Distance_> > initialize() const {
255 return initialize_known();
256 }
257
258 auto initialize_known() const {
259 return std::make_unique<HnswSearcher<Index_, Data_, Distance_, HnswData_> >(*this);
260 }
261
262public:
263 void save(const std::string& prefix) const {
264 knncolle::quick_save(prefix + "ALGORITHM", save_name, std::strlen(save_name));
265 knncolle::quick_save(prefix + "num_obs", &my_obs, 1);
266 knncolle::quick_save(prefix + "num_dim", &my_dim, 1);
267
269 knncolle::quick_save(prefix + "type", &type, 1);
270
271 const char* distname = get_distance_name(my_space.get());;
272 knncolle::quick_save(prefix + "distance", distname, std::strlen(distname));
273 knncolle::quick_save(prefix + "normalize", &my_normalize_method, 1);
274
275 // Custom normalization functions.
276 auto& datafunc = custom_save_for_hnsw_data<HnswData_>();
277 if (datafunc) {
278 datafunc(prefix);
279 }
280
281 auto& distfunc = custom_save_for_hnsw_distance<HnswData_>();
282 if (std::strcmp(distname, "unknown") == 0 && distfunc) {
283 distfunc(prefix, my_space.get());
284 }
285
286 auto& normfunc = custom_save_for_hnsw_normalize<Distance_>();
287 if (my_normalize_method == DistanceNormalizeMethod::CUSTOM && normfunc) {
288 normfunc(prefix, my_custom_normalize);
289 }
290
291 // Dear God, make saveIndex() const.
292 auto index_ptr = const_cast<hnswlib::HierarchicalNSW<HnswData_>*>(&my_index);
293 index_ptr->saveIndex(prefix + "index");
294 }
295
296 HnswPrebuilt(const std::string& prefix) :
297 my_dim([&]() {
298 std::size_t dim;
299 knncolle::quick_load(prefix + "num_dim", &dim, 1);
300 return dim;
301 }()),
302
303 my_obs([&]() {
304 Index_ obs;
305 knncolle::quick_load(prefix + "num_obs", &obs, 1);
306 return obs;
307 }()),
308
309 my_space([&]() {
310 std::string method = knncolle::quick_load_as_string(prefix + "distance");
311
312 if constexpr(std::is_same<HnswData_, float>::value) {
313 if (method == "l2") {
314 return static_cast<hnswlib::SpaceInterface<HnswData_>*>(new hnswlib::L2Space(my_dim));
315 }
316 }
317 if (method == "squared_euclidean") {
318 return static_cast<hnswlib::SpaceInterface<HnswData_>*>(new SquaredEuclideanDistance<HnswData_>(my_dim));
319 } else if (method == "manhattan") {
320 return static_cast<hnswlib::SpaceInterface<HnswData_>*>(new ManhattanDistance<HnswData_>(my_dim));
321 }
322
323 auto& loadfun = custom_load_for_hnsw_distance<HnswData_>();
324 if (!loadfun) {
325 throw std::runtime_error("no loader provided for an unknown distance");
326 }
327 return static_cast<hnswlib::SpaceInterface<HnswData_>*>(loadfun(prefix, my_dim));
328 }()),
329
330 my_normalize_method([&]() {
332 knncolle::quick_load(prefix + "normalize", &norm, 1);
333 return norm;
334 }()),
335
336 my_index(my_space.get(), prefix + "index")
337
338 {
339 if (my_normalize_method == DistanceNormalizeMethod::CUSTOM) {
340 auto& normfun = custom_load_for_hnsw_normalize<Distance_>();
341 if (!normfun) {
342 throw std::runtime_error("no loader provided for an unknown normalization");
343 }
344 my_custom_normalize = normfun(prefix);
345 }
346 }
347};
376template<
377 typename Index_,
378 typename Data_,
379 typename Distance_,
380 class Matrix_ = knncolle::Matrix<Index_, Data_>,
381 typename HnswData_ = float
382>
383class HnswBuilder : public knncolle::Builder<Index_, Data_, Distance_, Matrix_> {
384private:
385 DistanceConfig<Distance_, HnswData_> my_distance_config;
386 HnswOptions my_options;
387
388public:
394 my_distance_config(std::move(distance_config)),
395 my_options(std::move(options))
396 {
397 if (!my_distance_config.create) {
398 throw std::runtime_error("'distance_config.create' was not provided");
399 }
400 if (my_distance_config.normalize_method == DistanceNormalizeMethod::CUSTOM && !my_distance_config.custom_normalize) {
401 throw std::runtime_error("'distance_config.custom_normalize' was not provided");
402 }
403 }
404
409 HnswBuilder(DistanceConfig<Distance_, HnswData_> distance_config) : HnswBuilder(std::move(distance_config), {}) {}
410
415 return my_options;
416 }
417
418public:
419 knncolle::Prebuilt<Index_, Data_, Distance_>* build_raw(const Matrix_& data) const {
420 return build_known_raw(data);
421 }
422
423public:
427 auto build_known_raw(const Matrix_& data) const {
428 return new HnswPrebuilt<Index_, Data_, Distance_, HnswData_>(data, my_distance_config, my_options);
429 }
430
434 auto build_known_unique(const Matrix_& data) const {
435 return std::unique_ptr<I<decltype(*build_known_raw(data))> >(build_known_raw(data));
436 }
437
441 auto build_known_shared(const Matrix_& data) const {
442 return std::shared_ptr<I<decltype(*build_known_raw(data))> >(build_known_raw(data));
443 }
444};
445
446}
447
448#endif
Perform an approximate nearest neighbor search with HNSW.
Definition Hnsw.hpp:383
auto build_known_shared(const Matrix_ &data) const
Definition Hnsw.hpp:441
HnswBuilder(DistanceConfig< Distance_, HnswData_ > distance_config)
Definition Hnsw.hpp:409
HnswBuilder(DistanceConfig< Distance_, HnswData_ > distance_config, HnswOptions options)
Definition Hnsw.hpp:393
auto build_known_raw(const Matrix_ &data) const
Definition Hnsw.hpp:427
auto build_known_unique(const Matrix_ &data) const
Definition Hnsw.hpp:434
HnswOptions & get_options()
Definition Hnsw.hpp:414
Distance classes for HNSW.
knncolle bindings for HNSW search.
Definition distances.hpp:13
DistanceNormalizeMethod
Definition distances.hpp:22
const char * get_distance_name(const hnswlib::SpaceInterface< HnswData_ > *distance)
Definition distances.hpp:206
std::string quick_load_as_string(const std::string &path)
NumericType get_numeric_type()
void quick_load(const std::string &path, Input_ *const contents, const Length_ length)
void quick_save(const std::string &path, const Input_ *const contents, const Length_ length)
Distance configuration for the HNSW index.
Definition distances.hpp:31
DistanceNormalizeMethod normalize_method
Definition distances.hpp:40
std::function< hnswlib::SpaceInterface< HnswData_ > *(std::size_t)> create
Definition distances.hpp:35
std::function< Distance_(Distance_)> custom_normalize
Definition distances.hpp:46
Options for HnswBuilder and HnswPrebuilt.
Definition Hnsw.hpp:34
int num_links
Definition Hnsw.hpp:40
int ef_construction
Definition Hnsw.hpp:47
int ef_search
Definition Hnsw.hpp:54