Skip to content

Commit

Permalink
refactor: Add a TypeCheckRule to the optimizer (#20425)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Dec 24, 2024
1 parent 93ceacc commit 9ea5839
Show file tree
Hide file tree
Showing 21 changed files with 240 additions and 99 deletions.
6 changes: 6 additions & 0 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ impl LazyFrame {
self
}

/// Toggle type check optimization.
pub fn with_type_check(mut self, toggle: bool) -> Self {
self.opt_state.set(OptFlags::TYPE_CHECK, toggle);
self
}

/// Toggle expression simplification optimization on or off.
pub fn with_simplify_expr(mut self, toggle: bool) -> Self {
self.opt_state.set(OptFlags::SIMPLIFY_EXPR, toggle);
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-plan/src/dsl/expr_dyn_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,12 @@ impl GetOutput {
Default::default()
}

pub fn first() -> Self {
SpecialEq::new(Arc::new(
|_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
))
}

pub fn from_type(dt: DataType) -> Self {
SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
Ok(Field::new(flds[0].name().clone(), dt.clone()))
Expand Down
17 changes: 9 additions & 8 deletions crates/polars-plan/src/dsl/functions/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ where
F: 'static + Fn(Column, Column) -> PolarsResult<Option<Column>> + Send + Sync,
E: AsRef<[Expr]>,
{
let mut exprs = exprs.as_ref().to_vec();
exprs.push(acc);
let mut exprs_v = Vec::with_capacity(exprs.as_ref().len() + 1);
exprs_v.push(acc);
exprs_v.extend(exprs.as_ref().iter().cloned());
let exprs = exprs_v;

let function = new_column_udf(move |columns: &mut [Column]| {
let mut columns = columns.to_vec();
let mut acc = columns.pop().unwrap();

for c in columns {
if let Some(a) = f(acc.clone(), c)? {
let mut acc = columns.first().unwrap().clone();
for c in &columns[1..] {
if let Some(a) = f(acc.clone(), c.clone())? {
acc = a
}
}
Expand All @@ -43,7 +43,8 @@ where
Expr::AnonymousFunction {
input: exprs,
function,
output_type: GetOutput::super_type(),
// Take the type of the accumulator.
output_type: GetOutput::first(),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
flags: FunctionFlags::default()
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-plan/src/frame/opt_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ bitflags! {
/// Check if operations are order dependent and unset maintaining_order if
/// the order would not be observed.
const CHECK_ORDER_OBSERVE = 1 << 16;
/// Do type checking of the IR.
const TYPE_CHECK = 1 << 17;
}
}

impl OptFlags {
pub fn schema_only() -> Self {
Self::TYPE_COERCION
Self::TYPE_COERCION | Self::TYPE_CHECK
}
}

Expand Down
3 changes: 2 additions & 1 deletion crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ pub fn to_alp(
lp: DslPlan,
expr_arena: &mut Arena<AExpr>,
lp_arena: &mut Arena<IR>,
// Only `SIMPLIFY_EXPR` and `TYPE_COERCION` are respected.
// Only `SIMPLIFY_EXPR`, `TYPE_COERCION`, `TYPE_CHECK` are respected.
opt_flags: &mut OptFlags,
) -> PolarsResult<Node> {
let conversion_optimizer = ConversionOptimizer::new(
opt_flags.contains(OptFlags::SIMPLIFY_EXPR),
opt_flags.contains(OptFlags::TYPE_COERCION),
opt_flags.contains(OptFlags::TYPE_CHECK),
);

let mut ctxt = DslConversionContext {
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ pub fn resolve_join(
.coerce_types(ctxt.expr_arena, ctxt.lp_arena, input_right)
.map_err(|e| e.context("'join' failed".into()))?;

let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);

// Not a closure to avoid borrow issues because we mutate expr_arena as well.
macro_rules! get_dtype {
($expr:expr, $schema:expr) => {
Expand Down
28 changes: 28 additions & 0 deletions crates/polars-plan/src/plans/conversion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,20 @@ mod ir_to_dsl;
mod scans;
mod stack_opt;

use std::borrow::Cow;
use std::sync::{Arc, Mutex};

pub use dsl_to_ir::*;
pub use expr_to_ir::*;
pub use ir_to_dsl::*;
use polars_core::prelude::*;
use polars_utils::idx_vec::UnitVec;
use polars_utils::unitvec;
use polars_utils::vec::ConvertVec;
use recursive::recursive;
mod functions;
mod join;
pub(crate) mod type_check;
pub(crate) mod type_coercion;

pub(crate) use expr_expansion::{expand_selectors, is_regex_projection, prepare_projection};
Expand Down Expand Up @@ -266,3 +270,27 @@ impl IR {
}
}
}

fn get_input(lp_arena: &Arena<IR>, lp_node: Node) -> UnitVec<Node> {
let plan = lp_arena.get(lp_node);
let mut inputs: UnitVec<Node> = unitvec!();

// Used to get the schema of the input.
if is_scan(plan) {
inputs.push(lp_node);
} else {
plan.copy_inputs(&mut inputs);
};
inputs
}

fn get_schema(lp_arena: &Arena<IR>, lp_node: Node) -> Cow<'_, SchemaRef> {
let inputs = get_input(lp_arena, lp_node);
if inputs.is_empty() {
// Files don't have an input, so we must take their schema.
Cow::Borrowed(lp_arena.get(lp_node).scan_schema())
} else {
let input = inputs[0];
lp_arena.get(input).schema(lp_arena)
}
}
34 changes: 25 additions & 9 deletions crates/polars-plan/src/plans/conversion/stack_opt.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use std::borrow::Borrow;

use self::type_check::TypeCheckRule;
use super::*;

/// Applies expression simplification and type coercion during conversion to IR.
pub(super) struct ConversionOptimizer {
scratch: Vec<Node>,

simplify: Option<SimplifyExprRule>,
coerce: Option<TypeCoercionRule>,
check: Option<TypeCheckRule>,
// IR's can be cached in the DSL.
// But if they are used multiple times in DSL (e.g. concat/join)
// then it can occur that we take a slot multiple times.
Expand All @@ -16,7 +19,7 @@ pub(super) struct ConversionOptimizer {
}

impl ConversionOptimizer {
pub(super) fn new(simplify: bool, type_coercion: bool) -> Self {
pub(super) fn new(simplify: bool, type_coercion: bool, type_check: bool) -> Self {
let simplify = if simplify {
Some(SimplifyExprRule {})
} else {
Expand All @@ -29,10 +32,17 @@ impl ConversionOptimizer {
None
};

let check = if type_check {
Some(TypeCheckRule)
} else {
None
};

ConversionOptimizer {
scratch: Vec::with_capacity(8),
simplify,
coerce,
check,
used_arenas: Default::default(),
}
}
Expand All @@ -54,29 +64,35 @@ impl ConversionOptimizer {
pub(super) fn coerce_types(
&mut self,
expr_arena: &mut Arena<AExpr>,
lp_arena: &Arena<IR>,
ir_arena: &mut Arena<IR>,
current_node: Node,
) -> PolarsResult<()> {
// Different from the stack-opt in the optimizer phase, this does a single pass until fixed point per expression.

if let Some(rule) = &mut self.check {
while let Some(x) = rule.optimize_plan(ir_arena, expr_arena, current_node)? {
ir_arena.replace(current_node, x);
}
}

// process the expressions on the stack and apply optimizations.
while let Some(current_expr_node) = self.scratch.pop() {
{
let expr = unsafe { expr_arena.get_unchecked(current_expr_node) };
if expr.is_leaf() {
continue;
}
let expr = unsafe { expr_arena.get_unchecked(current_expr_node) };

if expr.is_leaf() {
continue;
}

if let Some(rule) = &mut self.simplify {
while let Some(x) =
rule.optimize_expr(expr_arena, current_expr_node, lp_arena, current_node)?
rule.optimize_expr(expr_arena, current_expr_node, ir_arena, current_node)?
{
expr_arena.replace(current_expr_node, x);
}
}
if let Some(rule) = &mut self.coerce {
while let Some(x) =
rule.optimize_expr(expr_arena, current_expr_node, lp_arena, current_node)?
rule.optimize_expr(expr_arena, current_expr_node, ir_arena, current_node)?
{
expr_arena.replace(current_expr_node, x);
}
Expand Down
48 changes: 48 additions & 0 deletions crates/polars-plan/src/plans/conversion/type_check/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use polars_core::error::{polars_ensure, PolarsResult};
use polars_core::prelude::DataType;
use polars_utils::arena::{Arena, Node};

use super::{AExpr, OptimizationRule, IR};
use crate::plans::conversion::get_schema;
use crate::plans::Context;

pub struct TypeCheckRule;

impl OptimizationRule for TypeCheckRule {
fn optimize_plan(
&mut self,
ir_arena: &mut Arena<IR>,
expr_arena: &mut Arena<AExpr>,
node: Node,
) -> PolarsResult<Option<IR>> {
let ir = ir_arena.get(node);
match ir {
IR::Scan {
predicate: Some(predicate),
..
} => {
let input_schema = get_schema(ir_arena, node);
let dtype = predicate.dtype(input_schema.as_ref(), Context::Default, expr_arena)?;

polars_ensure!(
matches!(dtype, DataType::Boolean | DataType::Unknown(_)),
InvalidOperation: "filter predicate must be of type `Boolean`, got `{dtype:?}`"
);

Ok(None)
},
IR::Filter { predicate, .. } => {
let input_schema = get_schema(ir_arena, node);
let dtype = predicate.dtype(input_schema.as_ref(), Context::Default, expr_arena)?;

polars_ensure!(
matches!(dtype, DataType::Boolean | DataType::Unknown(_)),
InvalidOperation: "filter predicate must be of type `Boolean`, got `{dtype:?}`"
);

Ok(None)
},
_ => Ok(None),
}
}
}
32 changes: 4 additions & 28 deletions crates/polars-plan/src/plans/conversion/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@ mod functions;
#[cfg(feature = "is_in")]
mod is_in;

use std::borrow::Cow;

use binary::process_binary;
use polars_compute::cast::temporal::{time_unit_multiple, SECONDS_IN_DAY};
use polars_core::chunked_array::cast::CastOptions;
use polars_core::prelude::*;
use polars_core::utils::{get_supertype, get_supertype_with_options, materialize_dyn_int};
use polars_utils::idx_vec::UnitVec;
use polars_utils::format_list;
use polars_utils::itertools::Itertools;
use polars_utils::{format_list, unitvec};

use super::*;

Expand Down Expand Up @@ -69,30 +66,6 @@ fn modify_supertype(
st
}

fn get_input(lp_arena: &Arena<IR>, lp_node: Node) -> UnitVec<Node> {
let plan = lp_arena.get(lp_node);
let mut inputs: UnitVec<Node> = unitvec!();

// Used to get the schema of the input.
if is_scan(plan) {
inputs.push(lp_node);
} else {
plan.copy_inputs(&mut inputs);
};
inputs
}

fn get_schema(lp_arena: &Arena<IR>, lp_node: Node) -> Cow<'_, SchemaRef> {
let inputs = get_input(lp_arena, lp_node);
if inputs.is_empty() {
// Files don't have an input, so we must take their schema.
Cow::Borrowed(lp_arena.get(lp_node).scan_schema())
} else {
let input = inputs[0];
lp_arena.get(input).schema(lp_arena)
}
}

fn get_aexpr_and_type<'a>(
expr_arena: &'a Arena<AExpr>,
e: Node,
Expand Down Expand Up @@ -515,6 +488,7 @@ fn cast_expr_ir(
if let AExpr::Literal(lv) = expr_arena.get(e.node()) {
if let Some(literal) = try_inline_literal_cast(lv, to_dtype, strict)? {
e.set_node(expr_arena.add(AExpr::Literal(literal)));
e.set_dtype(to_dtype.clone());
return Ok(());
}
}
Expand All @@ -524,6 +498,8 @@ fn cast_expr_ir(
dtype: to_dtype.clone(),
options: CastOptions::Strict,
}));
e.set_dtype(to_dtype.clone());

Ok(())
}

Expand Down
4 changes: 4 additions & 0 deletions crates/polars-plan/src/plans/expr_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ impl ExprIR {
self
}

pub(crate) fn set_dtype(&mut self, dtype: DataType) {
self.output_dtype = OnceLock::from(dtype);
}

pub fn from_node(node: Node, arena: &Arena<AExpr>) -> Self {
let mut out = Self {
node,
Expand Down
Loading

0 comments on commit 9ea5839

Please sign in to comment.