Skip to content

Commit

Permalink
[SPARK-17551][SQL] Add DataFrame API for null ordering
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This pull request adds Scala/Java DataFrame API for null ordering (NULLS FIRST | LAST).

Also did some minor clean up for related code (e.g. incorrect indentation), and renamed "orderby-nulls-ordering.sql" to be consistent with existing test files.

## How was this patch tested?
Added a new test case in DataFrameSuite.

Author: petermaxlee <[email protected]>
Author: Xin Wu <[email protected]>

Closes apache#15123 from petermaxlee/SPARK-17551.
  • Loading branch information
xwu0226 authored and hvanhovell committed Sep 25, 2016
1 parent 7945dae commit de333d1
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ case object NullsLast extends NullOrdering{
* An expression that can be used to sort a tuple. This class extends expression primarily so that
* transformations over expression will descend into its child.
*/
case class SortOrder(
child: Expression,
direction: SortDirection,
nullOrdering: NullOrdering)
case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: NullOrdering)
extends UnaryExpression with Unevaluable {

/** Sort order is not foldable because we don't have an eval for it. */
Expand Down Expand Up @@ -94,34 +91,23 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {

val nullValue = child.child.dataType match {
case BooleanType | DateType | TimestampType | _: IntegralType =>
if (nullAsSmallest) {
Long.MinValue
} else {
Long.MaxValue
}
if (nullAsSmallest) Long.MinValue else Long.MaxValue
case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
if (nullAsSmallest) {
Long.MinValue
} else {
Long.MaxValue
}
if (nullAsSmallest) Long.MinValue else Long.MaxValue
case _: DecimalType =>
if (nullAsSmallest) {
DoublePrefixComparator.computePrefix(Double.NegativeInfinity)
} else {
DoublePrefixComparator.computePrefix(Double.NaN)
}
case _ =>
if (nullAsSmallest) {
0L
} else {
-1L
}
if (nullAsSmallest) 0L else -1L
}

private def nullAsSmallest: Boolean = (child.isAscending && child.nullOrdering == NullsFirst) ||
private def nullAsSmallest: Boolean = {
(child.isAscending && child.nullOrdering == NullsFirst) ||
(!child.isAscending && child.nullOrdering == NullsLast)

}

override def eval(input: InternalRow): Any = throw new UnsupportedOperationException

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,16 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
// Nothing
} else if ($isNullA) {
return ${
order.nullOrdering match {
case NullsFirst => "-1"
case NullsLast => "1"
}};
order.nullOrdering match {
case NullsFirst => "-1"
case NullsLast => "1"
}};
} else if ($isNullB) {
return ${
order.nullOrdering match {
case NullsFirst => "1"
case NullsLast => "-1"
}};
order.nullOrdering match {
case NullsFirst => "1"
case NullsLast => "-1"
}};
} else {
int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)};
if (comp != 0) {
Expand Down
64 changes: 62 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
/**
* Returns an ordering used in sorting.
* {{{
* // Scala: sort a DataFrame by age column in descending order.
* // Scala
* df.sort(df("age").desc)
*
* // Java
Expand All @@ -1020,7 +1020,37 @@ class Column(protected[sql] val expr: Expression) extends Logging {
def desc: Column = withExpr { SortOrder(expr, Descending) }

/**
* Returns an ordering used in sorting.
* Returns a descending ordering used in sorting, where null values appear before non-null values.
* {{{
* // Scala: sort a DataFrame by age column in descending order and null values appearing first.
* df.sort(df("age").desc_nulls_first)
*
* // Java
* df.sort(df.col("age").desc_nulls_first());
* }}}
*
* @group expr_ops
* @since 2.1.0
*/
def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst) }

/**
* Returns a descending ordering used in sorting, where null values appear after non-null values.
* {{{
* // Scala: sort a DataFrame by age column in descending order and null values appearing last.
* df.sort(df("age").desc_nulls_last)
*
* // Java
* df.sort(df.col("age").desc_nulls_last());
* }}}
*
* @group expr_ops
* @since 2.1.0
*/
def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast) }

/**
* Returns an ascending ordering used in sorting.
* {{{
* // Scala: sort a DataFrame by age column in ascending order.
* df.sort(df("age").asc)
Expand All @@ -1034,6 +1064,36 @@ class Column(protected[sql] val expr: Expression) extends Logging {
*/
def asc: Column = withExpr { SortOrder(expr, Ascending) }

