Skip to content

Commit

Permalink
BUG: DataFrame[dt64].where downcasting (pandas-dev#45837)
Browse files Browse the repository at this point in the history
* BUG: DataFrame[dt64].where downcasting

* whatsnew

* GH ref

* punt on dt64

* mypy fixup
  • Loading branch information
jbrockmendel authored Feb 11, 2022
1 parent 5c89090 commit c1a492c
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 16 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ Indexing
- Bug in setting a NA value (``None`` or ``np.nan``) into a :class:`Series` with int-based :class:`IntervalDtype` incorrectly casting to object dtype instead of a float-based :class:`IntervalDtype` (:issue:`45568`)
- Bug in :meth:`Series.__setitem__` with a non-integer :class:`Index` when using an integer key to set a value that cannot be set inplace where a ``ValueError`` was raised instead of casting to a common dtype (:issue:`45070`)
- Bug in :meth:`Series.__setitem__` when setting incompatible values into a ``PeriodDtype`` or ``IntervalDtype`` :class:`Series` raising when indexing with a boolean mask but coercing when indexing with otherwise-equivalent indexers; these now consistently coerce, along with :meth:`Series.mask` and :meth:`Series.where` (:issue:`45768`)
- Bug in :meth:`DataFrame.where` with multiple columns with datetime-like dtypes failing to downcast results consistent with other dtypes (:issue:`45837`)
- Bug in :meth:`Series.loc.__setitem__` and :meth:`Series.loc.__getitem__` not raising when using multiple keys without using a :class:`MultiIndex` (:issue:`13831`)
- Bug when setting a value too large for a :class:`Series` dtype failing to coerce to a common type (:issue:`26049`, :issue:`32878`)
- Bug in :meth:`loc.__setitem__` treating ``range`` keys as positional instead of label-based (:issue:`45479`)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def coerce_to_array(
raise TypeError("Need to pass bool-like values")

if mask is None and mask_values is None:
mask = np.zeros(len(values), dtype=bool)
mask = np.zeros(values.shape, dtype=bool)
elif mask is None:
mask = mask_values
else:
Expand Down
10 changes: 9 additions & 1 deletion pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
dtype = "int64"
elif inferred_type == "datetime64":
dtype = "datetime64[ns]"
elif inferred_type == "timedelta64":
elif inferred_type in ["timedelta", "timedelta64"]:
dtype = "timedelta64[ns]"

# try to upcast here
Expand Down Expand Up @@ -290,6 +290,14 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
if dtype.kind in ["M", "m"] and result.dtype.kind in ["i", "f"]:
result = result.astype(dtype)

elif dtype.kind == "m" and result.dtype == _dtype_obj:
# test_where_downcast_to_td64
result = cast(np.ndarray, result)
result = array_to_timedelta64(result)

elif dtype == "M8[ns]" and result.dtype == _dtype_obj:
return np.asarray(maybe_cast_to_datetime(result, dtype=dtype))

return result


Expand Down
43 changes: 31 additions & 12 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,21 +1426,40 @@ def where(self, other, cond, _downcast="infer") -> list[Block]:
except (ValueError, TypeError) as err:
_catch_deprecated_value_error(err)

if is_interval_dtype(self.dtype):
# TestSetitemFloatIntervalWithIntIntervalValues
blk = self.coerce_to_target_dtype(orig_other)
nbs = blk.where(orig_other, orig_cond)
return self._maybe_downcast(nbs, downcast=_downcast)
if self.ndim == 1 or self.shape[0] == 1:

elif isinstance(self, NDArrayBackedExtensionBlock):
# NB: not (yet) the same as
# isinstance(values, NDArrayBackedExtensionArray)
blk = self.coerce_to_target_dtype(orig_other)
nbs = blk.where(orig_other, orig_cond)
return self._maybe_downcast(nbs, "infer")
if is_interval_dtype(self.dtype):
# TestSetitemFloatIntervalWithIntIntervalValues
blk = self.coerce_to_target_dtype(orig_other)
nbs = blk.where(orig_other, orig_cond)
return self._maybe_downcast(nbs, downcast=_downcast)

elif isinstance(self, NDArrayBackedExtensionBlock):
# NB: not (yet) the same as
# isinstance(values, NDArrayBackedExtensionArray)
blk = self.coerce_to_target_dtype(orig_other)
nbs = blk.where(orig_other, orig_cond)
return self._maybe_downcast(nbs, downcast=_downcast)

else:
raise

else:
raise
# Same pattern we use in Block.putmask
is_array = isinstance(orig_other, (np.ndarray, ExtensionArray))

res_blocks = []
nbs = self._split()
for i, nb in enumerate(nbs):
n = orig_other
if is_array:
# we have a different value per-column
n = orig_other[:, i : i + 1]

submask = orig_cond[:, i : i + 1]
rbs = nb.where(n, submask)
res_blocks.extend(rbs)
return res_blocks

nb = self.make_block_same_class(res_values)
return [nb]
Expand Down
5 changes: 4 additions & 1 deletion pandas/tests/arrays/boolean/test_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,11 @@ def test_coerce_to_array():
values = np.array([True, False, True, False], dtype="bool")
mask = np.array([False, False, False, True], dtype="bool")

# passing 2D values is OK as long as no mask
coerce_to_array(values.reshape(1, -1))

with pytest.raises(ValueError, match="values.shape and mask.shape must match"):
coerce_to_array(values.reshape(1, -1))
coerce_to_array(values.reshape(1, -1), mask=mask)

with pytest.raises(ValueError, match="values.shape and mask.shape must match"):
coerce_to_array(values, mask=mask.reshape(1, -1))
Expand Down
12 changes: 11 additions & 1 deletion pandas/tests/dtypes/cast/test_downcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

from pandas.core.dtypes.cast import maybe_downcast_to_dtype

from pandas import Series
from pandas import (
Series,
Timedelta,
)
import pandas._testing as tm


Expand Down Expand Up @@ -34,6 +37,13 @@
"int64",
np.array([decimal.Decimal(0.0)]),
),
(
# GH#45837
np.array([Timedelta(days=1), Timedelta(days=2)], dtype=object),
"infer",
np.array([1, 2], dtype="m8[D]").astype("m8[ns]"),
),
# TODO: similar for dt64, dt64tz, Period, Interval?
],
)
def test_downcast(arr, expected, dtype):
Expand Down
55 changes: 55 additions & 0 deletions pandas/tests/frame/indexing/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,3 +965,58 @@ def test_where_inplace_casting(data):
df_copy = df.where(pd.notnull(df), None).copy()
df.where(pd.notnull(df), None, inplace=True)
tm.assert_equal(df, df_copy)


