Skip to content

Commit

Permalink
PERF: avoid object dtype cast for Categorical in _ensure_data (pandas…
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Mar 3, 2022
1 parent c2cf93f commit 87803d0
Showing 1 changed file with 13 additions and 20 deletions.
33 changes: 13 additions & 20 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def _ensure_data(values: ArrayLike) -> np.ndarray:
# extract_array would raise
values = extract_array(values, extract_numpy=True)

# we check some simple dtypes first
if is_object_dtype(values.dtype):
return ensure_object(np.asarray(values))

Expand All @@ -149,17 +148,19 @@ def _ensure_data(values: ArrayLike) -> np.ndarray:
return _ensure_data(values._data)
return np.asarray(values)

elif is_categorical_dtype(values.dtype):
# NB: cases that go through here should NOT be using _reconstruct_data
# on the back-end.
values = cast("Categorical", values)
return values.codes

elif is_bool_dtype(values.dtype):
if isinstance(values, np.ndarray):
# i.e. actually dtype == np.dtype("bool")
return np.asarray(values).view("uint8")
else:
# i.e. all-bool Categorical, BooleanArray
try:
return np.asarray(values).astype("uint8", copy=False)
except (TypeError, ValueError):
# GH#42107 we have pd.NAs present
return np.asarray(values)
# e.g. Sparse[bool, False] # TODO: no test cases get here
return np.asarray(values).astype("uint8", copy=False)

elif is_integer_dtype(values.dtype):
return np.asarray(values)
Expand All @@ -174,10 +175,7 @@ def _ensure_data(values: ArrayLike) -> np.ndarray:
return np.asarray(values)

elif is_complex_dtype(values.dtype):
# Incompatible return value type (got "Tuple[Union[Any, ExtensionArray,
# ndarray[Any, Any]], Union[Any, ExtensionDtype]]", expected
# "Tuple[ndarray[Any, Any], Union[dtype[Any], ExtensionDtype]]")
return values # type: ignore[return-value]
return cast(np.ndarray, values)

# datetimelike
elif needs_i8_conversion(values.dtype):
Expand All @@ -187,11 +185,6 @@ def _ensure_data(values: ArrayLike) -> np.ndarray:
npvalues = cast(np.ndarray, npvalues)
return npvalues

elif is_categorical_dtype(values.dtype):
values = cast("Categorical", values)
values = values.codes
return values

# we have failed, return object
values = np.asarray(values, dtype=object)
return ensure_object(values)
Expand All @@ -218,7 +211,8 @@ def _reconstruct_data(
return values

if not isinstance(dtype, np.dtype):
# i.e. ExtensionDtype
# i.e. ExtensionDtype; note we have ruled out above the possibility
# that values.dtype == dtype
cls = dtype.construct_array_type()

values = cls._from_sequence(values, dtype=dtype)
Expand Down Expand Up @@ -938,9 +932,8 @@ def mode(values: ArrayLike, dropna: bool = True) -> ArrayLike:
if needs_i8_conversion(values.dtype):
# Got here with ndarray; dispatch to DatetimeArray/TimedeltaArray.
values = ensure_wrapped_if_datetimelike(values)
# error: Item "ndarray[Any, Any]" of "Union[ExtensionArray,
# ndarray[Any, Any]]" has no attribute "_mode"
return values._mode(dropna=dropna) # type: ignore[union-attr]
values = cast("ExtensionArray", values)
return values._mode(dropna=dropna)

values = _ensure_data(values)

Expand Down

0 comments on commit 87803d0

Please sign in to comment.