Skip to content

Commit

Permalink
Add temporal horizontal mean
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Dec 24, 2024
1 parent cd2da5b commit c1a87b6
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 17 deletions.
90 changes: 76 additions & 14 deletions crates/polars-ops/src/series/ops/horizontal.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::borrow::Cow;

use arrow::temporal_conversions::MILLISECONDS_IN_DAY;
use polars_core::chunked_array::cast::CastOptions;
use polars_core::prelude::*;
use polars_core::series::arithmetic::coerce_lhs_rhs;
Expand Down Expand Up @@ -267,24 +268,67 @@ pub fn mean_horizontal(
) -> PolarsResult<Option<Column>> {
validate_column_lengths(columns)?;

let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| {
let dtype = s.dtype();
dtype.is_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null()
});
let first_dtype = columns[0].dtype();
let is_temporal = first_dtype.is_temporal();
let columns = if is_temporal {
// All remaining must be the same temporal dtype.
for col in &columns[1..] {
if col.dtype() != first_dtype {
polars_bail!(
InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})",
columns[0].name(),
first_dtype,
col.name(),
col.dtype(),
);
};
}

if !non_numeric_columns.is_empty() {
let col = non_numeric_columns.first().cloned();
// Convert to physical
columns
.into_iter()
.map(|c| c.cast(&DataType::Int64).unwrap())
.collect::<Vec<_>>()
} else if first_dtype.is_numeric()
|| first_dtype.is_decimal()
|| first_dtype.is_bool()
|| first_dtype.is_null()
|| first_dtype.is_temporal()
{
// All remaining must be numeric.
for col in &columns[1..] {
let dtype = col.dtype();
if !(dtype.is_numeric()
|| dtype.is_decimal()
|| dtype.is_bool()
|| dtype.is_null()
|| dtype.is_temporal())
{
polars_bail!(
InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})",
columns[0].name(),
first_dtype,
col.name(),
dtype,
);
}
}
columns
.into_iter()
.map(|s| s.cast(&DataType::Int64).unwrap().into_column())
.collect::<Vec<_>>()
} else {
polars_bail!(
InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})",
col.unwrap().name(),
col.unwrap().dtype(),
InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={})",
columns[0].name(),
first_dtype,
);
}
let columns = numeric_columns.into_iter().cloned().collect::<Vec<_>>();
};

match columns.len() {
0 => Ok(None),
1 => Ok(Some(match columns[0].dtype() {
dt if dt != &DataType::Float32 && !dt.is_decimal() => {
dt if dt != &DataType::Float32 && !dt.is_temporal() && !dt.is_decimal() => {
columns[0].cast(&DataType::Float64)?
},
_ => columns[0].clone(),
Expand Down Expand Up @@ -331,8 +375,26 @@ pub fn mean_horizontal(
.into_column()
.cast(&DataType::Float64)?;

sum.map(|sum| std::ops::Div::div(&sum, &value_length))
.transpose()
let out = sum.map(|sum| std::ops::Div::div(&sum, &value_length));

let x = out.map(|opt| {
opt.and_then(|value| {
if is_temporal {
if first_dtype == &DataType::Date {
// Cast to DateTime(us)
(value * MILLISECONDS_IN_DAY)
.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))
} else {
// Cast to original
value.cast(&first_dtype)
}
} else {
Ok(value)
}
})
});

x.transpose()
},
}
}
Expand Down
53 changes: 50 additions & 3 deletions py-polars/tests/unit/operations/aggregation/test_horizontal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import datetime
from collections import OrderedDict
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING, Any

import pytest
Expand All @@ -12,7 +12,7 @@
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
from polars._typing import PolarsDataType
from polars._typing import PolarsDataType, TimeUnit


def test_any_expr(fruits_cars: pl.DataFrame) -> None:
Expand Down Expand Up @@ -340,7 +340,7 @@ def test_sum_dtype_12028() -> None:
[
pl.Series(
"sum_duration",
[datetime.timedelta(seconds=10)],
[timedelta(seconds=10)],
dtype=pl.Duration(time_unit="us"),
),
]
Expand Down Expand Up @@ -442,6 +442,53 @@ def test_mean_horizontal_all_null() -> None:
assert_frame_equal(result, expected)


@pytest.mark.parametrize("tz", [None, "UTC", "Asia/Kathmandu"])
@pytest.mark.parametrize("tu", ["ms", "us", "ns"])
def test_mean_horizontal_temporal(tu: TimeUnit, tz: str) -> None:
dt1 = [date(2024, 1, 1), date(2024, 1, 3)]
dt2 = [date(2024, 1, 2), date(2024, 1, 4)]
dur1 = [timedelta(hours=1), timedelta(hours=3)]
dur2 = [timedelta(hours=2), timedelta(hours=4)]
lf = pl.LazyFrame(
{
"date1": pl.Series(dt1, dtype=pl.Date),
"date2": pl.Series(dt2, dtype=pl.Date),
"datetime1": pl.Series(dt1, dtype=pl.Datetime(time_unit=tu, time_zone=tz)),
"datetime2": pl.Series(dt2, dtype=pl.Datetime(time_unit=tu, time_zone=tz)),
"time1": [time(1), time(3)],
"time2": [time(2), time(4)],
"duration1": pl.Series(dur1, dtype=pl.Duration(time_unit=tu)),
"duration2": pl.Series(dur2, dtype=pl.Duration(time_unit=tu)),
}
)
out = lf.select(
pl.mean_horizontal("date1", "date2").alias("date"),
pl.mean_horizontal("datetime1", "datetime2").alias("datetime"),
pl.mean_horizontal("time1", "time2").alias("time"),
pl.mean_horizontal("duration1", "duration2").alias("duration"),
).collect()

expected = pl.DataFrame(
{
"date": pl.Series(
[datetime(2024, 1, 1, 12), datetime(2024, 1, 3, 12)],
dtype=pl.Datetime("ms"),
),
"datetime": pl.Series(
[datetime(2024, 1, 1, 12), datetime(2024, 1, 3, 12)],
dtype=pl.Datetime(tu, tz),
),
"time": [time(hour=1, minute=30), time(hour=3, minute=30)],
"duration": pl.Series(
[timedelta(hours=1, minutes=30), timedelta(hours=3, minutes=30)],
dtype=pl.Duration(time_unit=tu),
),
}
)

assert_frame_equal(out, expected)


@pytest.mark.parametrize(
("in_dtype", "out_dtype"),
[
Expand Down

0 comments on commit c1a87b6

Please sign in to comment.