Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 7, 2024
1 parent 3bc51bd commit 58d01b2
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 9 deletions.
10 changes: 3 additions & 7 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,13 +835,9 @@ def chunked_nanlast(darray, axis):

def shuffle_array(array, indices: list[list[int]], axis: int):
# TODO: do chunk manager dance here.
if is_duck_dask_array(array):
if not module_available("dask", minversion="2024.08.0"):
raise ValueError(
"This method is very inefficient on dask<2024.08.0. Please upgrade."
)
# TODO: handle dimensions
return array.shuffle(indexer=indices, axis=axis)
if is_chunked_array(array):
chunkmanager = get_chunked_array_type(array)
return chunkmanager.shuffle(array, indexer=indices, axis=axis)
else:
indexer = np.concatenate(indices)
# TODO: Do the array API thing here.
Expand Down
11 changes: 10 additions & 1 deletion xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,11 @@ def shuffle(self) -> None:
if all(isinstance(idx, slice) for idx in self._group_indices):
return

if TYPE_CHECKING:
for idx in self._group_indices:
assert not isinstance(idx, slice)
indices: tuple[list[int]] = self._group_indices # type: ignore[assignment]

was_array = isinstance(self._obj, DataArray)
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj

Expand All @@ -547,20 +552,24 @@ def shuffle(self) -> None:
shuffled[name] = var
continue
shuffled_data = shuffle_array(
var._data, list(self._group_indices), axis=var.get_axis_num(dim)
var._data, list(indices), axis=var.get_axis_num(dim)
)
shuffled[name] = var._replace(data=shuffled_data)

# Replace self._group_indices with slices
slices = []
start = 0
for idxr in self._group_indices:
if TYPE_CHECKING:
assert not isinstance(idxr, slice)
slices.append(slice(start, start + len(idxr)))
start += len(idxr)
# TODO: we have now broken the invariant
# self._group_indices ≠ self.groupers[0].group_indices
self._group_indices = tuple(slices)
if was_array:
if TYPE_CHECKING:
assert isinstance(self._obj, DataArray)
self._obj = self._obj._from_temp_dataset(shuffled)
else:
self._obj = shuffled
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def copy(
ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"]

GroupKey = Any
GroupIndex = Union[int, slice, list[int]]
GroupIndex = Union[slice, list[int]]
GroupIndices = tuple[GroupIndex, ...]
Bins = Union[
int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index
Expand Down
9 changes: 9 additions & 0 deletions xarray/namedarray/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,12 @@ def store(
targets=targets,
**kwargs,
)

def shuffle(self, x: DaskArray, indexer: list[list[int]], axis: int) -> DaskArray:
import dask.array

if not module_available("dask", minversion="2024.08.0"):
raise ValueError(
"This method is very inefficient on dask<2024.08.0. Please upgrade."
)
return dask.array.shuffle(x, indexer, axis)
5 changes: 5 additions & 0 deletions xarray/namedarray/parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,11 @@ def compute(
"""
raise NotImplementedError()

def shuffle(
self, x: T_ChunkedArray, indexer: list[list[int]], axis: int
) -> T_ChunkedArray:
raise NotImplementedError()

@property
def array_api(self) -> Any:
"""
Expand Down

0 comments on commit 58d01b2

Please sign in to comment.