84 std::priority_queue<std::pair<InternalData_, hnswlib::labeltype> > my_queue;
85 std::vector<InternalData_> my_buffer;
87 static constexpr bool same_internal = std::is_same<Float_, InternalData_>::value;
94 if constexpr(!same_internal) {
95 my_buffer.resize(my_parent->my_dim);
103 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
104 my_buffer = my_parent->my_index.template getDataByLabel<InternalData_>(i);
106 my_queue = my_parent->my_index.searchKnn(my_buffer.data(), kp1);
108 if (output_indices) {
109 output_indices->clear();
110 output_indices->reserve(kp1);
112 if (output_distances) {
113 output_distances->clear();
114 output_distances->reserve(kp1);
117 bool self_found =
false;
118 hnswlib::labeltype icopy = i;
119 while (!my_queue.empty()) {
120 const auto& top = my_queue.top();
121 if (!self_found && top.second == icopy) {
124 if (output_indices) {
125 output_indices->push_back(top.second);
127 if (output_distances) {
128 output_distances->push_back(top.first);
134 if (output_indices) {
135 std::reverse(output_indices->begin(), output_indices->end());
137 if (output_distances) {
138 std::reverse(output_distances->begin(), output_distances->end());
147 if (output_indices) {
148 output_indices->pop_back();
150 if (output_distances) {
151 output_distances->pop_back();
155 if (output_distances && my_parent->my_normalize) {
156 for (
auto& d : *output_distances) {
157 d = my_parent->my_normalize(d);
163 void search_raw(
const InternalData_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
164 k = std::min(k, my_parent->my_obs);
165 my_queue = my_parent->my_index.searchKnn(query, k);
167 if (output_indices) {
168 output_indices->resize(k);
170 if (output_distances) {
171 output_distances->resize(k);
175 while (!my_queue.empty()) {
176 const auto& top = my_queue.top();
178 if (output_indices) {
179 (*output_indices)[position] = top.second;
181 if (output_distances) {
182 (*output_distances)[position] = top.first;
187 if (output_distances && my_parent->my_normalize) {
188 for (
auto& d : *output_distances) {
189 d = my_parent->my_normalize(d);
195 void search(
const Float_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Float_>* output_distances) {
196 if constexpr(same_internal) {
197 my_queue = my_parent->my_index.searchKnn(query, k);
198 search_raw(query, k, output_indices, output_distances);
200 std::copy_n(query, my_parent->my_dim, my_buffer.begin());
201 search_raw(my_buffer.data(), k, output_indices, output_distances);
225 template<
class Matrix_>
227 my_dim(data.num_dimensions()),
228 my_obs(data.num_observations()),
231 return options.distance_options.create(my_dim);
232 }
else if constexpr(std::is_same<InternalData_, float>::value) {
233 return static_cast<hnswlib::SpaceInterface<InternalData_>*>(new hnswlib::L2Space(my_dim));
235 return static_cast<hnswlib::SpaceInterface<InternalData_>*>(new SquaredEuclideanDistance<InternalData_>(my_dim));
240 return options.distance_options.normalize;
242 return std::function<InternalData_(InternalData_)>();
244 return std::function<InternalData_(InternalData_)>([](InternalData_ x) -> InternalData_ { return std::sqrt(x); });
247 my_index(my_space.get(), my_obs, options.num_links, options.ef_construction)
249 typedef typename Matrix_::data_type Data_;
250 auto work = data.create_workspace();
251 if constexpr(std::is_same<Data_, InternalData_>::value) {
252 for (Index_ i = 0; i < my_obs; ++i) {
253 auto ptr = data.get_observation(work);
254 my_index.addPoint(ptr, i);
257 std::vector<InternalData_> incoming(my_dim);
258 for (Index_ i = 0; i < my_obs; ++i) {
259 auto ptr = data.get_observation(work);
260 std::copy_n(ptr, my_dim, incoming.begin());
261 my_index.addPoint(incoming.data(), i);
265 my_index.setEf(options.ef_search);
278 std::shared_ptr<hnswlib::SpaceInterface<InternalData_> > my_space;
280 std::function<InternalData_(InternalData_)> my_normalize;
281 hnswlib::HierarchicalNSW<InternalData_> my_index;
283 friend class HnswSearcher<Dim_, Index_, Float_, InternalData_>;
286 Dim_ num_dimensions()
const {
290 Index_ num_observations()
const {
294 std::unique_ptr<knncolle::Searcher<Index_, Float_> > initialize()
const {
295 return std::make_unique<HnswSearcher<Dim_, Index_, Float_, InternalData_> >(
this);