From c78031acbef4dd1329bd3555651981d78b6e6da0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 14:32:11 +0200 Subject: [PATCH] Add set functions --- xarray/namedarray/_array_api/__init__.py | 15 ++++ .../namedarray/_array_api/_set_functions.py | 77 +++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 xarray/namedarray/_array_api/_set_functions.py diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 99628406abd..72ca9984815 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -307,6 +307,21 @@ "where", ] +from ._set_functions import ( + unique_all, + unique_counts, + unique_inverse, + unique_values, +) + +__all__ += [ + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", +] + + from xarray.namedarray._array_api._sorting_functions import argsort, sort __all__ += ["argsort", "sort"] diff --git a/xarray/namedarray/_array_api/_set_functions.py b/xarray/namedarray/_array_api/_set_functions.py new file mode 100644 index 00000000000..61e3747a277 --- /dev/null +++ b/xarray/namedarray/_array_api/_set_functions.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple, Literal + +from xarray.namedarray._array_api._utils import ( + _dims_to_axis, + _get_data_namespace, + _get_remaining_dims, + _infer_dims, +) +from xarray.namedarray._typing import ( + Default, + _default, + _Dims, +) +from xarray.namedarray.core import NamedArray + + +class UniqueAllResult(NamedTuple): + values: NamedArray + indices: NamedArray + inverse_indices: NamedArray + counts: NamedArray + + +class UniqueCountsResult(NamedTuple): + values: NamedArray + counts: NamedArray + + +class UniqueInverseResult(NamedTuple): + values: NamedArray + inverse_indices: NamedArray + + +def unique_all(x: NamedArray, /) -> UniqueAllResult: + xp = _get_data_namespace(x) + values, indices, inverse_indices, counts = xp.unique_all(x._data) + _dims_values = _infer_dims(values.shape) # TODO: Fix + _dims_indices = _infer_dims(indices.shape) # TODO: Fix dims + _dims_inverse_indices = _infer_dims(inverse_indices.shape) # TODO: Fix dims + _dims_counts = _infer_dims(counts.shape) # TODO: Fix dims + return UniqueAllResult( + NamedArray(_dims_values, values), + NamedArray(_dims_indices, indices), + NamedArray(_dims_inverse_indices, inverse_indices), + NamedArray(_dims_counts, counts), + ) + + +def unique_counts(x: NamedArray, /) -> UniqueCountsResult: + xp = _get_data_namespace(x) + values, counts = xp.unique_counts(x._data) + _dims_values = _infer_dims(values.shape) # TODO: Fix dims + _dims_counts = _infer_dims(counts.shape) # TODO: Fix dims + return UniqueCountsResult( + NamedArray(_dims_values, values), + NamedArray(_dims_counts, counts), + ) + + +def unique_inverse(x: NamedArray, /) -> UniqueInverseResult: + xp = _get_data_namespace(x) + values, inverse_indices = xp.unique_inverse(x._data) + _dims_values = _infer_dims(values.shape) # TODO: Fix + _dims_inverse_indices = _infer_dims(inverse_indices.shape) # TODO: Fix dims + return UniqueInverseResult( + NamedArray(_dims_values, values), + NamedArray(_dims_inverse_indices, inverse_indices), + ) + + +def unique_values(x: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.unique_values(x._data) + _dims = _infer_dims(_data.shape) # TODO: Fix + return x._new(_dims, _data)