Skip to content

Commit

Permalink
Support reading shape keys
Browse files Browse the repository at this point in the history
  • Loading branch information
Kontinuation committed Aug 21, 2024
1 parent cf3a0f7 commit 570cb28
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 26 deletions.
1 change: 1 addition & 0 deletions .github/linters/codespell.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
LOD
actualy
afterall
atmost
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ import java.util.Locale
import scala.collection.JavaConverters._
import scala.util.Try

/**
* A Spark SQL data source for reading ESRI Shapefiles. This data source supports reading the
* following components of shapefiles:
*
* <ul> <li>.shp: the main file <li>.dbf: (optional) the attribute file <li>.shx: (optional) the
* index file <li>.cpg: (optional) the code page file <li>.prj: (optional) the projection file
* </ul>
*
* <p>The load path can be a directory containing the shapefiles, or a path to the .shp file. If
* the path refers to a .shp file, the data source will also read other components such as .dbf
* and .shx files in the same directory.
*/
class ShapefileDataSource extends FileDataSourceV2 with DataSourceRegister {

override def shortName(): String = "shapefile"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FSDataInputStream
import org.apache.hadoop.fs.Path
import org.apache.sedona.common.FunctionsGeoTools
import org.apache.sedona.core.formatMapper.shapefileParser.shapes.CombineShapeReader
import org.apache.sedona.core.formatMapper.shapefileParser.shapes.DbfFileReader
import org.apache.sedona.core.formatMapper.shapefileParser.shapes.PrimitiveShape
import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShapeFileReader
Expand All @@ -39,6 +38,7 @@ import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.shapefile.ShapefilePartitionReader.logger
import org.apache.spark.sql.execution.datasources.shapefile.ShapefilePartitionReader.openStream
import org.apache.spark.sql.execution.datasources.shapefile.ShapefilePartitionReader.tryOpenStream
import org.apache.spark.sql.execution.datasources.shapefile.ShapefileUtils.baseSchema
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types.StructType
import org.locationtech.jts.geom.GeometryFactory
Expand Down Expand Up @@ -122,10 +122,7 @@ class ShapefilePartitionReader(
ShapefileUtils.fieldDescriptorsToStructFields(reader.getFieldDescriptors.asScala.toSeq)
}
.getOrElse(Seq.empty)
geometryField match {
case Some(geomField) => StructType(geomField +: dbfFields)
case None => StructType(dbfFields)
}
StructType(baseSchema(options).fields ++ dbfFields)
}

// projection from shpSchema to readDataSchema
Expand Down Expand Up @@ -226,10 +223,11 @@ class ShapefilePartitionReader(
Seq.fill(fieldValueConverters.length)(null)
}

