knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
NeighborQueue.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_NEIGHBOR_QUEUE_HPP
2#define KNNCOLLE_NEIGHBOR_QUEUE_HPP
3
4#include "utils.hpp"
5
6#include <queue>
7#include <vector>
8#include <algorithm>
9#include <cassert>
10
11#include "sanisizer/sanisizer.hpp"
12
18namespace knncolle {
19
35template<typename Index_, typename Distance_>
37public:
42 NeighborQueue() = default;
43
44public:
52 void reset(Index_ k) {
53 // We don't allow k == 0 as otherwise we'd be in a position where is_full() == true but limit() can't be called.
54 // If the caller doesn't want any neighbors, they're better of just aborting the search altogether.
55 assert(k > 0);
56 my_neighbors = sanisizer::cast<I<decltype(my_neighbors)> >(k);
57 my_full = false;
58
59 // Popping any existing elements out, just in case. This shouldn't
60 // usually be necessary if report() was called as the queue should
61 // already be completely exhausted, but sometimes report() is a no-op
62 // or there might have been an intervening exception, etc.
63 while (!my_nearest.empty()) {
64 my_nearest.pop();
65 }
66 }
67
68public:
72 bool is_full() const {
73 return my_full;
74 }
75
80 Distance_ limit() const {
81 return my_nearest.top().first;
82 }
83
88 auto size() const {
89 return my_nearest.size();
90 }
91
92public:
100 void add(Index_ i, Distance_ d) {
101 my_nearest.emplace(d, i);
102 if (my_full) {
103 my_nearest.pop();
104 } else if (size() == my_neighbors) {
105 my_full = true;
106 }
107 }
108
109private:
110 template<bool has_indices_, bool has_distances_>
111 void report_internal(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances, const Index_ self) {
112 // We expect that nearest is non-empty, as a search should at least find 'self' (or duplicates thereof).
113 assert(!my_nearest.empty());
114 const Index_ num_expected = size() - 1;
115
116 if constexpr(has_indices_) {
117 output_indices->clear();
118 output_indices->reserve(num_expected);
119 }
120 if constexpr(has_distances_) {
121 output_distances->clear();
122 output_distances->reserve(num_expected);
123 }
124
125 bool found_self = false;
126 while (!my_nearest.empty()) {
127 const auto& top = my_nearest.top();
128 if (!found_self && top.second == self) {
129 found_self = true;
130 } else {
131 if constexpr(has_indices_) {
132 output_indices->push_back(top.second);
133 }
134 if constexpr(has_distances_) {
135 output_distances->push_back(top.first);
136 }
137 }
138 my_nearest.pop();
139 }
140
141 // We use push_back + reverse to give us sorting in increasing order;
142 // this is nicer than push_front() for std::vectors.
143 if constexpr(has_indices_) {
144 std::reverse(output_indices->begin(), output_indices->end());
145 }
146 if constexpr(has_distances_) {
147 std::reverse(output_distances->begin(), output_distances->end());
148 }
149
150 // Removing the most distance element if we couldn't find ourselves,
151 // e.g., because there are too many duplicates.
152 if (!found_self) {
153 if constexpr(has_indices_) {
154 output_indices->pop_back();
155 }
156 if constexpr(has_distances_) {
157 output_distances->pop_back();
158 }
159 }
160 }
161
162public:
178 void report(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances, Index_ self) {
179 if (output_indices && output_distances) {
180 report_internal<true, true>(output_indices, output_distances, self);
181 } else if (output_indices) {
182 report_internal<true, false>(output_indices, NULL, self);
183 } else if (output_distances) {
184 report_internal<false, true>(NULL, output_distances, self);
185 }
186 }
187
188private:
189 template<bool has_indices_, bool has_distances_>
190 void report_internal(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
191 auto position = my_nearest.size();
192
193 if constexpr(has_indices_) {
194 sanisizer::resize(*output_indices, position);
195 }
196 if constexpr(has_distances_) {
197 sanisizer::resize(*output_distances, position);
198 }
199
200 while (!my_nearest.empty()) {
201 const auto& top = my_nearest.top();
202 --position;
203 if constexpr(has_indices_) {
204 (*output_indices)[position] = top.second;
205 }
206 if constexpr(has_distances_) {
207 (*output_distances)[position] = top.first;
208 }
209 my_nearest.pop();
210 }
211 }
212
213public:
223 void report(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
224 if (output_indices && output_distances) {
225 report_internal<true, true>(output_indices, output_distances);
226 } else if (output_indices) {
227 report_internal<true, false>(output_indices, NULL);
228 } else if (output_distances) {
229 report_internal<false, true>(NULL, output_distances);
230 }
231 }
232
233private:
234 bool my_full = false;
235 std::priority_queue<std::pair<Distance_, Index_> > my_nearest;
236 I<decltype(my_nearest.size())> my_neighbors = 1;
237};
238
239}
240
241#endif
Helper class to track nearest neighbors.
Definition NeighborQueue.hpp:36
auto size() const
Definition NeighborQueue.hpp:88
void report(std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition NeighborQueue.hpp:178
void report(std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition NeighborQueue.hpp:223
void add(Index_ i, Distance_ d)
Definition NeighborQueue.hpp:100
Distance_ limit() const
Definition NeighborQueue.hpp:80
bool is_full() const
Definition NeighborQueue.hpp:72
void reset(Index_ k)
Definition NeighborQueue.hpp:52
Collection of KNN algorithms.
Definition Bruteforce.hpp:29
Miscellaneous utilities for knncolle