Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust): Too-strict SQL UDF schema validation #20202

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 3 additions & 45 deletions crates/polars-plan/src/dsl/udf.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
use arrow::legacy::error::{polars_bail, PolarsResult};
use polars_core::prelude::Field;
use polars_core::schema::Schema;
use polars_utils::pl_str::PlSmallStr;

use super::{ColumnsUdf, Expr, GetOutput, OpaqueColumnUdf};
use crate::prelude::{new_column_udf, Context, FunctionOptions};
use crate::prelude::{new_column_udf, FunctionOptions};

/// Represents a user-defined function
#[derive(Clone)]
pub struct UserDefinedFunction {
/// name
pub name: PlSmallStr,
/// The function signature.
pub input_fields: Vec<Field>,
/// The function output type.
pub return_type: GetOutput,
/// The function implementation.
Expand All @@ -25,7 +20,6 @@ impl std::fmt::Debug for UserDefinedFunction {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("UserDefinedFunction")
.field("name", &self.name)
.field("signature", &self.input_fields)
.field("fun", &"<FUNC>")
.field("options", &self.options)
.finish()
Expand All @@ -34,53 +28,17 @@ impl std::fmt::Debug for UserDefinedFunction {

impl UserDefinedFunction {
/// Create a new UserDefinedFunction
pub fn new(
name: PlSmallStr,
input_fields: Vec<Field>,
return_type: GetOutput,
fun: impl ColumnsUdf + 'static,
) -> Self {
pub fn new(name: PlSmallStr, return_type: GetOutput, fun: impl ColumnsUdf + 'static) -> Self {
Self {
name,
input_fields,
return_type,
fun: new_column_udf(fun),
options: FunctionOptions::default(),
}
}

/// creates a logical expression with a call of the UDF
/// This utility allows using the UDF without requiring access to the registry.
/// The schema is validated and the query will fail if the schema is invalid.
pub fn call(self, args: Vec<Expr>) -> PolarsResult<Expr> {
if args.len() != self.input_fields.len() {
polars_bail!(InvalidOperation: "expected {} arguments, got {}", self.input_fields.len(), args.len())
}
let schema = Schema::from_iter(self.input_fields);

if args
.iter()
.map(|e| e.to_field(&schema, Context::Default))
.collect::<PolarsResult<Vec<_>>>()
.is_err()
{
polars_bail!(InvalidOperation: "unexpected field in UDF \nexpected: {:?}\n received {:?}", schema, args)
};

Ok(Expr::AnonymousFunction {
input: args,
function: self.fun,
output_type: self.return_type,
options: self.options,
})
}

/// creates a logical expression with a call of the UDF
/// This does not do any schema validation and is therefore faster.
///
/// Only use this if you are certain that the schema is correct.
/// If the schema is invalid, the query will fail at runtime.
pub fn call_unchecked(self, args: Vec<Expr>) -> Expr {
pub fn call(self, args: Vec<Expr>) -> Expr {
Expr::AnonymousFunction {
input: args,
function: self.fun,
Expand Down
5 changes: 3 additions & 2 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1454,11 +1454,12 @@ impl SQLFunctionVisitor<'_> {
})
.collect::<PolarsResult<Vec<_>>>()?;

self.ctx
Ok(self
.ctx
.function_registry
.get_udf(func_name)?
.ok_or_else(|| polars_err!(SQLInterface: "UDF {} not found", func_name))?
.call(args)
.call(args))
}

/// Window specs without partition bys are essentially cumulative functions
Expand Down
13 changes: 0 additions & 13 deletions crates/polars-sql/tests/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ impl FunctionRegistry for MyFunctionRegistry {
fn test_udfs() -> PolarsResult<()> {
let my_custom_sum = UserDefinedFunction::new(
"my_custom_sum".into(),
vec![
Field::new("a".into(), DataType::Int32),
Field::new("b".into(), DataType::Int32),
],
GetOutput::same_type(),
move |c: &mut [Column]| {
let first = c[0].as_materialized_series().clone();
Expand All @@ -61,18 +57,9 @@ fn test_udfs() -> PolarsResult<()> {
let res = ctx.execute("SELECT a, b, my_custom_sum(a, b) FROM foo");
assert!(res.is_ok());

// schema is invalid so it will fail
assert!(ctx
.execute("SELECT a, b, my_custom_sum(c) as invalid FROM foo")
.is_err());

// create a new UDF to be registered on the context
let my_custom_divide = UserDefinedFunction::new(
"my_custom_divide".into(),
vec![
Field::new("a".into(), DataType::Int32),
Field::new("b".into(), DataType::Int32),
],
GetOutput::same_type(),
move |c: &mut [Column]| {
let first = c[0].as_materialized_series().clone();
Expand Down
Loading