Skip to content

Commit

Permalink
[SPARK-14525][SQL] Make DataFrameWrite.save work for jdbc
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This change modifies the implementation of DataFrameWriter.save such that it works with jdbc, and the call to jdbc merely delegates to save.

## How was this patch tested?

This was tested via unit tests in the JDBCWriteSuite, of which I added one new test to cover this scenario.

## Additional details

rxin This seems to have been most recently touched by you and was also commented on in the JIRA.

This contribution is my original work and I license the work to the project under the project's open source license.

Author: Justin Pihony <[email protected]>
Author: Justin Pihony <[email protected]>

Closes apache#12601 from JustinPihony/jdbc_reconciliation.
  • Loading branch information
JustinPihony authored and srowen committed Sep 26, 2016
1 parent ac65139 commit 50b89d0
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 73 deletions.
6 changes: 5 additions & 1 deletion docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1100,9 +1100,13 @@ CREATE TEMPORARY VIEW jdbcTable
USING org.apache.spark.sql.jdbc
OPTIONS (
url "jdbc:postgresql:dbserver",
dbtable "schema.tablename"
dbtable "schema.tablename",
user 'username',
password 'password'
)

INSERT INTO TABLE jdbcTable
SELECT * FROM resultTable
{% endhighlight %}

</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Arrays;
import java.util.List;
// $example off:schema_merging$
import java.util.Properties;

// $example on:basic_parquet_example$
import org.apache.spark.api.java.JavaRDD;
Expand Down Expand Up @@ -235,13 +236,33 @@ private static void runJsonDatasetExample(SparkSession spark) {

private static void runJdbcDatasetExample(SparkSession spark) {
// $example on:jdbc_dataset$
// Note: JDBC loading and saving can be achieved via either the load/save or jdbc methods
// Loading data from a JDBC source
Dataset<Row> jdbcDF = spark.read()
.format("jdbc")
.option("url", "jdbc:postgresql:dbserver")
.option("dbtable", "schema.tablename")
.option("user", "username")
.option("password", "password")
.load();

Properties connectionProperties = new Properties();
connectionProperties.put("user", "username");
connectionProperties.put("password", "password");
Dataset<Row> jdbcDF2 = spark.read()
.jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties);

// Saving data to a JDBC source
jdbcDF.write()
.format("jdbc")
.option("url", "jdbc:postgresql:dbserver")
.option("dbtable", "schema.tablename")
.option("user", "username")
.option("password", "password")
.save();

jdbcDF2.write()
.jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties);
// $example off:jdbc_dataset$
}
}
19 changes: 19 additions & 0 deletions examples/src/main/python/sql/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,32 @@ def json_dataset_example(spark):

def jdbc_dataset_example(spark):
# $example on:jdbc_dataset$
# Note: JDBC loading and saving can be achieved via either the load/save or jdbc methods
# Loading data from a JDBC source
jdbcDF = spark.read \
.format("jdbc") \
.option("url", "jdbc:postgresql:dbserver") \
.option("dbtable", "schema.tablename") \
.option("user", "username") \
.option("password", "password") \
.load()

jdbcDF2 = spark.read \
.jdbc("jdbc:postgresql:dbserver", "schema.tablename",
properties={"user": "username", "password": "password"})

# Saving data to a JDBC source
jdbcDF.write \
.format("jdbc") \
.option("url", "jdbc:postgresql:dbserver") \
.option("dbtable", "schema.tablename") \
.option("user", "username") \
.option("password", "password") \
.save()

jdbcDF2.write \
.jdbc("jdbc:postgresql:dbserver", "schema.tablename",
properties={"user": "username", "password": "password"})
# $example off:jdbc_dataset$


Expand Down
4 changes: 4 additions & 0 deletions examples/src/main/r/RSparkSQLExample.R
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,11 @@ results <- collect(sql("FROM src SELECT key, value"))


# $example on:jdbc_dataset$
# Loading data from a JDBC source
df <- read.jdbc("jdbc:postgresql:dbserver", "schema.tablename", user = "username", password = "password")

