-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: add topk and / or argpartition #629
Comments
Note: since |
Thanks for the proposal @ogrisel. It's actually surprising that coverage and performance across array libraries is so spotty. I dug up the NumPy mailing list discussion, and it seemed more or less positive, just unfinished and the name to use is a nicely-sized bikeshed. Is this function something you already have in scikit-learn internally, or are you looking for something more efficient than the |
In scikit-learn, for k-nearest neighbors (bruteforce exact method for medium to high dimensional space), we use a routine optimized for multicore CPUs using Cython + OpenMP for pairwise distance (similar to scipy's This code can only be called as a reduction fused into the multithreaded pairwise distance computation kernel. It is orchestrated via: For CPU, I doubt than any Array API based solution will be able to compete both on speed and memory usage. However, we are interested in implementing Array API support for an alternative numpy code-path in order to provide GPU support, e.g. via PyTorch or CuPy. The reducer used in the numpy code-path is there: It's based on Note that to efficiently implement k-nearest neighbors in scikit-learn using the Array API, we would also need the Array API to provide scipy.spatial.distance.cdist. I have not open an issue to discuss |
JAX also has an approximate top-k implementation specifically tuned for TPUs: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.approx_max_k.html |
I am not sure if we want to include non-exact methods in the spec. I have the feeling that there are many ways to compute such approximations and that they will require different and evolving parametrizations with different speed-accuracy trade-offs. |
A PR has now been opened which proposes adding |
numpy provides an indirect way to compute the indices of the smallest (or largest) values of an array using: numpy.argpartition.
There is also a proposal to provide a higher level API, namely (arg)topk in numpy:
This PR relies on
numpy.argpartition
internally but it can probably later be optimized to avoid allocating a result array of the size of the input array whenk
is small.Here is a quick review of some available implementations in related libraries:
torch.argpartition
)cupy.argsort
which makes it very inefficient for large arrays and smallk
: O(nlog(n)) instead of O(n).Motivation: (arg)topk is needed by popular baseline data-science workloads (e.g. k-nearest neighbors classification in scikit-learn) and is surprisingly non trivial to implement efficiently. For instance on GPUs, the fastest implementations are based on some kind of partial radix sort while CPU implementations would use more traditional partial sorting algorithms (as implemented in
std:partial_sort
orstd::nth_element
).The text was updated successfully, but these errors were encountered: