Skip to content

Commit

Permalink
fix(python): Properly raise UDF errors (#20417)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Dec 23, 2024
1 parent f8671e8 commit 8df0cbe
Show file tree
Hide file tree
Showing 6 changed files with 644 additions and 535 deletions.
20 changes: 10 additions & 10 deletions crates/polars-python/src/dataframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,16 +555,16 @@ impl PyDataFrame {
use apply_lambda_with_primitive_out_type as apply;
#[rustfmt::skip]
let out = match output_type.map(|dt| dt.0) {
Some(DataType::Int32) => apply::<Int32Type>(df, py, lambda, 0, None).into_series(),
Some(DataType::Int64) => apply::<Int64Type>(df, py, lambda, 0, None).into_series(),
Some(DataType::UInt32) => apply::<UInt32Type>(df, py, lambda, 0, None).into_series(),
Some(DataType::UInt64) => apply::<UInt64Type>(df, py, lambda, 0, None).into_series(),
Some(DataType::Float32) => apply::<Float32Type>(df, py, lambda, 0, None).into_series(),
Some(DataType::Float64) => apply::<Float64Type>(df, py, lambda, 0, None).into_series(),
Some(DataType::Date) => apply::<Int32Type>(df, py, lambda, 0, None).into_date().into_series(),
Some(DataType::Datetime(tu, tz)) => apply::<Int64Type>(df, py, lambda, 0, None).into_datetime(tu, tz).into_series(),
Some(DataType::Boolean) => apply_lambda_with_bool_out_type(df, py, lambda, 0, None).into_series(),
Some(DataType::String) => apply_lambda_with_string_out_type(df, py, lambda, 0, None).into_series(),
Some(DataType::Int32) => apply::<Int32Type>(df, py, lambda, 0, None)?.into_series(),
Some(DataType::Int64) => apply::<Int64Type>(df, py, lambda, 0, None)?.into_series(),
Some(DataType::UInt32) => apply::<UInt32Type>(df, py, lambda, 0, None)?.into_series(),
Some(DataType::UInt64) => apply::<UInt64Type>(df, py, lambda, 0, None)?.into_series(),
Some(DataType::Float32) => apply::<Float32Type>(df, py, lambda, 0, None)?.into_series(),
Some(DataType::Float64) => apply::<Float64Type>(df, py, lambda, 0, None)?.into_series(),
Some(DataType::Date) => apply::<Int32Type>(df, py, lambda, 0, None)?.into_date().into_series(),
Some(DataType::Datetime(tu, tz)) => apply::<Int64Type>(df, py, lambda, 0, None)?.into_datetime(tu, tz).into_series(),
Some(DataType::Boolean) => apply_lambda_with_bool_out_type(df, py, lambda, 0, None)?.into_series(),
Some(DataType::String) => apply_lambda_with_string_out_type(df, py, lambda, 0, None)?.into_series(),
_ => return apply_lambda_unknown(df, py, lambda, inference_size),
};

Expand Down
59 changes: 33 additions & 26 deletions crates/polars-python/src/map/dataframe.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use polars::prelude::*;
use polars_core::frame::row::{rows_to_schema_first_non_null, Row};
use polars_core::series::SeriesIter;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString, PyTuple};
Expand Down Expand Up @@ -47,7 +48,7 @@ pub fn apply_lambda_unknown<'a>(
let first_value = out.extract::<bool>().ok();
return Ok((
PySeries::new(
apply_lambda_with_bool_out_type(df, py, lambda, null_count, first_value)
apply_lambda_with_bool_out_type(df, py, lambda, null_count, first_value)?
.into_series(),
)
.into_py_any(py)?,
Expand All @@ -64,7 +65,7 @@ pub fn apply_lambda_unknown<'a>(
lambda,
null_count,
first_value,
)
)?
.into_series(),
)
.into_py_any(py)?,
Expand All @@ -80,7 +81,7 @@ pub fn apply_lambda_unknown<'a>(
lambda,
null_count,
first_value,
)
)?
.into_series(),
)
.into_py_any(py)?,
Expand All @@ -90,7 +91,7 @@ pub fn apply_lambda_unknown<'a>(
let first_value = out.extract::<PyBackedStr>().ok();
return Ok((
PySeries::new(
apply_lambda_with_string_out_type(df, py, lambda, null_count, first_value)
apply_lambda_with_string_out_type(df, py, lambda, null_count, first_value)?
.into_series(),
)
.into_py_any(py)?,
Expand Down Expand Up @@ -145,18 +146,15 @@ fn apply_iter<'a, T>(
lambda: Bound<'a, PyAny>,
init_null_count: usize,
skip: usize,
) -> impl Iterator<Item = Option<T>> + 'a
) -> impl Iterator<Item = PyResult<Option<T>>> + 'a
where
T: FromPyObject<'a>,
{
let mut iters = get_iters_skip(df, init_null_count + skip);
((init_null_count + skip)..df.height()).map(move |_| {
let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
let tpl = (PyTuple::new(py, iter).unwrap(),);
match lambda.call1(tpl) {
Ok(val) => val.extract::<T>().ok(),
Err(e) => panic!("python function failed {e}"),
}
lambda.call1(tpl).map(|v| v.extract().ok())
})
}

