Skip to content

Commit

Permalink
REF: dispatch Block.fillna to putmask/where (pandas-dev#45911)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Feb 11, 2022
1 parent 6f8d279 commit 769fc54
Showing 1 changed file with 41 additions and 50 deletions.
91 changes: 41 additions & 50 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,14 +1019,16 @@ def putmask(self, mask, new) -> list[Block]:
res_blocks.extend(rbs)
return res_blocks

def where(self, other, cond) -> list[Block]:
def where(self, other, cond, _downcast="infer") -> list[Block]:
"""
evaluate the block; return result block(s) from the result
Parameters
----------
other : a ndarray/object
cond : np.ndarray[bool], SparseArray[bool], or BooleanArray
_downcast : str or None, default "infer"
Private because we only specify it when calling from fillna.
Returns
-------
Expand Down Expand Up @@ -1066,7 +1068,7 @@ def where(self, other, cond) -> list[Block]:

block = self.coerce_to_target_dtype(other)
blocks = block.where(orig_other, cond)
return self._maybe_downcast(blocks, "infer")
return self._maybe_downcast(blocks, downcast=_downcast)

else:
# since _maybe_downcast would split blocks anyway, we
Expand All @@ -1083,7 +1085,7 @@ def where(self, other, cond) -> list[Block]:
oth = other[:, i : i + 1]

submask = cond[:, i : i + 1]
rbs = nb.where(oth, submask)
rbs = nb.where(oth, submask, _downcast=_downcast)
res_blocks.extend(rbs)
return res_blocks

Expand Down Expand Up @@ -1158,21 +1160,19 @@ def fillna(
if limit is not None:
mask[mask.cumsum(self.ndim - 1) > limit] = False

if self._can_hold_element(value):
nb = self if inplace else self.copy()
putmask_inplace(nb.values, mask, value)
return nb._maybe_downcast([nb], downcast)

elif self.ndim == 1 or self.shape[0] == 1:
blk = self.coerce_to_target_dtype(value)
# bc we have already cast, inplace=True may avoid an extra copy
return blk.fillna(value, limit=limit, inplace=True, downcast=None)

if inplace:
nbs = self.putmask(mask.T, value)
else:
# operate column-by-column
return self.split_and_operate(
type(self).fillna, value, limit=limit, inplace=inplace, downcast=None
)
# without _downcast, we would break
# test_fillna_dtype_conversion_equiv_replace
nbs = self.where(value, ~mask.T, _downcast=False)

# Note: blk._maybe_downcast vs self._maybe_downcast(nbs)
# makes a difference bc blk may have object dtype, which has
# different behavior in _maybe_downcast.
return extend_blocks(
[blk._maybe_downcast([blk], downcast=downcast) for blk in nbs]
)

def interpolate(
self,
Expand Down Expand Up @@ -1401,7 +1401,8 @@ def setitem(self, indexer, value):
else:
return self

def where(self, other, cond) -> list[Block]:
def where(self, other, cond, _downcast="infer") -> list[Block]:
# _downcast private bc we only specify it when calling from fillna
arr = self.values.T

cond = extract_bool_array(cond)
Expand Down Expand Up @@ -1429,7 +1430,7 @@ def where(self, other, cond) -> list[Block]:
# TestSetitemFloatIntervalWithIntIntervalValues
blk = self.coerce_to_target_dtype(orig_other)
nbs = blk.where(orig_other, orig_cond)
return self._maybe_downcast(nbs, "infer")
return self._maybe_downcast(nbs, downcast=_downcast)

elif isinstance(self, NDArrayBackedExtensionBlock):
# NB: not (yet) the same as
Expand Down Expand Up @@ -1485,39 +1486,29 @@ def fillna(
) -> list[Block]:
# Caller is responsible for validating limit; if int it is strictly positive

try:
new_values = self.values.fillna(value=value, limit=limit)
except (TypeError, ValueError) as err:
_catch_deprecated_value_error(err)

if is_interval_dtype(self.dtype):
# Discussion about what we want to support in the general
# case GH#39584
blk = self.coerce_to_target_dtype(value)
return blk.fillna(value, limit, inplace, downcast)

elif isinstance(self, NDArrayBackedExtensionBlock):
# We support filling a DatetimeTZ with a `value` whose timezone
# is different by coercing to object.
if self.dtype.kind == "m":
# GH#45746
warnings.warn(
"The behavior of fillna with timedelta64[ns] dtype and "
f"an incompatible value ({type(value)}) is deprecated. "
"In a future version, this will cast to a common dtype "
"(usually object) instead of raising, matching the "
"behavior of other dtypes.",
FutureWarning,
stacklevel=find_stack_level(),
)
raise
blk = self.coerce_to_target_dtype(value)
return blk.fillna(value, limit, inplace, downcast)

else:
if self.dtype.kind == "m":
try:
res_values = self.values.fillna(value, limit=limit)
except (ValueError, TypeError):
# GH#45746
warnings.warn(
"The behavior of fillna with timedelta64[ns] dtype and "
f"an incompatible value ({type(value)}) is deprecated. "
"In a future version, this will cast to a common dtype "
"(usually object) instead of raising, matching the "
"behavior of other dtypes.",
FutureWarning,
stacklevel=find_stack_level(),
)
raise
else:
res_blk = self.make_block(res_values)
return [res_blk]

return [self.make_block_same_class(values=new_values)]
# TODO: since this now dispatches to super, which in turn dispatches
# to putmask, it may *actually* respect 'inplace=True'. If so, add
# tests for this.
return super().fillna(value, limit=limit, inplace=inplace, downcast=downcast)

def delete(self, loc) -> Block:
# This will be unnecessary if/when __array_function__ is implemented
Expand Down

0 comments on commit 769fc54

Please sign in to comment.