30inline static constexpr const char* kmknn_prebuilt_save_name =
"knncolle_kmknn::Kmknn";
47template<
class KmeansFloat_>
49 static std::function<void(
const std::filesystem::path&)> fun;
74 typename KmeansIndex_ = Index_,
75 typename KmeansData_ = Data_,
76 typename KmeansCluster_ = Index_,
77 typename KmeansFloat_ = Distance_,
91 std::shared_ptr<kmeans::Initialize<KmeansIndex_, KmeansData_, KmeansCluster_, KmeansFloat_, KmeansMatrix_> >
initialize_algorithm;
97 std::shared_ptr<kmeans::Refine<KmeansIndex_, KmeansData_, KmeansCluster_, KmeansFloat_, KmeansMatrix_> >
refine_algorithm;
103template<
typename Index_,
typename Data_,
typename Distance_,
class DistanceMetricData_,
class KmeansFloat_,
class DistanceMetricCenter_>
106template<
typename Index_,
typename Data_,
typename Distance_,
class DistanceMetricData_,
class KmeansFloat_,
class DistanceMetricCenter_>
109 KmknnSearcher(
const KmknnPrebuilt<Index_, Data_, Distance_, DistanceMetricData_, KmeansFloat_, DistanceMetricCenter_>& parent) : my_parent(parent) {
110 my_center_order.reserve(my_parent.my_sizes.size());
111 if constexpr(needs_conversion) {
112 sanisizer::resize(my_query_conversion_buffer, my_parent.my_dim);
117 const KmknnPrebuilt<Index_, Data_, Distance_, DistanceMetricData_, KmeansFloat_, DistanceMetricCenter_>& my_parent;
119 std::vector<std::pair<Distance_, Index_> > my_all_neighbors;
120 std::vector<std::pair<Distance_, Index_> > my_center_order;
123 static constexpr bool needs_conversion = !std::is_same<KmeansFloat_, Data_>::value;
124 typename std::conditional<needs_conversion, std::vector<KmeansFloat_>,
bool>::type my_query_conversion_buffer;
126 const KmeansFloat_* sanitize_query(
const Data_* query) {
127 if constexpr(needs_conversion) {
128 auto conv_buffer = my_query_conversion_buffer.data();
129 std::copy_n(query, my_parent.my_dim, conv_buffer);
136 void finalize(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances)
const {
137 if (output_indices) {
138 for (
auto& s : *output_indices) {
139 s = my_parent.my_observation_id[s];
142 if (output_distances) {
143 for (
auto& d : *output_distances) {
144 d = my_parent.my_metric_data->normalize(d);
150 void search_nn(
const Data_* query) {
155 const auto query_san = sanitize_query(query);
156 const auto ncenters = my_parent.my_sizes.size();
157 my_center_order.clear();
158 my_center_order.reserve(ncenters);
160 for (I<
decltype(ncenters)> c = 0; c < ncenters; ++c) {
161 auto clust_ptr = my_parent.my_centers.data() + sanisizer::product_unsafe<std::size_t>(c, my_parent.my_dim);
162 my_center_order.emplace_back(my_parent.my_metric_center->raw(my_parent.my_dim, query_san, clust_ptr), c);
164 std::sort(my_center_order.begin(), my_center_order.end());
168 const auto& dist2centers = my_parent.my_dist_to_centroid;
169 Distance_ threshold_raw = std::numeric_limits<Distance_>::infinity();
171 for (
const auto& curcent : my_center_order) {
172 const Index_ center = curcent.second;
173 Index_ firstsubj = my_parent.my_offsets[center], lastsubj = firstsubj + my_parent.my_sizes[center];
175 if (!std::isinf(threshold_raw)) {
176 const Distance_ threshold = my_parent.my_metric_center->normalize(threshold_raw);
177 const Distance_ query2center = my_parent.my_metric_center->normalize(curcent.first);
178 const Distance_ max_subj2center = dist2centers[lastsubj - 1];
187 const Distance_ lower_bd = query2center - threshold;
188 if (max_subj2center < lower_bd) {
191 firstsubj = std::lower_bound(dist2centers.begin() + firstsubj, dist2centers.begin() + lastsubj, lower_bd) - dist2centers.begin();
203 const Distance_ upper_bd = query2center + threshold;
204 if (max_subj2center > upper_bd) {
205 lastsubj = std::upper_bound(dist2centers.begin() + firstsubj, dist2centers.begin() + lastsubj, upper_bd) - dist2centers.begin();
209 for (
auto s = firstsubj; s < lastsubj; ++s) {
210 const auto other_subj = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(s, my_parent.my_dim);
211 auto dist2subj_raw = my_parent.my_metric_data->raw(my_parent.my_dim, query, other_subj);
212 if (dist2subj_raw <= threshold_raw) {
213 my_nearest.
add(s, dist2subj_raw);
215 threshold_raw = my_nearest.
limit();
240 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
241 my_nearest.
reset(k + 1);
242 auto new_i = my_parent.my_new_location[i];
243 auto iptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(new_i, my_parent.my_dim);
245 my_nearest.
report(output_indices, output_distances, new_i);
246 finalize(output_indices, output_distances);
249 void search(
const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
251 if (output_indices) {
252 output_indices->clear();
254 if (output_distances) {
255 output_distances->clear();
260 my_nearest.
report(output_indices, output_distances);
261 finalize(output_indices, output_distances);
266 template<
bool count_only_,
typename Output_>
267 void search_all(
const Data_* query, Distance_ threshold, Output_& all_neighbors) {
268 Distance_ threshold_raw = my_parent.my_metric_center->denormalize(threshold);
269 const auto query_san = sanitize_query(query);
272 const auto ncenters = my_parent.my_sizes.size();
273 const auto& dist2centers = my_parent.my_dist_to_centroid;
275 for (I<
decltype(ncenters)> center = 0; center < ncenters; ++center) {
276 auto center_ptr = my_parent.my_centers.data() + sanisizer::product_unsafe<std::size_t>(center, my_parent.my_dim);
277 const Distance_ query2center = my_parent.my_metric_center->normalize(my_parent.my_metric_center->raw(my_parent.my_dim, query_san, center_ptr));
278 Index_ firstsubj = my_parent.my_offsets[center], lastsubj = firstsubj + my_parent.my_sizes[center];
279 const Distance_ max_subj2center = dist2centers[lastsubj - 1];
282 const Distance_ lower_bd = query2center - threshold;
283 if (max_subj2center < lower_bd) {
286 firstsubj = std::lower_bound(dist2centers.begin() + firstsubj, dist2centers.begin() + lastsubj, lower_bd) - dist2centers.begin();
289 const Distance_ upper_bd = query2center + threshold;
290 if (max_subj2center > upper_bd) {
291 lastsubj = std::upper_bound(dist2centers.begin() + firstsubj, dist2centers.begin() + lastsubj, upper_bd) - dist2centers.begin();
294 for (
auto s = firstsubj; s < lastsubj; ++s) {
295 const auto other_ptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(s, my_parent.my_dim);
296 auto dist2cell_raw = my_parent.my_metric_data->raw(my_parent.my_dim, query, other_ptr);
297 if (dist2cell_raw <= threshold_raw) {
298 if constexpr(count_only_) {
301 all_neighbors.emplace_back(dist2cell_raw, s);
309 bool can_search_all()
const {
313 Index_ search_all(Index_ i, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
314 auto new_i = my_parent.my_new_location[i];
315 auto iptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(new_i, my_parent.my_dim);
317 if (!output_indices && !output_distances) {
319 search_all<true>(iptr, d, count);
323 my_all_neighbors.clear();
324 search_all<false>(iptr, d, my_all_neighbors);
326 finalize(output_indices, output_distances);
331 Index_ search_all(
const Data_* query, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
332 if (!output_indices && !output_distances) {
334 search_all<true>(query, d, count);
338 my_all_neighbors.clear();
339 search_all<false>(query, d, my_all_neighbors);
341 finalize(output_indices, output_distances);
342 return my_all_neighbors.size();
347template<
typename Index_,
typename Data_,
typename Distance_,
class DistanceMetricData_,
typename KmeansFloat_,
class DistanceMetricCenter_>
354 Index_ num_observations()
const {
358 std::size_t num_dimensions()
const {
363 std::vector<Data_> my_data;
364 std::shared_ptr<const DistanceMetricData_> my_metric_data;
365 std::shared_ptr<const DistanceMetricCenter_> my_metric_center;
367 std::vector<Index_> my_sizes;
368 std::vector<Index_> my_offsets;
370 std::vector<KmeansFloat_> my_centers;
372 std::vector<Index_> my_observation_id, my_new_location;
373 std::vector<Distance_> my_dist_to_centroid;
376 template<
typename KmeansIndex_,
typename KmeansData_,
typename KmeansCluster_,
class KmeansMatrix_>
380 std::vector<Data_> data,
381 std::shared_ptr<const DistanceMetricData_> metric_data,
382 std::shared_ptr<const DistanceMetricCenter_> metric_center,
383 const KmknnOptions<Index_, Data_, Distance_, KmeansIndex_, KmeansData_, KmeansCluster_, KmeansFloat_, KmeansMatrix_>& options
387 my_data(std::move(data)),
388 my_metric_data(std::move(metric_data)),
389 my_metric_center(std::move(metric_center))
391 auto init = options.initialize_algorithm;
392 if (init ==
nullptr) {
395 auto refine = options.refine_algorithm;
396 if (refine ==
nullptr) {
400 KmeansCluster_ ncenters = sanisizer::from_float<KmeansCluster_>(std::ceil(std::pow(my_obs, options.power)));
401 my_centers.resize(sanisizer::product<I<
decltype(my_centers.size())> >(sanisizer::attest_gez(ncenters), my_dim));
403 constexpr bool same_data = std::is_same<Data_, KmeansData_>::value;
404 typename std::conditional<same_data, bool, std::vector<KmeansData_> >::type kmeans_data_buffer;
405 const KmeansData_* data_ptr = NULL;
406 if constexpr(same_data) {
407 data_ptr = my_data.data();
409 kmeans_data_buffer.insert(kmeans_data_buffer.end(), my_data.begin(), my_data.end());
410 data_ptr = kmeans_data_buffer.data();
414 auto clusters = sanisizer::create<std::vector<KmeansCluster_> >(sanisizer::attest_gez(my_obs));
415 auto output =
kmeans::compute(mat, *init, *refine, ncenters, my_centers.data(), clusters.data());
418 const auto survivors =
kmeans::remove_unused_centers(my_dim,
static_cast<KmeansIndex_
>(my_obs), clusters.data(), ncenters, my_centers.data(), output.sizes);
419 if (survivors < ncenters) {
420 ncenters = survivors;
421 my_centers.resize(sanisizer::product_unsafe<I<
decltype(my_centers.size())> >(ncenters, my_dim));
422 output.sizes.resize(ncenters);
425 if constexpr(std::is_same<Index_, KmeansIndex_>::value) {
426 my_sizes.swap(output.sizes);
428 my_sizes.insert(my_sizes.end(), output.sizes.begin(), output.sizes.end());
431 sanisizer::resize(my_offsets, sanisizer::attest_gez(ncenters));
432 for (KmeansCluster_ i = 1; i < ncenters; ++i) {
433 my_offsets[i] = my_offsets[i - 1] + my_sizes[i - 1];
437 auto by_distance = sanisizer::create<std::vector<std::pair<Distance_, Index_> > >(sanisizer::attest_gez(my_obs));
439 static constexpr bool needs_conversion = !std::is_same<KmeansFloat_, Data_>::value;
440 typename std::conditional<needs_conversion, std::vector<KmeansFloat_>,
bool>::type conversion_buffer;
441 if constexpr(needs_conversion) {
442 sanisizer::resize(conversion_buffer, my_dim);
445 auto sofar = my_offsets;
446 for (Index_ o = 0; o < my_obs; ++o) {
447 auto optr = my_data.data() + sanisizer::product_unsafe<std::size_t>(o, my_dim);
449 const KmeansFloat_* observation = NULL;
450 if constexpr(needs_conversion) {
451 std::copy_n(optr, my_dim, conversion_buffer.data());
452 observation = conversion_buffer.data();
457 auto clustid = clusters[o];
458 auto cptr = my_centers.data() + sanisizer::product_unsafe<std::size_t>(clustid, my_dim);
460 auto& counter = sofar[clustid];
461 auto& current = by_distance[counter];
462 current.first = my_metric_center->normalize(my_metric_center->raw(my_dim, observation, cptr));
468 for (KmeansCluster_ c = 0; c < ncenters; ++c) {
469 auto begin = by_distance.data() + my_offsets[c];
470 std::sort(begin, begin + my_sizes[c]);
476 auto used = sanisizer::create<std::vector<unsigned char> >(sanisizer::attest_gez(my_obs));
477 auto buffer = sanisizer::create<std::vector<Data_> >(my_dim);
478 sanisizer::resize(my_observation_id, sanisizer::attest_gez(my_obs));
479 sanisizer::resize(my_dist_to_centroid, sanisizer::attest_gez(my_obs));
480 sanisizer::resize(my_new_location, sanisizer::attest_gez(my_obs));
482 for (Index_ o = 0; o < my_obs; ++o) {
487 const auto& current = by_distance[o];
488 my_observation_id[o] = current.second;
489 my_dist_to_centroid[o] = current.first;
490 my_new_location[current.second] = o;
491 if (current.second == o) {
497 auto optr = my_data.data() + sanisizer::product_unsafe<std::size_t>(o, my_dim);
498 std::copy_n(optr, my_dim, buffer.data());
499 Index_ replacement = current.second;
501 auto rptr = my_data.data() + sanisizer::product_unsafe<std::size_t>(replacement, my_dim);
502 std::copy_n(rptr, my_dim, optr);
503 used[replacement] = 1;
505 const auto& next = by_distance[replacement];
506 my_observation_id[replacement] = next.second;
507 my_dist_to_centroid[replacement] = next.first;
508 my_new_location[next.second] = replacement;
511 replacement = next.second;
512 }
while (replacement != o);
514 std::copy(buffer.begin(), buffer.end(), optr);
519 friend class KmknnSearcher<Index_, Data_, Distance_, DistanceMetricData_, KmeansFloat_, DistanceMetricCenter_>;
522 std::unique_ptr<knncolle::Searcher<Index_, Data_, Distance_> > initialize()
const {
523 return initialize_known();
526 auto initialize_known()
const {
527 return std::make_unique<KmknnSearcher<Index_, Data_, Distance_, DistanceMetricData_, KmeansFloat_, DistanceMetricCenter_> >(*this);
531 void save(
const std::filesystem::path& dir)
const {
532 knncolle::quick_save(dir /
"ALGORITHM", kmknn_prebuilt_save_name, std::strlen(kmknn_prebuilt_save_name));
536 const auto num_centers = my_sizes.size();
542 knncolle::quick_save(dir /
"OBSERVATION_ID", my_observation_id.data(), my_observation_id.size());
544 knncolle::quick_save(dir /
"DIST_TO_CENTROID", my_dist_to_centroid.data(), my_dist_to_centroid.size());
548 auto& kfcust = custom_save_for_kmknn_kmeansfloat<KmeansFloat_>();
554 const auto distdir = dir /
"DISTANCE_DATA";
555 std::filesystem::create_directory(distdir);
556 my_metric_data->save(distdir);
560 const auto distdir = dir /
"DISTANCE_CENTER";
561 std::filesystem::create_directory(distdir);
562 my_metric_center->save(distdir);
566 KmknnPrebuilt(
const std::filesystem::path& dir) {
569 auto num_centers = my_sizes.size();
572 my_data.resize(sanisizer::product<I<
decltype(my_data.size())> >(sanisizer::attest_gez(my_obs), my_dim));
575 sanisizer::resize(my_sizes, sanisizer::attest_gez(num_centers));
577 sanisizer::resize(my_offsets, sanisizer::attest_gez(num_centers));
579 my_centers.resize(sanisizer::product<I<
decltype(my_centers.size())> >(my_dim, sanisizer::attest_gez(num_centers)));
582 sanisizer::resize(my_observation_id, sanisizer::attest_gez(my_obs));
583 knncolle::quick_load(dir /
"OBSERVATION_ID", my_observation_id.data(), my_observation_id.size());
584 sanisizer::resize(my_new_location, sanisizer::attest_gez(my_obs));
586 sanisizer::resize(my_dist_to_centroid, sanisizer::attest_gez(my_obs));
587 knncolle::quick_load(dir /
"DIST_TO_CENTROID", my_dist_to_centroid.data(), my_dist_to_centroid.size());
591 auto xptr =
dynamic_cast<DistanceMetricData_*
>(dptr);
593 throw std::runtime_error(
"cannot cast the loaded distance metric to a DistanceMetricData_");
595 my_metric_data.reset(xptr);
600 auto xptr =
dynamic_cast<DistanceMetricCenter_*
>(dptr);
602 throw std::runtime_error(
"cannot cast the loaded distance metric to a DistanceMetricCenter_");
604 my_metric_center.reset(xptr);
653 typename KmeansIndex_ = Index_,
654 typename KmeansData_ = Data_,
655 typename KmeansCluster_ = Index_,
656 typename KmeansFloat_ = Distance_,
668 std::shared_ptr<const DistanceMetricData_> my_metric_data;
669 std::shared_ptr<const DistanceMetricCenter_> my_metric_center;
687 std::shared_ptr<const DistanceMetricData_> metric_data,
688 std::shared_ptr<const DistanceMetricCenter_> metric_center,
691 my_metric_data(std::move(metric_data)),
692 my_metric_center(std::move(metric_center)),
693 my_options(std::move(options))
702 KmknnBuilder(std::shared_ptr<const DistanceMetricData_> metric_data, std::shared_ptr<const DistanceMetricCenter_> metric_center) :
703 KmknnBuilder(std::move(metric_data), std::move(metric_center), {}) {}
734 const auto ndim = data.num_dimensions();
735 const auto nobs = data.num_observations();
737 typedef std::vector<Data_> Store;
738 Store store(sanisizer::product<typename Store::size_type>(ndim, nobs));
740 auto work = data.new_known_extractor();
741 for (I<
decltype(nobs)> o = 0; o < nobs; ++o) {
742 auto ptr = work->next();
743 std::copy_n(ptr, ndim, store.data() + sanisizer::product_unsafe<std::size_t>(o, ndim));
746 return new KmknnPrebuilt<Index_, Data_, Distance_, DistanceMetricData_, KmeansFloat_, DistanceMetricCenter_>(