1#ifndef KNNCOLLE_BRUTEFORCE_HPP
2#define KNNCOLLE_BRUTEFORCE_HPP
5#include "NeighborQueue.hpp"
10#include "report_all_neighbors.hpp"
25template<
typename Index_,
typename Data_,
typename Distance_,
typename DistanceMetric_>
39template<
typename Index_,
typename Data_,
typename Distance_,
class DistanceMetric_>
52 internal::NeighborQueue<Index_, Distance_> my_nearest;
53 std::vector<std::pair<Distance_, Index_> > my_all_neighbors;
56 void normalize(std::vector<Distance_>* output_distances)
const {
57 if (output_distances) {
58 for (
auto& d : *output_distances) {
59 d = my_parent.my_metric->normalize(d);
65 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
66 my_nearest.reset(k + 1);
67 auto ptr = my_parent.my_data.data() +
static_cast<size_t>(i) * my_parent.my_dim;
68 my_parent.search(ptr, my_nearest);
69 my_nearest.report(output_indices, output_distances, i);
70 normalize(output_distances);
73 void search(
const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
75 internal::flush_output(output_indices, output_distances, 0);
78 my_parent.search(query, my_nearest);
79 my_nearest.report(output_indices, output_distances);
80 normalize(output_distances);
88 Index_
search_all(Index_ i, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
89 auto ptr = my_parent.my_data.data() +
static_cast<size_t>(i) * my_parent.my_dim;
91 if (!output_indices && !output_distances) {
94 return internal::safe_remove_self(count);
97 my_all_neighbors.clear();
99 internal::report_all_neighbors(my_all_neighbors, output_indices, output_distances, i);
100 normalize(output_distances);
101 return internal::safe_remove_self(my_all_neighbors.size());
105 Index_
search_all(
const Data_* query, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
106 if (!output_indices && !output_distances) {
112 my_all_neighbors.clear();
114 internal::report_all_neighbors(my_all_neighbors, output_indices, output_distances);
115 normalize(output_distances);
116 return my_all_neighbors.size();
132template<
typename Index_,
typename Data_,
typename Distance_,
class DistanceMetric_>
137 std::vector<Data_> my_data;
138 std::shared_ptr<const DistanceMetric_> my_metric;
144 BruteforcePrebuilt(
size_t num_dim, Index_ num_obs, std::vector<Data_> data, std::shared_ptr<const DistanceMetric_> metric) :
145 my_dim(num_dim), my_obs(num_obs), my_data(std::move(data)), my_metric(std::move(metric)) {}
160 void search(
const Data_* query, internal::NeighborQueue<Index_, Distance_>& nearest)
const {
161 auto copy = my_data.data();
162 Distance_ threshold_raw = std::numeric_limits<Distance_>::infinity();
163 for (Index_ x = 0; x < my_obs; ++x, copy += my_dim) {
164 auto dist_raw = my_metric->raw(my_dim, query, copy);
165 if (dist_raw <= threshold_raw) {
166 nearest.add(x, dist_raw);
167 if (nearest.is_full()) {
168 threshold_raw = nearest.limit();
174 template<
bool count_only_,
typename Output_>
175 void search_all(
const Data_* query, Distance_ threshold, Output_& all_neighbors)
const {
176 Distance_ threshold_raw = my_metric->denormalize(threshold);
177 auto copy = my_data.data();
178 for (Index_ x = 0; x < my_obs; ++x, copy += my_dim) {
179 Distance_ raw = my_metric->raw(my_dim, query, copy);
180 if (threshold_raw >= raw) {
181 if constexpr(count_only_) {
184 all_neighbors.emplace_back(raw, x);
190 friend class BruteforceSearcher<Index_, Data_, Distance_, DistanceMetric_>;
196 std::unique_ptr<Searcher<Index_, Data_, Distance_> >
initialize()
const {
197 return std::make_unique<BruteforceSearcher<Index_, Data_, Distance_, DistanceMetric_> >(*this);
221 class Matrix_ = Matrix<Index_, Data_>,
222 class DistanceMetric_ = DistanceMetric<Data_, Distance_>
229 BruteforceBuilder(std::shared_ptr<const DistanceMetric_> metric) : my_metric(std::move(metric)) {}
237 std::shared_ptr<const DistanceMetric_> my_metric;
244 size_t ndim = data.num_dimensions();
245 size_t nobs = data.num_observations();
246 auto work = data.new_extractor();
248 std::vector<Data_> store(ndim * nobs);
249 for (
size_t o = 0; o < nobs; ++o) {
250 std::copy_n(work->next(), ndim, store.begin() + o * ndim);
Interface to build nearest-neighbor indices.
Interface for the input matrix.
Interface for prebuilt nearest-neighbor indices.
Interface for searching nearest-neighbor indices.
Perform a brute-force nearest neighbor search.
Definition Bruteforce.hpp:224
BruteforceBuilder(std::shared_ptr< const DistanceMetric_ > metric)
Definition Bruteforce.hpp:229
Prebuilt< Index_, Data_, Distance_ > * build_raw(const Matrix_ &data) const
Definition Bruteforce.hpp:243
BruteforceBuilder(const DistanceMetric_ *metric)
Definition Bruteforce.hpp:234
Index for a brute-force nearest neighbor search.
Definition Bruteforce.hpp:133
Index_ num_observations() const
Definition Bruteforce.hpp:155
std::unique_ptr< Searcher< Index_, Data_, Distance_ > > initialize() const
Definition Bruteforce.hpp:196
size_t num_dimensions() const
Definition Bruteforce.hpp:151
Brute-force nearest neighbor searcher.
Definition Bruteforce.hpp:40
void search(Index_ i, Index_ k, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Bruteforce.hpp:65
Index_ search_all(Index_ i, Distance_ d, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Bruteforce.hpp:88
bool can_search_all() const
Definition Bruteforce.hpp:84
Index_ search_all(const Data_ *query, Distance_ d, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Bruteforce.hpp:105
void search(const Data_ *query, Index_ k, std::vector< Index_ > *output_indices, std::vector< Distance_ > *output_distances)
Definition Bruteforce.hpp:73
Interface to build nearest-neighbor search indices.
Definition Builder.hpp:28
Interface for prebuilt nearest-neighbor search indices.
Definition Prebuilt.hpp:26
Interface for searching nearest-neighbor search indices.
Definition Searcher.hpp:28
Classes for distance calculations.
Collection of KNN algorithms.
Definition Bruteforce.hpp:23