knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
NeighborQueue.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_NEIGHBOR_QUEUE_HPP
2#define KNNCOLLE_NEIGHBOR_QUEUE_HPP
3
4#include "utils.hpp"
5
6#include <queue>
7#include <vector>
8#include <algorithm>
9#include <cassert>
10
11#include "sanisizer/sanisizer.hpp"
12
18namespace knncolle {
19
34template<typename Index_, typename Distance_>
36public:
41 NeighborQueue() = default;
42
43public:
51 void reset(Index_ k) {
52 my_neighbors = sanisizer::cast<I<decltype(my_neighbors)> >(sanisizer::attest_gez(k));
53 my_full = false;
54
55 // Popping any existing elements out, just in case. This shouldn't
56 // usually be necessary if report() was called as the queue should
57 // already be completely exhausted, but sometimes report() is a no-op
58 // or there might have been an intervening exception, etc.
59 while (!my_nearest.empty()) {
60 my_nearest.pop();
61 }
62 }
63
64public:
68 bool is_full() const {
69 return my_full;
70 }
71
76 Distance_ limit() const {
77 return my_nearest.top().first;
78 }
79
80public:
88 void add(Index_ i, Distance_ d) {
89 if (!my_full) {
90 my_nearest.emplace(d, i);
91 if (my_nearest.size() == my_neighbors) {
92 my_full = true;
93 }
94 } else {
95 my_nearest.emplace(d, i);
96 my_nearest.pop();
97 }
98 return;
99 }
100
101private:
102 template<bool has_indices_, bool has_distances_>
103 void report_internal(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances, const Index_ self) {
104 // We expect that nearest is non-empty, as a search should at least
105 // find 'self' (or duplicates thereof).
106 assert(!my_nearest.empty());
107 const Index_ num_expected = my_nearest.size() - 1;
108
109 if constexpr(has_indices_) {
110 output_indices->clear();
111 output_indices->reserve(num_expected);
112 }
113 if constexpr(has_distances_) {
114 output_distances->clear();
115 output_distances->reserve(num_expected);
116 }
117
118 bool found_self = false;
119 while (!my_nearest.empty()) {
120 const auto& top = my_nearest.top();
121 if (!found_self && top.second == self) {
122 found_self = true;
123 } else {
124 if constexpr(has_indices_) {
125 output_indices->push_back(top.second);
126 }
127 if constexpr(has_distances_) {
128 output_distances->push_back(top.first);
129 }
130 }
131 my_nearest.pop();
132 }
133
134 // We use push_back + reverse to give us sorting in increasing order;
135 // this is nicer than push_front() for std::vectors.
136 if constexpr(has_indices_) {
137 std::reverse(output_indices->begin(), output_indices->end());
138 }
139 if constexpr(has_distances_) {
140 std::reverse(output_distances->begin(), output_distances->end());
141 }
142
143 // Removing the most distance element if we couldn't find ourselves,
144 // e.g., because there are too many duplicates.
145 if (!found_self) {
146 if constexpr(has_indices_) {
147 output_indices->pop_back();
148 }
149 if constexpr(has_distances_) {
150 output_distances->pop_back();
151 }
152 }
153 }
154
155public:
168 void report(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances, Index_ self) {
169 if (output_indices && output_distances) {
170 report_internal<true, true>(output_indices, output_distances, self);
171 } else if (output_indices) {
172 report_internal<true, false>(output_indices, NULL, self);
173 } else if (output_distances) {
174 report_internal<false, true>(NULL, output_distances, self);
175 }
176 }
177
178private:
179 template<bool has_indices_, bool has_distances_>
180 void report_internal(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
181 auto position = my_nearest.size();
182
183 if constexpr(has_indices_) {
184 sanisizer::resize(*output_indices, position);
185 }
186 if constexpr(has_distances_) {
187 sanisizer::resize(*output_distances, position);
188 }
189
190 while (!my_nearest.empty()) {
191 const auto& top = my_nearest.top();
192 --position;
193 if constexpr(has_indices_) {
194 (*output_indices)[position] = top.second;
195 }
196 if constexpr(has_distances_) {
197 (*output_distances)[position] = top.first;
198 }
199 my_nearest.pop();
200 }
201 }
202
203public:
214 void report(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
215 if (output_indices && output_distances) {
216 report_internal<true, true>(output_indices, output_distances);
217 } else if (output_indices) {
218 report_internal<true, false>(output_indices, NULL);
219 } else if (output_distances) {
220 report_internal<false, true>(NULL, output_distances);
221 }
222 }
223
224private:
225 bool my_full = false;
226 std::priority_queue<std::pair<Distance_, Index_> > my_nearest;
227 I<decltype(my_nearest.size())> my_neighbors = 1;
228};
229
230}
231
232#endif
Helper class to track nearest neighbors.
Definition NeighborQueue.hpp:35
void report(std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition NeighborQueue.hpp:168
void report(std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition NeighborQueue.hpp:214
void add(Index_ i, Distance_ d)
Definition NeighborQueue.hpp:88
Distance_ limit() const
Definition NeighborQueue.hpp:76
bool is_full() const
Definition NeighborQueue.hpp:68
void reset(Index_ k)
Definition NeighborQueue.hpp:51
Collection of KNN algorithms.
Definition Bruteforce.hpp:29
Miscellaneous utilities for knncolle