From a14d202ed91e5bbc93035b25ffc3d334193c9590 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Mon, 30 Sep 2024 19:51:58 +0200 Subject: [PATCH] allow using `__array_function__` as a fallback for missing Array API functions (#9530) * don't pass along `out` for `nanprod` We don't do this anywhere elsewhere, so it doesn't make sense to do this only for `nanprod`. * add tests for `as_indexable` * allow using `__array_function__` as a fallback for missing array API funcs * also check dask * don't try to create a `dask` array if `dask` is not installed --- xarray/core/indexing.py | 4 +-- xarray/core/nanops.py | 2 +- xarray/tests/test_indexing.py | 68 +++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 3 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 08b1d0be290..d8727c38c48 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -878,10 +878,10 @@ def as_indexable(array): return PandasIndexingAdapter(array) if is_duck_dask_array(array): return DaskIndexingAdapter(array) - if hasattr(array, "__array_function__"): - return NdArrayLikeIndexingAdapter(array) if hasattr(array, "__array_namespace__"): return ArrayApiIndexingAdapter(array) + if hasattr(array, "__array_function__"): + return NdArrayLikeIndexingAdapter(array) raise TypeError(f"Invalid array type: {type(array)}") diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index fc7240139aa..7fbb63068c0 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -162,7 +162,7 @@ def nanstd(a, axis=None, dtype=None, out=None, ddof=0): def nanprod(a, axis=None, dtype=None, out=None, min_count=None): mask = isnull(a) - result = nputils.nanprod(a, axis=axis, dtype=dtype, out=out) + result = nputils.nanprod(a, axis=axis, dtype=dtype) if min_count is not None: return _maybe_null_out(result, axis, mask, min_count) else: diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 985fb02a87e..92c21cc32fb 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -894,6 +894,74 @@ def test_posify_mask_subindexer(indices, expected) -> None: np.testing.assert_array_equal(expected, actual) +class ArrayWithNamespace: + def __array_namespace__(self, version=None): + pass + + +class ArrayWithArrayFunction: + def __array_function__(self, func, types, args, kwargs): + pass + + +class ArrayWithNamespaceAndArrayFunction: + def __array_namespace__(self, version=None): + pass + + def __array_function__(self, func, types, args, kwargs): + pass + + +def as_dask_array(arr, chunks): + try: + import dask.array as da + except ImportError: + return None + + return da.from_array(arr, chunks=chunks) + + +@pytest.mark.parametrize( + ["array", "expected_type"], + ( + pytest.param( + indexing.CopyOnWriteArray(np.array([1, 2])), + indexing.CopyOnWriteArray, + id="ExplicitlyIndexed", + ), + pytest.param( + np.array([1, 2]), indexing.NumpyIndexingAdapter, id="numpy.ndarray" + ), + pytest.param( + pd.Index([1, 2]), indexing.PandasIndexingAdapter, id="pandas.Index" + ), + pytest.param( + as_dask_array(np.array([1, 2]), chunks=(1,)), + indexing.DaskIndexingAdapter, + id="dask.array", + marks=requires_dask, + ), + pytest.param( + ArrayWithNamespace(), indexing.ArrayApiIndexingAdapter, id="array_api" + ), + pytest.param( + ArrayWithArrayFunction(), + indexing.NdArrayLikeIndexingAdapter, + id="array_like", + ), + pytest.param( + ArrayWithNamespaceAndArrayFunction(), + indexing.ArrayApiIndexingAdapter, + id="array_api_with_fallback", + ), + ), +) +def test_as_indexable(array, expected_type): + actual = indexing.as_indexable(array) + + assert isinstance(actual, expected_type) + + def test_indexing_1d_object_array() -> None: items = (np.arange(3), np.arange(6)) arr = DataArray(np.array(items, dtype=object))