Skip to content

Commit

Permalink
Add set functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Aug 25, 2024
1 parent b799a9f commit c78031a
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
15 changes: 15 additions & 0 deletions xarray/namedarray/_array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,21 @@
"where",
]

from ._set_functions import (
unique_all,
unique_counts,
unique_inverse,
unique_values,
)

__all__ += [
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
]


from xarray.namedarray._array_api._sorting_functions import argsort, sort

__all__ += ["argsort", "sort"]
Expand Down
77 changes: 77 additions & 0 deletions xarray/namedarray/_array_api/_set_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

from typing import TYPE_CHECKING, NamedTuple, Literal

from xarray.namedarray._array_api._utils import (
_dims_to_axis,
_get_data_namespace,
_get_remaining_dims,
_infer_dims,
)
from xarray.namedarray._typing import (
Default,
_default,
_Dims,
)
from xarray.namedarray.core import NamedArray


class UniqueAllResult(NamedTuple):
values: NamedArray
indices: NamedArray
inverse_indices: NamedArray
counts: NamedArray


class UniqueCountsResult(NamedTuple):
values: NamedArray
counts: NamedArray


class UniqueInverseResult(NamedTuple):
values: NamedArray
inverse_indices: NamedArray


def unique_all(x: NamedArray, /) -> UniqueAllResult:
xp = _get_data_namespace(x)
values, indices, inverse_indices, counts = xp.unique_all(x._data)
_dims_values = _infer_dims(values.shape) # TODO: Fix
_dims_indices = _infer_dims(indices.shape) # TODO: Fix dims
_dims_inverse_indices = _infer_dims(inverse_indices.shape) # TODO: Fix dims
_dims_counts = _infer_dims(counts.shape) # TODO: Fix dims
return UniqueAllResult(
NamedArray(_dims_values, values),
NamedArray(_dims_indices, indices),
NamedArray(_dims_inverse_indices, inverse_indices),
NamedArray(_dims_counts, counts),
)


def unique_counts(x: NamedArray, /) -> UniqueCountsResult:
xp = _get_data_namespace(x)
values, counts = xp.unique_counts(x._data)
_dims_values = _infer_dims(values.shape) # TODO: Fix dims
_dims_counts = _infer_dims(counts.shape) # TODO: Fix dims
return UniqueCountsResult(
NamedArray(_dims_values, values),
NamedArray(_dims_counts, counts),
)


def unique_inverse(x: NamedArray, /) -> UniqueInverseResult:
xp = _get_data_namespace(x)
values, inverse_indices = xp.unique_inverse(x._data)
_dims_values = _infer_dims(values.shape) # TODO: Fix
_dims_inverse_indices = _infer_dims(inverse_indices.shape) # TODO: Fix dims
return UniqueInverseResult(
NamedArray(_dims_values, values),
NamedArray(_dims_inverse_indices, inverse_indices),
)


def unique_values(x: NamedArray, /) -> NamedArray:
xp = _get_data_namespace(x)
_data = xp.unique_values(x._data)
_dims = _infer_dims(_data.shape) # TODO: Fix
return x._new(_dims, _data)

0 comments on commit c78031a

Please sign in to comment.