knncolle_hnsw
knncolle bindings for HNSW
Loading...
Searching...
No Matches
Hnsw.hpp
1#ifndef KNNCOLLE_HNSW_HNSW_HPP
2#define KNNCOLLE_HNSW_HNSW_HPP
3
4#include <vector>
5#include <type_traits>
6#include <queue>
7#include <algorithm>
8#include <memory>
9#include <cstddef>
10#include <cstring>
11#include <cmath>
12#include <filesystem>
13
14#include "knncolle/knncolle.hpp"
15#include "sanisizer/sanisizer.hpp"
16#include "hnswlib/hnswalg.h"
17
18#include "distances.hpp"
19#include "utils.hpp"
20
27namespace knncolle_hnsw {
28
32inline static constexpr const char* hnsw_prebuilt_save_name = "knncolle_hnsw::Hnsw";
33
46 int num_links = 16;
47
53 int ef_construction = 200;
54
60 int ef_search = 10;
61};
62
66template<typename Index_, typename Data_, typename Distance_, typename HnswData_>
67class HnswPrebuilt;
68
69template<typename Index_, typename Data_, typename Distance_, typename HnswData_>
70class HnswSearcher final : public knncolle::Searcher<Index_, Data_, Distance_> {
71private:
72 const HnswPrebuilt<Index_, Data_, Distance_, HnswData_>& my_parent;
73
74 std::priority_queue<std::pair<HnswData_, hnswlib::labeltype> > my_queue;
75
76 static constexpr bool same_internal_data = std::is_same<Data_, HnswData_>::value;
77 std::vector<HnswData_> my_buffer;
78
79public:
80 HnswSearcher(const HnswPrebuilt<Index_, Data_, Distance_, HnswData_>& parent) : my_parent(parent) {
81 if constexpr(!same_internal_data) {
82 sanisizer::resize(my_buffer, my_parent.my_dim);
83 }
84 }
85
86private:
87 void normalize_distances(std::vector<Distance_>& output_distances) const {
88 switch(my_parent.my_normalize_method) {
89 case DistanceNormalizeMethod::SQRT:
90 for (auto& d : output_distances) {
91 d = std::sqrt(d);
92 }
93 break;
94 case DistanceNormalizeMethod::CUSTOM:
95 for (auto& d : output_distances) {
96 d = my_parent.my_custom_normalize(d);
97 }
98 break;
99 case DistanceNormalizeMethod::NONE:
100 break;
101 }
102 }
103
104public:
105 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
106 my_buffer = my_parent.my_index.template getDataByLabel<HnswData_>(i);
107 Index_ kp1 = k + 1;
108 my_queue = my_parent.my_index.searchKnn(my_buffer.data(), kp1); // +1, as it forgets to discard 'self'.
109
110 if (output_indices) {
111 output_indices->clear();
112 output_indices->reserve(kp1);
113 }
114 if (output_distances) {
115 output_distances->clear();
116 output_distances->reserve(kp1);
117 }
118
119 bool self_found = false;
120 hnswlib::labeltype icopy = i;
121 while (!my_queue.empty()) {
122 const auto& top = my_queue.top();
123 if (!self_found && top.second == icopy) {
124 self_found = true;
125 } else {
126 if (output_indices) {
127 output_indices->push_back(top.second);
128 }
129 if (output_distances) {
130 output_distances->push_back(top.first);
131 }
132 }
133 my_queue.pop();
134 }
135
136 if (output_indices) {
137 std::reverse(output_indices->begin(), output_indices->end());
138 }
139 if (output_distances) {
140 std::reverse(output_distances->begin(), output_distances->end());
141 }
142
143 // Just in case we're full of ties at duplicate points, such that 'c'
144 // is not in the set. Note that, if self_found=false, we must have at
145 // least 'K+2' points for 'c' to not be detected as its own neighbor.
146 // Thus there is no need to worry whether we are popping off a non-'c'
147 // element and then returning fewer elements than expected.
148 if (!self_found) {
149 if (output_indices) {
150 output_indices->pop_back();
151 }
152 if (output_distances) {
153 output_distances->pop_back();
154 }
155 }
156
157 if (output_distances) {
158 normalize_distances(*output_distances);
159 }
160 }
161
162private:
163 void search_raw(const HnswData_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
164 k = std::min(k, my_parent.my_obs);
165 my_queue = my_parent.my_index.searchKnn(query, k);
166
167 if (output_indices) {
168 output_indices->resize(k);
169 }
170 if (output_distances) {
171 output_distances->resize(k);
172 }
173
174 auto position = k;
175 while (!my_queue.empty()) {
176 const auto& top = my_queue.top();
177 --position;
178 if (output_indices) {
179 (*output_indices)[position] = top.second;
180 }
181 if (output_distances) {
182 (*output_distances)[position] = top.first;
183 }
184 my_queue.pop();
185 }
186
187 if (output_distances) {
188 normalize_distances(*output_distances);
189 }
190 }
191
192public:
193 void search(const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
194 if constexpr(same_internal_data) {
195 my_queue = my_parent.my_index.searchKnn(query, k);
196 search_raw(query, k, output_indices, output_distances);
197 } else {
198 std::copy_n(query, my_parent.my_dim, my_buffer.begin());
199 search_raw(my_buffer.data(), k, output_indices, output_distances);
200 }
201 }
202};
203
204template<typename Index_, typename Data_, typename Distance_, typename HnswData_>
205class HnswPrebuilt final : public knncolle::Prebuilt<Index_, Data_, Distance_> {
206public:
207 template<class Matrix_>
208 HnswPrebuilt(const Matrix_& data, const DistanceConfig<Distance_, HnswData_>& distance_config, const HnswOptions& options) :
209 my_dim(data.num_dimensions()),
210 my_obs(data.num_observations()),
211 my_space(distance_config.create(my_dim)),
212 my_normalize_method(distance_config.normalize_method),
213 my_custom_normalize(distance_config.custom_normalize),
214 my_index(my_space.get(), my_obs, options.num_links, options.ef_construction)
215 {
216 auto work = data.new_known_extractor();
217 if constexpr(std::is_same<Data_, HnswData_>::value) {
218 for (Index_ i = 0; i < my_obs; ++i) {
219 auto ptr = work->next();
220 my_index.addPoint(ptr, i);
221 }
222 } else {
223 auto incoming = sanisizer::create<std::vector<HnswData_> >(my_dim);
224 for (Index_ i = 0; i < my_obs; ++i) {
225 auto ptr = work->next();
226 std::copy_n(ptr, my_dim, incoming.begin());
227 my_index.addPoint(incoming.data(), i);
228 }
229 }
230
231 my_index.setEf(options.ef_search);
232 return;
233 }
234
235private:
236 std::size_t my_dim;
237 Index_ my_obs;
238
239 // The following must be a pointer for polymorphism, but also so that
240 // references to the object in my_index are still valid after copying.
241 std::shared_ptr<hnswlib::SpaceInterface<HnswData_> > my_space;
242
243 DistanceNormalizeMethod my_normalize_method;
244 std::function<Distance_(Distance_)> my_custom_normalize;
245
246 hnswlib::HierarchicalNSW<HnswData_> my_index;
247
248 friend class HnswSearcher<Index_, Data_, Distance_, HnswData_>;
249
250public:
251 std::size_t num_dimensions() const {
252 return my_dim;
253 }
254
255 Index_ num_observations() const {
256 return my_obs;
257 }
258
259public:
260 std::unique_ptr<knncolle::Searcher<Index_, Data_, Distance_> > initialize() const {
261 return initialize_known();
262 }
263
264 auto initialize_known() const {
265 return std::make_unique<HnswSearcher<Index_, Data_, Distance_, HnswData_> >(*this);
266 }
267
268public:
269 void save(const std::filesystem::path& dir) const {
270 knncolle::quick_save(dir / "ALGORITHM", hnsw_prebuilt_save_name, std::strlen(hnsw_prebuilt_save_name));
271 knncolle::quick_save(dir / "NUM_OBS", &my_obs, 1);
272 knncolle::quick_save(dir / "NUM_DIM", &my_dim, 1);
273
275 knncolle::quick_save(dir / "TYPE", &type, 1);
276
277 const char* distname = get_distance_name(my_space.get());;
278 knncolle::quick_save(dir / "DISTANCE", distname, std::strlen(distname));
279 knncolle::quick_save(dir / "NORMALIZE", &my_normalize_method, 1);
280
281 // Custom normalization functions.
282 auto& datafunc = custom_save_for_hnsw_data<HnswData_>();
283 if (datafunc) {
284 datafunc(dir);
285 }
286
287 auto& distfunc = custom_save_for_hnsw_distance<HnswData_>();
288 if (std::strcmp(distname, "unknown") == 0 && distfunc) {
289 distfunc(dir, my_space.get());
290 }
291
292 auto& normfunc = custom_save_for_hnsw_normalize<Distance_>();
293 if (my_normalize_method == DistanceNormalizeMethod::CUSTOM && normfunc) {
294 normfunc(dir, my_custom_normalize);
295 }
296
297 // Dear God, make saveIndex() const.
298 auto index_ptr = const_cast<hnswlib::HierarchicalNSW<HnswData_>*>(&my_index);
299 index_ptr->saveIndex(dir / "INDEX");
300 }
301
302 HnswPrebuilt(const std::filesystem::path& dir) :
303 my_dim([&]() {
304 std::size_t dim;
305 knncolle::quick_load(dir / "NUM_DIM", &dim, 1);
306 return dim;
307 }()),
308
309 my_obs([&]() {
310 Index_ obs;
311 knncolle::quick_load(dir / "NUM_OBS", &obs, 1);
312 return obs;
313 }()),
314
315 my_space([&]() {
316 std::string method = knncolle::quick_load_as_string(dir / "DISTANCE");
317
318 if constexpr(std::is_same<HnswData_, float>::value) {
319 if (method == "l2") {
320 return static_cast<hnswlib::SpaceInterface<HnswData_>*>(new hnswlib::L2Space(my_dim));
321 }
322 }
323 if (method == "squared_euclidean") {
324 return static_cast<hnswlib::SpaceInterface<HnswData_>*>(new SquaredEuclideanDistance<HnswData_>(my_dim));
325 } else if (method == "manhattan") {
326 return static_cast<hnswlib::SpaceInterface<HnswData_>*>(new ManhattanDistance<HnswData_>(my_dim));
327 }
328
329 auto& loadfun = custom_load_for_hnsw_distance<HnswData_>();
330 if (!loadfun) {
331 throw std::runtime_error("no loader provided for an unknown distance");
332 }
333 return static_cast<hnswlib::SpaceInterface<HnswData_>*>(loadfun(dir, my_dim));
334 }()),
335
336 my_normalize_method([&]() {
338 knncolle::quick_load(dir / "NORMALIZE", &norm, 1);
339 return norm;
340 }()),
341
342 my_index(my_space.get(), dir / "INDEX")
343
344 {
345 if (my_normalize_method == DistanceNormalizeMethod::CUSTOM) {
346 auto& normfun = custom_load_for_hnsw_normalize<Distance_>();
347 if (!normfun) {
348 throw std::runtime_error("no loader provided for an unknown normalization");
349 }
350 my_custom_normalize = normfun(dir);
351 }
352 }
353};
382template<
383 typename Index_,
384 typename Data_,
385 typename Distance_,
386 class Matrix_ = knncolle::Matrix<Index_, Data_>,
387 typename HnswData_ = float
388>
389class HnswBuilder final : public knncolle::Builder<Index_, Data_, Distance_, Matrix_> {
390private:
391 DistanceConfig<Distance_, HnswData_> my_distance_config;
392 HnswOptions my_options;
393
394public:
400 my_distance_config(std::move(distance_config)),
401 my_options(std::move(options))
402 {
403 if (!my_distance_config.create) {
404 throw std::runtime_error("'distance_config.create' was not provided");
405 }
406 if (my_distance_config.normalize_method == DistanceNormalizeMethod::CUSTOM && !my_distance_config.custom_normalize) {
407 throw std::runtime_error("'distance_config.custom_normalize' was not provided");
408 }
409 }
410
415 HnswBuilder(DistanceConfig<Distance_, HnswData_> distance_config) : HnswBuilder(std::move(distance_config), {}) {}
416
421 return my_options;
422 }
423
424public:
428 knncolle::Prebuilt<Index_, Data_, Distance_>* build_raw(const Matrix_& data) const {
429 return build_known_raw(data);
430 }
435public:
439 auto build_known_raw(const Matrix_& data) const {
440 return new HnswPrebuilt<Index_, Data_, Distance_, HnswData_>(data, my_distance_config, my_options);
441 }
442
446 auto build_known_unique(const Matrix_& data) const {
447 return std::unique_ptr<I<decltype(*build_known_raw(data))> >(build_known_raw(data));
448 }
449
453 auto build_known_shared(const Matrix_& data) const {
454 return std::shared_ptr<I<decltype(*build_known_raw(data))> >(build_known_raw(data));
455 }
456};
457
458}
459
460#endif
virtual Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const=0
Perform an approximate nearest neighbor search with HNSW.
Definition Hnsw.hpp:389
auto build_known_shared(const Matrix_ &data) const
Definition Hnsw.hpp:453
HnswBuilder(DistanceConfig< Distance_, HnswData_ > distance_config)
Definition Hnsw.hpp:415
HnswBuilder(DistanceConfig< Distance_, HnswData_ > distance_config, HnswOptions options)
Definition Hnsw.hpp:399
auto build_known_raw(const Matrix_ &data) const
Definition Hnsw.hpp:439
auto build_known_unique(const Matrix_ &data) const
Definition Hnsw.hpp:446
HnswOptions & get_options()
Definition Hnsw.hpp:420
Distance classes for HNSW.
knncolle bindings for HNSW search.
Definition distances.hpp:13
DistanceNormalizeMethod
Definition distances.hpp:22
const char * get_distance_name(const hnswlib::SpaceInterface< HnswData_ > *distance)
Definition distances.hpp:212
void quick_load(const std::filesystem::path &path, Input_ *const contents, const Length_ length)
NumericType get_numeric_type()
std::string quick_load_as_string(const std::filesystem::path &path)
void quick_save(const std::filesystem::path &path, const Input_ *const contents, const Length_ length)
Distance configuration for the HNSW index.
Definition distances.hpp:31
DistanceNormalizeMethod normalize_method
Definition distances.hpp:40
std::function< hnswlib::SpaceInterface< HnswData_ > *(std::size_t)> create
Definition distances.hpp:35
std::function< Distance_(Distance_)> custom_normalize
Definition distances.hpp:46
Options for HnswBuilder and HnswPrebuilt.
Definition Hnsw.hpp:40
int num_links
Definition Hnsw.hpp:46
int ef_construction
Definition Hnsw.hpp:53
int ef_search
Definition Hnsw.hpp:60