Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Add a TypeCheckRule to the optimizer #20425

Merged
merged 4 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ impl LazyFrame {
self
}

/// Toggle type check optimization.
pub fn with_type_check(mut self, toggle: bool) -> Self {
self.opt_state.set(OptFlags::TYPE_CHECK, toggle);
self
}

/// Toggle expression simplification optimization on or off.
pub fn with_simplify_expr(mut self, toggle: bool) -> Self {
self.opt_state.set(OptFlags::SIMPLIFY_EXPR, toggle);
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-plan/src/dsl/expr_dyn_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,12 @@ impl GetOutput {
Default::default()
}

pub fn first() -> Self {
SpecialEq::new(Arc::new(
|_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
))
}

pub fn from_type(dt: DataType) -> Self {
SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
Ok(Field::new(flds[0].name().clone(), dt.clone()))
Expand Down
17 changes: 9 additions & 8 deletions crates/polars-plan/src/dsl/functions/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ where
F: 'static + Fn(Column, Column) -> PolarsResult<Option<Column>> + Send + Sync,
E: AsRef<[Expr]>,
{
let mut exprs = exprs.as_ref().to_vec();
exprs.push(acc);
let mut exprs_v = Vec::with_capacity(exprs.as_ref().len() + 1);
exprs_v.push(acc);
exprs_v.extend(exprs.as_ref().iter().cloned());
let exprs = exprs_v;

let function = new_column_udf(move |columns: &mut [Column]| {
let mut columns = columns.to_vec();
let mut acc = columns.pop().unwrap();

for c in columns {
if let Some(a) = f(acc.clone(), c)? {
let mut acc = columns.first().unwrap().clone();
for c in &columns[1..] {
if let Some(a) = f(acc.clone(), c.clone())? {
acc = a
}
}
Expand All @@ -43,7 +43,8 @@ where
Expr::AnonymousFunction {
input: exprs,
function,
output_type: GetOutput::super_type(),
// Take the type of the accumulator.
output_type: GetOutput::first(),
options: FunctionOptions {
collect_groups: ApplyOptions::GroupWise,
flags: FunctionFlags::default()
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-plan/src/frame/opt_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ bitflags! {
/// Check if operations are order dependent and unset maintaining_order if
/// the order would not be observed.
const CHECK_ORDER_OBSERVE = 1 << 16;
/// Do type checking of the IR.
const TYPE_CHECK = 1 << 17;
}
}

impl OptFlags {
pub fn schema_only() -> Self {
Self::TYPE_COERCION
Self::TYPE_COERCION | Self::TYPE_CHECK
}
}

Expand Down
3 changes: 2 additions & 1 deletion crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ pub fn to_alp(
lp: DslPlan,
expr_arena: &mut Arena<AExpr>,
lp_arena: &mut Arena<IR>,
// Only `SIMPLIFY_EXPR` and `TYPE_COERCION` are respected.
// Only `SIMPLIFY_EXPR`, `TYPE_COERCION`, `TYPE_CHECK` are respected.
opt_flags: &mut OptFlags,
) -> PolarsResult<Node> {
let conversion_optimizer = ConversionOptimizer::new(
opt_flags.contains(OptFlags::SIMPLIFY_EXPR),
opt_flags.contains(OptFlags::TYPE_COERCION),
opt_flags.contains(OptFlags::TYPE_CHECK),
);

let mut ctxt = DslConversionContext {
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ pub fn resolve_join(
.coerce_types(ctxt.expr_arena, ctxt.lp_arena, input_right)
.map_err(|e| e.context("'join' failed".into()))?;

let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);

// Not a closure to avoid borrow issues because we mutate expr_arena as well.
macro_rules! get_dtype {
($expr:expr, $schema:expr) => {
Expand Down
28 changes: 28 additions & 0 deletions crates/polars-plan/src/plans/conversion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,20 @@ mod ir_to_dsl;
mod scans;
mod stack_opt;

use std::borrow::Cow;
use std::sync::{Arc, Mutex};

pub use dsl_to_ir::*;
pub use expr_to_ir::*;
pub use ir_to_dsl::*;
use polars_core::prelude::*;
use polars_utils::idx_vec::UnitVec;
use polars_utils::unitvec;
use polars_utils::vec::ConvertVec;
use recursive::recursive;
mod functions;
mod join;
pub(crate) mod type_check;
pub(crate) mod type_coercion;

pub(crate) use expr_expansion::{expand_selectors, is_regex_projection, prepare_projection};
Expand Down Expand Up @@ -266,3 +270,27 @@ impl IR {
}
}
}

fn get_input(lp_arena: &Arena<IR>, lp_node: Node) -> UnitVec<Node> {
let plan = lp_arena.get(lp_node);
let mut inputs: UnitVec<Node> = unitvec!();

// Used to get the schema of the input.
if is_scan(plan) {
inputs.push(lp_node);
} else {
plan.copy_inputs(&mut inputs);
};
inputs
}

fn get_schema(lp_arena: &Arena<IR>, lp_node: Node) -> Cow<'_, SchemaRef> {
let inputs = get_input(lp_arena, lp_node);
if inputs.is_empty() {
// Files don't have an input, so we must take their schema.
Cow::Borrowed(lp_arena.get(lp_node).scan_schema())
} else {
let input = inputs[0];
lp_arena.get(input).schema(lp_arena)
}
}
34 changes: 25 additions & 9 deletions crates/polars-plan/src/plans/conversion/stack_opt.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use std::borrow::Borrow;

use self::type_check::TypeCheckRule;
use super::*;

/// Applies expression simplification and type coercion during conversion to IR.
pub(super) struct ConversionOptimizer {
scratch: Vec<Node>,

simplify: Option<SimplifyExprRule>,
coerce: Option<TypeCoercionRule>,
check: Option<TypeCheckRule>,
// IR's can be cached in the DSL.
// But if they are used multiple times in DSL (e.g. concat/join)
// then it can occur that we take a slot multiple times.
Expand All @@ -16,7 +19,7 @@ pub(super) struct ConversionOptimizer {
}

impl ConversionOptimizer {
pub(super) fn new(simplify: bool, type_coercion: bool) -> Self {
pub(super) fn new(simplify: bool, type_coercion: bool, type_check: bool) -> Self {
let simplify = if simplify {
Some(SimplifyExprRule {})
} else {
Expand All @@ -29,10 +32,17 @@ impl ConversionOptimizer {
None
};

let check = if type_check {
Some(TypeCheckRule)
} else {
None
};

ConversionOptimizer {
scratch: Vec::with_capacity(8),
simplify,
coerce,
check,
used_arenas: Default::default(),
}
}
Expand All @@ -54,29 +64,35 @@ impl ConversionOptimizer {
pub(super) fn coerce_types(
&mut self,
expr_arena: &mut Arena<AExpr>,
lp_arena: &Arena<IR>,
ir_arena: &mut Arena<IR>,
current_node: Node,
) -> PolarsResult<()> {
// Different from the stack-opt in the optimizer phase, this does a single pass until fixed point per expression.

if let Some(rule) = &mut self.check {
while let Some(x) = rule.optimize_plan(ir_arena, expr_arena, current_node)? {
ir_arena.replace(current_node, x);
}
}

// process the expressions on the stack and apply optimizations.
while let Some(current_expr_node) = self.scratch.pop() {
{
let expr = unsafe { expr_arena.get_unchecked(current_expr_node) };
if expr.is_leaf() {
continue;
}
let expr = unsafe { expr_arena.get_unchecked(current_expr_node) };

if expr.is_leaf() {
continue;
}

if let Some(rule) = &mut self.simplify {
while let Some(x) =
rule.optimize_expr(expr_arena, current_expr_node, lp_arena, current_node)?
rule.optimize_expr(expr_arena, current_expr_node, ir_arena, current_node)?
{
expr_arena.replace(current_expr_node, x);
}
}
if let Some(rule) = &mut self.coerce {
while let Some(x) =
rule.optimize_expr(expr_arena, current_expr_node, lp_arena, current_node)?
rule.optimize_expr(expr_arena, current_expr_node, ir_arena, current_node)?
{
expr_arena.replace(current_expr_node, x);
}
Expand Down
48 changes: 48 additions & 0 deletions crates/polars-plan/src/plans/conversion/type_check/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use polars_core::error::{polars_ensure, PolarsResult};
use polars_core::prelude::DataType;
use polars_utils::arena::{Arena, Node};

use super::{AExpr, OptimizationRule, IR};
use crate::plans::conversion::get_schema;
use crate::plans::Context;

pub struct TypeCheckRule;

impl OptimizationRule for TypeCheckRule {
fn optimize_plan(
&mut self,
ir_arena: &mut Arena<IR>,
expr_arena: &mut Arena<AExpr>,
node: Node,
) -> PolarsResult<Option<IR>> {
let ir = ir_arena.get(node);
match ir {
IR::Scan {
predicate: Some(predicate),
..
} => {
let input_schema = get_schema(ir_arena, node);
let dtype = predicate.dtype(input_schema.as_ref(), Context::Default, expr_arena)?;

polars_ensure!(
matches!(dtype, DataType::Boolean | DataType::Unknown(_)),
InvalidOperation: "filter predicate must be of type `Boolean`, got `{dtype:?}`"
);

Ok(None)
},
IR::Filter { predicate, .. } => {
let input_schema = get_schema(ir_arena, node);
let dtype = predicate.dtype(input_schema.as_ref(), Context::Default, expr_arena)?;

polars_ensure!(
matches!(dtype, DataType::Boolean | DataType::Unknown(_)),
InvalidOperation: "filter predicate must be of type `Boolean`, got `{dtype:?}`"
);

Ok(None)
},
_ => Ok(None),
}
}
}
32 changes: 4 additions & 28 deletions crates/polars-plan/src/plans/conversion/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@ mod functions;
#[cfg(feature = "is_in")]
mod is_in;

use std::borrow::Cow;

use binary::process_binary;
use polars_compute::cast::temporal::{time_unit_multiple, SECONDS_IN_DAY};
use polars_core::chunked_array::cast::CastOptions;
use polars_core::prelude::*;
use polars_core::utils::{get_supertype, get_supertype_with_options, materialize_dyn_int};
use polars_utils::idx_vec::UnitVec;
use polars_utils::format_list;
use polars_utils::itertools::Itertools;
use polars_utils::{format_list, unitvec};

use super::*;

Expand Down Expand Up @@ -69,30 +66,6 @@ fn modify_supertype(
st
}

fn get_input(lp_arena: &Arena<IR>, lp_node: Node) -> UnitVec<Node> {
let plan = lp_arena.get(lp_node);
let mut inputs: UnitVec<Node> = unitvec!();

// Used to get the schema of the input.
if is_scan(plan) {
inputs.push(lp_node);
} else {
plan.copy_inputs(&mut inputs);
};
inputs
}

fn get_schema(lp_arena: &Arena<IR>, lp_node: Node) -> Cow<'_, SchemaRef> {
let inputs = get_input(lp_arena, lp_node);
if inputs.is_empty() {
// Files don't have an input, so we must take their schema.
Cow::Borrowed(lp_arena.get(lp_node).scan_schema())
} else {
let input = inputs[0];
lp_arena.get(input).schema(lp_arena)
}
}

fn get_aexpr_and_type<'a>(
expr_arena: &'a Arena<AExpr>,
e: Node,
Expand Down Expand Up @@ -515,6 +488,7 @@ fn cast_expr_ir(
if let AExpr::Literal(lv) = expr_arena.get(e.node()) {
if let Some(literal) = try_inline_literal_cast(lv, to_dtype, strict)? {
e.set_node(expr_arena.add(AExpr::Literal(literal)));
e.set_dtype(to_dtype.clone());
return Ok(());
}
}
Expand All @@ -524,6 +498,8 @@ fn cast_expr_ir(
dtype: to_dtype.clone(),
options: CastOptions::Strict,
}));
e.set_dtype(to_dtype.clone());

Ok(())
}

Expand Down
4 changes: 4 additions & 0 deletions crates/polars-plan/src/plans/expr_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ impl ExprIR {
self
}

pub(crate) fn set_dtype(&mut self, dtype: DataType) {
self.output_dtype = OnceLock::from(dtype);
}

pub fn from_node(node: Node, arena: &Arena<AExpr>) -> Self {
let mut out = Self {
node,
Expand Down
Loading
Loading