def test_where_downcast_to_td64():
ser = Series([1, 2, 3])

mask = np.array([False, False, False])

td = pd.Timedelta(days=1)

res = ser.where(mask, td)
expected = Series([td, td, td], dtype="m8[ns]")
tm.assert_series_equal(res, expected)


def _check_where_equivalences(df, mask, other, expected):
# similar to tests.series.indexing.test_setitem.SetitemCastingEquivalences
# but with DataFrame in mind and less fleshed-out
res = df.where(mask, other)
tm.assert_frame_equal(res, expected)

res = df.mask(~mask, other)
tm.assert_frame_equal(res, expected)

# Note: we cannot do the same with frame.mask(~mask, other, inplace=True)
# bc that goes through Block.putmask which does *not* downcast.


def test_where_dt64_2d():
dti = date_range("2016-01-01", periods=6)
dta = dti._data.reshape(3, 2)
other = dta - dta[0, 0]

df = DataFrame(dta, columns=["A", "B"])

mask = np.asarray(df.isna())
mask[:, 1] = True

# setting all of one column, none of the other
expected = DataFrame({"A": other[:, 0], "B": dta[:, 1]})
_check_where_equivalences(df, mask, other, expected)

# setting part of one column, none of the other
mask[1, 0] = True
expected = DataFrame(
{
"A": np.array([other[0, 0], dta[1, 0], other[2, 0]], dtype=object),
"B": dta[:, 1],
}
)
_check_where_equivalences(df, mask, other, expected)

# setting nothing in either column
mask[:] = True
expected = df
_check_where_equivalences(df, mask, other, expected)

0 comments on commit c1a492c

Please sign in to comment.