knncolle_kmknn
KMKNN in knncolle
Loading...
Searching...
No Matches
Kmknn.hpp
1#ifndef KNNCOLLE_KMKNN_KMKNN_HPP
2#define KNNCOLLE_KMKNN_KMKNN_HPP
3
4#include "utils.hpp"
5
7#include "kmeans/kmeans.hpp"
8#include "sanisizer/sanisizer.hpp"
9
10#include <algorithm>
11#include <vector>
12#include <memory>
13#include <limits>
14#include <cmath>
15#include <cstddef>
16#include <type_traits>
17#include <string>
18#include <filesystem>
19
25namespace knncolle_kmknn {
26
30inline static constexpr const char* kmknn_prebuilt_save_name = "knncolle_kmknn::Kmknn";
31
47template<class KmeansFloat_>
48std::function<void(const std::filesystem::path&)>& custom_save_for_kmknn_kmeansfloat() {
49 static std::function<void(const std::filesystem::path&)> fun;
50 return fun;
51}
52
70template<
71 typename Index_,
72 typename Data_,
73 typename Distance_,
74 typename KmeansIndex_ = Index_,
75 typename KmeansData_ = Data_,
76 typename KmeansCluster_ = Index_,
77 typename KmeansFloat_ = Distance_,
79>
85 double power = 0.5;
86
91 std::shared_ptr<kmeans::Initialize<KmeansIndex_, KmeansData_, KmeansCluster_, KmeansFloat_, KmeansMatrix_> > initialize_algorithm;
92
97 std::shared_ptr<kmeans::Refine<KmeansIndex_, KmeansData_, KmeansCluster_, KmeansFloat_, KmeansMatrix_> > refine_algorithm;
98};
99
103template<typename Index_, typename Data_, typename Distance_, class DistanceMetricData_, class KmeansFloat_, class DistanceMetricCenter_>
104class KmknnPrebuilt;
105
106template<typename Index_, typename Data_, typename Distance_, class DistanceMetricData_, class KmeansFloat_, class DistanceMetricCenter_>
107class KmknnSearcher final : public knncolle::Searcher<Index_, Data_, Distance_> {
108public:
109 KmknnSearcher(const KmknnPrebuilt<Index_, Data_, Distance_, DistanceMetricData_, KmeansFloat_, DistanceMetricCenter_>& parent) : my_parent(parent) {
110 my_center_order.reserve(my_parent.my_sizes.size());
111 if constexpr(needs_conversion) {
112 sanisizer::resize(my_query_conversion_buffer, my_parent.my_dim);
113 }
114 }
115
116private:
117 const KmknnPrebuilt<Index_, Data_, Distance_, DistanceMetricData_, KmeansFloat_, DistanceMetricCenter_>& my_parent;
119 std::vector<std::pair<Distance_, Index_> > my_all_neighbors;
120 std::vector<std::pair<Distance_, Index_> > my_center_order;
121
122 // Converting Data_ to KmeansFloat_ if we need to.
123 static constexpr bool needs_conversion = !std::is_same<KmeansFloat_, Data_>::value;
124 typename std::conditional<needs_conversion, std::vector<KmeansFloat_>, bool>::type my_query_conversion_buffer;
125
126 const KmeansFloat_* sanitize_query(const Data_* query) {
127 if constexpr(needs_conversion) {
128 auto conv_buffer = my_query_conversion_buffer.data();
129 std::copy_n(query, my_parent.my_dim, conv_buffer);
130 return conv_buffer;
131 } else {
132 return query;
133 }
134 }
135
136 void finalize(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) const {
137 if (output_indices) {
138 for (auto& s : *output_indices) {
139 s = my_parent.my_observation_id[s];
140 }
141 }
142 if (output_distances) {
143 for (auto& d : *output_distances) {
144 d = my_parent.my_metric_data->normalize(d);
145 }
146 }
147 }
148
149private:
150 void search_nn(const Data_* query) {
151 // Computing distances to all centers and sorting them.
152 // The aim is to go through the nearest centers first, to try to get the shortest threshold (i.e., 'nearest.limit()') possible at the start;
153 // this allows us to skip searches of the later clusters.
154 {
155 const auto query_san = sanitize_query(query);
156 const auto ncenters = my_parent.my_sizes.size();
157 my_center_order.clear();
158 my_center_order.reserve(ncenters);
159
160 for (I<decltype(ncenters)> c = 0; c < ncenters; ++c) {
161 auto clust_ptr = my_parent.my_centers.data() + sanisizer::product_unsafe<std::size_t>(c, my_parent.my_dim);
162 my_center_order.emplace_back(my_parent.my_metric_center->raw(my_parent.my_dim, query_san, clust_ptr), c);
163 }
164 std::sort(my_center_order.begin(), my_center_order.end());
165 }
166
167 // Computing the distance to each center, and deciding whether to proceed for each cluster.
168 const auto& dist2centers = my_parent.my_dist_to_centroid;
169 Distance_ threshold_raw = std::numeric_limits<Distance_>::infinity();
170
171 for (const auto& curcent : my_center_order) {
172 const Index_ center = curcent.second;
173 Index_ firstsubj = my_parent.my_offsets[center], lastsubj = firstsubj + my_parent.my_sizes[center];
174
175 if (!std::isinf(threshold_raw)) {
176 const Distance_ threshold = my_parent.my_metric_center->normalize(threshold_raw);
177 const Distance_ query2center = my_parent.my_metric_center->normalize(curcent.first);
178 const Distance_ max_subj2center = dist2centers[lastsubj - 1];
179
180 /* This exploits the triangle inequality to ignore points where:
181 * threshold + subject-to-center < query-to-center
182 * All points (if any) within this cluster with distances at or above 'lower_bd' are potentially countable.
183 *
184 * If the maximum distance between a subject and the center is less than 'lower_bd', there's no point proceeding,
185 * as we know that all other subjects will have smaller distances and are thus uncountable.
186 */
187 const Distance_ lower_bd = query2center - threshold;
188 if (max_subj2center < lower_bd) {
189 continue;
190 }
191 firstsubj = std::lower_bound(dist2centers.begin() + firstsubj, dist2centers.begin() + lastsubj, lower_bd) - dist2centers.begin();
192
193 /* This exploits the reverse triangle inequality, to ignore points where:
194 * threshold + query-to-center < subject-to-center
195 * All points (if any) within this cluster with distances at or below 'upper_bd' are potentially countable.
196 *
197 * If the maximum distance between a subject and the center is less than or equal to 'upper_bd', we can just skip the search.
198 * No subjects will a distance-to-center greater than 'upper_bd' so we know that we have to examine all subjects.
199 *
200 * We could also skip this center altogther if the minimum subject-to-center distance is greater than 'upper_bd'.
201 * However, this seems too unlikely to warrant a special clause.
202 */
203 const Distance_ upper_bd = query2center + threshold;
204 if (max_subj2center > upper_bd) {
205 lastsubj = std::upper_bound(dist2centers.begin() + firstsubj, dist2centers.begin() + lastsubj, upper_bd) - dist2centers.begin();
206 }
207 }
208
209 for (auto s = firstsubj; s < lastsubj; ++s) {
210 const auto other_subj = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(s, my_parent.my_dim);
211 auto dist2subj_raw = my_parent.my_metric_data->raw(my_parent.my_dim, query, other_subj);
212 if (dist2subj_raw <= threshold_raw) {
213 my_nearest.add(s, dist2subj_raw);
214 if (my_nearest.is_full()) {
215 threshold_raw = my_nearest.limit(); // Shrinking the threshold, if an earlier NN has been found.
216
217 /* P.S. We could also consider increasing 'firstsubj' as 'threshold_raw' decreases.
218 * The idea would be to exploit the triangle inequality to quickly skip over more points.
219 * However, this is pointless because 'lower_bd' will never increase enough to skip subsequent observations.
220 * We wouldn't have been able to skip the observation that we just added,
221 * so there's no way we could skip observations with larger subject-to-center distances.
222 *
223 * P.P.S. We could also consider decreasing 'lastsubj' as 'threshold_raw' decreases.
224 * The idea would be to exploit the triangle inequality to terminate sooner.
225 * However, this doesn't seem to provide a lot of benefit in practice.
226 * In theory, we can only trim the search space if the query already lies in a center's hypersphere (as 'upper_bd' cannot decrease below 'query2center').
227 * Even then, 'upper_bd' is usually too large; testing indicates that a reduced 'upper_bd' only trims away a single observation at a time.
228 * There are also practical challenges as changes to 'lastsubj' within the loop might prevent out-of-order CPU execution;
229 * we need to do more memory accesses to 'dist2centers' to check if 'lastsubj' can be decreased;
230 * and we need to run an extra 'normalize()' to recompute 'upper_bd' inside the loop.
231 * All in all, I don't think it's worth it.
232 */
233 }
234 }
235 }
236 }
237 }
238
239public:
240 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
241 my_nearest.reset(k + 1); // +1 is safe as k < num_obs.
242 auto new_i = my_parent.my_new_location[i];
243 auto iptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(new_i, my_parent.my_dim);
244 search_nn(iptr);
245 my_nearest.report(output_indices, output_distances, new_i);
246 finalize(output_indices, output_distances);
247 }
248
249 void search(const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
250 if (k == 0) { // protect the NeighborQueue from k = 0.
251 if (output_indices) {
252 output_indices->clear();
253 }
254 if (output_distances) {
255 output_distances->clear();
256 }
257 } else {
258 my_nearest.reset(k);
259 search_nn(query);
260 my_nearest.report(output_indices, output_distances);
261 finalize(output_indices, output_distances);
262 }
263 }
264
265private:
266 template<bool count_only_, typename Output_>
267 void search_all(const Data_* query, Distance_ threshold, Output_& all_neighbors) {
268 Distance_ threshold_raw = my_parent.my_metric_center->denormalize(threshold);
269 const auto query_san = sanitize_query(query);
270
271 // Computing distances to all centers. We don't sort them here because the threshold is constant so there's no point.
272 const auto ncenters = my_parent.my_sizes.size();
273 const auto& dist2centers = my_parent.my_dist_to_centroid;
274
275 for (I<decltype(ncenters)> center = 0; center < ncenters; ++center) {
276 auto center_ptr = my_parent.my_centers.data() + sanisizer::product_unsafe<std::size_t>(center, my_parent.my_dim);
277 const Distance_ query2center = my_parent.my_metric_center->normalize(my_parent.my_metric_center->raw(my_parent.my_dim, query_san, center_ptr));
278 Index_ firstsubj = my_parent.my_offsets[center], lastsubj = firstsubj + my_parent.my_sizes[center];
279 const Distance_ max_subj2center = dist2centers[lastsubj - 1];
280
281 // Same logic as in search_nn().
282 const Distance_ lower_bd = query2center - threshold;
283 if (max_subj2center < lower_bd) {
284 continue;
285 }
286 firstsubj = std::lower_bound(dist2centers.begin() + firstsubj, dist2centers.begin() + lastsubj, lower_bd) - dist2centers.begin();
287
288 // Same logic as in search_nn().
289 const Distance_ upper_bd = query2center + threshold;
290 if (max_subj2center > upper_bd) {
291 lastsubj = std::upper_bound(dist2centers.begin() + firstsubj, dist2centers.begin() + lastsubj, upper_bd) - dist2centers.begin();
292 }
293
294 for (auto s = firstsubj; s < lastsubj; ++s) {
295 const auto other_ptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(s, my_parent.my_dim);
296 auto dist2cell_raw = my_parent.my_metric_data->raw(my_parent.my_dim, query, other_ptr);
297 if (dist2cell_raw <= threshold_raw) {
298 if constexpr(count_only_) {
299 ++all_neighbors;
300 } else {
301 all_neighbors.emplace_back(dist2cell_raw, s);
302 }
303 }
304 }
305 }
306 }
307
308public:
309 bool can_search_all() const {
310 return true;
311 }
312
313 Index_ search_all(Index_ i, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
314 auto new_i = my_parent.my_new_location[i];
315 auto iptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(new_i, my_parent.my_dim);
316
317 if (!output_indices && !output_distances) {
318 Index_ count = 0;
319 search_all<true>(iptr, d, count);
321
322 } else {
323 my_all_neighbors.clear();
324 search_all<false>(iptr, d, my_all_neighbors);
325 knncolle::report_all_neighbors(my_all_neighbors, output_indices, output_distances, new_i);
326 finalize(output_indices, output_distances);
327 return knncolle::count_all_neighbors_without_self(my_all_neighbors.size());
328 }
329 }
330
331 Index_ search_all(const Data_* query, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
332 if (!output_indices && !output_distances) {
333 Index_ count = 0;
334 search_all<true>(query, d, count);
335 return count;
336
337 } else {
338 my_all_neighbors.clear();
339 search_all<false>(query, d, my_all_neighbors);
340 knncolle::report_all_neighbors(my_all_neighbors, output_indices, output_distances);
341 finalize(output_indices, output_distances);
342 return my_all_neighbors.size();
343 }
344 }
345};
346
347template<typename Index_, typename Data_, typename Distance_, class DistanceMetricData_, typename KmeansFloat_, class DistanceMetricCenter_>
348class KmknnPrebuilt final : public knncolle::Prebuilt<Index_, Data_, Distance_> {
349private:
350 std::size_t my_dim;
351 Index_ my_obs;
352
353public:
354 Index_ num_observations() const {
355 return my_obs;
356 }
357
358 std::size_t num_dimensions() const {
359 return my_dim;
360 }
361
362private:
363 std::vector<Data_> my_data;
364 std::shared_ptr<const DistanceMetricData_> my_metric_data;
365 std::shared_ptr<const DistanceMetricCenter_> my_metric_center;
366
367 std::vector<Index_> my_sizes;
368 std::vector<Index_> my_offsets;
369
370 std::vector<KmeansFloat_> my_centers;
371
372 std::vector<Index_> my_observation_id, my_new_location;
373 std::vector<Distance_> my_dist_to_centroid;
374
375public:
376 template<typename KmeansIndex_, typename KmeansData_, typename KmeansCluster_, class KmeansMatrix_>
377 KmknnPrebuilt(
378 std::size_t num_dim,
379 Index_ num_obs,
380 std::vector<Data_> data,
381 std::shared_ptr<const DistanceMetricData_> metric_data,
382 std::shared_ptr<const DistanceMetricCenter_> metric_center,
383 const KmknnOptions<Index_, Data_, Distance_, KmeansIndex_, KmeansData_, KmeansCluster_, KmeansFloat_, KmeansMatrix_>& options
384 ) :
385 my_dim(num_dim),
386 my_obs(num_obs),
387 my_data(std::move(data)),
388 my_metric_data(std::move(metric_data)),
389 my_metric_center(std::move(metric_center))
390 {
391 auto init = options.initialize_algorithm;
392 if (init == nullptr) {
394 }
395 auto refine = options.refine_algorithm;
396 if (refine == nullptr) {
398 }
399
400 KmeansCluster_ ncenters = sanisizer::from_float<KmeansCluster_>(std::ceil(std::pow(my_obs, options.power)));
401 my_centers.resize(sanisizer::product<I<decltype(my_centers.size())> >(sanisizer::attest_gez(ncenters), my_dim));
402
403 constexpr bool same_data = std::is_same<Data_, KmeansData_>::value;
404 typename std::conditional<same_data, bool, std::vector<KmeansData_> >::type kmeans_data_buffer;
405 const KmeansData_* data_ptr = NULL;
406 if constexpr(same_data) {
407 data_ptr = my_data.data();
408 } else {
409 kmeans_data_buffer.insert(kmeans_data_buffer.end(), my_data.begin(), my_data.end());
410 data_ptr = kmeans_data_buffer.data();
411 }
412
413 kmeans::SimpleMatrix<KmeansIndex_, KmeansData_> mat(my_dim, sanisizer::cast<KmeansIndex_>(sanisizer::attest_gez(my_obs)), data_ptr);
414 auto clusters = sanisizer::create<std::vector<KmeansCluster_> >(sanisizer::attest_gez(my_obs));
415 auto output = kmeans::compute(mat, *init, *refine, ncenters, my_centers.data(), clusters.data());
416
417 // Removing empty clusters, e.g., due to duplicate points.
418 const auto survivors = kmeans::remove_unused_centers(my_dim, static_cast<KmeansIndex_>(my_obs), clusters.data(), ncenters, my_centers.data(), output.sizes);
419 if (survivors < ncenters) {
420 ncenters = survivors;
421 my_centers.resize(sanisizer::product_unsafe<I<decltype(my_centers.size())> >(ncenters, my_dim));
422 output.sizes.resize(ncenters);
423 }
424
425 if constexpr(std::is_same<Index_, KmeansIndex_>::value) {
426 my_sizes.swap(output.sizes);
427 } else {
428 my_sizes.insert(my_sizes.end(), output.sizes.begin(), output.sizes.end());
429 }
430
431 sanisizer::resize(my_offsets, sanisizer::attest_gez(ncenters));
432 for (KmeansCluster_ i = 1; i < ncenters; ++i) {
433 my_offsets[i] = my_offsets[i - 1] + my_sizes[i - 1];
434 }
435
436 // Organize points correctly; firstly, sorting by distance from the assigned center.
437 auto by_distance = sanisizer::create<std::vector<std::pair<Distance_, Index_> > >(sanisizer::attest_gez(my_obs));
438 {
439 static constexpr bool needs_conversion = !std::is_same<KmeansFloat_, Data_>::value;
440 typename std::conditional<needs_conversion, std::vector<KmeansFloat_>, bool>::type conversion_buffer;
441 if constexpr(needs_conversion) {
442 sanisizer::resize(conversion_buffer, my_dim);
443 }
444
445 auto sofar = my_offsets;
446 for (Index_ o = 0; o < my_obs; ++o) {
447 auto optr = my_data.data() + sanisizer::product_unsafe<std::size_t>(o, my_dim);
448
449 const KmeansFloat_* observation = NULL;
450 if constexpr(needs_conversion) {
451 std::copy_n(optr, my_dim, conversion_buffer.data());
452 observation = conversion_buffer.data();
453 } else {
454 observation = optr;
455 }
456
457 auto clustid = clusters[o];
458 auto cptr = my_centers.data() + sanisizer::product_unsafe<std::size_t>(clustid, my_dim);
459
460 auto& counter = sofar[clustid];
461 auto& current = by_distance[counter];
462 current.first = my_metric_center->normalize(my_metric_center->raw(my_dim, observation, cptr));
463 current.second = o;
464
465 ++counter;
466 }
467
468 for (KmeansCluster_ c = 0; c < ncenters; ++c) {
469 auto begin = by_distance.data() + my_offsets[c];
470 std::sort(begin, begin + my_sizes[c]);
471 }
472 }
473
474 // Permuting in-place to mirror the reordered distances, so that the search is more cache-friendly.
475 {
476 auto used = sanisizer::create<std::vector<unsigned char> >(sanisizer::attest_gez(my_obs));
477 auto buffer = sanisizer::create<std::vector<Data_> >(my_dim);
478 sanisizer::resize(my_observation_id, sanisizer::attest_gez(my_obs));
479 sanisizer::resize(my_dist_to_centroid, sanisizer::attest_gez(my_obs));
480 sanisizer::resize(my_new_location, sanisizer::attest_gez(my_obs));
481
482 for (Index_ o = 0; o < my_obs; ++o) {
483 if (used[o]) {
484 continue;
485 }
486
487 const auto& current = by_distance[o];
488 my_observation_id[o] = current.second;
489 my_dist_to_centroid[o] = current.first;
490 my_new_location[current.second] = o;
491 if (current.second == o) {
492 continue;
493 }
494
495 // We recursively perform a "thread" of replacements until we
496 // are able to find the home of the originally replaced 'o'.
497 auto optr = my_data.data() + sanisizer::product_unsafe<std::size_t>(o, my_dim);
498 std::copy_n(optr, my_dim, buffer.data());
499 Index_ replacement = current.second;
500 do {
501 auto rptr = my_data.data() + sanisizer::product_unsafe<std::size_t>(replacement, my_dim);
502 std::copy_n(rptr, my_dim, optr);
503 used[replacement] = 1;
504
505 const auto& next = by_distance[replacement];
506 my_observation_id[replacement] = next.second;
507 my_dist_to_centroid[replacement] = next.first;
508 my_new_location[next.second] = replacement;
509
510 optr = rptr;
511 replacement = next.second;
512 } while (replacement != o);
513
514 std::copy(buffer.begin(), buffer.end(), optr);
515 }
516 }
517 }
518
519 friend class KmknnSearcher<Index_, Data_, Distance_, DistanceMetricData_, KmeansFloat_, DistanceMetricCenter_>;
520
521public:
522 std::unique_ptr<knncolle::Searcher<Index_, Data_, Distance_> > initialize() const {
523 return initialize_known();
524 }
525
526 auto initialize_known() const {
527 return std::make_unique<KmknnSearcher<Index_, Data_, Distance_, DistanceMetricData_, KmeansFloat_, DistanceMetricCenter_> >(*this);
528 }
529
530public:
531 void save(const std::filesystem::path& dir) const {
532 knncolle::quick_save(dir / "ALGORITHM", kmknn_prebuilt_save_name, std::strlen(kmknn_prebuilt_save_name));
533 knncolle::quick_save(dir / "DATA", my_data.data(), my_data.size());
534 knncolle::quick_save(dir / "NUM_OBS", &my_obs, 1);
535 knncolle::quick_save(dir / "NUM_DIM", &my_dim, 1);
536 const auto num_centers = my_sizes.size();
537 knncolle::quick_save(dir / "NUM_CENTERS", &num_centers, 1);
538
539 knncolle::quick_save(dir / "SIZES", my_sizes.data(), my_sizes.size());
540 knncolle::quick_save(dir / "OFFSETS", my_offsets.data(), my_offsets.size());
541 knncolle::quick_save(dir / "CENTERS", my_centers.data(), my_centers.size());
542 knncolle::quick_save(dir / "OBSERVATION_ID", my_observation_id.data(), my_observation_id.size());
543 knncolle::quick_save(dir / "NEW_LOCATION", my_new_location.data(), my_new_location.size());
544 knncolle::quick_save(dir / "DIST_TO_CENTROID", my_dist_to_centroid.data(), my_dist_to_centroid.size());
545
547 knncolle::quick_save(dir / "FLOAT_TYPE", &float_type, 1);
548 auto& kfcust = custom_save_for_kmknn_kmeansfloat<KmeansFloat_>();
549 if (kfcust) {
550 kfcust(dir);
551 }
552
553 {
554 const auto distdir = dir / "DISTANCE_DATA";
555 std::filesystem::create_directory(distdir);
556 my_metric_data->save(distdir);
557 }
558
559 {
560 const auto distdir = dir / "DISTANCE_CENTER";
561 std::filesystem::create_directory(distdir);
562 my_metric_center->save(distdir);
563 }
564 }
565
566 KmknnPrebuilt(const std::filesystem::path& dir) {
567 knncolle::quick_load(dir / "NUM_OBS", &my_obs, 1);
568 knncolle::quick_load(dir / "NUM_DIM", &my_dim, 1);
569 auto num_centers = my_sizes.size();
570 knncolle::quick_load(dir / "NUM_CENTERS", &num_centers, 1);
571
572 my_data.resize(sanisizer::product<I<decltype(my_data.size())> >(sanisizer::attest_gez(my_obs), my_dim));
573 knncolle::quick_load(dir / "DATA", my_data.data(), my_data.size());
574
575 sanisizer::resize(my_sizes, sanisizer::attest_gez(num_centers));
576 knncolle::quick_load(dir / "SIZES", my_sizes.data(), my_sizes.size());
577 sanisizer::resize(my_offsets, sanisizer::attest_gez(num_centers));
578 knncolle::quick_load(dir / "OFFSETS", my_offsets.data(), my_offsets.size());
579 my_centers.resize(sanisizer::product<I<decltype(my_centers.size())> >(my_dim, sanisizer::attest_gez(num_centers)));
580 knncolle::quick_load(dir / "CENTERS", my_centers.data(), my_centers.size());
581
582 sanisizer::resize(my_observation_id, sanisizer::attest_gez(my_obs));
583 knncolle::quick_load(dir / "OBSERVATION_ID", my_observation_id.data(), my_observation_id.size());
584 sanisizer::resize(my_new_location, sanisizer::attest_gez(my_obs));
585 knncolle::quick_load(dir / "NEW_LOCATION", my_new_location.data(), my_new_location.size());
586 sanisizer::resize(my_dist_to_centroid, sanisizer::attest_gez(my_obs));
587 knncolle::quick_load(dir / "DIST_TO_CENTROID", my_dist_to_centroid.data(), my_dist_to_centroid.size());
588
589 {
590 auto dptr = knncolle::load_distance_metric_raw<Data_, Distance_>(dir / "DISTANCE_DATA");
591 auto xptr = dynamic_cast<DistanceMetricData_*>(dptr);
592 if (xptr == NULL) {
593 throw std::runtime_error("cannot cast the loaded distance metric to a DistanceMetricData_");
594 }
595 my_metric_data.reset(xptr);
596 }
597
598 {
599 auto dptr = knncolle::load_distance_metric_raw<Data_, Distance_>(dir / "DISTANCE_CENTER");
600 auto xptr = dynamic_cast<DistanceMetricCenter_*>(dptr);
601 if (xptr == NULL) {
602 throw std::runtime_error("cannot cast the loaded distance metric to a DistanceMetricCenter_");
603 }
604 my_metric_center.reset(xptr);
605 }
606 }
607};
647template<
648 typename Index_,
649 typename Data_,
650 typename Distance_,
651 class Matrix_ = knncolle::Matrix<Index_, Data_>,
652 class DistanceMetricData_ = knncolle::DistanceMetric<Data_, Distance_>,
653 typename KmeansIndex_ = Index_,
654 typename KmeansData_ = Data_,
655 typename KmeansCluster_ = Index_,
656 typename KmeansFloat_ = Distance_,
658 class DistanceMetricCenter_ = knncolle::DistanceMetric<KmeansFloat_, Distance_>
659>
660class KmknnBuilder final : public knncolle::Builder<Index_, Data_, Distance_, Matrix_> {
661public:
666
667private:
668 std::shared_ptr<const DistanceMetricData_> my_metric_data;
669 std::shared_ptr<const DistanceMetricCenter_> my_metric_center;
670 Options my_options;
671
672public:
687 std::shared_ptr<const DistanceMetricData_> metric_data,
688 std::shared_ptr<const DistanceMetricCenter_> metric_center,
689 Options options
690 ) :
691 my_metric_data(std::move(metric_data)),
692 my_metric_center(std::move(metric_center)),
693 my_options(std::move(options))
694 {}
695
702 KmknnBuilder(std::shared_ptr<const DistanceMetricData_> metric_data, std::shared_ptr<const DistanceMetricCenter_> metric_center) :
703 KmknnBuilder(std::move(metric_data), std::move(metric_center), {}) {}
704
705 // Don't provide an overload that accepts a raw metric pointer and the options,
706 // as it's possible for the raw pointer to be constructed first, and then the options
707 // is constructed but throws an error somewhere (e.g., in an IIFE), causing a memory leak.
708 // as the raw pointer is never passed to the shared_ptr for management.
709
715 return my_options;
716 }
717
718public:
722 knncolle::Prebuilt<Index_, Data_, Distance_>* build_raw(const Matrix_& data) const {
723 return build_known_raw(data);
724 }
729public:
733 auto build_known_raw(const Matrix_& data) const {
734 const auto ndim = data.num_dimensions();
735 const auto nobs = data.num_observations();
736
737 typedef std::vector<Data_> Store;
738 Store store(sanisizer::product<typename Store::size_type>(ndim, nobs));
739
740 auto work = data.new_known_extractor();
741 for (I<decltype(nobs)> o = 0; o < nobs; ++o) {
742 auto ptr = work->next();
743 std::copy_n(ptr, ndim, store.data() + sanisizer::product_unsafe<std::size_t>(o, ndim));
744 }
745
746 return new KmknnPrebuilt<Index_, Data_, Distance_, DistanceMetricData_, KmeansFloat_, DistanceMetricCenter_>(
747 ndim,
748 nobs,
749 std::move(store),
750 my_metric_data,
751 my_metric_center,
752 my_options
753 );
754 }
755
759 auto build_known_unique(const Matrix_& data) const {
760 return std::unique_ptr<I<decltype(*build_known_raw(data))> >(build_known_raw(data));
761 }
762
766 auto build_known_shared(const Matrix_& data) const {
767 return std::shared_ptr<I<decltype(*build_known_raw(data))> >(build_known_raw(data));
768 }
769};
770
771}
772
773#endif
virtual Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const=0
void report(std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
void add(Index_ i, Distance_ d)
Distance_ limit() const
void reset(Index_ k)
Perform a nearest neighbor search based on k-means clustering.
Definition Kmknn.hpp:660
KmknnBuilder(std::shared_ptr< const DistanceMetricData_ > metric_data, std::shared_ptr< const DistanceMetricCenter_ > metric_center)
Definition Kmknn.hpp:702
auto build_known_unique(const Matrix_ &data) const
Definition Kmknn.hpp:759
auto build_known_raw(const Matrix_ &data) const
Definition Kmknn.hpp:733
KmknnOptions< Index_, Data_, Distance_, KmeansIndex_, KmeansData_, KmeansCluster_, KmeansFloat_, KmeansMatrix_ > Options
Definition Kmknn.hpp:665
KmknnBuilder(std::shared_ptr< const DistanceMetricData_ > metric_data, std::shared_ptr< const DistanceMetricCenter_ > metric_center, Options options)
Definition Kmknn.hpp:686
auto build_known_shared(const Matrix_ &data) const
Definition Kmknn.hpp:766
Options & get_options()
Definition Kmknn.hpp:714
Details< Index_ > compute(const Matrix_ &data, const Initialize< Index_, Data_, Cluster_, Float_, Matrix_ > &initialize, const Refine< Index_, Data_, Cluster_, Float_, Matrix_ > &refine, Cluster_ num_centers, Float_ *centers, Cluster_ *clusters)
Cluster_ remove_unused_centers(const std::size_t num_dimensions, const Index_ num_observations, Cluster_ *const clusters, const Cluster_ num_centers, Float_ *const centers, std::vector< Index_ > &sizes)
Namespace for the knncolle_kmeans library.
Definition Kmknn.hpp:25
std::function< void(const std::filesystem::path &)> & custom_save_for_kmknn_kmeansfloat()
Definition Kmknn.hpp:48
void quick_load(const std::filesystem::path &path, Input_ *const contents, const Length_ length)
NumericType get_numeric_type()
Index_ count_all_neighbors_without_self(Index_ count)
DistanceMetric< Data_, Distance_ > * load_distance_metric_raw(const std::filesystem::path &dir)
void quick_save(const std::filesystem::path &path, const Input_ *const contents, const Length_ length)
void report_all_neighbors(std::vector< std::pair< Distance_, Index_ > > &all_neighbors, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Options for KmknnBuilder construction.
Definition Kmknn.hpp:80
std::shared_ptr< kmeans::Initialize< KmeansIndex_, KmeansData_, KmeansCluster_, KmeansFloat_, KmeansMatrix_ > > initialize_algorithm
Definition Kmknn.hpp:91
std::shared_ptr< kmeans::Refine< KmeansIndex_, KmeansData_, KmeansCluster_, KmeansFloat_, KmeansMatrix_ > > refine_algorithm
Definition Kmknn.hpp:97
double power
Definition Kmknn.hpp:85