147 std::vector<Data_> my_data;
148 std::shared_ptr<const DistanceMetric_> my_metric;
161 static const Index_ LEAF = 0;
165 Distance_ radius = 0;
179 std::vector<Node> my_nodes;
181 typedef std::pair<Distance_, Index_> DataPoint;
184 Index_ build(Index_ lower, Index_ upper,
const Data_* coords, std::vector<DataPoint>& items, Rng_& rng) {
191 Index_ pos = my_nodes.size();
192 my_nodes.emplace_back();
193 Node& node = my_nodes.back();
195 Index_ gap = upper - lower;
206 Index_ i = (rng() % gap + lower);
207 std::swap(items[lower], items[i]);
208 const auto& vantage = items[lower];
209 node.index = vantage.second;
210 const Data_* vantage_ptr = coords +
static_cast<std::size_t
>(vantage.second) * my_dim;
213 for (Index_ i = lower + 1; i < upper; ++i) {
214 const Data_* loc = coords +
static_cast<std::size_t
>(items[i].second) * my_dim;
215 items[i].first = my_metric->raw(my_dim, vantage_ptr, loc);
219 Index_ median = lower + gap/2;
220 Index_ lower_p1 = lower + 1;
221 std::nth_element(items.begin() + lower_p1, items.begin() + median, items.begin() + upper);
224 node.radius = my_metric->normalize(items[median].first);
227 if (lower_p1 < median) {
228 node.left = build(lower_p1, median, coords, items, rng);
230 if (median < upper) {
231 node.right = build(median, upper, coords, items, rng);
235 const auto& leaf = items[lower];
236 node.index = leaf.second;
243 std::vector<Index_> my_new_locations;
249 VptreePrebuilt(std::size_t num_dim, Index_ num_obs, std::vector<Data_> data, std::shared_ptr<const DistanceMetric_> metric) :
252 my_data(std::move(data)),
253 my_metric(std::move(metric))
256 std::vector<DataPoint> items;
257 items.reserve(my_obs);
258 for (Index_ i = 0; i < my_obs; ++i) {
259 items.emplace_back(0, i);
262 my_nodes.reserve(my_obs);
268 uint64_t base = 1234567890, m1 = my_obs, m2 = my_dim;
269 std::mt19937_64 rand(base * m1 + m2);
271 build(0, my_obs, my_data.data(), items, rand);
275 std::vector<uint8_t> used(my_obs);
276 std::vector<Data_> buffer(my_dim);
277 my_new_locations.resize(my_obs);
278 auto host = my_data.data();
280 for (Index_ o = 0; o < num_obs; ++o) {
285 auto& current = my_nodes[o];
286 my_new_locations[current.index] = o;
287 if (current.index == o) {
291 auto optr = host +
static_cast<std::size_t
>(o) * my_dim;
292 std::copy_n(optr, my_dim, buffer.begin());
293 Index_ replacement = current.index;
296 auto rptr = host +
static_cast<std::size_t
>(replacement) * my_dim;
297 std::copy_n(rptr, my_dim, optr);
298 used[replacement] = 1;
300 const auto& next = my_nodes[replacement];
301 my_new_locations[next.index] = replacement;
304 replacement = next.index;
305 }
while (replacement != o);
307 std::copy(buffer.begin(), buffer.end(), optr);
316 void search_nn(Index_ curnode_index,
const Data_* target, Distance_& max_dist, NeighborQueue<Index_, Distance_>& nearest)
const {
317 auto nptr = my_data.data() +
static_cast<std::size_t
>(curnode_index) * my_dim;
318 Distance_ dist = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
321 const auto& curnode = my_nodes[curnode_index];
322 if (dist <= max_dist) {
323 nearest.add(curnode.index, dist);
324 if (nearest.is_full()) {
325 max_dist = nearest.limit();
329 if (dist < curnode.radius) {
330 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) {
331 search_nn(curnode.left, target, max_dist, nearest);
334 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) {
335 search_nn(curnode.right, target, max_dist, nearest);
339 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) {
340 search_nn(curnode.right, target, max_dist, nearest);
343 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) {
344 search_nn(curnode.left, target, max_dist, nearest);
349 template<
bool count_only_,
typename Output_>
350 void search_all(Index_ curnode_index,
const Data_* target, Distance_ threshold, Output_& all_neighbors)
const {
351 auto nptr = my_data.data() +
static_cast<std::size_t
>(curnode_index) * my_dim;
352 Distance_ dist = my_metric->normalize(my_metric->raw(my_dim, nptr, target));
355 const auto& curnode = my_nodes[curnode_index];
356 if (dist <= threshold) {
357 if constexpr(count_only_) {
360 all_neighbors.emplace_back(dist, curnode.index);
364 if (dist < curnode.radius) {
365 if (curnode.left != LEAF && dist - threshold <= curnode.radius) {
366 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
369 if (curnode.right != LEAF && dist + threshold >= curnode.radius) {
370 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
374 if (curnode.right != LEAF && dist + threshold >= curnode.radius) {
375 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
378 if (curnode.left != LEAF && dist - threshold <= curnode.radius) {
379 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
384 friend class VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_>;
390 std::unique_ptr<Searcher<Index_, Data_, Distance_> >
initialize()
const {
391 return std::make_unique<VptreeSearcher<Index_, Data_, Distance_, DistanceMetric_> >(*this);