Skip to content

Commit

Permalink
allow using __array_function__ as a fallback for missing array API …
Browse files Browse the repository at this point in the history
…funcs
  • Loading branch information
keewis committed Sep 22, 2024
1 parent a10e97a commit ecb47d6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
4 changes: 2 additions & 2 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,10 +875,10 @@ def as_indexable(array):
return PandasIndexingAdapter(array)
if is_duck_dask_array(array):
return DaskIndexingAdapter(array)
if hasattr(array, "__array_function__"):
return NdArrayLikeIndexingAdapter(array)
if hasattr(array, "__array_namespace__"):
return ArrayApiIndexingAdapter(array)
if hasattr(array, "__array_function__"):
return NdArrayLikeIndexingAdapter(array)

raise TypeError(f"Invalid array type: {type(array)}")

Expand Down
13 changes: 13 additions & 0 deletions xarray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,14 @@ def __array_function__(self, func, types, args, kwargs):
pass


class ArrayWithNamespaceAndArrayFunction:
def __array_namespace__(self, version=None):
pass

def __array_function__(self, func, types, args, kwargs):
pass


@pytest.mark.parametrize(
["array", "expected_type"],
(
Expand All @@ -926,6 +934,11 @@ def __array_function__(self, func, types, args, kwargs):
indexing.NdArrayLikeIndexingAdapter,
id="array_like",
),
pytest.param(
ArrayWithNamespaceAndArrayFunction(),
indexing.ArrayApiIndexingAdapter,
id="array_api_with_fallback",
),
),
)
def test_as_indexable(array, expected_type):
Expand Down

0 comments on commit ecb47d6

Please sign in to comment.