Skip to content

Commit

Permalink
feat(errors): some better error handling
Browse files Browse the repository at this point in the history
- create custom error enum
- update session traits to leverage new error method
- update .sql command to be a Result
  • Loading branch information
sjrusso8 committed Mar 13, 2024
1 parent d82454a commit 635c321
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 33 deletions.
2 changes: 1 addition & 1 deletion examples/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let mut df = spark
.sql("SELECT * FROM json.`/opt/spark/examples/src/main/resources/employees.json`")
.await;
.await?;

df.filter("salary >= 3500")
.select("*")
Expand Down
6 changes: 4 additions & 2 deletions src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ mod tests {
spark
.clone()
.sql("CREATE SCHEMA IF NOT EXISTS spark_rust")
.await;
.await
.unwrap();

let value = spark.catalog().listDatabases(None).await;

Expand All @@ -203,7 +204,8 @@ mod tests {
spark
.clone()
.sql("CREATE SCHEMA IF NOT EXISTS spark_rust")
.await;
.await
.unwrap();

let value = spark
.catalog()
Expand Down
6 changes: 4 additions & 2 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::collections::HashMap;

use crate::column::Column;
use crate::errors::SparkError;
use crate::expressions::{ToFilterExpr, ToVecExpr};
use crate::plan::LogicalPlanBuilder;
pub use crate::readwriter::{DataFrameReader, DataFrameWriter};
Expand Down Expand Up @@ -334,7 +335,7 @@ impl DataFrame {
num_rows: Option<i32>,
truncate: Option<i32>,
vertical: Option<bool>,
) -> Result<(), ArrowError> {
) -> Result<(), SparkError> {
let show_expr = RelType::ShowString(Box::new(spark::ShowString {
input: self.logical_plan.clone().relation_input(),
num_rows: num_rows.unwrap_or(10),
Expand All @@ -346,7 +347,8 @@ impl DataFrame {

let rows = self.spark_session.consume_plan(Some(plan)).await?;

pretty::print_batches(rows.as_slice())
pretty::print_batches(rows.as_slice())?;
Ok(())
}

/// Returns the last `n` rows as vector of [RecordBatch]
Expand Down
114 changes: 114 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//! Defines `SparkError` for representing failures in various Spark operations.
//! Most of these are wrappers for tonic or arrow error messages
use std::fmt::{Debug, Display, Formatter};
use std::io::Write;

use std::error::Error;

use arrow::error::ArrowError;

/// Many different operations in the `Spark` crate return this error type.
#[derive(Debug)]
pub enum SparkError {
/// Returned when functionality is not yet available.
NotYetImplemented(String),
ExternalError(Box<dyn Error + Send + Sync>),
AnalysisException(String),
IoError(String, std::io::Error),
ArrowError(ArrowError),
}

impl SparkError {
/// Wraps an external error in an `SparkError`.
pub fn from_external_error(error: Box<dyn Error + Send + Sync>) -> Self {
Self::ExternalError(error)
}
}

impl From<std::io::Error> for SparkError {
fn from(error: std::io::Error) -> Self {
SparkError::IoError(error.to_string(), error)
}
}

impl From<std::str::Utf8Error> for SparkError {
fn from(error: std::str::Utf8Error) -> Self {
SparkError::AnalysisException(error.to_string())
}
}

impl From<std::string::FromUtf8Error> for SparkError {
fn from(error: std::string::FromUtf8Error) -> Self {
SparkError::AnalysisException(error.to_string())
}
}

impl From<ArrowError> for SparkError {
fn from(error: ArrowError) -> Self {
SparkError::ArrowError(error)
}
}

impl From<tonic::Status> for SparkError {
fn from(status: tonic::Status) -> Self {
SparkError::AnalysisException(status.message().to_string())
}
}

impl<W: Write> From<std::io::IntoInnerError<W>> for SparkError {
fn from(error: std::io::IntoInnerError<W>) -> Self {
SparkError::IoError(error.to_string(), error.into())
}
}

impl Display for SparkError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
SparkError::ExternalError(source) => write!(f, "External error: {}", &source),
SparkError::AnalysisException(desc) => write!(f, "Analysis error: {desc}"),
SparkError::IoError(desc, _) => write!(f, "Io error: {desc}"),
SparkError::ArrowError(desc) => write!(f, "Apache Arrow error: {desc}"),
SparkError::NotYetImplemented(source) => write!(f, "Not yet implemented: {source}"),
}
}
}

impl Error for SparkError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
if let Self::ExternalError(e) = self {
Some(e.as_ref())
} else {
None
}
}
}

// #[cfg(test)]
// mod test {
// use super::*;
//
// #[test]
// fn error_source() {
// let e1 = SparkError::DivideByZero;
// assert!(e1.source().is_none());
//
// // one level of wrapping
// let e2 = SparkError::ExternalError(Box::new(e1));
// let source = e2.source().unwrap().downcast_ref::<SparkError>().unwrap();
// assert!(matches!(source, SparkError::DivideByZero));
//
// // two levels of wrapping
// let e3 = SparkError::ExternalError(Box::new(e2));
// let source = e3
// .source()
// .unwrap()
// .downcast_ref::<SparkError>()
// .unwrap()
// .source()
// .unwrap()
// .downcast_ref::<SparkError>()
// .unwrap();
//
// assert!(matches!(source, SparkError::DivideByZero));
// }
// }
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! .build()
//! .await?;
//!
//! let mut df = spark.sql("SELECT * FROM json.`/opt/spark/examples/src/main/resources/employees.json`");
//! let mut df = spark.sql("SELECT * FROM json.`/opt/spark/examples/src/main/resources/employees.json`").await?;
//!
//! df.filter("salary > 3000").show(Some(5), None, None).await?;
//!
Expand Down Expand Up @@ -80,6 +80,7 @@ pub mod session;
mod catalog;
mod client;
pub mod column;
mod errors;
pub mod expressions;
pub mod functions;
mod handler;
Expand Down
46 changes: 19 additions & 27 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use std::sync::Arc;
use crate::catalog::Catalog;
pub use crate::client::SparkSessionBuilder;
use crate::dataframe::{DataFrame, DataFrameReader};
use crate::errors::SparkError;
use crate::handler::ResponseHandler;
use crate::plan::LogicalPlanBuilder;
use crate::spark;

use arrow::error::ArrowError;
use arrow::record_batch::RecordBatch;

use spark::spark_connect_service_client::SparkConnectServiceClient;
Expand Down Expand Up @@ -82,7 +82,11 @@ impl SparkSession {
}

/// Returns a [DataFrame] representing the result of the given query
pub async fn sql(&mut self, sql_query: &str) -> DataFrame {
pub async fn sql(&mut self, sql_query: &str) -> Result<DataFrame, SparkError> {
let error_msg = SparkError::AnalysisException(
"Failed to get command response from Spark Connect Server".to_string(),
);

let sql_cmd = spark::command::CommandType::SqlCommand(spark::SqlCommand {
sql: sql_query.to_string(),
args: HashMap::default(),
Expand All @@ -91,23 +95,16 @@ impl SparkSession {

let plan = LogicalPlanBuilder::build_plan_cmd(sql_cmd);

// !TODO this is gross and needs to be handled WAY better
let resp = self
.execute_plan(Some(plan))
.await
.unwrap()
.message()
.await
.unwrap()
.unwrap();
let resp = self.execute_plan(Some(plan)).await?.message().await?;

match resp.response_type {
match resp.ok_or(error_msg)?.response_type {
Some(spark::execute_plan_response::ResponseType::SqlCommandResult(sql_result)) => {
let logical_plan = LogicalPlanBuilder::new(sql_result.relation.unwrap());
DataFrame::new(self.clone(), logical_plan)
Ok(DataFrame::new(self.clone(), logical_plan))
}
Some(_) => todo!("not implemented"),
None => todo!("got none as a response for SQL Command"),
_ => Err(SparkError::NotYetImplemented(
"Response type not implemented".to_string(),
)),
}
}

Expand Down Expand Up @@ -146,7 +143,7 @@ impl SparkSession {
pub async fn execute_plan(
&mut self,
plan: Option<spark::Plan>,
) -> Result<Streaming<ExecutePlanResponse>, tonic::Status> {
) -> Result<Streaming<ExecutePlanResponse>, SparkError> {
let exc_plan = self.build_execute_plan_request(plan);

let mut client = self.client.lock().await;
Expand All @@ -163,25 +160,20 @@ impl SparkSession {
pub async fn consume_plan(
&mut self,
plan: Option<spark::Plan>,
) -> Result<Vec<RecordBatch>, ArrowError> {
let mut stream = self.execute_plan(plan).await.map_err(|err| {
ArrowError::IoError(
err.to_string(),
Error::new(std::io::ErrorKind::Other, err.to_string()),
)
})?;
) -> Result<Vec<RecordBatch>, SparkError> {
let mut stream = self.execute_plan(plan).await?;

let mut handler = ResponseHandler::new();

while let Some(resp) = stream.message().await.map_err(|err| {
ArrowError::IoError(
SparkError::IoError(
err.to_string(),
Error::new(std::io::ErrorKind::Other, err.to_string()),
)
})? {
let _ = handler.handle_response(&resp);
}
handler.records()
Ok(handler.records().unwrap())
}

pub async fn analyze_plan(
Expand All @@ -200,13 +192,13 @@ impl SparkSession {
let result = self
.consume_plan(plan)
.await
.expect("failed to get a result from spark connect");
.expect("Failed to get a result from Spark Connect");

let col = result[0].column(0);

let data: &arrow::array::StringArray = match col.data_type() {
arrow::datatypes::DataType::Utf8 => col.as_any().downcast_ref().unwrap(),
_ => unimplemented!("only Utf8 data types are currently handled."),
_ => unimplemented!("only Utf8 data types are currently handled currently."),
};

Some(data.value(0).to_string())
Expand Down

0 comments on commit 635c321

Please sign in to comment.