Skip to content

Commit

Permalink
[query] Lift backend state into {Service|Py4J}BackendApi
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Sep 20, 2024
1 parent 16d54bd commit fb3cc0d
Show file tree
Hide file tree
Showing 4 changed files with 350 additions and 362 deletions.
123 changes: 0 additions & 123 deletions hail/src/main/scala/is/hail/backend/BackendServer.scala

This file was deleted.

154 changes: 154 additions & 0 deletions hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package is.hail.backend.py4j

import com.google.api.client.http.HttpStatusCodes
import com.sun.net.{httpserver => http}
import is.hail.HailFeatureFlags
import is.hail.asm4s.HailClassLoader
import is.hail.backend.caching.BlockMatrixCache
import is.hail.backend.{Backend, BackendContext, ExecuteContext, HttpLikeBackendRpc, TempFileManager}
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.lowering.IrMetadata
import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction}
import is.hail.linalg.BlockMatrix
import is.hail.types.virtual.Kinds.{BlockMatrix, Matrix, Table, Value}
import is.hail.utils.ExecutionTimer.Timings
import is.hail.utils._
import is.hail.variant.ReferenceGenome
import org.json4s._
import org.json4s.jackson.{JsonMethods, Serialization}

import java.io.Closeable
import java.net.InetSocketAddress
import java.util.concurrent._
import scala.collection.mutable

final class Py4JBackendApi(override val backend: Backend) extends Py4JBackendExtensions with Closeable {

override val references: mutable.Map[String, ReferenceGenome] =
mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*)

override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
override def longLifeTempFileManager: TempFileManager = null

private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
private[this] val bmCache = new BlockMatrixCache()
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
private[this] val persistedIr: mutable.Map[Int, BaseIR] = mutable.Map()
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)


override def close(): Unit = {
HttpServer.close()
backend.close()
}

object HttpServer extends HttpLikeBackendRpc[http.HttpExchange] with Closeable {
// 0 => let the OS pick an available port
private[this] val httpServer = http.HttpServer.create(new InetSocketAddress(0), 10)

Check failure on line 47 in hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala#L47

Server-Side Request Forgery occur when a web server executes a request to a user supplied destination parameter that is not validated.

private[this] val thread = {
// This HTTP server *must not* start non-daemon threads because such threads keep the JVM
// alive. A living JVM indicates to Spark that the job is incomplete. This does not manifest
// when you run jobs in a local pyspark (because you'll Ctrl-C out of Python regardless of the
// JVM's state) nor does it manifest in a Notebook (again, you'll kill the Notebook kernel
// explicitly regardless of the JVM). It *does* manifest when submitting jobs with
//
// gcloud dataproc submit ...
//
// or
//
// spark-submit
//
// setExecutor(null) ensures the server creates no new threads:
//
/* > If this method is not called (before start()) or if it is called with a null Executor,
* then */
// > a default implementation is used, which uses the thread which was created by the start()
// > method.
//
/* Source:
* https://docs.oracle.com/en/java/javase/11/docs/api/jdk.httpserver/com/sun/net/httpserver/HttpServer.html#setExecutor(java.util.concurrent.Executor) */
//
httpServer.createContext("/", runRpc(_: http.HttpExchange))
httpServer.setExecutor(null)
val t = Executors.defaultThreadFactory().newThread(() => httpServer.start())
t.setDaemon(true)
t
}

def port: Int = httpServer.getAddress.getPort

override def close(): Unit =
httpServer.stop(10)

thread.start()

implicit private object Handler
extends Routing with Write[http.HttpExchange] with Context[http.HttpExchange]
with ErrorHandling {

override def route(req: http.HttpExchange): Route =
req.getRequestURI.getPath match {
case "/value/type" => Routes.TypeOf(Value)
case "/table/type" => Routes.TypeOf(Table)
case "/matrixtable/type" => Routes.TypeOf(Matrix)
case "/blockmatrix/type" => Routes.TypeOf(BlockMatrix)
case "/execute" => Routes.Execute
case "/vcf/metadata/parse" => Routes.ParseVcfMetadata
case "/fam/import" => Routes.ImportFam
case "/references/load" => Routes.LoadReferencesFromDataset
case "/references/from_fasta" => Routes.LoadReferencesFromFASTA
}

override def payload(req: http.HttpExchange): JValue =
using(req.getRequestBody)(JsonMethods.parse(_))

override def timings(req: http.HttpExchange)(t: Timings): Unit = {
val ts = Serialization.write(Map("timings" -> t))
req.getResponseHeaders.add("X-Hail-Timings", ts)
}

override def result(req: http.HttpExchange)(result: Array[Byte]): Unit =
respond(req)(HttpStatusCodes.STATUS_CODE_OK, result)

override def error(req: http.HttpExchange)(t: Throwable): Unit =
respond(req)(
HttpStatusCodes.STATUS_CODE_SERVER_ERROR,
jsonToBytes {
val (shortMessage, expandedMessage, errorId) = handleForPython(t)
JObject(
"short" -> JString(shortMessage),
"expanded" -> JString(expandedMessage),
"error_id" -> JInt(errorId),
)
},
)

private[this] def respond(req: http.HttpExchange)(code: Int, payload: Array[Byte]): Unit = {
req.sendResponseHeaders(code, payload.length)
using(req.getResponseBody)(_.write(payload))
}

override def scoped[A](req: http.HttpExchange)(f: ExecuteContext => A): (A, Timings) = {
ExecutionTimer.time { timer =>
ExecuteContext.scoped(
tmpdir = String,
localTmpdir = String,
backend = Py4JBackendApi.this.backend,
fs = Py4JBackendApi.this.fs,
timer = timer,
tempFileManager = null,
theHailClassLoader = Py4JBackendApi.this.theHailClassLoader,
flags = Py4JBackendApi.this.flags,
irMetadata = IrMetadata(None),
references = Py4JBackendApi.this.references,
blockMatrixCache = Py4JBackendApi.this.bmCache,
codeCache = Py4JBackendApi.this.codeCache,
irCache = Py4JBackendApi.this.persistedIr,
coercerCache = Py4JBackendApi.this.coercerCache,
)(f)
}
}
}
}
}
Loading

0 comments on commit fb3cc0d

Please sign in to comment.