Skip to content

Commit

Permalink
Add doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Sep 19, 2024
1 parent f4f3268 commit 656ab62
Showing 1 changed file with 33 additions and 21 deletions.
54 changes: 33 additions & 21 deletions xarray/namedarray/_array_api/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,28 @@ def _infer_dims(
(None,)
>>> _infer_dims((1,), ("x",))
('x',)
>>> _infer_dims((1, 3), ("x",))
('dim_0', 'x')
>>> _infer_dims((1, 1, 3), ("x",))
('dim_1', 'dim_0', 'x')
"""
if isinstance(dims, Default):
ndim = len(shape)
return tuple(f"dim_{ndim - 1 - n}" for n in range(ndim))

_dims = _normalize_dimensions(dims)
diff = len(shape) - len(_dims)
if diff > 0:
# TODO: Leads to ('dim_0', 'x'), should it be ('dim_1', 'x')?
return _infer_dims(shape[:diff], _default) + _dims
else:
return _normalize_dimensions(dims)
return _dims


def _normalize_axis_index(axis: int, ndim: int) -> int:
"""
Normalize axis index to positive values.
Parameters
----------
axis : int
Expand All @@ -139,31 +151,35 @@ def _normalize_axis_index(axis: int, ndim: int) -> int:
normalized_axis : int
The normalized axis index, such that `0 <= normalized_axis < ndim`
Raises
------
AxisError
If the axis index is invalid, when `-ndim <= axis < ndim` is false.
Examples
--------
>>> _normalize_axis_index(0, ndim=3)
0
>>> _normalize_axis_index(1, ndim=3)
1
>>> _normalize_axis_index(2, ndim=3)
2
>>> _normalize_axis_index(-1, ndim=3)
2
>>> _normalize_axis_index(-2, ndim=3)
1
>>> _normalize_axis_index(-3, ndim=3)
0
Errors
>>> _normalize_axis_index(3, ndim=3)
Traceback (most recent call last):
...
AxisError: axis 3 is out of bounds for array of dimension 3
>>> _normalize_axis_index(-4, ndim=3, msg_prefix='axes_arg')
...
ValueError: axis 3 is out of bounds for array of dimension 3
>>> _normalize_axis_index(-4, ndim=3)
Traceback (most recent call last):
...
AxisError: axes_arg: axis -4 is out of bounds for array of dimension 3
...
ValueError: axis -4 is out of bounds for array of dimension 3
"""

if -ndim > axis >= ndim:
if -ndim > axis or axis >= ndim:
raise ValueError(f"axis {axis} is out of bounds for array of dimension {ndim}")

return axis % ndim
Expand Down Expand Up @@ -248,7 +264,7 @@ def _dims_to_axis(
--------
Convert to dims to axis values
>>> import numpy as np
>>> x = NamedArray(("x", "y", "z"), np.zeros((1, 2, 3)))
>>> _dims_to_axis(x, ("y", "x"), None)
(1, 0)
Expand All @@ -257,7 +273,7 @@ def _dims_to_axis(
>>> _dims_to_axis(x, _default, 0)
(0,)
>>> type(_dims_to_axis(x, _default, None))
NoneType
<class 'NoneType'>
Normalizes negative integers
Expand Down Expand Up @@ -381,6 +397,7 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape]
Examples
--------
>>> import numpy as np
>>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4)))
>>> _get_broadcasted_dims(a)
(('x', 'y', 'z'), (5, 3, 4))
Expand Down Expand Up @@ -455,12 +472,7 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape]
return out_dims, out_shape


if any(i not in [-1, 0, 1, dim] for i in sizes) or len(_d) != 1:
raise ValueError(
f"operands could not be broadcast together with {dims = } and {shapes = }"
)

out_dims += (_d[0],)
out_shape += (dim,)
if __name__ == "__main__":
import doctest

return out_dims[::-1], out_shape[::-1]
doctest.testmod()

0 comments on commit 656ab62

Please sign in to comment.