Skip to content

Commit

Permalink
BUG: nullable groupby result dtypes (pandas-dev#46197)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Mar 5, 2022
1 parent aec0f57 commit 9303302
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 31 deletions.
20 changes: 20 additions & 0 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,3 +1417,23 @@ def construct_array_type(cls) -> type_t[BaseMaskedArray]:
type
"""
raise NotImplementedError

@classmethod
def from_numpy_dtype(cls, dtype: np.dtype) -> BaseMaskedDtype:
"""
Construct the MaskedDtype corresponding to the given numpy dtype.
"""
if dtype.kind == "b":
from pandas.core.arrays.boolean import BooleanDtype

return BooleanDtype()
elif dtype.kind in ["i", "u"]:
from pandas.core.arrays.integer import INT_STR_TO_DTYPE

return INT_STR_TO_DTYPE[dtype.name]
elif dtype.kind == "f":
from pandas.core.arrays.floating import FLOAT_STR_TO_DTYPE

return FLOAT_STR_TO_DTYPE[dtype.name]
else:
raise NotImplementedError(dtype)
49 changes: 19 additions & 30 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Iterator,
Sequence,
final,
overload,
)

import numpy as np
Expand Down Expand Up @@ -57,7 +56,6 @@
is_timedelta64_dtype,
needs_i8_conversion,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.missing import (
isna,
maybe_fill,
Expand All @@ -70,14 +68,8 @@
TimedeltaArray,
)
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.floating import (
Float64Dtype,
FloatingDtype,
)
from pandas.core.arrays.integer import (
Int64Dtype,
IntegerDtype,
)
from pandas.core.arrays.floating import FloatingDtype
from pandas.core.arrays.integer import IntegerDtype
from pandas.core.arrays.masked import (
BaseMaskedArray,
BaseMaskedDtype,
Expand Down Expand Up @@ -277,41 +269,27 @@ def _get_out_dtype(self, dtype: np.dtype) -> np.dtype:
out_dtype = "object"
return np.dtype(out_dtype)

@overload
def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
... # pragma: no cover

@overload
def _get_result_dtype(self, dtype: ExtensionDtype) -> ExtensionDtype:
... # pragma: no cover

# TODO: general case implementation overridable by EAs.
def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
"""
Get the desired dtype of a result based on the
input dtype and how it was computed.
Parameters
----------
dtype : np.dtype or ExtensionDtype
Input dtype.
dtype : np.dtype
Returns
-------
np.dtype or ExtensionDtype
np.dtype
The desired dtype of the result.
"""
how = self.how

if how in ["add", "cumsum", "sum", "prod"]:
if dtype == np.dtype(bool):
return np.dtype(np.int64)
elif isinstance(dtype, (BooleanDtype, IntegerDtype)):
return Int64Dtype()
elif how in ["mean", "median", "var"]:
if isinstance(dtype, (BooleanDtype, IntegerDtype)):
return Float64Dtype()
elif is_float_dtype(dtype) or is_complex_dtype(dtype):
if is_float_dtype(dtype) or is_complex_dtype(dtype):
return dtype
elif is_numeric_dtype(dtype):
return np.dtype(np.float64)
Expand Down Expand Up @@ -390,8 +368,18 @@ def _reconstruct_ea_result(
Construct an ExtensionArray result from an ndarray result.
"""

if isinstance(values.dtype, (BaseMaskedDtype, StringDtype)):
dtype = self._get_result_dtype(values.dtype)
if isinstance(values.dtype, StringDtype):
dtype = values.dtype
cls = dtype.construct_array_type()
return cls._from_sequence(res_values, dtype=dtype)

elif isinstance(values.dtype, BaseMaskedDtype):
new_dtype = self._get_result_dtype(values.dtype.numpy_dtype)
# error: Incompatible types in assignment (expression has type
# "BaseMaskedDtype", variable has type "StringDtype")
dtype = BaseMaskedDtype.from_numpy_dtype( # type: ignore[assignment]
new_dtype
)
cls = dtype.construct_array_type()
return cls._from_sequence(res_values, dtype=dtype)

Expand Down Expand Up @@ -433,7 +421,8 @@ def _masked_ea_wrap_cython_operation(
**kwargs,
)

dtype = self._get_result_dtype(orig_values.dtype)
new_dtype = self._get_result_dtype(orig_values.dtype.numpy_dtype)
dtype = BaseMaskedDtype.from_numpy_dtype(new_dtype)
# TODO: avoid cast as res_values *should* already have the right
# dtype; last attempt ran into trouble on 32bit linux build
res_values = res_values.astype(dtype.type, copy=False)
Expand Down
8 changes: 7 additions & 1 deletion pandas/tests/groupby/aggregate/test_cython.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import numpy as np
import pytest

from pandas.core.dtypes.common import is_float_dtype
from pandas.core.dtypes.common import (
is_float_dtype,
is_integer_dtype,
)

import pandas as pd
from pandas import (
Expand Down Expand Up @@ -369,6 +372,9 @@ def test_cython_agg_EA_known_dtypes(data, op_name, action, with_na):
# for any int/bool use Int64, for float preserve dtype
if is_float_dtype(data.dtype):
expected_dtype = data.dtype
elif is_integer_dtype(data.dtype):
# match the numpy dtype we'd get with the non-nullable analogue
expected_dtype = data.dtype
else:
expected_dtype = pd.Int64Dtype()
elif action == "always_float":
Expand Down

0 comments on commit 9303302

Please sign in to comment.