1#ifndef KNNCOLLE_VPTREE_HPP
2#define KNNCOLLE_VPTREE_HPP
24#include "sanisizer/sanisizer.hpp"
37inline static constexpr const char* vptree_prebuilt_save_name =
"knncolle::Vptree";
47 std::optional<typename std::mt19937_64::result_type>
seed;
53template<
typename Index_,
typename Data_,
typename Distance_,
class DistanceMetric_>
56template<
typename Index_>
57struct VptreeSearchHistory {
58 VptreeSearchHistory(
bool right, Index_ node) : node(node), right(right) {}
63template<
typename Index_,
typename Data_,
typename Distance_,
class DistanceMetric_>
64class VptreeSearcher final :
public Searcher<Index_, Data_, Distance_> {
66 VptreeSearcher(
const VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>& parent) : my_parent(parent) {}
69 const VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>& my_parent;
70 NeighborQueue<Index_, Distance_> my_nearest;
71 std::vector<VptreeSearchHistory<Index_> > my_history;
72 std::vector<std::pair<Distance_, Index_> > my_all_neighbors;
75 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
76 my_nearest.reset(k + 1);
77 auto iptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(my_parent.my_new_locations[i], my_parent.my_dim);
78 my_parent.search_nn(iptr, my_nearest, my_history);
79 my_nearest.report(output_indices, output_distances, i);
82 void search(
const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
85 if (k == 0 || my_parent.my_nodes.empty()) {
87 output_indices->clear();
89 if (output_distances) {
90 output_distances->clear();
95 my_parent.search_nn(query, my_nearest, my_history);
96 my_nearest.report(output_indices, output_distances);
100 bool can_search_all()
const {
104 Index_ search_all(Index_ i, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
105 auto iptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(my_parent.my_new_locations[i], my_parent.my_dim);
107 if (!output_indices && !output_distances) {
109 my_parent.template search_all<true>(iptr, d, count, my_history);
113 my_all_neighbors.clear();
114 my_parent.template search_all<false>(iptr, d, my_all_neighbors, my_history);
120 Index_ search_all(
const Data_* query, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
121 if (my_parent.my_nodes.empty()) {
122 my_all_neighbors.clear();
127 if (!output_indices && !output_distances) {
129 my_parent.template search_all<true>(query, d, count, my_history);
133 my_all_neighbors.clear();
134 my_parent.template search_all<false>(query, d, my_all_neighbors, my_history);
136 return my_all_neighbors.size();
141template<
typename Index_,
typename Data_,
typename Distance_,
class DistanceMetric_>
142class VptreePrebuilt final :
public Prebuilt<Index_, Data_, Distance_> {
146 std::vector<Data_> my_data;
147 std::shared_ptr<const DistanceMetric_> my_metric;
150 Index_ num_observations()
const {
154 std::size_t num_dimensions()
const {
164 static constexpr Index_ TERMINAL = 0;
168 Distance_ radius = 0;
174 Index_ left = TERMINAL;
177 Index_ right = TERMINAL;
180 std::vector<Node> my_nodes;
182 void build(
const VptreeOptions& options) {
183 typedef std::pair<Distance_, Index_> DataPoint;
184 std::vector<DataPoint> items;
185 items.reserve(my_obs);
186 for (Index_ i = 0; i < my_obs; ++i) {
187 items.emplace_back(0, i);
190 std::mt19937_64 rng([&]() {
191 if (options.seed.has_value()) {
192 return *(options.seed);
199 typedef typename std::mt19937_64::result_type SeedType;
200 const SeedType base = 1234567890, m1 = my_obs, m2 = my_dim;
201 return static_cast<SeedType
>(base * m1 + m2);
205 Index_ lower = 0, upper = my_obs;
208 my_nodes.reserve(my_obs);
209 const auto coords = my_data.data();
211 struct BuildHistory {
212 BuildHistory(Index_ lower, Index_ upper, Index_* right) : right(right), lower(lower), upper(upper) {}
216 std::vector<BuildHistory> history;
219 my_nodes.emplace_back();
220 Node& node = my_nodes.back();
222 const Index_ gap = upper - lower;
225 const auto& leaf = items[lower];
226 node.index = leaf.second;
229 if (history.empty()) {
232 *(history.back().right) = my_nodes.size();
233 lower = history.back().lower;
234 upper = history.back().upper;
242 const Index_ vp = (rng() % gap + lower);
243 std::swap(items[lower], items[vp]);
244 const auto& vantage = items[lower];
245 node.index = vantage.second;
246 const Data_* vantage_ptr = coords + sanisizer::product_unsafe<std::size_t>(vantage.second, my_dim);
250 const Index_ lower_p1 = lower + 1;
251 for (Index_ i = lower_p1 ; i < upper; ++i) {
252 const Data_* loc = coords + sanisizer::product_unsafe<std::size_t>(items[i].second, my_dim);
253 items[i].first = my_metric->raw(my_dim, vantage_ptr, loc);
258 const Index_ median = lower_p1 + (gap - 1)/2;
259 std::nth_element(items.begin() + lower_p1, items.begin() + median, items.begin() + upper);
262 node.radius = my_metric->normalize(items[median].first);
266 history.emplace_back(median, upper, &(node.right));
267 node.left = my_nodes.size();
274 const Index_ median = lower_p1;
275 node.radius = my_metric->normalize(items[median].first);
276 node.right = my_nodes.size();
288 std::vector<Index_> my_new_locations;
291 VptreePrebuilt(std::size_t num_dim, Index_ num_obs, std::vector<Data_> data, std::shared_ptr<const DistanceMetric_> metric,
const VptreeOptions& options) :
294 my_data(std::move(data)),
295 my_metric(std::move(metric))
301 auto used = sanisizer::create<std::vector<char> >(sanisizer::attest_gez(my_obs));
302 auto buffer = sanisizer::create<std::vector<Data_> >(sanisizer::attest_gez(my_dim));
303 sanisizer::resize(my_new_locations, sanisizer::attest_gez(my_obs));
304 auto host = my_data.data();
306 for (Index_ o = 0; o < num_obs; ++o) {
311 auto& current = my_nodes[o];
312 my_new_locations[current.index] = o;
313 if (current.index == o) {
317 auto optr = host + sanisizer::product_unsafe<std::size_t>(o, my_dim);
318 std::copy_n(optr, my_dim, buffer.begin());
319 Index_ replacement = current.index;
322 auto rptr = host + sanisizer::product_unsafe<std::size_t>(replacement, my_dim);
323 std::copy_n(rptr, my_dim, optr);
324 used[replacement] = 1;
326 const auto& next = my_nodes[replacement];
327 my_new_locations[next.index] = replacement;
330 replacement = next.index;
331 }
while (replacement != o);
333 std::copy(buffer.begin(), buffer.end(), optr);
339 static bool can_progress_left(
const Node& node,
const Distance_ dist_to_vp,
const Distance_ threshold) {
340 return node.left != TERMINAL && dist_to_vp - threshold <= node.radius;
343 static bool can_progress_right(
const Node& node,
const Distance_ dist_to_vp,
const Distance_ threshold) {
346 return node.right != TERMINAL && dist_to_vp + threshold >= node.radius;
349 void search_nn(
const Data_* target, NeighborQueue<Index_, Distance_>& nearest, std::vector<VptreeSearchHistory<Index_> >& history)
const {
351 Index_ curnode_offset = 0;
352 Distance_ max_dist = std::numeric_limits<Distance_>::max();
355 auto nptr = my_data.data() + sanisizer::product_unsafe<std::size_t>(curnode_offset, my_dim);
356 const Distance_ dist_to_vp = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
358 const auto& curnode = my_nodes[curnode_offset];
359 if (dist_to_vp <= max_dist) {
360 nearest.add(curnode.index, dist_to_vp);
361 if (nearest.is_full()) {
362 max_dist = nearest.limit();
366 if (dist_to_vp < curnode.radius) {
372 const bool can_left = curnode.left != TERMINAL;
373 const bool can_right = can_progress_right(curnode, dist_to_vp, max_dist);
377 history.emplace_back(
false, curnode_offset);
379 curnode_offset = curnode.left;
381 }
else if (can_right) {
382 curnode_offset = curnode.right;
392 const bool can_right = curnode.right != TERMINAL;
393 const bool can_left = can_progress_left(curnode, dist_to_vp, max_dist);
397 history.emplace_back(
true, curnode_offset);
399 curnode_offset = curnode.right;
409 if (history.empty()) {
413 auto& histinfo = history.back();
414 if (!histinfo.right) {
415 curnode_offset = my_nodes[histinfo.node].right;
417 curnode_offset = my_nodes[histinfo.node].left;
423 template<
bool count_only_,
typename Output_>
424 void search_all(
const Data_* target,
const Distance_ threshold, Output_& all_neighbors, std::vector<VptreeSearchHistory<Index_> >& history)
const {
426 Index_ curnode_offset = 0;
429 auto nptr = my_data.data() + sanisizer::product_unsafe<std::size_t>(curnode_offset, my_dim);
430 const Distance_ dist_to_vp = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
432 const auto& curnode = my_nodes[curnode_offset];
433 if (dist_to_vp <= threshold) {
434 if constexpr(count_only_) {
437 all_neighbors.emplace_back(dist_to_vp, curnode.index);
441 const bool can_left = can_progress_left(curnode, dist_to_vp, threshold);
442 const bool can_right = can_progress_right(curnode, dist_to_vp, threshold);
448 history.emplace_back(
false, curnode_offset);
450 curnode_offset = curnode.left;
452 }
else if (can_right) {
453 curnode_offset = curnode.right;
458 if (history.empty()) {
462 auto& histinfo = history.back();
463 curnode_offset = my_nodes[histinfo.node].right;
468 friend class VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_>;
471 std::unique_ptr<Searcher<Index_, Data_, Distance_> > initialize()
const {
472 return initialize_known();
475 auto initialize_known()
const {
476 return std::make_unique<VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_> >(*this);
480 void save(
const std::filesystem::path& dir)
const {
481 quick_save(dir /
"ALGORITHM", vptree_prebuilt_save_name, std::strlen(vptree_prebuilt_save_name));
482 quick_save(dir /
"DATA", my_data.data(), my_data.size());
485 quick_save(dir /
"NODES", my_nodes.data(), my_nodes.size());
486 quick_save(dir /
"NEW_LOCATIONS", my_new_locations.data(), my_new_locations.size());
488 const auto distdir = dir /
"DISTANCE";
489 std::filesystem::create_directory(distdir);
490 my_metric->save(distdir);
493 VptreePrebuilt(
const std::filesystem::path& dir) {
497 my_data.resize(sanisizer::product<I<
decltype(my_data.size())> >(sanisizer::attest_gez(my_obs), my_dim));
498 quick_load(dir /
"DATA", my_data.data(), my_data.size());
500 sanisizer::resize(my_nodes, sanisizer::attest_gez(my_obs));
501 quick_load(dir /
"NODES", my_nodes.data(), my_nodes.size());
503 sanisizer::resize(my_new_locations, sanisizer::attest_gez(my_obs));
504 quick_load(dir /
"NEW_LOCATIONS", my_new_locations.data(), my_new_locations.size());
506 auto dptr = load_distance_metric_raw<Data_, Distance_>(dir /
"DISTANCE");
507 auto xptr =
dynamic_cast<DistanceMetric_*
>(dptr);
509 throw std::runtime_error(
"cannot cast the loaded distance metric to a DistanceMetric_");
511 my_metric.reset(xptr);
563 class Matrix_ = Matrix<Index_, Data_>,
564 class DistanceMetric_ = DistanceMetric<Data_, Distance_>
572 VptreeBuilder(std::shared_ptr<const DistanceMetric_> metric,
VptreeOptions options) : my_metric(std::move(metric)), my_options(std::move(options)) {}
590 std::shared_ptr<const DistanceMetric_> my_metric;
609 std::size_t ndim = data.num_dimensions();
610 Index_ nobs = data.num_observations();
611 auto work = data.new_known_extractor();
614 std::vector<Data_> store(sanisizer::product<
typename std::vector<Data_>::size_type>(ndim, sanisizer::attest_gez(nobs)));
615 for (Index_ o = 0; o < nobs; ++o) {
616 std::copy_n(work->next(), ndim, store.data() + sanisizer::product_unsafe<std::size_t>(o, ndim));
619 return new VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>(ndim, nobs, std::move(store), my_metric, my_options);
Interface to build nearest-neighbor indices.
Interface for the input matrix.
Helper class to track nearest neighbors.
Interface for prebuilt nearest-neighbor indices.
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:28
virtual Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const =0
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:29
Perform a nearest neighbor search based on a vantage point (VP) tree.
Definition Vptree.hpp:566
VptreeBuilder(std::shared_ptr< const DistanceMetric_ > metric)
Definition Vptree.hpp:579
auto build_known_unique(const Matrix_ &data) const
Definition Vptree.hpp:625
VptreeOptions & get_options()
Definition Vptree.hpp:585
VptreeBuilder(std::shared_ptr< const DistanceMetric_ > metric, VptreeOptions options)
Definition Vptree.hpp:572
auto build_known_raw(const Matrix_ &data) const
Definition Vptree.hpp:608
auto build_known_shared(const Matrix_ &data) const
Definition Vptree.hpp:632
Classes for distance calculations.
Collection of KNN algorithms.
Definition Bruteforce.hpp:29
void quick_load(const std::filesystem::path &path, Input_ *const contents, const Length_ length)
Definition utils.hpp:57
Index_ count_all_neighbors_without_self(Index_ count)
Definition report_all_neighbors.hpp:23
void quick_save(const std::filesystem::path &path, const Input_ *const contents, const Length_ length)
Definition utils.hpp:33
void report_all_neighbors(std::vector< std::pair< Distance_, Index_ > > &all_neighbors, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition report_all_neighbors.hpp:106
Format the output for Searcher::search_all().
Options for VptreeBuilder construction.
Definition Vptree.hpp:42
std::optional< typename std::mt19937_64::result_type > seed
Definition Vptree.hpp:47
Miscellaneous utilities for knncolle