Expand All @@ -167,14 +165,17 @@ pub fn apply_lambda_with_primitive_out_type<'a, D>(
lambda: Bound<'a, PyAny>,
init_null_count: usize,
first_value: Option<D::Native>,
) -> ChunkedArray<D>
) -> PyResult<ChunkedArray<D>>
where
D: PyArrowPrimitiveType,
D::Native: IntoPyObject<'a> + FromPyObject<'a>,
{
let skip = usize::from(first_value.is_some());
if init_null_count == df.height() {
ChunkedArray::full_null(PlSmallStr::from_static("map"), df.height())
Ok(ChunkedArray::full_null(
PlSmallStr::from_static("map"),
df.height(),
))
} else {
let iter = apply_iter(df, py, lambda, init_null_count, skip);
iterator_to_primitive(
Expand All @@ -194,10 +195,13 @@ pub fn apply_lambda_with_bool_out_type<'a>(
lambda: Bound<'a, PyAny>,
init_null_count: usize,
first_value: Option<bool>,
) -> ChunkedArray<BooleanType> {
) -> PyResult<ChunkedArray<BooleanType>> {
let skip = usize::from(first_value.is_some());
if init_null_count == df.height() {
ChunkedArray::full_null(PlSmallStr::from_static("map"), df.height())
Ok(ChunkedArray::full_null(
PlSmallStr::from_static("map"),
df.height(),
))
} else {
let iter = apply_iter(df, py, lambda, init_null_count, skip);
iterator_to_bool(
Expand All @@ -217,10 +221,13 @@ pub fn apply_lambda_with_string_out_type<'a>(
lambda: Bound<'a, PyAny>,
init_null_count: usize,
first_value: Option<PyBackedStr>,
) -> StringChunked {
) -> PyResult<StringChunked> {
let skip = usize::from(first_value.is_some());
if init_null_count == df.height() {
ChunkedArray::full_null(PlSmallStr::from_static("map"), df.height())
Ok(ChunkedArray::full_null(
PlSmallStr::from_static("map"),
df.height(),
))
} else {
let iter = apply_iter::<PyBackedStr>(df, py, lambda, init_null_count, skip);
iterator_to_string(
Expand Down Expand Up @@ -253,18 +260,18 @@ pub fn apply_lambda_with_list_out_type<'a>(
let iter = ((init_null_count + skip)..df.height()).map(|_| {
let iter = iters.iter_mut().map(|it| Wrap(it.next().unwrap()));
let tpl = (PyTuple::new(py, iter).unwrap(),);
match lambda.call1(tpl) {
Ok(val) => match val.getattr("_s") {
Ok(val) => val.extract::<PySeries>().ok().map(|ps| ps.series),
Err(_) => {
if val.is_none() {
None
} else {
panic!("should return a Series, got a {val:?}")
}
},
let val = lambda.call1(tpl)?;
match val.getattr("_s") {
Ok(val) => val.extract::<PySeries>().map(|s| Some(s.series)),
Err(_) => {
if val.is_none() {
Ok(None)
} else {
Err(PyValueError::new_err(
"should return a Series, got a {val:?}",
))
}
},
Err(e) => panic!("python function failed {e}"),
}
});
iterator_to_list(
Expand Down
101 changes: 69 additions & 32 deletions crates/polars-python/src/map/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl PyArrowPrimitiveType for Float64Type {}

fn iterator_to_struct<'a>(
py: Python,
it: impl Iterator<Item = Option<Bound<'a, PyAny>>>,
it: impl Iterator<Item = PyResult<Option<Bound<'a, PyAny>>>>,
init_null_count: usize,
first_value: AnyValue<'a>,
name: PlSmallStr,
Expand Down Expand Up @@ -72,7 +72,7 @@ fn iterator_to_struct<'a>(
}

for dict in it {
match dict {
match dict? {
None => {
for field_items in struct_fields.values_mut() {
field_items.push(AnyValue::Null);
Expand Down Expand Up @@ -134,127 +134,164 @@ fn iterator_to_struct<'a>(
}

fn iterator_to_primitive<T>(
it: impl Iterator<Item = Option<T::Native>>,
it: impl Iterator<Item = PyResult<Option<T::Native>>>,
init_null_count: usize,
first_value: Option<T::Native>,
name: PlSmallStr,
capacity: usize,
) -> ChunkedArray<T>
) -> PyResult<ChunkedArray<T>>
where
T: PyArrowPrimitiveType,
{
let mut error = None;
// SAFETY: we know the iterators len.
let ca: ChunkedArray<T> = unsafe {
if init_null_count > 0 {
(0..init_null_count)
.map(|_| None)
.chain(std::iter::once(first_value))
.map(|_| Ok(None))
.chain(std::iter::once(Ok(first_value)))
.chain(it)
.trust_my_length(capacity)
.map(|v| catch_err(&mut error, v))
.collect_trusted()
} else if first_value.is_some() {
std::iter::once(first_value)
std::iter::once(Ok(first_value))
.chain(it)
.trust_my_length(capacity)
.map(|v| catch_err(&mut error, v))
.collect_trusted()
} else {
it.collect()
it.map(|v| catch_err(&mut error, v)).collect()
}
};
debug_assert_eq!(ca.len(), capacity);
ca.with_name(name)

if let Some(err) = error {
let _ = err?;
}
Ok(ca.with_name(name))
}

fn iterator_to_bool(
it: impl Iterator<Item = Option<bool>>,
it: impl Iterator<Item = PyResult<Option<bool>>>,
init_null_count: usize,
first_value: Option<bool>,
name: PlSmallStr,
capacity: usize,
) -> ChunkedArray<BooleanType> {
) -> PyResult<ChunkedArray<BooleanType>> {
let mut error = None;
// SAFETY: we know the iterators len.
let ca: BooleanChunked = unsafe {
if init_null_count > 0 {
(0..init_null_count)
.map(|_| None)
.chain(std::iter::once(first_value))
.map(|_| Ok(None))
.chain(std::iter::once(Ok(first_value)))
.chain(it)
.trust_my_length(capacity)
.map(|v| catch_err(&mut error, v))
.collect_trusted()
} else if first_value.is_some() {
std::iter::once(first_value)
std::iter::once(Ok(first_value))
.chain(it)
.trust_my_length(capacity)
.map(|v| catch_err(&mut error, v))
.collect_trusted()
} else {
it.collect()
it.map(|v| catch_err(&mut error, v)).collect()
}
};
if let Some(err) = error {
let _ = err?;
}
debug_assert_eq!(ca.len(), capacity);
ca.with_name(name)
Ok(ca.with_name(name))
}

#[cfg(feature = "object")]
fn iterator_to_object(
it: impl Iterator<Item = Option<ObjectValue>>,
it: impl Iterator<Item = PyResult<Option<ObjectValue>>>,
init_null_count: usize,
first_value: Option<ObjectValue>,
name: PlSmallStr,
capacity: usize,
) -> ObjectChunked<ObjectValue> {
) -> PyResult<ObjectChunked<ObjectValue>> {
let mut error = None;
// SAFETY: we know the iterators len.
let ca: ObjectChunked<ObjectValue> = unsafe {
if init_null_count > 0 {
(0..init_null_count)
.map(|_| None)
.chain(std::iter::once(first_value))
.map(|_| Ok(None))
.chain(std::iter::once(Ok(first_value)))
.chain(it)
.map(|v| catch_err(&mut error, v))
.trust_my_length(capacity)
.collect_trusted()
} else if first_value.is_some() {
std::iter::once(first_value)
std::iter::once(Ok(first_value))
.chain(it)
.map(|v| catch_err(&mut error, v))
.trust_my_length(capacity)
.collect_trusted()
} else {
it.collect()
it.map(|v| catch_err(&mut error, v)).collect()
}
};
if let Some(err) = error {
let _ = err?;
}
debug_assert_eq!(ca.len(), capacity);
ca.with_name(name)
Ok(ca.with_name(name))
}

fn catch_err<K>(error: &mut Option<PyResult<Option<K>>>, result: PyResult<Option<K>>) -> Option<K> {
match result {
Ok(item) => item,
err => {
if error.is_none() {
*error = Some(err);
}
None
},
}
}

fn iterator_to_string<S: AsRef<str>>(
it: impl Iterator<Item = Option<S>>,
it: impl Iterator<Item = PyResult<Option<S>>>,
init_null_count: usize,
first_value: Option<S>,
name: PlSmallStr,
capacity: usize,
) -> StringChunked {
) -> PyResult<StringChunked> {
let mut error = None;
// SAFETY: we know the iterators len.
let ca: StringChunked = unsafe {
if init_null_count > 0 {
(0..init_null_count)
.map(|_| None)
.chain(std::iter::once(first_value))
.map(|_| Ok(None))
.chain(std::iter::once(Ok(first_value)))
.trust_my_length(capacity)
.map(|v| catch_err(&mut error, v))
.collect_trusted()
} else if first_value.is_some() {
std::iter::once(first_value)
std::iter::once(Ok(first_value))
.chain(it)
.trust_my_length(capacity)
.map(|v| catch_err(&mut error, v))
.collect_trusted()
} else {
it.collect()
it.map(|v| catch_err(&mut error, v)).collect()
}
};
debug_assert_eq!(ca.len(), capacity);
ca.with_name(name)
if let Some(err) = error {
let _ = err?;
}
Ok(ca.with_name(name))
}

fn iterator_to_list(
dt: &DataType,
it: impl Iterator<Item = Option<Series>>,
it: impl Iterator<Item = PyResult<Option<Series>>>,
init_null_count: usize,
first_value: Option<&Series>,
name: PlSmallStr,
Expand All @@ -270,7 +307,7 @@ fn iterator_to_list(
.map_err(PyPolarsErr::from)?;
}
for opt_val in it {
match opt_val {
match opt_val? {
None => builder.append_null(),
Some(s) => {
if s.len() == 0 && s.dtype() != dt {
Expand Down
Loading

0 comments on commit 8df0cbe

Please sign in to comment.