Skip to content

Commit

Permalink
PERF: faster groupby diff (pandas-dev#45575)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley authored Feb 28, 2022
1 parent 77d9237 commit e70d310
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 3 deletions.
6 changes: 4 additions & 2 deletions asv_bench/benchmarks/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

method_blocklist = {
"object": {
"diff",
"median",
"prod",
"sem",
Expand Down Expand Up @@ -405,7 +406,7 @@ class GroupByMethods:

param_names = ["dtype", "method", "application", "ncols"]
params = [
["int", "float", "object", "datetime", "uint"],
["int", "int16", "float", "object", "datetime", "uint"],
[
"all",
"any",
Expand All @@ -417,6 +418,7 @@ class GroupByMethods:
"cumprod",
"cumsum",
"describe",
"diff",
"ffill",
"first",
"head",
Expand Down Expand Up @@ -478,7 +480,7 @@ def setup(self, dtype, method, application, ncols):
values = rng.take(taker, axis=0)
if dtype == "int":
key = np.random.randint(0, size, size=size)
elif dtype == "uint":
elif dtype in ("int16", "uint"):
key = np.random.randint(0, size, size=size, dtype=dtype)
elif dtype == "float":
key = np.concatenate(
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ Performance improvements
~~~~~~~~~~~~~~~~~~~~~~~~
- Performance improvement in :meth:`.GroupBy.transform` for some user-defined DataFrame -> Series functions (:issue:`45387`)
- Performance improvement in :meth:`DataFrame.duplicated` when subset consists of only one column (:issue:`45236`)
- Performance improvement in :meth:`.GroupBy.diff` (:issue:`16706`)
- Performance improvement in :meth:`.GroupBy.transform` when broadcasting values for user-defined functions (:issue:`45708`)
- Performance improvement in :meth:`.GroupBy.transform` for user-defined functions when only a single group exists (:issue:`44977`)
- Performance improvement in :meth:`MultiIndex.get_locs` (:issue:`45681`, :issue:`46040`)
Expand Down
41 changes: 41 additions & 0 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3456,6 +3456,47 @@ def shift(self, periods=1, freq=None, axis=0, fill_value=None):
)
return res

@final
@Substitution(name="groupby")
@Appender(_common_see_also)
def diff(self, periods: int = 1, axis: int = 0) -> Series | DataFrame:
"""
First discrete difference of element.
Calculates the difference of each element compared with another
element in the group (default is element in previous row).
Parameters
----------
periods : int, default 1
Periods to shift for calculating difference, accepts negative values.
axis : axis to shift, default 0
Take difference over rows (0) or columns (1).
Returns
-------
Series or DataFrame
First differences.
"""
if axis != 0:
return self.apply(lambda x: x.diff(periods=periods, axis=axis))

obj = self._obj_with_exclusions
shifted = self.shift(periods=periods, axis=axis)

# GH45562 - to retain existing behavior and match behavior of Series.diff(),
# int8 and int16 are coerced to float32 rather than float64.
dtypes_to_f32 = ["int8", "int16"]
if obj.ndim == 1:
if obj.dtype in dtypes_to_f32:
shifted = shifted.astype("float32")
else:
to_coerce = [c for c, dtype in obj.dtypes.items() if dtype in dtypes_to_f32]
if len(to_coerce):
shifted = shifted.astype({c: "float32" for c in to_coerce})

return obj - shifted

@final
@Substitution(name="groupby")
@Appender(_common_see_also)
Expand Down
25 changes: 24 additions & 1 deletion pandas/tests/groupby/test_groupby_shift_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_group_shift_lose_timezone():
tm.assert_series_equal(result, expected)


def test_group_diff_real(any_real_numpy_dtype):
def test_group_diff_real_series(any_real_numpy_dtype):
df = DataFrame(
{"a": [1, 2, 3, 3, 2], "b": [1, 2, 3, 4, 5]},
dtype=any_real_numpy_dtype,
Expand All @@ -82,6 +82,29 @@ def test_group_diff_real(any_real_numpy_dtype):
tm.assert_series_equal(result, expected)


def test_group_diff_real_frame(any_real_numpy_dtype):
df = DataFrame(
{
"a": [1, 2, 3, 3, 2],
"b": [1, 2, 3, 4, 5],
"c": [1, 2, 3, 4, 6],
},
dtype=any_real_numpy_dtype,
)
result = df.groupby("a").diff()
exp_dtype = "float"
if any_real_numpy_dtype in ["int8", "int16", "float32"]:
exp_dtype = "float32"
expected = DataFrame(
{
"b": [np.nan, np.nan, np.nan, 1.0, 3.0],
"c": [np.nan, np.nan, np.nan, 1.0, 4.0],
},
dtype=exp_dtype,
)
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"data",
[
Expand Down

0 comments on commit e70d310

Please sign in to comment.