Skip to content

Commit

Permalink
Fix bug with NaNs in by and method='blockwise'
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 14, 2024
1 parent f0ce343 commit 15eab72
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
26 changes: 17 additions & 9 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,12 +1797,13 @@ def dask_groupby_agg(
output_chunks = new_dims_shape + reduced.chunks[: -len(axis)] + group_chunks
new_axes = dict(zip(new_inds, new_dims_shape))

if method == "blockwise" and len(axis) > 1:
# The final results are available but the blocks along axes
# need to be reshaped to axis=-1
# I don't know that this is possible with blockwise
# All other code paths benefit from an unmaterialized Blockwise layer
reduced = _collapse_blocks_along_axes(reduced, axis, group_chunks)
if method == "blockwise":
if len(axis) > 1:
# The final results are available but the blocks along axes
# need to be reshaped to axis=-1
# I don't know that this is possible with blockwise
# All other code paths benefit from an unmaterialized Blockwise layer
reduced = _collapse_blocks_along_axes(reduced, axis, group_chunks)

# Can't use map_blocks because it forces concatenate=True along drop_axes,
result = dask.array.blockwise(
Expand All @@ -1817,7 +1818,6 @@ def dask_groupby_agg(
concatenate=False,
new_axes=new_axes,
)

return (result, groups)


Expand Down Expand Up @@ -2663,10 +2663,18 @@ def groupby_reduce(
groups = (groups[0][sorted_idx],)

if factorize_early:
assert len(groups) == 1
(groups_,) = groups
# nan group labels are factorized to -1, and preserved
# now we get rid of them by reindexing
# This also handles bins with no data
result = reindex_(result, from_=groups[0], to=expected_, fill_value=fill_value).reshape(
# First, for "blockwise", we can have -1 repeated in different blocks
# This breaks the reindexing so remove those first.
if method == "blockwise" and (mask := groups_ == -1).sum(axis=-1) > 1:
result = result[..., ~mask]
groups_ = groups_[..., ~mask]

# This reindex also handles bins with no data
result = reindex_(result, from_=groups_, to=expected_, fill_value=fill_value).reshape(
result.shape[:-1] + grp_shape
)
groups = final_groups
Expand Down
14 changes: 14 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1928,3 +1928,17 @@ def test_ffill_bfill(chunks, size, add_nan_by, func):
expected = flox.groupby_scan(array.compute(), by, func=func)
actual = flox.groupby_scan(array, by, func=func)
assert_equal(expected, actual)


@requires_dask
def test_blockwise_nans():
array = dask.array.ones((1, 10), chunks=2)
by = np.array([-1, 0, -1, 1, -1, 2, -1, 3, 4, 4])
actual, actual_groups = flox.groupby_reduce(
array, by, func="sum", expected_groups=pd.RangeIndex(0, 5)
)
expected, expected_groups = flox.groupby_reduce(
array.compute(), by, func="sum", expected_groups=pd.RangeIndex(0, 5)
)
assert_equal(expected_groups, actual_groups)
assert_equal(expected, actual)

0 comments on commit 15eab72

Please sign in to comment.