Skip to content

Commit

Permalink
more places to simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Sep 25, 2024
1 parent c59fb54 commit 889feaa
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 86 deletions.
16 changes: 7 additions & 9 deletions xarray/namedarray/_array_api/_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from xarray.namedarray._array_api._utils import (
_dim_to_optional_axis,
_get_data_namespace,
_get_remaining_dims,
_infer_dims,
_reduce_dims,
)
from xarray.namedarray._typing import (
Default,
Expand All @@ -32,10 +32,9 @@ def argmax(
) -> NamedArray[Any, Any]:
xp = _get_data_namespace(x)
_axis = _dim_to_optional_axis(x, dim, axis)
_data = xp.argmax(x._data, axis=_axis, keepdims=False) # We fix keepdims later
# TODO: Why do we need to do the keepdims ourselves?
_dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims)
return x._new(dims=_dims, data=data_)
_data = xp.argmax(x._data, axis=_axis, keepdims=keepdims)
_dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims)
return x._new(dims=_dims, data=_data)


def argmin(
Expand All @@ -48,10 +47,9 @@ def argmin(
) -> NamedArray[Any, Any]:
xp = _get_data_namespace(x)
_axis = _dim_to_optional_axis(x, dim, axis)
_data = xp.argmin(x._data, axis=_axis, keepdims=False) # We fix keepdims later
# TODO: Why do we need to do the keepdims ourselves?
_dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims)
return x._new(dims=_dims, data=data_)
_data = xp.argmin(x._data, axis=_axis, keepdims=keepdims)
_dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims)
return x._new(dims=_dims, data=_data)


def nonzero(x: NamedArray[Any, Any], /) -> tuple[NamedArray[Any, Any], ...]:
Expand Down
20 changes: 9 additions & 11 deletions xarray/namedarray/_array_api/_utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from xarray.namedarray._array_api._utils import (
_dims_to_axis,
_get_data_namespace,
_get_remaining_dims,
_reduce_dims,
)
from xarray.namedarray._typing import (
Default,
Expand All @@ -25,11 +25,10 @@ def all(
axis: _AxisLike | None = None,
) -> NamedArray[Any, Any]:
xp = _get_data_namespace(x)
axis_ = _dims_to_axis(x, dims, axis)
d = xp.all(x._data, axis=axis_, keepdims=False)
dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims)
out = x._new(dims=dims_, data=data_)
return out
_axis = _dims_to_axis(x, dims, axis)
_data = xp.all(x._data, axis=_axis, keepdims=keepdims)
_dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims)
return x._new(dims=_dims, data=_data)


def any(
Expand All @@ -41,8 +40,7 @@ def any(
axis: _AxisLike | None = None,
) -> NamedArray[Any, Any]:
xp = _get_data_namespace(x)
axis_ = _dims_to_axis(x, dims, axis)
d = xp.any(x._data, axis=axis_, keepdims=False)
dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims)
out = x._new(dims=dims_, data=data_)
return out
_axis = _dims_to_axis(x, dims, axis)
_data = xp.any(x._data, axis=_axis, keepdims=keepdims)
_dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims)
return x._new(dims=_dims, data=_data)
100 changes: 34 additions & 66 deletions xarray/namedarray/_array_api/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,72 +337,6 @@ def _dim_to_axis(x: NamedArray[Any, Any], dim: _Dim | Default, axis: int) -> int
return _axis


def _get_remaining_dims(
x: NamedArray[Any, _DType],
data: duckarray[Any, _DType],
axis: _AxisLike | None,
*,
keepdims: bool,
) -> tuple[_Dims, duckarray[Any, _DType]]:
"""
Get the reamining dims after a reduce operation.
"""
if data.shape == x.shape:
return x.dims, data

removed_axes: tuple[int, ...]
if axis is None:
removed_axes = tuple(v for v in range(x.ndim))
else:
removed_axes = _normalize_axis_tuple(axis, x.ndim)

if keepdims:
# Insert None (aka newaxis) for removed dims
slices = tuple(
None if i in removed_axes else slice(None, None) for i in range(x.ndim)
)
data = data[slices]
dims = x.dims
else:
dims = tuple(adim for n, adim in enumerate(x.dims) if n not in removed_axes)

return dims, data


def _reduce_dims(dims: _Dims, *, axis: _AxisLike | None, keepdims: False) -> _Dims:
"""
Reduce dims according to axis.
Examples
--------
>>> _reduce_dims(("x", "y", "z"), axis=None, keepdims=False)
()
>>> _reduce_dims(("x", "y", "z"), axis=1, keepdims=False)
('x', 'z')
>>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=False)
('x', 'y')
keepdims retains the same dims
>>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=True)
('x', 'y', 'z')
"""
if keepdims:
return dims

ndim = len(dims)
if axis is None:
_axis = tuple(v for v in range(ndim))
else:
_axis = _normalize_axis_tuple(axis, ndim)

key = [slice(None)] * ndim
for i, v in enumerate(_axis):
key[v] = 0

return _dims_from_tuple_indexing(dims, tuple(key))


def _new_unique_dim_name(dims: _Dims, i: int | None = None) -> _Dim:
"""
Get a new unique dimension name.
Expand Down Expand Up @@ -565,6 +499,40 @@ def _atleast1d_dims(dims: _Dims) -> _Dims:
return (_new_unique_dim_name(dims),) if len(dims) < 1 else dims


def _reduce_dims(dims: _Dims, *, axis: _AxisLike | None, keepdims: False) -> _Dims:
"""
Reduce dims according to axis.
Examples
--------
>>> _reduce_dims(("x", "y", "z"), axis=None, keepdims=False)
()
>>> _reduce_dims(("x", "y", "z"), axis=1, keepdims=False)
('x', 'z')
>>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=False)
('x', 'y')
keepdims retains the same dims
>>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=True)
('x', 'y', 'z')
"""
if keepdims:
return dims

ndim = len(dims)
if axis is None:
_axis = tuple(v for v in range(ndim))
else:
_axis = _normalize_axis_tuple(axis, ndim)

key = [slice(None)] * ndim
for i, v in enumerate(_axis):
key[v] = 0

return _dims_from_tuple_indexing(dims, tuple(key))


def _raise_if_any_duplicate_dimensions(
dims: _Dims, err_context: str = "This function"
) -> None:
Expand Down

0 comments on commit 889feaa

Please sign in to comment.