34inline static constexpr const char* bruteforce_prebuilt_save_name =
"knncolle::Bruteforce";
39template<
typename Index_,
typename Data_,
typename Distance_,
typename DistanceMetric_>
40class BruteforcePrebuilt;
42template<
typename Index_,
typename Data_,
typename Distance_,
class DistanceMetric_>
43class BruteforceSearcher final :
public Searcher<Index_, Data_, Distance_> {
45 BruteforceSearcher(
const BruteforcePrebuilt<Index_, Data_, Distance_, DistanceMetric_>& parent) : my_parent(parent) {}
48 const BruteforcePrebuilt<Index_, Data_, Distance_, DistanceMetric_>& my_parent;
50 std::vector<std::pair<Distance_, Index_> > my_all_neighbors;
53 void normalize(std::vector<Distance_>* output_distances)
const {
54 if (output_distances) {
55 for (
auto& d : *output_distances) {
56 d = my_parent.my_metric->normalize(d);
62 void search(Index_ i, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
63 my_nearest.
reset(k + 1);
64 auto ptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(i, my_parent.my_dim);
65 my_parent.search(ptr, my_nearest);
66 my_nearest.
report(output_indices, output_distances, i);
67 normalize(output_distances);
70 void search(
const Data_* query, Index_ k, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
73 output_indices->clear();
75 if (output_distances) {
76 output_distances->clear();
80 my_parent.search(query, my_nearest);
81 my_nearest.
report(output_indices, output_distances);
82 normalize(output_distances);
86 bool can_search_all()
const {
90 Index_ search_all(Index_ i, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
91 auto ptr = my_parent.my_data.data() + sanisizer::product_unsafe<std::size_t>(i, my_parent.my_dim);
93 if (!output_indices && !output_distances) {
95 my_parent.template search_all<true>(ptr, d, count);
99 my_all_neighbors.clear();
100 my_parent.template search_all<false>(ptr, d, my_all_neighbors);
102 normalize(output_distances);
107 Index_ search_all(
const Data_* query, Distance_ d, std::vector<Index_>* output_indices, std::vector<Distance_>* output_distances) {
108 if (!output_indices && !output_distances) {
110 my_parent.template search_all<true>(query, d, count);
114 my_all_neighbors.clear();
115 my_parent.template search_all<false>(query, d, my_all_neighbors);
117 normalize(output_distances);
118 return my_all_neighbors.size();
123template<
typename Index_,
typename Data_,
typename Distance_,
class DistanceMetric_>
124class BruteforcePrebuilt final :
public Prebuilt<Index_, Data_, Distance_> {
128 std::vector<Data_> my_data;
129 std::shared_ptr<const DistanceMetric_> my_metric;
132 BruteforcePrebuilt(std::size_t num_dim, Index_ num_obs, std::vector<Data_> data, std::shared_ptr<const DistanceMetric_> metric) :
133 my_dim(num_dim), my_obs(num_obs), my_data(std::move(data)), my_metric(std::move(metric)) {}
136 std::size_t num_dimensions()
const {
140 Index_ num_observations()
const {
146 Distance_ threshold_raw = std::numeric_limits<Distance_>::infinity();
147 for (Index_ x = 0; x < my_obs; ++x) {
148 auto dist_raw = my_metric->raw(my_dim, query, my_data.data() + sanisizer::product_unsafe<std::size_t>(x, my_dim));
149 if (dist_raw <= threshold_raw) {
150 nearest.
add(x, dist_raw);
152 threshold_raw = nearest.
limit();
158 template<
bool count_only_,
typename Output_>
159 void search_all(
const Data_* query, Distance_ threshold, Output_& all_neighbors)
const {
160 Distance_ threshold_raw = my_metric->denormalize(threshold);
161 for (Index_ x = 0; x < my_obs; ++x) {
162 Distance_ raw = my_metric->raw(my_dim, query, my_data.data() + sanisizer::product_unsafe<std::size_t>(x, my_dim));
163 if (threshold_raw >= raw) {
164 if constexpr(count_only_) {
167 all_neighbors.emplace_back(raw, x);
173 friend class BruteforceSearcher<Index_, Data_, Distance_, DistanceMetric_>;
176 std::unique_ptr<Searcher<Index_, Data_, Distance_> > initialize()
const {
177 return initialize_known();
180 auto initialize_known()
const {
181 return std::make_unique<BruteforceSearcher<Index_, Data_, Distance_, DistanceMetric_> >(*this);
185 void save(
const std::filesystem::path& dir)
const {
186 quick_save(dir /
"ALGORITHM", bruteforce_prebuilt_save_name, std::strlen(bruteforce_prebuilt_save_name));
187 quick_save(dir /
"DATA", my_data.data(), my_data.size());
191 const auto distdir = dir /
"DISTANCE";
192 std::filesystem::create_directory(distdir);
193 my_metric->save(distdir);
196 BruteforcePrebuilt(
const std::filesystem::path& dir) {
200 my_data.resize(sanisizer::product<I<
decltype(my_data.size())> >(sanisizer::attest_gez(my_obs), my_dim));
201 quick_load(dir /
"DATA", my_data.data(), my_data.size());
203 auto dptr = load_distance_metric_raw<Data_, Distance_>(dir /
"DISTANCE");
204 auto xptr =
dynamic_cast<DistanceMetric_*
>(dptr);
206 throw std::runtime_error(
"cannot cast the loaded distance metric to a DistanceMetric_");
208 my_metric.reset(xptr);
243 BruteforceBuilder(std::shared_ptr<const DistanceMetric_> metric) : my_metric(std::move(metric)) {}
246 std::shared_ptr<const DistanceMetric_> my_metric;
264 std::size_t ndim = data.num_dimensions();
265 const Index_ nobs = data.num_observations();
266 auto work = data.new_known_extractor();
269 std::vector<Data_> store(sanisizer::product<
typename std::vector<Data_>::size_type>(ndim, sanisizer::attest_gez(nobs)));
270 for (Index_ o = 0; o < nobs; ++o) {
271 std::copy_n(work->next(), ndim, store.data() + sanisizer::product_unsafe<std::size_t>(o, ndim));
274 return new BruteforcePrebuilt<Index_, Data_, Distance_, DistanceMetric_>(ndim, nobs, std::move(store), my_metric);