1#ifndef KNNCOLLE_HNSW_HNSW_HPP
2#define KNNCOLLE_HNSW_HNSW_HPP
15#include "sanisizer/sanisizer.hpp"
16#include "hnswlib/hnswalg.h"
32inline static constexpr const char* hnsw_prebuilt_save_name =
"knncolle_hnsw::Hnsw";
66template<
typename Index_,
typename Data_,
typename Distance_,
typename HnswData_>
69template<
typename Index_,
typename Data_,
typename Distance_,
typename HnswData_>
72 const HnswPrebuilt<Index_, Data_, Distance_, HnswData_>& my_parent;
74 std::priority_queue<std::pair<HnswData_, hnswlib::labeltype> > my_queue;
76 static constexpr bool same_internal_data = std::is_same<Data_, HnswData_>::value;
77 std::vector<HnswData_> my_buffer;
80 HnswSearcher(
const HnswPrebuilt<Index_, Data_, Distance_, HnswData_>& parent) : my_parent(parent) {
81 if constexpr(!same_internal_data) {
82 sanisizer::resize(my_buffer, my_parent.my_dim);
87 void normalize_distances(std::vector<Distance_>& output_distances)
const {
88 switch(my_parent.my_normalize_method) {
89 case DistanceNormalizeMethod::SQRT:
90 for (
auto& d : output_distances) {
94 case DistanceNormalizeMethod::CUSTOM:
95 for (
auto& d : output_distances) {
96 d = my_parent.my_custom_normalize(d);
99 case DistanceNormalizeMethod::NONE:
105 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
106 my_buffer = my_parent.my_index.template getDataByLabel<HnswData_>(i);
108 my_queue = my_parent.my_index.searchKnn(my_buffer.data(), kp1);
110 if (output_indices) {
111 output_indices->clear();
112 output_indices->reserve(kp1);
114 if (output_distances) {
115 output_distances->clear();
116 output_distances->reserve(kp1);
119 bool self_found =
false;
120 hnswlib::labeltype icopy = i;
121 while (!my_queue.empty()) {
122 const auto& top = my_queue.top();
123 if (!self_found && top.second == icopy) {
126 if (output_indices) {
127 output_indices->push_back(top.second);
129 if (output_distances) {
130 output_distances->push_back(top.first);
136 if (output_indices) {
137 std::reverse(output_indices->begin(), output_indices->end());
139 if (output_distances) {
140 std::reverse(output_distances->begin(), output_distances->end());
149 if (output_indices) {
150 output_indices->pop_back();
152 if (output_distances) {
153 output_distances->pop_back();
157 if (output_distances) {
158 normalize_distances(*output_distances);
163 void search_raw(
const HnswData_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
164 k = std::min(k, my_parent.my_obs);
165 my_queue = my_parent.my_index.searchKnn(query, k);
167 if (output_indices) {
168 output_indices->resize(k);
170 if (output_distances) {
171 output_distances->resize(k);
175 while (!my_queue.empty()) {
176 const auto& top = my_queue.top();
178 if (output_indices) {
179 (*output_indices)[position] = top.second;
181 if (output_distances) {
182 (*output_distances)[position] = top.first;
187 if (output_distances) {
188 normalize_distances(*output_distances);
193 void search(
const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
194 if constexpr(same_internal_data) {
195 my_queue = my_parent.my_index.searchKnn(query, k);
196 search_raw(query, k, output_indices, output_distances);
198 std::copy_n(query, my_parent.my_dim, my_buffer.begin());
199 search_raw(my_buffer.data(), k, output_indices, output_distances);
204template<
typename Index_,
typename Data_,
typename Distance_,
typename HnswData_>
207 template<
class Matrix_>
208 HnswPrebuilt(
const Matrix_& data,
const DistanceConfig<Distance_, HnswData_>& distance_config,
const HnswOptions& options) :
209 my_dim(data.num_dimensions()),
210 my_obs(data.num_observations()),
211 my_space(distance_config.create(my_dim)),
212 my_normalize_method(distance_config.normalize_method),
213 my_custom_normalize(distance_config.custom_normalize),
214 my_index(my_space.get(), my_obs, options.num_links, options.ef_construction)
216 auto work = data.new_known_extractor();
217 if constexpr(std::is_same<Data_, HnswData_>::value) {
218 for (Index_ i = 0; i < my_obs; ++i) {
219 auto ptr = work->next();
220 my_index.addPoint(ptr, i);
223 auto incoming = sanisizer::create<std::vector<HnswData_> >(my_dim);
224 for (Index_ i = 0; i < my_obs; ++i) {
225 auto ptr = work->next();
226 std::copy_n(ptr, my_dim, incoming.begin());
227 my_index.addPoint(incoming.data(), i);
231 my_index.setEf(options.ef_search);
241 std::shared_ptr<hnswlib::SpaceInterface<HnswData_> > my_space;
244 std::function<Distance_(Distance_)> my_custom_normalize;
246 hnswlib::HierarchicalNSW<HnswData_> my_index;
248 friend class HnswSearcher<Index_, Data_, Distance_, HnswData_>;
251 std::size_t num_dimensions()
const {
255 Index_ num_observations()
const {
260 std::unique_ptr<knncolle::Searcher<Index_, Data_, Distance_> > initialize()
const {
261 return initialize_known();
264 auto initialize_known()
const {
265 return std::make_unique<HnswSearcher<Index_, Data_, Distance_, HnswData_> >(*this);
269 void save(
const std::filesystem::path& dir)
const {
270 knncolle::quick_save(dir /
"ALGORITHM", hnsw_prebuilt_save_name, std::strlen(hnsw_prebuilt_save_name));
282 auto& datafunc = custom_save_for_hnsw_data<HnswData_>();
287 auto& distfunc = custom_save_for_hnsw_distance<HnswData_>();
288 if (std::strcmp(distname,
"unknown") == 0 && distfunc) {
289 distfunc(dir, my_space.get());
292 auto& normfunc = custom_save_for_hnsw_normalize<Distance_>();
293 if (my_normalize_method == DistanceNormalizeMethod::CUSTOM && normfunc) {
294 normfunc(dir, my_custom_normalize);
298 auto index_ptr =
const_cast<hnswlib::HierarchicalNSW<HnswData_>*
>(&my_index);
299 index_ptr->saveIndex(dir /
"INDEX");
302 HnswPrebuilt(
const std::filesystem::path& dir) :
318 if constexpr(std::is_same<HnswData_, float>::value) {
319 if (method ==
"l2") {
320 return static_cast<hnswlib::SpaceInterface<HnswData_>*
>(
new hnswlib::L2Space(my_dim));
323 if (method ==
"squared_euclidean") {
324 return static_cast<hnswlib::SpaceInterface<HnswData_>*
>(
new SquaredEuclideanDistance<HnswData_>(my_dim));
325 }
else if (method ==
"manhattan") {
326 return static_cast<hnswlib::SpaceInterface<HnswData_>*
>(
new ManhattanDistance<HnswData_>(my_dim));
329 auto& loadfun = custom_load_for_hnsw_distance<HnswData_>();
331 throw std::runtime_error(
"no loader provided for an unknown distance");
333 return static_cast<hnswlib::SpaceInterface<HnswData_>*
>(loadfun(dir, my_dim));
336 my_normalize_method([&]() {
342 my_index(my_space.get(), dir /
"INDEX")
345 if (my_normalize_method == DistanceNormalizeMethod::CUSTOM) {
346 auto& normfun = custom_load_for_hnsw_normalize<Distance_>();
348 throw std::runtime_error(
"no loader provided for an unknown normalization");
350 my_custom_normalize = normfun(dir);
387 typename HnswData_ =
float
400 my_distance_config(std::move(distance_config)),
401 my_options(std::move(options))
403 if (!my_distance_config.
create) {
404 throw std::runtime_error(
"'distance_config.create' was not provided");
407 throw std::runtime_error(
"'distance_config.custom_normalize' was not provided");
440 return new HnswPrebuilt<Index_, Data_, Distance_, HnswData_>(data, my_distance_config, my_options);
virtual Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const=0
Perform an approximate nearest neighbor search with HNSW.
Definition Hnsw.hpp:389
auto build_known_shared(const Matrix_ &data) const
Definition Hnsw.hpp:453
HnswBuilder(DistanceConfig< Distance_, HnswData_ > distance_config)
Definition Hnsw.hpp:415
HnswBuilder(DistanceConfig< Distance_, HnswData_ > distance_config, HnswOptions options)
Definition Hnsw.hpp:399
auto build_known_raw(const Matrix_ &data) const
Definition Hnsw.hpp:439
auto build_known_unique(const Matrix_ &data) const
Definition Hnsw.hpp:446
HnswOptions & get_options()
Definition Hnsw.hpp:420
Distance classes for HNSW.
knncolle bindings for HNSW search.
Definition distances.hpp:13
DistanceNormalizeMethod
Definition distances.hpp:22
const char * get_distance_name(const hnswlib::SpaceInterface< HnswData_ > *distance)
Definition distances.hpp:212
void quick_load(const std::filesystem::path &path, Input_ *const contents, const Length_ length)
NumericType get_numeric_type()
std::string quick_load_as_string(const std::filesystem::path &path)
void quick_save(const std::filesystem::path &path, const Input_ *const contents, const Length_ length)
Distance configuration for the HNSW index.
Definition distances.hpp:31
DistanceNormalizeMethod normalize_method
Definition distances.hpp:40
std::function< hnswlib::SpaceInterface< HnswData_ > *(std::size_t)> create
Definition distances.hpp:35
std::function< Distance_(Distance_)> custom_normalize
Definition distances.hpp:46
Options for HnswBuilder and HnswPrebuilt.
Definition Hnsw.hpp:40
int num_links
Definition Hnsw.hpp:46
int ef_construction
Definition Hnsw.hpp:53
int ef_search
Definition Hnsw.hpp:60