Skip to content

Commit

Permalink
TYP: GroupBy (pandas-dev#43806)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Sep 30, 2021
1 parent 742ab04 commit dfe958c
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 23 deletions.
3 changes: 2 additions & 1 deletion pandas/_libs/lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ from typing import (
Any,
Callable,
Generator,
Hashable,
Literal,
overload,
)
Expand Down Expand Up @@ -197,7 +198,7 @@ def indices_fast(
labels: np.ndarray, # const int64_t[:]
keys: list,
sorted_labels: list[npt.NDArray[np.int64]],
) -> dict: ...
) -> dict[Hashable, npt.NDArray[np.intp]]: ...
def generate_slices(
labels: np.ndarray, ngroups: int # const intp_t[:]
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]: ...
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/groupby/grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pandas._typing import (
ArrayLike,
NDFrameT,
npt,
)
from pandas.errors import InvalidIndexError
from pandas.util._decorators import cache_readonly
Expand Down Expand Up @@ -604,7 +605,7 @@ def ngroups(self) -> int:
return len(self.group_index)

@cache_readonly
def indices(self):
def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
# we have a list of groupers
if isinstance(self.grouping_vector, ops.BaseGrouper):
return self.grouping_vector.indices
Expand Down
20 changes: 12 additions & 8 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def apply(
return result_values, mutated

@cache_readonly
def indices(self):
def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
"""dict {group name -> group indices}"""
if len(self.groupings) == 1 and isinstance(self.result_index, CategoricalIndex):
# This shows unused categories in indices GH#38642
Expand Down Expand Up @@ -807,7 +807,7 @@ def is_monotonic(self) -> bool:
return Index(self.group_info[0]).is_monotonic

@cache_readonly
def group_info(self):
def group_info(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
comp_ids, obs_group_ids = self._get_compressed_codes()

ngroups = len(obs_group_ids)
Expand All @@ -817,22 +817,26 @@ def group_info(self):

@final
@cache_readonly
def codes_info(self) -> np.ndarray:
def codes_info(self) -> npt.NDArray[np.intp]:
# return the codes of items in original grouped axis
ids, _, _ = self.group_info
if self.indexer is not None:
sorter = np.lexsort((ids, self.indexer))
ids = ids[sorter]
ids = ensure_platform_int(ids)
# TODO: if numpy annotates np.lexsort, this ensure_platform_int
# may become unnecessary
return ids

@final
def _get_compressed_codes(self) -> tuple[np.ndarray, np.ndarray]:
def _get_compressed_codes(self) -> tuple[np.ndarray, npt.NDArray[np.intp]]:
# The first returned ndarray may have any signed integer dtype
if len(self.groupings) > 1:
group_index = get_group_index(self.codes, self.shape, sort=True, xnull=True)
return compress_group_index(group_index, sort=self._sort)

ping = self.groupings[0]
return ping.codes, np.arange(len(ping.group_index))
return ping.codes, np.arange(len(ping.group_index), dtype=np.intp)

@final
@cache_readonly
Expand Down Expand Up @@ -1017,7 +1021,7 @@ class BinGrouper(BaseGrouper):
"""

bins: np.ndarray # np.ndarray[np.int64]
bins: npt.NDArray[np.int64]
binlabels: Index
mutated: bool

Expand Down Expand Up @@ -1101,9 +1105,9 @@ def indices(self):
return indices

@cache_readonly
def group_info(self):
def group_info(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
ngroups = self.ngroups
obs_group_ids = np.arange(ngroups, dtype=np.int64)
obs_group_ids = np.arange(ngroups, dtype=np.intp)
rep = np.diff(np.r_[0, self.bins])

rep = ensure_platform_int(rep)
Expand Down
22 changes: 11 additions & 11 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __repr__(self) -> str:
def __len__(self) -> int:
return len(self.values)

def _slice(self, slicer):
def _slice(self, slicer) -> ArrayLike:
"""return a slice of my values"""

return self.values[slicer]
Expand Down Expand Up @@ -344,7 +344,7 @@ def dtype(self) -> DtypeObj:
def iget(self, i):
return self.values[i]

def set_inplace(self, locs, values):
def set_inplace(self, locs, values) -> None:
"""
Modify block values in-place with new item value.
Expand Down Expand Up @@ -563,13 +563,13 @@ def _downcast_2d(self) -> list[Block]:
return [self.make_block(new_values)]

@final
def astype(self, dtype, copy: bool = False, errors: str = "raise"):
def astype(self, dtype: DtypeObj, copy: bool = False, errors: str = "raise"):
"""
Coerce to the new dtype.
Parameters
----------
dtype : str, dtype convertible
dtype : np.dtype or ExtensionDtype
copy : bool, default False
copy if indicated
errors : str, {'raise', 'ignore'}, default 'raise'
Expand Down Expand Up @@ -1441,7 +1441,7 @@ def iget(self, col):
raise IndexError(f"{self} only contains one item")
return self.values

def set_inplace(self, locs, values):
def set_inplace(self, locs, values) -> None:
# NB: This is a misnomer, is supposed to be inplace but is not,
# see GH#33457
assert locs.tolist() == [0]
Expand Down Expand Up @@ -1509,7 +1509,7 @@ def setitem(self, indexer, value):
# https://github.com/pandas-dev/pandas/issues/24020
# Need a dedicated setitem until GH#24020 (type promotion in setitem
# for extension arrays) is designed and implemented.
return self.astype(object).setitem(indexer, value)
return self.astype(_dtype_obj).setitem(indexer, value)

if isinstance(indexer, tuple):
# TODO(EA2D): not needed with 2D EAs
Expand Down Expand Up @@ -1547,7 +1547,7 @@ def take_nd(

return self.make_block_same_class(new_values, new_mgr_locs)

def _slice(self, slicer):
def _slice(self, slicer) -> ExtensionArray:
"""
Return a slice of my values.
Expand All @@ -1558,7 +1558,7 @@ def _slice(self, slicer):
Returns
-------
np.ndarray or ExtensionArray
ExtensionArray
"""
# return same dims as we currently have
if not isinstance(slicer, tuple) and self.ndim == 2:
Expand Down Expand Up @@ -1736,7 +1736,7 @@ def is_view(self) -> bool:
def setitem(self, indexer, value):
if not self._can_hold_element(value):
# TODO: general case needs casting logic.
return self.astype(object).setitem(indexer, value)
return self.astype(_dtype_obj).setitem(indexer, value)

values = self.values
if self.ndim > 1:
Expand All @@ -1750,7 +1750,7 @@ def putmask(self, mask, new) -> list[Block]:
mask = extract_bool_array(mask)

if not self._can_hold_element(new):
return self.astype(object).putmask(mask, new)
return self.astype(_dtype_obj).putmask(mask, new)

arr = self.values
arr.T.putmask(mask, new)
Expand Down Expand Up @@ -1808,7 +1808,7 @@ def fillna(
# We support filling a DatetimeTZ with a `value` whose timezone
# is different by coercing to object.
# TODO: don't special-case td64
return self.astype(object).fillna(value, limit, inplace, downcast)
return self.astype(_dtype_obj).fillna(value, limit, inplace, downcast)

values = self.values
values = values if inplace else values.copy()
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
TYPE_CHECKING,
Callable,
DefaultDict,
Hashable,
Iterable,
Sequence,
)
Expand Down Expand Up @@ -576,7 +577,7 @@ def get_flattened_list(

def get_indexer_dict(
label_list: list[np.ndarray], keys: list[Index]
) -> dict[str | tuple, np.ndarray]:
) -> dict[Hashable, npt.NDArray[np.intp]]:
"""
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/groupby/test_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def test_groupby_empty(self):
)

tm.assert_numpy_array_equal(
gr.grouper.group_info[1], np.array([], dtype=np.dtype("int"))
gr.grouper.group_info[1], np.array([], dtype=np.dtype(np.intp))
)

assert gr.grouper.group_info[2] == 0
Expand Down

0 comments on commit dfe958c

Please sign in to comment.