29inline static constexpr const char* annoy_prebuilt_save_name =
"knncolle_annoy::Annoy";
52template<
typename Index_,
typename Data_,
typename Distance_,
typename AnnoyDistance_,
typename AnnoyIndex_,
typename AnnoyData_,
class AnnoyRng_,
class AnnoyThreadPolicy_>
60 typename AnnoyIndex_ = Index_,
61 typename AnnoyData_ = float,
62 class AnnoyRng_ = Annoy::Kiss64Random,
63 class AnnoyThreadPolicy_ = Annoy::AnnoyIndexSingleThreadedBuildPolicy
67 const AnnoyPrebuilt<Index_, Data_, Distance_, AnnoyDistance_, AnnoyIndex_, AnnoyData_, AnnoyRng_, AnnoyThreadPolicy_>& my_parent;
69 static constexpr bool same_internal_data = std::is_same<Data_, AnnoyData_>::value;
70 typename std::conditional<!same_internal_data, std::vector<AnnoyData_>,
bool>::type my_buffer;
72 static constexpr bool same_internal_index = std::is_same<Index_, AnnoyIndex_>::value;
73 std::vector<AnnoyIndex_> my_indices;
75 static constexpr bool same_internal_distance = std::is_same<Distance_, AnnoyData_>::value;
76 typename std::conditional<!same_internal_distance, std::vector<AnnoyData_>,
bool>::type my_distances;
78 typedef int SearchKType;
80 Index_ my_capped_k = 0;
82 SearchKType get_search_k(Index_ k)
const {
83 if (my_parent.my_search_mult < 1) {
85 }
else if (k <= my_capped_k) {
86 return my_parent.my_search_mult *
static_cast<double>(k) + 0.5;
88 return std::numeric_limits<SearchKType>::max();
93 AnnoySearcher(
const AnnoyPrebuilt<Index_, Data_, Distance_, AnnoyDistance_, AnnoyIndex_, AnnoyData_, AnnoyRng_, AnnoyThreadPolicy_>& parent) : my_parent(parent) {
94 if constexpr(!same_internal_data) {
95 sanisizer::resize(my_buffer, my_parent.my_dim);
98 if (my_parent.my_search_mult >= 1) {
99 my_capped_k = static_cast<double>(std::numeric_limits<SearchKType>::max()) / my_parent.my_search_mult;
104 std::pair<std::vector<AnnoyIndex_>*, std::vector<AnnoyData_>*> obtain_pointers(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances, Index_ k) {
105 std::vector<AnnoyIndex_>* icopy_ptr = &my_indices;
106 if (output_indices) {
107 if constexpr(same_internal_index) {
108 icopy_ptr = output_indices;
112 icopy_ptr->reserve(k);
114 std::vector<AnnoyData_>* dcopy_ptr = NULL;
115 if (output_distances) {
116 if constexpr(same_internal_distance) {
117 dcopy_ptr = output_distances;
119 dcopy_ptr = &my_distances;
122 dcopy_ptr->reserve(k);
125 return std::make_pair(icopy_ptr, dcopy_ptr);
128 template<
typename Type_>
129 static void remove_self(std::vector<Type_>& vec, std::size_t at) {
130 if (at < vec.size()) {
131 vec.erase(vec.begin() + at);
137 template<
typename Source_,
typename Dest_>
138 static void copy_skip_self(
const std::vector<Source_>& source, std::vector<Dest_>& dest, std::size_t at) {
139 auto sIt = source.begin();
140 auto end = source.size();
142 dest.reserve(end - 1);
145 dest.insert(dest.end(), sIt, sIt + at);
146 dest.insert(dest.end(), sIt + at + 1, source.end());
153 dest.insert(dest.end(), sIt, sIt + end - 1);
158 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
160 const auto kp1 = k + 1;
162 auto ptrs = obtain_pointers(output_indices, output_distances, kp1);
163 auto icopy_ptr = ptrs.first;
164 auto dcopy_ptr = ptrs.second;
166 my_parent.my_index.get_nns_by_item(
168 sanisizer::cast<std::size_t>(sanisizer::attest_gez(kp1)),
176 const auto& cur_i = *icopy_ptr;
178 const AnnoyIndex_ icopy = i;
179 for (std::size_t x = 0, end = cur_i.size(); x < end; ++x) {
180 if (cur_i[x] == icopy) {
187 if (output_indices) {
188 if constexpr(same_internal_index) {
189 remove_self(*output_indices, at);
191 copy_skip_self(my_indices, *output_indices, at);
195 if (output_distances) {
196 if constexpr(same_internal_distance) {
197 remove_self(*output_distances, at);
199 copy_skip_self(my_distances, *output_distances, at);
205 void search_raw(
const AnnoyData_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
206 auto ptrs = obtain_pointers(output_indices, output_distances, k);
207 auto icopy_ptr = ptrs.first;
208 auto dcopy_ptr = ptrs.second;
210 my_parent.my_index.get_nns_by_vector(
212 sanisizer::cast<std::size_t>(sanisizer::attest_gez(k)),
218 if (output_indices) {
219 if constexpr(!same_internal_index) {
220 output_indices->clear();
221 output_indices->insert(output_indices->end(), my_indices.begin(), my_indices.end());
225 if (output_distances) {
226 if constexpr(!same_internal_distance) {
227 output_distances->clear();
228 output_distances->insert(output_distances->end(), my_distances.begin(), my_distances.end());
234 void search(
const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
235 if constexpr(same_internal_data) {
236 search_raw(query, k, output_indices, output_distances);
238 std::copy_n(query, my_parent.my_dim, my_buffer.begin());
239 search_raw(my_buffer.data(), k, output_indices, output_distances);
248 class AnnoyDistance_,
249 typename AnnoyIndex_ = Index_,
250 typename AnnoyData_ = float,
251 class AnnoyRng_ = Annoy::Kiss64Random,
252 class AnnoyThreadPolicy_ = Annoy::AnnoyIndexSingleThreadedBuildPolicy
256 template<
class Matrix_>
257 AnnoyPrebuilt(
const Matrix_& data,
const AnnoyOptions& options) :
258 my_dim(data.num_dimensions()),
259 my_obs(data.num_observations()),
260 my_search_mult(options.search_mult),
264 sanisizer::cast<AnnoyIndex_>(sanisizer::attest_gez(my_obs));
266 auto work = data.new_known_extractor();
267 if constexpr(std::is_same<Data_, AnnoyData_>::value) {
268 for (Index_ i = 0; i < my_obs; ++i) {
269 auto ptr = work->next();
270 my_index.add_item(i, ptr);
273 auto incoming = sanisizer::create<std::vector<AnnoyData_> >(my_dim);
274 for (Index_ i = 0; i < my_obs; ++i) {
275 auto ptr = work->next();
276 std::copy_n(ptr, my_dim, incoming.begin());
277 my_index.add_item(i, incoming.data());
281 my_index.build(options.num_trees);
288 double my_search_mult;
293 struct SuperHackyThing final :
public Annoy::AnnoyIndex<AnnoyIndex_, AnnoyData_, AnnoyDistance_, AnnoyRng_, AnnoyThreadPolicy_> {
294 template<
typename ... Args_>
295 SuperHackyThing(Args_&& ... args) : Annoy::AnnoyIndex<AnnoyIndex_, AnnoyData_, AnnoyDistance_, AnnoyRng_, AnnoyThreadPolicy_>(std::forward<Args_>(args)...) {}
297 auto get_nodes()
const {
305 auto get_n_nodes()
const {
306 return this->_n_nodes;
309 SuperHackyThing my_index;
311 friend class AnnoySearcher<Index_, Data_, Distance_, AnnoyDistance_, AnnoyIndex_, AnnoyData_, AnnoyRng_, AnnoyThreadPolicy_>;
314 std::size_t num_dimensions()
const {
318 Index_ num_observations()
const {
322 std::unique_ptr<knncolle::Searcher<Index_, Data_, Distance_> > initialize()
const {
323 return initialize_known();
326 auto initialize_known()
const {
327 return std::make_unique<AnnoySearcher<Index_, Data_, Distance_, AnnoyDistance_, AnnoyIndex_, AnnoyData_, AnnoyRng_, AnnoyThreadPolicy_> >(*this);
331 void save(
const std::filesystem::path& dir)
const {
332 knncolle::quick_save(dir /
"ALGORITHM", annoy_prebuilt_save_name, std::strlen(annoy_prebuilt_save_name));
342 const auto dname = get_distance_name<AnnoyDistance_>();
346 auto& icust = custom_save_for_annoy_index<AnnoyIndex_>();
351 auto& dcust = custom_save_for_annoy_data<AnnoyData_>();
356 auto& dscust = custom_save_for_annoy_distance<AnnoyDistance_>();
366 const auto idxpath = dir /
"INDEX";
367 knncolle::quick_save(idxpath,
reinterpret_cast<char*
>(my_index.get_nodes()), sanisizer::product<std::streamsize>(my_index.get_s(), my_index.get_n_nodes()));
370 AnnoyPrebuilt(
const std::filesystem::path& dir, std::size_t ndim) : my_dim(ndim), my_index(ndim) {
374 const auto idxpath = (dir /
"INDEX").
string();
376 if (!my_index.load(idxpath.c_str(),
true, &errbuf)) {
377 std::runtime_error ex(errbuf);
418 class AnnoyDistance_,
419 typename AnnoyIndex_ = Index_,
420 typename AnnoyData_ = float,
421 class AnnoyRng_ = Annoy::Kiss64Random,
422 class AnnoyThreadPolicy_ = Annoy::AnnoyIndexSingleThreadedBuildPolicy,
464 return new AnnoyPrebuilt<Index_, Data_, Distance_, AnnoyDistance_, AnnoyIndex_, AnnoyData_, AnnoyRng_, AnnoyThreadPolicy_>(data, my_options);