Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable binops between Decimal and Integer columns #7859

Open
wants to merge 14 commits into
base: branch-21.08
Choose a base branch
from
7 changes: 7 additions & 0 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def to_arrow(self):
)

def binary_operator(self, op, other, reflect=False):
if isinstance(other, cudf.core.column.NumericalColumn):
if other.dtype.kind in 'ui':
other = other.as_decimal_column(other._decimal_dtype())
elif other.dtype.kind == 'f':
return self.as_numerical_column(other.dtype).binary_operator(
op, other, reflect=reflect
)
if reflect:
self, other = other, self

Expand Down
59 changes: 45 additions & 14 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from numbers import Number
from decimal import Decimal
from typing import Any, Callable, Sequence, Tuple, Union, cast

import numpy as np
Expand Down Expand Up @@ -90,16 +91,7 @@ def unary_operator(self, unaryop: str) -> ColumnBase:
def binary_operator(
self, binop: str, rhs: BinaryOperand, reflect: bool = False,
) -> ColumnBase:
int_dtypes = [
np.dtype("int8"),
np.dtype("int16"),
np.dtype("int32"),
np.dtype("int64"),
np.dtype("uint8"),
np.dtype("uint16"),
np.dtype("uint32"),
np.dtype("uint64"),
]

if rhs is None:
out_dtype = self.dtype
else:
Expand All @@ -117,14 +109,23 @@ def binary_operator(
msg = "{!r} operator not supported between {} and {}"
raise TypeError(msg.format(binop, type(self), type(rhs)))
if isinstance(rhs, cudf.core.column.DecimalColumn):
lhs = self.as_decimal_column(
Decimal64Dtype(Decimal64Dtype.MAX_PRECISION, 0)
)
# For arithmetic ops between decimals and numeric non decimals,
# an integral column will be cast to decimal, whereas a float
# column will cause the decimal operand to cast to float
if self.dtype.kind in 'ui':
lhs = self.as_decimal_column(self._decimal_dtype())
elif self.dtype.kind == 'f':
lhs = self
rhs = rhs.astype(self.dtype)
else:
lhs = self.as_decimal_column(
Decimal64Dtype(Decimal64Dtype.MAX_PRECISION, 0)
)
return lhs.binary_operator(binop, rhs)
out_dtype = np.result_type(self.dtype, rhs.dtype)
if binop in ["mod", "floordiv"]:
tmp = self if reflect else rhs
if (tmp.dtype in int_dtypes) and (
if (tmp.dtype.kind in 'ui') and (
(np.isscalar(tmp) and (0 == tmp))
or ((isinstance(tmp, NumericalColumn)) and (0.0 in tmp))
):
Expand Down Expand Up @@ -218,6 +219,36 @@ def as_timedelta_column(
),
)

def _decimal_dtype(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems more suitable to go to utils/dtypes.py, and naming should be decimal_dtype_from_numerical?

"""
Derive a Decimal64Dtype corresponding to an integral
dtype column. Dtypes are derived using the following
rules, which are based off of the maximum sized ints
that can be contained within that dtype. As such, we
do not allow for the uint64 and int64 cases, because
the resulting type would require more than 18 digits
of precision:

uint8 -> Decimal64Dtype(3, 1) ✓
uint16 -> Decimal64Dtype(5, 1) ✓
uint32 -> Decimal64Dtype(10, 1) ✓
uint64 -> Decimal64Dtype(20, 1) x
int8 -> Decimal64Dtype(3, 1) ✓
int16 -> Decimal64Dtype(5, 1) ✓
int32 -> Decimal64Dtype(10, 1) ✓
int64 -> Decimal64Dtype(19, 1) x
"""
if self.dtype in {np.dtype("int64"), np.dtype("uint64")}:
raise TypeError(
f"Can not implicitly cast integer column of "
f"dtype {self.dtype} to Decimal64Dtype, as "
f"integers could contain more than 18 digits"
)
return cudf.Decimal64Dtype._from_decimal(
Decimal(np.iinfo(self.dtype).max)
)


def as_decimal_column(
self, dtype: Dtype, **kwargs
) -> "cudf.core.column.DecimalColumn":
Expand Down
Loading