Skip to content

Commit

Permalink
ENH: retain masked EA dtypes in groupby with as_index=False (pandas-d…
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Jul 25, 2021
1 parent 4c9ef1b commit daec2e7
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 17 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ enhancement2

Other enhancements
^^^^^^^^^^^^^^^^^^
- :class:`DataFrameGroupBy` operations with ``as_index=False`` now correctly retain ``ExtensionDtype`` dtypes for columns being grouped on (:issue:`41373`)
- Add support for assigning values to ``by`` argument in :meth:`DataFrame.plot.hist` and :meth:`DataFrame.plot.box` (:issue:`15079`)
- :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`)
- Additional options added to :meth:`.Styler.bar` to control alignment and display, with keyword only arguments (:issue:`26070`, :issue:`36419`)
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
self._insert_inaxis_grouper_inplace(result)
result.index = Index(range(len(result)))

return result._convert(datetime=True)
return result

agg = aggregate

Expand Down Expand Up @@ -1684,6 +1684,8 @@ def _wrap_agged_manager(self, mgr: Manager2D) -> DataFrame:
if self.axis == 1:
result = result.T

# Note: we only need to pass datetime=True in order to get numeric
# values converted
return self._reindex_output(result)._convert(datetime=True)

def _iterate_column_groupbys(self, obj: FrameOrSeries):
Expand Down
14 changes: 12 additions & 2 deletions pandas/core/groupby/grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,11 +619,20 @@ def group_arraylike(self) -> ArrayLike:
Analogous to result_index, but holding an ArrayLike to ensure
we can can retain ExtensionDtypes.
"""
if self._group_index is not None:
# _group_index is set in __init__ for MultiIndex cases
return self._group_index._values

elif self._all_grouper is not None:
# retain dtype for categories, including unobserved ones
return self.result_index._values

return self._codes_and_uniques[1]

@cache_readonly
def result_index(self) -> Index:
# TODO: what's the difference between result_index vs group_index?
# result_index retains dtype for categories, including unobserved ones,
# which group_index does not
if self._all_grouper is not None:
group_idx = self.group_index
assert isinstance(group_idx, CategoricalIndex)
Expand All @@ -635,7 +644,8 @@ def group_index(self) -> Index:
if self._group_index is not None:
# _group_index is set in __init__ for MultiIndex cases
return self._group_index
uniques = self.group_arraylike

uniques = self._codes_and_uniques[1]
return Index(uniques, name=self.name)

@cache_readonly
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,7 @@ def result_arraylike(self) -> ArrayLike:
if len(self.groupings) == 1:
return self.groupings[0].group_arraylike

# result_index is MultiIndex
return self.result_index._values

@cache_readonly
Expand All @@ -903,12 +904,12 @@ def get_group_levels(self) -> list[ArrayLike]:
# Note: only called from _insert_inaxis_grouper_inplace, which
# is only called for BaseGrouper, never for BinGrouper
if len(self.groupings) == 1:
return [self.groupings[0].result_index]
return [self.groupings[0].group_arraylike]

name_list = []
for ping, codes in zip(self.groupings, self.reconstructed_codes):
codes = ensure_platform_int(codes)
levels = ping.result_index.take(codes)
levels = ping.group_arraylike.take(codes)

name_list.append(levels)

Expand Down
8 changes: 4 additions & 4 deletions pandas/tests/extension/base/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def test_grouping_grouper(self, data_for_grouping):
def test_groupby_extension_agg(self, as_index, data_for_grouping):
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping})
result = df.groupby("B", as_index=as_index).A.mean()
_, index = pd.factorize(data_for_grouping, sort=True)
_, uniques = pd.factorize(data_for_grouping, sort=True)

index = pd.Index(index, name="B")
expected = pd.Series([3.0, 1.0, 4.0], index=index, name="A")
if as_index:
index = pd.Index(uniques, name="B")
expected = pd.Series([3.0, 1.0, 4.0], index=index, name="A")
self.assert_series_equal(result, expected)
else:
expected = expected.reset_index()
expected = pd.DataFrame({"B": uniques, "A": [3.0, 1.0, 4.0]})
self.assert_frame_equal(result, expected)

def test_groupby_agg_extension(self, data_for_grouping):
Expand Down
4 changes: 0 additions & 4 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,6 @@ def test_groupby_extension_apply(self):
we'll be able to dispatch unique.
"""

@pytest.mark.parametrize("as_index", [True, False])
def test_groupby_extension_agg(self, as_index, data_for_grouping):
super().test_groupby_extension_agg(as_index, data_for_grouping)

@pytest.mark.xfail(reason="GH#39098: Converts agg result to object")
def test_groupby_agg_extension(self, data_for_grouping):
super().test_groupby_agg_extension(data_for_grouping)
Expand Down
8 changes: 4 additions & 4 deletions pandas/tests/extension/test_boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,14 @@ def test_grouping_grouper(self, data_for_grouping):
def test_groupby_extension_agg(self, as_index, data_for_grouping):
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
result = df.groupby("B", as_index=as_index).A.mean()
_, index = pd.factorize(data_for_grouping, sort=True)
_, uniques = pd.factorize(data_for_grouping, sort=True)

index = pd.Index(index, name="B")
expected = pd.Series([3.0, 1.0], index=index, name="A")
if as_index:
index = pd.Index(uniques, name="B")
expected = pd.Series([3.0, 1.0], index=index, name="A")
self.assert_series_equal(result, expected)
else:
expected = expected.reset_index()
expected = pd.DataFrame({"B": uniques, "A": [3.0, 1.0]})
self.assert_frame_equal(result, expected)

def test_groupby_agg_extension(self, data_for_grouping):
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,10 @@ def test_ops_not_as_index(reduction_func):
expected = expected.rename("size")
expected = expected.reset_index()

if reduction_func != "size":
# 32 bit compat -> groupby preserves dtype whereas reset_index casts to int64
expected["a"] = expected["a"].astype(df["a"].dtype)

g = df.groupby("a", as_index=False)

result = getattr(g, reduction_func)()
Expand Down

0 comments on commit daec2e7

Please sign in to comment.