Skip to content

Commit

Permalink
BUG: .transform(...) with "first" and "last" fail when axis=1 (pandas…
Browse files Browse the repository at this point in the history
  • Loading branch information
rhshadrach authored Feb 26, 2022
1 parent 7dea5ae commit e932ec9
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 41 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ Groupby/resample/rolling
- Bug in :meth:`DataFrame.resample` ignoring ``closed="right"`` on :class:`TimedeltaIndex` (:issue:`45414`)
- Bug in :meth:`.DataFrameGroupBy.transform` fails when ``func="size"`` and the input DataFrame has multiple columns (:issue:`27469`)
- Bug in :meth:`.DataFrameGroupBy.size` and :meth:`.DataFrameGroupBy.transform` with ``func="size"`` produced incorrect results when ``axis=1`` (:issue:`45715`)
- Bug in :meth:`.DataFrameGroupby.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`)

Reshaping
^^^^^^^^^
Expand Down
9 changes: 0 additions & 9 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,6 @@ def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
result.name = self.obj.name
return result

def _can_use_transform_fast(self, func: str, result) -> bool:
return True

def filter(self, func, dropna: bool = True, *args, **kwargs):
"""
Return a copy of a Series excluding elements from groups that
Expand Down Expand Up @@ -1184,12 +1181,6 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)

def _can_use_transform_fast(self, func: str, result) -> bool:
return func == "size" or (
isinstance(result, DataFrame)
and result.columns.equals(self._obj_with_exclusions.columns)
)

def _define_paths(self, func, *args, **kwargs):
if isinstance(func, str):
fast_path = lambda group: getattr(group, func)(*args, **kwargs)
Expand Down
6 changes: 1 addition & 5 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,11 +1650,7 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
with com.temp_setattr(self, "observed", True):
result = getattr(self, func)(*args, **kwargs)

if self._can_use_transform_fast(func, result):
return self._wrap_transform_fast_result(result)

# only reached for DataFrameGroupBy
return self._transform_general(func, *args, **kwargs)
return self._wrap_transform_fast_result(result)

@final
def _wrap_transform_fast_result(self, result: NDFrameT) -> NDFrameT:
Expand Down
40 changes: 13 additions & 27 deletions pandas/tests/groupby/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
)
import pandas._testing as tm
from pandas.core.groupby.base import maybe_normalize_deprecated_kernels
from pandas.core.groupby.generic import (
DataFrameGroupBy,
SeriesGroupBy,
)
from pandas.core.groupby.generic import DataFrameGroupBy


def assert_fp_equal(a, b):
Expand Down Expand Up @@ -195,10 +192,8 @@ def test_transform_axis_1_reducer(request, reduction_func):
# GH#45715
if reduction_func in (
"corrwith",
"first",
"idxmax",
"idxmin",
"last",
"ngroup",
"nth",
):
Expand Down Expand Up @@ -418,45 +413,36 @@ def test_transform_select_columns(df):
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("duplicates", [True, False])
def test_transform_exclude_nuisance(df, duplicates):
def test_transform_exclude_nuisance(df):
# case that goes through _transform_item_by_item

if duplicates:
# make sure we work with duplicate columns GH#41427
df.columns = ["A", "C", "C", "D"]
df.columns = ["A", "B", "B", "D"]

# this also tests orderings in transform between
# series/frame to make sure it's consistent
expected = {}
grouped = df.groupby("A")

gbc = grouped["C"]
warn = FutureWarning if duplicates else None
with tm.assert_produces_warning(warn, match="Dropping invalid columns"):
expected["C"] = gbc.transform(np.mean)
if duplicates:
# squeeze 1-column DataFrame down to Series
expected["C"] = expected["C"]["C"]
gbc = grouped["B"]
with tm.assert_produces_warning(FutureWarning, match="Dropping invalid columns"):
expected["B"] = gbc.transform(lambda x: np.mean(x))
# squeeze 1-column DataFrame down to Series
expected["B"] = expected["B"]["B"]

assert isinstance(gbc.obj, DataFrame)
assert isinstance(gbc, DataFrameGroupBy)
else:
assert isinstance(gbc, SeriesGroupBy)
assert isinstance(gbc.obj, Series)
assert isinstance(gbc.obj, DataFrame)
assert isinstance(gbc, DataFrameGroupBy)

expected["D"] = grouped["D"].transform(np.mean)
expected = DataFrame(expected)
with tm.assert_produces_warning(FutureWarning, match="Dropping invalid columns"):
result = df.groupby("A").transform(np.mean)
result = df.groupby("A").transform(lambda x: np.mean(x))

tm.assert_frame_equal(result, expected)


def test_transform_function_aliases(df):
with tm.assert_produces_warning(FutureWarning, match="Dropping invalid columns"):
result = df.groupby("A").transform("mean")
expected = df.groupby("A").transform(np.mean)
result = df.groupby("A").transform("mean")
expected = df.groupby("A").transform(np.mean)
tm.assert_frame_equal(result, expected)

result = df.groupby("A")["C"].transform("mean")
Expand Down

0 comments on commit e932ec9

Please sign in to comment.