130 std::vector<Store_> my_data;
143 static const Index_ LEAF = 0;
161 std::vector<Node> my_nodes;
163 typedef std::pair<Float_, Index_> DataPoint;
166 Index_ build(Index_ lower, Index_ upper,
const Store_* coords, std::vector<DataPoint>& items, Rng_& rng) {
173 Index_ pos = my_nodes.size();
174 my_nodes.emplace_back();
175 Node& node = my_nodes.back();
177 Index_ gap = upper - lower;
188 Index_ i = (rng() % gap + lower);
189 std::swap(items[lower], items[i]);
190 const auto& vantage = items[lower];
191 node.index = vantage.second;
192 const Store_* vantage_ptr = coords +
static_cast<size_t>(vantage.second) * my_long_ndim;
195 for (Index_ i = lower + 1; i < upper; ++i) {
196 const Store_* loc = coords +
static_cast<size_t>(items[i].second) * my_long_ndim;
197 items[i].first = Distance_::template raw_distance<Float_>(vantage_ptr, loc, my_dim);
201 Index_ median = lower + gap/2;
202 Index_ lower_p1 = lower + 1;
203 std::nth_element(items.begin() + lower_p1, items.begin() + median, items.begin() + upper);
206 node.radius = Distance_::normalize(items[median].first);
209 if (lower_p1 < median) {
210 node.left = build(lower_p1, median, coords, items, rng);
212 if (median < upper) {
213 node.right = build(median, upper, coords, items, rng);
217 const auto& leaf = items[lower];
218 node.index = leaf.second;
225 std::vector<Index_> my_new_locations;
236 my_long_ndim(my_dim),
237 my_data(std::move(data))
240 std::vector<DataPoint> items;
241 items.reserve(my_obs);
242 for (Index_ i = 0; i < my_obs; ++i) {
243 items.emplace_back(0, i);
246 my_nodes.reserve(my_obs);
252 uint64_t base = 1234567890, m1 = my_obs, m2 = my_dim;
253 std::mt19937_64 rand(base * m1 + m2);
255 build(0, my_obs, my_data.data(), items, rand);
259 std::vector<uint8_t> used(my_obs);
260 std::vector<Store_> buffer(my_dim);
261 my_new_locations.resize(my_obs);
262 auto host = my_data.data();
264 for (Index_ o = 0; o < num_obs; ++o) {
269 auto& current = my_nodes[o];
270 my_new_locations[current.index] = o;
271 if (current.index == o) {
275 auto optr = host +
static_cast<size_t>(o) * my_long_ndim;
276 std::copy_n(optr, my_dim, buffer.begin());
277 Index_ replacement = current.index;
280 auto rptr = host +
static_cast<size_t>(replacement) * my_long_ndim;
281 std::copy_n(rptr, my_dim, optr);
282 used[replacement] = 1;
284 const auto& next = my_nodes[replacement];
285 my_new_locations[next.index] = replacement;
288 replacement = next.index;
289 }
while (replacement != o);
291 std::copy(buffer.begin(), buffer.end(), optr);
297 template<
typename Query_>
298 void search_nn(Index_ curnode_index,
const Query_* target, Float_& max_dist, internal::NeighborQueue<Index_, Float_>& nearest)
const {
299 auto nptr = my_data.data() +
static_cast<size_t>(curnode_index) * my_long_ndim;
300 Float_ dist = Distance_::normalize(Distance_::template raw_distance<Float_>(nptr, target, my_dim));
303 const auto& curnode = my_nodes[curnode_index];
304 if (dist <= max_dist) {
305 nearest.add(curnode.index, dist);
306 if (nearest.is_full()) {
307 max_dist = nearest.limit();
311 if (dist < curnode.radius) {
312 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) {
313 search_nn(curnode.left, target, max_dist, nearest);
316 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) {
317 search_nn(curnode.right, target, max_dist, nearest);
321 if (curnode.right != LEAF && dist + max_dist >= curnode.radius) {
322 search_nn(curnode.right, target, max_dist, nearest);
325 if (curnode.left != LEAF && dist - max_dist <= curnode.radius) {
326 search_nn(curnode.left, target, max_dist, nearest);
331 template<
bool count_only_,
typename Query_,
typename Output_>
332 void search_all(Index_ curnode_index,
const Query_* target, Float_ threshold, Output_& all_neighbors)
const {
333 auto nptr = my_data.data() +
static_cast<size_t>(curnode_index) * my_long_ndim;
334 Float_ dist = Distance_::normalize(Distance_::template raw_distance<Float_>(nptr, target, my_dim));
337 const auto& curnode = my_nodes[curnode_index];
338 if (dist <= threshold) {
339 if constexpr(count_only_) {
342 all_neighbors.emplace_back(dist, curnode.index);
346 if (dist < curnode.radius) {
347 if (curnode.left != LEAF && dist - threshold <= curnode.radius) {
348 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
351 if (curnode.right != LEAF && dist + threshold >= curnode.radius) {
352 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
356 if (curnode.right != LEAF && dist + threshold >= curnode.radius) {
357 search_all<count_only_>(curnode.right, target, threshold, all_neighbors);
360 if (curnode.left != LEAF && dist - threshold <= curnode.radius) {
361 search_all<count_only_>(curnode.left, target, threshold, all_neighbors);
366 friend class VptreeSearcher<Distance_, Dim_, Index_, Store_, Float_>;
372 std::unique_ptr<Searcher<Index_, Float_> >
initialize()
const {
373 return std::make_unique<VptreeSearcher<Distance_, Dim_, Index_, Store_, Float_> >(
this);