knncolle_hnsw
knncolle bindings for HNSW
Loading...
Searching...
No Matches
knncolle_hnsw.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_HNSW_HPP
2#define KNNCOLLE_HNSW_HPP
3
4#include <vector>
5#include <type_traits>
6#include <queue>
7#include <algorithm>
8#include <memory>
9#include <cstddef>
10
11#include "knncolle/knncolle.hpp"
12#include "hnswlib/hnswalg.h"
13
14#include "distances.hpp"
15
26namespace knncolle_hnsw {
27
40 int num_links = 16;
41
47 int ef_construction = 200;
48
54 int ef_search = 10;
55};
56
57template<typename Index_, typename Data_, typename Distance_, typename HnswData_>
58class HnswPrebuilt;
59
70template<typename Index_, typename Data_, typename Distance_, typename HnswData_>
71class HnswSearcher final : public knncolle::Searcher<Index_, Data_, Distance_> {
72private:
74
75 std::priority_queue<std::pair<HnswData_, hnswlib::labeltype> > my_queue;
76
77 static constexpr bool same_internal_data = std::is_same<Data_, HnswData_>::value;
78 std::vector<HnswData_> my_buffer;
79
80public:
84 HnswSearcher(const HnswPrebuilt<Index_, Data_, Distance_, HnswData_>& parent) : my_parent(parent) {
85 if constexpr(!same_internal_data) {
86 my_buffer.resize(my_parent.my_dim);
87 }
88 }
93public:
94 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
95 my_buffer = my_parent.my_index.template getDataByLabel<HnswData_>(i);
96 Index_ kp1 = k + 1;
97 my_queue = my_parent.my_index.searchKnn(my_buffer.data(), kp1); // +1, as it forgets to discard 'self'.
98
99 if (output_indices) {
100 output_indices->clear();
101 output_indices->reserve(kp1);
102 }
103 if (output_distances) {
104 output_distances->clear();
105 output_distances->reserve(kp1);
106 }
107
108 bool self_found = false;
109 hnswlib::labeltype icopy = i;
110 while (!my_queue.empty()) {
111 const auto& top = my_queue.top();
112 if (!self_found && top.second == icopy) {
113 self_found = true;
114 } else {
115 if (output_indices) {
116 output_indices->push_back(top.second);
117 }
118 if (output_distances) {
119 output_distances->push_back(top.first);
120 }
121 }
122 my_queue.pop();
123 }
124
125 if (output_indices) {
126 std::reverse(output_indices->begin(), output_indices->end());
127 }
128 if (output_distances) {
129 std::reverse(output_distances->begin(), output_distances->end());
130 }
131
132 // Just in case we're full of ties at duplicate points, such that 'c'
133 // is not in the set. Note that, if self_found=false, we must have at
134 // least 'K+2' points for 'c' to not be detected as its own neighbor.
135 // Thus there is no need to worry whether we are popping off a non-'c'
136 // element and then returning fewer elements than expected.
137 if (!self_found) {
138 if (output_indices) {
139 output_indices->pop_back();
140 }
141 if (output_distances) {
142 output_distances->pop_back();
143 }
144 }
145
146 if (output_distances && my_parent.my_normalize) {
147 for (auto& d : *output_distances) {
148 d = my_parent.my_normalize(d);
149 }
150 }
151 }
152
153private:
154 void search_raw(const HnswData_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
155 k = std::min(k, my_parent.my_obs);
156 my_queue = my_parent.my_index.searchKnn(query, k);
157
158 if (output_indices) {
159 output_indices->resize(k);
160 }
161 if (output_distances) {
162 output_distances->resize(k);
163 }
164
165 auto position = k;
166 while (!my_queue.empty()) {
167 const auto& top = my_queue.top();
168 --position;
169 if (output_indices) {
170 (*output_indices)[position] = top.second;
171 }
172 if (output_distances) {
173 (*output_distances)[position] = top.first;
174 }
175 my_queue.pop();
176 }
177
178 if (output_distances && my_parent.my_normalize) {
179 for (auto& d : *output_distances) {
180 d = my_parent.my_normalize(d);
181 }
182 }
183 }
184
185public:
186 void search(const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
187 if constexpr(same_internal_data) {
188 my_queue = my_parent.my_index.searchKnn(query, k);
189 search_raw(query, k, output_indices, output_distances);
190 } else {
191 std::copy_n(query, my_parent.my_dim, my_buffer.begin());
192 search_raw(my_buffer.data(), k, output_indices, output_distances);
193 }
194 }
195};
196
208template<typename Index_, typename Data_, typename Distance_, typename HnswData_>
209class HnswPrebuilt : public knncolle::Prebuilt<Index_, Data_, Distance_> {
210public:
214 template<class Matrix_>
215 HnswPrebuilt(const Matrix_& data, const DistanceConfig<HnswData_>& distance_config, const HnswOptions& options) :
216 my_dim(data.num_dimensions()),
217 my_obs(data.num_observations()),
218 my_space(distance_config.create(my_dim)),
219 my_normalize(distance_config.normalize),
220 my_index(my_space.get(), my_obs, options.num_links, options.ef_construction)
221 {
222 auto work = data.new_extractor();
223 if constexpr(std::is_same<Data_, HnswData_>::value) {
224 for (Index_ i = 0; i < my_obs; ++i) {
225 auto ptr = work->next();
226 my_index.addPoint(ptr, i);
227 }
228 } else {
229 std::vector<HnswData_> incoming(my_dim);
230 for (Index_ i = 0; i < my_obs; ++i) {
231 auto ptr = work->next();
232 std::copy_n(ptr, my_dim, incoming.begin());
233 my_index.addPoint(incoming.data(), i);
234 }
235 }
236
237 my_index.setEf(options.ef_search);
238 return;
239 }
244private:
245 std::size_t my_dim;
246 Index_ my_obs;
247
248 // The following must be a pointer for polymorphism, but also so that
249 // references to the object in my_index are still valid after copying.
250 std::shared_ptr<hnswlib::SpaceInterface<HnswData_> > my_space;
251
252 std::function<HnswData_(HnswData_)> my_normalize;
253 hnswlib::HierarchicalNSW<HnswData_> my_index;
254
255 friend class HnswSearcher<Index_, Data_, Distance_, HnswData_>;
256
257public:
258 std::size_t num_dimensions() const {
259 return my_dim;
260 }
261
262 Index_ num_observations() const {
263 return my_obs;
264 }
265
269 std::unique_ptr<knncolle::Searcher<Index_, Data_, Distance_> > initialize() const {
270 return std::make_unique<HnswSearcher<Index_, Data_, Distance_, HnswData_> >(*this);
271 }
272};
273
298template<
299 typename Index_,
300 typename Data_,
301 typename Distance_,
303 typename HnswData_ = float
304>
305class HnswBuilder : public knncolle::Builder<Index_, Data_, Distance_, Matrix_> {
306private:
307 DistanceConfig<HnswData_> my_distance_config;
308 HnswOptions my_options;
309
310public:
315 HnswBuilder(DistanceConfig<HnswData_> distance_config, HnswOptions options) : my_distance_config(std::move(distance_config)), my_options(std::move(options)) {
316 if (!my_distance_config.create) {
317 throw std::runtime_error("'distance_config.create' was not provided");
318 }
319 }
320
325 HnswBuilder(DistanceConfig<HnswData_> distance_config) : HnswBuilder(std::move(distance_config), {}) {}
326
331 return my_options;
332 }
333
334public:
339 return new HnswPrebuilt<Index_, Data_, Distance_, HnswData_>(data, my_distance_config, my_options);
340 }
341};
342
343}
344
345#endif
Perform an approximate nearest neighbor search with HNSW.
Definition knncolle_hnsw.hpp:305
HnswBuilder(DistanceConfig< HnswData_ > distance_config, HnswOptions options)
Definition knncolle_hnsw.hpp:315
knncolle::Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const
Definition knncolle_hnsw.hpp:338
HnswOptions & get_options()
Definition knncolle_hnsw.hpp:330
HnswBuilder(DistanceConfig< HnswData_ > distance_config)
Definition knncolle_hnsw.hpp:325
Prebuilt index for an Hnsw search.
Definition knncolle_hnsw.hpp:209
std::unique_ptr< knncolle::Searcher< Index_, Data_, Distance_ > > initialize() const
Definition knncolle_hnsw.hpp:269
Searcher on an Hnsw index.
Definition knncolle_hnsw.hpp:71
Distance classes for HNSW.
knncolle bindings for HNSW search.
Definition distances.hpp:13
Distance configuration for the HNSW index.
Definition distances.hpp:21
std::function< HnswData_(HnswData_)> normalize
Definition distances.hpp:31
std::function< hnswlib::SpaceInterface< HnswData_ > *(std::size_t)> create
Definition distances.hpp:25
Options for HnswBuilder and HnswPrebuilt.
Definition knncolle_hnsw.hpp:34
int num_links
Definition knncolle_hnsw.hpp:40
int ef_construction
Definition knncolle_hnsw.hpp:47
int ef_search
Definition knncolle_hnsw.hpp:54