Skip to content

Commit

Permalink
Add Catalog & more DataFrame traits (#4)
Browse files Browse the repository at this point in the history
* feat: implement basic catalog

* feat: colummns & dtypes
  • Loading branch information
sjrusso8 authored Mar 12, 2024
1 parent 778f7c0 commit d82454a
Show file tree
Hide file tree
Showing 7 changed files with 317 additions and 5 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ is not yet working with this Spark Connect implementation.
| range | ![done] | |
| sql | ![done] | Does not include the new Spark Connect 3.5 feature with "position arguments" |
| read | ![done] | |
| readStream | ![done] | |
| readStream | ![open] | |
| createDataFrame | ![open] | |
| getActiveSession | ![open] | |
| catalog | ![open] | |
| catalog | ![open] | Partial. List/Get functions are implemented |


### DataFrame
Expand All @@ -87,6 +87,8 @@ is not yet working with this Spark Connect implementation.
| sample | ![done] | |
| repartition | ![done] | |
| offset | ![done] | |
| dtypes | ![done] | |
| columns | ![done] | |
| schema | ![done] | The output needs to be handled better |
| explain | ![done] | The output needs to be handled better |
| show | ![done] | |
Expand Down Expand Up @@ -122,7 +124,7 @@ required jars
Spark [Column](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/column.html) type object and its implemented traits


| DataFrame | API | Comment |
| Column | API | Comment |
|------------------|---------|------------------------------------------------------------------------------|
| alias | ![done] | |
| asc | ![done] | |
Expand Down Expand Up @@ -162,7 +164,7 @@ Spark [Column](https://spark.apache.org/docs/latest/api/python/reference/pyspark

Only a few of the functions are covered by unit tests.

| DataFrame | API | Comment |
| Functions | API | Comment |
|-----------------------------|---------|---------|
| abs | ![done] | |
| acos | ![open] | |
Expand Down
216 changes: 216 additions & 0 deletions src/catalog.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
//! Spark Catalog representation through which the user may create, drop, alter or query underlying databases, tables, functions, etc.

use arrow::array::RecordBatch;

use crate::plan::LogicalPlanBuilder;
use crate::session::SparkSession;
use crate::spark;

#[derive(Debug, Clone)]
pub struct Catalog {
spark_session: SparkSession,
}

impl Catalog {
pub fn new(spark_session: SparkSession) -> Self {
Self { spark_session }
}

/// Returns the current default catalog in this session
#[allow(non_snake_case)]
pub async fn currentCatalog(&mut self) -> String {
let cat_type = Some(spark::catalog::CatType::CurrentCatalog(
spark::CurrentCatalog {},
));

let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });

let plan = LogicalPlanBuilder::from(rel_type).clone().build_plan_root();

self.spark_session
.clone()
.consume_plan_and_fetch(Some(plan))
.await
.unwrap()
}

/// Returns the current default database in this session
#[allow(non_snake_case)]
pub async fn currentDatabase(&mut self) -> String {
let cat_type = Some(spark::catalog::CatType::CurrentDatabase(
spark::CurrentDatabase {},
));

let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });

let plan = LogicalPlanBuilder::from(rel_type).clone().build_plan_root();

self.spark_session
.clone()
.consume_plan_and_fetch(Some(plan))
.await
.unwrap()
}

/// Returns a list of catalogs in this session
#[allow(non_snake_case)]
pub async fn listCatalogs(&mut self, pattern: Option<String>) -> Vec<RecordBatch> {
let cat_type = Some(spark::catalog::CatType::ListCatalogs(spark::ListCatalogs {
pattern,
}));

let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });

let plan = LogicalPlanBuilder::from(rel_type).clone().build_plan_root();

self.spark_session
.clone()
.consume_plan(Some(plan))
.await
.unwrap()
}

/// Returns a list of databases in this session
#[allow(non_snake_case)]
pub async fn listDatabases(&mut self, pattern: Option<String>) -> Vec<RecordBatch> {
let cat_type = Some(spark::catalog::CatType::ListDatabases(
spark::ListDatabases { pattern },
));

let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });

let plan = LogicalPlanBuilder::from(rel_type).clone().build_plan_root();

self.spark_session
.clone()
.consume_plan(Some(plan))
.await
.unwrap()
}

/// Returns a list of tables/views in the specific database
#[allow(non_snake_case)]
pub async fn listTables(
&mut self,
dbName: Option<String>,
pattern: Option<String>,
) -> Vec<RecordBatch> {
let cat_type = Some(spark::catalog::CatType::ListTables(spark::ListTables {
db_name: dbName,
pattern,
}));

let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });

let plan = LogicalPlanBuilder::from(rel_type).clone().build_plan_root();

self.spark_session
.clone()
.consume_plan(Some(plan))
.await
.unwrap()
}

/// Returns a list of columns for the given tables/views in the specific database
#[allow(non_snake_case)]
pub async fn listColumns(
&mut self,
tableName: String,
dbName: Option<String>,
) -> Vec<RecordBatch> {
let cat_type = Some(spark::catalog::CatType::ListColumns(spark::ListColumns {
table_name: tableName,
db_name: dbName,
}));

let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });

