146 std::vector<Data_> my_data;
147 std::shared_ptr<const DistanceMetric_> my_metric;
160 static const Index_ LEAF = 0;
164 Distance_ radius = 0;
178 std::vector<Node> my_nodes;
180 typedef std::pair<Distance_, Index_> DataPoint;
183 Index_ build(Index_ lower, Index_ upper,
const Data_* coords, std::vector<DataPoint>& items, Rng_& rng) {
190 Index_ pos = my_nodes.size();
191 my_nodes.emplace_back();
192 Node& node = my_nodes.back();
194 Index_ gap = upper - lower;
205 Index_ i = (rng() % gap + lower);
206 std::swap(items[lower], items[i]);
207 const auto& vantage = items[lower];
208 node.index = vantage.second;
209 const Data_* vantage_ptr = coords +
static_cast<size_t>(vantage.second) * my_dim;
212 for (Index_ i = lower + 1; i < upper; ++i) {
213 const Data_* loc = coords +
static_cast<size_t>(items[i].second) * my_dim;
214 items[i].first = my_metric->raw(my_dim, vantage_ptr, loc);
218 Index_ median = lower + gap/2;
219 Index_ lower_p1 = lower + 1;
220 std::nth_element(items.begin() + lower_p1, items.begin() + median, items.begin() + upper);
223 node.radius = my_metric->normalize(items[median].first);
226 if (lower_p1 < median) {
227 node.left = build(lower_p1, median, coords, items, rng);
229 if (median < upper) {
230 node.right = build(median, upper, coords, items, rng);
234 const auto& leaf = items[lower];
235 node.index = leaf.second;
242 std::vector<Index_> my_new_locations;
248 VptreePrebuilt(
size_t num_dim, Index_ num_obs, std::vector<Data_> data, std::shared_ptr<const DistanceMetric_> metric) :
251 my_data(std::move(data)),
252 my_metric(std::move(metric))
255 std::vector<DataPoint> items;
256 items.reserve(my_obs);
257 for (Index_ i = 0; i < my_obs; ++i) {
258 items.emplace_back(0, i);
261 my_nodes.reserve(my_obs);
267 uint64_t base = 1234567890, m1 = my_obs, m2 = my_dim;
268 std::mt19937_64 rand(base * m1 + m2);
270 build(0, my_obs, my_data.data(), items, rand);
274 std::vector<uint8_t> used(my_obs);
275 std::vector<Data_> buffer(my_dim);
276 my_new_locations.resize(my_obs);
277 auto host = my_data.data();
279 for (Index_ o = 0; o < num_obs; ++o) {
284 auto& current = my_nodes[o];
285 my_new_locations[current.index] = o;
286 if (current.index == o) {
290 auto optr = host +
static_cast<size_t>(o) * my_dim;
291 std::copy_n(optr, my_dim, buffer.begin());
292 Index_ replacement = current.index;
295 auto rptr = host +
static_cast<size_t>(replacement) * my_dim;
296 std::copy_n(rptr, my_dim, optr);
297 used[replacement] = 1;
299 const auto& next = my_nodes[replacement];
300 my_new_locations[next.index] = replacement;
303 replacement = next.index;
304 }
while (replacement != o);
306 std::copy(buffer.begin(), buffer.end(), optr);
315 void search_nn(Index_ curnode_index,
const Data_* target, Distance_& max_dist, NeighborQueue<Index_, Distance_>& nearest)
const {
316 auto nptr = my_data.data() +
static_cast<size_t>(curnode_index) * my_dim;
317 Distance_ dist = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
320 const auto& curnode = my_nodes[curnode_index];
321 if (dist <= max_dist) {
322 nearest.add(curnode.index, dist);
323 if (nearest.is_full()) {
324 max_dist = nearest.limit();
328 if (dist < curnode.radius) {
329 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) {
330 search_nn(curnode.left, target, max_dist, nearest);
333 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) {
334 search_nn(curnode.right, target, max_dist, nearest);
338 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) {
339 search_nn(curnode.right, target, max_dist, nearest);
342 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) {
343 search_nn(curnode.left, target, max_dist, nearest);
348 template<
bool count_only_,
typename Output_>
349 void search_all(Index_ curnode_index,
const Data_* target, Distance_ threshold, Output_& all_neighbors)
const {
350 auto nptr = my_data.data() +
static_cast<size_t>(curnode_index) * my_dim;
351 Distance_ dist = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
354 const auto& curnode = my_nodes[curnode_index];
355 if (dist <= threshold) {
356 if constexpr(count_only_) {
359 all_neighbors.emplace_back(dist, curnode.index);
363 if (dist < curnode.radius) {
364 if (curnode.left != LEAF && dist - threshold <= curnode.radius) {
365 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
368 if (curnode.right != LEAF && dist + threshold >= curnode.radius) {
369 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
373 if (curnode.right != LEAF && dist + threshold >= curnode.radius) {
374 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
377 if (curnode.left != LEAF && dist - threshold <= curnode.radius) {
378 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
383 friend class VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_>;
389 std::unique_ptr<Searcher<Index_, Data_, Distance_> >
initialize()
const {
390 return std::make_unique<VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_> >(*this);