Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add how argument to join_where to support different join types #19962

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
23 changes: 20 additions & 3 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2158,9 +2158,26 @@ impl JoinBuilder {
}

// Finish with join predicates
pub fn join_where(self, predicates: Vec<Expr>) -> LazyFrame {
pub fn join_where(self, predicates: Vec<Expr>) -> PolarsResult<LazyFrame> {
let mut opt_state = self.lf.opt_state;
let other = self.other.expect("with not set");
let other = self
.other
.ok_or_else(|| polars_err!(oos = "with parameter for join_where not set"))?;

// join_where supports a subset of the full set of join types
if !matches!(
self.how,
JoinType::Inner
| JoinType::Left
| JoinType::Right
| JoinType::Full
| JoinType::Semi
| JoinType::Anti
) {
return Err(polars_err!(
oos = format!("Invalid join type '{}' for join_where", self.how)
));
}

// If any of the nodes reads from files we must activate this plan as well.
if other.opt_state.contains(OptFlags::FILE_CACHING) {
Expand Down Expand Up @@ -2249,6 +2266,6 @@ impl JoinBuilder {
options: Arc::from(options),
};

LazyFrame::from_logical_plan(lp, opt_state)
Ok(LazyFrame::from_logical_plan(lp, opt_state))
}
}
44 changes: 44 additions & 0 deletions crates/polars-mem-engine/src/executors/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@ use polars_ops::frame::DataFrameJoinOps;

use super::*;

pub struct JoinPredicateExec {
pub left_on: Arc<dyn PhysicalExpr>,
pub right_on: Arc<dyn PhysicalExpr>,
pub op: Operator,
}

pub struct JoinExec {
input_left: Option<Box<dyn Executor>>,
input_right: Option<Box<dyn Executor>>,
left_on: Vec<Arc<dyn PhysicalExpr>>,
right_on: Vec<Arc<dyn PhysicalExpr>>,
extra_predicates: Vec<JoinPredicateExec>,
parallel: bool,
args: JoinArgs,
}
Expand All @@ -18,6 +25,7 @@ impl JoinExec {
input_right: Box<dyn Executor>,
left_on: Vec<Arc<dyn PhysicalExpr>>,
right_on: Vec<Arc<dyn PhysicalExpr>>,
extra_predicates: Vec<JoinPredicateExec>,
parallel: bool,
args: JoinArgs,
) -> Self {
Expand All @@ -26,6 +34,7 @@ impl JoinExec {
input_right: Some(input_right),
left_on,
right_on,
extra_predicates,
parallel,
args,
}
Expand Down Expand Up @@ -97,6 +106,21 @@ impl Executor for JoinExec {
.map(|e| e.evaluate(&df_right, state))
.collect::<PolarsResult<Vec<_>>>()?;

let extra_predicates_inputs = self
.extra_predicates
.iter()
.map(|ep| {
let left_on = ep.left_on.evaluate(&df_left, state)?.take_materialized_series();
let right_on = ep.right_on.evaluate(&df_right, state)?.take_materialized_series();
let op = operator_to_join_predicate_op(ep.op)?;
Ok(MaterializedJoinPredicate {
left_on,
right_on,
op,
})
})
.collect::<PolarsResult<Vec<_>>>()?;

// prepare the tolerance
// we must ensure that we use the right units
#[cfg(feature = "asof_join")]
Expand Down Expand Up @@ -142,6 +166,7 @@ impl Executor for JoinExec {
left_on_series.into_iter().map(|c| c.take_materialized_series()).collect(),
right_on_series.into_iter().map(|c| c.take_materialized_series()).collect(),
self.args.clone(),
extra_predicates_inputs,
true,
state.verbose(),
);
Expand All @@ -154,3 +179,22 @@ impl Executor for JoinExec {
}, profile_name)
}
}

fn operator_to_join_predicate_op(op: Operator) -> PolarsResult<JoinComparisonOperator> {
match op {
Operator::Eq => Ok(JoinComparisonOperator::Eq),
Operator::EqValidity => Ok(JoinComparisonOperator::EqValidity),
Operator::NotEq => Ok(JoinComparisonOperator::NotEq),
Operator::NotEqValidity => Ok(JoinComparisonOperator::NotEqValidity),
Operator::Lt => Ok(JoinComparisonOperator::Lt),
Operator::LtEq => Ok(JoinComparisonOperator::LtEq),
Operator::Gt => Ok(JoinComparisonOperator::Gt),
Operator::GtEq => Ok(JoinComparisonOperator::GtEq),
Operator::And => Ok(JoinComparisonOperator::And),
Operator::Or => Ok(JoinComparisonOperator::Or),
Operator::Xor => Ok(JoinComparisonOperator::Xor),
_ => {
Err(polars_err!(ComputeError: format!("Invalid operator for join predicate: {:?}", op)))
},
}
}
27 changes: 27 additions & 0 deletions crates/polars-mem-engine/src/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use polars_plan::plans::expr_ir::ExprIR;

use super::super::executors::{self, Executor};
use super::*;
use crate::executors::JoinPredicateExec;
use crate::utils::*;

fn partitionable_gb(
Expand Down Expand Up @@ -580,6 +581,7 @@ fn create_physical_plan_impl(
left_on,
right_on,
options,
extra_predicates,
..
} => {
let parallel = if options.force_parallel {
Expand Down Expand Up @@ -616,12 +618,37 @@ fn create_physical_plan_impl(
&schema_right,
&mut ExpressionConversionState::new(true, state.expr_depth),
)?;
let extra_predicates = extra_predicates
.into_iter()
.map(|jc| {
let left_on = create_physical_expr(
&jc.left_on,
Context::Default,
expr_arena,
&schema_left,
&mut ExpressionConversionState::new(true, state.expr_depth),
)?;
let right_on = create_physical_expr(
&jc.right_on,
Context::Default,
expr_arena,
&schema_right,
&mut ExpressionConversionState::new(true, state.expr_depth),
)?;
Ok(JoinPredicateExec {
left_on,
right_on,
op: jc.op,
})
})
.collect::<PolarsResult<Vec<_>>>()?;
let options = Arc::try_unwrap(options).unwrap_or_else(|options| (*options).clone());
Ok(Box::new(executors::JoinExec::new(
input_left,
input_right,
left_on,
right_on,
extra_predicates,
parallel,
options.args,
)))
Expand Down
Loading
Loading