From 9ea5839c52bf0606aaa0b174d9a974992e0ea328 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Tue, 24 Dec 2024 13:47:09 +0100 Subject: [PATCH] refactor: Add a `TypeCheckRule` to the optimizer (#20425) --- crates/polars-lazy/src/frame/mod.rs | 6 +++ crates/polars-plan/src/dsl/expr_dyn_fn.rs | 6 +++ .../src/dsl/functions/horizontal.rs | 17 +++---- crates/polars-plan/src/frame/opt_state.rs | 4 +- .../src/plans/conversion/dsl_to_ir.rs | 3 +- .../polars-plan/src/plans/conversion/join.rs | 3 ++ .../polars-plan/src/plans/conversion/mod.rs | 28 +++++++++++ .../src/plans/conversion/stack_opt.rs | 34 +++++++++---- .../src/plans/conversion/type_check/mod.rs | 48 +++++++++++++++++++ .../src/plans/conversion/type_coercion/mod.rs | 32 ++----------- crates/polars-plan/src/plans/expr_ir.rs | 4 ++ .../plans/optimizer/collapse_and_project.rs | 37 +++++++------- .../src/plans/optimizer/count_star.rs | 46 +++++++++--------- .../src/plans/optimizer/delay_rechunk.rs | 8 ++-- .../src/plans/optimizer/flatten_union.rs | 9 ++-- crates/polars-plan/src/plans/optimizer/mod.rs | 2 +- .../src/plans/optimizer/stack_opt.rs | 6 +-- crates/polars-python/src/lazyframe/general.rs | 2 + py-polars/polars/functions/lazy.py | 8 ++++ py-polars/polars/lazyframe/frame.py | 34 +++++++++++++ py-polars/tests/unit/test_errors.py | 2 +- 21 files changed, 240 insertions(+), 99 deletions(-) create mode 100644 crates/polars-plan/src/plans/conversion/type_check/mod.rs diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index acfa4b77d8f2..535a557b0c02 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -172,6 +172,12 @@ impl LazyFrame { self } + /// Toggle type check optimization. + pub fn with_type_check(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::TYPE_CHECK, toggle); + self + } + /// Toggle expression simplification optimization on or off. pub fn with_simplify_expr(mut self, toggle: bool) -> Self { self.opt_state.set(OptFlags::SIMPLIFY_EXPR, toggle); diff --git a/crates/polars-plan/src/dsl/expr_dyn_fn.rs b/crates/polars-plan/src/dsl/expr_dyn_fn.rs index 483dafcc83f1..039f00b4f2ba 100644 --- a/crates/polars-plan/src/dsl/expr_dyn_fn.rs +++ b/crates/polars-plan/src/dsl/expr_dyn_fn.rs @@ -295,6 +295,12 @@ impl GetOutput { Default::default() } + pub fn first() -> Self { + SpecialEq::new(Arc::new( + |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()), + )) + } + pub fn from_type(dt: DataType) -> Self { SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { Ok(Field::new(flds[0].name().clone(), dt.clone())) diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs index 8d01d3696086..841c5279f28d 100644 --- a/crates/polars-plan/src/dsl/functions/horizontal.rs +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -25,15 +25,15 @@ where F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync, E: AsRef<[Expr]>, { - let mut exprs = exprs.as_ref().to_vec(); - exprs.push(acc); + let mut exprs_v = Vec::with_capacity(exprs.as_ref().len() + 1); + exprs_v.push(acc); + exprs_v.extend(exprs.as_ref().iter().cloned()); + let exprs = exprs_v; let function = new_column_udf(move |columns: &mut [Column]| { - let mut columns = columns.to_vec(); - let mut acc = columns.pop().unwrap(); - - for c in columns { - if let Some(a) = f(acc.clone(), c)? { + let mut acc = columns.first().unwrap().clone(); + for c in &columns[1..] { + if let Some(a) = f(acc.clone(), c.clone())? { acc = a } } @@ -43,7 +43,8 @@ where Expr::AnonymousFunction { input: exprs, function, - output_type: GetOutput::super_type(), + // Take the type of the accumulator. + output_type: GetOutput::first(), options: FunctionOptions { collect_groups: ApplyOptions::GroupWise, flags: FunctionFlags::default() diff --git a/crates/polars-plan/src/frame/opt_state.rs b/crates/polars-plan/src/frame/opt_state.rs index 04586e774a04..23e4d28e44b0 100644 --- a/crates/polars-plan/src/frame/opt_state.rs +++ b/crates/polars-plan/src/frame/opt_state.rs @@ -38,12 +38,14 @@ bitflags! { /// Check if operations are order dependent and unset maintaining_order if /// the order would not be observed. const CHECK_ORDER_OBSERVE = 1 << 16; + /// Do type checking of the IR. + const TYPE_CHECK = 1 << 17; } } impl OptFlags { pub fn schema_only() -> Self { - Self::TYPE_COERCION + Self::TYPE_COERCION | Self::TYPE_CHECK } } diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index ab9d2f71f78a..90fa93c5a120 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -63,12 +63,13 @@ pub fn to_alp( lp: DslPlan, expr_arena: &mut Arena, lp_arena: &mut Arena, - // Only `SIMPLIFY_EXPR` and `TYPE_COERCION` are respected. + // Only `SIMPLIFY_EXPR`, `TYPE_COERCION`, `TYPE_CHECK` are respected. opt_flags: &mut OptFlags, ) -> PolarsResult { let conversion_optimizer = ConversionOptimizer::new( opt_flags.contains(OptFlags::SIMPLIFY_EXPR), opt_flags.contains(OptFlags::TYPE_COERCION), + opt_flags.contains(OptFlags::TYPE_CHECK), ); let mut ctxt = DslConversionContext { diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 9aea4cd686fb..6eff0f951338 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -120,6 +120,9 @@ pub fn resolve_join( .coerce_types(ctxt.expr_arena, ctxt.lp_arena, input_right) .map_err(|e| e.context("'join' failed".into()))?; + let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena); + let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena); + // Not a closure to avoid borrow issues because we mutate expr_arena as well. macro_rules! get_dtype { ($expr:expr, $schema:expr) => { diff --git a/crates/polars-plan/src/plans/conversion/mod.rs b/crates/polars-plan/src/plans/conversion/mod.rs index 15cc416d49f4..1d3363d20842 100644 --- a/crates/polars-plan/src/plans/conversion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/mod.rs @@ -12,16 +12,20 @@ mod ir_to_dsl; mod scans; mod stack_opt; +use std::borrow::Cow; use std::sync::{Arc, Mutex}; pub use dsl_to_ir::*; pub use expr_to_ir::*; pub use ir_to_dsl::*; use polars_core::prelude::*; +use polars_utils::idx_vec::UnitVec; +use polars_utils::unitvec; use polars_utils::vec::ConvertVec; use recursive::recursive; mod functions; mod join; +pub(crate) mod type_check; pub(crate) mod type_coercion; pub(crate) use expr_expansion::{expand_selectors, is_regex_projection, prepare_projection}; @@ -266,3 +270,27 @@ impl IR { } } } + +fn get_input(lp_arena: &Arena, lp_node: Node) -> UnitVec { + let plan = lp_arena.get(lp_node); + let mut inputs: UnitVec = unitvec!(); + + // Used to get the schema of the input. + if is_scan(plan) { + inputs.push(lp_node); + } else { + plan.copy_inputs(&mut inputs); + }; + inputs +} + +fn get_schema(lp_arena: &Arena, lp_node: Node) -> Cow<'_, SchemaRef> { + let inputs = get_input(lp_arena, lp_node); + if inputs.is_empty() { + // Files don't have an input, so we must take their schema. + Cow::Borrowed(lp_arena.get(lp_node).scan_schema()) + } else { + let input = inputs[0]; + lp_arena.get(input).schema(lp_arena) + } +} diff --git a/crates/polars-plan/src/plans/conversion/stack_opt.rs b/crates/polars-plan/src/plans/conversion/stack_opt.rs index 8db4e82659d5..3401a892ced3 100644 --- a/crates/polars-plan/src/plans/conversion/stack_opt.rs +++ b/crates/polars-plan/src/plans/conversion/stack_opt.rs @@ -1,12 +1,15 @@ use std::borrow::Borrow; +use self::type_check::TypeCheckRule; use super::*; /// Applies expression simplification and type coercion during conversion to IR. pub(super) struct ConversionOptimizer { scratch: Vec, + simplify: Option, coerce: Option, + check: Option, // IR's can be cached in the DSL. // But if they are used multiple times in DSL (e.g. concat/join) // then it can occur that we take a slot multiple times. @@ -16,7 +19,7 @@ pub(super) struct ConversionOptimizer { } impl ConversionOptimizer { - pub(super) fn new(simplify: bool, type_coercion: bool) -> Self { + pub(super) fn new(simplify: bool, type_coercion: bool, type_check: bool) -> Self { let simplify = if simplify { Some(SimplifyExprRule {}) } else { @@ -29,10 +32,17 @@ impl ConversionOptimizer { None }; + let check = if type_check { + Some(TypeCheckRule) + } else { + None + }; + ConversionOptimizer { scratch: Vec::with_capacity(8), simplify, coerce, + check, used_arenas: Default::default(), } } @@ -54,29 +64,35 @@ impl ConversionOptimizer { pub(super) fn coerce_types( &mut self, expr_arena: &mut Arena, - lp_arena: &Arena, + ir_arena: &mut Arena, current_node: Node, ) -> PolarsResult<()> { // Different from the stack-opt in the optimizer phase, this does a single pass until fixed point per expression. + if let Some(rule) = &mut self.check { + while let Some(x) = rule.optimize_plan(ir_arena, expr_arena, current_node)? { + ir_arena.replace(current_node, x); + } + } + // process the expressions on the stack and apply optimizations. while let Some(current_expr_node) = self.scratch.pop() { - { - let expr = unsafe { expr_arena.get_unchecked(current_expr_node) }; - if expr.is_leaf() { - continue; - } + let expr = unsafe { expr_arena.get_unchecked(current_expr_node) }; + + if expr.is_leaf() { + continue; } + if let Some(rule) = &mut self.simplify { while let Some(x) = - rule.optimize_expr(expr_arena, current_expr_node, lp_arena, current_node)? + rule.optimize_expr(expr_arena, current_expr_node, ir_arena, current_node)? { expr_arena.replace(current_expr_node, x); } } if let Some(rule) = &mut self.coerce { while let Some(x) = - rule.optimize_expr(expr_arena, current_expr_node, lp_arena, current_node)? + rule.optimize_expr(expr_arena, current_expr_node, ir_arena, current_node)? { expr_arena.replace(current_expr_node, x); } diff --git a/crates/polars-plan/src/plans/conversion/type_check/mod.rs b/crates/polars-plan/src/plans/conversion/type_check/mod.rs new file mode 100644 index 000000000000..cf5c353688cb --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/type_check/mod.rs @@ -0,0 +1,48 @@ +use polars_core::error::{polars_ensure, PolarsResult}; +use polars_core::prelude::DataType; +use polars_utils::arena::{Arena, Node}; + +use super::{AExpr, OptimizationRule, IR}; +use crate::plans::conversion::get_schema; +use crate::plans::Context; + +pub struct TypeCheckRule; + +impl OptimizationRule for TypeCheckRule { + fn optimize_plan( + &mut self, + ir_arena: &mut Arena, + expr_arena: &mut Arena, + node: Node, + ) -> PolarsResult> { + let ir = ir_arena.get(node); + match ir { + IR::Scan { + predicate: Some(predicate), + .. + } => { + let input_schema = get_schema(ir_arena, node); + let dtype = predicate.dtype(input_schema.as_ref(), Context::Default, expr_arena)?; + + polars_ensure!( + matches!(dtype, DataType::Boolean | DataType::Unknown(_)), + InvalidOperation: "filter predicate must be of type `Boolean`, got `{dtype:?}`" + ); + + Ok(None) + }, + IR::Filter { predicate, .. } => { + let input_schema = get_schema(ir_arena, node); + let dtype = predicate.dtype(input_schema.as_ref(), Context::Default, expr_arena)?; + + polars_ensure!( + matches!(dtype, DataType::Boolean | DataType::Unknown(_)), + InvalidOperation: "filter predicate must be of type `Boolean`, got `{dtype:?}`" + ); + + Ok(None) + }, + _ => Ok(None), + } + } +} diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs index c37f792ce797..366b297ec6bb 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs @@ -3,16 +3,13 @@ mod functions; #[cfg(feature = "is_in")] mod is_in; -use std::borrow::Cow; - use binary::process_binary; use polars_compute::cast::temporal::{time_unit_multiple, SECONDS_IN_DAY}; use polars_core::chunked_array::cast::CastOptions; use polars_core::prelude::*; use polars_core::utils::{get_supertype, get_supertype_with_options, materialize_dyn_int}; -use polars_utils::idx_vec::UnitVec; +use polars_utils::format_list; use polars_utils::itertools::Itertools; -use polars_utils::{format_list, unitvec}; use super::*; @@ -69,30 +66,6 @@ fn modify_supertype( st } -fn get_input(lp_arena: &Arena, lp_node: Node) -> UnitVec { - let plan = lp_arena.get(lp_node); - let mut inputs: UnitVec = unitvec!(); - - // Used to get the schema of the input. - if is_scan(plan) { - inputs.push(lp_node); - } else { - plan.copy_inputs(&mut inputs); - }; - inputs -} - -fn get_schema(lp_arena: &Arena, lp_node: Node) -> Cow<'_, SchemaRef> { - let inputs = get_input(lp_arena, lp_node); - if inputs.is_empty() { - // Files don't have an input, so we must take their schema. - Cow::Borrowed(lp_arena.get(lp_node).scan_schema()) - } else { - let input = inputs[0]; - lp_arena.get(input).schema(lp_arena) - } -} - fn get_aexpr_and_type<'a>( expr_arena: &'a Arena, e: Node, @@ -515,6 +488,7 @@ fn cast_expr_ir( if let AExpr::Literal(lv) = expr_arena.get(e.node()) { if let Some(literal) = try_inline_literal_cast(lv, to_dtype, strict)? { e.set_node(expr_arena.add(AExpr::Literal(literal))); + e.set_dtype(to_dtype.clone()); return Ok(()); } } @@ -524,6 +498,8 @@ fn cast_expr_ir( dtype: to_dtype.clone(), options: CastOptions::Strict, })); + e.set_dtype(to_dtype.clone()); + Ok(()) } diff --git a/crates/polars-plan/src/plans/expr_ir.rs b/crates/polars-plan/src/plans/expr_ir.rs index c61aa21f03e6..01bb76a2de5c 100644 --- a/crates/polars-plan/src/plans/expr_ir.rs +++ b/crates/polars-plan/src/plans/expr_ir.rs @@ -106,6 +106,10 @@ impl ExprIR { self } + pub(crate) fn set_dtype(&mut self, dtype: DataType) { + self.output_dtype = OnceLock::from(dtype); + } + pub fn from_node(node: Node, arena: &Arena) -> Self { let mut out = Self { node, diff --git a/crates/polars-plan/src/plans/optimizer/collapse_and_project.rs b/crates/polars-plan/src/plans/optimizer/collapse_and_project.rs index 4bd0079a4827..912c33e9c77b 100644 --- a/crates/polars-plan/src/plans/optimizer/collapse_and_project.rs +++ b/crates/polars-plan/src/plans/optimizer/collapse_and_project.rs @@ -33,7 +33,7 @@ impl OptimizationRule for SimpleProjectionAndCollapse { lp_arena: &mut Arena, expr_arena: &mut Arena, node: Node, - ) -> Option { + ) -> PolarsResult> { use IR::*; let lp = lp_arena.get(node); @@ -47,22 +47,25 @@ impl OptimizationRule for SimpleProjectionAndCollapse { matches!(expr_arena.get(e.node()), AExpr::Column(_)) && !e.has_alias() }) { self.processed.insert(node); - return None; + return Ok(None); } let exprs = expr .iter() .map(|e| e.output_name().clone()) .collect::>(); - let alp = IRBuilder::new(*input, expr_arena, lp_arena) + let Some(alp) = IRBuilder::new(*input, expr_arena, lp_arena) .project_simple(exprs.iter().cloned()) - .ok()? - .build(); + .ok() + else { + return Ok(None); + }; + let alp = alp.build(); - Some(alp) + Ok(Some(alp)) } else { self.processed.insert(node); - None + Ok(None) } }, SimpleProjection { columns, input } if !self.eager => { @@ -70,10 +73,10 @@ impl OptimizationRule for SimpleProjectionAndCollapse { // If there are 2 subsequent fast projections, flatten them and only take the last SimpleProjection { input: prev_input, .. - } => Some(SimpleProjection { + } => Ok(Some(SimpleProjection { input: *prev_input, columns: columns.clone(), - }), + })), // Cleanup projections set in projection pushdown just above caches // they are not needed. cache_lp @ Cache { .. } if self.processed.contains(&node) => { @@ -83,9 +86,9 @@ impl OptimizationRule for SimpleProjectionAndCollapse { |(left_name, right_name)| left_name.as_str() == right_name.as_str(), ) { - Some(cache_lp.clone()) + Ok(Some(cache_lp.clone())) } else { - None + Ok(None) } }, // If a projection does nothing, remove it. @@ -93,10 +96,10 @@ impl OptimizationRule for SimpleProjectionAndCollapse { let input_schema = other.schema(lp_arena); // This will fail fast if lengths are not equal if *input_schema.as_ref() == *columns { - Some(other.clone()) + Ok(Some(other.clone())) } else { self.processed.insert(node); - None + Ok(None) } }, } @@ -113,17 +116,17 @@ impl OptimizationRule for SimpleProjectionAndCollapse { cache_hits, } = lp_arena.get(*input) { - Some(Cache { + Ok(Some(Cache { input: *prev_input, id: *id, // ensure the counts are updated cache_hits: cache_hits.saturating_add(*outer_cache_hits), - }) + })) } else { - None + Ok(None) } }, - _ => None, + _ => Ok(None), } } } diff --git a/crates/polars-plan/src/plans/optimizer/count_star.rs b/crates/polars-plan/src/plans/optimizer/count_star.rs index e0643028e0fe..d465112b30ac 100644 --- a/crates/polars-plan/src/plans/optimizer/count_star.rs +++ b/crates/polars-plan/src/plans/optimizer/count_star.rs @@ -19,30 +19,32 @@ impl OptimizationRule for CountStar { lp_arena: &mut Arena, expr_arena: &mut Arena, node: Node, - ) -> Option { - visit_logical_plan_for_scan_paths(node, lp_arena, expr_arena, false).map( - |count_star_expr| { - // MapFunction needs a leaf node, hence we create a dummy placeholder node - let placeholder = IR::DataFrameScan { - df: Arc::new(Default::default()), - schema: Arc::new(Default::default()), - output_schema: None, - filter: None, - }; - let placeholder_node = lp_arena.add(placeholder); + ) -> PolarsResult> { + Ok( + visit_logical_plan_for_scan_paths(node, lp_arena, expr_arena, false).map( + |count_star_expr| { + // MapFunction needs a leaf node, hence we create a dummy placeholder node + let placeholder = IR::DataFrameScan { + df: Arc::new(Default::default()), + schema: Arc::new(Default::default()), + output_schema: None, + filter: None, + }; + let placeholder_node = lp_arena.add(placeholder); - let alp = IR::MapFunction { - input: placeholder_node, - function: FunctionIR::FastCount { - sources: count_star_expr.sources, - scan_type: count_star_expr.scan_type, - alias: count_star_expr.alias, - }, - }; + let alp = IR::MapFunction { + input: placeholder_node, + function: FunctionIR::FastCount { + sources: count_star_expr.sources, + scan_type: count_star_expr.scan_type, + alias: count_star_expr.alias, + }, + }; - lp_arena.replace(count_star_expr.node, alp.clone()); - alp - }, + lp_arena.replace(count_star_expr.node, alp.clone()); + alp + }, + ), ) } } diff --git a/crates/polars-plan/src/plans/optimizer/delay_rechunk.rs b/crates/polars-plan/src/plans/optimizer/delay_rechunk.rs index 6d598f57f799..319e47fb79c2 100644 --- a/crates/polars-plan/src/plans/optimizer/delay_rechunk.rs +++ b/crates/polars-plan/src/plans/optimizer/delay_rechunk.rs @@ -19,14 +19,14 @@ impl OptimizationRule for DelayRechunk { lp_arena: &mut Arena, _expr_arena: &mut Arena, node: Node, - ) -> Option { + ) -> PolarsResult> { match lp_arena.get(node) { // An aggregation can be partitioned, its wasteful to rechunk before that partition. #[allow(unused_mut)] IR::GroupBy { input, keys, .. } => { // Multiple keys on multiple chunks is much slower, so rechunk. if !self.processed.insert(node.0) || keys.len() > 1 { - return None; + return Ok(None); }; use IR::*; @@ -62,9 +62,9 @@ impl OptimizationRule for DelayRechunk { } }; - None + Ok(None) }, - _ => None, + _ => Ok(None), } } } diff --git a/crates/polars-plan/src/plans/optimizer/flatten_union.rs b/crates/polars-plan/src/plans/optimizer/flatten_union.rs index 42ab5fd1525f..8f8de84217da 100644 --- a/crates/polars-plan/src/plans/optimizer/flatten_union.rs +++ b/crates/polars-plan/src/plans/optimizer/flatten_union.rs @@ -1,3 +1,4 @@ +use polars_core::error::PolarsResult; use polars_utils::arena::{Arena, Node}; use IR::*; @@ -19,7 +20,7 @@ impl OptimizationRule for FlattenUnionRule { lp_arena: &mut polars_utils::arena::Arena, _expr_arena: &mut polars_utils::arena::Arena, node: polars_utils::arena::Node, - ) -> Option { + ) -> PolarsResult> { let lp = lp_arena.get(node); match lp { @@ -41,12 +42,12 @@ impl OptimizationRule for FlattenUnionRule { } options.flattened_by_opt = true; - Some(Union { + Ok(Some(Union { inputs: new_inputs, options, - }) + })) }, - _ => None, + _ => Ok(None), } } } diff --git a/crates/polars-plan/src/plans/optimizer/mod.rs b/crates/polars-plan/src/plans/optimizer/mod.rs index c712badfab45..7661ec7c5368 100644 --- a/crates/polars-plan/src/plans/optimizer/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/mod.rs @@ -161,7 +161,7 @@ pub fn optimize( if projection_pushdown_opt.is_count_star { let mut count_star_opt = CountStar::new(); - count_star_opt.optimize_plan(lp_arena, expr_arena, lp_top); + count_star_opt.optimize_plan(lp_arena, expr_arena, lp_top)?; } } diff --git a/crates/polars-plan/src/plans/optimizer/stack_opt.rs b/crates/polars-plan/src/plans/optimizer/stack_opt.rs index 8e4f807c139a..5468960661cf 100644 --- a/crates/polars-plan/src/plans/optimizer/stack_opt.rs +++ b/crates/polars-plan/src/plans/optimizer/stack_opt.rs @@ -31,7 +31,7 @@ impl StackOptimizer { // Apply rules for rule in rules.iter_mut() { // keep iterating over same rule - while let Some(x) = rule.optimize_plan(lp_arena, expr_arena, current_node) { + while let Some(x) = rule.optimize_plan(lp_arena, expr_arena, current_node)? { lp_arena.replace(current_node, x); changed = true; } @@ -93,8 +93,8 @@ pub trait OptimizationRule { _lp_arena: &mut Arena, _expr_arena: &mut Arena, _node: Node, - ) -> Option { - None + ) -> PolarsResult> { + Ok(None) } fn optimize_expr( &mut self, diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index 5e0ce0e2e0a4..ea6c39f8a2b9 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -484,6 +484,7 @@ impl PyLazyFrame { fn optimization_toggle( &self, type_coercion: bool, + type_check: bool, predicate_pushdown: bool, projection_pushdown: bool, simplify_expression: bool, @@ -500,6 +501,7 @@ impl PyLazyFrame { let ldf = self.ldf.clone(); let mut ldf = ldf .with_type_coercion(type_coercion) + .with_type_check(type_check) .with_predicate_pushdown(predicate_pushdown) .with_simplify_expr(simplify_expression) .with_slice_pushdown(slice_pushdown) diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index 7fadb8517aaa..e2705a3d54de 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -1619,6 +1619,7 @@ def collect_all( lazy_frames: Iterable[LazyFrame], *, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -1694,8 +1695,10 @@ def collect_all( prepared = [] for lf in lazy_frames: + type_check = _type_check ldf = lf._ldf.optimization_toggle( type_coercion, + type_check, predicate_pushdown, projection_pushdown, simplify_expression, @@ -1725,6 +1728,7 @@ def collect_all_async( *, gevent: Literal[True], type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -1744,6 +1748,7 @@ def collect_all_async( *, gevent: Literal[False] = False, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -1763,6 +1768,7 @@ def collect_all_async( *, gevent: bool = False, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -1861,8 +1867,10 @@ def collect_all_async( prepared = [] for lf in lazy_frames: + type_check = _type_check ldf = lf._ldf.optimization_toggle( type_coercion, + type_check, predicate_pushdown, projection_pushdown, simplify_expression, diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index fa9e10165f35..887a7a63faa8 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -1013,6 +1013,7 @@ def explain( format: ExplainFormat = "plain", optimized: bool = True, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -1096,8 +1097,10 @@ def explain( issue_unstable_warning("Streaming mode is considered unstable.") if optimized: + type_check = _type_check ldf = self._ldf.optimization_toggle( type_coercion, + type_check, predicate_pushdown, projection_pushdown, simplify_expression, @@ -1130,6 +1133,7 @@ def show_graph( raw_output: bool = False, figsize: tuple[float, float] = (16.0, 12.0), type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -1193,8 +1197,10 @@ def show_graph( ... "a" ... ).show_graph() # doctest: +SKIP """ + type_check = _type_check _ldf = self._ldf.optimization_toggle( type_coercion, + type_check, predicate_pushdown, projection_pushdown, simplify_expression, @@ -1602,6 +1608,7 @@ def profile( self, *, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -1699,8 +1706,10 @@ def profile( cluster_with_columns = False collapse_joins = False + type_check = _type_check ldf = self._ldf.optimization_toggle( type_coercion, + type_check, predicate_pushdown, projection_pushdown, simplify_expression, @@ -1761,6 +1770,7 @@ def collect( self, *, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -1782,6 +1792,7 @@ def collect( self, *, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -1802,6 +1813,7 @@ def collect( self, *, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -2003,8 +2015,10 @@ def collect( # Don't run on GPU in _eager mode (but don't warn) is_gpu = False + type_check = _type_check ldf = self._ldf.optimization_toggle( type_coercion, + type_check, predicate_pushdown, projection_pushdown, simplify_expression, @@ -2048,6 +2062,7 @@ def collect_async( *, gevent: Literal[True], type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -2066,6 +2081,7 @@ def collect_async( *, gevent: Literal[False] = False, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -2083,6 +2099,7 @@ def collect_async( *, gevent: bool = False, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -2203,8 +2220,10 @@ def collect_async( if streaming: issue_unstable_warning("Streaming mode is considered unstable.") + type_check = _type_check ldf = self._ldf.optimization_toggle( type_coercion, + type_check, predicate_pushdown, projection_pushdown, simplify_expression, @@ -2269,6 +2288,7 @@ def sink_parquet( data_page_size: int | None = None, maintain_order: bool = True, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -2383,6 +2403,7 @@ def sink_parquet( """ lf = self._set_sink_optimizations( type_coercion=type_coercion, + _type_check=_type_check, predicate_pushdown=predicate_pushdown, projection_pushdown=projection_pushdown, simplify_expression=simplify_expression, @@ -2441,6 +2462,7 @@ def sink_ipc( compression: str | None = "zstd", maintain_order: bool = True, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -2522,6 +2544,7 @@ def sink_ipc( """ lf = self._set_sink_optimizations( type_coercion=type_coercion, + _type_check=_type_check, predicate_pushdown=predicate_pushdown, projection_pushdown=projection_pushdown, simplify_expression=simplify_expression, @@ -2571,6 +2594,7 @@ def sink_csv( quote_style: CsvQuoteStyle | None = None, maintain_order: bool = True, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -2707,6 +2731,7 @@ def sink_csv( lf = self._set_sink_optimizations( type_coercion=type_coercion, + _type_check=_type_check, predicate_pushdown=predicate_pushdown, projection_pushdown=projection_pushdown, simplify_expression=simplify_expression, @@ -2755,6 +2780,7 @@ def sink_ndjson( *, maintain_order: bool = True, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -2833,6 +2859,7 @@ def sink_ndjson( """ lf = self._set_sink_optimizations( type_coercion=type_coercion, + _type_check=_type_check, predicate_pushdown=predicate_pushdown, projection_pushdown=projection_pushdown, simplify_expression=simplify_expression, @@ -2865,6 +2892,7 @@ def _set_sink_optimizations( self, *, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -2881,6 +2909,7 @@ def _set_sink_optimizations( return self._ldf.optimization_toggle( type_coercion=type_coercion, + type_check=_type_check, predicate_pushdown=predicate_pushdown, projection_pushdown=projection_pushdown, simplify_expression=simplify_expression, @@ -2905,6 +2934,7 @@ def fetch( n_rows: int = 500, *, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -2939,6 +2969,7 @@ def fetch( return self._fetch( n_rows=n_rows, type_coercion=type_coercion, + _type_check=_type_check, predicate_pushdown=predicate_pushdown, projection_pushdown=projection_pushdown, simplify_expression=simplify_expression, @@ -2956,6 +2987,7 @@ def _fetch( n_rows: int = 500, *, type_coercion: bool = True, + _type_check: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, simplify_expression: bool = True, @@ -3055,8 +3087,10 @@ def _fetch( if streaming: issue_unstable_warning("Streaming mode is considered unstable.") + type_check = _type_check lf = self._ldf.optimization_toggle( type_coercion, + type_check, predicate_pushdown, projection_pushdown, simplify_expression, diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index 79ed64dc88ac..f5f7f6a3d985 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -233,7 +233,7 @@ def test_error_on_double_agg() -> None: def test_filter_not_of_type_bool() -> None: df = pl.DataFrame({"json_val": ['{"a":"hello"}', None, '{"a":"world"}']}) with pytest.raises( - ComputeError, match="filter predicate must be of type `Boolean`, got" + InvalidOperationError, match="filter predicate must be of type `Boolean`, got" ): df.filter(pl.col("json_val").str.json_path_match("$.a"))