1#ifndef KNNCOLLE_VPTREE_HPP
2#define KNNCOLLE_VPTREE_HPP
22#include "sanisizer/sanisizer.hpp"
35inline static constexpr const char* vptree_prebuilt_save_name =
"knncolle::Vptree";
40template<
typename Index_,
typename Data_,
typename Distance_,
class DistanceMetric_>
43template<
typename Index_,
typename Data_,
typename Distance_,
class DistanceMetric_>
44class VptreeSearcher final :
public Searcher<Index_, Data_, Distance_> {
46 VptreeSearcher(
const VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>& parent) : my_parent(parent) {}
49 const VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>& my_parent;
50 NeighborQueue<Index_, Distance_> my_nearest;
51 std::vector<std::pair<Distance_, Index_> > my_all_neighbors;
54 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
55 my_nearest.reset(k + 1);
56 auto iptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(my_parent.my_new_locations[i], my_parent.my_dim);
57 Distance_ max_dist = std::numeric_limits<Distance_>::max();
58 my_parent.search_nn(0, iptr, max_dist, my_nearest);
59 my_nearest.report(output_indices, output_distances, i);
62 void search(
const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
65 if (k == 0 || my_parent.my_nodes.empty()) {
67 output_indices->clear();
69 if (output_distances) {
70 output_distances->clear();
75 Distance_ max_dist = std::numeric_limits<Distance_>::max();
76 my_parent.search_nn(0, query, max_dist, my_nearest);
77 my_nearest.report(output_indices, output_distances);
81 bool can_search_all()
const {
85 Index_ search_all(Index_ i, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
86 auto iptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(my_parent.my_new_locations[i], my_parent.my_dim);
88 if (!output_indices && !output_distances) {
90 my_parent.template search_all<true>(0, iptr, d, count);
94 my_all_neighbors.clear();
95 my_parent.template search_all<false>(0, iptr, d, my_all_neighbors);
101 Index_ search_all(
const Data_* query, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
102 if (my_parent.my_nodes.empty()) {
103 my_all_neighbors.clear();
108 if (!output_indices && !output_distances) {
110 my_parent.template search_all<true>(0, query, d, count);
114 my_all_neighbors.clear();
115 my_parent.template search_all<false>(0, query, d, my_all_neighbors);
117 return my_all_neighbors.size();
122template<
typename Index_,
typename Data_,
typename Distance_,
class DistanceMetric_>
123class VptreePrebuilt final :
public Prebuilt<Index_, Data_, Distance_> {
127 std::vector<Data_> my_data;
128 std::shared_ptr<const DistanceMetric_> my_metric;
131 Index_ num_observations()
const {
135 std::size_t num_dimensions()
const {
141 static const Index_ LEAF = 0;
145 Distance_ radius = 0;
159 std::vector<Node> my_nodes;
161 typedef std::pair<Distance_, Index_> DataPoint;
164 Index_ build(Index_ lower, Index_ upper,
const Data_* coords, std::vector<DataPoint>& items, Rng_& rng) {
171 Index_ pos = my_nodes.size();
172 my_nodes.emplace_back();
173 Node& node = my_nodes.back();
175 Index_ gap = upper - lower;
186 Index_ i = (rng() % gap + lower);
187 std::swap(items[lower], items[i]);
188 const auto& vantage = items[lower];
189 node.index = vantage.second;
190 const Data_* vantage_ptr = coords + sanisizer::product_unsafe<std::size_t>(vantage.second, my_dim);
193 for (Index_ i = lower + 1; i < upper; ++i) {
194 const Data_* loc = coords + sanisizer::product_unsafe<std::size_t>(items[i].second, my_dim);
195 items[i].first = my_metric->raw(my_dim, vantage_ptr, loc);
199 Index_ median = lower + gap/2;
200 Index_ lower_p1 = lower + 1;
201 std::nth_element(items.begin() + lower_p1, items.begin() + median, items.begin() + upper);
204 node.radius = my_metric->normalize(items[median].first);
207 if (lower_p1 < median) {
208 node.left = build(lower_p1, median, coords, items, rng);
210 if (median < upper) {
211 node.right = build(median, upper, coords, items, rng);
215 const auto& leaf = items[lower];
216 node.index = leaf.second;
223 std::vector<Index_> my_new_locations;
226 VptreePrebuilt(std::size_t num_dim, Index_ num_obs, std::vector<Data_> data, std::shared_ptr<const DistanceMetric_> metric) :
229 my_data(std::move(data)),
230 my_metric(std::move(metric))
233 std::vector<DataPoint> items;
234 items.reserve(my_obs);
235 for (Index_ i = 0; i < my_obs; ++i) {
236 items.emplace_back(0, i);
239 my_nodes.reserve(my_obs);
245 const std::mt19937_64::result_type base = 1234567890, m1 = my_obs, m2 = my_dim;
246 std::mt19937_64 rand(base * m1 + m2);
248 build(0, my_obs, my_data.data(), items, rand);
252 auto used = sanisizer::create<std::vector<char> >(sanisizer::attest_gez(my_obs));
253 auto buffer = sanisizer::create<std::vector<Data_> >(sanisizer::attest_gez(my_dim));
254 sanisizer::resize(my_new_locations, sanisizer::attest_gez(my_obs));
255 auto host = my_data.data();
257 for (Index_ o = 0; o < num_obs; ++o) {
262 auto& current = my_nodes[o];
263 my_new_locations[current.index] = o;
264 if (current.index == o) {
268 auto optr = host + sanisizer::product_unsafe<std::size_t>(o, my_dim);
269 std::copy_n(optr, my_dim, buffer.begin());
270 Index_ replacement = current.index;
273 auto rptr = host + sanisizer::product_unsafe<std::size_t>(replacement, my_dim);
274 std::copy_n(rptr, my_dim, optr);
275 used[replacement] = 1;
277 const auto& next = my_nodes[replacement];
278 my_new_locations[next.index] = replacement;
281 replacement = next.index;
282 }
while (replacement != o);
284 std::copy(buffer.begin(), buffer.end(), optr);
290 void search_nn(Index_ curnode_index,
const Data_* target, Distance_& max_dist, NeighborQueue<Index_, Distance_>& nearest)
const {
291 auto nptr = my_data.data() + sanisizer::product_unsafe<std::size_t>(curnode_index, my_dim);
292 Distance_ dist = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
295 const auto& curnode = my_nodes[curnode_index];
296 if (dist <= max_dist) {
297 nearest.add(curnode.index, dist);
298 if (nearest.is_full()) {
299 max_dist = nearest.limit();
303 if (dist < curnode.radius) {
304 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) {
305 search_nn(curnode.left, target, max_dist, nearest);
308 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) {
309 search_nn(curnode.right, target, max_dist, nearest);
313 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) {
314 search_nn(curnode.right, target, max_dist, nearest);
317 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) {
318 search_nn(curnode.left, target, max_dist, nearest);
323 template<
bool count_only_,
typename Output_>
324 void search_all(Index_ curnode_index,
const Data_* target, Distance_ threshold, Output_& all_neighbors)
const {
325 auto nptr = my_data.data() + sanisizer::product_unsafe<std::size_t>(curnode_index, my_dim);
326 Distance_ dist = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
329 const auto& curnode = my_nodes[curnode_index];
330 if (dist <= threshold) {
331 if constexpr(count_only_) {
334 all_neighbors.emplace_back(dist, curnode.index);
338 if (dist < curnode.radius) {
339 if (curnode.left != LEAF && dist - threshold <= curnode.radius) {
340 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
343 if (curnode.right != LEAF && dist + threshold >= curnode.radius) {
344 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
348 if (curnode.right != LEAF && dist + threshold >= curnode.radius) {
349 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
352 if (curnode.left != LEAF && dist - threshold <= curnode.radius) {
353 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
358 friend class VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_>;
361 std::unique_ptr<Searcher<Index_, Data_, Distance_> > initialize()
const {
362 return initialize_known();
365 auto initialize_known()
const {
366 return std::make_unique<VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_> >(*this);
370 void save(
const std::filesystem::path& dir)
const {
371 quick_save(dir /
"ALGORITHM", vptree_prebuilt_save_name, std::strlen(vptree_prebuilt_save_name));
372 quick_save(dir /
"DATA", my_data.data(), my_data.size());
375 quick_save(dir /
"NODES", my_nodes.data(), my_nodes.size());
376 quick_save(dir /
"NEW_LOCATIONS", my_new_locations.data(), my_new_locations.size());
378 const auto distdir = dir /
"DISTANCE";
379 std::filesystem::create_directory(distdir);
380 my_metric->save(distdir);
383 VptreePrebuilt(
const std::filesystem::path& dir) {
387 my_data.resize(sanisizer::product<I<
decltype(my_data.size())> >(sanisizer::attest_gez(my_obs), my_dim));
388 quick_load(dir /
"DATA", my_data.data(), my_data.size());
390 sanisizer::resize(my_nodes, sanisizer::attest_gez(my_obs));
391 quick_load(dir /
"NODES", my_nodes.data(), my_nodes.size());
393 sanisizer::resize(my_new_locations, sanisizer::attest_gez(my_obs));
394 quick_load(dir /
"NEW_LOCATIONS", my_new_locations.data(), my_new_locations.size());
396 auto dptr = load_distance_metric_raw<Data_, Distance_>(dir /
"DISTANCE");
397 auto xptr =
dynamic_cast<DistanceMetric_*
>(dptr);
399 throw std::runtime_error(
"cannot cast the loaded distance metric to a DistanceMetric_");
401 my_metric.reset(xptr);
453 class Matrix_ = Matrix<Index_, Data_>,
454 class DistanceMetric_ = DistanceMetric<Data_, Distance_>
461 VptreeBuilder(std::shared_ptr<const DistanceMetric_> metric) : my_metric(std::move(metric)) {}
464 std::shared_ptr<const DistanceMetric_> my_metric;
482 std::size_t ndim = data.num_dimensions();
483 Index_ nobs = data.num_observations();
484 auto work = data.new_known_extractor();
487 std::vector<Data_> store(sanisizer::product<
typename std::vector<Data_>::size_type>(ndim, sanisizer::attest_gez(nobs)));
488 for (Index_ o = 0; o < nobs; ++o) {
489 std::copy_n(work->next(), ndim, store.data() + sanisizer::product_unsafe<std::size_t>(o, ndim));
492 return new VptreePrebuilt<Index_, Data_, Distance_, DistanceMetric_>(ndim, nobs, std::move(store), my_metric);
Interface to build nearest-neighbor indices.
Interface for the input matrix.
Helper class to track nearest neighbors.
Interface for prebuilt nearest-neighbor indices.
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:28
virtual Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const =0
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:29
Perform a nearest neighbor search based on a vantage point (VP) tree.
Definition Vptree.hpp:456
VptreeBuilder(std::shared_ptr< const DistanceMetric_ > metric)
Definition Vptree.hpp:461
auto build_known_unique(const Matrix_ &data) const
Definition Vptree.hpp:498
auto build_known_raw(const Matrix_ &data) const
Definition Vptree.hpp:481
auto build_known_shared(const Matrix_ &data) const
Definition Vptree.hpp:505
Classes for distance calculations.
Collection of KNN algorithms.
Definition Bruteforce.hpp:29
void quick_load(const std::filesystem::path &path, Input_ *const contents, const Length_ length)
Definition utils.hpp:57
Index_ count_all_neighbors_without_self(Index_ count)
Definition report_all_neighbors.hpp:23
void quick_save(const std::filesystem::path &path, const Input_ *const contents, const Length_ length)
Definition utils.hpp:33
void report_all_neighbors(std::vector< std::pair< Distance_, Index_ > > &all_neighbors, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition report_all_neighbors.hpp:106
Format the output for Searcher::search_all().
Miscellaneous utilities for knncolle