Skip to content

Commit

Permalink
Add overloads to get_axis_num (pydata#8547)
Browse files Browse the repository at this point in the history
Co-authored-by: Anderson Banihirwe <[email protected]>
  • Loading branch information
Illviljan and andersy005 committed Jan 30, 2024
1 parent b0b5b2f commit 81f38f3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
8 changes: 8 additions & 0 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,14 @@ def __iter__(self: Any) -> Iterator[Any]:
raise TypeError("iteration over a 0-d array")
return self._iter()

@overload
def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]:
...

@overload
def get_axis_num(self, dim: Hashable) -> int:
...

def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]:
"""Return axis number(s) corresponding to dimension(s) in this array.
Expand Down
8 changes: 8 additions & 0 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,14 @@ def _dask_finalize(
data = array_func(results, *args, **kwargs)
return type(self)(self._dims, data, attrs=self._attrs)

@overload
def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]:
...

@overload
def get_axis_num(self, dim: Hashable) -> int:
...

def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]:
"""Return axis number(s) corresponding to dimension(s) in this array.
Expand Down

0 comments on commit 81f38f3

Please sign in to comment.