1#ifndef KNNCOLLE_HNSW_HPP
2#define KNNCOLLE_HNSW_HPP
14#include "sanisizer/sanisizer.hpp"
15#include "hnswlib/hnswalg.h"
60template<
typename Index_,
typename Data_,
typename Distance_,
typename HnswData_>
63template<
typename Index_,
typename Data_,
typename Distance_,
typename HnswData_>
66 const HnswPrebuilt<Index_, Data_, Distance_, HnswData_>& my_parent;
68 std::priority_queue<std::pair<HnswData_, hnswlib::labeltype> > my_queue;
70 static constexpr bool same_internal_data = std::is_same<Data_, HnswData_>::value;
71 std::vector<HnswData_> my_buffer;
74 HnswSearcher(
const HnswPrebuilt<Index_, Data_, Distance_, HnswData_>& parent) : my_parent(parent) {
75 if constexpr(!same_internal_data) {
76 sanisizer::resize(my_buffer, my_parent.my_dim);
81 void normalize_distances(std::vector<Distance_>& output_distances)
const {
82 switch(my_parent.my_normalize_method) {
83 case DistanceNormalizeMethod::SQRT:
84 for (
auto& d : output_distances) {
88 case DistanceNormalizeMethod::CUSTOM:
89 for (
auto& d : output_distances) {
90 d = my_parent.my_custom_normalize(d);
93 case DistanceNormalizeMethod::NONE:
99 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
100 my_buffer = my_parent.my_index.template getDataByLabel<HnswData_>(i);
102 my_queue = my_parent.my_index.searchKnn(my_buffer.data(), kp1);
104 if (output_indices) {
105 output_indices->clear();
106 output_indices->reserve(kp1);
108 if (output_distances) {
109 output_distances->clear();
110 output_distances->reserve(kp1);
113 bool self_found =
false;
114 hnswlib::labeltype icopy = i;
115 while (!my_queue.empty()) {
116 const auto& top = my_queue.top();
117 if (!self_found && top.second == icopy) {
120 if (output_indices) {
121 output_indices->push_back(top.second);
123 if (output_distances) {
124 output_distances->push_back(top.first);
130 if (output_indices) {
131 std::reverse(output_indices->begin(), output_indices->end());
133 if (output_distances) {
134 std::reverse(output_distances->begin(), output_distances->end());
143 if (output_indices) {
144 output_indices->pop_back();
146 if (output_distances) {
147 output_distances->pop_back();
151 if (output_distances) {
152 normalize_distances(*output_distances);
157 void search_raw(
const HnswData_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
158 k = std::min(k, my_parent.my_obs);
159 my_queue = my_parent.my_index.searchKnn(query, k);
161 if (output_indices) {
162 output_indices->resize(k);
164 if (output_distances) {
165 output_distances->resize(k);
169 while (!my_queue.empty()) {
170 const auto& top = my_queue.top();
172 if (output_indices) {
173 (*output_indices)[position] = top.second;
175 if (output_distances) {
176 (*output_distances)[position] = top.first;
181 if (output_distances) {
182 normalize_distances(*output_distances);
187 void search(
const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
188 if constexpr(same_internal_data) {
189 my_queue = my_parent.my_index.searchKnn(query, k);
190 search_raw(query, k, output_indices, output_distances);
192 std::copy_n(query, my_parent.my_dim, my_buffer.begin());
193 search_raw(my_buffer.data(), k, output_indices, output_distances);
198template<
typename Index_,
typename Data_,
typename Distance_,
typename HnswData_>
201 template<
class Matrix_>
202 HnswPrebuilt(
const Matrix_& data,
const DistanceConfig<Distance_, HnswData_>& distance_config,
const HnswOptions& options) :
203 my_dim(data.num_dimensions()),
204 my_obs(data.num_observations()),
205 my_space(distance_config.create(my_dim)),
206 my_normalize_method(distance_config.normalize_method),
207 my_custom_normalize(distance_config.custom_normalize),
208 my_index(my_space.get(), my_obs, options.num_links, options.ef_construction)
210 auto work = data.new_known_extractor();
211 if constexpr(std::is_same<Data_, HnswData_>::value) {
212 for (Index_ i = 0; i < my_obs; ++i) {
213 auto ptr = work->next();
214 my_index.addPoint(ptr, i);
217 auto incoming = sanisizer::create<std::vector<HnswData_> >(my_dim);
218 for (Index_ i = 0; i < my_obs; ++i) {
219 auto ptr = work->next();
220 std::copy_n(ptr, my_dim, incoming.begin());
221 my_index.addPoint(incoming.data(), i);
225 my_index.setEf(options.ef_search);
235 std::shared_ptr<hnswlib::SpaceInterface<HnswData_> > my_space;
238 std::function<Distance_(Distance_)> my_custom_normalize;
240 hnswlib::HierarchicalNSW<HnswData_> my_index;
242 friend class HnswSearcher<Index_, Data_, Distance_, HnswData_>;
245 std::size_t num_dimensions()
const {
249 Index_ num_observations()
const {
254 std::unique_ptr<knncolle::Searcher<Index_, Data_, Distance_> > initialize()
const {
255 return initialize_known();
258 auto initialize_known()
const {
259 return std::make_unique<HnswSearcher<Index_, Data_, Distance_, HnswData_> >(*this);
263 void save(
const std::string& prefix)
const {
276 auto& datafunc = custom_save_for_hnsw_data<HnswData_>();
281 auto& distfunc = custom_save_for_hnsw_distance<HnswData_>();
282 if (std::strcmp(distname,
"unknown") == 0 && distfunc) {
283 distfunc(prefix, my_space.get());
286 auto& normfunc = custom_save_for_hnsw_normalize<Distance_>();
287 if (my_normalize_method == DistanceNormalizeMethod::CUSTOM && normfunc) {
288 normfunc(prefix, my_custom_normalize);
292 auto index_ptr =
const_cast<hnswlib::HierarchicalNSW<HnswData_>*
>(&my_index);
293 index_ptr->saveIndex(prefix +
"index");
296 HnswPrebuilt(
const std::string& prefix) :
312 if constexpr(std::is_same<HnswData_, float>::value) {
313 if (method ==
"l2") {
314 return static_cast<hnswlib::SpaceInterface<HnswData_>*
>(
new hnswlib::L2Space(my_dim));
317 if (method ==
"squared_euclidean") {
318 return static_cast<hnswlib::SpaceInterface<HnswData_>*
>(
new SquaredEuclideanDistance<HnswData_>(my_dim));
319 }
else if (method ==
"manhattan") {
320 return static_cast<hnswlib::SpaceInterface<HnswData_>*
>(
new ManhattanDistance<HnswData_>(my_dim));
323 auto& loadfun = custom_load_for_hnsw_distance<HnswData_>();
325 throw std::runtime_error(
"no loader provided for an unknown distance");
327 return static_cast<hnswlib::SpaceInterface<HnswData_>*
>(loadfun(prefix, my_dim));
330 my_normalize_method([&]() {
336 my_index(my_space.get(), prefix +
"index")
339 if (my_normalize_method == DistanceNormalizeMethod::CUSTOM) {
340 auto& normfun = custom_load_for_hnsw_normalize<Distance_>();
342 throw std::runtime_error(
"no loader provided for an unknown normalization");
344 my_custom_normalize = normfun(prefix);
381 typename HnswData_ =
float
394 my_distance_config(std::move(distance_config)),
395 my_options(std::move(options))
397 if (!my_distance_config.
create) {
398 throw std::runtime_error(
"'distance_config.create' was not provided");
401 throw std::runtime_error(
"'distance_config.custom_normalize' was not provided");
428 return new HnswPrebuilt<Index_, Data_, Distance_, HnswData_>(data, my_distance_config, my_options);
Perform an approximate nearest neighbor search with HNSW.
Definition Hnsw.hpp:383
auto build_known_shared(const Matrix_ &data) const
Definition Hnsw.hpp:441
HnswBuilder(DistanceConfig< Distance_, HnswData_ > distance_config)
Definition Hnsw.hpp:409
HnswBuilder(DistanceConfig< Distance_, HnswData_ > distance_config, HnswOptions options)
Definition Hnsw.hpp:393
auto build_known_raw(const Matrix_ &data) const
Definition Hnsw.hpp:427
auto build_known_unique(const Matrix_ &data) const
Definition Hnsw.hpp:434
HnswOptions & get_options()
Definition Hnsw.hpp:414
Distance classes for HNSW.
knncolle bindings for HNSW search.
Definition distances.hpp:13
DistanceNormalizeMethod
Definition distances.hpp:22
const char * get_distance_name(const hnswlib::SpaceInterface< HnswData_ > *distance)
Definition distances.hpp:206
std::string quick_load_as_string(const std::string &path)
NumericType get_numeric_type()
void quick_load(const std::string &path, Input_ *const contents, const Length_ length)
void quick_save(const std::string &path, const Input_ *const contents, const Length_ length)
Distance configuration for the HNSW index.
Definition distances.hpp:31
DistanceNormalizeMethod normalize_method
Definition distances.hpp:40
std::function< hnswlib::SpaceInterface< HnswData_ > *(std::size_t)> create
Definition distances.hpp:35
std::function< Distance_(Distance_)> custom_normalize
Definition distances.hpp:46
Options for HnswBuilder and HnswPrebuilt.
Definition Hnsw.hpp:34
int num_links
Definition Hnsw.hpp:40
int ef_construction
Definition Hnsw.hpp:47
int ef_search
Definition Hnsw.hpp:54