Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 19, 2024
1 parent 656ab62 commit 95f8490
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
1 change: 0 additions & 1 deletion xarray/namedarray/_array_api/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from xarray.namedarray._array_api._manipulation_functions import _arithmetic_broadcast
from xarray.namedarray._array_api._utils import (
_get_broadcasted_dims,
_get_data_namespace,
)
from xarray.namedarray._typing import (
Expand Down
12 changes: 5 additions & 7 deletions xarray/namedarray/_array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,23 @@

from xarray.namedarray._array_api._data_type_functions import result_type
from xarray.namedarray._array_api._utils import (
_dims_to_axis,
_get_broadcasted_dims,
_get_data_namespace,
_infer_dims,
_insert_dim,
_dims_to_axis,
)
from xarray.namedarray._typing import (
Default,
_arrayapi,
_Axes,
_Axis,
_AxisLike,
_default,
_Dim,
_DType,
_ShapeType,
_Dims,
_DimsLike2,
_DType,
_Shape,
_ShapeType,
)
from xarray.namedarray.core import NamedArray

Expand Down Expand Up @@ -54,7 +52,7 @@ def broadcast_arrays(*arrays: NamedArray[Any, Any]) -> list[NamedArray[Any, Any]
_dims, _ = _get_broadcasted_dims(*arrays)
_arrays = tuple(a._data for a in arrays)
_datas = xp.broadcast_arrays(*_arrays)
return [arr._new(_dims, _data) for arr, _data in zip(arrays, _datas)]
return [arr._new(_dims, _data) for arr, _data in zip(arrays, _datas, strict=False)]


def broadcast_to(
Expand Down Expand Up @@ -346,7 +344,7 @@ def _set_dims(

if shape is not None:
# Add dimensions, with same size as shape:
dims_map = dict(zip(dim, shape))
dims_map = dict(zip(dim, shape, strict=False))
expanded_dims = extra_dims + x.dims
tmp_shape = tuple(dims_map[d] for d in expanded_dims)
return permute_dims(broadcast_to(x, tmp_shape, dims=expanded_dims), dims=dim)
Expand Down
4 changes: 3 additions & 1 deletion xarray/namedarray/_array_api/_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def nonzero(x: NamedArray[Any, Any], /) -> tuple[NamedArray[Any, Any], ...]:
xp = _get_data_namespace(x)
_datas: tuple[_arrayapi[Any, Any], ...] = xp.nonzero(x._data)
# TODO: Verify that dims and axis matches here:
return tuple(x._new((dim,), data) for dim, data in zip(x.dims, _datas))
return tuple(
x._new((dim,), data) for dim, data in zip(x.dims, _datas, strict=False)
)


def searchsorted(
Expand Down
3 changes: 2 additions & 1 deletion xarray/namedarray/_array_api/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,9 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape]
for dims, shape in zip(
zip_longest(*map(reversed, arrays_dims), fillvalue=_default),
zip_longest(*map(reversed, arrays_shapes), fillvalue=-1),
strict=False,
):
for d, s in zip(reversed(dims), reversed(shape)):
for d, s in zip(reversed(dims), reversed(shape), strict=False):
if isinstance(d, Default):
continue

Expand Down

0 comments on commit 95f8490

Please sign in to comment.