# Saving data to a JDBC source
write.jdbc(df, "jdbc:postgresql:dbserver", "schema.tablename", user = "username", password = "password")
# $example off:jdbc_dataset$

# Stop the SparkSession now
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.examples.sql

import java.util.Properties

import org.apache.spark.sql.SparkSession

object SQLDataSourceExample {
Expand Down Expand Up @@ -148,13 +150,33 @@ object SQLDataSourceExample {

private def runJdbcDatasetExample(spark: SparkSession): Unit = {
// $example on:jdbc_dataset$
// Note: JDBC loading and saving can be achieved via either the load/save or jdbc methods
// Loading data from a JDBC source
val jdbcDF = spark.read
.format("jdbc")
.option("url", "jdbc:postgresql:dbserver")
.option("dbtable", "schema.tablename")
.option("user", "username")
.option("password", "password")
.load()

val connectionProperties = new Properties()
connectionProperties.put("user", "username")
connectionProperties.put("password", "password")
val jdbcDF2 = spark.read
.jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties)

// Saving data to a JDBC source
jdbcDF.write
.format("jdbc")
.option("url", "jdbc:postgresql:dbserver")
.option("dbtable", "schema.tablename")
.option("user", "username")
.option("password", "password")
.save()

jdbcDF2.write
.jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties)
// $example off:jdbc_dataset$
}
}
59 changes: 4 additions & 55 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -425,62 +425,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
def jdbc(url: String, table: String, connectionProperties: Properties): Unit = {
assertNotPartitioned("jdbc")
assertNotBucketed("jdbc")

// to add required options like URL and dbtable
val params = extraOptions.toMap ++ Map("url" -> url, "dbtable" -> table)
val jdbcOptions = new JDBCOptions(params)
val jdbcUrl = jdbcOptions.url
val jdbcTable = jdbcOptions.table

val props = new Properties()
extraOptions.foreach { case (key, value) =>
props.put(key, value)
}
// connectionProperties should override settings in extraOptions
props.putAll(connectionProperties)
val conn = JdbcUtils.createConnectionFactory(jdbcUrl, props)()

try {
var tableExists = JdbcUtils.tableExists(conn, jdbcUrl, jdbcTable)

if (mode == SaveMode.Ignore && tableExists) {
return
}

if (mode == SaveMode.ErrorIfExists && tableExists) {
sys.error(s"Table $jdbcTable already exists.")
}

if (mode == SaveMode.Overwrite && tableExists) {
if (jdbcOptions.isTruncate &&
JdbcUtils.isCascadingTruncateTable(jdbcUrl) == Some(false)) {
JdbcUtils.truncateTable(conn, jdbcTable)
} else {
JdbcUtils.dropTable(conn, jdbcTable)
tableExists = false
}
}

// Create the table if the table didn't exist.
if (!tableExists) {
val schema = JdbcUtils.schemaString(df, jdbcUrl)
// To allow certain options to append when create a new table, which can be
// table_options or partition_options.
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
val createtblOptions = jdbcOptions.createTableOptions
val sql = s"CREATE TABLE $jdbcTable ($schema) $createtblOptions"
val statement = conn.createStatement
try {
statement.executeUpdate(sql)
} finally {
statement.close()
}
}
} finally {
conn.close()
}

JdbcUtils.saveTable(df, jdbcUrl, jdbcTable, props)
this.extraOptions = this.extraOptions ++ (connectionProperties.asScala)
// explicit url and dbtable should override all
this.extraOptions += ("url" -> url, "dbtable" -> table)
format("jdbc").save()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ class JDBCOptions(
// ------------------------------------------------------------
// Required parameters
// ------------------------------------------------------------
require(parameters.isDefinedAt("url"), "Option 'url' is required.")
require(parameters.isDefinedAt("dbtable"), "Option 'dbtable' is required.")
// a JDBC URL
val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
val url = parameters("url")
// name of table
val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
val table = parameters("dbtable")

// ------------------------------------------------------------
// Optional parameter list
Expand All @@ -44,6 +46,11 @@ class JDBCOptions(
// the number of partitions
val numPartitions = parameters.getOrElse("numPartitions", null)

require(partitionColumn == null ||
(lowerBound != null && upperBound != null && numPartitions != null),
"If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," +
" and 'numPartitions' are required.")

// ------------------------------------------------------------
// The options for DataFrameWriter
// ------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,102 @@ package org.apache.spark.sql.execution.datasources.jdbc

import java.util.Properties

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider}
import scala.collection.JavaConverters.mapAsJavaMapConverter

class JdbcRelationProvider extends RelationProvider with DataSourceRegister {
import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}

class JdbcRelationProvider extends CreatableRelationProvider
with RelationProvider with DataSourceRegister {

override def shortName(): String = "jdbc"

/** Returns a new base relation with the given parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
val jdbcOptions = new JDBCOptions(parameters)
if (jdbcOptions.partitionColumn != null
&& (jdbcOptions.lowerBound == null
|| jdbcOptions.upperBound == null
|| jdbcOptions.numPartitions == null)) {
sys.error("Partitioning incompletely specified")
}
val partitionColumn = jdbcOptions.partitionColumn
val lowerBound = jdbcOptions.lowerBound
val upperBound = jdbcOptions.upperBound
val numPartitions = jdbcOptions.numPartitions

val partitionInfo = if (jdbcOptions.partitionColumn == null) {
val partitionInfo = if (partitionColumn == null) {
null
} else {
JDBCPartitioningInfo(
jdbcOptions.partitionColumn,
jdbcOptions.lowerBound.toLong,
jdbcOptions.upperBound.toLong,
jdbcOptions.numPartitions.toInt)
partitionColumn, lowerBound.toLong, upperBound.toLong, numPartitions.toInt)
}
val parts = JDBCRelation.columnPartition(partitionInfo)
val properties = new Properties() // Additional properties that we will pass to getConnection
parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession)
}

/*
* The following structure applies to this code:
* | tableExists | !tableExists
*------------------------------------------------------------------------------------
* Ignore | BaseRelation | CreateTable, saveTable, BaseRelation
* ErrorIfExists | ERROR | CreateTable, saveTable, BaseRelation
* Overwrite* | (DropTable, CreateTable,) | CreateTable, saveTable, BaseRelation
* | saveTable, BaseRelation |
* Append | saveTable, BaseRelation | CreateTable, saveTable, BaseRelation
*
* *Overwrite & tableExists with truncate, will not drop & create, but instead truncate
*/
override def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
val jdbcOptions = new JDBCOptions(parameters)
val url = jdbcOptions.url
val table = jdbcOptions.table

val props = new Properties()
props.putAll(parameters.asJava)
val conn = JdbcUtils.createConnectionFactory(url, props)()

try {
val tableExists = JdbcUtils.tableExists(conn, url, table)

val (doCreate, doSave) = (mode, tableExists) match {
case (SaveMode.Ignore, true) => (false, false)
case (SaveMode.ErrorIfExists, true) => throw new AnalysisException(
s"Table or view '$table' already exists, and SaveMode is set to ErrorIfExists.")
case (SaveMode.Overwrite, true) =>
if (jdbcOptions.isTruncate && JdbcUtils.isCascadingTruncateTable(url) == Some(false)) {
JdbcUtils.truncateTable(conn, table)
(false, true)
} else {
JdbcUtils.dropTable(conn, table)
(true, true)
}
case (SaveMode.Append, true) => (false, true)
case (_, true) => throw new IllegalArgumentException(s"Unexpected SaveMode, '$mode'," +
" for handling existing tables.")
case (_, false) => (true, true)
}

if (doCreate) {
val schema = JdbcUtils.schemaString(data, url)
// To allow certain options to append when create a new table, which can be
// table_options or partition_options.
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
val createtblOptions = jdbcOptions.createTableOptions
val sql = s"CREATE TABLE $table ($schema) $createtblOptions"
val statement = conn.createStatement
try {
statement.executeUpdate(sql)
} finally {
statement.close()
}
}
if (doSave) JdbcUtils.saveTable(data, url, table, props)
} finally {
conn.close()
}

createRelation(sqlContext, parameters)
}
}
Loading

0 comments on commit 50b89d0

Please sign in to comment.