/**
* Returns an ascending ordering used in sorting, where null values appear before non-null values.
* {{{
* // Scala: sort a DataFrame by age column in ascending order and null values appearing first.
* df.sort(df("age").asc_nulls_last)
*
* // Java
* df.sort(df.col("age").asc_nulls_last());
* }}}
*
* @group expr_ops
* @since 2.1.0
*/
def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst) }

/**
* Returns an ordering used in sorting, where null values appear after non-null values.
* {{{
* // Scala: sort a DataFrame by age column in ascending order and null values appearing last.
* df.sort(df("age").asc_nulls_last)
*
* // Java
* df.sort(df.col("age").asc_nulls_last());
* }}}
*
* @group expr_ops
* @since 2.1.0
*/
def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast) }

/**
* Prints the expression to the console for debugging purpose.
*
Expand Down
51 changes: 49 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ object functions {
/**
* Returns a sort expression based on ascending order of the column.
* {{{
* // Sort by dept in ascending order, and then age in descending order.
* df.sort(asc("dept"), desc("age"))
* }}}
*
Expand All @@ -118,10 +117,33 @@ object functions {
*/
def asc(columnName: String): Column = Column(columnName).asc

/**
* Returns a sort expression based on ascending order of the column,
* and null values return before non-null values.
* {{{
* df.sort(asc_nulls_last("dept"), desc("age"))
* }}}
*
* @group sort_funcs
* @since 2.1.0
*/
def asc_nulls_first(columnName: String): Column = Column(columnName).asc_nulls_first

/**
* Returns a sort expression based on ascending order of the column,
* and null values appear after non-null values.
* {{{
* df.sort(asc_nulls_last("dept"), desc("age"))
* }}}
*
* @group sort_funcs
* @since 2.1.0
*/
def asc_nulls_last(columnName: String): Column = Column(columnName).asc_nulls_last

/**
* Returns a sort expression based on the descending order of the column.
* {{{
* // Sort by dept in ascending order, and then age in descending order.
* df.sort(asc("dept"), desc("age"))
* }}}
*
Expand All @@ -130,6 +152,31 @@ object functions {
*/
def desc(columnName: String): Column = Column(columnName).desc

/**
* Returns a sort expression based on the descending order of the column,
* and null values appear before non-null values.
* {{{
* df.sort(asc("dept"), desc_nulls_first("age"))
* }}}
*
* @group sort_funcs
* @since 2.1.0
*/
def desc_nulls_first(columnName: String): Column = Column(columnName).desc_nulls_first

/**
* Returns a sort expression based on the descending order of the column,
* and null values appear after non-null values.
* {{{
* df.sort(asc("dept"), desc_nulls_last("age"))
* }}}
*
* @group sort_funcs
* @since 2.1.0
*/
def desc_nulls_last(columnName: String): Column = Column(columnName).desc_nulls_last


//////////////////////////////////////////////////////////////////////////////////////////////
// Aggregate functions
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
18 changes: 18 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row(6))
}

test("sorting with null ordering") {
val data = Seq[java.lang.Integer](2, 1, null).toDF("key")

checkAnswer(data.orderBy('key.asc), Row(null) :: Row(1) :: Row(2) :: Nil)
checkAnswer(data.orderBy(asc("key")), Row(null) :: Row(1) :: Row(2) :: Nil)
checkAnswer(data.orderBy('key.asc_nulls_first), Row(null) :: Row(1) :: Row(2) :: Nil)
checkAnswer(data.orderBy(asc_nulls_first("key")), Row(null) :: Row(1) :: Row(2) :: Nil)
checkAnswer(data.orderBy('key.asc_nulls_last), Row(1) :: Row(2) :: Row(null) :: Nil)
checkAnswer(data.orderBy(asc_nulls_last("key")), Row(1) :: Row(2) :: Row(null) :: Nil)

checkAnswer(data.orderBy('key.desc), Row(2) :: Row(1) :: Row(null) :: Nil)
checkAnswer(data.orderBy(desc("key")), Row(2) :: Row(1) :: Row(null) :: Nil)
checkAnswer(data.orderBy('key.desc_nulls_first), Row(null) :: Row(2) :: Row(1) :: Nil)
checkAnswer(data.orderBy(desc_nulls_first("key")), Row(null) :: Row(2) :: Row(1) :: Nil)
checkAnswer(data.orderBy('key.desc_nulls_last), Row(2) :: Row(1) :: Row(null) :: Nil)
checkAnswer(data.orderBy(desc_nulls_last("key")), Row(2) :: Row(1) :: Row(null) :: Nil)
}

test("global sorting") {
checkAnswer(
testData2.orderBy('a.asc, 'b.asc),
Expand Down

0 comments on commit de333d1

Please sign in to comment.