Skip to content

Commit

Permalink
feat: add support for producing SQL in multiple dialects (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
sardination authored Mar 4, 2024
1 parent f81967d commit cc69b79
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 34 deletions.
4 changes: 4 additions & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,9 @@ datafusion-federation-flight-sql.path = "../sources/flight-sql"
connectorx = { git = "https://github.com/sfu-db/connector-x.git", rev = "fa0fc7bc", features = [
"dst_arrow",
"src_sqlite",
"src_postgres",
] }
tonic = "0.10.2"

[dependencies]
async-std = "1.12.0"
70 changes: 70 additions & 0 deletions examples/examples/postgres-partial.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use std::sync::Arc;
use tokio::task;

use datafusion::{
catalog::schema::SchemaProvider,
error::Result,
execution::context::{SessionContext, SessionState},
};
use datafusion_federation::{FederatedQueryPlanner, FederationAnalyzerRule};
use datafusion_federation_sql::connectorx::CXExecutor;
use datafusion_federation_sql::{MultiSchemaProvider, SQLFederationProvider, SQLSchemaProvider};

#[tokio::main]
async fn main() -> Result<()> {
let state = SessionContext::new().state();
// Register FederationAnalyzer
// TODO: Interaction with other analyzers & optimizers.
let state = state
.add_analyzer_rule(Arc::new(FederationAnalyzerRule::new()))
.with_query_planner(Arc::new(FederatedQueryPlanner::new()));

let df = task::spawn_blocking(move || {
// Register schema
let pg_provider_1 = async_std::task::block_on(create_postgres_provider(vec!["class"], "conn1")).unwrap();
let pg_provider_2 = async_std::task::block_on(create_postgres_provider(vec!["teacher"], "conn2")).unwrap();
let provider = MultiSchemaProvider::new(vec![
pg_provider_1,
pg_provider_2,
]);

overwrite_default_schema(&state, Arc::new(provider)).unwrap();

// Run query
let ctx = SessionContext::new_with_state(state);
let query = r#"SELECT class.name AS classname, teacher.name AS teachername FROM class JOIN teacher ON class.id = teacher.class_id"#;
let df = async_std::task::block_on(ctx.sql(query)).unwrap();

df
}).await.unwrap();

task::spawn_blocking(move || async_std::task::block_on(df.show()))
.await
.unwrap()
}

async fn create_postgres_provider(
known_tables: Vec<&str>,
context: &str,
) -> Result<Arc<SQLSchemaProvider>> {
let dsn = "postgresql://<username>:<password>@localhost:<port>/<dbname>".to_string();
let known_tables: Vec<String> = known_tables.iter().map(|&x| x.into()).collect();
let mut executor = CXExecutor::new(dsn)?;
executor.context(context.to_string());
let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
Ok(Arc::new(
SQLSchemaProvider::new_with_tables(provider, known_tables).await?,
))
}

fn overwrite_default_schema(state: &SessionState, schema: Arc<dyn SchemaProvider>) -> Result<()> {
let options = &state.config().options().catalog;
let catalog = state
.catalog_list()
.catalog(options.default_catalog.as_str())
.unwrap();

catalog.register_schema(options.default_schema.as_str(), schema)?;

Ok(())
}
5 changes: 5 additions & 0 deletions sources/flight-sql/src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use async_trait::async_trait;
use datafusion::{
error::{DataFusionError, Result},
physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream},
sql::sqlparser::dialect::{Dialect, GenericDialect},
};
use datafusion_federation_sql::SQLExecutor;
use futures::TryStreamExt;
Expand Down Expand Up @@ -93,6 +94,10 @@ impl SQLExecutor for FlightSQLExecutor {
let schema = flight_info.try_decode_schema().map_err(arrow_error_to_df)?;
Ok(Arc::new(schema))
}

fn dialect(&self) -> Arc<dyn Dialect> {
Arc::new(GenericDialect {})
}
}

fn arrow_error_to_df(err: ArrowError) -> DataFusionError {
Expand Down
18 changes: 16 additions & 2 deletions sources/sql/src/connectorx/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@ use async_trait::async_trait;
use connectorx::{
destinations::arrow::ArrowDestinationError,
errors::{ConnectorXError, ConnectorXOutError},
prelude::{get_arrow, CXQuery, SourceConn},
prelude::{get_arrow, CXQuery, SourceConn, SourceType},
};
use datafusion::{
arrow::datatypes::{Field, Schema, SchemaRef},
error::{DataFusionError, Result},
physical_plan::{
stream::RecordBatchStreamAdapter, EmptyRecordBatchStream, SendableRecordBatchStream,
},
sql::sqlparser::dialect::{Dialect, GenericDialect, PostgreSqlDialect, SQLiteDialect},
};
use futures::executor::block_on;
use std::sync::Arc;
use tokio::task;

use crate::executor::SQLExecutor;

Expand Down Expand Up @@ -54,7 +57,10 @@ impl SQLExecutor for CXExecutor {
let conn = self.conn.clone();
let query: CXQuery = sql.into();

let mut dst = get_arrow(&conn, None, &[query.clone()]).map_err(cx_out_error_to_df)?;
let mut dst = block_on(task::spawn_blocking(move || -> Result<_, _> {
get_arrow(&conn, None, &[query.clone()]).map_err(cx_out_error_to_df)
}))
.map_err(|err| DataFusionError::External(err.to_string().into()))??;
let stream = if let Some(batch) = dst.record_batch().map_err(cx_dst_error_to_df)? {
futures::stream::once(async move { Ok(batch) })
} else {
Expand Down Expand Up @@ -84,6 +90,14 @@ impl SQLExecutor for CXExecutor {
let schema = schema_to_lowercase(dst.arrow_schema());
Ok(schema)
}

fn dialect(&self) -> Arc<dyn Dialect> {
match &self.conn.ty {
SourceType::Postgres => Arc::new(PostgreSqlDialect {}),
SourceType::SQLite => Arc::new(SQLiteDialect {}),
_ => Arc::new(GenericDialect {}),
}
}
}

fn cx_dst_error_to_df(err: ArrowDestinationError) -> DataFusionError {
Expand Down
4 changes: 4 additions & 0 deletions sources/sql/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use async_trait::async_trait;
use core::fmt;
use datafusion::{
arrow::datatypes::SchemaRef, error::Result, physical_plan::SendableRecordBatchStream,
sql::sqlparser::dialect::Dialect,
};
use std::sync::Arc;

Expand All @@ -16,6 +17,9 @@ pub trait SQLExecutor: Sync + Send {
/// such as authorization or active database.
fn compute_context(&self) -> Option<String>;

// The specific SQL dialect (currently supports 'sqlite', 'postgres', 'flight')
fn dialect(&self) -> Arc<dyn Dialect>;

// Execution
/// Execute a SQL query
fn execute(&self, query: &str, schema: SchemaRef) -> Result<SendableRecordBatchStream>;
Expand Down
2 changes: 1 addition & 1 deletion sources/sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl ExecutionPlan for VirtualExecutionPlan {
_partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let ast = query_to_sql(&self.plan)?;
let ast = query_to_sql(&self.plan, self.executor.dialect())?;
let query = format!("{ast}");

self.executor.execute(query.as_str(), self.schema())
Expand Down
Loading

0 comments on commit cc69b79

Please sign in to comment.