Skip to content

Commit

Permalink
feat: implement catalog/schema/table commands on the service (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielc authored Sep 3, 2024
1 parent 003a5dc commit cf6c6dd
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 26 deletions.
1 change: 1 addition & 0 deletions datafusion-flight-sql-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ prost = "0.12.3"
arrow = "52.0.0"
arrow-flight = { version = "52.2.0", features = ["flight-sql-experimental"] }
log = "0.4.22"
once_cell = "1.19.0"

[dev-dependencies]
tokio = { version = "1.39.3", features = ["full"] }
Expand Down
188 changes: 162 additions & 26 deletions datafusion-flight-sql-server/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use std::{pin::Pin, sync::Arc};

use arrow::{datatypes::SchemaRef, error::ArrowError, ipc::writer::IpcWriteOptions};
use arrow::{
array::{ArrayRef, RecordBatch, StringArray},
datatypes::{DataType, Field, SchemaRef},
error::ArrowError,
ipc::writer::IpcWriteOptions,
};
use arrow_flight::sql::{
self,
server::{FlightSqlService as ArrowFlightSqlService, PeekableFlightDataStream},
Expand All @@ -13,7 +18,7 @@ use arrow_flight::sql::{
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate,
DoPutPreparedStatementResult, SqlInfo, TicketStatementQuery,
DoPutPreparedStatementResult, ProstMessageExt as _, SqlInfo, TicketStatementQuery,
};
use arrow_flight::{
encode::FlightDataEncoderBuilder,
Expand All @@ -25,6 +30,7 @@ use arrow_flight::{
use datafusion::{
common::arrow::datatypes::Schema,
dataframe::DataFrame,
datasource::TableType,
error::{DataFusionError, Result as DataFusionResult},
execution::context::{SQLOptions, SessionContext, SessionState},
logical_expr::LogicalPlan,
Expand All @@ -35,7 +41,9 @@ use datafusion_substrait::{
};
use futures::{Stream, StreamExt, TryStreamExt};
use log::info;
use once_cell::sync::Lazy;
use prost::bytes::Bytes;
use prost::Message;
use tonic::transport::Server;
use tonic::{Request, Response, Status, Streaming};

Expand Down Expand Up @@ -92,6 +100,16 @@ impl FlightSqlService {
}
}

/// The schema for GetTableTypes
static GET_TABLE_TYPES_SCHEMA: Lazy<SchemaRef> = Lazy::new(|| {
//TODO: Move this into arrow-flight itself, similar to the builder pattern for CommandGetCatalogs and CommandGetDbSchemas
Arc::new(Schema::new(vec![Field::new(
"table_type",
DataType::Utf8,
false,
)]))
});

struct FlightSqlSessionContext {
inner: SessionContext,
}
Expand Down Expand Up @@ -334,48 +352,93 @@ impl ArrowFlightSqlService for FlightSqlService {

async fn get_flight_info_catalogs(
&self,
_query: CommandGetCatalogs,
query: CommandGetCatalogs,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>> {
info!("get_flight_info_catalogs");
let (_, _) = self.new_context(request)?;
let (request, _ctx) = self.new_context(request)?;

let flight_descriptor = request.into_inner();
let ticket = Ticket {
ticket: query.as_any().encode_to_vec().into(),
};
let endpoint = FlightEndpoint::new().with_ticket(ticket);

Err(Status::unimplemented("Implement get_flight_info_catalogs"))
let flight_info = FlightInfo::new()
.try_with_schema(&query.into_builder().schema())
.map_err(arrow_error_to_status)?
.with_endpoint(endpoint)
.with_descriptor(flight_descriptor);

Ok(Response::new(flight_info))
}

async fn get_flight_info_schemas(
&self,
_query: CommandGetDbSchemas,
query: CommandGetDbSchemas,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>> {
info!("get_flight_info_schemas");
let (_, _) = self.new_context(request)?;
let (request, _ctx) = self.new_context(request)?;
let flight_descriptor = request.into_inner();
let ticket = Ticket {
ticket: query.as_any().encode_to_vec().into(),
};
let endpoint = FlightEndpoint::new().with_ticket(ticket);

let flight_info = FlightInfo::new()
.try_with_schema(&query.into_builder().schema())
.map_err(arrow_error_to_status)?
.with_endpoint(endpoint)
.with_descriptor(flight_descriptor);

Err(Status::unimplemented("Implement get_flight_info_schemas"))
Ok(Response::new(flight_info))
}

async fn get_flight_info_tables(
&self,
_query: CommandGetTables,
query: CommandGetTables,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>> {
info!("get_flight_info_tables");
let (_, _) = self.new_context(request)?;
let (request, _ctx) = self.new_context(request)?;

let flight_descriptor = request.into_inner();
let ticket = Ticket {
ticket: query.as_any().encode_to_vec().into(),
};
let endpoint = FlightEndpoint::new().with_ticket(ticket);

let flight_info = FlightInfo::new()
.try_with_schema(&query.into_builder().schema())
.map_err(arrow_error_to_status)?
.with_endpoint(endpoint)
.with_descriptor(flight_descriptor);

Err(Status::unimplemented("Implement get_flight_info_tables"))
Ok(Response::new(flight_info))
}

async fn get_flight_info_table_types(
&self,
_query: CommandGetTableTypes,
query: CommandGetTableTypes,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>> {
info!("get_flight_info_table_types");
let (_, _) = self.new_context(request)?;
let (request, _ctx) = self.new_context(request)?;

Err(Status::unimplemented(
"Implement get_flight_info_table_types",
))
let flight_descriptor = request.into_inner();
let ticket = Ticket {
ticket: query.as_any().encode_to_vec().into(),
};
let endpoint = FlightEndpoint::new().with_ticket(ticket);

let flight_info = FlightInfo::new()
.try_with_schema(&GET_TABLE_TYPES_SCHEMA)
.map_err(arrow_error_to_status)?
.with_endpoint(endpoint)
.with_descriptor(flight_descriptor);

Ok(Response::new(flight_info))
}

async fn get_flight_info_sql_info(
Expand Down Expand Up @@ -478,35 +541,94 @@ impl ArrowFlightSqlService for FlightSqlService {

async fn do_get_catalogs(
&self,
_query: CommandGetCatalogs,
query: CommandGetCatalogs,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>> {
info!("do_get_catalogs");
let (_, _) = self.new_context(request)?;
let (_request, ctx) = self.new_context(request)?;
let catalog_names = ctx.inner.catalog_names();

Err(Status::unimplemented("Implement do_get_catalogs"))
let mut builder = query.into_builder();
for catalog_name in &catalog_names {
builder.append(catalog_name);
}
let schema = builder.schema();
let batch = builder.build();
let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(futures::stream::once(async { batch }))
.map_err(Status::from);
Ok(Response::new(Box::pin(stream)))
}

async fn do_get_schemas(
&self,
_query: CommandGetDbSchemas,
query: CommandGetDbSchemas,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>> {
info!("do_get_schemas");
let (_, _) = self.new_context(request)?;
let (_request, ctx) = self.new_context(request)?;
let catalog_name = query.catalog.clone();
// Append all schemas to builder, the builder handles applying the filters.
let mut builder = query.into_builder();
if let Some(catalog_name) = &catalog_name {
if let Some(catalog) = ctx.inner.catalog(catalog_name) {
for schema_name in &catalog.schema_names() {
builder.append(catalog_name, schema_name);
}
}
};

Err(Status::unimplemented("Implement do_get_schemas"))
let schema = builder.schema();
let batch = builder.build();
let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(futures::stream::once(async { batch }))
.map_err(Status::from);
Ok(Response::new(Box::pin(stream)))
}

async fn do_get_tables(
&self,
_query: CommandGetTables,
query: CommandGetTables,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>> {
info!("do_get_tables");
let (_, _) = self.new_context(request)?;
let (_request, ctx) = self.new_context(request)?;
let catalog_name = query.catalog.clone();
let mut builder = query.into_builder();
// Append all schemas/tables to builder, the builder handles applying the filters.
if let Some(catalog_name) = &catalog_name {
if let Some(catalog) = ctx.inner.catalog(catalog_name) {
for schema_name in &catalog.schema_names() {
if let Some(schema) = catalog.schema(schema_name) {
for table_name in &schema.table_names() {
if let Some(table) =
schema.table(table_name).await.map_err(df_error_to_status)?
{
builder
.append(
catalog_name,
schema_name,
table_name,
table.table_type().to_string(),
&table.schema(),
)
.map_err(flight_error_to_status)?;
}
}
}
}
}
};

Err(Status::unimplemented("Implement do_get_tables"))
let schema = builder.schema();
let batch = builder.build();
let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(futures::stream::once(async { batch }))
.map_err(Status::from);
Ok(Response::new(Box::pin(stream)))
}

async fn do_get_table_types(
Expand All @@ -517,7 +639,21 @@ impl ArrowFlightSqlService for FlightSqlService {
info!("do_get_table_types");
let (_, _) = self.new_context(request)?;

Err(Status::unimplemented("Implement do_get_table_types"))
// Report all variants of table types that datafusion uses.
let table_types: ArrayRef = Arc::new(StringArray::from(
vec![TableType::Base, TableType::View, TableType::Temporary]
.into_iter()
.map(|tt| tt.to_string())
.collect::<Vec<String>>(),
));

let batch = RecordBatch::try_from_iter(vec![("table_type", table_types)]).unwrap();

let stream = FlightDataEncoderBuilder::new()
.with_schema(GET_TABLE_TYPES_SCHEMA.clone())
.build(futures::stream::once(async { Ok(batch) }))
.map_err(Status::from);
Ok(Response::new(Box::pin(stream)))
}

async fn do_get_sql_info(
Expand Down

0 comments on commit cf6c6dd

Please sign in to comment.