189 std::vector<Store_> my_data;
191 std::vector<Index_> my_sizes;
192 std::vector<Index_> my_offsets;
193 std::vector<Store_> my_centers;
195 std::vector<Index_> my_observation_id, my_new_location;
196 std::vector<Float_> my_dist_to_centroid;
208 my_long_ndim(my_dim),
209 my_data(std::move(data))
212 if (init ==
nullptr) {
216 if (refine ==
nullptr) {
220 Index_ ncenters = std::ceil(std::pow(my_obs, options.
power));
221 my_centers.resize(
static_cast<size_t>(ncenters) * my_long_ndim);
224 std::vector<Index_> clusters(my_obs);
225 auto output =
kmeans::compute(mat, *init, *refine, ncenters, my_centers.data(), clusters.data());
229 my_sizes.resize(ncenters);
230 std::vector<Index_> remap(ncenters);
231 Index_ survivors = 0;
232 for (Index_ c = 0; c < ncenters; ++c) {
233 if (output.sizes[c]) {
235 auto src = my_centers.begin() +
static_cast<size_t>(c) * my_long_ndim;
236 auto dest = my_centers.begin() +
static_cast<size_t>(survivors) * my_long_ndim;
237 std::copy_n(src, my_dim, dest);
239 remap[c] = survivors;
240 my_sizes[survivors] = output.sizes[c];
245 if (survivors < ncenters) {
246 for (
auto& c : clusters) {
249 ncenters = survivors;
250 my_centers.resize(
static_cast<size_t>(ncenters) * my_long_ndim);
251 my_sizes.resize(ncenters);
255 my_offsets.resize(ncenters);
256 for (Index_ i = 1; i < ncenters; ++i) {
257 my_offsets[i] = my_offsets[i - 1] + my_sizes[i - 1];
261 std::vector<std::pair<Float_, Index_> > by_distance(my_obs);
263 auto sofar = my_offsets;
264 auto host = my_data.data();
265 for (Index_ o = 0; o < my_obs; ++o) {
266 auto optr = host +
static_cast<size_t>(o) * my_long_ndim;
267 auto clustid = clusters[o];
268 auto cptr = my_centers.data() +
static_cast<size_t>(clustid) * my_long_ndim;
270 auto& counter = sofar[clustid];
271 auto& current = by_distance[counter];
272 current.first = Distance_::normalize(Distance_::template raw_distance<Float_>(optr, cptr, my_dim));
278 for (Index_ c = 0; c < ncenters; ++c) {
279 auto begin = by_distance.begin() + my_offsets[c];
280 std::sort(begin, begin + my_sizes[c]);
286 auto host = my_data.data();
287 std::vector<uint8_t> used(my_obs);
288 std::vector<Store_> buffer(my_dim);
289 my_observation_id.resize(my_obs);
290 my_dist_to_centroid.resize(my_obs);
291 my_new_location.resize(my_obs);
293 for (Index_ o = 0; o < my_obs; ++o) {
298 const auto& current = by_distance[o];
299 my_observation_id[o] = current.second;
300 my_dist_to_centroid[o] = current.first;
301 my_new_location[current.second] = o;
302 if (current.second == o) {
308 auto optr = host +
static_cast<size_t>(o) * my_long_ndim;
309 std::copy_n(optr, my_dim, buffer.begin());
310 Index_ replacement = current.second;
312 auto rptr = host +
static_cast<size_t>(replacement) * my_long_ndim;
313 std::copy_n(rptr, my_dim, optr);
314 used[replacement] = 1;
316 const auto& next = by_distance[replacement];
317 my_observation_id[replacement] = next.second;
318 my_dist_to_centroid[replacement] = next.first;
319 my_new_location[next.second] = replacement;
322 replacement = next.second;
323 }
while (replacement != o);
325 std::copy(buffer.begin(), buffer.end(), optr);
333 template<
typename Query_>
334 void search_nn(
const Query_* target, internal::NeighborQueue<Index_, Float_>& nearest, std::vector<std::pair<Float_, Index_> >& center_order)
const {
340 center_order.clear();
341 size_t ncenters = my_sizes.size();
342 center_order.reserve(ncenters);
343 auto clust_ptr = my_centers.data();
344 for (
size_t c = 0; c < ncenters; ++c, clust_ptr += my_dim) {
345 center_order.emplace_back(Distance_::template raw_distance<Float_>(target, clust_ptr, my_dim), c);
347 std::sort(center_order.begin(), center_order.end());
350 Float_ threshold_raw = std::numeric_limits<Float_>::infinity();
351 for (
const auto& curcent : center_order) {
352 const Index_ center = curcent.second;
353 const Float_ dist2center = Distance_::normalize(curcent.first);
355 const auto cur_nobs = my_sizes[center];
356 const Float_* dIt = my_dist_to_centroid.data() + my_offsets[center];
357 const Float_ maxdist = *(dIt + cur_nobs - 1);
359 Index_ firstcell = 0;
360#if KNNCOLLE_KMKNN_USE_UPPER
361 Float_ upper_bd = std::numeric_limits<Float_>::max();
364 if (!std::isinf(threshold_raw)) {
365 const Float_ threshold = Distance_::normalize(threshold_raw);
371 const Float_ lower_bd = dist2center - threshold;
372 if (maxdist < lower_bd) {
376 firstcell = std::lower_bound(dIt, dIt + cur_nobs, lower_bd) - dIt;
378#if KNNCOLLE_KMKNN_USE_UPPER
382 upper_bd = threshold + dist2center;
386 const auto cur_start = my_offsets[center];
387 const auto* other_cell = my_data.data() + my_long_ndim *
static_cast<size_t>(cur_start + firstcell);
388 for (
auto celldex = firstcell; celldex < cur_nobs; ++celldex, other_cell += my_dim) {
389#if KNNCOLLE_KMKNN_USE_UPPER
390 if (*(dIt + celldex) > upper_bd) {
395 auto dist2cell_raw = Distance_::template raw_distance<Float_>(target, other_cell, my_dim);
396 if (dist2cell_raw <= threshold_raw) {
397 nearest.add(cur_start + celldex, dist2cell_raw);
398 if (nearest.is_full()) {
399 threshold_raw = nearest.limit();
400#if KNNCOLLE_KMKNN_USE_UPPER
401 upper_bd = Distance_::normalize(threshold_raw) + dist2center;
409 template<
bool count_only_,
typename Query_,
typename Output_>
410 void search_all(
const Query_* target, Float_ threshold, Output_& all_neighbors)
const {
411 Float_ threshold_raw = Distance_::denormalize(threshold);
416 Index_ ncenters = my_sizes.size();
417 auto center_ptr = my_centers.data();
418 for (Index_ center = 0; center < ncenters; ++center, center_ptr += my_dim) {
419 const Float_ dist2center = Distance_::normalize(Distance_::template raw_distance<Float_>(target, center_ptr, my_dim));
421 auto cur_nobs = my_sizes[center];
422 const Float_* dIt = my_dist_to_centroid.data() + my_offsets[center];
423 const Float_ maxdist = *(dIt + cur_nobs - 1);
429 const Float_ lower_bd = dist2center - threshold;
430 if (maxdist < lower_bd) {
434 Index_ firstcell = std::lower_bound(dIt, dIt + cur_nobs, lower_bd) - dIt;
435#if KNNCOLLE_KMKNN_USE_UPPER
439 Float_ upper_bd = threshold + dist2center;
442 const auto cur_start = my_offsets[center];
443 auto other_ptr = my_data.data() + my_long_ndim *
static_cast<size_t>(cur_start + firstcell);
444 for (
auto celldex = firstcell; celldex < cur_nobs; ++celldex, other_ptr += my_dim) {
445#if KNNCOLLE_KMKNN_USE_UPPER
446 if (*(dIt + celldex) > upper_bd) {
451 auto dist2cell_raw = Distance_::template raw_distance<Float_>(target, other_ptr, my_dim);
452 if (dist2cell_raw <= threshold_raw) {
453 if constexpr(count_only_) {
456 all_neighbors.emplace_back(dist2cell_raw, cur_start + celldex);
463 void normalize(std::vector<Index_>* output_indices, std::vector<Float_>* output_distances)
const {
464 if (output_indices) {
465 for (
auto& s : *output_indices) {
466 s = my_observation_id[s];
469 if (output_distances) {
470 for (
auto& d : *output_distances) {
471 d = Distance_::normalize(d);
476 friend class KmknnSearcher<Distance_, Dim_, Index_, Store_, Float_>;
482 std::unique_ptr<Searcher<Index_, Float_> >
initialize()
const {
483 return std::make_unique<KmknnSearcher<Distance_, Dim_, Index_, Store_, Float_> >(
this);