Skip to content

Commit

Permalink
Update _sorting_functions.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Aug 25, 2024
1 parent 0470089 commit 5ff5f66
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions xarray/namedarray/_array_api/_sorting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ def argsort(
axis: int = -1,
) -> NamedArray:
xp = _get_data_namespace(x)
_axis = _dims_to_axis(x, dim, axis)
_axis = _dims_to_axis(x, dim, axis)[0]
if not descending:
_data = xp.argsort(x._data, axis=_axis, stable=stable)
else:
# As NumPy has no native descending sort, we imitate it here. Note that
# simply flipping the results of np.argsort(x._array, ...) would not
# respect the relative order like it would in native descending sorts.
_data = xp.flip(
xp.argsort(xp.flip(x._data, axis=axis), stable=stable, axis=axis),
axis=axis,
xp.argsort(xp.flip(x._data, axis=_axis), stable=stable, axis=_axis),
axis=_axis,
)
# Rely on flip()/argsort() to validate axis
normalised_axis = axis if axis >= 0 else x.ndim + axis
normalised_axis = _axis if _axis >= 0 else x.ndim + _axis
max_i = x.shape[normalised_axis] - 1
_data = max_i - _data
return x._new(data=_data)
Expand All @@ -47,8 +47,8 @@ def sort(
axis: int = -1,
) -> NamedArray:
xp = _get_data_namespace(x)
_axis = _dims_to_axis(x, dim, axis)
_axis = _dims_to_axis(x, dim, axis)[0]
_data = xp.sort(x._data, axis=_axis, stable=stable)
if descending:
_data = xp.flip(_data, axis=axis)
_data = xp.flip(_data, axis=_axis)
return x._new(data=_data)

0 comments on commit 5ff5f66

Please sign in to comment.