Source code for knncolle.query_neighbors

from functools import singledispatch
from typing import Sequence, Optional, Union
from dataclasses import dataclass
import numpy

from .classes import Index, GenericIndex
from . import lib_knncolle as lib
from ._utils import process_threshold, process_subset


[docs] @dataclass class QueryNeighborsResults: """Results of :py:func:`~knncolle.query_neighbors.query_neighbors`. ``index`` and ``distance`` are lists where each element is a NumPy array that corresponds to an observation in ``query``. Each array contains the indices of (for ``index``) or distances to (for ``distance``) the observations of ``X`` that neighbor the corresponding observation within the specified threshold distance. For each query observation, neighbors are guaranteed to be sorted in order of increasing distance. If ``get_index = False``, ``index`` is set to None. If ``get_distance = False``, ``distance`` is set to None. """ index: Optional[list] distance: Optional[list]
[docs] @singledispatch def query_neighbors( X: Index, threshold: Union[float, Sequence], num_threads: int = 1, subset: Optional[Sequence] = None, get_index: bool = True, get_distance: bool = True, **kwargs ) -> QueryNeighborsResults: """Find all observations in the search index that lie within a threshold distance of each observation in the query dataset. Args: X: A prebuilt search index. query: Matrix of coordinates for the query observations. This should be a double-precision row-major NumPy matrix where the rows are dimensions and columns are observations. The number of dimensions should be consistent with that in ``X``. threshold: Distance threshold at which to identify neighbors for each observation in ``X``. Alternatively, this may be a sequence of non-negative floats of length equal to the number of observations in ``X``, specifying the distance threshold to search for each observation. num_threads: Number of threads to use for the search. get_index: Whether to report the indices of each nearest neighbor. get_distance: Whether to report the distances to each nearest neighbor. kwargs: Additional arguments to pass to specific methods. Returns: Results of the neighbor search. """ raise NotImplementedError("no available method for '" + str(type(X)) + "'")
@query_neighbors.register def _query_neighbors_generic( X: GenericIndex, query: numpy.ndarray, threshold: Union[float, Sequence], num_threads: int = 1, get_index: bool = True, get_distance: bool = True, **kwargs ) -> QueryNeighborsResults: idx, dist = lib.generic_query_all( X.ptr, query, process_threshold(threshold), num_threads, get_index, get_distance ) return QueryNeighborsResults(index = idx, distance = dist)