val shpRow = if (geometryField.isDefined) {
InternalRow.fromSeq(geometry.map(GeometryUDT.serialize).orNull +: attrValues.toSeq)
val serializedGeom = geometry.map(GeometryUDT.serialize).orNull
val shpRow = if (options.keyFieldName.isDefined) {
InternalRow.fromSeq(serializedGeom +: key.getIndex +: attrValues.toSeq)
} else {
InternalRow.fromSeq(attrValues.toSeq)
InternalRow.fromSeq(serializedGeom +: attrValues.toSeq)
}
currentRow = projection(shpRow)
true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,26 @@ package org.apache.spark.sql.execution.datasources.shapefile

import org.apache.spark.sql.util.CaseInsensitiveStringMap

case class ShapefileReadOptions(geometryFieldName: String, charset: Option[String])
/**
* Options for reading Shapefiles.
* @param geometryFieldName
* The name of the geometry field.
* @param keyFieldName
* The name of the shape key field.
* @param charset
* The charset of non-spatial attributes.
*/
case class ShapefileReadOptions(
geometryFieldName: String,
keyFieldName: Option[String],
charset: Option[String])

object ShapefileReadOptions {
def parse(options: CaseInsensitiveStringMap): ShapefileReadOptions = {
val geometryFieldName = options.getOrDefault("geometry.name", "geometry")
val keyFieldName =
if (options.containsKey("key.name")) Some(options.get("key.name")) else None
val charset = if (options.containsKey("charset")) Some(options.get("charset")) else None
ShapefileReadOptions(geometryFieldName, charset)
ShapefileReadOptions(geometryFieldName, keyFieldName, charset)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,15 @@ package org.apache.spark.sql.execution.datasources.shapefile

import org.apache.hadoop.fs.FileStatus
import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.TableCapability
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.LogicalWriteInfo
import org.apache.spark.sql.connector.write.WriteBuilder
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.shapefile.ShapefileUtils.fieldDescriptorsToSchema
import org.apache.spark.sql.execution.datasources.shapefile.ShapefileUtils.mergeSchemas
import org.apache.spark.sql.execution.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas}
import org.apache.spark.sql.execution.datasources.v2.FileTable
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration

import java.util.Locale
Expand Down Expand Up @@ -66,14 +63,14 @@ case class ShapefileTable(

if (!files.exists(isShpFile)) None
else {
val geometryFieldName = ShapefileReadOptions.parse(options).geometryFieldName
val readOptions = ShapefileReadOptions.parse(options)
val resolver = sparkSession.sessionState.conf.resolver
val dbfFiles = files.filter(isDbfFile)
if (dbfFiles.isEmpty) {
Some(StructType(StructField(geometryFieldName, GeometryUDT) :: Nil))
Some(baseSchema(readOptions, Some(resolver)))
} else {
val serializableConf = new SerializableConfiguration(
sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap))
val resolver = sparkSession.sessionState.conf.resolver
val partiallyMergedSchemas = sparkSession.sparkContext
.parallelize(dbfFiles)
.mapPartitions { iter =>
Expand All @@ -84,10 +81,7 @@ case class ShapefileTable(
val dbfParser = new DbfParseUtil()
dbfParser.parseFileHead(stream)
val fieldDescriptors = dbfParser.getFieldDescriptors
fieldDescriptorsToSchema(
fieldDescriptors.asScala.toSeq,
geometryFieldName,
resolver)
fieldDescriptorsToSchema(fieldDescriptors.asScala.toSeq, readOptions, resolver)
} finally {
stream.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,36 @@ object ShapefileUtils {

def fieldDescriptorsToSchema(
fieldDescriptors: Seq[FieldDescriptor],
geometryFieldName: String,
options: ShapefileReadOptions,
resolver: Resolver): StructType = {
val structFields = fieldDescriptorsToStructFields(fieldDescriptors)
val geometryFieldName = options.geometryFieldName
if (structFields.exists(f => resolver(f.name, geometryFieldName))) {
throw new IllegalArgumentException(
s"Field name $geometryFieldName is reserved for geometry but appears in non-spatial attributes. " +
"Please specify a different field name for geometry using the 'geometry.name' option.")
}
StructType(StructField(geometryFieldName, GeometryUDT) +: structFields)
options.keyFieldName.foreach { name =>
if (structFields.exists(f => resolver(f.name, name))) {
throw new IllegalArgumentException(
s"Field name $name is reserved for shape key but appears in non-spatial attributes. " +
"Please specify a different field name for shape key using the 'key.name' option.")
}
}
StructType(baseSchema(options, Some(resolver)).fields ++ structFields)
}

def baseSchema(options: ShapefileReadOptions, resolver: Option[Resolver] = None): StructType = {
options.keyFieldName match {
case Some(name) =>
if (resolver.exists(_(name, options.geometryFieldName))) {
throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same")
}
StructType(
Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType)))
case _ =>
StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil)
}
}

def fieldValueConverter(desc: FieldDescriptor, cpg: Option[String]): Array[Byte] => Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,50 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll {
}
}

it("read partitioned directory") {
FileUtils.cleanDirectory(new File(temporaryLocation))
Files.createDirectory(new File(temporaryLocation + "/part=1").toPath)
Files.createDirectory(new File(temporaryLocation + "/part=2").toPath)
FileUtils.copyFile(
new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"),
new File(temporaryLocation + "/part=1/datatypes1.shp"))
FileUtils.copyFile(
new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"),
new File(temporaryLocation + "/part=1/datatypes1.dbf"))
FileUtils.copyFile(
new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"),
new File(temporaryLocation + "/part=1/datatypes1.cpg"))
FileUtils.copyFile(
new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"),
new File(temporaryLocation + "/part=2/datatypes2.shp"))
FileUtils.copyFile(
new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"),
new File(temporaryLocation + "/part=2/datatypes2.dbf"))
FileUtils.copyFile(
new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"),
new File(temporaryLocation + "/part=2/datatypes2.cpg"))

val shapefileDf = sparkSession.read
.format("shapefile")
.load(temporaryLocation)
.select("part", "id", "aInt", "aUnicode", "geometry")
val rows = shapefileDf.collect()
assert(rows.length == 9)
rows.foreach { row =>
assert(row.getAs[Geometry]("geometry").isInstanceOf[Point])
val id = row.getAs[Long]("id")
assert(row.getAs[Long]("aInt") == id)
if (id < 10) {
assert(row.getAs[Int]("part") == 1)
} else {
assert(row.getAs[Int]("part") == 2)
}
if (id > 0) {
assert(row.getAs[String]("aUnicode") == s"测试$id")
}
}
}

it("read with custom geometry column name") {
val shapefileDf = sparkSession.read
.format("shapefile")
Expand Down Expand Up @@ -447,6 +491,83 @@ class ShapefileTests extends TestBaseScala with BeforeAndAfterAll {
"osm_id is reserved for geometry but appears in non-spatial attributes"))
}

it("read with shape key column") {
val shapefileDf = sparkSession.read
.format("shapefile")
.option("key.name", "fid")
.load(resourceFolder + "shapefiles/datatypes")
.select("id", "fid", "geometry", "aUnicode")
val schema = shapefileDf.schema
assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT)
assert(schema.find(_.name == "id").get.dataType == LongType)
assert(schema.find(_.name == "fid").get.dataType == LongType)
assert(schema.find(_.name == "aUnicode").get.dataType == StringType)
val rows = shapefileDf.collect()
assert(rows.length == 9)
shapefileDf.show()
rows.foreach { row =>
val geom = row.getAs[Geometry]("geometry")
assert(geom.isInstanceOf[Point])
val id = row.getAs[Long]("id")
if (id > 0) {
assert(row.getAs[Long]("fid") == id % 10)
assert(row.getAs[String]("aUnicode") == s"测试$id")
} else {
assert(row.getAs[Long]("fid") == 5)
}
}
}

it("read with both custom geometry column and shape key column") {
val shapefileDf = sparkSession.read
.format("shapefile")
.option("geometry.name", "g")
.option("key.name", "fid")
.load(resourceFolder + "shapefiles/datatypes")
.select("id", "fid", "g", "aUnicode")
val schema = shapefileDf.schema
assert(schema.find(_.name == "g").get.dataType == GeometryUDT)
assert(schema.find(_.name == "id").get.dataType == LongType)
assert(schema.find(_.name == "fid").get.dataType == LongType)
assert(schema.find(_.name == "aUnicode").get.dataType == StringType)
val rows = shapefileDf.collect()
assert(rows.length == 9)
shapefileDf.show()
rows.foreach { row =>
val geom = row.getAs[Geometry]("g")
assert(geom.isInstanceOf[Point])
val id = row.getAs[Long]("id")
if (id > 0) {
assert(row.getAs[Long]("fid") == id % 10)
assert(row.getAs[String]("aUnicode") == s"测试$id")
} else {
assert(row.getAs[Long]("fid") == 5)
}
}
}

it("read with invalid shape key column") {
val exception = intercept[Exception] {
sparkSession.read
.format("shapefile")
.option("geometry.name", "g")
.option("key.name", "aDate")
.load(resourceFolder + "shapefiles/datatypes")
}
assert(
exception.getMessage.contains(
"aDate is reserved for shape key but appears in non-spatial attributes"))

val exception2 = intercept[Exception] {
sparkSession.read
.format("shapefile")
.option("geometry.name", "g")
.option("key.name", "g")
.load(resourceFolder + "shapefiles/datatypes")
}
assert(exception2.getMessage.contains("geometry.name and key.name cannot be the same"))
}

it("read with custom charset") {
FileUtils.cleanDirectory(new File(temporaryLocation))
FileUtils.copyFile(
Expand Down

0 comments on commit 570cb28

Please sign in to comment.