From ecb47d6005da928a19af4225c15b52bd01d48d9d Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sun, 22 Sep 2024 15:01:46 +0200 Subject: [PATCH] allow using `__array_function__` as a fallback for missing array API funcs --- xarray/core/indexing.py | 4 ++-- xarray/tests/test_indexing.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 67912908a2b..6985d6681a5 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -875,10 +875,10 @@ def as_indexable(array): return PandasIndexingAdapter(array) if is_duck_dask_array(array): return DaskIndexingAdapter(array) - if hasattr(array, "__array_function__"): - return NdArrayLikeIndexingAdapter(array) if hasattr(array, "__array_namespace__"): return ArrayApiIndexingAdapter(array) + if hasattr(array, "__array_function__"): + return NdArrayLikeIndexingAdapter(array) raise TypeError(f"Invalid array type: {type(array)}") diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 39ac09bf802..4bf05471d71 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -904,6 +904,14 @@ def __array_function__(self, func, types, args, kwargs): pass +class ArrayWithNamespaceAndArrayFunction: + def __array_namespace__(self, version=None): + pass + + def __array_function__(self, func, types, args, kwargs): + pass + + @pytest.mark.parametrize( ["array", "expected_type"], ( @@ -926,6 +934,11 @@ def __array_function__(self, func, types, args, kwargs): indexing.NdArrayLikeIndexingAdapter, id="array_like", ), + pytest.param( + ArrayWithNamespaceAndArrayFunction(), + indexing.ArrayApiIndexingAdapter, + id="array_api_with_fallback", + ), ), ) def test_as_indexable(array, expected_type):