From 386a4309763c5aeac6fc0c574b10029c5fdc061c Mon Sep 17 00:00:00 2001 From: Steve Russo <64294847+sjrusso8@users.noreply.github.com> Date: Tue, 30 Jul 2024 15:04:26 -0400 Subject: [PATCH] feat(dataframe): implement missing methods (#67) * feat(dataframe): stat & na methods * update README * update rust docs * flaky test --- README.md | 45 ++- core/src/dataframe.rs | 643 +++++++++++++++++++++++++++++++++++++++--- core/src/plan.rs | 212 ++++++++++++-- core/src/window.rs | 6 +- 4 files changed, 815 insertions(+), 91 deletions(-) diff --git a/README.md b/README.md index fb6abca..dffc9f1 100644 --- a/README.md +++ b/README.md @@ -16,11 +16,11 @@ The `spark-connect-rs` aims to provide an entrypoint to [Spark Connect](https:// ## Project Layout ``` -├── core <- core implementation in Rust -│ └─ spark <- git submodule for apache/spark -├── rust <- shim for 'spark-connect-rs' from core -├── examples <- examples of using different aspects of the crate -├── datasets <- sample files from the main spark repo +├── core <- core implementation in Rust +│ └─ protobuf <- connect protobuf for apache/spark +├── rust <- shim for 'spark-connect-rs' from core +├── examples <- examples of using different aspects of the crate +├── datasets <- sample files from the main spark repo ``` Future state would be to have additional bindings for other languages along side the top level `rust` folder. @@ -37,7 +37,6 @@ This section explains how run Spark Connect Rust locally starting from 0. ```bash git clone https://github.com/sjrusso8/spark-connect-rs.git -git submodule update --init --recursive cargo build ``` @@ -266,13 +265,13 @@ Spark [DataFrame](https://spark.apache.org/docs/latest/api/python/reference/pysp |-------------------------------|---------|------------------------------------------------------------| | agg | ![done] | | | alias | ![done] | | -| approxQuantile | ![open] | | +| approxQuantile | ![done] | | | cache | ![done] | | -| checkpoint | ![open] | | +| checkpoint | ![open] | Not part of Spark Connect | | coalesce | ![done] | | | colRegex | ![done] | | | collect | ![done] | | -| columns | ![done] | | +| columns | ![done] | | | corr | ![done] | | | count | ![done] | | | cov | ![done] | | @@ -287,13 +286,13 @@ Spark [DataFrame](https://spark.apache.org/docs/latest/api/python/reference/pysp | distinct | ![done] | | | drop | ![done] | | | dropDuplicates | ![done] | | -| dropDuplicatesWithinWatermark | ![open] | Windowing functions are currently in progress | +| dropDuplicatesWithinWatermark | ![done] | | | drop_duplicates | ![done] | | | dropna | ![done] | | | dtypes | ![done] | | | exceptAll | ![done] | | | explain | ![done] | | -| fillna | ![open] | | +| fillna | ![done] | | | filter | ![done] | | | first | ![done] | | | foreach | ![open] | | @@ -306,29 +305,29 @@ Spark [DataFrame](https://spark.apache.org/docs/latest/api/python/reference/pysp | intersect | ![done] | | | intersectAll | ![done] | | | isEmpty | ![done] | | -| isLocal | ![open] | | +| isLocal | ![done] | | | isStreaming | ![done] | | | join | ![done] | | | limit | ![done] | | -| localCheckpoint | ![open] | | +| localCheckpoint | ![open] | Not part of Spark Connect | | mapInPandas | ![open] | TBD on this exact implementation | | mapInArrow | ![open] | TBD on this exact implementation | | melt | ![done] | | -| na | ![open] | | +| na | ![done] | | | observe | ![open] | | | offset | ![done] | | | orderBy | ![done] | | | persist | ![done] | | | printSchema | ![done] | | -| randomSplit | ![open] | | -| registerTempTable | ![open] | | +| randomSplit | ![done] | | +| registerTempTable | ![done] | | | repartition | ![done] | | -| repartitionByRange | ![open] | | -| replace | ![open] | | +| repartitionByRange | ![done] | | +| replace | ![done] | | | rollup | ![done] | | | sameSemantics | ![done] | | | sample | ![done] | | -| sampleBy | ![open] | | +| sampleBy | ![done] | | | schema | ![done] | | | select | ![done] | | | selectExpr | ![done] | | @@ -340,7 +339,7 @@ Spark [DataFrame](https://spark.apache.org/docs/latest/api/python/reference/pysp | stat | ![done] | | | storageLevel | ![done] | | | subtract | ![done] | | -| summary | ![open] | | +| summary | ![done] | | | tail | ![done] | | | take | ![done] | | | to | ![done] | | @@ -358,10 +357,10 @@ Spark [DataFrame](https://spark.apache.org/docs/latest/api/python/reference/pysp | where | ![done] | use `filter` instead, `where` is a keyword for rust | | withColumn | ![done] | | | withColumns | ![done] | | -| withColumnRenamed | ![open] | | +| withColumnRenamed | ![done] | | | withColumnsRenamed | ![done] | | -| withMetadata | ![open] | | -| withWatermark | ![open] | | +| withMetadata | ![done] | | +| withWatermark | ![done] | | | write | ![done] | | | writeStream | ![done] | | | writeTo | ![done] | | diff --git a/core/src/dataframe.rs b/core/src/dataframe.rs index 28df6a2..09c1cd3 100644 --- a/core/src/dataframe.rs +++ b/core/src/dataframe.rs @@ -2,7 +2,7 @@ use crate::column::Column; use crate::errors::SparkError; -use crate::expressions::{ToExpr, ToFilterExpr, ToVecExpr}; +use crate::expressions::{ToExpr, ToFilterExpr, ToLiteral, ToVecExpr}; use crate::group::GroupedData; use crate::plan::LogicalPlanBuilder; use crate::session::SparkSession; @@ -24,6 +24,8 @@ use arrow::json::ArrayWriter; use arrow::record_batch::RecordBatch; use arrow::util::pretty; +use rand::random; + #[cfg(feature = "datafusion")] use datafusion::execution::context::SessionContext; @@ -85,7 +87,7 @@ pub struct DataFrame { impl DataFrame { /// create default DataFrame based on a spark session and initial logical plan - pub fn new(spark_session: SparkSession, plan: LogicalPlanBuilder) -> DataFrame { + pub(crate) fn new(spark_session: SparkSession, plan: LogicalPlanBuilder) -> DataFrame { DataFrame { spark_session: Box::new(spark_session), plan, @@ -102,7 +104,7 @@ impl DataFrame { Ok(()) } - /// Aggregate on the entire [DataFrame] without groups (shorthand for `df.groupBy().agg()`) + /// Aggregate on the entire [DataFrame] without groups (shorthand for `df.group_by().agg()`) pub fn agg(self, exprs: T) -> DataFrame { self.group_by::(None).agg(exprs) } @@ -117,6 +119,35 @@ impl DataFrame { } } + /// Calculates the approximate quantiles of numerical columns of a [DataFrame]. + pub async fn approx_quantile<'a, I, P>( + self, + cols: I, + probabilities: P, + relative_error: f64, + ) -> Result + where + I: IntoIterator, + P: IntoIterator, + { + if relative_error < 0.0 { + return Err(SparkError::AnalysisException( + "Relative Error Negative Value".to_string(), + )); + } + + let plan = self + .plan + .approx_quantile(cols, probabilities, relative_error); + + let df = DataFrame { + spark_session: self.spark_session, + plan, + }; + + df.collect().await + } + /// Persists the [DataFrame] with the default [storage::StorageLevel::MemoryAndDiskDeser] (MEMORY_AND_DISK_DESER). pub async fn cache(self) -> DataFrame { self.persist(storage::StorageLevel::MemoryAndDiskDeser) @@ -128,19 +159,6 @@ impl DataFrame { self.repartition(num_partitions, Some(false)) } - /// Returns the number of rows in this [DataFrame] - pub async fn count(self) -> Result { - let res = self.group_by::(None).count().collect().await?; - let col = res.column(0); - - let data: &arrow::array::Int64Array = match col.data_type() { - arrow::datatypes::DataType::Int64 => col.as_any().downcast_ref().unwrap(), - _ => unimplemented!("only Utf8 data types are currently handled currently."), - }; - - Ok(data.value(0)) - } - /// Selects column based on the column name specified as a regex and returns it as [Column]. pub fn col_regex(self, col_name: &str) -> Column { let expr = spark::Expression { @@ -212,6 +230,19 @@ impl DataFrame { Ok(data.value(0)) } + /// Returns the number of rows in this [DataFrame] + pub async fn count(self) -> Result { + let res = self.group_by::(None).count().collect().await?; + let col = res.column(0); + + let data: &arrow::array::Int64Array = match col.data_type() { + arrow::datatypes::DataType::Int64 => col.as_any().downcast_ref().unwrap(), + _ => unimplemented!("only Utf8 data types are currently handled currently."), + }; + + Ok(data.value(0)) + } + /// Calculate the sample covariance for the given columns, specified by their names, as a f64 pub async fn cov(self, col1: &str, col2: &str) -> Result { let plan = self.plan.cov(col1, col2); @@ -236,24 +267,26 @@ impl DataFrame { Ok(data.value(0)) } - /// Creates a local temporary view with this DataFrame. - pub async fn create_temp_view(self, name: &str) -> Result<(), SparkError> { - self.create_view_cmd(name, false, false).await - } - + /// Creates a global temporary view with this [DataFrame]. pub async fn create_global_temp_view(self, name: &str) -> Result<(), SparkError> { self.create_view_cmd(name, true, false).await } + /// Creates or replaces a global temporary view using the given name. pub async fn create_or_replace_global_temp_view(self, name: &str) -> Result<(), SparkError> { self.create_view_cmd(name, true, true).await } - /// Creates or replaces a local temporary view with this DataFrame + /// Creates or replaces a local temporary view with this [DataFrame] pub async fn create_or_replace_temp_view(self, name: &str) -> Result<(), SparkError> { self.create_view_cmd(name, false, true).await } + /// Creates a local temporary view with this [DataFrame] + pub async fn create_temp_view(self, name: &str) -> Result<(), SparkError> { + self.create_view_cmd(name, false, false).await + } + async fn create_view_cmd( self, name: &str, @@ -296,7 +329,7 @@ impl DataFrame { } } - /// Create a multi-dimensional cube for the current DataFrame using the specified columns, so we can run aggregations on them. + /// Create a multi-dimensional cube for the current [DataFrame] using the specified columns, so we can run aggregations on them. pub fn cube(self, cols: T) -> GroupedData { GroupedData::new(self, GroupType::Cube, cols.to_vec_expr(), None, None) } @@ -341,7 +374,7 @@ impl DataFrame { /// If no columns are supplied then it all columns are used /// pub fn drop_duplicates(self, cols: Option>) -> DataFrame { - let plan = self.plan.drop_duplicates(cols); + let plan = self.plan.drop_duplicates(cols, false); DataFrame { spark_session: self.spark_session, @@ -349,7 +382,21 @@ impl DataFrame { } } - /// Returns a new DataFrame omitting rows with null values. + /// Return a new [DataFrame] with duplicate rows removed, + /// optionally only considering certain columns, within watermark. + /// + /// This only works with streaming [DataFrame], and watermark for the input [DataFrame] must be set via `with_watermark()`. + /// + pub fn drop_duplicates_within_waterwmark(self, cols: Option>) -> DataFrame { + let plan = self.plan.drop_duplicates(cols, true); + + DataFrame { + spark_session: self.spark_session, + plan, + } + } + + /// Returns a new [DataFrame] omitting rows with null values. pub fn dropna(self, how: &str, threshold: Option, subset: Option>) -> DataFrame { let plan = self.plan.dropna(how, threshold, subset); @@ -359,8 +406,8 @@ impl DataFrame { } } - /// Returns all column names and their data types as a Vec containing - /// the field name as a String and the [spark::data_type::Kind] enum + /// Returns all column names and their data types as a `Vec` containing + /// the field name as a `String` and the [spark::data_type::Kind] enum pub async fn dtypes(self) -> Result, SparkError> { let schema = self.schema().await?; @@ -383,7 +430,7 @@ impl DataFrame { Ok(dtypes) } - /// Return a new DataFrame containing rows in this DataFrame but not in another DataFrame while preserving duplicates. + /// Return a new [DataFrame] containing rows in this [DataFrame] but not in another [DataFrame] while preserving duplicates. pub fn except_all(self, other: DataFrame) -> DataFrame { self.check_same_session(&other).unwrap(); @@ -428,6 +475,20 @@ impl DataFrame { Ok(explain) } + /// Replace null values, alias for `df.na().fill()`. + pub fn fillna<'a, I, T>(self, cols: Option, values: T) -> DataFrame + where + I: IntoIterator, + T: IntoIterator>, + { + let plan = self.plan.fillna(cols, values); + + DataFrame { + spark_session: self.spark_session, + plan, + } + } + /// Filters rows using a given conditions and returns a new [DataFrame] /// /// # Example: @@ -463,7 +524,7 @@ impl DataFrame { } } - /// Groups the DataFrame using the specified columns, and returns a [GroupedData] object + /// Groups the [DataFrame] using the specified columns, and returns a [GroupedData] object pub fn group_by(self, cols: Option) -> GroupedData { let grouping_cols = match cols { Some(cols) => cols.to_vec_expr(), @@ -477,7 +538,7 @@ impl DataFrame { self.limit(n.unwrap_or(1)).collect().await } - /// Specifies some hint on the current DataFrame. + /// Specifies some hint on the current [DataFrame]. pub fn hint(self, name: &str, parameters: Option) -> DataFrame { let plan = self.plan.hint(name, parameters); @@ -487,7 +548,7 @@ impl DataFrame { } } - /// Returns a best-effort snapshot of the files that compose this DataFrame + /// Returns a best-effort snapshot of the files that compose this [DataFrame] pub async fn input_files(self) -> Result, SparkError> { let input_files = spark::analyze_plan_request::Analyze::InputFiles( spark::analyze_plan_request::InputFiles { @@ -500,7 +561,7 @@ impl DataFrame { client.analyze(input_files).await?.input_files() } - /// Return a new DataFrame containing rows only in both this DataFrame and another DataFrame. + /// Return a new [DataFrame] containing rows only in both this [DataFrame] and another [DataFrame]. pub fn intersect(self, other: DataFrame) -> DataFrame { self.check_same_session(&other).unwrap(); @@ -512,6 +573,7 @@ impl DataFrame { } } + /// Return a new [DataFrame] containing rows in both this [DataFrame] and another [DataFrame] while preserving duplicates. pub fn intersect_all(self, other: DataFrame) -> DataFrame { self.check_same_session(&other).unwrap(); @@ -530,7 +592,19 @@ impl DataFrame { Ok(val.num_rows() == 0) } - /// Returns True if this DataFrame contains one or more sources that continuously return data as it arrives. + /// Returns `true` if the `collect()` and `take()` methods can be run locally (without any Spark executors). + pub async fn is_local(self) -> Result { + let is_local = + spark::analyze_plan_request::Analyze::IsLocal(spark::analyze_plan_request::IsLocal { + plan: Some(LogicalPlanBuilder::plan_root(self.plan)), + }); + + let mut client = self.spark_session.client(); + + client.analyze(is_local).await?.is_local() + } + + /// Returns `true` if this [DataFrame] contains one or more sources that continuously return data as it arrives. pub async fn is_streaming(self) -> Result { let is_streaming = spark::analyze_plan_request::Analyze::IsStreaming( spark::analyze_plan_request::IsStreaming { @@ -543,7 +617,7 @@ impl DataFrame { client.analyze(is_streaming).await?.is_streaming() } - /// Joins with another DataFrame, using the given join expression. + /// Joins with another [DataFrame], using the given join expression. /// /// # Example: /// ```rust @@ -598,6 +672,13 @@ impl DataFrame { self.unpivot(ids, values, variable_column_name, value_column_name) } + /// Returns a [DataFrameNaFunctions] for handling missing values. + pub fn na(self) -> DataFrameNaFunctions { + DataFrameNaFunctions::new(self) + } + + // !TODO observe + /// Returns a new [DataFrame] by skiping the first n rows pub fn offset(self, num: i32) -> DataFrame { let plan = self.plan.offset(num); @@ -608,11 +689,12 @@ impl DataFrame { } } + /// Returns a new [DataFrame] sorted by the specified column(s). pub fn order_by(self, cols: I) -> DataFrame where I: IntoIterator, { - let plan = self.plan.sort(cols); + let plan = self.plan.sort(cols, false); DataFrame { spark_session: self.spark_session, @@ -620,6 +702,7 @@ impl DataFrame { } } + /// Sets the storage level to persist the contents of the [DataFrame] across operations after the first time it is computed. pub async fn persist(self, storage_level: storage::StorageLevel) -> DataFrame { let analyze = spark::analyze_plan_request::Analyze::Persist(spark::analyze_plan_request::Persist { @@ -653,6 +736,47 @@ impl DataFrame { client.analyze(tree_string).await?.tree_string() } + /// Randomly splits this [DataFrame] with the provided weights. + pub fn random_split(self, weights: I, seed: Option) -> Vec + where + I: IntoIterator + Clone, + { + let seed = seed.unwrap_or(random::()); + let total: f64 = weights.clone().into_iter().sum(); + + let proportions: Vec = weights.into_iter().map(|v| v / total).collect(); + + let mut normalized_cum_weights = vec![0.0]; + + for &v in &proportions { + let prior_val = *normalized_cum_weights.last().unwrap(); + normalized_cum_weights.push(prior_val + v); + } + + let mut i = 1; + let length = normalized_cum_weights.len(); + let mut splits: Vec = vec![]; + + while i < length { + let lower_bound = *normalized_cum_weights.get(i - 1).unwrap(); + let upper_bound = *normalized_cum_weights.get(i).unwrap(); + + let plan = + self.clone() + .plan + .sample(lower_bound, upper_bound, Some(false), Some(seed), true); + + let df = DataFrame { + spark_session: self.clone().spark_session, + plan, + }; + splits.push(df); + i += 1; + } + + splits + } + /// Returns a new [DataFrame] partitioned by the given partition number and shuffle option /// /// # Arguments @@ -669,6 +793,34 @@ impl DataFrame { } } + /// Returns a new [DataFrame] partitioned by the given partitioning expressions. + pub fn repartition_by_range( + self, + num_partitions: Option, + cols: T, + ) -> DataFrame { + let plan = self.plan.repartition_by_range(num_partitions, cols); + + DataFrame { + spark_session: self.spark_session, + plan, + } + } + + /// Returns a new [DataFrame] replacing a value with another value. + pub fn replace<'a, I, T>(self, to_replace: T, value: T, subset: Option) -> DataFrame + where + I: IntoIterator, + T: IntoIterator>, + { + let plan = self.plan.replace(to_replace, value, subset); + + DataFrame { + spark_session: self.spark_session, + plan, + } + } + /// Create a multi-dimensional rollup for the current DataFrame using the specified columns, /// and returns a [GroupedData] object pub fn rollup(self, cols: T) -> GroupedData { @@ -702,7 +854,24 @@ impl DataFrame { ) -> DataFrame { let plan = self .plan - .sample(lower_bound, upper_bound, with_replacement, seed); + .sample(lower_bound, upper_bound, with_replacement, seed, false); + + DataFrame { + spark_session: self.spark_session, + plan, + } + } + + /// Returns a stratified sample without replacement based on the fraction given on each stratum. + pub fn sample_by(self, col: T, fractions: I, seed: Option) -> DataFrame + where + K: ToLiteral, + T: ToExpr, + I: IntoIterator, + { + let seed = seed.unwrap_or(random::()); + + let plan = self.plan.sample_by(col, fractions, seed); DataFrame { spark_session: self.spark_session, @@ -710,8 +879,8 @@ impl DataFrame { } } - /// Returns the schema of this DataFrame as a [spark::DataType] - /// which contains the schema of a DataFrame + /// Returns the schema of this [DataFrame] as a [spark::DataType] + /// which contains the schema of a [DataFrame] pub async fn schema(self) -> Result { let plan = LogicalPlanBuilder::plan_root(self.plan); @@ -768,6 +937,7 @@ impl DataFrame { } } + /// Returns a hash code of the logical query plan against this [DataFrame]. pub async fn semantic_hash(self) -> Result { let plan = LogicalPlanBuilder::plan_root(self.plan); @@ -808,11 +978,12 @@ impl DataFrame { Ok(pretty::print_batches(&[rows])?) } + /// Returns a new [DataFrame] sorted by the specified column(s). pub fn sort(self, cols: I) -> DataFrame where I: IntoIterator, { - let plan = self.plan.sort(cols); + let plan = self.plan.sort(cols, true); DataFrame { spark_session: self.spark_session, @@ -820,10 +991,30 @@ impl DataFrame { } } + /// Returns a new [DataFrame] with each partition sorted by the specified column(s). + pub fn sort_within_partitions(self, cols: I) -> DataFrame + where + I: IntoIterator, + { + let plan = self.plan.sort(cols, false); + + DataFrame { + spark_session: self.spark_session, + plan, + } + } + + /// Returns Spark session that created this DataFrame. pub fn spark_session(self) -> Box { self.spark_session } + /// Returns a DataFrameStatFunctions for statistic functions. + pub fn stat(self) -> DataFrameStatFunctions { + DataFrameStatFunctions::new(self) + } + + /// Get the DataFrame’s current storage level. pub async fn storage_level(self) -> Result { let storage_level = spark::analyze_plan_request::Analyze::GetStorageLevel( spark::analyze_plan_request::GetStorageLevel { @@ -837,6 +1028,7 @@ impl DataFrame { Ok(storage?.into()) } + /// Return a new [DataFrame] containing rows in this [DataFrame] but not in another [DataFrame]. pub fn subtract(self, other: DataFrame) -> DataFrame { self.check_same_session(&other).unwrap(); @@ -848,6 +1040,31 @@ impl DataFrame { } } + /// Computes specified statistics for numeric and string columns. + /// Available statistics are: + /// - count + /// - mean + /// - stddev + /// - min + /// - max + /// - arbitrary approximate percentiles specified as a percentage (e.g., 75%) + /// + /// If no statistics are given, this function computes count, mean, stddev, min, + /// approximate quartiles (percentiles at 25%, 50%, and 75%), and max + /// + pub fn summary(self, statistics: Option) -> DataFrame + where + T: AsRef, + I: IntoIterator, + { + let plan = self.plan.summary(statistics); + + DataFrame { + spark_session: self.spark_session, + plan, + } + } + /// Returns the last `n` rows as a [RecordBatch] /// /// Running tail requires moving the data and results in an action @@ -868,10 +1085,12 @@ impl DataFrame { df.collect().await } + /// Returns the first `num` rows as a RecordBatch. pub async fn take(self, n: i32) -> Result { self.limit(n).collect().await } + /// Returns a new [DataFrame] that with new specified column names pub fn to_df<'a, I>(self, cols: I) -> DataFrame where I: IntoIterator, @@ -954,6 +1173,7 @@ impl DataFrame { func(self) } + /// Return a new [DataFrame] containing the union of rows in this and another [DataFrame]. pub fn union(self, other: DataFrame) -> DataFrame { self.check_same_session(&other).unwrap(); @@ -965,6 +1185,7 @@ impl DataFrame { } } + /// Return a new [DataFrame] containing the union of rows in this and another [DataFrame]. pub fn union_all(self, other: DataFrame) -> DataFrame { self.check_same_session(&other).unwrap(); @@ -976,6 +1197,7 @@ impl DataFrame { } } + /// Returns a new [DataFrame] containing union of rows in this and another [DataFrame]. pub fn union_by_name(self, other: DataFrame, allow_missing_columns: Option) -> DataFrame { self.check_same_session(&other).unwrap(); @@ -987,6 +1209,7 @@ impl DataFrame { } } + /// Marks the [DataFrame] as non-persistent, and remove all blocks for it from memory and disk. pub async fn unpersist(self, blocking: Option) -> DataFrame { let unpersist = spark::analyze_plan_request::Analyze::Unpersist( spark::analyze_plan_request::Unpersist { @@ -1031,6 +1254,7 @@ impl DataFrame { } } + /// Returns a new [DataFrame] by adding a column or replacing the existing column that has the same name. pub fn with_column(self, col_name: &str, col: Column) -> DataFrame { let plan = self.plan.with_column(col_name, col); @@ -1040,12 +1264,13 @@ impl DataFrame { } } + /// Returns a new [DataFrame] by adding multiple columns or replacing the existing columns that have the same names. pub fn with_columns(self, col_map: I) -> DataFrame where I: IntoIterator, - K: ToString, + K: AsRef, { - let plan = self.plan.with_columns(col_map); + let plan = self.plan.with_columns(col_map, None::>); DataFrame { spark_session: self.spark_session, @@ -1053,6 +1278,15 @@ impl DataFrame { } } + /// Returns a new [DataFrame] by renaming an existing column. + pub fn with_column_renamed(self, existing: K, new: V) -> DataFrame + where + K: AsRef, + V: AsRef, + { + self.with_columns_renamed([(existing, new)]) + } + /// Returns a new [DataFrame] by renaming multiple columns from a /// an iterator of containing a key/value pair with the key as the `existing` /// column name and the value as the `new` column name. @@ -1069,20 +1303,140 @@ impl DataFrame { plan, } } + + /// Returns a new [DataFrame] by updating an existing column with metadata. + pub fn with_metadata(self, col: &str, metadata: &str) -> DataFrame { + let col_map = vec![(col, col)]; + + let plan = self.plan.with_columns(col_map, Some(vec![metadata])); + + DataFrame { + spark_session: self.spark_session, + plan, + } + } + + /// Defines an event time watermark for this [DataFrame]. + pub fn with_watermark(self, event_time: &str, delay_threshold: &str) -> DataFrame { + let plan = self.plan.with_watermark(event_time, delay_threshold); + + DataFrame { + spark_session: self.spark_session, + plan, + } + } + /// Returns a [DataFrameWriter] struct based on the current [DataFrame] pub fn write(self) -> DataFrameWriter { DataFrameWriter::new(self) } - pub fn write_to(self, table: &str) -> DataFrameWriterV2 { - DataFrameWriterV2::new(self, table) - } - /// Interface for [DataStreamWriter] to save the content of the streaming DataFrame out /// into external storage. pub fn write_stream(self) -> DataStreamWriter { DataStreamWriter::new(self) } + + /// Create a write configuration builder for v2 sources with [DataFrameWriterV2]. + pub fn write_to(self, table: &str) -> DataFrameWriterV2 { + DataFrameWriterV2::new(self, table) + } +} + +/// Functionality for working with missing data in [DataFrame]. +#[derive(Clone, Debug)] +pub struct DataFrameStatFunctions { + df: DataFrame, +} + +impl DataFrameStatFunctions { + pub(crate) fn new(df: DataFrame) -> DataFrameStatFunctions { + DataFrameStatFunctions { df } + } + + /// Calculates the approximate quantiles of numerical columns of a [DataFrame]. + pub async fn approx_quantile<'a, I, P>( + self, + cols: I, + probabilities: P, + relative_error: f64, + ) -> Result + where + I: IntoIterator, + P: IntoIterator, + { + self.df + .approx_quantile(cols, probabilities, relative_error) + .await + } + + /// Calculates the correlation of two columns of a [DataFrame] as a double value. + pub async fn corr(self, col1: &str, col2: &str) -> Result { + self.df.corr(col1, col2).await + } + + /// Calculate the sample covariance for the given columns, specified by their names, as a double value. + pub async fn cov(self, col1: &str, col2: &str) -> Result { + self.df.cov(col1, col2).await + } + + /// Computes a pair-wise frequency table of the given columns. + pub fn crosstab(self, col1: &str, col2: &str) -> DataFrame { + self.df.crosstab(col1, col2) + } + + /// Finding frequent items for columns, possibly with false positives. + pub fn freq_items<'a, I>(self, cols: I, support: Option) -> DataFrame + where + I: IntoIterator, + { + self.df.freq_items(cols, support) + } + + /// Returns a stratified sample without replacement based on the fraction given on each stratum. + pub fn sample_by(self, col: T, fractions: I, seed: Option) -> DataFrame + where + K: ToLiteral, + T: ToExpr, + I: IntoIterator, + { + self.df.sample_by(col, fractions, seed) + } +} + +/// Functionality for statistic functions with [DataFrame]. +#[derive(Clone, Debug)] +pub struct DataFrameNaFunctions { + df: DataFrame, +} + +impl DataFrameNaFunctions { + pub(crate) fn new(df: DataFrame) -> DataFrameNaFunctions { + DataFrameNaFunctions { df } + } + + /// Returns a new [DataFrame] omitting rows with null values. + pub fn drop(self, how: &str, threshold: Option, subset: Option>) -> DataFrame { + self.df.dropna(how, threshold, subset) + } + + /// Replace null values, alias for `df.na().fill()`. + pub fn fill<'a, I, T>(self, cols: Option, values: T) -> DataFrame + where + I: IntoIterator, + T: IntoIterator>, + { + self.df.fillna(cols, values) + } + + /// Returns a new [DataFrame] replacing a value with another value. + pub fn replace<'a, I, T>(self, to_replace: T, value: T, subset: Option) -> DataFrame + where + I: IntoIterator, + T: IntoIterator>, + { + self.df.replace(to_replace, value, subset) + } } #[cfg(test)] @@ -2160,4 +2514,203 @@ mod tests { assert!(val_clone.contains("== Physical Plan ==")); Ok(()) } + + #[tokio::test] + async fn test_df_random_split() -> Result<(), SparkError> { + let spark = setup().await; + + let name: ArrayRef = Arc::new(StringArray::from(vec![ + Some("Alice"), + Some("Bob"), + Some("Tom"), + None, + ])); + + let age: ArrayRef = Arc::new(Int64Array::from(vec![Some(10), Some(5), None, None])); + let height: ArrayRef = Arc::new(Int64Array::from(vec![Some(80), None, None, None])); + + let schema = Schema::new(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("age", DataType::Int64, true), + Field::new("height", DataType::Int64, true), + ]); + + let data = RecordBatch::try_new(Arc::new(schema), vec![name, age, height])?; + + let df = spark.create_dataframe(&data)?; + + let splits = df.random_split([1.0, 2.0], Some(24)); + + let df_one = splits.get(0).unwrap().clone().count().await?; + let df_two = splits.get(1).unwrap().clone().count().await?; + + assert_eq!(2, df_one); + assert_eq!(2, df_two); + Ok(()) + } + + #[tokio::test] + async fn test_df_fillna() -> Result<(), SparkError> { + let spark = setup().await; + + let name: ArrayRef = Arc::new(StringArray::from(vec![Some("Alice"), None, Some("Tom")])); + + let age: ArrayRef = Arc::new(Int64Array::from(vec![Some(10), None, None])); + + let schema = Schema::new(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("age", DataType::Int64, true), + ]); + + let data = RecordBatch::try_new(Arc::new(schema.clone()), vec![name, age])?; + + let df = spark.create_dataframe(&data)?; + + let output = df + .fillna( + None::>, + vec![Box::new(80_i64) as Box], + ) + .collect() + .await?; + + let name: ArrayRef = Arc::new(StringArray::from(vec![Some("Alice"), None, Some("Tom")])); + + let age: ArrayRef = Arc::new(Int64Array::from(vec![10, 80, 80])); + + let schema = Schema::new(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("age", DataType::Int64, false), + ]); + + let expected = RecordBatch::try_new(Arc::new(schema), vec![name, age])?; + + assert_eq!(expected, output); + Ok(()) + } + + #[tokio::test] + async fn test_df_replace() -> Result<(), SparkError> { + let spark = setup().await; + + let name: ArrayRef = Arc::new(StringArray::from(vec![ + Some("Alice"), + Some("Bob"), + Some("Tom"), + None, + ])); + + let age: ArrayRef = Arc::new(Int64Array::from(vec![Some(10), Some(5), None, None])); + let height: ArrayRef = Arc::new(Int64Array::from(vec![Some(80), None, Some(10), None])); + + let schema = Schema::new(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("age", DataType::Int64, true), + Field::new("height", DataType::Int64, true), + ]); + + let data = RecordBatch::try_new(Arc::new(schema), vec![name, age, height])?; + + let df = spark.create_dataframe(&data)?; + + let df = df.replace( + vec![Box::new(10) as Box], + vec![Box::new(20) as Box], + None::>, + ); + + let output = df + .filter("name in ('Alice', 'Tom')") + .select(["name", "age", "height"]) + .collect() + .await?; + + let name: ArrayRef = Arc::new(StringArray::from(vec![Some("Alice"), Some("Tom")])); + + let age: ArrayRef = Arc::new(Int64Array::from(vec![Some(20), None])); + let height: ArrayRef = Arc::new(Int64Array::from(vec![Some(80), Some(20)])); + + let schema = Schema::new(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("age", DataType::Int64, true), + Field::new("height", DataType::Int64, true), + ]); + + let expected = RecordBatch::try_new(Arc::new(schema), vec![name, age, height])?; + + assert_eq!(expected, output); + Ok(()) + } + + #[tokio::test] + async fn test_df_summary() -> Result<(), SparkError> { + let spark = setup().await; + + let data = mock_data(); + + let df = spark.create_dataframe(&data)?; + + let output = df + .select(["age"]) + .summary(None::>) + .select(["summary"]) + .collect() + .await?; + + let summary: ArrayRef = Arc::new(StringArray::from(vec![ + "count", "mean", "stddev", "min", "25%", "50%", "75%", "max", + ])); + + let expected = + RecordBatch::try_from_iter_with_nullable(vec![("summary", summary, true)]).unwrap(); + + assert_eq!(expected, output); + + Ok(()) + } + + #[tokio::test] + async fn test_df_sample_by() -> Result<(), SparkError> { + let spark = setup().await; + + let df = spark + .range(Some(0), 100, 1, None) + .select([(col("id") % lit(3)).alias("key")]); + + let sampled = df.sample_by("key", [(0, 0.1), (1, 0.2)], Some(0)); + + let output = sampled.group_by(Some(["key"])).count().collect().await?; + + assert_eq!(output.num_rows(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_df_with_metadata() -> Result<(), SparkError> { + let spark = setup().await; + + let data = mock_data(); + + let df = spark.create_dataframe(&data)?; + + let metadata_val = "{\"foo\":\"bar\"}"; + + let val = df + .clone() + .with_metadata("name", metadata_val) + .select(["name"]) + .schema() + .await?; + + let output = match val.kind.unwrap() { + spark::data_type::Kind::Struct(val) => { + val.fields.get(0).unwrap().metadata.clone().unwrap() + } + _ => unimplemented!(), + }; + + assert_eq!(metadata_val.to_string(), output); + Ok(()) + } } diff --git a/core/src/plan.rs b/core/src/plan.rs index 573cfe5..c2ffdb2 100644 --- a/core/src/plan.rs +++ b/core/src/plan.rs @@ -5,7 +5,7 @@ use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering::SeqCst; use crate::errors::SparkError; -use crate::expressions::{ToExpr, ToFilterExpr, ToVecExpr}; +use crate::expressions::{ToExpr, ToFilterExpr, ToLiteral, ToVecExpr}; use crate::spark; use arrow::array::RecordBatch; @@ -220,7 +220,7 @@ impl LogicalPlanBuilder { } pub fn distinct(self) -> LogicalPlanBuilder { - self.drop_duplicates::>(None) + self.drop_duplicates::>(None, false) } pub fn drop(self, cols: T) -> LogicalPlanBuilder { @@ -233,27 +233,33 @@ impl LogicalPlanBuilder { LogicalPlanBuilder::from(drop_expr) } - pub fn drop_duplicates<'a, I>(self, cols: Option) -> LogicalPlanBuilder + pub fn drop_duplicates<'a, I>( + self, + cols: Option, + within_watermark: bool, + ) -> LogicalPlanBuilder where I: IntoIterator + std::default::Default, { let drop_expr = match cols { - Some(cols) => RelType::Deduplicate(Box::new(spark::Deduplicate { + Some(cols) => spark::Deduplicate { input: self.relation_input(), column_names: cols.into_iter().map(|col| col.to_string()).collect(), all_columns_as_keys: Some(false), - within_watermark: Some(false), - })), + within_watermark: Some(within_watermark), + }, - None => RelType::Deduplicate(Box::new(spark::Deduplicate { + None => spark::Deduplicate { input: self.relation_input(), column_names: vec![], all_columns_as_keys: Some(true), - within_watermark: Some(false), - })), + within_watermark: Some(within_watermark), + }, }; - LogicalPlanBuilder::from(drop_expr) + let rel_type = RelType::Deduplicate(Box::new(drop_expr)); + + LogicalPlanBuilder::from(rel_type) } pub fn dropna<'a, I>( @@ -290,6 +296,30 @@ impl LogicalPlanBuilder { LogicalPlanBuilder::from(dropna_rel) } + pub fn fillna<'a, I, T>(self, cols: Option, values: T) -> LogicalPlanBuilder + where + I: IntoIterator, + T: IntoIterator>, + { + let cols: Vec = match cols { + Some(cols) => cols.into_iter().map(|v| v.to_string()).collect(), + None => vec![], + }; + + let values: Vec = + values.into_iter().map(|v| v.to_literal()).collect(); + + println!("{:?}", values); + + let fillna = RelType::FillNa(Box::new(spark::NaFill { + input: self.relation_input(), + cols, + values, + })); + + LogicalPlanBuilder::from(fillna) + } + pub fn to_df<'a, I>(self, cols: I) -> LogicalPlanBuilder where I: IntoIterator, @@ -403,6 +433,28 @@ impl LogicalPlanBuilder { LogicalPlanBuilder::from(rel_type) } + pub fn approx_quantile<'a, I, P>( + self, + cols: I, + probabilities: P, + relative_error: f64, + ) -> LogicalPlanBuilder + where + I: IntoIterator, + P: IntoIterator, + { + let approx_quantile = spark::StatApproxQuantile { + input: self.relation_input(), + cols: cols.into_iter().map(|col| col.to_string()).collect(), + probabilities: probabilities.into_iter().collect(), + relative_error, + }; + + let approx_quantile_rel = RelType::ApproxQuantile(Box::new(approx_quantile)); + + LogicalPlanBuilder::from(approx_quantile_rel) + } + pub fn freq_items<'a, I>(self, cols: I, support: Option) -> LogicalPlanBuilder where I: IntoIterator, @@ -493,12 +545,58 @@ impl LogicalPlanBuilder { LogicalPlanBuilder::from(repart_expr) } + pub fn repartition_by_range( + self, + num_partitions: Option, + cols: T, + ) -> LogicalPlanBuilder { + let repart_expr = + RelType::RepartitionByExpression(Box::new(spark::RepartitionByExpression { + input: self.relation_input(), + num_partitions, + partition_exprs: cols.to_vec_expr(), + })); + + LogicalPlanBuilder::from(repart_expr) + } + + pub fn replace<'a, I, T>(self, to_replace: T, value: T, subset: Option) -> LogicalPlanBuilder + where + I: IntoIterator, + T: IntoIterator>, + { + let cols: Vec = match subset { + Some(subset) => subset.into_iter().map(|v| v.to_string()).collect(), + None => vec![], + }; + + let replacements = to_replace + .into_iter() + .zip(value) + .map(|(a, b)| spark::na_replace::Replacement { + old_value: Some(a.to_literal()), + new_value: Some(b.to_literal()), + }) + .collect(); + + let replace = spark::NaReplace { + input: self.relation_input(), + cols, + replacements, + }; + + let replace_expr = RelType::Replace(Box::new(replace)); + + LogicalPlanBuilder::from(replace_expr) + } + pub fn sample( self, lower_bound: f64, upper_bound: f64, with_replacement: Option, seed: Option, + deterministic_order: bool, ) -> LogicalPlanBuilder { let sample_expr = RelType::Sample(Box::new(spark::Sample { input: self.relation_input(), @@ -506,7 +604,31 @@ impl LogicalPlanBuilder { upper_bound, with_replacement, seed, - deterministic_order: false, + deterministic_order, + })); + + LogicalPlanBuilder::from(sample_expr) + } + + pub fn sample_by(self, col: T, fractions: I, seed: i64) -> LogicalPlanBuilder + where + K: ToLiteral, + T: ToExpr, + I: IntoIterator, + { + let fractions = fractions + .into_iter() + .map(|(k, v)| spark::stat_sample_by::Fraction { + stratum: Some(k.to_literal()), + fraction: v, + }) + .collect(); + + let sample_expr = RelType::SampleBy(Box::new(spark::StatSampleBy { + input: self.relation_input(), + col: Some(col.to_expr()), + fractions, + seed: Some(seed), })); LogicalPlanBuilder::from(sample_expr) @@ -521,16 +643,17 @@ impl LogicalPlanBuilder { LogicalPlanBuilder::from(rel_type) } - pub fn select_expr<'a, I>(self, cols: I) -> LogicalPlanBuilder + pub fn select_expr(self, cols: I) -> LogicalPlanBuilder where - I: IntoIterator, + T: AsRef, + I: IntoIterator, { let expressions = cols .into_iter() .map(|col| spark::Expression { expr_type: Some(spark::expression::ExprType::ExpressionString( spark::expression::ExpressionString { - expression: col.to_string(), + expression: col.as_ref().to_string(), }, )), }) @@ -544,7 +667,7 @@ impl LogicalPlanBuilder { LogicalPlanBuilder::from(rel_type) } - pub fn sort(self, cols: I) -> LogicalPlanBuilder + pub fn sort(self, cols: I, is_global: bool) -> LogicalPlanBuilder where T: ToExpr, I: IntoIterator, @@ -553,12 +676,39 @@ impl LogicalPlanBuilder { let sort_type = RelType::Sort(Box::new(spark::Sort { order, input: self.relation_input(), - is_global: None, + is_global: Some(is_global), })); LogicalPlanBuilder::from(sort_type) } + pub fn summary(self, statistics: Option) -> LogicalPlanBuilder + where + T: AsRef, + I: IntoIterator, + { + let statistics = match statistics { + Some(stats) => stats.into_iter().map(|s| s.as_ref().to_string()).collect(), + None => vec![ + "count".to_string(), + "mean".to_string(), + "stddev".to_string(), + "min".to_string(), + "25%".to_string(), + "50%".to_string(), + "75%".to_string(), + "max".to_string(), + ], + }; + + let stats = RelType::Summary(Box::new(spark::StatSummary { + input: self.relation_input(), + statistics, + })); + + LogicalPlanBuilder::from(stats) + } + pub fn with_column(self, col_name: &str, col: T) -> LogicalPlanBuilder { let aliases: Vec = vec![spark::expression::Alias { expr: Some(Box::new(col.to_expr())), @@ -574,21 +724,29 @@ impl LogicalPlanBuilder { LogicalPlanBuilder::from(with_col) } - pub fn with_columns(self, col_map: I) -> LogicalPlanBuilder + pub fn with_columns(self, col_map: I, metadata: Option) -> LogicalPlanBuilder where T: ToExpr, - K: ToString, + K: AsRef, + N: AsRef, I: IntoIterator, + M: IntoIterator, { - let aliases: Vec = col_map + let mut aliases: Vec = col_map .into_iter() .map(|(name, col)| spark::expression::Alias { expr: Some(Box::new(col.to_expr())), - name: vec![name.to_string()], + name: vec![name.as_ref().to_string()], metadata: None, }) .collect(); + if let Some(meta_iter) = metadata { + for (alias, meta) in aliases.iter_mut().zip(meta_iter) { + alias.metadata = Some(meta.as_ref().to_string()); + } + }; + let with_col = RelType::WithColumns(Box::new(spark::WithColumns { input: self.relation_input(), aliases, @@ -615,6 +773,20 @@ impl LogicalPlanBuilder { LogicalPlanBuilder::from(rename_expr) } + + pub fn with_watermark(self, event_time: T, delay_threshold: D) -> LogicalPlanBuilder + where + T: AsRef, + D: AsRef, + { + let watermark_expr = RelType::WithWatermark(Box::new(spark::WithWatermark { + input: self.relation_input(), + event_time: event_time.as_ref().to_string(), + delay_threshold: delay_threshold.as_ref().to_string(), + })); + + LogicalPlanBuilder::from(watermark_expr) + } } pub(crate) fn sort_order(cols: I) -> Vec diff --git a/core/src/window.rs b/core/src/window.rs index 39fbbc3..44fd329 100644 --- a/core/src/window.rs +++ b/core/src/window.rs @@ -155,7 +155,7 @@ impl Window { /// Both start and end are relative from the current row. For example, “0” means “current row”, /// while “-1” means one off before the current row, and “5” means the five off after the current row. /// - /// Recommended to use [Window::unboundedPreceding], [Window::unboundedFollowing], and [Window::currentRow] + /// Recommended to use [Window::unbounded_preceding], [Window::unbounded_following], and [Window::current_row] /// to specify special boundary values, rather than using integral values directly. /// /// # Example @@ -164,7 +164,7 @@ impl Window { /// let window = Window::new() /// .partition_by(col("name")) /// .order_by([col("age")]) - /// .range_between(Window::unboundedPreceding(), Window::currentRow()); + /// .range_between(Window::unbounded_preceding(), Window::current_row()); /// /// let df = df.with_column("rank", rank().over(window.clone())) /// .with_column("min", min("age").over(window)); @@ -180,7 +180,7 @@ impl Window { /// Both start and end are relative from the current row. For example, “0” means “current row”, /// while “-1” means one off before the current row, and “5” means the five off after the current row. /// - /// Recommended to use [Window::unboundedPreceding], [Window::unboundedFollowing], and [Window::currentRow] + /// Recommended to use [Window::unbounded_preceding], [Window::unbounded_following], and [Window::current_row] /// to specify special boundary values, rather than using integral values directly. /// /// # Example