46template<
typename Index_,
typename Data_,
typename Distance_,
typename AnnoyDistance_,
typename AnnoyIndex_,
typename AnnoyData_,
class AnnoyRng_,
class AnnoyThreadPolicy_>
54 typename AnnoyIndex_ = Index_,
55 typename AnnoyData_ = float,
56 class AnnoyRng_ = Annoy::Kiss64Random,
57 class AnnoyThreadPolicy_ = Annoy::AnnoyIndexSingleThreadedBuildPolicy
61 const AnnoyPrebuilt<Index_, Data_, Distance_, AnnoyDistance_, AnnoyIndex_, AnnoyData_, AnnoyRng_, AnnoyThreadPolicy_>& my_parent;
63 static constexpr bool same_internal_data = std::is_same<Data_, AnnoyData_>::value;
64 typename std::conditional<!same_internal_data, std::vector<AnnoyData_>,
bool>::type my_buffer;
66 static constexpr bool same_internal_index = std::is_same<Index_, AnnoyIndex_>::value;
67 std::vector<AnnoyIndex_> my_indices;
69 static constexpr bool same_internal_distance = std::is_same<Distance_, AnnoyData_>::value;
70 typename std::conditional<!same_internal_distance, std::vector<AnnoyData_>,
bool>::type my_distances;
72 typedef int SearchKType;
74 Index_ my_capped_k = 0;
76 SearchKType get_search_k(Index_ k)
const {
77 if (my_parent.my_search_mult < 1) {
79 }
else if (k <= my_capped_k) {
80 return my_parent.my_search_mult *
static_cast<double>(k) + 0.5;
82 return std::numeric_limits<SearchKType>::max();
87 AnnoySearcher(
const AnnoyPrebuilt<Index_, Data_, Distance_, AnnoyDistance_, AnnoyIndex_, AnnoyData_, AnnoyRng_, AnnoyThreadPolicy_>& parent) : my_parent(parent) {
88 if constexpr(!same_internal_data) {
89 sanisizer::resize(my_buffer, my_parent.my_dim);
92 if (my_parent.my_search_mult >= 1) {
93 my_capped_k = static_cast<double>(std::numeric_limits<SearchKType>::max()) / my_parent.my_search_mult;
98 std::pair<std::vector<AnnoyIndex_>*, std::vector<AnnoyData_>*> obtain_pointers(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances, Index_ k) {
99 std::vector<AnnoyIndex_>* icopy_ptr = &my_indices;
100 if (output_indices) {
101 if constexpr(same_internal_index) {
102 icopy_ptr = output_indices;
106 icopy_ptr->reserve(k);
108 std::vector<AnnoyData_>* dcopy_ptr = NULL;
109 if (output_distances) {
110 if constexpr(same_internal_distance) {
111 dcopy_ptr = output_distances;
113 dcopy_ptr = &my_distances;
116 dcopy_ptr->reserve(k);
119 return std::make_pair(icopy_ptr, dcopy_ptr);
122 template<
typename Type_>
123 static void remove_self(std::vector<Type_>& vec, std::size_t at) {
124 if (at < vec.size()) {
125 vec.erase(vec.begin() + at);
131 template<
typename Source_,
typename Dest_>
132 static void copy_skip_self(
const std::vector<Source_>& source, std::vector<Dest_>& dest, std::size_t at) {
133 auto sIt = source.begin();
134 auto end = source.size();
136 dest.reserve(end - 1);
139 dest.insert(dest.end(), sIt, sIt + at);
140 dest.insert(dest.end(), sIt + at + 1, source.end());
147 dest.insert(dest.end(), sIt, sIt + end - 1);
152 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
154 const auto kp1 = k + 1;
156 auto ptrs = obtain_pointers(output_indices, output_distances, kp1);
157 auto icopy_ptr = ptrs.first;
158 auto dcopy_ptr = ptrs.second;
160 my_parent.my_index.get_nns_by_item(
162 sanisizer::cast<std::size_t>(sanisizer::attest_gez(kp1)),
170 const auto& cur_i = *icopy_ptr;
172 const AnnoyIndex_ icopy = i;
173 for (std::size_t x = 0, end = cur_i.size(); x < end; ++x) {
174 if (cur_i[x] == icopy) {
181 if (output_indices) {
182 if constexpr(same_internal_index) {
183 remove_self(*output_indices, at);
185 copy_skip_self(my_indices, *output_indices, at);
189 if (output_distances) {
190 if constexpr(same_internal_distance) {
191 remove_self(*output_distances, at);
193 copy_skip_self(my_distances, *output_distances, at);
199 void search_raw(
const AnnoyData_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
200 auto ptrs = obtain_pointers(output_indices, output_distances, k);
201 auto icopy_ptr = ptrs.first;
202 auto dcopy_ptr = ptrs.second;
204 my_parent.my_index.get_nns_by_vector(
206 sanisizer::cast<std::size_t>(sanisizer::attest_gez(k)),
212 if (output_indices) {
213 if constexpr(!same_internal_index) {
214 output_indices->clear();
215 output_indices->insert(output_indices->end(), my_indices.begin(), my_indices.end());
219 if (output_distances) {
220 if constexpr(!same_internal_distance) {
221 output_distances->clear();
222 output_distances->insert(output_distances->end(), my_distances.begin(), my_distances.end());
228 void search(
const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
229 if constexpr(same_internal_data) {
230 search_raw(query, k, output_indices, output_distances);
232 std::copy_n(query, my_parent.my_dim, my_buffer.begin());
233 search_raw(my_buffer.data(), k, output_indices, output_distances);
242 class AnnoyDistance_,
243 typename AnnoyIndex_ = Index_,
244 typename AnnoyData_ = float,
245 class AnnoyRng_ = Annoy::Kiss64Random,
246 class AnnoyThreadPolicy_ = Annoy::AnnoyIndexSingleThreadedBuildPolicy
250 template<
class Matrix_>
251 AnnoyPrebuilt(
const Matrix_& data,
const AnnoyOptions& options) :
252 my_dim(data.num_dimensions()),
253 my_obs(data.num_observations()),
254 my_search_mult(options.search_mult),
258 sanisizer::cast<AnnoyIndex_>(sanisizer::attest_gez(my_obs));
260 auto work = data.new_known_extractor();
261 if constexpr(std::is_same<Data_, AnnoyData_>::value) {
262 for (Index_ i = 0; i < my_obs; ++i) {
263 auto ptr = work->next();
264 my_index.add_item(i, ptr);
267 auto incoming = sanisizer::create<std::vector<AnnoyData_> >(my_dim);
268 for (Index_ i = 0; i < my_obs; ++i) {
269 auto ptr = work->next();
270 std::copy_n(ptr, my_dim, incoming.begin());
271 my_index.add_item(i, incoming.data());
275 my_index.build(options.num_trees);
282 double my_search_mult;
287 struct SuperHackyThing final :
public Annoy::AnnoyIndex<AnnoyIndex_, AnnoyData_, AnnoyDistance_, AnnoyRng_, AnnoyThreadPolicy_> {
288 template<
typename ... Args_>
289 SuperHackyThing(Args_&& ... args) : Annoy::AnnoyIndex<AnnoyIndex_, AnnoyData_, AnnoyDistance_, AnnoyRng_, AnnoyThreadPolicy_>(std::forward<Args_>(args)...) {}
291 auto get_nodes()
const {
299 auto get_n_nodes()
const {
300 return this->_n_nodes;
303 SuperHackyThing my_index;
305 friend class AnnoySearcher<Index_, Data_, Distance_, AnnoyDistance_, AnnoyIndex_, AnnoyData_, AnnoyRng_, AnnoyThreadPolicy_>;
308 std::size_t num_dimensions()
const {
312 Index_ num_observations()
const {
316 std::unique_ptr<knncolle::Searcher<Index_, Data_, Distance_> > initialize()
const {
317 return initialize_known();
320 auto initialize_known()
const {
321 return std::make_unique<AnnoySearcher<Index_, Data_, Distance_, AnnoyDistance_, AnnoyIndex_, AnnoyData_, AnnoyRng_, AnnoyThreadPolicy_> >(*this);
325 void save(
const std::string& prefix)
const {
332 types[0] = get_numeric_type<AnnoyIndex_>();
333 types[1] = get_numeric_type<AnnoyData_>();
336 const auto dname = get_distance_name<AnnoyDistance_>();
344 const auto idxpath = prefix +
"index";
345 knncolle::quick_save(idxpath,
reinterpret_cast<char*
>(my_index.get_nodes()), sanisizer::product<std::streamsize>(my_index.get_s(), my_index.get_n_nodes()));
348 AnnoyPrebuilt(
const std::string prefix, std::size_t ndim) : my_dim(ndim), my_index(ndim) {
352 const auto idxpath = prefix +
"index";
354 if (!my_index.load(idxpath.c_str(),
true, &errbuf)) {
355 std::runtime_error ex(errbuf);
396 class AnnoyDistance_,
397 typename AnnoyIndex_ = Index_,
398 typename AnnoyData_ = float,
399 class AnnoyRng_ = Annoy::Kiss64Random,
400 class AnnoyThreadPolicy_ = Annoy::AnnoyIndexSingleThreadedBuildPolicy,
442 return new AnnoyPrebuilt<Index_, Data_, Distance_, AnnoyDistance_, AnnoyIndex_, AnnoyData_, AnnoyRng_, AnnoyThreadPolicy_>(data, my_options);