Skip to content

Commit

Permalink
REF: implement _reconstruct_ea_result, _ea_to_cython_values (pandas-d…
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Feb 28, 2022
1 parent 0cceffd commit 367f8a1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 19 deletions.
8 changes: 8 additions & 0 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,10 @@ cdef group_min_max(iu_64_floating_t[:, ::1] out,
else:
if uses_mask:
result_mask[i, j] = True
# set out[i, j] to 0 to be deterministic, as
# it was initialized with np.empty. Also ensures
# we can downcast out if appropriate.
out[i, j] = 0
else:
out[i, j] = nan_val
else:
Expand Down Expand Up @@ -1494,6 +1498,10 @@ cdef group_cummin_max(iu_64_floating_t[:, ::1] out,
if not skipna and na_possible and seen_na[lab, j]:
if uses_mask:
mask[i, j] = 1 # FIXME: shouldn't alter inplace
# Set to 0 ensures that we are deterministic and can
# downcast if appropriate
out[i, j] = 0

else:
out[i, j] = na_val
else:
Expand Down
47 changes: 28 additions & 19 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,26 @@ def _ea_wrap_cython_operation(
**kwargs,
)

npvalues = self._ea_to_cython_values(values)

res_values = self._cython_op_ndim_compat(
npvalues,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
mask=None,
**kwargs,
)

if self.how in ["rank"]:
# i.e. how in WrappedCythonOp.cast_blocklist, since
# other cast_blocklist methods dont go through cython_operation
return res_values

return self._reconstruct_ea_result(values, res_values)

def _ea_to_cython_values(self, values: ExtensionArray):
# GH#43682
if isinstance(values, (DatetimeArray, PeriodArray, TimedeltaArray)):
# All of the functions implemented here are ordinal, so we can
# operate on the tz-naive equivalents
Expand All @@ -356,22 +376,7 @@ def _ea_wrap_cython_operation(
raise NotImplementedError(
f"function is not implemented for this dtype: {values.dtype}"
)

res_values = self._cython_op_ndim_compat(
npvalues,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
mask=None,
**kwargs,
)

if self.how in ["rank"]:
# i.e. how in WrappedCythonOp.cast_blocklist, since
# other cast_blocklist methods dont go through cython_operation
return res_values

return self._reconstruct_ea_result(values, res_values)
return npvalues

def _reconstruct_ea_result(self, values, res_values):
"""
Expand All @@ -387,6 +392,7 @@ def _reconstruct_ea_result(self, values, res_values):
return cls._from_sequence(res_values, dtype=dtype)

elif needs_i8_conversion(values.dtype):
assert res_values.dtype.kind != "f" # just to be on the safe side
i8values = res_values.view("i8")
return type(values)(i8values, dtype=values.dtype)

Expand Down Expand Up @@ -577,9 +583,12 @@ def _call_cython_op(
cutoff = max(1, min_count)
empty_groups = counts < cutoff
if empty_groups.any():
# Note: this conversion could be lossy, see GH#40767
result = result.astype("float64")
result[empty_groups] = np.nan
if result_mask is not None and self.uses_mask():
assert result_mask[empty_groups].all()
else:
# Note: this conversion could be lossy, see GH#40767
result = result.astype("float64")
result[empty_groups] = np.nan

result = result.T

Expand Down

0 comments on commit 367f8a1

Please sign in to comment.