Skip to content

Commit

Permalink
TYP: groupby, sorting (pandas-dev#46133)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Feb 27, 2022
1 parent ffeb205 commit 661d88e
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 18 deletions.
6 changes: 3 additions & 3 deletions pandas/core/groupby/grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ class Grouping:
* groups : dict of {group -> label_list}
"""

_codes: np.ndarray | None = None
_codes: npt.NDArray[np.signedinteger] | None = None
_group_index: Index | None = None
_passed_categorical: bool
_all_grouper: Categorical | None
Expand Down Expand Up @@ -614,7 +614,7 @@ def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
return values._reverse_indexer()

@property
def codes(self) -> np.ndarray:
def codes(self) -> npt.NDArray[np.signedinteger]:
if self._codes is not None:
# _codes is set in __init__ for MultiIndex cases
return self._codes
Expand Down Expand Up @@ -657,7 +657,7 @@ def group_index(self) -> Index:
return Index._with_infer(uniques, name=self.name)

@cache_readonly
def _codes_and_uniques(self) -> tuple[np.ndarray, ArrayLike]:
def _codes_and_uniques(self) -> tuple[npt.NDArray[np.signedinteger], ArrayLike]:
if self._passed_categorical:
# we make a CategoricalIndex out of the cat grouper
# preserving the categories / ordered attributes
Expand Down
9 changes: 6 additions & 3 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:

@final
@property
def codes(self) -> list[np.ndarray]:
def codes(self) -> list[npt.NDArray[np.signedinteger]]:
return [ping.codes for ping in self.groupings]

@property
Expand Down Expand Up @@ -860,11 +860,14 @@ def codes_info(self) -> npt.NDArray[np.intp]:
return ids

@final
def _get_compressed_codes(self) -> tuple[np.ndarray, npt.NDArray[np.intp]]:
def _get_compressed_codes(
self,
) -> tuple[npt.NDArray[np.signedinteger], npt.NDArray[np.intp]]:
# The first returned ndarray may have any signed integer dtype
if len(self.groupings) > 1:
group_index = get_group_index(self.codes, self.shape, sort=True, xnull=True)
return compress_group_index(group_index, sort=self._sort)
# FIXME: compress_group_index's second return value is int64, not intp

ping = self.groupings[0]
return ping.codes, np.arange(len(ping.group_index), dtype=np.intp)
Expand All @@ -875,7 +878,7 @@ def ngroups(self) -> int:
return len(self.result_index)

@property
def reconstructed_codes(self) -> list[np.ndarray]:
def reconstructed_codes(self) -> list[npt.NDArray[np.intp]]:
codes = self.codes
ids, obs_ids, _ = self.group_info
return decons_obs_group_ids(ids, obs_ids, self.shape, codes, xnull=True)
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2187,7 +2187,9 @@ def _drop_level_numbers(self, levnums: list[int]):
verify_integrity=False,
)

def _get_grouper_for_level(self, mapper, *, level=None):
def _get_grouper_for_level(
self, mapper, *, level=None
) -> tuple[Index, npt.NDArray[np.signedinteger] | None, Index | None]:
"""
Get index grouper corresponding to an index level
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,9 @@ def _set_names(self, names, *, level=None, validate: bool = True):
# --------------------------------------------------------------------

@doc(Index._get_grouper_for_level)
def _get_grouper_for_level(self, mapper, *, level):
def _get_grouper_for_level(
self, mapper, *, level=None
) -> tuple[Index, npt.NDArray[np.signedinteger] | None, Index | None]:
indexer = self.codes[level]
level_index = self.levels[level]

Expand Down
1 change: 0 additions & 1 deletion pandas/core/reshape/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def _make_selectors(self):

self.group_index = comp_index
self.mask = mask
self.unique_groups = obs_ids
self.compressor = comp_index.searchsorted(np.arange(ngroups))

@cache_readonly
Expand Down
24 changes: 17 additions & 7 deletions pandas/core/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def is_int64_overflow_possible(shape: Shape) -> bool:
return the_prod >= lib.i8max


def decons_group_index(comp_labels, shape: Shape):
def _decons_group_index(
comp_labels: npt.NDArray[np.intp], shape: Shape
) -> list[npt.NDArray[np.intp]]:
# reconstruct labels
if is_int64_overflow_possible(shape):
# at some point group indices are factorized,
Expand All @@ -233,7 +235,7 @@ def decons_group_index(comp_labels, shape: Shape):

label_list = []
factor = 1
y = 0
y = np.array(0)
x = comp_labels
for i in reversed(range(len(shape))):
labels = (x - y) % (factor * shape[i]) // factor
Expand All @@ -245,24 +247,32 @@ def decons_group_index(comp_labels, shape: Shape):


def decons_obs_group_ids(
comp_ids: npt.NDArray[np.intp], obs_ids, shape: Shape, labels, xnull: bool
):
comp_ids: npt.NDArray[np.intp],
obs_ids: npt.NDArray[np.intp],
shape: Shape,
labels: Sequence[npt.NDArray[np.signedinteger]],
xnull: bool,
) -> list[npt.NDArray[np.intp]]:
"""
Reconstruct labels from observed group ids.
Parameters
----------
comp_ids : np.ndarray[np.intp]
obs_ids: np.ndarray[np.intp]
shape : tuple[int]
labels : Sequence[np.ndarray[np.signedinteger]]
xnull : bool
If nulls are excluded; i.e. -1 labels are passed through.
"""
if not xnull:
lift = np.fromiter(((a == -1).any() for a in labels), dtype="i8")
shape = np.asarray(shape, dtype="i8") + lift
lift = np.fromiter(((a == -1).any() for a in labels), dtype=np.intp)
arr_shape = np.asarray(shape, dtype=np.intp) + lift
shape = tuple(arr_shape)

if not is_int64_overflow_possible(shape):
# obs ids are deconstructable! take the fast route!
out = decons_group_index(obs_ids, shape)
out = _decons_group_index(obs_ids, shape)
return out if xnull or not lift.any() else [x - y for x, y in zip(out, lift)]

indexer = unique_label_indices(comp_ids)
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pandas.core.algorithms import safe_sort
import pandas.core.common as com
from pandas.core.sorting import (
decons_group_index,
_decons_group_index,
get_group_index,
is_int64_overflow_possible,
lexsort_indexer,
Expand Down Expand Up @@ -389,7 +389,7 @@ def align(df):
)
def test_decons(codes_list, shape):
group_index = get_group_index(codes_list, shape, sort=True, xnull=True)
codes_list2 = decons_group_index(group_index, shape)
codes_list2 = _decons_group_index(group_index, shape)

for a, b in zip(codes_list, codes_list2):
tm.assert_numpy_array_equal(a, b)
Expand Down

0 comments on commit 661d88e

Please sign in to comment.