Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Order observability optimizations #20396

Merged
merged 10 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ impl LazyFrame {
self
}

/// Check if operations are order dependent and unset maintaining_order if
/// the order would not be observed.
pub fn with_check_order(mut self, toggle: bool) -> Self {
self.opt_state.set(OptFlags::CHECK_ORDER_OBSERVE, toggle);
self
}

/// Toggle predicate pushdown optimization.
pub fn with_predicate_pushdown(mut self, toggle: bool) -> Self {
self.opt_state.set(OptFlags::PREDICATE_PUSHDOWN, toggle);
Expand Down
12 changes: 2 additions & 10 deletions crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,7 @@ pub(crate) fn insert_streaming_nodes(
state.operators_sinks.push(PipelineNode::Operator(root));
stack.push(StackFrame::new(*input, state, current_idx))
},
HStack { input, exprs, .. }
if exprs
.iter()
.all(|e| is_elementwise_rec(expr_arena.get(e.node()), expr_arena)) =>
{
HStack { input, exprs, .. } if all_elementwise(exprs, expr_arena) => {
state.streamable = true;
state.operators_sinks.push(PipelineNode::Operator(root));
stack.push(StackFrame::new(*input, state, current_idx))
Expand All @@ -198,11 +194,7 @@ pub(crate) fn insert_streaming_nodes(
state.operators_sinks.push(PipelineNode::Sink(root));
stack.push(StackFrame::new(*input, state, current_idx))
},
Select { input, expr, .. }
if expr
.iter()
.all(|e| is_elementwise_rec(expr_arena.get(e.node()), expr_arena)) =>
{
Select { input, expr, .. } if all_elementwise(expr, expr_arena) => {
state.streamable = true;
state.operators_sinks.push(PipelineNode::Operator(root));
stack.push(StackFrame::new(*input, state, current_idx))
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-plan/src/frame/opt_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ bitflags! {
const FAST_PROJECTION = 1 << 14;
/// Collapse slower joins with filters into faster joins.
const COLLAPSE_JOINS = 1 << 15;
/// Check if operations are order dependent and unset maintaining_order if
/// the order would not be observed.
const CHECK_ORDER_OBSERVE = 1 << 16;
}
}

Expand Down
9 changes: 9 additions & 0 deletions crates/polars-plan/src/plans/aexpr/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ pub fn is_elementwise(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<
true
}

pub fn all_elementwise<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool
where
Node: From<&'a N>,
{
nodes
.iter()
.all(|n| is_elementwise_rec(expr_arena.get(n.into()), expr_arena))
}

/// Recursive variant of `is_elementwise`
pub fn is_elementwise_rec<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<AExpr>) -> bool {
let mut stack = unitvec![];
Expand Down
7 changes: 1 addition & 6 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,9 @@ pub fn resolve_join(
}
// Every expression must be elementwise so that we are
// guaranteed the keys for a join are all the same length.
let all_elementwise = |aexprs: &[ExprIR]| {
aexprs
.iter()
.all(|e| is_elementwise_rec(ctxt.expr_arena.get(e.node()), ctxt.expr_arena))
};

polars_ensure!(
all_elementwise(&left_on) && all_elementwise(&right_on),
all_elementwise(&left_on, ctxt.expr_arena) && all_elementwise(&right_on, ctxt.expr_arena),
InvalidOperation: "all join key expressions must be elementwise."
);

Expand Down
23 changes: 0 additions & 23 deletions crates/polars-plan/src/plans/optimizer/collapse_and_project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,29 +123,6 @@ impl OptimizationRule for SimpleProjectionAndCollapse {
None
}
},
// Remove double sorts
Sort {
input,
by_column,
slice,
sort_options:
sort_options @ SortMultipleOptions {
maintain_order: false, // `maintain_order=True` is influenced by result of earlier sorts
..
},
} => match lp_arena.get(*input) {
Sort {
input: inner,
slice: None,
..
} => Some(Sort {
input: *inner,
by_column: by_column.clone(),
slice: *slice,
sort_options: sort_options.clone(),
}),
_ => None,
},
_ => None,
}
}
Expand Down
15 changes: 15 additions & 0 deletions crates/polars-plan/src/plans/optimizer/collect_members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ pub(super) struct MemberCollector {
pub(crate) has_cache: bool,
pub(crate) has_ext_context: bool,
pub(crate) has_filter_with_join_input: bool,
pub(crate) has_distinct: bool,
pub(crate) has_sort: bool,
pub(crate) has_group_by: bool,
#[cfg(feature = "cse")]
scans: UniqueScans,
}
Expand All @@ -38,6 +41,9 @@ impl MemberCollector {
has_cache: false,
has_ext_context: false,
has_filter_with_join_input: false,
has_distinct: false,
has_sort: false,
has_group_by: false,
#[cfg(feature = "cse")]
scans: UniqueScans::default(),
}
Expand All @@ -50,6 +56,15 @@ impl MemberCollector {
Filter { input, .. } => {
self.has_filter_with_join_input |= matches!(lp_arena.get(*input), Join { options, .. } if options.args.how.is_cross())
},
Distinct { .. } => {
self.has_distinct = true;
},
GroupBy { .. } => {
self.has_group_by = true;
},
Sort { .. } => {
self.has_sort = true;
},
Cache { .. } => self.has_cache = true,
ExtContext { .. } => self.has_ext_context = true,
#[cfg(feature = "cse")]
Expand Down
9 changes: 9 additions & 0 deletions crates/polars-plan/src/plans/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod fused;
mod join_utils;
mod predicate_pushdown;
mod projection_pushdown;
mod set_order;
mod simplify_expr;
mod slice_pushdown_expr;
mod slice_pushdown_lp;
Expand All @@ -34,6 +35,7 @@ use slice_pushdown_lp::SlicePushDown;
pub use stack_opt::{OptimizationRule, StackOptimizer};

use self::flatten_union::FlattenUnionRule;
use self::set_order::set_order_flags;
pub use crate::frame::{AllowedOptimizations, OptFlags};
pub use crate::plans::conversion::type_coercion::TypeCoercionRule;
use crate::plans::optimizer::count_star::CountStar;
Expand Down Expand Up @@ -116,6 +118,13 @@ pub fn optimize(
members.collect(lp_top, lp_arena, expr_arena)
}

// Run before slice pushdown
if opt_state.contains(OptFlags::CHECK_ORDER_OBSERVE)
&& members.has_group_by | members.has_sort | members.has_distinct
{
set_order_flags(lp_top, lp_arena, expr_arena, scratch);
}

if simplify_expr {
#[cfg(feature = "fused")]
rules.push(Box::new(fused::FusedArithmetic {}));
Expand Down
156 changes: 156 additions & 0 deletions crates/polars-plan/src/plans/optimizer/set_order.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
use polars_utils::unitvec;

use super::*;

// Can give false positives.
fn is_order_dependent_top_level(ae: &AExpr, ctx: Context) -> bool {
match ae {
AExpr::Agg(agg) => match agg {
IRAggExpr::Min { .. } => false,
IRAggExpr::Max { .. } => false,
IRAggExpr::Median(_) => false,
IRAggExpr::NUnique(_) => false,
IRAggExpr::First(_) => true,
IRAggExpr::Last(_) => true,
IRAggExpr::Mean(_) => false,
IRAggExpr::Implode(_) => true,
IRAggExpr::Quantile { .. } => false,
IRAggExpr::Sum(_) => false,
IRAggExpr::Count(_, _) => false,
IRAggExpr::Std(_, _) => false,
IRAggExpr::Var(_, _) => false,
IRAggExpr::AggGroups(_) => true,
},
AExpr::Column(_) => matches!(ctx, Context::Aggregation),
_ => true,
}
}

// Can give false positives.
fn is_order_dependent<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<AExpr>, ctx: Context) -> bool {
let mut stack = unitvec![];

loop {
if is_order_dependent_top_level(ae, ctx) {
return true;
}

let Some(node) = stack.pop() else {
break;
};

ae = expr_arena.get(node);
}

false
}

// Can give false negatives.
pub(crate) fn all_order_independent<'a, N>(
nodes: &'a [N],
expr_arena: &Arena<AExpr>,
ctx: Context,
) -> bool
where
Node: From<&'a N>,
{
!nodes
.iter()
.any(|n| is_order_dependent(expr_arena.get(n.into()), expr_arena, ctx))
}

// Should run before slice pushdown.
pub(super) fn set_order_flags(
root: Node,
ir_arena: &mut Arena<IR>,
expr_arena: &Arena<AExpr>,
scratch: &mut Vec<Node>,
) {
scratch.clear();
scratch.push(root);

let mut maintain_order_above = true;

while let Some(node) = scratch.pop() {
let ir = ir_arena.get_mut(node);
ir.copy_inputs(scratch);

match ir {
IR::Sort {
input,
sort_options,
..
} => {
debug_assert!(sort_options.limit.is_none());
// This sort can be removed
if !maintain_order_above {
scratch.pop();
scratch.push(node);
let input = *input;
ir_arena.swap(node, input);
continue;
}

if !sort_options.maintain_order {
maintain_order_above = false; // `maintain_order=True` is influenced by result of earlier sorts
}
},
IR::Distinct { options, .. } => {
debug_assert!(options.slice.is_none());
if !maintain_order_above {
options.maintain_order = false;
continue;
}
if !options.maintain_order {
maintain_order_above = false;
}
},
IR::Union { options, .. } => {
debug_assert!(options.slice.is_none());
options.maintain_order = maintain_order_above;
},
IR::GroupBy {
keys,
aggs,
maintain_order,
options,
apply,
..
} => {
debug_assert!(options.slice.is_none());
if !maintain_order_above && *maintain_order {
*maintain_order = false;
continue;
}

if apply.is_some()
|| *maintain_order
|| options.is_rolling()
|| options.is_dynamic()
{
maintain_order_above = true;
continue;
}
if all_elementwise(keys, expr_arena)
&& all_order_independent(aggs, expr_arena, Context::Aggregation)
{
maintain_order_above = false;
continue;
}
maintain_order_above = true;
},
// Conservative now.
IR::HStack { exprs, .. } | IR::Select { expr: exprs, .. } => {
if !maintain_order_above && all_elementwise(exprs, expr_arena) {
continue;
}
maintain_order_above = true;
},
_ => {
// If we don't know maintain order
// Known: slice
maintain_order_above = true;
},
}
}
}
Loading
Loading