Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GroupBy.shuffle() #9320

Draft
wants to merge 40 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3bc51bd
Add GroupBy.shuffle()
dcherian Aug 7, 2024
60d7619
Cleanup
dcherian Aug 7, 2024
d1429cd
Cleanup
dcherian Aug 7, 2024
31fc00e
fix
dcherian Aug 7, 2024
4583853
return groupby instance from shuffle
dcherian Aug 13, 2024
abd9dd2
Fix nD by
dcherian Aug 13, 2024
6b820aa
Merge branch 'main' into groupby-shuffle
dcherian Aug 14, 2024
0d70656
Skip if no dask
dcherian Aug 14, 2024
fafb937
fix tests
dcherian Aug 14, 2024
939db9a
Merge branch 'main' into groupby-shuffle
dcherian Aug 14, 2024
a08450e
Add `chunks` to signature
dcherian Aug 14, 2024
d0cd218
FIx self
dcherian Aug 14, 2024
4edc976
Another Self fix
dcherian Aug 14, 2024
0b42be4
Forward chunks too
dcherian Aug 14, 2024
c52734d
[revert]
dcherian Aug 14, 2024
8180625
undo flox limit
dcherian Aug 14, 2024
7897c91
[revert]
dcherian Aug 14, 2024
7773548
fix types
dcherian Aug 14, 2024
51a7723
Add DataArray.shuffle_by, Dataset.shuffle_by
dcherian Aug 15, 2024
cc95513
Add doctest
dcherian Aug 15, 2024
18f4a40
Refactor
dcherian Aug 15, 2024
f489bcf
tweak docstrings
dcherian Aug 15, 2024
ead1bb4
fix typing
dcherian Aug 15, 2024
75115d0
Fix
dcherian Aug 15, 2024
390863a
fix docstring
dcherian Aug 15, 2024
a408cb0
bump min version to dask>=2024.08.1
dcherian Aug 17, 2024
7038f37
Merge branch 'main' into groupby-shuffle
dcherian Aug 17, 2024
05a0fb4
Fix typing
dcherian Aug 17, 2024
b8e7f62
Fix types
dcherian Aug 17, 2024
6d9ed1c
Merge branch 'main' into groupby-shuffle
dcherian Aug 22, 2024
20a8cd9
Merge branch 'main' into groupby-shuffle
dcherian Aug 30, 2024
7a99c8f
remove shuffle_by for now.
dcherian Aug 30, 2024
5e2fdfb
Add tests
dcherian Aug 30, 2024
a22c7ed
Support shuffling with multiple groupers
dcherian Aug 30, 2024
2d48690
Revert "remove shuffle_by for now."
dcherian Sep 11, 2024
0679d2b
Merge branch 'main' into groupby-shuffle
dcherian Sep 12, 2024
63b3e77
Merge branch 'main' into groupby-shuffle
dcherian Sep 17, 2024
7dc5dd1
bad merge
dcherian Sep 17, 2024
bad0744
Merge branch 'main' into groupby-shuffle
dcherian Sep 18, 2024
91e4bd8
Add a test
dcherian Sep 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,3 +831,18 @@ def chunked_nanfirst(darray, axis):

def chunked_nanlast(darray, axis):
return _chunked_first_or_last(darray, axis, op=nputils.nanlast)


def shuffle_array(array, indices: list[list[int]], axis: int):
# TODO: do chunk manager dance here.
if is_duck_dask_array(array):
if not module_available("dask", minversion="2024.08.0"):
raise ValueError(
"This method is very inefficient on dask<2024.08.0. Please upgrade."
)
# TODO: handle dimensions
return array.shuffle(indexer=indices, axis=axis)
else:
indexer = np.concatenate(indices)
# TODO: Do the array API thing here.
return np.take(array, indices=indexer, axis=axis)
48 changes: 48 additions & 0 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,54 @@ def sizes(self) -> Mapping[Hashable, int]:
self._sizes = self._obj.isel({self._group_dim: index}).sizes
return self._sizes

def shuffle(self) -> None:
dcherian marked this conversation as resolved.
Show resolved Hide resolved
"""
Shuffle the underlying object so that all members in a group occur sequentially.

The order of appearance is not guaranteed. This method modifies the underlying Xarray
object in place.

Use this method first if you need to map a function that requires all members of a group
be in a single chunk.
"""
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.duck_array_ops import shuffle_array

(grouper,) = self.groupers
dim = self._group_dim

# Slices mean this is already sorted. E.g. resampling ops, _DummyGroup
if all(isinstance(idx, slice) for idx in self._group_indices):
return

was_array = isinstance(self._obj, DataArray)
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj

shuffled = Dataset()
for name, var in as_dataset._variables.items():
if dim not in var.dims:
shuffled[name] = var
continue
shuffled_data = shuffle_array(
var._data, list(self._group_indices), axis=var.get_axis_num(dim)
)
shuffled[name] = var._replace(data=shuffled_data)

# Replace self._group_indices with slices
slices = []
start = 0
for idxr in self._group_indices:
slices.append(slice(start, start + len(idxr)))
start += len(idxr)
# TODO: we have now broken the invariant
# self._group_indices ≠ self.groupers[0].group_indices
self._group_indices = tuple(slices)
if was_array:
self._obj = self._obj._from_temp_dataset(shuffled)
else:
self._obj = shuffled

def map(
self,
func: Callable,
Expand Down
1 change: 1 addition & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _importorskip(
has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf")
has_cftime, requires_cftime = _importorskip("cftime")
has_dask, requires_dask = _importorskip("dask")
has_dask_ge_2024_08_0, _ = _importorskip("dask", minversion="2024.08.0")
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand Down
18 changes: 17 additions & 1 deletion xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
assert_identical,
create_test_data,
has_cftime,
has_dask_ge_2024_08_0,
has_flox,
requires_cftime,
requires_dask,
Expand Down Expand Up @@ -1293,11 +1294,26 @@ def test_groupby_sum(self) -> None:
assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y"))
assert_allclose(expected_sum_axis1, grouped.sum("y"))

@pytest.mark.parametrize(
"shuffle",
[
pytest.param(
True,
marks=pytest.mark.skipif(
not has_dask_ge_2024_08_0, reason="dask too old"
),
),
False,
],
)
@pytest.mark.parametrize("method", ["sum", "mean", "median"])
def test_groupby_reductions(self, method) -> None:
def test_groupby_reductions(self, method: str, shuffle: bool) -> None:
array = self.da
grouped = array.groupby("abc")

if shuffle:
grouped.shuffle()

reduction = getattr(np, method)
expected = Dataset(
{
Expand Down
Loading