Skip to content

Commit

Permalink
fix(python): Fix decimal series dispatch (#20400)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Dec 21, 2024
1 parent 234810d commit 2f2bb92
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 5 deletions.
28 changes: 24 additions & 4 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,13 +1035,15 @@ def __add__(self, other: Expr) -> Expr: ...
@overload
def __add__(self, other: Any) -> Self: ...

def __add__(self, other: Any) -> Self | DataFrame | Expr:
def __add__(self, other: Any) -> Series | DataFrame | Expr:
if isinstance(other, str):
other = Series("", [other])
elif isinstance(other, pl.DataFrame):
return other + self
elif isinstance(other, pl.Expr):
return F.lit(self) + other
if self.dtype.is_decimal() and isinstance(other, (float, int)):
return self.to_frame().select(F.col(self.name) + other).to_series()
return self._arithmetic(other, "add", "add_<>")

@overload
Expand All @@ -1050,9 +1052,11 @@ def __sub__(self, other: Expr) -> Expr: ...
@overload
def __sub__(self, other: Any) -> Self: ...

def __sub__(self, other: Any) -> Self | Expr:
def __sub__(self, other: Any) -> Series | Expr:
if isinstance(other, pl.Expr):
return F.lit(self) - other
if self.dtype.is_decimal() and isinstance(other, (float, int)):
return self.to_frame().select(F.col(self.name) - other).to_series()
return self._arithmetic(other, "sub", "sub_<>")

def _recursive_cast_to_dtype(self, leaf_dtype: PolarsDataType) -> Series:
Expand Down Expand Up @@ -1083,6 +1087,8 @@ def __truediv__(self, other: Any) -> Series | Expr:
if self.dtype.is_temporal():
msg = "first cast to integer before dividing datelike dtypes"
raise TypeError(msg)
if self.dtype.is_decimal() and isinstance(other, (float, int)):
return self.to_frame().select(F.col(self.name) / other).to_series()

self = (
self
Expand Down Expand Up @@ -1111,6 +1117,8 @@ def __floordiv__(self, other: Any) -> Series | Expr:
if self.dtype.is_temporal():
msg = "first cast to integer before dividing datelike dtypes"
raise TypeError(msg)
if self.dtype.is_decimal() and isinstance(other, (float, int)):
return self.to_frame().select(F.col(self.name) // other).to_series()

if not isinstance(other, pl.Expr):
other = F.lit(other)
Expand All @@ -1134,6 +1142,8 @@ def __mul__(self, other: Any) -> Series | DataFrame | Expr:
if self.dtype.is_temporal():
msg = "first cast to integer before multiplying datelike dtypes"
raise TypeError(msg)
if self.dtype.is_decimal() and isinstance(other, (float, int)):
return self.to_frame().select(F.col(self.name) * other).to_series()
elif isinstance(other, pl.DataFrame):
return other * self
else:
Expand All @@ -1151,6 +1161,8 @@ def __mod__(self, other: Any) -> Series | Expr:
if self.dtype.is_temporal():
msg = "first cast to integer before applying modulo on datelike dtypes"
raise TypeError(msg)
if self.dtype.is_decimal() and isinstance(other, (float, int)):
return self.to_frame().select(F.col(self.name) % other).to_series()
return self._arithmetic(other, "rem", "rem_<>")

def __rmod__(self, other: Any) -> Series:
Expand All @@ -1160,11 +1172,15 @@ def __rmod__(self, other: Any) -> Series:
return self._arithmetic(other, "rem", "rem_<>_rhs")

def __radd__(self, other: Any) -> Series:
if isinstance(other, str):
return (other + self.to_frame()).to_series()
if isinstance(other, str) or (
isinstance(other, (int, float)) and self.dtype.is_decimal()
):
return self.to_frame().select(other + F.col(self.name)).to_series()
return self._arithmetic(other, "add", "add_<>_rhs")

def __rsub__(self, other: Any) -> Series:
if isinstance(other, (int, float)) and self.dtype.is_decimal():
return self.to_frame().select(other - F.col(self.name)).to_series()
return self._arithmetic(other, "sub", "sub_<>_rhs")

def __rtruediv__(self, other: Any) -> Series:
Expand All @@ -1173,6 +1189,8 @@ def __rtruediv__(self, other: Any) -> Series:
raise TypeError(msg)
if self.dtype.is_float():
self.__rfloordiv__(other)
if isinstance(other, (int, float)) and self.dtype.is_decimal():
return self.to_frame().select(other / F.col(self.name)).to_series()

if isinstance(other, int):
other = float(other)
Expand All @@ -1188,6 +1206,8 @@ def __rmul__(self, other: Any) -> Series:
if self.dtype.is_temporal():
msg = "first cast to integer before multiplying datelike dtypes"
raise TypeError(msg)
if isinstance(other, (int, float)) and self.dtype.is_decimal():
return self.to_frame().select(other * F.col(self.name)).to_series()
return self._arithmetic(other, "mul", "mul_<>")

def __pow__(self, exponent: int | float | Series) -> Series:
Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,6 +2038,7 @@ def test_add_string() -> None:
expected = pl.DataFrame(
{"a": ["hello hi", "hello there"], "b": ["hello hello", "hello world"]}
)
print(expected)
assert_frame_equal(("hello " + df), expected)


Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/datatypes/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,26 @@ def test_decimal_arithmetic_schema() -> None:
assert q1.collect_schema() == q1.collect().schema
q1 = q.select(pl.col.x + pl.col.x)
assert q1.collect_schema() == q1.collect().schema


def test_decimal_arithmetic_schema_float_20369() -> None:
s = pl.Series("x", [1.0], dtype=pl.Decimal(15, 2))
assert_series_equal((s - 1.0), pl.Series("x", [0.0], dtype=pl.Decimal(None, 2)))
assert_series_equal(
(3.0 - s), pl.Series("literal", [2.0], dtype=pl.Decimal(None, 2))
)
assert_series_equal(
(3.0 / s), pl.Series("literal", [3.0], dtype=pl.Decimal(None, 6))
)
assert_series_equal(
(s / 3.0), pl.Series("x", [0.333333], dtype=pl.Decimal(None, 6))
)

assert_series_equal((s + 1.0), pl.Series("x", [2.0], dtype=pl.Decimal(None, 2)))
assert_series_equal(
(1.0 + s), pl.Series("literal", [2.0], dtype=pl.Decimal(None, 2))
)
assert_series_equal((s * 1.0), pl.Series("x", [1.0], dtype=pl.Decimal(None, 4)))
assert_series_equal(
(1.0 * s), pl.Series("literal", [1.0], dtype=pl.Decimal(None, 4))
)
3 changes: 2 additions & 1 deletion py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,11 @@ def test_categorical_agg(s: pl.Series, min: str | None, max: str | None) -> None
def test_add_string() -> None:
s = pl.Series(["hello", "weird"])
result = s + " world"
print(result)
assert_series_equal(result, pl.Series(["hello world", "weird world"]))

result = "pfx:" + s
assert_series_equal(result, pl.Series(["pfx:hello", "pfx:weird"]))
assert_series_equal(result, pl.Series("literal", ["pfx:hello", "pfx:weird"]))


@pytest.mark.parametrize(
Expand Down

0 comments on commit 2f2bb92

Please sign in to comment.