Skip to content

Commit

Permalink
Add PerfSource to measure performance per partition
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Jun 24, 2020
1 parent 8fc08f7 commit 1da946e
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package uk.co.gresearch.spark.dgraph.connector.encoder

import com.google.gson.Gson
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import uk.co.gresearch.spark.dgraph.connector.{Json, PerfJson}

/**
* Encodes a perf json response.
*/
case class PerfEncoder() extends JsonNodeInternalRowEncoder {

/**
* Returns the schema of this table. If the table is not readable and doesn't have a schema, an
* empty schema can be returned here.
* From: org.apache.spark.sql.connector.catalog.Table.schema
*/
override def schema(): StructType = StructType(Seq(
StructField("partitionTargets", ArrayType(StringType, containsNull = false), nullable=false),
StructField("partitionPredicates", ArrayType(StringType, containsNull = false), nullable=true),
StructField("partitionUidsFirst", LongType, nullable=true),
StructField("partitionUidsLength", LongType, nullable=true),

StructField("sparkStageId", IntegerType, nullable=false),
StructField("sparkStageAttemptNumber", IntegerType, nullable=false),
StructField("sparkPartitionId", IntegerType, nullable=false),
StructField("sparkAttemptNumber", IntegerType, nullable=false),
StructField("sparkTaskAttemptId", LongType, nullable=false),

StructField("dgraphAssignTimestamp", LongType, nullable=true),
StructField("dgraphParsing", LongType, nullable=true),
StructField("dgraphProcessing", LongType, nullable=true),
StructField("dgraphEncoding", LongType, nullable=true),
StructField("dgraphTotal", LongType, nullable=true)
))

/**
* Returns the actual schema of this data source scan, which may be different from the physical
* schema of the underlying storage, as column pruning or other optimizations may happen.
* From: org.apache.spark.sql.connector.read.Scan.readSchema
*/
override def readSchema(): StructType = schema()

/**
* Encodes the given perf json result into InternalRows.
*
* @param json perf json result
* @param member member in the json that has the result
* @return internal rows
*/
override def fromJson(json: Json, member: String): Iterator[InternalRow] = {
val perf = new Gson().fromJson(json.string, classOf[PerfJson])
Iterator(InternalRow(
new GenericArrayData(perf.partitionTargets.map(UTF8String.fromString)),
Option(perf.partitionPredicates).map(p => new GenericArrayData(p.map(UTF8String.fromString))).orNull,
perf.partitionUidsFirst,
perf.partitionUidsLength,

perf.sparkStageId,
perf.sparkStageAttemptNumber,
perf.sparkPartitionId,
perf.sparkAttemptNumber,
perf.sparkTaskAttemptId,

perf.dgraphAssignTimestamp,
perf.dgraphParsing,
perf.dgraphProcessing,
perf.dgraphEncoding,
perf.dgraphTotal
))
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package uk.co.gresearch.spark.dgraph.connector.executor
import com.google.gson.Gson
import io.dgraph.DgraphClient
import io.dgraph.DgraphProto.Response
import io.grpc.ManagedChannel
import org.apache.spark.TaskContext
import uk.co.gresearch.spark.dgraph.connector.{GraphQl, Json, Partition, PerfJson, getClientFromChannel, toChannel}

case class DgraphPerfExecutor(partition: Partition) extends JsonGraphQlExecutor {

/**
* Executes the given graphql query and returns the query result as json.
*
* @param query query
* @return result
*/
override def query(query: GraphQl): Json = {
val channels: Seq[ManagedChannel] = partition.targets.map(toChannel)
try {
val client: DgraphClient = getClientFromChannel(channels)
val response: Response = client.newReadOnlyTransaction().query(query.string)
toJson(response)
} finally {
channels.foreach(_.shutdown())
}
}

def toJson(response: Response): Json = {
val task = TaskContext.get()
val latency = Some(response).filter(_.hasLatency).map(_.getLatency)
val perf =
new PerfJson(
partition.targets.map(_.target).toArray,
partition.predicates.map(_.map(_.predicateName).toArray).orNull,
partition.uids.map(_.first.asInstanceOf[java.lang.Long]).orNull,
partition.uids.map(_.length.asInstanceOf[java.lang.Long]).orNull,

task.stageId(),
task.stageAttemptNumber(),
task.partitionId(),
task.attemptNumber(),
task.taskAttemptId(),

latency.map(_.getAssignTimestampNs.asInstanceOf[java.lang.Long]).orNull,
latency.map(_.getParsingNs.asInstanceOf[java.lang.Long]).orNull,
latency.map(_.getProcessingNs.asInstanceOf[java.lang.Long]).orNull,
latency.map(_.getEncodingNs.asInstanceOf[java.lang.Long]).orNull,
latency.map(_.getTotalNs.asInstanceOf[java.lang.Long]).orNull
)
val json = new Gson().toJson(perf)
Json(json)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package uk.co.gresearch.spark.dgraph.connector.executor
import uk.co.gresearch.spark.dgraph.connector.Partition

case class DgraphPerfExecutorProvider() extends ExecutorProvider {

/**
* Provide an executor for the given partition.
*
* @param partition a partitioon
* @return an executor
*/
override def getExecutor(partition: Partition): JsonGraphQlExecutor =
DgraphPerfExecutor(partition)

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import io.dgraph.DgraphGrpc.DgraphStub
import io.dgraph.{DgraphClient, DgraphGrpc}
import io.grpc.ManagedChannel
import io.grpc.netty.NettyChannelBuilder
import org.apache.spark.sql.{DataFrame, DataFrameReader, Dataset, Encoder, Encoders}
import org.apache.spark.sql.{DataFrame, DataFrameReader, Encoder, Encoders}

package object connector {

Expand Down Expand Up @@ -92,6 +92,42 @@ package object connector {
case class GraphQl(string: String) // technically not GraphQl but GraphQl+: https://dgraph.io/docs/query-language/
case class Json(string: String)

case class Perf(partitionTargets: Seq[String],
partitionPredicates: Option[Seq[String]],
partitionUidsFirst: Option[Long],
partitionUidsLength: Option[Long],

sparkStageId: Int,
sparkStageAttemptNumber: Int,
sparkPartitionId: Int,
sparkAttemptNumber: Int,
sparkTaskAttemptId: Long,

dgraphAssignTimestamp: Option[Long],
dgraphParsing: Option[Long],
dgraphProcessing: Option[Long],
dgraphEncoding: Option[Long],
dgraphTotal: Option[Long]
)

class PerfJson(val partitionTargets: Array[String],
val partitionPredicates: Array[String],
val partitionUidsFirst: java.lang.Long,
val partitionUidsLength: java.lang.Long,

val sparkStageId: Int,
val sparkStageAttemptNumber: Int,
val sparkPartitionId: Int,
val sparkAttemptNumber: Int,
val sparkTaskAttemptId: Long,

val dgraphAssignTimestamp: java.lang.Long,
val dgraphParsing: java.lang.Long,
val dgraphProcessing: java.lang.Long,
val dgraphEncoding: java.lang.Long,
val dgraphTotal: java.lang.Long
)

val TargetOption: String = "dgraph.target"
val TargetsOption: String = "dgraph.targets"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package uk.co.gresearch.spark.dgraph.connector.sources

import java.util

import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import uk.co.gresearch.spark.dgraph.connector._
import uk.co.gresearch.spark.dgraph.connector.encoder.{PerfEncoder, StringTripleEncoder, TypedTripleEncoder}
import uk.co.gresearch.spark.dgraph.connector.executor.{DgraphExecutorProvider, DgraphPerfExecutorProvider}
import uk.co.gresearch.spark.dgraph.connector.model.TripleTableModel
import uk.co.gresearch.spark.dgraph.connector.partitioner.PartitionerProvider

class PerfSource() extends TableProviderBase
with TargetsConfigParser with SchemaProvider
with ClusterStateProvider with PartitionerProvider {

override def shortName(): String = "dgraph-triples"

override def inferSchema(options: CaseInsensitiveStringMap): StructType =
PerfEncoder().schema()

def getTripleMode(options: CaseInsensitiveStringMap): Option[String] =
getStringOption(TriplesModeOption, options)

override def getTable(schema: StructType,
partitioning: Array[Transform],
properties: util.Map[String, String]): Table = {
val options = new CaseInsensitiveStringMap(properties)
val targets = getTargets(options)
val schema = getSchema(targets)
val clusterState = getClusterState(targets)
val partitioner = getPartitioner(schema, clusterState, options)
val encoder = PerfEncoder()
val execution = DgraphPerfExecutorProvider()
val model = TripleTableModel(execution, encoder)
new TripleTable(partitioner, model, clusterState.cid)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package uk.co.gresearch.spark.dgraph.connector.sources

import org.scalatest.FunSpec
import uk.co.gresearch.spark.SparkTestSession
import uk.co.gresearch.spark.dgraph.DgraphTestCluster
import uk.co.gresearch.spark.dgraph.connector._

import scala.collection.Set

class TestPerfSource extends FunSpec
with SparkTestSession with DgraphTestCluster {

import spark.implicits._

describe("PerfSource") {

it("should load via paths") {
val perfs =
spark
.read
.format("uk.co.gresearch.spark.dgraph.connector.sources.PerfSource")
.options(Map(
PredicatePartitionerPredicatesOption -> "2",
UidRangePartitionerUidsPerPartOption -> "2"
))
.load(cluster.grpc)
.as[Perf]
.collect()
.sortBy(_.sparkPartitionId)

// Example:
// Perf(Seq("localhost:9080"), Some(Seq("release_date", "revenue")), Some(0), Some(2), 0, 0, 0, 0, 0, Some(450097), Some(95947), Some(309825), Some(20358), Some(945510)),
// Perf(Seq("localhost:9080"), Some(Seq("release_date", "revenue")), Some(2), Some(2), 0, 0, 1, 0, 1, Some(454868), Some(78280), Some(321795), Some(21096), Some(944941)),
// Perf(Seq("localhost:9080"), Some(Seq("dgraph.graphql.schema", "starring")), Some(0), Some(2), 0, 0, 2, 0, 2, Some(401568), Some(72688), Some(197814), Some(10291), Some(741997)),
// Perf(Seq("localhost:9080"), Some(Seq("dgraph.graphql.schema", "starring")), Some(2), Some(2), 0, 0, 3, 0, 3, Some(345644), Some(112154), Some(231255), Some(8781), Some(754845)),
// Perf(Seq("localhost:9080"), Some(Seq("director", "running_time")), Some(0), Some(2), 0, 0, 4, 0, 4, Some(352526), Some(75411), Some(283403), Some(9663), Some(781171)),
// Perf(Seq("localhost:9080"), Some(Seq("director", "running_time")), Some(2), Some(2), 0, 0, 5, 0, 5, Some(315593), Some(66102), Some(256086), Some(10080), Some(703224)),
// Perf(Seq("localhost:9080"), Some(Seq("dgraph.type", "name")), Some(0), Some(2), 0, 0, 6, 0, 6, Some(381511), Some(71763), Some(216050), Some(11367), Some(731836)),
// Perf(Seq("localhost:9080"), Some(Seq("dgraph.type", "name")), Some(2), Some(2), 0, 0, 7, 0, 7, Some(330556), Some(68906), Some(249247), Some(13140), Some(721444)),
// Perf(Seq("localhost:9080"), Some(Seq("dgraph.type", "name")), Some(4), Some(2), 0, 0, 8, 0, 8, Some(393403), Some(92074), Some(216785), Some(10273), Some(779874)),
// Perf(Seq("localhost:9080"), Some(Seq("dgraph.type", "name")), Some(6), Some(2), 0, 0, 9, 0, 9, Some(394063), Some(86764), Some(238309), Some(10604), Some(797032)),
// Perf(Seq("localhost:9080"), Some(Seq("dgraph.type", "name")), Some(8), Some(2), 0, 0, 10, 0, 10, Some(365052), Some(68823), Some(294403), Some(15287), Some(807023)),

assert(perfs.length === 11)
assert(perfs.forall(_.partitionTargets == Seq("localhost:9080")))
assert(perfs.map(_.partitionPredicates).distinct === Seq(
Some(Seq("release_date", "revenue")),
Some(Seq("dgraph.graphql.schema", "starring")),
Some(Seq("director", "running_time")),
Some(Seq("dgraph.type", "name")),
))
assert(perfs.map(p => (p.partitionPredicates, p.partitionUidsFirst)).groupBy(_._1.get).mapValues(_.map(_._2.get).toSeq) === Map(
Seq("release_date", "revenue") -> Seq(0, 2),
Seq("dgraph.graphql.schema", "starring") -> Seq(0, 2),
Seq("director", "running_time") -> Seq(0, 2),
Seq("dgraph.type", "name") -> Seq(0, 2, 4, 6, 8),
))
assert(perfs.forall(_.partitionUidsLength.contains(2)))

assert(perfs.forall(_.sparkStageId == 0))
assert(perfs.forall(_.sparkStageAttemptNumber == 0))
assert(perfs.zipWithIndex.forall { case (perf, idx) => perf.sparkPartitionId == idx })
assert(perfs.forall(_.sparkAttemptNumber == 0))
assert(perfs.zipWithIndex.forall { case (perf, idx) => perf.sparkTaskAttemptId == idx })

assert(perfs.forall(_.dgraphAssignTimestamp.isDefined))
assert(perfs.forall(_.dgraphParsing.isDefined))
assert(perfs.forall(_.dgraphProcessing.isDefined))
assert(perfs.forall(_.dgraphEncoding.isDefined))
assert(perfs.forall(_.dgraphAssignTimestamp.isDefined))
}

}

}

0 comments on commit 1da946e

Please sign in to comment.