From a3bcbf86cdd961fb304dcfd09fe3bad0ecb8a896 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 4 Nov 2021 01:43:43 +0100 Subject: [PATCH] BUG: Period.__eq__ numpy scalar (#44182 (comment)) (#44285) --- pandas/_libs/tslibs/period.pyx | 8 ++++++-- pandas/tests/arithmetic/test_period.py | 4 ++++ pandas/tests/scalar/period/test_period.py | 10 ++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/pandas/_libs/tslibs/period.pyx b/pandas/_libs/tslibs/period.pyx index 337876d610c5e..f594e0a8bdafd 100644 --- a/pandas/_libs/tslibs/period.pyx +++ b/pandas/_libs/tslibs/period.pyx @@ -1657,8 +1657,12 @@ cdef class _Period(PeriodMixin): elif other is NaT: return _nat_scalar_rules[op] elif util.is_array(other): - # in particular ndarray[object]; see test_pi_cmp_period - return np.array([PyObject_RichCompare(self, x, op) for x in other]) + # GH#44285 + if cnp.PyArray_IsZeroDim(other): + return PyObject_RichCompare(self, other.item(), op) + else: + # in particular ndarray[object]; see test_pi_cmp_period + return np.array([PyObject_RichCompare(self, x, op) for x in other]) return NotImplemented def __hash__(self): diff --git a/pandas/tests/arithmetic/test_period.py b/pandas/tests/arithmetic/test_period.py index d7cb314743e86..41c2cb2cc4f1e 100644 --- a/pandas/tests/arithmetic/test_period.py +++ b/pandas/tests/arithmetic/test_period.py @@ -189,6 +189,10 @@ def test_pi_cmp_period(self): result = idx.values.reshape(10, 2) < idx[10] tm.assert_numpy_array_equal(result, exp.reshape(10, 2)) + # Tests Period.__richcmp__ against ndarray[object, ndim=0] + result = idx < np.array(idx[10]) + tm.assert_numpy_array_equal(result, exp) + # TODO: moved from test_datetime64; de-duplicate with version below def test_parr_cmp_period_scalar2(self, box_with_array): xbox = get_expected_box(box_with_array) diff --git a/pandas/tests/scalar/period/test_period.py b/pandas/tests/scalar/period/test_period.py index f1b8c1cfdd39b..cd1bf21753249 100644 --- a/pandas/tests/scalar/period/test_period.py +++ b/pandas/tests/scalar/period/test_period.py @@ -1148,6 +1148,16 @@ def test_period_cmp_nat(self): assert not left <= right assert not left >= right + @pytest.mark.parametrize( + "zerodim_arr, expected", + ((np.array(0), False), (np.array(Period("2000-01", "M")), True)), + ) + def test_comparison_numpy_zerodim_arr(self, zerodim_arr, expected): + p = Period("2000-01", "M") + + assert (p == zerodim_arr) is expected + assert (zerodim_arr == p) is expected + class TestArithmetic: def test_sub_delta(self):