Skip to content

Commit

Permalink
allow using __array_function__ as a fallback for missing Array API …
Browse files Browse the repository at this point in the history
…functions (#9530)

* don't pass along `out` for `nanprod`

We don't do this anywhere elsewhere, so it doesn't make sense to do
this only for `nanprod`.

* add tests for `as_indexable`

* allow using `__array_function__` as a fallback for missing array API funcs

* also check dask

* don't try to create a `dask` array if `dask` is not installed
  • Loading branch information
keewis authored Sep 30, 2024
1 parent cde720f commit a14d202
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 3 deletions.
4 changes: 2 additions & 2 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,10 +878,10 @@ def as_indexable(array):
return PandasIndexingAdapter(array)
if is_duck_dask_array(array):
return DaskIndexingAdapter(array)
if hasattr(array, "__array_function__"):
return NdArrayLikeIndexingAdapter(array)
if hasattr(array, "__array_namespace__"):
return ArrayApiIndexingAdapter(array)
if hasattr(array, "__array_function__"):
return NdArrayLikeIndexingAdapter(array)

raise TypeError(f"Invalid array type: {type(array)}")

Expand Down
2 changes: 1 addition & 1 deletion xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def nanstd(a, axis=None, dtype=None, out=None, ddof=0):

def nanprod(a, axis=None, dtype=None, out=None, min_count=None):
mask = isnull(a)
result = nputils.nanprod(a, axis=axis, dtype=dtype, out=out)
result = nputils.nanprod(a, axis=axis, dtype=dtype)
if min_count is not None:
return _maybe_null_out(result, axis, mask, min_count)
else:
Expand Down
68 changes: 68 additions & 0 deletions xarray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,74 @@ def test_posify_mask_subindexer(indices, expected) -> None:
np.testing.assert_array_equal(expected, actual)


class ArrayWithNamespace:
def __array_namespace__(self, version=None):
pass


class ArrayWithArrayFunction:
def __array_function__(self, func, types, args, kwargs):
pass


class ArrayWithNamespaceAndArrayFunction:
def __array_namespace__(self, version=None):
pass

def __array_function__(self, func, types, args, kwargs):
pass


def as_dask_array(arr, chunks):
try:
import dask.array as da
except ImportError:
return None

return da.from_array(arr, chunks=chunks)


@pytest.mark.parametrize(
["array", "expected_type"],
(
pytest.param(
indexing.CopyOnWriteArray(np.array([1, 2])),
indexing.CopyOnWriteArray,
id="ExplicitlyIndexed",
),
pytest.param(
np.array([1, 2]), indexing.NumpyIndexingAdapter, id="numpy.ndarray"
),
pytest.param(
pd.Index([1, 2]), indexing.PandasIndexingAdapter, id="pandas.Index"
),
pytest.param(
as_dask_array(np.array([1, 2]), chunks=(1,)),
indexing.DaskIndexingAdapter,
id="dask.array",
marks=requires_dask,
),
pytest.param(
ArrayWithNamespace(), indexing.ArrayApiIndexingAdapter, id="array_api"
),
pytest.param(
ArrayWithArrayFunction(),
indexing.NdArrayLikeIndexingAdapter,
id="array_like",
),
pytest.param(
ArrayWithNamespaceAndArrayFunction(),
indexing.ArrayApiIndexingAdapter,
id="array_api_with_fallback",
),
),
)
def test_as_indexable(array, expected_type):
actual = indexing.as_indexable(array)

assert isinstance(actual, expected_type)


def test_indexing_1d_object_array() -> None:
items = (np.arange(3), np.arange(6))
arr = DataArray(np.array(items, dtype=object))
Expand Down

0 comments on commit a14d202

Please sign in to comment.