let plan = LogicalPlanBuilder::from(rel_type).clone().build_plan_root();

self.spark_session
.clone()
.consume_plan(Some(plan))
.await
.unwrap()
}
}

#[cfg(test)]
mod tests {

use super::*;

use crate::SparkSessionBuilder;

async fn setup() -> SparkSession {
println!("SparkSession Setup");

let connection = "sc://127.0.0.1:15002/;user_id=rust_test".to_string();

SparkSessionBuilder::remote(connection)
.build()
.await
.unwrap()
}

#[tokio::test]
async fn test_current_catalog() {
let spark = setup().await;

let value = spark.catalog().currentCatalog().await;

assert_eq!(value, "spark_catalog".to_string())
}

#[tokio::test]
async fn test_current_database() {
let spark = setup().await;

let value = spark.catalog().currentDatabase().await;

assert_eq!(value, "default".to_string());
}

#[tokio::test]
async fn test_list_catalogs() {
let spark = setup().await;

let value = spark.catalog().listCatalogs(None).await;

assert_eq!(2, value[0].num_columns());
assert_eq!(1, value[0].num_rows());
}

#[tokio::test]
async fn test_list_databases() {
let spark = setup().await;

spark
.clone()
.sql("CREATE SCHEMA IF NOT EXISTS spark_rust")
.await;

let value = spark.catalog().listDatabases(None).await;

assert_eq!(4, value[0].num_columns());
assert_eq!(2, value[0].num_rows());
}

#[tokio::test]
async fn test_list_databases_pattern() {
let spark = setup().await;

spark
.clone()
.sql("CREATE SCHEMA IF NOT EXISTS spark_rust")
.await;

let value = spark
.catalog()
.listDatabases(Some("*rust".to_string()))
.await;

assert_eq!(4, value[0].num_columns());
assert_eq!(1, value[0].num_rows());
}
}
39 changes: 39 additions & 0 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,45 @@ impl DataFrame {
}
}

/// Retrieves the names of all columns in the DataFrame as a `Vec<String>`.
/// The order of the column names in the list reflects their order in the DataFrame.
pub async fn columns(&mut self) -> Vec<String> {
let schema = self.schema().await.schema.unwrap();

let struct_val = schema.kind.unwrap();

match struct_val {
spark::data_type::Kind::Struct(val) => val
.fields
.iter()
.map(|field| field.name.to_string())
.collect(),
_ => unimplemented!("Unexpected schema response"),
}
}

/// Returns all column names and their data types as a Vec containing
/// the field name as a String and the [spark::data_type::Kind] enum
pub async fn dtypes(&mut self) -> Vec<(String, Option<spark::data_type::Kind>)> {
let schema = self.schema().await.schema.unwrap();

let struct_val = schema.kind.unwrap();

match struct_val {
spark::data_type::Kind::Struct(val) => val
.fields
.iter()
.map(|field| {
(
field.name.to_string(),
field.data_type.clone().unwrap().kind,
)
})
.collect(),
_ => unimplemented!("Unexpected schema response"),
}
}

/// Prints the [spark::Plan] to the console
///
/// # Arguments:
Expand Down
2 changes: 2 additions & 0 deletions src/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use arrow_ipc::reader::StreamReader;
use spark::execute_plan_response::{ArrowBatch, Metrics};
use spark::{DataType, ExecutePlanResponse};

#[derive(Debug, Clone)]
pub struct ResponseHandler {
pub schema: Option<DataType>,
pub data: Vec<Option<ArrowBatch>>,
Expand Down Expand Up @@ -63,6 +64,7 @@ impl ResponseHandler {
}
}

#[derive(Debug, Clone)]
struct ArrowBatchReader {
batch: ArrowBatch,
}
Expand Down
25 changes: 25 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ pub mod plan;
pub mod readwriter;
pub mod session;

mod catalog;
mod client;
pub mod column;
pub mod expressions;
Expand Down Expand Up @@ -236,4 +237,28 @@ mod tests {

assert_eq!(total, 1000)
}

#[tokio::test]
async fn test_dataframe_columns() {
let spark = setup().await;

let paths = vec!["/opt/spark/examples/src/main/resources/people.csv".to_string()];

let cols = spark
.read()
.format("csv")
.option("header", "True")
.option("delimiter", ";")
.load(paths)
.columns()
.await;

let expected = vec![
String::from("name"),
String::from("age"),
String::from("job"),
];

assert_eq!(cols, expected)
}
}
7 changes: 6 additions & 1 deletion src/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,12 @@ impl LogicalPlanBuilder {
let order = cols
.iter()
.map(|col| {
if let ExprType::SortOrder(ord) = col.expression.clone().expr_type.unwrap() {
if let ExprType::SortOrder(ord) = col
.expression
.clone()
.expr_type
.expect("provided column set is not sortable")
{
*ord
} else {
// TODO don't make this a panic but actually raise an error
Expand Down
Loading

0 comments on commit d82454a

Please sign in to comment.