knncolle
Collection of KNN methods in C++
Loading...
Searching...
No Matches
NeighborQueue.hpp
Go to the documentation of this file.
1#ifndef KNNCOLLE_NEIGHBOR_QUEUE_HPP
2#define KNNCOLLE_NEIGHBOR_QUEUE_HPP
3
4#include "utils.hpp"
5
6#include <queue>
7#include <vector>
8#include <algorithm>
9#include <cassert>
10
11#include "sanisizer/sanisizer.hpp"
12
18namespace knncolle {
19
34template<typename Index_, typename Distance_>
36public:
41 NeighborQueue() = default;
42
43public:
51 void reset(Index_ k) {
52 my_neighbors = sanisizer::cast<I<decltype(my_neighbors)> >(sanisizer::attest_gez(k));
53 my_full = false;
54
55 // Popping any existing elements out, just in case. This shouldn't
56 // usually be necessary if report() was called as the queue should
57 // already be completely exhausted, but sometimes report() is a no-op
58 // or there might have been an intervening exception, etc.
59 while (!my_nearest.empty()) {
60 my_nearest.pop();
61 }
62 }
63
64public:
68 bool is_full() const {
69 return my_full;
70 }
71
76 Distance_ limit() const {
77 return my_nearest.top().first;
78 }
79
80public:
88 void add(Index_ i, Distance_ d) {
89 if (!my_full) {
90 my_nearest.emplace(d, i);
91 if (my_nearest.size() == my_neighbors) {
92 my_full = true;
93 }
94 } else {
95 my_nearest.emplace(d, i);
96 my_nearest.pop();
97 }
98 return;
99 }
100
101public:
114 void report(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances, Index_ self) {
115 // We expect that nearest is non-empty, as a search should at least
116 // find 'self' (or duplicates thereof).
117 assert(!my_nearest.empty());
118 auto num_expected = my_nearest.size() - 1;
119
120 if (output_indices) {
121 output_indices->clear();
122 output_indices->reserve(num_expected);
123 }
124 if (output_distances) {
125 output_distances->clear();
126 output_distances->reserve(num_expected);
127 }
128
129 bool found_self = false;
130 while (!my_nearest.empty()) {
131 const auto& top = my_nearest.top();
132 if (!found_self && top.second == self) {
133 found_self = true;
134 } else {
135 if (output_indices) {
136 output_indices->push_back(top.second);
137 }
138 if (output_distances) {
139 output_distances->push_back(top.first);
140 }
141 }
142 my_nearest.pop();
143 }
144
145 // We use push_back + reverse to give us sorting in increasing order;
146 // this is nicer than push_front() for std::vectors.
147 if (output_indices) {
148 std::reverse(output_indices->begin(), output_indices->end());
149 }
150 if (output_distances) {
151 std::reverse(output_distances->begin(), output_distances->end());
152 }
153
154 // Removing the most distance element if we couldn't find ourselves,
155 // e.g., because there are too many duplicates.
156 if (!found_self) {
157 if (output_indices) {
158 output_indices->pop_back();
159 }
160 if (output_distances) {
161 output_distances->pop_back();
162 }
163 }
164 }
165
166private:
167 template<bool has_indices_, bool has_distances_>
168 void report_internal(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
169 auto position = my_nearest.size();
170
171 if constexpr(has_indices_) {
172 sanisizer::resize(*output_indices, position);
173 }
174 if constexpr(has_distances_) {
175 sanisizer::resize(*output_distances, position);
176 }
177
178 while (!my_nearest.empty()) {
179 const auto& top = my_nearest.top();
180 --position;
181 if constexpr(has_indices_) {
182 (*output_indices)[position] = top.second;
183 }
184 if constexpr(has_distances_) {
185 (*output_distances)[position] = top.first;
186 }
187 my_nearest.pop();
188 }
189 }
190
191public:
202 void report(std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
203 if (output_indices && output_distances) {
204 report_internal<true, true>(output_indices, output_distances);
205 } else if (output_indices) {
206 report_internal<true, false>(output_indices, NULL);
207 } else if (output_distances) {
208 report_internal<false, true>(NULL, output_distances);
209 }
210 }
211
212private:
213 bool my_full = false;
214 std::priority_queue<std::pair<Distance_, Index_> > my_nearest;
215 decltype(my_nearest.size()) my_neighbors = 1;
216};
217
218}
219
220#endif
Helper class to track nearest neighbors.
Definition NeighborQueue.hpp:35
void report(std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances, Index_ self)
Definition NeighborQueue.hpp:114
void report(std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition NeighborQueue.hpp:202
void add(Index_ i, Distance_ d)
Definition NeighborQueue.hpp:88
Distance_ limit() const
Definition NeighborQueue.hpp:76
bool is_full() const
Definition NeighborQueue.hpp:68
void reset(Index_ k)
Definition NeighborQueue.hpp:51
Collection of KNN algorithms.
Definition Bruteforce.hpp:29
Miscellaneous utilities for knncolle