Skip to content

Commit

Permalink
Update _set_functions.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Sep 27, 2024
1 parent 5496da9 commit 6499fb7
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions xarray/namedarray/_array_api/_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,20 @@ def unique_counts(x: NamedArray[Any, Any], /) -> UniqueCountsResult:
>>> x = NamedArray(("x",), np.array([0, 1, 2, 2], dtype=int))
>>> x_unique = unique_counts(x)
>>> x_unique.values
<xarray.NamedArray (x: 3)> Size: 24B
array([0, 1, 2])
>>> x_unique.counts
<xarray.NamedArray (x: 3)> Size: 24B
array([1, 1, 2])
>>> x = NamedArray(("x", "y"), np.array([0, 1, 2, 2], dtype=int).reshape((2, 2)))
>>> x_unique = unique_counts(x)
>>> x_unique.values
<xarray.NamedArray (('x', 'y'): 3)> Size: 24B
array([0, 1, 2])
>>> x_unique.counts
<xarray.NamedArray (('x', 'y'): 3)> Size: 24B
array([1, 1, 2])
"""
xp = _get_data_namespace(x)
values, counts = xp.unique_counts(x._data)
Expand All @@ -77,17 +85,27 @@ def unique_inverse(x: NamedArray[Any, Any], /) -> UniqueInverseResult:
>>> x = NamedArray(("x",), np.array([0, 1, 2, 2], dtype=int))
>>> x_unique = unique_inverse(x)
>>> x_unique.values
<xarray.NamedArray (x: 3)> Size: 24B
array([0, 1, 2])
>>> x_unique.inverse_indices
<xarray.NamedArray (x: 4)> Size: 32B
array([0, 1, 2, 2])
>>> x = NamedArray(("x", "y"), np.array([0, 1, 2, 2], dtype=int).reshape((2, 2)))
>>> x_unique = unique_inverse(x)
>>> x_unique.dims, x_unique.shape
(('x',), (3,))
>>> x_unique.values
<xarray.NamedArray (('x', 'y'): 3)> Size: 24B
array([0, 1, 2])
>>> x_unique.inverse_indices
<xarray.NamedArray (x: 2, y: 2)> Size: 32B
array([[0, 1],
[2, 2]])
"""
xp = _get_data_namespace(x)
values, inverse_indices = xp.unique_inverse(x._data)
return UniqueInverseResult(
NamedArray(_flatten_dims(_atleast1d_dims(x.dims)), values),
NamedArray(_flatten_dims(x.dims), inverse_indices),
NamedArray(x.dims, inverse_indices),
)


Expand All @@ -99,22 +117,28 @@ def unique_values(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]:
--------
>>> import numpy as np
>>> x = NamedArray(("x",), np.array([0, 1, 2, 2], dtype=int))
>>> x_unique = unique_values(x)
>>> x_unique.dims, x_unique.shape
(('x',), (3,))
>>> unique_values(x)
<xarray.NamedArray (x: 3)> Size: 24B
array([0, 1, 2])
>>> x = NamedArray(("x", "y"), np.array([0, 1, 2, 2], dtype=int).reshape((2, 2)))
>>> x_unique = unique_values(x)
>>> x_unique.dims, x_unique.shape
(('x',), (3,))
>>> unique_values(x)
<xarray.NamedArray (('x', 'y'): 3)> Size: 24B
array([0, 1, 2])
# Scalars becomes 1-dimensional
>>> x = NamedArray((), np.array(0, dtype=int))
x_unique = unique_values(x)
>>> x_unique.dims, x_unique.shape
(('dim_0',), (1,))
>>> unique_values(x)
<xarray.NamedArray (dim_0: 1)> Size: 8B
array([0])
"""
xp = _get_data_namespace(x)
_data = xp.unique_values(x._data)
_dims = _flatten_dims(_atleast1d_dims(x.dims))
return x._new(_dims, _data)


if __name__ == "__main__":
import doctest

doctest.testmod()

0 comments on commit 6499fb7

Please sign in to comment.