Skip to content

Commit

Permalink
feat: Add AST analyzer middleware to optionally rewrite queries befor…
Browse files Browse the repository at this point in the history
…e execution (#55)

Co-authored-by: peasee <[email protected]>
  • Loading branch information
phillipleblanc and peasee authored Sep 6, 2024
1 parent cf6c6dd commit ed418c5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
8 changes: 7 additions & 1 deletion datafusion-federation/src/sql/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use async_trait::async_trait;
use core::fmt;
use datafusion::{
arrow::datatypes::SchemaRef, error::Result, physical_plan::SendableRecordBatchStream,
sql::unparser::dialect::Dialect,
sql::sqlparser::ast, sql::unparser::dialect::Dialect,
};
use std::sync::Arc;

pub type SQLExecutorRef = Arc<dyn SQLExecutor>;
pub type AstAnalyzer = Box<dyn Fn(ast::Statement) -> Result<ast::Statement>>;

#[async_trait]
pub trait SQLExecutor: Sync + Send {
Expand All @@ -20,6 +21,11 @@ pub trait SQLExecutor: Sync + Send {
// The specific SQL dialect (currently supports 'sqlite', 'postgres', 'flight')
fn dialect(&self) -> Arc<dyn Dialect>;

/// Returns an AST analyzer specific for this engine to modify the AST before execution
fn ast_analyzer(&self) -> Option<AstAnalyzer> {
None
}

// Execution
/// Execute a SQL query
fn execute(&self, query: &str, schema: SchemaRef) -> Result<SendableRecordBatchStream>;
Expand Down
10 changes: 8 additions & 2 deletions datafusion-federation/src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -645,8 +645,13 @@ impl VirtualExecutionPlan {
fn sql(&self) -> Result<String> {
// Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table.
let mut known_rewrites = HashMap::new();
let ast = Unparser::new(self.executor.dialect().as_ref())
let mut ast = Unparser::new(self.executor.dialect().as_ref())
.plan_to_sql(&rewrite_table_scans(&self.plan, &mut known_rewrites)?)?;

if let Some(analyzer) = self.executor.ast_analyzer() {
ast = analyzer(ast)?;
}

Ok(format!("{ast}"))
}
}
Expand All @@ -660,7 +665,8 @@ impl DisplayAs for VirtualExecutionPlan {
write!(f, " name={}", self.executor.name())?;
if let Some(ctx) = self.executor.compute_context() {
write!(f, " compute_context={ctx}")?;
}
};

write!(f, " sql={ast}")?;
if let Ok(query) = self.sql() {
write!(f, " rewritten_sql={query}")?;
Expand Down

0 comments on commit ed418c5

Please sign in to comment.