Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise on __eq__ binary operations between unsupported types. #11609

Draft
wants to merge 4 commits into
base: branch-24.06
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from cudf._typing import ColumnBinaryOperand, Dtype
from cudf.api.types import is_integer_dtype, is_scalar
from cudf.core.buffer import as_buffer
from cudf.core.column import ColumnBase
from cudf.core.column import ColumnBase, as_column
from cudf.core.dtypes import (
Decimal32Dtype,
Decimal64Dtype,
Expand Down Expand Up @@ -103,9 +103,14 @@ def __rtruediv__(self, other):

def _binaryop(self, other: ColumnBinaryOperand, op: str):
reflect, op = self._check_reflected_op(op)
other = self._wrap_binop_normalization(other)
if other is NotImplemented:
normalized_other = self._wrap_binop_normalization(other)
if normalized_other is NotImplemented:
if op in {"__eq__", "__ne__"} and isinstance(other, ColumnBase):
return as_column(
op != "__eq__", length=len(self), dtype="bool"
)
return NotImplemented
other = normalized_other
lhs, rhs = (other, self) if reflect else (self, other)

# Binary Arithmetics between decimal columns. `Scale` and `precision`
Expand Down
36 changes: 23 additions & 13 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing_extensions import Self

import cudf
import cudf._lib as libcudf
from cudf._lib.copying import segmented_gather
from cudf._lib.lists import (
concatenate_list_elements,
Expand All @@ -36,7 +37,7 @@

class ListColumn(ColumnBase):
dtype: ListDtype
_VALID_BINARY_OPERATIONS = {"__add__", "__radd__"}
_VALID_BINARY_OPERATIONS = {"__add__", "__radd__", "__eq__", "__ne__"}

def __init__(
self,
Expand Down Expand Up @@ -109,19 +110,28 @@ def base_size(self):
def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
# Lists only support __add__, which concatenates lists.
reflect, op = self._check_reflected_op(op)
other = self._wrap_binop_normalization(other)
if other is NotImplemented:
return NotImplemented
if isinstance(other.dtype, ListDtype):
if op == "__add__":
return concatenate_rows([self, other])
else:
raise NotImplementedError(
"Lists concatenation for this operation is not yet"
"supported"

normalized_other = self._wrap_binop_normalization(other)
if normalized_other is NotImplemented:
if op in {"__eq__", "__ne__"} and isinstance(other, ColumnBase):
return as_column(
op != "__eq__", length=len(self), dtype="bool"
)
else:
raise TypeError("can only concatenate list to list")
return NotImplemented
other = normalized_other

lhs, rhs = (other, self) if reflect else (self, other)

if isinstance(other.dtype, ListDtype) and op == "__add__":
return concatenate_rows([lhs, rhs])
elif op in {"__eq__", "__ne__"}:
return libcudf.binaryop.binaryop(
lhs=lhs, rhs=rhs, op=op, dtype="bool"
)
raise TypeError(
f"'{op}' not supported between instances of "
f"'{type(self).__name__}' and '{type(other).__name__}'"
)

@property
def elements(self):
Expand Down
16 changes: 15 additions & 1 deletion python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,22 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
return self.astype(truediv_type)._binaryop(other, op)

reflect, op = self._check_reflected_op(op)
if (other := self._wrap_binop_normalization(other)) is NotImplemented:
normalized_other = self._wrap_binop_normalization(other)
if normalized_other is NotImplemented:
# Non-numerical columns cannot be compared to numerical columns,
# but are expected to return all False values for equality and all
# True values for inequality. We have to exclude
# NumericalBaseColumn to ensure that decimal columns pass through.
if (
op in {"__eq__", "__ne__"}
and not isinstance(other, NumericalBaseColumn)
and isinstance(other, ColumnBase)
):
return as_column(
op != "__eq__", length=len(self), dtype="bool"
)
return NotImplemented
other = normalized_other
out_dtype = self.dtype
if other is not None:
out_dtype = np.result_type(self.dtype, other.dtype)
Expand Down
11 changes: 8 additions & 3 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from cudf._lib.types import size_type_dtype
from cudf.api.types import is_integer, is_scalar, is_string_dtype
from cudf.core.buffer import Buffer
from cudf.core.column import column, datetime
from cudf.core.column import as_column, column, datetime
from cudf.core.column.column import ColumnBase
from cudf.core.column.methods import ColumnMethods
from cudf.utils.docutils import copy_docstring
Expand Down Expand Up @@ -5910,9 +5910,14 @@ def _binaryop(
elif op == "__ne__":
return self.isnull()

other = self._wrap_binop_normalization(other)
if other is NotImplemented:
normalized_other = self._wrap_binop_normalization(other)
if normalized_other is NotImplemented:
if op in {"__eq__", "__ne__"} and isinstance(other, ColumnBase):
return as_column(
op != "__eq__", length=len(self), dtype="bool"
)
return NotImplemented
other = normalized_other

if isinstance(other, (StringColumn, str, cudf.Scalar)):
if isinstance(other, cudf.Scalar) and other.dtype != "O":
Expand Down
29 changes: 27 additions & 2 deletions python/cudf/cudf/core/column/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import pyarrow as pa

import cudf
from cudf._typing import Dtype
from cudf.core.column import ColumnBase
import cudf._lib as libcudf
from cudf._typing import ColumnBinaryOperand, Dtype
from cudf.core.column import ColumnBase, as_column
from cudf.core.column.methods import ColumnMethods
from cudf.core.dtypes import StructDtype
from cudf.core.missing import NA
Expand All @@ -25,6 +26,7 @@ class StructColumn(ColumnBase):
"""

dtype: StructDtype
_VALID_BINARY_OPERATIONS = {"__eq__", "__ne__"}

@property
def base_size(self):
Expand All @@ -33,6 +35,29 @@ def base_size(self):
else:
return self.size + self.offset

def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
reflect, op = self._check_reflected_op(op)

normalized_other = self._wrap_binop_normalization(other)
if normalized_other is NotImplemented:
if op in {"__eq__", "__ne__"} and isinstance(other, ColumnBase):
return as_column(
op != "__eq__", length=len(self), dtype="bool"
)
return NotImplemented
other = normalized_other

lhs, rhs = (other, self) if reflect else (self, other)

if op in {"__eq__", "__ne__"}:
return libcudf.binaryop.binaryop(
lhs=lhs, rhs=rhs, op=op, dtype="bool"
)
raise TypeError(
f"'{op}' not supported between instances of "
f"'{type(self).__name__}' and '{type(other).__name__}'"
)

def to_arrow(self):
children = [
pa.nulls(len(child))
Expand Down
3 changes: 0 additions & 3 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,9 +1572,6 @@ def _colwise_binop(
else:
assert False, "At least one operand must be a column."

# TODO: Disable logical and binary operators between columns that
# are not numerical using the new binops mixin.

outcol = (
getattr(operator, fn)(right_column, left_column)
if reflect
Expand Down
20 changes: 17 additions & 3 deletions python/cudf/cudf/core/mixins/binops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.

from .mixin_factory import _create_delegating_mixin
from .mixin_factory import Operation, _create_delegating_mixin

BinaryOperand = _create_delegating_mixin(
"BinaryOperand",
Expand Down Expand Up @@ -59,7 +59,13 @@ def _binaryop(self, other, op: str):
Must be overridden by subclasses, the default implementation raises a
NotImplementedError.
"""
raise NotImplementedError
if op in {"__eq__", "__ne__"}:
op_string = "==" if op == "__eq__" else "!="
raise TypeError(
f"'{op_string}' not supported between instances of "
f"'{type(self).__name__}' and '{type(other).__name__}'"
)
raise NotImplementedError()


def _check_reflected_op(op):
Expand All @@ -70,3 +76,11 @@ def _check_reflected_op(op):

BinaryOperand._binaryop = _binaryop
BinaryOperand._check_reflected_op = staticmethod(_check_reflected_op)

# It is necessary to override the default object.__eq__ so that objects don't
# automatically support equality binops using the wrong operator implementation
# (falling back to ``object``). We must override the object.__eq__ with an
# Operation rather than a plain function/method, so that the BinaryOperand
# mixin overrides it for classes that define their own __eq__ in _binaryop.
BinaryOperand.__eq__ = Operation("__eq__", {}, BinaryOperand._binaryop)
BinaryOperand.__ne__ = Operation("__ne__", {}, BinaryOperand._binaryop)
7 changes: 3 additions & 4 deletions python/cudf/cudf/core/mixins/mixin_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.

import inspect

Expand Down Expand Up @@ -27,9 +27,8 @@ class Operation:
----------
name : str
The name of the operation.
docstring_format_args : str
The attribute of the owning class from which to pull format parameters
for this operation's docstring.
docstring_format_args : dict
The dict of format parameters for this operation's docstring.
base_operation : str
The underlying operation function to be invoked when operation `name`
is called on the owning class.
Expand Down
41 changes: 41 additions & 0 deletions python/cudf/cudf/tests/test_binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import cudf
from cudf import Series
from cudf.api.types import is_decimal_dtype, is_float_dtype
from cudf.core._compat import PANDAS_CURRENT_SUPPORTED_VERSION, PANDAS_VERSION
from cudf.core.buffer.spill_manager import get_global_manager
from cudf.core.index import as_index
Expand Down Expand Up @@ -3046,6 +3047,46 @@ def test_equality_ops_index_mismatch(fn):
utils.assert_eq(expected, actual)


TYPE_MISMATCH_INPUTS = [
[1, 2, 3],
["a", "b", "c"],
[0.0, 3.14, np.nan],
[
decimal.Decimal("1.324324"),
decimal.Decimal("2.71"),
decimal.Decimal("3.14"),
],
[pd.Timedelta(10), pd.Timedelta(123456789), pd.Timedelta(-99999999999)],
[pd.Timestamp(10), pd.Timestamp(123456789), pd.Timestamp(-99999999999)],
# TODO: Categorical columns need significant work in _binaryop
# pd.Categorical(["a", "b", "c"]),
# TODO: Need binary operations for (in)equality in libcudf between list columns
# [[1, 2], [3, 4, 5], [6, 7, 8, 9]],
# TODO: Need binary operations for (in)equality in libcudf between struct columns
# [{"a": 0}, {"a": 1}, {"a": 2}]
]


@pytest.mark.parametrize("left", TYPE_MISMATCH_INPUTS)
@pytest.mark.parametrize("right", TYPE_MISMATCH_INPUTS)
@pytest.mark.parametrize("fn", ["eq", "ne"])
def test_equality_ops_type_mismatch(left, right, fn):
a = cudf.Series(left, nan_as_null=False)
b = cudf.Series(right, nan_as_null=False)

if (is_float_dtype(a) and is_decimal_dtype(b)) or (
is_float_dtype(b) and is_decimal_dtype(a)
):
pytest.skip("Cannot compare float/decimal types.")

pa = a.to_pandas()
pb = b.to_pandas()
expected = getattr(operator, fn)(pa, pb)
actual = getattr(operator, fn)(a, b)

utils.assert_eq(expected, actual)


def generate_test_null_equals_columnops_data():
# Generate tuples of:
# (left_data, right_data, compare_bool
Expand Down
Loading