Skip to content

Commit

Permalink
PERF: avoid cast in algos.rank (pandas-dev#46175)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Feb 28, 2022
1 parent e70d310 commit d801a4b
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 45 deletions.
4 changes: 2 additions & 2 deletions pandas/_libs/algos.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def is_monotonic(
# ----------------------------------------------------------------------

def rank_1d(
values: np.ndarray, # ndarray[iu_64_floating_obj_t, ndim=1]
values: np.ndarray, # ndarray[numeric_object_t, ndim=1]
labels: np.ndarray | None = ..., # const int64_t[:]=None
is_datetimelike: bool = ...,
ties_method=...,
Expand All @@ -111,7 +111,7 @@ def rank_1d(
na_option=...,
) -> np.ndarray: ... # np.ndarray[float64_t, ndim=1]
def rank_2d(
in_arr: np.ndarray, # ndarray[iu_64_floating_obj_t, ndim=2]
in_arr: np.ndarray, # ndarray[numeric_object_t, ndim=2]
axis: int = ...,
is_datetimelike: bool = ...,
ties_method=...,
Expand Down
106 changes: 68 additions & 38 deletions pandas/_libs/algos.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ cnp.import_array()

cimport pandas._libs.util as util
from pandas._libs.dtypes cimport (
iu_64_floating_obj_t,
numeric_object_t,
numeric_t,
)
Expand Down Expand Up @@ -821,30 +820,54 @@ def is_monotonic(ndarray[numeric_object_t, ndim=1] arr, bint timelike):
# rank_1d, rank_2d
# ----------------------------------------------------------------------

cdef iu_64_floating_obj_t get_rank_nan_fill_val(
bint rank_nans_highest,
iu_64_floating_obj_t[:] _=None
cdef numeric_object_t get_rank_nan_fill_val(
bint rank_nans_highest,
numeric_object_t[:] _=None
):
"""
Return the value we'll use to represent missing values when sorting depending
on if we'd like missing values to end up at the top/bottom. (The second parameter
is unused, but needed for fused type specialization)
"""
if rank_nans_highest:
if iu_64_floating_obj_t is object:
if numeric_object_t is object:
return Infinity()
elif iu_64_floating_obj_t is int64_t:
elif numeric_object_t is int64_t:
return util.INT64_MAX
elif iu_64_floating_obj_t is uint64_t:
elif numeric_object_t is int32_t:
return util.INT32_MAX
elif numeric_object_t is int16_t:
return util.INT16_MAX
elif numeric_object_t is int8_t:
return util.INT8_MAX
elif numeric_object_t is uint64_t:
return util.UINT64_MAX
elif numeric_object_t is uint32_t:
return util.UINT32_MAX
elif numeric_object_t is uint16_t:
return util.UINT16_MAX
elif numeric_object_t is uint8_t:
return util.UINT8_MAX
else:
return np.inf
else:
if iu_64_floating_obj_t is object:
if numeric_object_t is object:
return NegInfinity()
elif iu_64_floating_obj_t is int64_t:
elif numeric_object_t is int64_t:
return NPY_NAT
elif iu_64_floating_obj_t is uint64_t:
elif numeric_object_t is int32_t:
return util.INT32_MIN
elif numeric_object_t is int16_t:
return util.INT16_MIN
elif numeric_object_t is int8_t:
return util.INT8_MIN
elif numeric_object_t is uint64_t:
return 0
elif numeric_object_t is uint32_t:
return 0
elif numeric_object_t is uint16_t:
return 0
elif numeric_object_t is uint8_t:
return 0
else:
return -np.inf
Expand All @@ -853,7 +876,7 @@ cdef iu_64_floating_obj_t get_rank_nan_fill_val(
@cython.wraparound(False)
@cython.boundscheck(False)
def rank_1d(
ndarray[iu_64_floating_obj_t, ndim=1] values,
ndarray[numeric_object_t, ndim=1] values,
const intp_t[:] labels=None,
bint is_datetimelike=False,
ties_method="average",
Expand All @@ -866,7 +889,7 @@ def rank_1d(
Parameters
----------
values : array of iu_64_floating_obj_t values to be ranked
values : array of numeric_object_t values to be ranked
labels : np.ndarray[np.intp] or None
Array containing unique label for each group, with its ordering
matching up to the corresponding record in `values`. If not called
Expand Down Expand Up @@ -896,11 +919,11 @@ def rank_1d(
int64_t[::1] grp_sizes
intp_t[:] lexsort_indexer
float64_t[::1] out
ndarray[iu_64_floating_obj_t, ndim=1] masked_vals
iu_64_floating_obj_t[:] masked_vals_memview
ndarray[numeric_object_t, ndim=1] masked_vals
numeric_object_t[:] masked_vals_memview
uint8_t[:] mask
bint keep_na, nans_rank_highest, check_labels, check_mask
iu_64_floating_obj_t nan_fill_val
numeric_object_t nan_fill_val

tiebreak = tiebreakers[ties_method]
if tiebreak == TIEBREAK_FIRST:
Expand All @@ -921,22 +944,26 @@ def rank_1d(
check_labels = labels is not None

# For cases where a mask is not possible, we can avoid mask checks
check_mask = not (iu_64_floating_obj_t is uint64_t or
(iu_64_floating_obj_t is int64_t and not is_datetimelike))
check_mask = (
numeric_object_t is float32_t
or numeric_object_t is float64_t
or numeric_object_t is object
or (numeric_object_t is int64_t and is_datetimelike)
)

# Copy values into new array in order to fill missing data
# with mask, without obfuscating location of missing data
# in values array
if iu_64_floating_obj_t is object and values.dtype != np.object_:
if numeric_object_t is object and values.dtype != np.object_:
masked_vals = values.astype('O')
else:
masked_vals = values.copy()

if iu_64_floating_obj_t is object:
if numeric_object_t is object:
mask = missing.isnaobj(masked_vals)
elif iu_64_floating_obj_t is int64_t and is_datetimelike:
elif numeric_object_t is int64_t and is_datetimelike:
mask = (masked_vals == NPY_NAT).astype(np.uint8)
elif iu_64_floating_obj_t is float64_t:
elif numeric_object_t is float64_t or numeric_object_t is float32_t:
mask = np.isnan(masked_vals).astype(np.uint8)
else:
mask = np.zeros(shape=len(masked_vals), dtype=np.uint8)
Expand All @@ -948,7 +975,7 @@ def rank_1d(
# will flip the ordering to still end up with lowest rank.
# Symmetric logic applies to `na_option == 'bottom'`
nans_rank_highest = ascending ^ (na_option == 'top')
nan_fill_val = get_rank_nan_fill_val[iu_64_floating_obj_t](nans_rank_highest)
nan_fill_val = get_rank_nan_fill_val[numeric_object_t](nans_rank_highest)
if nans_rank_highest:
order = [masked_vals, mask]
else:
Expand Down Expand Up @@ -994,8 +1021,8 @@ cdef void rank_sorted_1d(
float64_t[::1] out,
int64_t[::1] grp_sizes,
const intp_t[:] sort_indexer,
# Can make const with cython3 (https://github.com/cython/cython/issues/3222)
iu_64_floating_obj_t[:] masked_vals,
# TODO(cython3): make const (https://github.com/cython/cython/issues/3222)
numeric_object_t[:] masked_vals,
const uint8_t[:] mask,
bint check_mask,
Py_ssize_t N,
Expand All @@ -1019,7 +1046,7 @@ cdef void rank_sorted_1d(
if labels is None.
sort_indexer : intp_t[:]
Array of indices which sorts masked_vals
masked_vals : iu_64_floating_obj_t[:]
masked_vals : numeric_object_t[:]
The values input to rank_1d, with missing values replaced by fill values
mask : uint8_t[:]
Array where entries are True if the value is missing, False otherwise.
Expand Down Expand Up @@ -1051,7 +1078,7 @@ cdef void rank_sorted_1d(
# that sorted value for retrieval back from the original
# values / masked_vals arrays
# TODO(cython3): de-duplicate once cython supports conditional nogil
if iu_64_floating_obj_t is object:
if numeric_object_t is object:
with gil:
for i in range(N):
at_end = i == N - 1
Expand Down Expand Up @@ -1259,7 +1286,7 @@ cdef void rank_sorted_1d(


def rank_2d(
ndarray[iu_64_floating_obj_t, ndim=2] in_arr,
ndarray[numeric_object_t, ndim=2] in_arr,
int axis=0,
bint is_datetimelike=False,
ties_method="average",
Expand All @@ -1274,13 +1301,13 @@ def rank_2d(
Py_ssize_t k, n, col
float64_t[::1, :] out # Column-major so columns are contiguous
int64_t[::1] grp_sizes
ndarray[iu_64_floating_obj_t, ndim=2] values
iu_64_floating_obj_t[:, :] masked_vals
ndarray[numeric_object_t, ndim=2] values
numeric_object_t[:, :] masked_vals
intp_t[:, :] sort_indexer
uint8_t[:, :] mask
TiebreakEnumType tiebreak
bint check_mask, keep_na, nans_rank_highest
iu_64_floating_obj_t nan_fill_val
numeric_object_t nan_fill_val

tiebreak = tiebreakers[ties_method]
if tiebreak == TIEBREAK_FIRST:
Expand All @@ -1290,29 +1317,32 @@ def rank_2d(
keep_na = na_option == 'keep'

# For cases where a mask is not possible, we can avoid mask checks
check_mask = not (iu_64_floating_obj_t is uint64_t or
(iu_64_floating_obj_t is int64_t and not is_datetimelike))
check_mask = (
numeric_object_t is float32_t
or numeric_object_t is float64_t
or numeric_object_t is object
or (numeric_object_t is int64_t and is_datetimelike)
)

if axis == 1:
values = np.asarray(in_arr).T.copy()
else:
values = np.asarray(in_arr).copy()

if iu_64_floating_obj_t is object:
if numeric_object_t is object:
if values.dtype != np.object_:
values = values.astype('O')

nans_rank_highest = ascending ^ (na_option == 'top')
if check_mask:
nan_fill_val = get_rank_nan_fill_val[iu_64_floating_obj_t](nans_rank_highest)
nan_fill_val = get_rank_nan_fill_val[numeric_object_t](nans_rank_highest)

if iu_64_floating_obj_t is object:
if numeric_object_t is object:
mask = missing.isnaobj2d(values).view(np.uint8)
elif iu_64_floating_obj_t is float64_t:
elif numeric_object_t is float64_t or numeric_object_t is float32_t:
mask = np.isnan(values).view(np.uint8)

# int64 and datetimelike
else:
# i.e. int64 and datetimelike
mask = (values == NPY_NAT).view(np.uint8)
np.putmask(values, mask, nan_fill_val)
else:
Expand Down
5 changes: 0 additions & 5 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,11 +1003,6 @@ def rank(
is_datetimelike = needs_i8_conversion(values.dtype)
values = _ensure_data(values)

if values.dtype.kind in ["i", "u", "f"]:
# rank_t includes only object, int64, uint64, float64
dtype = values.dtype.kind + "8"
values = values.astype(dtype, copy=False)

if values.ndim == 1:
ranks = algos.rank_1d(
values,
Expand Down

0 comments on commit d801a4b

Please sign in to comment.