Skip to content

Commit

Permalink
REF: simplify factorize, fix factorize_array return type (pandas-dev#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Mar 3, 2022
1 parent 5efb570 commit 4985b89
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 57 deletions.
63 changes: 26 additions & 37 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Hashable,
Literal,
Sequence,
Union,
cast,
final,
)
Expand All @@ -30,7 +29,6 @@
ArrayLike,
DtypeObj,
IndexLabel,
Scalar,
TakeIndexer,
npt,
)
Expand Down Expand Up @@ -105,9 +103,7 @@
)
from pandas.core.arrays import (
BaseMaskedArray,
DatetimeArray,
ExtensionArray,
TimedeltaArray,
)


Expand Down Expand Up @@ -539,13 +535,24 @@ def factorize_array(
codes : ndarray[np.intp]
uniques : ndarray
"""
original = values
if values.dtype.kind in ["m", "M"]:
# _get_hashtable_algo will cast dt64/td64 to i8 via _ensure_data, so we
# need to do the same to na_value. We are assuming here that the passed
# na_value is an appropriately-typed NaT.
# e.g. test_where_datetimelike_categorical
na_value = iNaT

hash_klass, values = _get_hashtable_algo(values)

table = hash_klass(size_hint or len(values))
uniques, codes = table.factorize(
values, na_sentinel=na_sentinel, na_value=na_value, mask=mask
)

# re-cast e.g. i8->dt64/td64, uint8->bool
uniques = _reconstruct_data(uniques, original.dtype, original)

codes = ensure_platform_int(codes)
return codes, uniques

Expand Down Expand Up @@ -720,33 +727,18 @@ def factorize(
isinstance(values, (ABCDatetimeArray, ABCTimedeltaArray))
and values.freq is not None
):
# The presence of 'freq' means we can fast-path sorting and know there
# aren't NAs
codes, uniques = values.factorize(sort=sort)
if isinstance(original, ABCIndex):
uniques = original._shallow_copy(uniques, name=None)
elif isinstance(original, ABCSeries):
from pandas import Index

uniques = Index(uniques)
return codes, uniques
return _re_wrap_factorize(original, uniques, codes)

if not isinstance(values.dtype, np.dtype):
# i.e. ExtensionDtype
codes, uniques = values.factorize(na_sentinel=na_sentinel)
dtype = original.dtype
else:
dtype = values.dtype
values = _ensure_data(values)
na_value: Scalar | None

if original.dtype.kind in ["m", "M"]:
# Note: factorize_array will cast NaT bc it has a __int__
# method, but will not cast the more-correct dtype.type("nat")
na_value = iNaT
else:
na_value = None

values = np.asarray(values) # convert DTA/TDA/MultiIndex
codes, uniques = factorize_array(
values, na_sentinel=na_sentinel, size_hint=size_hint, na_value=na_value
values, na_sentinel=na_sentinel, size_hint=size_hint
)

if sort and len(uniques) > 0:
Expand All @@ -759,23 +751,20 @@ def factorize(
# na_value is set based on the dtype of uniques, and compat set to False is
# because we do not want na_value to be 0 for integers
na_value = na_value_for_dtype(uniques.dtype, compat=False)
# Argument 2 to "append" has incompatible type "List[Union[str, float, Period,
# Timestamp, Timedelta, Any]]"; expected "Union[_SupportsArray[dtype[Any]],
# _NestedSequence[_SupportsArray[dtype[Any]]]
# , bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int,
# float, complex, str, bytes]]]" [arg-type]
uniques = np.append(uniques, [na_value]) # type: ignore[arg-type]
uniques = np.append(uniques, [na_value])
codes = np.where(code_is_na, len(uniques) - 1, codes)

uniques = _reconstruct_data(uniques, dtype, original)
uniques = _reconstruct_data(uniques, original.dtype, original)

return _re_wrap_factorize(original, uniques, codes)


# return original tenor
def _re_wrap_factorize(original, uniques, codes: np.ndarray):
"""
Wrap factorize results in Series or Index depending on original type.
"""
if isinstance(original, ABCIndex):
if original.dtype.kind in ["m", "M"] and isinstance(uniques, np.ndarray):
original._data = cast(
"Union[DatetimeArray, TimedeltaArray]", original._data
)
uniques = type(original._data)._simple_new(uniques, dtype=original.dtype)
uniques = ensure_wrapped_if_datetimelike(uniques)
uniques = original._shallow_copy(uniques, name=None)
elif isinstance(original, ABCSeries):
from pandas import Index
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def equals(self, other) -> bool:
return False
return bool(array_equivalent(self._ndarray, other._ndarray))

def _from_factorized(cls, values, original):
assert values.dtype == original._ndarray.dtype
return original._from_backing_data(values)

def _values_for_argsort(self) -> np.ndarray:
return self._ndarray

Expand Down
6 changes: 0 additions & 6 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2305,12 +2305,6 @@ def unique(self):
def _values_for_factorize(self):
return self._ndarray, -1

@classmethod
def _from_factorized(cls, uniques, original):
# ensure we have the same itemsize for codes
codes = coerce_indexer_dtype(uniques, original.dtype.categories)
return original._from_backing_data(codes)

def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
# make sure we have correct itemsize for resulting codes
res_values = coerce_indexer_dtype(res_values, self.dtype.categories)
Expand Down
9 changes: 1 addition & 8 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,14 +550,7 @@ def copy(self: DatetimeLikeArrayT, order="C") -> DatetimeLikeArrayT:
return new_obj

def _values_for_factorize(self):
# int64 instead of int ensures we have a "view" method
return self._ndarray, np.int64(iNaT)

@classmethod
def _from_factorized(
cls: type[DatetimeLikeArrayT], values, original: DatetimeLikeArrayT
) -> DatetimeLikeArrayT:
return cls(values, dtype=original.dtype)
return self._ndarray, self._internal_fill_value

# ------------------------------------------------------------------
# Validation Methods
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,8 +874,9 @@ def factorize(self, na_sentinel: int = -1) -> tuple[np.ndarray, ExtensionArray]:

codes, uniques = factorize_array(arr, na_sentinel=na_sentinel, mask=mask)

# the hashtables don't handle all different types of bits
uniques = uniques.astype(self.dtype.numpy_dtype, copy=False)
# check that factorize_array correctly preserves dtype.
assert uniques.dtype == self.dtype.numpy_dtype, (uniques.dtype, self.dtype)

uniques_ea = type(self)(uniques, np.zeros(len(uniques), dtype=bool))
return codes, uniques_ea

Expand Down
4 changes: 0 additions & 4 deletions pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,6 @@ def _from_sequence(
result = result.copy()
return cls(result)

@classmethod
def _from_factorized(cls, values, original) -> PandasArray:
return original._from_backing_data(values)

def _from_backing_data(self, arr: np.ndarray) -> PandasArray:
return type(self)(arr)

Expand Down

0 comments on commit 4985b89

Please sign in to comment.