diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 072caa6a4b74f..d05be3b003f67 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -1417,3 +1417,23 @@ def construct_array_type(cls) -> type_t[BaseMaskedArray]: type """ raise NotImplementedError + + @classmethod + def from_numpy_dtype(cls, dtype: np.dtype) -> BaseMaskedDtype: + """ + Construct the MaskedDtype corresponding to the given numpy dtype. + """ + if dtype.kind == "b": + from pandas.core.arrays.boolean import BooleanDtype + + return BooleanDtype() + elif dtype.kind in ["i", "u"]: + from pandas.core.arrays.integer import INT_STR_TO_DTYPE + + return INT_STR_TO_DTYPE[dtype.name] + elif dtype.kind == "f": + from pandas.core.arrays.floating import FLOAT_STR_TO_DTYPE + + return FLOAT_STR_TO_DTYPE[dtype.name] + else: + raise NotImplementedError(dtype) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index e984279c17a65..ca6adcda4396e 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -16,7 +16,6 @@ Iterator, Sequence, final, - overload, ) import numpy as np @@ -57,7 +56,6 @@ is_timedelta64_dtype, needs_i8_conversion, ) -from pandas.core.dtypes.dtypes import ExtensionDtype from pandas.core.dtypes.missing import ( isna, maybe_fill, @@ -70,14 +68,8 @@ TimedeltaArray, ) from pandas.core.arrays.boolean import BooleanDtype -from pandas.core.arrays.floating import ( - Float64Dtype, - FloatingDtype, -) -from pandas.core.arrays.integer import ( - Int64Dtype, - IntegerDtype, -) +from pandas.core.arrays.floating import FloatingDtype +from pandas.core.arrays.integer import IntegerDtype from pandas.core.arrays.masked import ( BaseMaskedArray, BaseMaskedDtype, @@ -277,28 +269,18 @@ def _get_out_dtype(self, dtype: np.dtype) -> np.dtype: out_dtype = "object" return np.dtype(out_dtype) - @overload def _get_result_dtype(self, dtype: np.dtype) -> np.dtype: - ... # pragma: no cover - - @overload - def _get_result_dtype(self, dtype: ExtensionDtype) -> ExtensionDtype: - ... # pragma: no cover - - # TODO: general case implementation overridable by EAs. - def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj: """ Get the desired dtype of a result based on the input dtype and how it was computed. Parameters ---------- - dtype : np.dtype or ExtensionDtype - Input dtype. + dtype : np.dtype Returns ------- - np.dtype or ExtensionDtype + np.dtype The desired dtype of the result. """ how = self.how @@ -306,12 +288,8 @@ def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj: if how in ["add", "cumsum", "sum", "prod"]: if dtype == np.dtype(bool): return np.dtype(np.int64) - elif isinstance(dtype, (BooleanDtype, IntegerDtype)): - return Int64Dtype() elif how in ["mean", "median", "var"]: - if isinstance(dtype, (BooleanDtype, IntegerDtype)): - return Float64Dtype() - elif is_float_dtype(dtype) or is_complex_dtype(dtype): + if is_float_dtype(dtype) or is_complex_dtype(dtype): return dtype elif is_numeric_dtype(dtype): return np.dtype(np.float64) @@ -390,8 +368,18 @@ def _reconstruct_ea_result( Construct an ExtensionArray result from an ndarray result. """ - if isinstance(values.dtype, (BaseMaskedDtype, StringDtype)): - dtype = self._get_result_dtype(values.dtype) + if isinstance(values.dtype, StringDtype): + dtype = values.dtype + cls = dtype.construct_array_type() + return cls._from_sequence(res_values, dtype=dtype) + + elif isinstance(values.dtype, BaseMaskedDtype): + new_dtype = self._get_result_dtype(values.dtype.numpy_dtype) + # error: Incompatible types in assignment (expression has type + # "BaseMaskedDtype", variable has type "StringDtype") + dtype = BaseMaskedDtype.from_numpy_dtype( # type: ignore[assignment] + new_dtype + ) cls = dtype.construct_array_type() return cls._from_sequence(res_values, dtype=dtype) @@ -433,7 +421,8 @@ def _masked_ea_wrap_cython_operation( **kwargs, ) - dtype = self._get_result_dtype(orig_values.dtype) + new_dtype = self._get_result_dtype(orig_values.dtype.numpy_dtype) + dtype = BaseMaskedDtype.from_numpy_dtype(new_dtype) # TODO: avoid cast as res_values *should* already have the right # dtype; last attempt ran into trouble on 32bit linux build res_values = res_values.astype(dtype.type, copy=False) diff --git a/pandas/tests/groupby/aggregate/test_cython.py b/pandas/tests/groupby/aggregate/test_cython.py index d9372ba5cbb50..96009be5d12e3 100644 --- a/pandas/tests/groupby/aggregate/test_cython.py +++ b/pandas/tests/groupby/aggregate/test_cython.py @@ -5,7 +5,10 @@ import numpy as np import pytest -from pandas.core.dtypes.common import is_float_dtype +from pandas.core.dtypes.common import ( + is_float_dtype, + is_integer_dtype, +) import pandas as pd from pandas import ( @@ -369,6 +372,9 @@ def test_cython_agg_EA_known_dtypes(data, op_name, action, with_na): # for any int/bool use Int64, for float preserve dtype if is_float_dtype(data.dtype): expected_dtype = data.dtype + elif is_integer_dtype(data.dtype): + # match the numpy dtype we'd get with the non-nullable analogue + expected_dtype = data.dtype else: expected_dtype = pd.Int64Dtype() elif action == "always_float":