Skip to content

Commit

Permalink
REF: simplify groupby.ops (pandas-dev#46196)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Mar 3, 2022
1 parent 4985b89 commit c2cf93f
Showing 1 changed file with 45 additions and 36 deletions.
81 changes: 45 additions & 36 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ def __init__(self, kind: str, how: str):
"min": "group_min",
"max": "group_max",
"mean": "group_mean",
"median": "group_median",
"median": "group_median_float64",
"var": "group_var",
"first": "group_nth",
"last": "group_last",
"ohlc": "group_ohlc",
},
"transform": {
"cumprod": "group_cumprod",
"cumprod": "group_cumprod_float64",
"cumsum": "group_cumsum",
"cummin": "group_cummin",
"cummax": "group_cummax",
Expand Down Expand Up @@ -161,52 +161,54 @@ def _get_cython_function(
if is_numeric:
return f
elif dtype == object:
if "object" not in f.__signatures__:
if how in ["median", "cumprod"]:
# no fused types -> no __signatures__
raise NotImplementedError(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{dtype_str}]"
)
elif "object" not in f.__signatures__:
# raise NotImplementedError here rather than TypeError later
raise NotImplementedError(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{dtype_str}]"
)
return f
else:
raise NotImplementedError(
"This should not be reached. Please report a bug at "
"github.com/pandas-dev/pandas/",
dtype,
)

def get_cython_func_and_vals(self, values: np.ndarray, is_numeric: bool):
def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
"""
Find the appropriate cython function, casting if necessary.
Cast numeric dtypes to float64 for functions that only support that.
Parameters
----------
values : np.ndarray
is_numeric : bool
Returns
-------
func : callable
values : np.ndarray
"""
how = self.how
kind = self.kind

if how in ["median", "cumprod"]:
# these two only have float64 implementations
if is_numeric:
values = ensure_float64(values)
else:
raise NotImplementedError(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{values.dtype.name}]"
)
func = getattr(libgroupby, f"group_{how}_float64")
return func, values

func = self._get_cython_function(kind, how, values.dtype, is_numeric)
# We should only get here with is_numeric, as non-numeric cases
# should raise in _get_cython_function
values = ensure_float64(values)

if values.dtype.kind in ["i", "u"]:
elif values.dtype.kind in ["i", "u"]:
if how in ["add", "var", "prod", "mean", "ohlc"]:
# result may still include NaN, so we have to cast
values = ensure_float64(values)

return func, values
return values

# TODO: general case implementation overridable by EAs.
def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
"""
Check if we can do this operation with our cython functions.
Expand Down Expand Up @@ -235,6 +237,7 @@ def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
# are not setup for dim transforming
raise NotImplementedError(f"{dtype} dtype not supported")
elif is_datetime64_any_dtype(dtype):
# TODO: same for period_dtype? no for these methods with Period
# we raise NotImplemented if this is an invalid operation
# entirely, e.g. adding datetimes
if how in ["add", "prod", "cumsum", "cumprod"]:
Expand Down Expand Up @@ -262,7 +265,7 @@ def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
out_shape = (ngroups,) + values.shape[1:]
return out_shape

def get_out_dtype(self, dtype: np.dtype) -> np.dtype:
def _get_out_dtype(self, dtype: np.dtype) -> np.dtype:
how = self.how

if how == "rank":
Expand All @@ -282,6 +285,7 @@ def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
def _get_result_dtype(self, dtype: ExtensionDtype) -> ExtensionDtype:
... # pragma: no cover

# TODO: general case implementation overridable by EAs.
def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
"""
Get the desired dtype of a result based on the
Expand Down Expand Up @@ -329,7 +333,6 @@ def _ea_wrap_cython_operation(
If we have an ExtensionArray, unwrap, call _cython_operation, and
re-wrap if appropriate.
"""
# TODO: general case implementation overridable by EAs.
if isinstance(values, BaseMaskedArray) and self.uses_mask():
return self._masked_ea_wrap_cython_operation(
values,
Expand Down Expand Up @@ -357,7 +360,8 @@ def _ea_wrap_cython_operation(

return self._reconstruct_ea_result(values, res_values)

def _ea_to_cython_values(self, values: ExtensionArray):
# TODO: general case implementation overridable by EAs.
def _ea_to_cython_values(self, values: ExtensionArray) -> np.ndarray:
# GH#43682
if isinstance(values, (DatetimeArray, PeriodArray, TimedeltaArray)):
# All of the functions implemented here are ordinal, so we can
Expand All @@ -378,23 +382,24 @@ def _ea_to_cython_values(self, values: ExtensionArray):
)
return npvalues

def _reconstruct_ea_result(self, values, res_values):
# TODO: general case implementation overridable by EAs.
def _reconstruct_ea_result(
self, values: ExtensionArray, res_values: np.ndarray
) -> ExtensionArray:
"""
Construct an ExtensionArray result from an ndarray result.
"""
# TODO: allow EAs to override this logic

if isinstance(
values.dtype, (BooleanDtype, IntegerDtype, FloatingDtype, StringDtype)
):
if isinstance(values.dtype, (BaseMaskedDtype, StringDtype)):
dtype = self._get_result_dtype(values.dtype)
cls = dtype.construct_array_type()
return cls._from_sequence(res_values, dtype=dtype)

elif needs_i8_conversion(values.dtype):
assert res_values.dtype.kind != "f" # just to be on the safe side
i8values = res_values.view("i8")
return type(values)(i8values, dtype=values.dtype)
# error: Too many arguments for "ExtensionArray"
return type(values)(i8values, dtype=values.dtype) # type: ignore[call-arg]

raise NotImplementedError

Expand Down Expand Up @@ -429,13 +434,16 @@ def _masked_ea_wrap_cython_operation(
)

dtype = self._get_result_dtype(orig_values.dtype)
assert isinstance(dtype, BaseMaskedDtype)
cls = dtype.construct_array_type()
# TODO: avoid cast as res_values *should* already have the right
# dtype; last attempt ran into trouble on 32bit linux build
res_values = res_values.astype(dtype.type, copy=False)

if self.kind != "aggregate":
return cls(res_values.astype(dtype.type, copy=False), mask)
out_mask = mask
else:
return cls(res_values.astype(dtype.type, copy=False), result_mask)
out_mask = result_mask

return orig_values._maybe_mask_result(res_values, out_mask)

@final
def _cython_op_ndim_compat(
Expand Down Expand Up @@ -521,8 +529,9 @@ def _call_cython_op(
result_mask = result_mask.T

out_shape = self._get_output_shape(ngroups, values)
func, values = self.get_cython_func_and_vals(values, is_numeric)
out_dtype = self.get_out_dtype(values.dtype)
func = self._get_cython_function(self.kind, self.how, values.dtype, is_numeric)
values = self._get_cython_vals(values)
out_dtype = self._get_out_dtype(values.dtype)

result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
if self.kind == "aggregate":
Expand Down

0 comments on commit c2cf93f

Please sign in to comment.