diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 45e5409ae9ac..feccf5679efb 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -28,11 +28,10 @@ use crate::utils::NamePreserver; use datafusion_common::alias::AliasGenerator; use datafusion_common::hash_utils::combine_hashes; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, -}; -use datafusion_common::{ - internal_datafusion_err, qualified_name, Column, DFSchema, DFSchemaRef, Result, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + TreeNodeVisitor, }; +use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::{ @@ -144,6 +143,23 @@ pub struct CommonSubexprEliminate { random_state: RandomState, } +/// The result of potentially rewriting a list of expressions to eliminate common +/// subexpressions. +#[derive(Debug)] +enum FoundCommonExprs { + /// No common expressions were found + No { original_exprs_list: Vec> }, + /// Common expressions were found + Yes { + /// extracted common expressions + common_exprs: Vec<(Expr, String)>, + /// new expressions with common subexpressions replaced + new_exprs_list: Vec>, + /// original expressions + original_exprs_list: Vec>, + }, +} + impl CommonSubexprEliminate { pub fn new() -> Self { Self { @@ -217,8 +233,7 @@ impl CommonSubexprEliminate { expr_stats: &ExprStats<'n>, common_exprs: &mut CommonExprs<'n>, alias_generator: &AliasGenerator, - ) -> Result>>> { - let mut transformed = false; + ) -> Result>> { exprs_list .into_iter() .zip(arrays_list.iter()) @@ -227,69 +242,65 @@ impl CommonSubexprEliminate { .into_iter() .zip(arrays.iter()) .map(|(expr, id_array)| { - let replaced = replace_common_expr( + replace_common_expr( expr, id_array, expr_stats, common_exprs, alias_generator, - )?; - // remember if this expression was actually replaced - transformed |= replaced.transformed; - Ok(replaced.data) + ) }) .collect::>>() }) .collect::>>() - .map(|rewritten_exprs_list| { - // propagate back transformed information - Transformed::new_transformed(rewritten_exprs_list, transformed) - }) } - /// Rewrites the expression in `exprs_list` with common sub-expressions - /// replaced with a new column and adds a ProjectionExec on top of `input` - /// which computes any replaced common sub-expressions. + /// Extracts common sub-expressions and rewrites `exprs_list`. /// - /// Returns a tuple of: - /// 1. The rewritten expressions - /// 2. A `LogicalPlan::Projection` with input of `input` that computes any - /// common sub-expressions that were used - fn rewrite_expr( + /// Returns `FoundCommonExprs` recording the result of the extraction + fn find_common_exprs( &self, exprs_list: Vec>, - arrays_list: Vec>, - input: LogicalPlan, - expr_stats: &ExprStats, config: &dyn OptimizerConfig, - ) -> Result>, LogicalPlan)>> { - let mut transformed = false; - let mut common_exprs = CommonExprs::new(); - - let rewrite_exprs = self.rewrite_exprs_list( - exprs_list, - arrays_list, - expr_stats, - &mut common_exprs, - &config.alias_generator(), - )?; - transformed |= rewrite_exprs.transformed; + expr_mask: ExprMask, + ) -> Result> { + let mut found_common = false; + let mut expr_stats = ExprStats::new(); + let id_arrays_list = exprs_list + .iter() + .map(|exprs| { + self.to_arrays(exprs, &mut expr_stats, expr_mask).map( + |(fc, id_arrays)| { + found_common |= fc; - let new_input = self.rewrite(input, config)?; - transformed |= new_input.transformed; - let mut new_input = new_input.data; + id_arrays + }, + ) + }) + .collect::>>()?; + if found_common { + let mut common_exprs = CommonExprs::new(); + let new_exprs_list = self.rewrite_exprs_list( + // Must clone as Identifiers use references to original expressions so we have + // to keep the original expressions intact. + exprs_list.clone(), + id_arrays_list, + &expr_stats, + &mut common_exprs, + &config.alias_generator(), + )?; + assert!(!common_exprs.is_empty()); - if !common_exprs.is_empty() { - assert!(transformed); - new_input = build_common_expr_project_plan(new_input, common_exprs)?; + Ok(Transformed::yes(FoundCommonExprs::Yes { + common_exprs: common_exprs.into_values().collect(), + new_exprs_list, + original_exprs_list: exprs_list, + })) + } else { + Ok(Transformed::no(FoundCommonExprs::No { + original_exprs_list: exprs_list, + })) } - - // return the transformed information - - Ok(Transformed::new_transformed( - (rewrite_exprs.data, new_input), - transformed, - )) } fn try_optimize_proj( @@ -353,96 +364,86 @@ impl CommonSubexprEliminate { window: Window, config: &dyn OptimizerConfig, ) -> Result> { - // collect all window expressions from any number of LogicalPlanWindow - let (mut window_exprs, mut window_schemas, mut plan) = + // Collects window expressions from consecutive `LogicalPlan::Window` nodes into + // a list. + let (window_expr_list, window_schemas, input) = get_consecutive_window_exprs(window); - let mut found_common = false; - let mut expr_stats = ExprStats::new(); - let arrays_per_window = window_exprs - .iter() - .map(|window_expr| { - self.to_arrays(window_expr, &mut expr_stats, ExprMask::Normal) - .map(|(fc, id_arrays)| { - found_common |= fc; - - id_arrays + // Extract common sub-expressions from the list. + self.find_common_exprs(window_expr_list, config, ExprMask::Normal)? + .map_data(|common| match common { + // If there are common sub-expressions, then the insert a projection node + // with the common expressions between the new window nodes and the + // original input. + FoundCommonExprs::Yes { + common_exprs, + new_exprs_list, + original_exprs_list, + } => { + build_common_expr_project_plan(input, common_exprs).map(|new_input| { + (new_exprs_list, new_input, Some(original_exprs_list)) }) - }) - .collect::>>()?; - - if found_common { - // save the original names - let name_preserver = NamePreserver::new(&plan); - let mut saved_names = window_exprs - .iter() - .map(|exprs| { - exprs - .iter() - .map(|expr| name_preserver.save(expr)) - .collect::>>() + } + FoundCommonExprs::No { + original_exprs_list, + } => Ok((original_exprs_list, input, None)), + })? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_window_expr_list, new_input, window_expr_list)| { + self.rewrite(new_input, config)?.map_data(|new_input| { + Ok((new_window_expr_list, new_input, window_expr_list)) }) - .collect::>>()?; - - assert_eq!(window_exprs.len(), arrays_per_window.len()); - let num_window_exprs = window_exprs.len(); - let rewritten_window_exprs = self.rewrite_expr( - // Must clone as Identifiers use references to original expressions so we - // have to keep the original expressions intact. - window_exprs.clone(), - arrays_per_window, - plan, - &expr_stats, - config, - )?; - let transformed = rewritten_window_exprs.transformed; - assert!(transformed); - - let (mut new_expr, new_input) = rewritten_window_exprs.data; - - let mut plan = new_input; - - // Construct consecutive window operator, with their corresponding new - // window expressions. - // - // Note this iterates over, `new_expr` and `saved_names` which are the - // same length, in reverse order - assert_eq!(num_window_exprs, new_expr.len()); - assert_eq!(num_window_exprs, saved_names.len()); - while let (Some(new_window_expr), Some(saved_names)) = - (new_expr.pop(), saved_names.pop()) - { - assert_eq!(new_window_expr.len(), saved_names.len()); - - // Rename re-written window expressions with original name, to - // preserve the output schema - let new_window_expr = new_window_expr - .into_iter() - .zip(saved_names.into_iter()) - .map(|(new_window_expr, saved_name)| { - saved_name.restore(new_window_expr) - }) - .collect::>>()?; - plan = LogicalPlan::Window(Window::try_new( - new_window_expr, - Arc::new(plan), - )?); - } - - Ok(Transformed::new_transformed(plan, transformed)) - } else { - while let (Some(window_expr), Some(schema)) = - (window_exprs.pop(), window_schemas.pop()) - { - plan = LogicalPlan::Window(Window { - input: Arc::new(plan), - window_expr, - schema, - }); - } - - Ok(Transformed::no(plan)) - } + })? + // Rebuild the consecutive window nodes. + .map_data(|(new_window_expr_list, new_input, window_expr_list)| { + // If there were common expressions extracted, then we need to make sure + // we restore the original column names. + // TODO: Although `find_common_exprs()` inserts aliases around extracted + // common expressions this doesn't mean that the original column names + // (schema) are preserved due to the inserted aliases are not always at + // the top of the expression. + // Let's consider improving `find_common_exprs()` to always keep column + // names and get rid of additional name preserving logic here. + if let Some(window_expr_list) = window_expr_list { + let name_preserver = NamePreserver::new_for_projection(); + let saved_names = window_expr_list + .iter() + .map(|exprs| { + exprs + .iter() + .map(|expr| name_preserver.save(expr)) + .collect::>>() + }) + .collect::>>()?; + new_window_expr_list.into_iter().zip(saved_names).try_rfold( + new_input, + |plan, (new_window_expr, saved_names)| { + let new_window_expr = new_window_expr + .into_iter() + .zip(saved_names) + .map(|(new_window_expr, saved_name)| { + saved_name.restore(new_window_expr) + }) + .collect::>>()?; + Window::try_new(new_window_expr, Arc::new(plan)) + .map(LogicalPlan::Window) + }, + ) + } else { + new_window_expr_list + .into_iter() + .zip(window_schemas) + .try_rfold(new_input, |plan, (new_window_expr, schema)| { + Window::try_new_with_schema( + new_window_expr, + Arc::new(plan), + schema, + ) + .map(LogicalPlan::Window) + }) + } + }) } fn try_optimize_aggregate( @@ -454,136 +455,184 @@ impl CommonSubexprEliminate { group_expr, aggr_expr, input, - schema: orig_schema, + schema, .. } = aggregate; - // track transformed information - let mut transformed = false; - - let name_perserver = NamePreserver::new_for_projection(); - let saved_names = aggr_expr - .iter() - .map(|expr| name_perserver.save(expr)) - .collect::>>()?; - - let mut expr_stats = ExprStats::new(); - // rewrite inputs - let (group_found_common, group_arrays) = - self.to_arrays(&group_expr, &mut expr_stats, ExprMask::Normal)?; - let (aggr_found_common, aggr_arrays) = - self.to_arrays(&aggr_expr, &mut expr_stats, ExprMask::Normal)?; - let (new_aggr_expr, new_group_expr, new_input) = - if group_found_common || aggr_found_common { - // rewrite both group exprs and aggr_expr - let rewritten = self.rewrite_expr( - // Must clone as Identifiers use references to original expressions so - // we have to keep the original expressions intact. - vec![group_expr.clone(), aggr_expr.clone()], - vec![group_arrays, aggr_arrays], - unwrap_arc(input), - &expr_stats, - config, - )?; - assert!(rewritten.transformed); - transformed |= rewritten.transformed; - let (mut new_expr, new_input) = rewritten.data; - - // note the reversed pop order. - let new_aggr_expr = pop_expr(&mut new_expr)?; - let new_group_expr = pop_expr(&mut new_expr)?; - - (new_aggr_expr, new_group_expr, Arc::new(new_input)) - } else { - (aggr_expr, group_expr, input) - }; + let input = unwrap_arc(input); + // Extract common sub-expressions from the aggregate and grouping expressions. + self.find_common_exprs(vec![group_expr, aggr_expr], config, ExprMask::Normal)? + .map_data(|common| { + match common { + // If there are common sub-expressions, then insert a projection node + // with the common expressions between the new aggregate node and the + // original input. + FoundCommonExprs::Yes { + common_exprs, + mut new_exprs_list, + mut original_exprs_list, + } => { + let new_aggr_expr = new_exprs_list.pop().unwrap(); + let new_group_expr = new_exprs_list.pop().unwrap(); + + build_common_expr_project_plan(input, common_exprs).map( + |new_input| { + let aggr_expr = original_exprs_list.pop().unwrap(); + ( + new_aggr_expr, + new_group_expr, + new_input, + Some(aggr_expr), + ) + }, + ) + } - // create potential projection on top - let mut expr_stats = ExprStats::new(); - let (aggr_found_common, aggr_arrays) = self.to_arrays( - &new_aggr_expr, - &mut expr_stats, - ExprMask::NormalAndAggregates, - )?; - if aggr_found_common { - let mut common_exprs = CommonExprs::new(); - let mut rewritten_exprs = self.rewrite_exprs_list( - // Must clone as Identifiers use references to original expressions so we - // have to keep the original expressions intact. - vec![new_aggr_expr.clone()], - vec![aggr_arrays], - &expr_stats, - &mut common_exprs, - &config.alias_generator(), - )?; - assert!(rewritten_exprs.transformed); - let rewritten = pop_expr(&mut rewritten_exprs.data)?; + FoundCommonExprs::No { + mut original_exprs_list, + } => { + let new_aggr_expr = original_exprs_list.pop().unwrap(); + let new_group_expr = original_exprs_list.pop().unwrap(); - assert!(!common_exprs.is_empty()); - let mut agg_exprs = common_exprs - .into_values() - .map(|(expr, expr_alias)| expr.alias(expr_alias)) - .collect::>(); - - let new_input_schema = Arc::clone(new_input.schema()); - let mut proj_exprs = vec![]; - for expr in &new_group_expr { - extract_expressions(expr, &new_input_schema, &mut proj_exprs)? - } - for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) { - if expr_rewritten == expr_orig { - if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten { - agg_exprs.push(expr.alias(&name)); - proj_exprs.push(Expr::Column(Column::from_name(name))); - } else { - let expr_alias = config.alias_generator().next(CSE_PREFIX); - let (qualifier, field) = - expr_rewritten.to_field(&new_input_schema)?; - let out_name = qualified_name(qualifier.as_ref(), field.name()); - - agg_exprs.push(expr_rewritten.alias(&expr_alias)); - proj_exprs.push( - Expr::Column(Column::from_name(expr_alias)).alias(out_name), - ); + Ok((new_aggr_expr, new_group_expr, input, None)) } - } else { - proj_exprs.push(expr_rewritten); } - } - - let agg = LogicalPlan::Aggregate(Aggregate::try_new( - new_input, - new_group_expr, - agg_exprs, - )?); - - Projection::try_new(proj_exprs, Arc::new(agg)) - .map(LogicalPlan::Projection) - .map(Transformed::yes) - } else { - // TODO: How exactly can the name or the schema change in this case? - // In theory `new_aggr_expr` and `new_group_expr` are either the original expressions or they were crafted via `rewrite_expr()`, that keeps the original expression names. - // If this is really needed can we have UT for it? - // Alias aggregation expressions if they have changed - let new_aggr_expr = new_aggr_expr - .into_iter() - .zip(saved_names.into_iter()) - .map(|(new_expr, saved_name)| saved_name.restore(new_expr)) - .collect::>>()?; - // Since group_expr may have changed, schema may also. Use try_new method. - let new_agg = if transformed { - Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)? - } else { - Aggregate::try_new_with_schema( - new_input, - new_group_expr, - new_aggr_expr, - orig_schema, + })? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| { + self.rewrite(new_input, config)?.map_data(|new_input| { + Ok(( + new_aggr_expr, + new_group_expr, + aggr_expr, + Arc::new(new_input), + )) + }) + })? + // Try extracting common aggregate expressions and rebuild the aggregate node. + .transform_data(|(new_aggr_expr, new_group_expr, aggr_expr, new_input)| { + // Extract common aggregate sub-expressions from the aggregate expressions. + self.find_common_exprs( + vec![new_aggr_expr], + config, + ExprMask::NormalAndAggregates, )? - }; - let new_agg = LogicalPlan::Aggregate(new_agg); - - Ok(Transformed::new_transformed(new_agg, transformed)) - } + .map_data(|common| { + match common { + FoundCommonExprs::Yes { + common_exprs, + mut new_exprs_list, + mut original_exprs_list, + } => { + let rewritten_aggr_expr = new_exprs_list.pop().unwrap(); + let new_aggr_expr = original_exprs_list.pop().unwrap(); + + let mut agg_exprs = common_exprs + .into_iter() + .map(|(expr, expr_alias)| expr.alias(expr_alias)) + .collect::>(); + + let new_input_schema = Arc::clone(new_input.schema()); + let mut proj_exprs = vec![]; + for expr in &new_group_expr { + extract_expressions( + expr, + &new_input_schema, + &mut proj_exprs, + )? + } + for (expr_rewritten, expr_orig) in + rewritten_aggr_expr.into_iter().zip(new_aggr_expr) + { + if expr_rewritten == expr_orig { + if let Expr::Alias(Alias { expr, name, .. }) = + expr_rewritten + { + agg_exprs.push(expr.alias(&name)); + proj_exprs + .push(Expr::Column(Column::from_name(name))); + } else { + let expr_alias = + config.alias_generator().next(CSE_PREFIX); + let (qualifier, field) = + expr_rewritten.to_field(&new_input_schema)?; + let out_name = qualified_name( + qualifier.as_ref(), + field.name(), + ); + + agg_exprs.push(expr_rewritten.alias(&expr_alias)); + proj_exprs.push( + Expr::Column(Column::from_name(expr_alias)) + .alias(out_name), + ); + } + } else { + proj_exprs.push(expr_rewritten); + } + } + + let agg = LogicalPlan::Aggregate(Aggregate::try_new( + new_input, + new_group_expr, + agg_exprs, + )?); + Projection::try_new(proj_exprs, Arc::new(agg)) + .map(LogicalPlan::Projection) + } + + // If there aren't any common aggregate sub-expressions, then just + // rebuild the aggregate node. + FoundCommonExprs::No { + mut original_exprs_list, + } => { + let rewritten_aggr_expr = original_exprs_list.pop().unwrap(); + + // If there were common expressions extracted, then we need to + // make sure we restore the original column names. + // TODO: Although `find_common_exprs()` inserts aliases around + // extracted common expressions this doesn't mean that the + // original column names (schema) are preserved due to the + // inserted aliases are not always at the top of the + // expression. + // Let's consider improving `find_common_exprs()` to always + // keep column names and get rid of additional name + // preserving logic here. + if let Some(aggr_expr) = aggr_expr { + let name_perserver = NamePreserver::new_for_projection(); + let saved_names = aggr_expr + .iter() + .map(|expr| name_perserver.save(expr)) + .collect::>>()?; + let new_aggr_expr = rewritten_aggr_expr + .into_iter() + .zip(saved_names.into_iter()) + .map(|(new_expr, saved_name)| { + saved_name.restore(new_expr) + }) + .collect::>>()?; + + // Since `group_expr` may have changed, schema may also. + // Use `try_new()` method. + Aggregate::try_new( + new_input, + new_group_expr, + new_aggr_expr, + ) + .map(LogicalPlan::Aggregate) + } else { + Aggregate::try_new_with_schema( + new_input, + new_group_expr, + rewritten_aggr_expr, + schema, + ) + .map(LogicalPlan::Aggregate) + } + } + } + }) + }) } /// Rewrites the expr list and input to remove common subexpressions @@ -602,32 +651,35 @@ impl CommonSubexprEliminate { /// that computes the common subexpressions fn try_unary_plan( &self, - expr: Vec, + exprs: Vec, input: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result, LogicalPlan)>> { - let mut expr_stats = ExprStats::new(); - let (found_common, id_arrays) = - self.to_arrays(&expr, &mut expr_stats, ExprMask::Normal)?; - - if found_common { - let rewritten = self.rewrite_expr( - // Must clone as Identifiers use references to original expressions so we - // have to keep the original expressions intact. - vec![expr.clone()], - vec![id_arrays], - input, - &expr_stats, - config, - )?; - assert!(rewritten.transformed); - rewritten.map_data(|(mut new_expr, new_input)| { - assert_eq!(new_expr.len(), 1); - Ok((new_expr.pop().unwrap(), new_input)) + // Extract common sub-expressions from the expressions. + self.find_common_exprs(vec![exprs], config, ExprMask::Normal)? + .map_data(|common| match common { + FoundCommonExprs::Yes { + common_exprs, + mut new_exprs_list, + original_exprs_list: _, + } => { + let new_exprs = new_exprs_list.pop().unwrap(); + build_common_expr_project_plan(input, common_exprs) + .map(|new_input| (new_exprs, new_input)) + } + FoundCommonExprs::No { + mut original_exprs_list, + } => { + let new_exprs = original_exprs_list.pop().unwrap(); + Ok((new_exprs, input)) + } + })? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_exprs, new_input)| { + self.rewrite(new_input, config)? + .map_data(|new_input| Ok((new_exprs, new_input))) }) - } else { - Ok(Transformed::no((expr, input))) - } } } @@ -665,7 +717,7 @@ impl CommonSubexprEliminate { fn get_consecutive_window_exprs( window: Window, ) -> (Vec>, Vec, LogicalPlan) { - let mut window_exprs = vec![]; + let mut window_expr_list = vec![]; let mut window_schemas = vec![]; let mut plan = LogicalPlan::Window(window); while let LogicalPlan::Window(Window { @@ -674,12 +726,12 @@ fn get_consecutive_window_exprs( schema, }) = plan { - window_exprs.push(window_expr); + window_expr_list.push(window_expr); window_schemas.push(schema); plan = unwrap_arc(input); } - (window_exprs, window_schemas, plan) + (window_expr_list, window_schemas, plan) } impl OptimizerRule for CommonSubexprEliminate { @@ -688,7 +740,10 @@ impl OptimizerRule for CommonSubexprEliminate { } fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) + // This rule handles recursion itself in a `ApplyOrder::TopDown` like manner. + // This is because in some cases adjacent nodes are collected (e.g. `Window`) and + // CSEd as a group, which can't be done in a simple `ApplyOrder::TopDown` rule. + None } fn rewrite( @@ -726,8 +781,9 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Prepare(_) => { - // ApplyOrder::TopDown handles recursion - Transformed::no(plan) + // This rule handles recursion itself in a `ApplyOrder::TopDown` like + // manner. + plan.map_children(|c| self.rewrite(c, config))? } }; @@ -753,12 +809,6 @@ impl Default for CommonSubexprEliminate { } } -fn pop_expr(new_expr: &mut Vec>) -> Result> { - new_expr - .pop() - .ok_or_else(|| internal_datafusion_err!("Failed to pop expression")) -} - /// Build the "intermediate" projection plan that evaluates the extracted common /// expressions. /// @@ -771,11 +821,11 @@ fn pop_expr(new_expr: &mut Vec>) -> Result> { /// expr_stats: the set of common subexpressions fn build_common_expr_project_plan( input: LogicalPlan, - common_exprs: CommonExprs, + common_exprs: Vec<(Expr, String)>, ) -> Result { let mut fields_set = BTreeSet::new(); let mut project_exprs = common_exprs - .into_values() + .into_iter() .map(|(expr, expr_alias)| { fields_set.insert(expr_alias.clone()); Ok(expr.alias(expr_alias)) @@ -1147,7 +1197,7 @@ fn replace_common_expr<'n>( expr_stats: &ExprStats<'n>, common_exprs: &mut CommonExprs<'n>, alias_generator: &AliasGenerator, -) -> Result> { +) -> Result { if id_array.is_empty() { Ok(Transformed::no(expr)) } else { @@ -1160,6 +1210,7 @@ fn replace_common_expr<'n>( alias_generator, }) } + .data() } #[cfg(test)] @@ -1178,42 +1229,22 @@ mod test { }; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; + use super::*; use crate::optimizer::OptimizerContext; use crate::test::*; + use crate::Optimizer; use datafusion_expr::test::function_stub::{avg, sum}; - use super::*; - - fn assert_non_optimized_plan_eq( - expected: &str, - plan: LogicalPlan, - config: Option<&dyn OptimizerConfig>, - ) { - assert_eq!(expected, format!("{plan}"), "Unexpected starting plan"); - let optimizer = CommonSubexprEliminate::new(); - let default_config = OptimizerContext::new(); - let config = config.unwrap_or(&default_config); - let optimized_plan = optimizer.rewrite(plan, config).unwrap(); - assert!(!optimized_plan.transformed, "unexpectedly optimize plan"); - let optimized_plan = optimized_plan.data; - assert_eq!( - expected, - format!("{optimized_plan}"), - "Unexpected optimized plan" - ); - } - fn assert_optimized_plan_eq( expected: &str, plan: LogicalPlan, config: Option<&dyn OptimizerConfig>, ) { - let optimizer = CommonSubexprEliminate::new(); + let optimizer = + Optimizer::with_rules(vec![Arc::new(CommonSubexprEliminate::new())]); let default_config = OptimizerContext::new(); let config = config.unwrap_or(&default_config); - let optimized_plan = optimizer.rewrite(plan, config).unwrap(); - assert!(optimized_plan.transformed, "failed to optimize plan"); - let optimized_plan = optimized_plan.data; + let optimized_plan = optimizer.optimize(plan, config, |_, _| ()).unwrap(); let formatted_plan = format!("{optimized_plan}"); assert_eq!(expected, formatted_plan); } @@ -1603,7 +1634,7 @@ mod test { let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\ \n TableScan: test"; - assert_non_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -1621,41 +1652,25 @@ mod test { \n Projection: Int32(1) + test.a, test.a\ \n TableScan: test"; - assert_non_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } - fn test_identifier(hash: u64, expr: &Expr) -> Identifier { - Identifier { hash, expr } - } - #[test] fn redundant_project_fields() { let table_scan = test_table_scan().unwrap(); let c_plus_a = col("c") + col("a"); let b_plus_a = col("b") + col("a"); - let common_exprs_1 = CommonExprs::from([ - ( - test_identifier(0, &c_plus_a), - (c_plus_a.clone(), format!("{CSE_PREFIX}_1")), - ), - ( - test_identifier(1, &b_plus_a), - (b_plus_a.clone(), format!("{CSE_PREFIX}_2")), - ), - ]); + let common_exprs_1 = vec![ + (c_plus_a, format!("{CSE_PREFIX}_1")), + (b_plus_a, format!("{CSE_PREFIX}_2")), + ]; let c_plus_a_2 = col(format!("{CSE_PREFIX}_1")); let b_plus_a_2 = col(format!("{CSE_PREFIX}_2")); - let common_exprs_2 = CommonExprs::from([ - ( - test_identifier(3, &c_plus_a_2), - (c_plus_a_2.clone(), format!("{CSE_PREFIX}_3")), - ), - ( - test_identifier(4, &b_plus_a_2), - (b_plus_a_2.clone(), format!("{CSE_PREFIX}_4")), - ), - ]); + let common_exprs_2 = vec![ + (c_plus_a_2, format!("{CSE_PREFIX}_3")), + (b_plus_a_2, format!("{CSE_PREFIX}_4")), + ]; let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap(); let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap(); @@ -1676,28 +1691,16 @@ mod test { .unwrap(); let c_plus_a = col("test1.c") + col("test1.a"); let b_plus_a = col("test1.b") + col("test1.a"); - let common_exprs_1 = CommonExprs::from([ - ( - test_identifier(0, &c_plus_a), - (c_plus_a.clone(), format!("{CSE_PREFIX}_1")), - ), - ( - test_identifier(1, &b_plus_a), - (b_plus_a.clone(), format!("{CSE_PREFIX}_2")), - ), - ]); + let common_exprs_1 = vec![ + (c_plus_a, format!("{CSE_PREFIX}_1")), + (b_plus_a, format!("{CSE_PREFIX}_2")), + ]; let c_plus_a_2 = col(format!("{CSE_PREFIX}_1")); let b_plus_a_2 = col(format!("{CSE_PREFIX}_2")); - let common_exprs_2 = CommonExprs::from([ - ( - test_identifier(3, &c_plus_a_2), - (c_plus_a_2.clone(), format!("{CSE_PREFIX}_3")), - ), - ( - test_identifier(4, &b_plus_a_2), - (b_plus_a_2.clone(), format!("{CSE_PREFIX}_4")), - ), - ]); + let common_exprs_2 = vec![ + (c_plus_a_2, format!("{CSE_PREFIX}_3")), + (b_plus_a_2, format!("{CSE_PREFIX}_4")), + ]; let project = build_common_expr_project_plan(join, common_exprs_1).unwrap(); let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap(); @@ -1963,6 +1966,52 @@ mod test { Ok(()) } + #[test] + fn test_non_top_level_common_expression() -> Result<()> { + let table_scan = test_table_scan()?; + + let common_expr = col("a") + col("b"); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + common_expr.clone().alias("c1"), + common_expr.alias("c2"), + ])? + .project(vec![col("c1"), col("c2")])? + .build()?; + + let expected = "Projection: c1, c2\ + \n Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\ + \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_nested_common_expression() -> Result<()> { + let table_scan = test_table_scan()?; + + let nested_common_expr = col("a") + col("b"); + let common_expr = nested_common_expr.clone() * nested_common_expr; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + common_expr.clone().alias("c1"), + common_expr.alias("c2"), + ])? + .build()?; + + let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\ + \n Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c\ + \n Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + /// returns a "random" function that is marked volatile (aka each invocation /// returns a different value) ///