From 43cada30eaf194b436ad26f4962394060b297423 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Fri, 20 Sep 2024 18:06:07 -0400 Subject: [PATCH 1/4] [query] Lift backend state into `{Service|Py4J}BackendApi` --- hail/python/hail/backend/local_backend.py | 1 - hail/python/hail/backend/py4j_backend.py | 20 +- .../main/scala/is/hail/backend/Backend.scala | 6 +- .../scala/is/hail/backend/BackendRpc.scala | 4 - .../scala/is/hail/backend/BackendServer.scala | 123 ------ .../is/hail/backend/ExecuteContext.scala | 12 +- .../is/hail/backend/api/Py4JBackendApi.scala | 404 ++++++++++++++++++ .../hail/backend/api/ServiceBackendApi.scala | 267 ++++++++++++ .../is/hail/backend/local/LocalBackend.scala | 112 +---- .../backend/py4j/Py4JBackendExtensions.scala | 239 ----------- .../scala/is/hail/backend/service/Main.scala | 4 +- .../hail/backend/service/ServiceBackend.scala | 335 ++------------- .../is/hail/backend/service/Worker.scala | 14 +- .../is/hail/backend/spark/SparkBackend.scala | 162 ++----- .../src/main/scala/is/hail/expr/ir/Emit.scala | 2 +- .../main/scala/is/hail/io/vcf/LoadVCF.scala | 2 +- .../main/scala/is/hail/utils/package.scala | 25 +- hail/src/test/scala/is/hail/HailSuite.scala | 270 ++++++------ hail/src/test/scala/is/hail/TestUtils.scala | 104 ++--- .../test/scala/is/hail/TestUtilsSuite.scala | 24 +- .../is/hail/annotations/UnsafeSuite.scala | 2 +- .../is/hail/backend/ServiceBackendSuite.scala | 237 +++++----- .../is/hail/expr/ir/ArrayFunctionsSuite.scala | 1 - .../is/hail/expr/ir/CallFunctionsSuite.scala | 4 +- .../is/hail/expr/ir/DictFunctionsSuite.scala | 1 - .../is/hail/expr/ir/EmitStreamSuite.scala | 16 +- .../is/hail/expr/ir/ForwardLetsSuite.scala | 1 - .../hail/expr/ir/GenotypeFunctionsSuite.scala | 1 - .../test/scala/is/hail/expr/ir/IRSuite.scala | 83 ++-- .../scala/is/hail/expr/ir/IntervalSuite.scala | 1 - .../is/hail/expr/ir/MathFunctionsSuite.scala | 1 - .../scala/is/hail/expr/ir/MatrixIRSuite.scala | 3 +- .../is/hail/expr/ir/MemoryLeakSuite.scala | 32 +- .../scala/is/hail/expr/ir/OrderingSuite.scala | 1 - .../is/hail/expr/ir/RequirednessSuite.scala | 3 +- .../is/hail/expr/ir/SetFunctionsSuite.scala | 1 - .../scala/is/hail/expr/ir/SimplifySuite.scala | 2 +- .../hail/expr/ir/StringFunctionsSuite.scala | 1 - .../is/hail/expr/ir/StringSliceSuite.scala | 1 - .../scala/is/hail/expr/ir/TableIRSuite.scala | 1 - .../scala/is/hail/expr/ir/TrapNodeSuite.scala | 1 - .../is/hail/expr/ir/UtilFunctionsSuite.scala | 1 - .../expr/ir/analyses/SemanticHashSuite.scala | 26 +- .../lowering/LowerDistributedSortSuite.scala | 9 +- .../is/hail/expr/ir/table/TableGenSuite.scala | 14 +- .../is/hail/io/compress/BGzipCodecSuite.scala | 1 - .../test/scala/is/hail/io/fs/FSSuite.scala | 7 +- .../is/hail/linalg/BlockMatrixSuite.scala | 26 +- .../is/hail/methods/LocalLDPruneSuite.scala | 4 +- .../scala/is/hail/methods/SkatSuite.scala | 4 +- .../scala/is/hail/stats/eigSymDSuite.scala | 4 +- .../scala/is/hail/utils/RichArraySuite.scala | 4 +- .../utils/RichDenseMatrixDoubleSuite.scala | 4 +- .../scala/is/hail/variant/GenotypeSuite.scala | 5 +- .../is/hail/variant/LocusIntervalSuite.scala | 14 +- .../hail/variant/ReferenceGenomeSuite.scala | 20 +- 56 files changed, 1256 insertions(+), 1411 deletions(-) delete mode 100644 hail/src/main/scala/is/hail/backend/BackendServer.scala create mode 100644 hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala create mode 100644 hail/src/main/scala/is/hail/backend/api/ServiceBackendApi.scala delete mode 100644 hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala diff --git a/hail/python/hail/backend/local_backend.py b/hail/python/hail/backend/local_backend.py index 7bcb1145259..0908ee06986 100644 --- a/hail/python/hail/backend/local_backend.py +++ b/hail/python/hail/backend/local_backend.py @@ -74,7 +74,6 @@ def __init__( hail_package = getattr(self._gateway.jvm, 'is').hail jbackend = hail_package.backend.local.LocalBackend.apply( - tmpdir, log, True, append, diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index 697269b152b..49838ea4a85 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -170,8 +170,15 @@ def parse(node): class Py4JBackend(Backend): @abc.abstractmethod - def __init__(self, jvm: JVMView, jbackend: JavaObject, jhc: JavaObject): - super(Py4JBackend, self).__init__() + def __init__( + self, + jvm: JVMView, + jbackend: JavaObject, + jhc: JavaObject, + tmpdir: str, + remote_tmpdir: str, + ): + super().__init__() import base64 def decode_bytearray(encoded): @@ -184,12 +191,11 @@ def decode_bytearray(encoded): self._jvm = jvm self._hail_package = getattr(self._jvm, 'is').hail self._utils_package_object = scala_package_object(self._hail_package.utils) - self._jbackend = jbackend self._jhc = jhc - self._backend_server = self._hail_package.backend.BackendServer(self._jbackend) - self._backend_server_port: int = self._backend_server.port() - self._backend_server.start() + self._jbackend = self._hail_package.backend.api.P4jBackendApi(jbackend) + self._jhttp_server = self._jbackend.pyHttpServer() + self._backend_server_port: int = self._jbackend.HttpServer.port() self._requests_session = requests.Session() # This has to go after creating the SparkSession. Unclear why. @@ -289,7 +295,7 @@ def _to_java_blockmatrix_ir(self, ir): return self._parse_blockmatrix_ir(self._render_ir(ir)) def stop(self): - self._backend_server.close() + self._jhttp_server.close() self._jbackend.close() self._jhc.stop() self._jhc = None diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index 91379ea33a2..4ca7147e25d 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -9,7 +9,6 @@ import is.hail.io.fs.FS import is.hail.types.RTable import is.hail.types.encoded.EType import is.hail.types.physical.PTuple -import is.hail.utils.ExecutionTimer.Timings import is.hail.utils.fatal import scala.reflect.ClassTag @@ -54,6 +53,7 @@ trait BackendContext { } abstract class Backend extends Closeable { + // From https://github.com/hail-is/hail/issues/14580 : // IR can get quite big, especially as it can contain an arbitrary // amount of encoded literals from the user's python session. This @@ -119,7 +119,7 @@ abstract class Backend extends Closeable { def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) : TableStage - def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) - def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] + + def backendContext(ctx: ExecuteContext): BackendContext } diff --git a/hail/src/main/scala/is/hail/backend/BackendRpc.scala b/hail/src/main/scala/is/hail/backend/BackendRpc.scala index 33cecbbb0fe..0dd89283c42 100644 --- a/hail/src/main/scala/is/hail/backend/BackendRpc.scala +++ b/hail/src/main/scala/is/hail/backend/BackendRpc.scala @@ -244,8 +244,4 @@ trait HttpLikeBackendRpc[A] extends BackendRpc { ) } } - - implicit protected def Ask: Routing - implicit protected def Write: Write[A] - implicit protected def Context: Context[A] } diff --git a/hail/src/main/scala/is/hail/backend/BackendServer.scala b/hail/src/main/scala/is/hail/backend/BackendServer.scala deleted file mode 100644 index db23bcb310f..00000000000 --- a/hail/src/main/scala/is/hail/backend/BackendServer.scala +++ /dev/null @@ -1,123 +0,0 @@ -package is.hail.backend - -import is.hail.types.virtual.Kinds.{BlockMatrix, Matrix, Table, Value} -import is.hail.utils._ -import is.hail.utils.ExecutionTimer.Timings - -import java.io.Closeable -import java.net.InetSocketAddress -import java.util.concurrent._ - -import com.google.api.client.http.HttpStatusCodes -import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer} -import org.json4s._ -import org.json4s.jackson.{JsonMethods, Serialization} - -class BackendServer(backend: Backend) extends Closeable { - // 0 => let the OS pick an available port - private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) - - private[this] val thread = { - // This HTTP server *must not* start non-daemon threads because such threads keep the JVM - // alive. A living JVM indicates to Spark that the job is incomplete. This does not manifest - // when you run jobs in a local pyspark (because you'll Ctrl-C out of Python regardless of the - // JVM's state) nor does it manifest in a Notebook (again, you'll kill the Notebook kernel - // explicitly regardless of the JVM). It *does* manifest when submitting jobs with - // - // gcloud dataproc submit ... - // - // or - // - // spark-submit - // - // setExecutor(null) ensures the server creates no new threads: - // - // > If this method is not called (before start()) or if it is called with a null Executor, then - // > a default implementation is used, which uses the thread which was created by the start() - // > method. - // - /* Source: - * https://docs.oracle.com/en/java/javase/11/docs/api/jdk.httpserver/com/sun/net/httpserver/HttpServer.html#setExecutor(java.util.concurrent.Executor) */ - // - httpServer.createContext("/", Handler) - httpServer.setExecutor(null) - val t = Executors.defaultThreadFactory().newThread(new Runnable() { - def run(): Unit = - httpServer.start() - }) - t.setDaemon(true) - t - } - - def port: Int = httpServer.getAddress.getPort - - def start(): Unit = - thread.start() - - override def close(): Unit = - httpServer.stop(10) - - private case class Request(exchange: HttpExchange, payload: JValue) - - private[this] object Handler extends HttpHandler with HttpLikeBackendRpc[Request] { - - override def handle(exchange: HttpExchange): Unit = { - val payload = using(exchange.getRequestBody)(JsonMethods.parse(_)) - runRpc(Request(exchange, payload)) - } - - implicit override protected object Ask extends Routing { - - import Routes._ - - override def route(a: Request): Route = - a.exchange.getRequestURI.getPath match { - case "/value/type" => TypeOf(Value) - case "/table/type" => TypeOf(Table) - case "/matrixtable/type" => TypeOf(Matrix) - case "/blockmatrix/type" => TypeOf(BlockMatrix) - case "/execute" => Execute - case "/vcf/metadata/parse" => ParseVcfMetadata - case "/fam/import" => ImportFam - case "/references/load" => LoadReferencesFromDataset - case "/references/from_fasta" => LoadReferencesFromFASTA - } - - override def payload(a: Request): JValue = a.payload - } - - implicit override protected object Write extends Write[Request] with ErrorHandling { - - override def timings(req: Request)(t: Timings): Unit = { - val ts = Serialization.write(Map("timings" -> t)) - req.exchange.getResponseHeaders.add("X-Hail-Timings", ts) - } - - override def result(req: Request)(result: Array[Byte]): Unit = - respond(req)(HttpStatusCodes.STATUS_CODE_OK, result) - - override def error(req: Request)(t: Throwable): Unit = - respond(req)( - HttpStatusCodes.STATUS_CODE_SERVER_ERROR, - jsonToBytes { - val (shortMessage, expandedMessage, errorId) = handleForPython(t) - JObject( - "short" -> JString(shortMessage), - "expanded" -> JString(expandedMessage), - "error_id" -> JInt(errorId), - ) - }, - ) - - private[this] def respond(req: Request)(code: Int, payload: Array[Byte]): Unit = { - req.exchange.sendResponseHeaders(code, payload.length) - using(req.exchange.getResponseBody)(_.write(payload)) - } - } - - implicit override protected object Context extends Context[Request] { - override def scoped[A](req: Request)(f: ExecuteContext => A): (A, Timings) = - backend.withExecuteContext(f) - } - } -} diff --git a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala index ac66ed222b4..a4dc9c3aa2f 100644 --- a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala +++ b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala @@ -1,6 +1,6 @@ package is.hail.backend -import is.hail.{HailContext, HailFeatureFlags} +import is.hail.HailFeatureFlags import is.hail.annotations.{Region, RegionPool} import is.hail.asm4s.HailClassLoader import is.hail.backend.local.LocalTaskContext @@ -55,11 +55,6 @@ object NonOwningTempFileManager { } object ExecuteContext { - def scoped[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - HailContext.sparkBackend.withExecuteContext( - selfContainedExecution = false - )(f) - def scoped[T]( tmpdir: String, localTmpdir: String, @@ -69,7 +64,6 @@ object ExecuteContext { tempFileManager: TempFileManager, theHailClassLoader: HailClassLoader, flags: HailFeatureFlags, - backendContext: BackendContext, irMetadata: IrMetadata, references: mutable.Map[String, ReferenceGenome], blockMatrixCache: mutable.Map[String, BlockMatrix], @@ -91,7 +85,6 @@ object ExecuteContext { tempFileManager, theHailClassLoader, flags, - backendContext, irMetadata, references, blockMatrixCache, @@ -124,7 +117,6 @@ class ExecuteContext( _tempFileManager: TempFileManager, val theHailClassLoader: HailClassLoader, val flags: HailFeatureFlags, - val backendContext: BackendContext, var irMetadata: IrMetadata, val References: mutable.Map[String, ReferenceGenome], val BlockMatrixCache: mutable.Map[String, BlockMatrix], @@ -197,7 +189,6 @@ class ExecuteContext( tempFileManager: TempFileManager = NonOwningTempFileManager(this.tempFileManager), theHailClassLoader: HailClassLoader = this.theHailClassLoader, flags: HailFeatureFlags = this.flags, - backendContext: BackendContext = this.backendContext, references: mutable.Map[String, ReferenceGenome] = this.References, irMetadata: IrMetadata = this.irMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache, @@ -217,7 +208,6 @@ class ExecuteContext( tempFileManager, theHailClassLoader, flags, - backendContext, irMetadata, references, blockMatrixCache, diff --git a/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala new file mode 100644 index 00000000000..bfce7175d48 --- /dev/null +++ b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala @@ -0,0 +1,404 @@ +package is.hail.backend.api + +import is.hail.HailFeatureFlags +import is.hail.asm4s.HailClassLoader +import is.hail.backend._ +import is.hail.backend.caching.BlockMatrixCache +import is.hail.backend.spark.SparkBackend +import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} +import is.hail.expr.ir.{ + BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser, + Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue, +} +import is.hail.expr.ir.IRParser.parseType +import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer +import is.hail.expr.ir.functions.IRFunctionRegistry +import is.hail.expr.ir.lowering.IrMetadata +import is.hail.io.fs._ +import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} +import is.hail.linalg.RowMatrix +import is.hail.types.physical.PStruct +import is.hail.types.virtual.{TArray, TInterval} +import is.hail.types.virtual.Kinds.{BlockMatrix, Matrix, Table, Value} +import is.hail.utils._ +import is.hail.utils.ExecutionTimer.Timings +import is.hail.variant.ReferenceGenome + +import scala.collection.mutable +import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter} + +import java.io.Closeable +import java.net.InetSocketAddress +import java.util +import java.util.concurrent._ + +import com.google.api.client.http.HttpStatusCodes +import com.sun.net.httpserver.{HttpExchange, HttpServer} +import org.apache.hadoop +import org.apache.hadoop.conf.Configuration +import org.apache.spark.sql.DataFrame +import org.json4s +import org.json4s._ +import org.json4s.jackson.{JsonMethods, Serialization} +import sourcecode.Enclosing + +final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling { + + private[this] val tmpdir: String = ??? + private[this] val localTmpdir: String = ??? + private[this] val longLifeTempFileManager = null: TempFileManager + private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() + private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader) + private[this] val references = mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*) + private[this] val bmCache = new BlockMatrixCache() + private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) + private[this] val persistedIr = mutable.Map[Int, BaseIR]() + private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) + + private[this] def cloudfsConfig = CloudStorageFSConfig.fromFlagsAndEnv(None, flags) + + def fs: FS = backend match { + case s: SparkBackend => + val conf = new Configuration(s.sc.hadoopConfiguration) + cloudfsConfig.google.flatMap(_.requester_pays_config).foreach { + case RequesterPaysConfig(prj, bkts) => + bkts + .map { buckets => + conf.set("fs.gs.requester.pays.mode", "CUSTOM") + conf.set("fs.gs.requester.pays.project.id", prj) + conf.set("fs.gs.requester.pays.buckets", buckets.mkString(",")) + } + .getOrElse { + conf.set("fs.gs.requester.pays.mode", "AUTO") + conf.set("fs.gs.requester.pays.project.id", prj) + } + } + new HadoopFS(new SerializableHadoopConfiguration(conf)) + + case _ => + RouterFS.buildRoutes(cloudfsConfig) + } + + def pyGetFlag(name: String): String = + flags.get(name) + + def pySetFlag(name: String, value: String): Unit = + flags.set(name, value) + + def pyAvailableFlags: java.util.ArrayList[String] = + flags.available + + private[this] var irID: Int = 0 + + private[this] def nextIRID(): Int = { + irID += 1 + irID + } + + private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = { + val id = nextIRID() + ctx.IrCache += (id -> ir) + id + } + + def pyRemoveJavaIR(id: Int): Unit = + persistedIr.remove(id) + + def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = + references(name).addSequence(IndexedFastaSequenceFile(fs, fastaFile, indexFile)) + + def pyRemoveSequence(name: String): Unit = + references(name).removeSequence() + + def pyExportBlockMatrix( + pathIn: String, + pathOut: String, + delimiter: String, + header: String, + addIndex: Boolean, + exportType: String, + partitionSize: java.lang.Integer, + entries: String, + ): Unit = + withExecuteContext() { ctx => + val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize) + entries match { + case "full" => + rm.export(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "lower" => + rm.exportLowerTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "strict_lower" => + rm.exportStrictLowerTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + case "upper" => + rm.exportUpperTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "strict_upper" => + rm.exportStrictUpperTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + } + } + + def pyRegisterIR( + name: String, + typeParamStrs: java.util.ArrayList[String], + argNameStrs: java.util.ArrayList[String], + argTypeStrs: java.util.ArrayList[String], + returnType: String, + bodyStr: String, + ): Unit = + withExecuteContext() { ctx => + IRFunctionRegistry.registerIR( + ctx, + name, + typeParamStrs.asScala.toArray, + argNameStrs.asScala.toArray, + argTypeStrs.asScala.toArray, + returnType, + bodyStr, + ) + } + + def pyExecuteLiteral(irStr: String): Int = + withExecuteContext() { ctx => + val ir = IRParser.parse_value_ir(ctx, irStr) + backend.execute(ctx, ir) match { + case Left(_) => throw new HailException("Can't create literal") + case Right((pt, addr)) => + val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) + addJavaIR(ctx, field) + } + }._1 + + def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = + withExecuteContext(selfContainedExecution = false) { ctx => + val key = jKey.asScala.toArray.toFastSeq + val signature = + SparkAnnotationImpex.importType(df.schema).setRequired(true).asInstanceOf[PStruct] + val tir = TableLiteral( + TableValue( + ctx, + signature.virtualType, + key, + df.rdd, + Some(signature), + ), + ctx.theHailClassLoader, + ) + val id = addJavaIR(ctx, tir) + (id, JsonMethods.compact(tir.typ.toJSON)) + }._1 + + def pyToDF(s: String): DataFrame = + withExecuteContext() { ctx => + val tir = IRParser.parse_table_ir(ctx, s) + Interpret(tir, ctx).toDF() + }._1 + + def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] = + withExecuteContext() { ctx => + log.info("pyReadMultipleMatrixTables: got query") + val kvs = JsonMethods.parse(jsonQuery) match { + case json4s.JObject(values) => values.toMap + } + + val paths = kvs("paths").asInstanceOf[json4s.JArray].arr.toArray.map { + case json4s.JString(s) => s + } + + val intervalPointType = parseType(kvs("intervalPointType").asInstanceOf[json4s.JString].s) + val intervalObjects = + JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType))) + .asInstanceOf[IndexedSeq[Interval]] + + val opts = NativeReaderOptions(intervalObjects, intervalPointType) + val matrixReaders: IndexedSeq[MatrixIR] = paths.map { p => + log.info(s"creating MatrixRead node for $p") + val mnr = MatrixNativeReader(ctx.fs, p, Some(opts)) + MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR + } + log.info("pyReadMultipleMatrixTables: returning N matrix tables") + matrixReaders.asJava + }._1 + + def pyAddReference(jsonConfig: String): Unit = + addReference(ReferenceGenome.fromJSON(jsonConfig)) + + def pyRemoveReference(name: String): Unit = + removeReference(name) + + def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = + references(name).addLiftover(references(destRGName), LiftOver(fs, chainFile)) + + def pyRemoveLiftover(name: String, destRGName: String): Unit = + references(name).removeLiftover(destRGName) + + def parse_blockmatrix_ir(s: String): BlockMatrixIR = + withExecuteContext(selfContainedExecution = false) { ctx => + IRParser.parse_blockmatrix_ir(ctx, s) + }._1 + + private[this] def addReference(rg: ReferenceGenome): Unit = { + references.get(rg.name) match { + case Some(rg2) => + if (rg != rg2) { + fatal( + s"Cannot add reference genome '${rg.name}', a different reference with that name already exists. Choose a reference name NOT in the following list:\n " + + s"@1", + references.keys.truncatable("\n "), + ) + } + case None => + references += (rg.name -> rg) + } + } + + private[this] def removeReference(name: String): Unit = + references -= name + + private def withExecuteContext[T]( + selfContainedExecution: Boolean = true + )( + f: ExecuteContext => T + )(implicit E: Enclosing + ): (T, Timings) = + ExecutionTimer.time { timer => + ExecuteContext.scoped( + tmpdir = tmpdir, + localTmpdir = localTmpdir, + backend = backend, + fs = fs, + timer = timer, + tempFileManager = + if (selfContainedExecution) null + else NonOwningTempFileManager(longLifeTempFileManager), + theHailClassLoader = theHailClassLoader, + flags = flags, + irMetadata = IrMetadata(None), + references = references, + blockMatrixCache = bmCache, + codeCache = codeCache, + irCache = persistedIr, + coercerCache = coercerCache, + )(f) + } + + override def close(): Unit = + synchronized { + bmCache.close() + codeCache.clear() + persistedIr.clear() + coercerCache.clear() + backend.close() + + if (backend.isInstanceOf[SparkBackend]) { + // Hadoop does not honor the hadoop configuration as a component of the cache key for file + // systems, so we blow away the cache so that a new configuration can successfully take + // effect. + // https://github.com/hail-is/hail/pull/12133#issuecomment-1241322443 + hadoop.fs.FileSystem.closeAll() + } + } + + def pyHttpServer: HttpLikeBackendRpc[HttpExchange] with Closeable = + new HttpLikeBackendRpc[HttpExchange] with Closeable { + implicit object Handler extends Routing with Write[HttpExchange] with Context[HttpExchange] { + + override def route(req: HttpExchange): Route = + req.getRequestURI.getPath match { + case "/value/type" => Routes.TypeOf(Value) + case "/table/type" => Routes.TypeOf(Table) + case "/matrixtable/type" => Routes.TypeOf(Matrix) + case "/blockmatrix/type" => Routes.TypeOf(BlockMatrix) + case "/execute" => Routes.Execute + case "/vcf/metadata/parse" => Routes.ParseVcfMetadata + case "/fam/import" => Routes.ImportFam + case "/references/load" => Routes.LoadReferencesFromDataset + case "/references/from_fasta" => Routes.LoadReferencesFromFASTA + } + + override def payload(req: HttpExchange): JValue = + using(req.getRequestBody)(JsonMethods.parse(_)) + + override def timings(req: HttpExchange)(t: Timings): Unit = { + val ts = Serialization.write(Map("timings" -> t)) + req.getResponseHeaders.add("X-Hail-Timings", ts) + } + + override def result(req: HttpExchange)(result: Array[Byte]): Unit = + respond(req)(HttpStatusCodes.STATUS_CODE_OK, result) + + override def error(req: HttpExchange)(t: Throwable): Unit = + respond(req)( + HttpStatusCodes.STATUS_CODE_SERVER_ERROR, + jsonToBytes { + val (shortMessage, expandedMessage, errorId) = handleForPython(t) + JObject( + "short" -> JString(shortMessage), + "expanded" -> JString(expandedMessage), + "error_id" -> JInt(errorId), + ) + }, + ) + + private[this] def respond(req: HttpExchange)(code: Int, payload: Array[Byte]): Unit = { + req.sendResponseHeaders(code, payload.length) + using(req.getResponseBody)(_.write(payload)) + } + + override def scoped[A](req: HttpExchange)(f: ExecuteContext => A): (A, Timings) = + withExecuteContext()(f) + } + + // 0 => let the OS pick an available port + private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) + + private[this] val thread = { + // This HTTP server *must not* start non-daemon threads because such threads keep the JVM + // alive. A living JVM indicates to Spark that the job is incomplete. This does not manifest + /* when you run jobs in a local pyspark (because you'll Ctrl-C out of Python regardless of + * the */ + // JVM's state) nor does it manifest in a Notebook (again, you'll kill the Notebook kernel + // explicitly regardless of the JVM). It *does* manifest when submitting jobs with + // + // gcloud dataproc submit ... + // + // or + // + // spark-submit + // + // setExecutor(null) ensures the server creates no new threads: + // + /* > If this method is not called (before start()) or if it is called with a null Executor, + * then */ + /* > a default implementation is used, which uses the thread which was created by the + * start() */ + // > method. + // + // Source: + /* https://docs.oracle.com/en/java/javase/11/docs/api/jdk.httpserver/com/sun/net/httpserver/HttpServer.html#setExecutor(java.util.concurrent.Executor) */ + // + httpServer.createContext("/", runRpc(_: HttpExchange)) + httpServer.setExecutor(null) + val t = Executors.defaultThreadFactory().newThread(() => httpServer.start()) + t.setDaemon(true) + t + } + + def port: Int = httpServer.getAddress.getPort + override def close(): Unit = httpServer.stop(10) + + thread.start() + } +} diff --git a/hail/src/main/scala/is/hail/backend/api/ServiceBackendApi.scala b/hail/src/main/scala/is/hail/backend/api/ServiceBackendApi.scala new file mode 100644 index 00000000000..4e5659dc13b --- /dev/null +++ b/hail/src/main/scala/is/hail/backend/api/ServiceBackendApi.scala @@ -0,0 +1,267 @@ +package is.hail.backend.api + +import is.hail.{HailContext, HailFeatureFlags} +import is.hail.annotations.Memory +import is.hail.asm4s.HailClassLoader +import is.hail.backend.{Backend, ExecuteContext, HttpLikeBackendRpc} +import is.hail.backend.caching.NoCaching +import is.hail.backend.service._ +import is.hail.expr.ir.lowering.IrMetadata +import is.hail.io.fs.{CloudStorageFSConfig, FS, RouterFS} +import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} +import is.hail.services._ +import is.hail.types.virtual.Kinds +import is.hail.utils.{toRichIterable, using, ErrorHandling, ExecutionTimer, HailWorkerException, Logging} +import is.hail.utils.ExecutionTimer.Timings +import is.hail.variant.ReferenceGenome + +import scala.annotation.switch +import scala.collection.mutable + +import java.io.OutputStream +import java.nio.charset.StandardCharsets +import java.nio.file.Path + +import org.json4s.JsonAST.JValue +import org.json4s.jackson.JsonMethods + +object ServiceBackendApi extends HttpLikeBackendRpc[Request] with Logging { + + implicit object Handler + extends Routing with Write[Request] with Context[Request] with ErrorHandling { + import Routes._ + + override def route(a: Request): Route = + (a.action: @switch) match { + case 1 => TypeOf(Kinds.Value) + case 2 => TypeOf(Kinds.Table) + case 3 => TypeOf(Kinds.Matrix) + case 4 => TypeOf(Kinds.BlockMatrix) + case 5 => Execute + case 6 => ParseVcfMetadata + case 7 => ImportFam + case 8 => LoadReferencesFromDataset + case 9 => LoadReferencesFromFASTA + } + + override def payload(a: Request): JValue = a.payload + + // service backend doesn't support sending timings back to the python client + override def timings(env: Request)(t: Timings): Unit = + () + + override def result(env: Request)(result: Array[Byte]): Unit = + retryTransientErrors { + using(env.fs.createNoCompression(env.outputUrl)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(true) + output.writeBytes(result) + } + } + + override def error(env: Request)(t: Throwable): Unit = + retryTransientErrors { + val (shortMessage, expandedMessage, errorId) = + t match { + case t: HailWorkerException => + log.error( + "A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job.", + t, + ) + (t.shortMessage, t.expandedMessage, t.errorId) + case _ => + log.error( + "An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job.", + t, + ) + handleForPython(t) + } + + using(env.fs.createNoCompression(env.outputUrl)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(false) + output.writeString(shortMessage) + output.writeString(expandedMessage) + output.writeInt(errorId) + } + + throw t + } + + override def scoped[A](env: Request)(f: ExecuteContext => A): (A, Timings) = + ExecutionTimer.time { timer => + ExecuteContext.scoped( + env.rpcConfig.tmp_dir, + env.rpcConfig.remote_tmpdir, + env.backend, + env.fs, + timer, + null, + env.hcl, + env.flags, + IrMetadata(None), + mutable.Map(env.references.toSeq: _*), + NoCaching, + NoCaching, + NoCaching, + NoCaching, + )(f) + } + } + + def main(argv: Array[String]): Unit = { + assert(argv.length == 7, argv.toFastSeq) + + val scratchDir = argv(0) + // val logFile = argv(1) + val jarLocation = argv(2) + val kind = argv(3) + assert(kind == Main.DRIVER) + val name = argv(4) + val inputURL = argv(5) + val outputURL = argv(6) + + val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") + DeployConfig.set(deployConfig) + sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) + + var fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + HailFeatureFlags.fromEnv(), + ) + ) + + val (rpcConfig, jobConfig, action, payload) = + using(fs.openNoCompression(inputURL)) { is => + val input = JsonMethods.parse(is) + ( + (input \ "rpc_config").extract[ServiceBackendRPCPayload], + (input \ "job_config").extract[BatchJobConfig], + (input \ "action").extract[Int], + input \ "payload", + ) + } + + // requester pays config is conveyed in feature flags currently + val featureFlags = HailFeatureFlags.fromEnv(rpcConfig.flags) + fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + featureFlags, + ) + ) + + val references: Map[String, ReferenceGenome] = + ReferenceGenome.builtinReferences() ++ + rpcConfig.custom_references.map(ReferenceGenome.fromJSON).map(rg => rg.name -> rg) + + rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => + liftoversForSource.foreach { case (destGenome, chainFile) => + references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) + } + } + + rpcConfig.sequences.foreach { case (rg, seq) => + references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) + } + + val backend = new ServiceBackend( + name, + BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")), + jarLocation, + BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")), + jobConfig, + ) + + log.info("ServiceBackend allocated.") + if (HailContext.isInitialized) { + HailContext.get.backend = backend + log.info("Default references added to already initialized HailContexet.") + } else { + HailContext(backend, 50, 3) + log.info("HailContexet initialized.") + } + + // FIXME: when can the classloader be shared? (optimizer benefits!) + runRpc( + Request( + backend, + featureFlags, + new HailClassLoader(getClass.getClassLoader), + rpcConfig, + fs, + references, + outputURL, + action, + payload, + ) + ) + } +} + +case class Request( + backend: Backend, + flags: HailFeatureFlags, + hcl: HailClassLoader, + rpcConfig: ServiceBackendRPCPayload, + fs: FS, + references: Map[String, ReferenceGenome], + outputUrl: String, + action: Int, + payload: JValue, +) + +private class HailSocketAPIOutputStream( + private[this] val out: OutputStream +) extends AutoCloseable { + private[this] var closed: Boolean = false + private[this] val dummy = new Array[Byte](8) + + def writeBool(b: Boolean): Unit = + out.write(if (b) 1 else 0) + + def writeInt(v: Int): Unit = { + Memory.storeInt(dummy, 0, v) + out.write(dummy, 0, 4) + } + + def writeLong(v: Long): Unit = { + Memory.storeLong(dummy, 0, v) + out.write(dummy) + } + + def writeBytes(bytes: Array[Byte]): Unit = { + writeInt(bytes.length) + out.write(bytes) + } + + def writeString(s: String): Unit = writeBytes(s.getBytes(StandardCharsets.UTF_8)) + + def close(): Unit = + if (!closed) { + out.close() + closed = true + } +} + +case class SequenceConfig(fasta: String, index: String) + +case class ServiceBackendRPCPayload( + tmp_dir: String, + remote_tmpdir: String, + flags: Map[String, String], + custom_references: Array[String], + liftovers: Map[String, Map[String, String]], + sequences: Map[String, SequenceConfig], +) + +case class BatchJobConfig( + token: String, + billing_project: String, + worker_cores: String, + worker_memory: String, + storage: String, + cloudfuse_configs: Array[CloudfuseConfig], + regions: Array[String], +) diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index c9f4d1b1d66..53f34869f79 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -1,13 +1,11 @@ package is.hail.backend.local -import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} +import is.hail.{CancellingExecutorService, HailContext} import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend._ -import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ -import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering._ @@ -17,17 +15,12 @@ import is.hail.types.physical.PTuple import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual.TVoid import is.hail.utils._ -import is.hail.utils.ExecutionTimer.Timings -import is.hail.variant.ReferenceGenome -import scala.collection.mutable import scala.reflect.ClassTag import java.io.PrintWriter import com.google.common.util.concurrent.MoreExecutors -import org.apache.hadoop -import sourcecode.Enclosing class LocalBroadcastValue[T](val value: T) extends BroadcastValue[T] with Serializable @@ -35,95 +28,35 @@ class LocalTaskContext(val partitionId: Int, val stageId: Int) extends HailTaskC override def attemptNumber(): Int = 0 } -object LocalBackend { - private var theLocalBackend: LocalBackend = _ +object LocalBackend extends Backend { def apply( - tmpdir: String, logFile: String = "hail.log", quiet: Boolean = false, append: Boolean = false, skipLoggingConfiguration: Boolean = false, - ): LocalBackend = synchronized { - require(theLocalBackend == null) - - if (!skipLoggingConfiguration) - HailContext.configureLogging(logFile, quiet, append) - - theLocalBackend = new LocalBackend( - tmpdir, - mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*), - ) - - theLocalBackend - } - - def stop(): Unit = synchronized { - if (theLocalBackend != null) { - theLocalBackend = null - // Hadoop does not honor the hadoop configuration as a component of the cache key for file - // systems, so we blow away the cache so that a new configuration can successfully take - // effect. - // https://github.com/hail-is/hail/pull/12133#issuecomment-1241322443 - hadoop.fs.FileSystem.closeAll() + ): LocalBackend.type = + synchronized { + if (!skipLoggingConfiguration) HailContext.configureLogging(logFile, quiet, append) + this } - } -} -class LocalBackend( - val tmpdir: String, - override val references: mutable.Map[String, ReferenceGenome], -) extends Backend with Py4JBackendExtensions { - - override def backend: Backend = this - override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() - override def longLifeTempFileManager: TempFileManager = null - - private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader) - private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) - private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map() - private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) - - // flags can be set after construction from python - def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) - - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = - ExecutionTimer.time { timer => - val fs = this.fs - ExecuteContext.scoped( - tmpdir, - tmpdir, - this, - fs, - timer, - null, - theHailClassLoader, - flags, - new BackendContext { - override val executionCache: ExecutionCache = - ExecutionCache.fromFlags(flags, fs, tmpdir) - }, - IrMetadata(None), - references, - ImmutableMap.empty, - codeCache, - persistedIR, - coercerCache, - )(f) - } + private case class Context(hcl: HailClassLoader, override val executionCache: ExecutionCache) + extends BackendContext def broadcast[T: ClassTag](value: T): BroadcastValue[T] = new LocalBroadcastValue[T](value) private[this] var stageIdx: Int = 0 - private[this] def nextStageId(): Int = { - val current = stageIdx - stageIdx += 1 - current - } + private[this] def nextStageId(): Int = + synchronized { + val current = stageIdx + stageIdx += 1 + current + } override def parallelizeAndComputeWithIndex( - backendContext: BackendContext, + ctx: BackendContext, fs: FS, contexts: IndexedSeq[Array[Byte]], stageIdentifier: String, @@ -134,22 +67,24 @@ class LocalBackend( ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = { val stageId = nextStageId() + val hcl = ctx.asInstanceOf[Context].hcl runAllKeepFirstError(new CancellingExecutorService(MoreExecutors.newDirectExecutorService())) { partitions.getOrElse(contexts.indices).map { i => ( - () => - using(new LocalTaskContext(i, stageId)) { - f(contexts(i), _, theHailClassLoader, fs) - }, + () => using(new LocalTaskContext(i, stageId))(f(contexts(i), _, hcl, fs)), i, ) } } } + override def backendContext(ctx: ExecuteContext): BackendContext = + Context(ctx.theHailClassLoader, ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.tmpdir)) + def defaultParallelism: Int = 1 - def close(): Unit = LocalBackend.stop() + def close(): Unit = + synchronized { stageIdx = 0 } private[this] def _jvmLowerAndExecute( ctx: ExecuteContext, @@ -157,8 +92,7 @@ class LocalBackend( print: Option[PrintWriter] = None, ): Either[Unit, (PTuple, Long)] = ctx.time { - val ir = - LoweringPipeline.darrayLowerer(true)(DArrayLowering.All).apply(ctx, ir0).asInstanceOf[IR] + val ir = LoweringPipeline.darrayLowerer(true)(DArrayLowering.All)(ctx, ir0).asInstanceOf[IR] if (!Compilable(ir)) throw new LowererUnsupportedOperation(s"lowered to uncompilable IR: ${Pretty(ctx, ir)}") diff --git a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala deleted file mode 100644 index f2d972e570a..00000000000 --- a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala +++ /dev/null @@ -1,239 +0,0 @@ -package is.hail.backend.py4j - -import is.hail.HailFeatureFlags -import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager} -import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} -import is.hail.expr.ir.{ - BaseIR, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR, - MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue, -} -import is.hail.expr.ir.IRParser.parseType -import is.hail.expr.ir.functions.IRFunctionRegistry -import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} -import is.hail.linalg.RowMatrix -import is.hail.types.physical.PStruct -import is.hail.types.virtual.{TArray, TInterval} -import is.hail.utils.{fatal, log, toRichIterable, HailException, Interval} -import is.hail.variant.ReferenceGenome - -import scala.collection.mutable -import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter} - -import java.util - -import org.apache.spark.sql.DataFrame -import org.json4s -import org.json4s.jackson.JsonMethods -import sourcecode.Enclosing - -trait Py4JBackendExtensions { - def backend: Backend - def references: mutable.Map[String, ReferenceGenome] - def flags: HailFeatureFlags - def longLifeTempFileManager: TempFileManager - - def pyGetFlag(name: String): String = - flags.get(name) - - def pySetFlag(name: String, value: String): Unit = - flags.set(name, value) - - def pyAvailableFlags: java.util.ArrayList[String] = - flags.available - - private[this] var irID: Int = 0 - - private[this] def nextIRID(): Int = { - irID += 1 - irID - } - - private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = { - val id = nextIRID() - ctx.IrCache += (id -> ir) - id - } - - def pyRemoveJavaIR(id: Int): Unit = - backend.withExecuteContext(_.IrCache.remove(id)) - - def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = - backend.withExecuteContext { ctx => - references(name).addSequence(IndexedFastaSequenceFile(ctx.fs, fastaFile, indexFile)) - } - - def pyRemoveSequence(name: String): Unit = - references(name).removeSequence() - - def pyExportBlockMatrix( - pathIn: String, - pathOut: String, - delimiter: String, - header: String, - addIndex: Boolean, - exportType: String, - partitionSize: java.lang.Integer, - entries: String, - ): Unit = - backend.withExecuteContext { ctx => - val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize) - entries match { - case "full" => - rm.export(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "lower" => - rm.exportLowerTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "strict_lower" => - rm.exportStrictLowerTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - case "upper" => - rm.exportUpperTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "strict_upper" => - rm.exportStrictUpperTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - } - } - - def pyRegisterIR( - name: String, - typeParamStrs: java.util.ArrayList[String], - argNameStrs: java.util.ArrayList[String], - argTypeStrs: java.util.ArrayList[String], - returnType: String, - bodyStr: String, - ): Unit = - backend.withExecuteContext { ctx => - IRFunctionRegistry.registerIR( - ctx, - name, - typeParamStrs.asScala.toArray, - argNameStrs.asScala.toArray, - argTypeStrs.asScala.toArray, - returnType, - bodyStr, - ) - } - - def pyExecuteLiteral(irStr: String): Int = - backend.withExecuteContext { ctx => - val ir = IRParser.parse_value_ir(ctx, irStr) - backend.execute(ctx, ir) match { - case Left(_) => throw new HailException("Can't create literal") - case Right((pt, addr)) => - val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) - addJavaIR(ctx, field) - } - }._1 - - def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = { - val key = jKey.asScala.toArray.toFastSeq - val signature = - SparkAnnotationImpex.importType(df.schema).setRequired(true).asInstanceOf[PStruct] - withExecuteContext(selfContainedExecution = false) { ctx => - val tir = TableLiteral( - TableValue( - ctx, - signature.virtualType, - key, - df.rdd, - Some(signature), - ), - ctx.theHailClassLoader, - ) - val id = addJavaIR(ctx, tir) - (id, JsonMethods.compact(tir.typ.toJSON)) - } - } - - def pyToDF(s: String): DataFrame = - backend.withExecuteContext { ctx => - val tir = IRParser.parse_table_ir(ctx, s) - Interpret(tir, ctx).toDF() - }._1 - - def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] = - backend.withExecuteContext { ctx => - log.info("pyReadMultipleMatrixTables: got query") - val kvs = JsonMethods.parse(jsonQuery) match { - case json4s.JObject(values) => values.toMap - } - - val paths = kvs("paths").asInstanceOf[json4s.JArray].arr.toArray.map { - case json4s.JString(s) => s - } - - val intervalPointType = parseType(kvs("intervalPointType").asInstanceOf[json4s.JString].s) - val intervalObjects = - JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType))) - .asInstanceOf[IndexedSeq[Interval]] - - val opts = NativeReaderOptions(intervalObjects, intervalPointType) - val matrixReaders: IndexedSeq[MatrixIR] = paths.map { p => - log.info(s"creating MatrixRead node for $p") - val mnr = MatrixNativeReader(ctx.fs, p, Some(opts)) - MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR - } - log.info("pyReadMultipleMatrixTables: returning N matrix tables") - matrixReaders.asJava - }._1 - - def pyAddReference(jsonConfig: String): Unit = - addReference(ReferenceGenome.fromJSON(jsonConfig)) - - def pyRemoveReference(name: String): Unit = - removeReference(name) - - def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = - backend.withExecuteContext { ctx => - references(name).addLiftover(references(destRGName), LiftOver(ctx.fs, chainFile)) - } - - def pyRemoveLiftover(name: String, destRGName: String): Unit = - references(name).removeLiftover(destRGName) - - private[this] def addReference(rg: ReferenceGenome): Unit = { - references.get(rg.name) match { - case Some(rg2) => - if (rg != rg2) { - fatal( - s"Cannot add reference genome '${rg.name}', a different reference with that name already exists. Choose a reference name NOT in the following list:\n " + - s"@1", - references.keys.truncatable("\n "), - ) - } - case None => - references += (rg.name -> rg) - } - } - - private[this] def removeReference(name: String): Unit = - references -= name - - def parse_blockmatrix_ir(s: String): BlockMatrixIR = - withExecuteContext(selfContainedExecution = false) { ctx => - IRParser.parse_blockmatrix_ir(ctx, s) - } - - def withExecuteContext[T]( - selfContainedExecution: Boolean = true - )( - f: ExecuteContext => T - )(implicit E: Enclosing - ): T = - backend.withExecuteContext { ctx => - val tempFileManager = longLifeTempFileManager - if (selfContainedExecution && tempFileManager != null) f(ctx) - else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f) - }._1 -} diff --git a/hail/src/main/scala/is/hail/backend/service/Main.scala b/hail/src/main/scala/is/hail/backend/service/Main.scala index 698f5ffa23c..3558dc9a54e 100644 --- a/hail/src/main/scala/is/hail/backend/service/Main.scala +++ b/hail/src/main/scala/is/hail/backend/service/Main.scala @@ -1,5 +1,7 @@ package is.hail.backend.service +import is.hail.backend.api.ServiceBackendApi + object Main { val WORKER = "worker" val DRIVER = "driver" @@ -7,7 +9,7 @@ object Main { def main(argv: Array[String]): Unit = argv(3) match { case WORKER => Worker.main(argv) - case DRIVER => ServiceBackendAPI.main(argv) + case DRIVER => ServiceBackendApi.main(argv) case kind => throw new RuntimeException(s"unknown kind: $kind") } } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 0fc23636063..923e73535f1 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -1,10 +1,10 @@ package is.hail.backend.service -import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} +import is.hail.{CancellingExecutorService, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ -import is.hail.backend.caching.NoCaching +import is.hail.backend.api.BatchJobConfig import is.hail.backend.service.ServiceBackend.MaxAvailableGcsConnections import is.hail.expr.Validate import is.hail.expr.ir.{ @@ -14,36 +14,20 @@ import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering._ import is.hail.io.fs._ -import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} import is.hail.services._ import is.hail.services.JobGroupStates.Failure import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType -import is.hail.types.virtual.{Kinds, TVoid} +import is.hail.types.virtual.TVoid import is.hail.utils._ -import is.hail.utils.ExecutionTimer.Timings -import is.hail.variant.ReferenceGenome -import scala.annotation.switch -import scala.collection.mutable import scala.reflect.ClassTag import java.io._ import java.nio.charset.StandardCharsets -import java.nio.file.Path import java.util.concurrent._ -import org.json4s.JsonAST._ -import org.json4s.jackson.JsonMethods -import sourcecode.Enclosing - -case class ServiceBackendContext( - remoteTmpDir: String, - jobConfig: BatchJobConfig, - override val executionCache: ExecutionCache, -) extends BackendContext with Serializable - object ServiceBackend { val MaxAvailableGcsConnections = 1000 } @@ -52,17 +36,23 @@ class ServiceBackend( val name: String, batchClient: BatchClient, jarLocation: String, - theHailClassLoader: HailClassLoader, batchConfig: Option[BatchConfig], - rpcConfig: ServiceBackendRPCPayload, jobConfig: BatchJobConfig, - flags: HailFeatureFlags, - val fs: FS, - references: mutable.Map[String, ReferenceGenome], ) extends Backend with Logging { + case class Context( + remoteTmpDir: String, + batchConfig: BatchConfig, + jobConfig: BatchJobConfig, + flags: HailFeatureFlags, + override val executionCache: ExecutionCache, + ) extends BackendContext + private[this] var stageCount = 0 - private[this] val executor = Executors.newFixedThreadPool(MaxAvailableGcsConnections) + + private[this] val executor = lazily { + Executors.newFixedThreadPool(MaxAvailableGcsConnections) + } def defaultParallelism: Int = 4 @@ -80,6 +70,27 @@ class ServiceBackend( } } + override def backendContext(ctx: ExecuteContext): BackendContext = + Context( + remoteTmpDir = ctx.tmpdir, + batchConfig = batchConfig.getOrElse { + BatchConfig( + batchId = batchClient.newBatch( + BatchRequest( + billing_project = jobConfig.billing_project, + n_jobs = 0, + token = jobConfig.token, + attributes = Map("name" -> name), + ) + ), + jobGroupId = 0, + ) + }, + jobConfig = jobConfig, + flags = ctx.flags, + executionCache = ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.tmpdir), + ) + private[this] def readString(in: DataInputStream): String = { val n = in.readInt() val bytes = new Array[Byte](n) @@ -94,8 +105,8 @@ class ServiceBackend( stageIdentifier: String, f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte], ): (String, String, Int) = { - val ServiceBackendContext(remoteTmp, jobConfig, _) = - _backendContext.asInstanceOf[ServiceBackendContext] + val Context(remoteTmp, batchConfig, jobConfig, flags, _) = + _backendContext.asInstanceOf[Context] val n = collection.length val token = tokenUrlSafe val root = s"$remoteTmp/parallelizeAndComputeWithIndex/$token" @@ -128,7 +139,7 @@ class ServiceBackend( val jobGroup = JobGroupRequest( job_group_id = 1, // QoB creates an update for every new stage - absolute_parent_id = batchConfig.map(_.jobGroupId).getOrElse(0), + absolute_parent_id = batchConfig.jobGroupId, attributes = Map("name" -> stageIdentifier), ) @@ -164,29 +175,18 @@ class ServiceBackend( log.info(s"parallelizeAndComputeWithIndex: $token: running job") - val batchId = batchConfig.map(_.batchId).getOrElse { - batchClient.newBatch( - BatchRequest( - billing_project = jobConfig.billing_project, - n_jobs = 0, - token = token, - attributes = Map("name" -> name), - ) - ) - } - - val (updateId, jobGroupId) = batchClient.newJobGroup(batchId, token, jobGroup, jobs) - val response = batchClient.waitForJobGroup(batchId, jobGroupId) + val (updateId, jobGroupId) = batchClient.newJobGroup(batchConfig.batchId, token, jobGroup, jobs) + val response = batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId) stageCount += 1 if (response.state == Failure) { - throw new HailBatchFailure(s"Update $updateId for batch $batchId failed") + throw new HailBatchFailure(s"Update $updateId for batch ${batchConfig.batchId} failed") } (token, root, n) } - private[this] def readResult(root: String, i: Int): Array[Byte] = { + private[this] def readResult(fs: FS, root: String, i: Int): Array[Byte] = { val bytes = fs.readNoCompression(s"$root/result.$i") if (bytes(0) != 0) { bytes.slice(1, bytes.length) @@ -223,7 +223,7 @@ class ServiceBackend( val startTime = System.nanoTime() val r @ (error, results) = runAllKeepFirstError(new CancellingExecutorService(executor)) { (partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) => - (() => readResult(root, jobIndex), partIdx) + (() => readResult(fs, root, jobIndex), partIdx) } } @@ -237,7 +237,7 @@ class ServiceBackend( } override def close(): Unit = { - executor.shutdownNow() + if (executor.isEvaluated) executor.shutdownNow() batchClient.close() } @@ -296,251 +296,6 @@ class ServiceBackend( def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) : TableStage = LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses) - - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = - ExecutionTimer.time { timer => - ExecuteContext.scoped( - rpcConfig.tmp_dir, - rpcConfig.remote_tmpdir, - this, - fs, - timer, - null, - theHailClassLoader, - flags, - ServiceBackendContext( - rpcConfig.remote_tmpdir, - jobConfig, - ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), - ), - IrMetadata(None), - references, - NoCaching, - NoCaching, - NoCaching, - NoCaching, - )(f) - } } -class EndOfInputException extends RuntimeException class HailBatchFailure(message: String) extends RuntimeException(message) - -case class Request( - backend: ServiceBackend, - fs: FS, - outputUrl: String, - action: Int, - payload: JValue, -) - -object ServiceBackendAPI extends HttpLikeBackendRpc[Request] with Logging { - - def main(argv: Array[String]): Unit = { - assert(argv.length == 7, argv.toFastSeq) - - val scratchDir = argv(0) - // val logFile = argv(1) - val jarLocation = argv(2) - val kind = argv(3) - assert(kind == Main.DRIVER) - val name = argv(4) - val inputURL = argv(5) - val outputURL = argv(6) - - val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") - DeployConfig.set(deployConfig) - sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) - - var fs = RouterFS.buildRoutes( - CloudStorageFSConfig.fromFlagsAndEnv( - Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), - HailFeatureFlags.fromEnv(), - ) - ) - - val (rpcConfig, jobConfig, action, payload) = - using(fs.openNoCompression(inputURL)) { is => - val input = JsonMethods.parse(is) - ( - (input \ "rpc_config").extract[ServiceBackendRPCPayload], - (input \ "job_config").extract[BatchJobConfig], - (input \ "action").extract[Int], - input \ "payload", - ) - } - - // requester pays config is conveyed in feature flags currently - val featureFlags = HailFeatureFlags.fromEnv(rpcConfig.flags) - fs = RouterFS.buildRoutes( - CloudStorageFSConfig.fromFlagsAndEnv( - Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), - featureFlags, - ) - ) - - val references = mutable.Map[String, ReferenceGenome]() - references ++= ReferenceGenome.builtinReferences() - rpcConfig.custom_references.toFastSeq.view.map(ReferenceGenome.fromJSON).foreach { rg => - references += rg.name -> rg - } - - rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => - liftoversForSource.foreach { case (destGenome, chainFile) => - references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) - } - } - - rpcConfig.sequences.foreach { case (rg, seq) => - references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) - } - - // FIXME: when can the classloader be shared? (optimizer benefits!) - val backend = new ServiceBackend( - name, - BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")), - jarLocation, - new HailClassLoader(getClass.getClassLoader), - BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")), - rpcConfig, - jobConfig, - featureFlags, - fs, - references, - ) - - log.info("ServiceBackend allocated.") - if (HailContext.isInitialized) { - HailContext.get.backend = backend - log.info("Default references added to already initialized HailContexet.") - } else { - HailContext(backend, 50, 3) - log.info("HailContexet initialized.") - } - - runRpc(Request(backend, fs, outputURL, action, payload)) - } - - implicit override protected object Ask extends Routing { - import Routes._ - - override def route(a: Request): Route = - (a.action: @switch) match { - case 1 => TypeOf(Kinds.Value) - case 2 => TypeOf(Kinds.Table) - case 3 => TypeOf(Kinds.Matrix) - case 4 => TypeOf(Kinds.BlockMatrix) - case 5 => Execute - case 6 => ParseVcfMetadata - case 7 => ImportFam - case 8 => LoadReferencesFromDataset - case 9 => LoadReferencesFromFASTA - } - - override def payload(a: Request): JValue = a.payload - } - - implicit override protected object Write extends Write[Request] { - - // service backend doesn't support sending timings back to the python client - override def timings(env: Request)(t: Timings): Unit = - () - - override def result(env: Request)(result: Array[Byte]): Unit = - retryTransientErrors { - using(env.fs.createNoCompression(env.outputUrl)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(true) - output.writeBytes(result) - } - } - - override def error(env: Request)(t: Throwable): Unit = - retryTransientErrors { - val (shortMessage, expandedMessage, errorId) = - t match { - case t: HailWorkerException => - log.error( - "A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job.", - t, - ) - (t.shortMessage, t.expandedMessage, t.errorId) - case _ => - log.error( - "An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job.", - t, - ) - handleForPython(t) - } - - using(env.fs.createNoCompression(env.outputUrl)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(false) - output.writeString(shortMessage) - output.writeString(expandedMessage) - output.writeInt(errorId) - } - - throw t - } - } - - implicit override protected object Context extends Context[Request] { - override def scoped[A](env: Request)(f: ExecuteContext => A): (A, Timings) = - env.backend.withExecuteContext(f) - } -} - -private class HailSocketAPIOutputStream( - private[this] val out: OutputStream -) extends AutoCloseable { - private[this] var closed: Boolean = false - private[this] val dummy = new Array[Byte](8) - - def writeBool(b: Boolean): Unit = - out.write(if (b) 1 else 0) - - def writeInt(v: Int): Unit = { - Memory.storeInt(dummy, 0, v) - out.write(dummy, 0, 4) - } - - def writeLong(v: Long): Unit = { - Memory.storeLong(dummy, 0, v) - out.write(dummy) - } - - def writeBytes(bytes: Array[Byte]): Unit = { - writeInt(bytes.length) - out.write(bytes) - } - - def writeString(s: String): Unit = writeBytes(s.getBytes(StandardCharsets.UTF_8)) - - def close(): Unit = - if (!closed) { - out.close() - closed = true - } -} - -case class SequenceConfig(fasta: String, index: String) - -case class ServiceBackendRPCPayload( - tmp_dir: String, - remote_tmpdir: String, - flags: Map[String, String], - custom_references: Array[String], - liftovers: Map[String, Map[String, String]], - sequences: Map[String, SequenceConfig], -) - -case class BatchJobConfig( - token: String, - billing_project: String, - worker_cores: String, - worker_memory: String, - storage: String, - cloudfuse_configs: Array[CloudfuseConfig], - regions: Array[String], -) diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index cc461dcc5ca..0b3ce24cdc9 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -168,19 +168,7 @@ object Worker { timer.start("executeFunction") // FIXME: workers should not have backends, but some things do need hail contexts - val backend = new ServiceBackend( - null, - null, - null, - new HailClassLoader(getClass().getClassLoader()), - None, - null, - null, - null, - null, - null, - ) - + val backend = new ServiceBackend(null, null, null, null, null) if (HailContext.isInitialized) HailContext.get.backend = backend else HailContext(backend) diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index 1c136f3b1ff..1cb02f54509 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -1,14 +1,11 @@ package is.hail.backend.spark -import is.hail.{HailContext, HailFeatureFlags} +import is.hail.HailContext import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ -import is.hail.backend.caching.{BlockMatrixCache, NoCaching} -import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ -import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering._ @@ -20,8 +17,6 @@ import is.hail.types.physical.{PStruct, PTuple} import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual._ import is.hail.utils._ -import is.hail.utils.ExecutionTimer.Timings -import is.hail.variant.ReferenceGenome import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -31,8 +26,6 @@ import scala.util.control.NonFatal import java.io.PrintWriter -import org.apache.hadoop -import org.apache.hadoop.conf.Configuration import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -205,15 +198,11 @@ object SparkBackend { append: Boolean = false, skipLoggingConfiguration: Boolean = false, minBlockSize: Long = 1L, - tmpdir: String = "/tmp", - localTmpdir: String = "file:///tmp", - gcsRequesterPaysProject: String = null, - gcsRequesterPaysBuckets: String = null, ): SparkBackend = synchronized { if (theSparkBackend == null) return SparkBackend(sc, appName, master, local, logFile, quiet, append, skipLoggingConfiguration, - minBlockSize, tmpdir, localTmpdir, gcsRequesterPaysProject, gcsRequesterPaysBuckets) + minBlockSize) // there should be only one SparkContext assert(sc == null || (sc eq theSparkBackend.sc)) @@ -249,10 +238,6 @@ object SparkBackend { append: Boolean = false, skipLoggingConfiguration: Boolean = false, minBlockSize: Long = 1L, - tmpdir: String, - localTmpdir: String, - gcsRequesterPaysProject: String = null, - gcsRequesterPaysBuckets: String = null, ): SparkBackend = synchronized { require(theSparkBackend == null) @@ -267,32 +252,19 @@ object SparkBackend { checkSparkConfiguration(sc1) - if (!quiet) - ProgressBarBuilder.build(sc1) + if (!quiet) ProgressBarBuilder.build(sc1) sc1.uiWebUrl.foreach(ui => info(s"SparkUI: $ui")) - theSparkBackend = - new SparkBackend( - tmpdir, - localTmpdir, - sc1, - mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*), - gcsRequesterPaysProject, - gcsRequesterPaysBuckets, - ) + theSparkBackend = new SparkBackend(sc1) theSparkBackend } def stop(): Unit = synchronized { if (theSparkBackend != null) { + if (theSparkBackend.sparkSession.isEvaluated) theSparkBackend.sparkSession.close() theSparkBackend.sc.stop() theSparkBackend = null - // Hadoop does not honor the hadoop configuration as a component of the cache key for file - // systems, so we blow away the cache so that a new configuration can successfully take - // effect. - // https://github.com/hail-is/hail/pull/12133#issuecomment-1241322443 - hadoop.fs.FileSystem.closeAll() } } } @@ -303,104 +275,31 @@ class AnonymousDependency[T](val _rdd: RDD[T]) extends NarrowDependency[T](_rdd) override def getParents(partitionId: Int): Seq[Int] = Seq.empty } -class SparkBackend( - val tmpdir: String, - val localTmpdir: String, - val sc: SparkContext, - override val references: mutable.Map[String, ReferenceGenome], - gcsRequesterPaysProject: String, - gcsRequesterPaysBuckets: String, -) extends Backend with Py4JBackendExtensions { - - assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null) - lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate() +class SparkBackend(val sc: SparkContext) extends Backend { - private[this] val theHailClassLoader: HailClassLoader = - new HailClassLoader(getClass().getClassLoader()) - - override def canExecuteParallelTasksOnDriver: Boolean = false + private case class Context( + maxStageParallelism: Int, + override val executionCache: ExecutionCache, + ) extends BackendContext - val fs: HadoopFS = { - val conf = new Configuration(sc.hadoopConfiguration) - if (gcsRequesterPaysProject != null) { - if (gcsRequesterPaysBuckets == null) { - conf.set("fs.gs.requester.pays.mode", "AUTO") - conf.set("fs.gs.requester.pays.project.id", gcsRequesterPaysProject) - } else { - conf.set("fs.gs.requester.pays.mode", "CUSTOM") - conf.set("fs.gs.requester.pays.project.id", gcsRequesterPaysProject) - conf.set("fs.gs.requester.pays.buckets", gcsRequesterPaysBuckets) - } + val sparkSession: Lazy[SparkSession] = + lazily { + SparkSession.builder().config(sc.getConf).getOrCreate() } - new HadoopFS(new SerializableHadoopConfiguration(conf)) - } - override def backend: Backend = this - override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() - - override val longLifeTempFileManager: TempFileManager = - new OwningTempFileManager(fs) - - private[this] val bmCache = new BlockMatrixCache() - private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) - private[this] val persistedIr = mutable.Map.empty[Int, BaseIR] - private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) - - def createExecuteContextForTests( - timer: ExecutionTimer, - region: Region, - ): ExecuteContext = - new ExecuteContext( - tmpdir, - localTmpdir, - this, - fs, - region, - timer, - null, - theHailClassLoader, - flags, - new BackendContext { - override val executionCache: ExecutionCache = - ExecutionCache.forTesting - }, - IrMetadata(None), - references, - ImmutableMap.empty, - NoCaching, - ImmutableMap.empty, - NoCaching, - ) - - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = - ExecutionTimer.time { timer => - ExecuteContext.scoped( - tmpdir, - localTmpdir, - this, - fs, - timer, - null, - theHailClassLoader, - flags, - new BackendContext { - override val executionCache: ExecutionCache = - ExecutionCache.fromFlags(flags, fs, tmpdir) - }, - IrMetadata(None), - references, - bmCache, - codeCache, - persistedIr, - coercerCache, - )(f) - } + override def canExecuteParallelTasksOnDriver: Boolean = false def broadcast[T: ClassTag](value: T): BroadcastValue[T] = new SparkBroadcastValue[T](sc.broadcast(value)) + override def backendContext(ctx: ExecuteContext): BackendContext = + Context( + ctx.flags.get(SparkBackend.Flags.MaxStageParallelism).toInt, + ExecutionCache.fromFlags(ctx.flags, ctx.fs, ctx.tmpdir), + ) + override def parallelizeAndComputeWithIndex( - backendContext: BackendContext, + ctx: BackendContext, fs: FS, contexts: IndexedSeq[Array[Byte]], stageIdentifier: String, @@ -430,13 +329,12 @@ class SparkBackend( } } - val chunkSize = flags.get(SparkBackend.Flags.MaxStageParallelism).toInt val partsToRun = partitions.getOrElse(contexts.indices) val buffer = new ArrayBuffer[(Array[Byte], Int)](partsToRun.length) var failure: Option[Throwable] = None try { - for (subparts <- partsToRun.grouped(chunkSize)) { + for (subparts <- partsToRun.grouped(ctx.asInstanceOf[Context].maxStageParallelism)) { sc.runJob( rdd, (_: TaskContext, it: Iterator[Array[Byte]]) => it.next(), @@ -456,11 +354,8 @@ class SparkBackend( override def asSpark(implicit E: Enclosing): SparkBackend = this - def close(): Unit = { - bmCache.close() + def close(): Unit = SparkBackend.stop() - longLifeTempFileManager.close() - } def startProgressBar(): Unit = ProgressBarBuilder.build(sc) @@ -493,8 +388,7 @@ class SparkBackend( case (false, true) => DArrayLowering.BMOnly case (false, false) => throw new LowererUnsupportedOperation("no lowering enabled") } - val ir = - LoweringPipeline.darrayLowerer(optimize)(typesToLower).apply(ctx, ir0).asInstanceOf[IR] + val ir = LoweringPipeline.darrayLowerer(optimize)(typesToLower)(ctx, ir0).asInstanceOf[IR] if (!Compilable(ir)) throw new LowererUnsupportedOperation(s"lowered to uncompilable IR: ${Pretty(ctx, ir)}") @@ -532,11 +426,11 @@ class SparkBackend( Validate(ir) ctx.irMetadata = ctx.irMetadata.copy(semhash = SemanticHash(ctx)(ir)) try { - val lowerTable = flags.get("lower") != null - val lowerBM = flags.get("lower_bm") != null + val lowerTable = ctx.flags.get("lower") != null + val lowerBM = ctx.flags.get("lower_bm") != null _jvmLowerAndExecute(ctx, ir, optimize = true, lowerTable, lowerBM) } catch { - case e: LowererUnsupportedOperation if flags.get("lower_only") != null => throw e + case e: LowererUnsupportedOperation if ctx.flags.get("lower_only") != null => throw e case _: LowererUnsupportedOperation => CompileAndEvaluate._apply(ctx, ir, optimize = true) } @@ -549,7 +443,7 @@ class SparkBackend( rt: RTable, nPartitions: Option[Int], ): TableReader = { - if (flags.get("use_new_shuffle") != null) + if (ctx.flags.get("use_new_shuffle") != null) return LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) val (globals, rvd) = TableStageToRVD(ctx, stage) diff --git a/hail/src/main/scala/is/hail/expr/ir/Emit.scala b/hail/src/main/scala/is/hail/expr/ir/Emit.scala index fed96d91c2e..5a6197b897f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Emit.scala @@ -3385,7 +3385,7 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { Array[Array[Byte]], ]( "collectDArray", - mb.getObject(ctx.executeContext.backendContext), + mb.getObject(ctx.executeContext.backend.backendContext(ctx.executeContext)), mb.getHailClassLoader, mb.getFS, functionID, diff --git a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala index 2be62c0d407..2d4eee85606 100644 --- a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala +++ b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala @@ -1806,7 +1806,7 @@ object MatrixVCFReader { val fsConfigBC = backend.broadcast(fs.getConfiguration()) val (failureOpt, _) = backend.parallelizeAndComputeWithIndex( - ctx.backendContext, + ctx.backend.backendContext(ctx), fs, files.tail.map(_.getBytes), "load_vcf_parse_header", diff --git a/hail/src/main/scala/is/hail/utils/package.scala b/hail/src/main/scala/is/hail/utils/package.scala index 201cf59aa75..05d1d0932b4 100644 --- a/hail/src/main/scala/is/hail/utils/package.scala +++ b/hail/src/main/scala/is/hail/utils/package.scala @@ -9,7 +9,7 @@ import scala.collection.{mutable, GenTraversableOnce, TraversableOnce} import scala.collection.generic.CanBuildFrom import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionException -import scala.language.higherKinds +import scala.language.{higherKinds, implicitConversions} import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -1052,6 +1052,29 @@ package object utils def runAllKeepFirstError[A](executor: ExecutorService) : IndexedSeq[(() => A, Int)] => (Option[Throwable], IndexedSeq[(A, Int)]) = runAll[Option, A](executor) { case (opt, e) => opt.orElse(Some(e)) }(None) + + def lazily[A](f: => A): Lazy[A] = + new Lazy(f) + + implicit def evalLazy[A](f: Lazy[A]): A = + f() + + class Lazy[A] private[utils] (f: => A) { + private[this] var option: Option[A] = None + + def apply(): A = + synchronized { + option match { + case Some(a) => a + case None => val a = f; option = Some(a); a + } + } + + def isEvaluated: Boolean = + synchronized { + option.isDefined + } + } } class CancellingExecutorService(delegate: ExecutorService) extends AbstractExecutorService { diff --git a/hail/src/test/scala/is/hail/HailSuite.scala b/hail/src/test/scala/is/hail/HailSuite.scala index f00b588f8a8..057d3b77744 100644 --- a/hail/src/test/scala/is/hail/HailSuite.scala +++ b/hail/src/test/scala/is/hail/HailSuite.scala @@ -1,28 +1,38 @@ package is.hail import is.hail.ExecStrategy.ExecStrategy -import is.hail.TestUtils._ import is.hail.annotations._ -import is.hail.backend.{BroadcastValue, ExecuteContext} +import is.hail.asm4s.HailClassLoader +import is.hail.backend.{Backend, ExecuteContext} +import is.hail.backend.caching.NoCaching import is.hail.backend.spark.SparkBackend import is.hail.expr.ir._ -import is.hail.io.fs.FS +import is.hail.expr.ir.lowering.IrMetadata +import is.hail.io.fs.{FS, HadoopFS} import is.hail.types.virtual._ import is.hail.utils._ +import is.hail.variant.ReferenceGenome import java.io.{File, PrintWriter} import breeze.linalg.DenseMatrix +import org.apache.hadoop +import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.sql.Row import org.scalatestplus.testng.TestNGSuite import org.testng.ITestContext -import org.testng.annotations.{AfterMethod, BeforeClass, BeforeMethod} +import org.testng.annotations.{AfterClass, AfterMethod, BeforeClass, BeforeMethod} object HailSuite { - val theHailClassLoader = TestUtils.theHailClassLoader - def withSparkBackend(): HailContext = { + val theHailClassLoader: HailClassLoader = + new HailClassLoader(getClass.getClassLoader) + + val flags: HailFeatureFlags = + HailFeatureFlags.fromEnv(sys.env + ("lower" -> "1")) + + lazy val hc: HailContext = { HailContext.configureLogging("/tmp/hail.log", quiet = false, append = false) val backend = SparkBackend( sc = new SparkContext( @@ -34,61 +44,69 @@ object HailSuite { ) .set("spark.unsafe.exceptionOnMemoryLeak", "true") ), - tmpdir = "/tmp", - localTmpdir = "file:///tmp", skipLoggingConfiguration = true, ) - HailContext(backend) - } - - lazy val hc: HailContext = { - val hc = withSparkBackend() - hc.backend.asSpark.flags.set("lower", "1") + val hc = HailContext(backend) hc.checkRVDKeys = true hc } } -class HailSuite extends TestNGSuite { - val theHailClassLoader = HailSuite.theHailClassLoader +class HailSuite extends TestNGSuite with TestUtils { def hc: HailContext = HailSuite.hc - @BeforeClass def ensureHailContextInitialized(): Unit = hc - - def backend: SparkBackend = hc.backend.asSpark - - def sc: SparkContext = backend.sc - - def fs: FS = backend.fs - - def fsBc: BroadcastValue[FS] = fs.broadcast - - var timer: ExecutionTimer = _ + @BeforeClass + def initFs(): Unit = { + val conf = new Configuration(sc.hadoopConfiguration) + fs = new HadoopFS(new SerializableHadoopConfiguration(conf)) + } - var ctx: ExecuteContext = _ + @AfterClass + def closeFS(): Unit = + hadoop.fs.FileSystem.closeAll() + var fs: FS = _ var pool: RegionPool = _ + private[this] var ctx_ : ExecuteContext = _ + + def backend: Backend = ctx.backend + def sc: SparkContext = backend.asSpark.sc + def timer: ExecutionTimer = ctx.timer + def theHailClassLoader: HailClassLoader = ctx.theHailClassLoader + override def ctx: ExecuteContext = ctx_ @BeforeMethod def setupContext(context: ITestContext): Unit = { - assert(timer == null) - timer = new ExecutionTimer("HailSuite") - assert(ctx == null) pool = RegionPool() - ctx = backend.createExecuteContextForTests(timer, Region(pool = pool)) + ctx_ = new ExecuteContext( + tmpdir = "/tmp", + localTmpdir = "file:///tmp", + backend = hc.backend, + fs = fs, + r = Region(pool = pool), + timer = new ExecutionTimer(context.getName), + _tempFileManager = null, + theHailClassLoader = HailSuite.theHailClassLoader, + flags = HailSuite.flags, + irMetadata = IrMetadata(None), + References = ImmutableMap(ReferenceGenome.builtinReferences()), + BlockMatrixCache = NoCaching, + CodeCache = NoCaching, + IrCache = NoCaching, + CoercerCache = NoCaching, + ) } @AfterMethod def tearDownContext(context: ITestContext): Unit = { - ctx.close() - ctx = null - timer.finish() - timer = null + ctx_.timer.finish() + ctx_.close() + ctx_ = null pool.close() - if (backend.sc.isStopped) - throw new RuntimeException(s"method stopped spark context!") + if (sc.isStopped) + throw new RuntimeException(s"'${context.getName}' stopped spark context!") } def assertEvalsTo( @@ -105,73 +123,71 @@ class HailSuite extends TestNGSuite { val t = x.typ assert(t == TVoid || t.typeCheck(expected), s"$t, $expected") - ExecuteContext.scoped { ctx => - val filteredExecStrats: Set[ExecStrategy] = - if (HailContext.backend.isInstanceOf[SparkBackend]) - execStrats - else { - info("skipping interpret and non-lowering compile steps on non-spark backend") - execStrats.intersect(ExecStrategy.backendOnly) - } + val filteredExecStrats: Set[ExecStrategy] = + if (HailContext.backend.isInstanceOf[SparkBackend]) + execStrats + else { + info("skipping interpret and non-lowering compile steps on non-spark backend") + execStrats.intersect(ExecStrategy.backendOnly) + } - filteredExecStrats.foreach { strat => - try { - val res = strat match { - case ExecStrategy.Interpret => - assert(agg.isEmpty) - Interpret[Any](ctx, x, env, args) - case ExecStrategy.InterpretUnoptimized => - assert(agg.isEmpty) - Interpret[Any](ctx, x, env, args, optimize = false) - case ExecStrategy.JvmCompile => - assert(Forall(x, node => Compilable(node))) - eval( - x, - env, - args, - agg, - bytecodePrinter = - Option(ctx.getFlag("jvm_bytecode_dump")) - .map { path => - val pw = new PrintWriter(new File(path)) - pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") - pw - }, - true, - ctx, - ) - case ExecStrategy.JvmCompileUnoptimized => - assert(Forall(x, node => Compilable(node))) - eval( - x, - env, - args, - agg, - bytecodePrinter = - Option(ctx.getFlag("jvm_bytecode_dump")) - .map { path => - val pw = new PrintWriter(new File(path)) - pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") - pw - }, - optimize = false, - ctx, - ) - case ExecStrategy.LoweredJVMCompile => - loweredExecute(ctx, x, env, args, agg) - } - if (t != TVoid) { - assert(t.typeCheck(res), s"\n t=$t\n result=$res\n strategy=$strat") - assert( - t.valuesSimilar(res, expected), - s"\n result=$res\n expect=$expected\n strategy=$strat)", + filteredExecStrats.foreach { strat => + try { + val res = strat match { + case ExecStrategy.Interpret => + assert(agg.isEmpty) + Interpret[Any](ctx, x, env, args) + case ExecStrategy.InterpretUnoptimized => + assert(agg.isEmpty) + Interpret[Any](ctx, x, env, args, optimize = false) + case ExecStrategy.JvmCompile => + assert(Forall(x, node => Compilable(node))) + eval( + x, + env, + args, + agg, + bytecodePrinter = + Option(ctx.getFlag("jvm_bytecode_dump")) + .map { path => + val pw = new PrintWriter(new File(path)) + pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") + pw + }, + true, + ctx, + ) + case ExecStrategy.JvmCompileUnoptimized => + assert(Forall(x, node => Compilable(node))) + eval( + x, + env, + args, + agg, + bytecodePrinter = + Option(ctx.getFlag("jvm_bytecode_dump")) + .map { path => + val pw = new PrintWriter(new File(path)) + pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") + pw + }, + optimize = false, + ctx, ) - } - } catch { - case e: Exception => - error(s"error from strategy $strat") - if (execStrats.contains(strat)) throw e + case ExecStrategy.LoweredJVMCompile => + loweredExecute(ctx, x, env, args, agg) + } + if (t != TVoid) { + assert(t.typeCheck(res), s"\n t=$t\n result=$res\n strategy=$strat") + assert( + t.valuesSimilar(res, expected), + s"\n result=$res\n expect=$expected\n strategy=$strat)", + ) } + } catch { + case e: Exception => + error(s"error from strategy $strat") + if (execStrats.contains(strat)) throw e } } } @@ -250,35 +266,33 @@ class HailSuite extends TestNGSuite { expected: DenseMatrix[Double], )(implicit execStrats: Set[ExecStrategy] ): Unit = { - ExecuteContext.scoped { ctx => - val filteredExecStrats: Set[ExecStrategy] = - if (HailContext.backend.isInstanceOf[SparkBackend]) execStrats - else { - info("skipping interpret and non-lowering compile steps on non-spark backend") - execStrats.intersect(ExecStrategy.backendOnly) - } - filteredExecStrats.filter(ExecStrategy.interpretOnly).foreach { strat => - try { - val res = strat match { - case ExecStrategy.Interpret => - Interpret(bm, ctx, optimize = true) - case ExecStrategy.InterpretUnoptimized => - Interpret(bm, ctx, optimize = false) - } - assert(res.toBreezeMatrix() == expected) - } catch { - case e: Exception => - error(s"error from strategy $strat") - if (execStrats.contains(strat)) throw e + val filteredExecStrats: Set[ExecStrategy] = + if (HailContext.backend.isInstanceOf[SparkBackend]) execStrats + else { + info("skipping interpret and non-lowering compile steps on non-spark backend") + execStrats.intersect(ExecStrategy.backendOnly) + } + filteredExecStrats.filter(ExecStrategy.interpretOnly).foreach { strat => + try { + val res = strat match { + case ExecStrategy.Interpret => + Interpret(bm, ctx, optimize = true) + case ExecStrategy.InterpretUnoptimized => + Interpret(bm, ctx, optimize = false) } + assert(res.toBreezeMatrix() == expected) + } catch { + case e: Exception => + error(s"error from strategy $strat") + if (execStrats.contains(strat)) throw e } - val expectedArray = Array.tabulate(expected.rows)(i => - Array.tabulate(expected.cols)(j => expected(i, j)).toFastSeq - ).toFastSeq - assertNDEvals(BlockMatrixCollect(bm), expectedArray)( - filteredExecStrats.filterNot(ExecStrategy.interpretOnly) - ) } + val expectedArray = Array.tabulate(expected.rows)(i => + Array.tabulate(expected.cols)(j => expected(i, j)).toFastSeq + ).toFastSeq + assertNDEvals(BlockMatrixCollect(bm), expectedArray)( + filteredExecStrats.filterNot(ExecStrategy.interpretOnly) + ) } def assertAllEvalTo( diff --git a/hail/src/test/scala/is/hail/TestUtils.scala b/hail/src/test/scala/is/hail/TestUtils.scala index 71ab9445a4d..16a3ff21292 100644 --- a/hail/src/test/scala/is/hail/TestUtils.scala +++ b/hail/src/test/scala/is/hail/TestUtils.scala @@ -41,8 +41,9 @@ object ExecStrategy extends Enumeration { val allRelational: Set[ExecStrategy] = interpretOnly.union(lowering) } -object TestUtils { - val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) +trait TestUtils { + + def ctx: ExecuteContext = ??? import org.scalatest.Assertions._ @@ -95,7 +96,7 @@ object TestUtils { def removeConstantCols(A: DenseMatrix[Int]): DenseMatrix[Int] = { val data = (0 until A.cols).flatMap { j => val col = A(::, j) - if (TestUtils.isConstant(col)) + if (isConstant(col)) Array[Int]() else col.toArray @@ -124,18 +125,14 @@ object TestUtils { print = bytecodePrinter) } - def eval(x: IR): Any = ExecuteContext.scoped { ctx => - eval(x, Env.empty, FastSeq(), None, None, true, ctx) - } - def eval( x: IR, - env: Env[(Any, Type)], - args: IndexedSeq[(Any, Type)], - agg: Option[(IndexedSeq[Row], TStruct)], + env: Env[(Any, Type)] = Env.empty, + args: IndexedSeq[(Any, Type)] = FastSeq(), + agg: Option[(IndexedSeq[Row], TStruct)] = None, bytecodePrinter: Option[PrintWriter] = None, optimize: Boolean = true, - ctx: ExecuteContext, + ctx: ExecuteContext = ctx, ): Any = { val inputTypesB = new BoxedArrayBuilder[Type]() val inputsB = new mutable.ArrayBuffer[Any]() @@ -207,26 +204,29 @@ object TestUtils { assert(resultType2.virtualType == resultType) ctx.r.pool.scopedRegion { region => - val rvb = new RegionValueBuilder(ctx.stateManager, region) - rvb.start(argsPType) - rvb.startTuple() - var i = 0 - while (i < inputsB.length) { - rvb.addAnnotation(inputTypesB(i), inputsB(i)) - i += 1 + ctx.local(r = region) { ctx => + val rvb = new RegionValueBuilder(ctx.stateManager, ctx.r) + rvb.start(argsPType) + rvb.startTuple() + var i = 0 + while (i < inputsB.length) { + rvb.addAnnotation(inputTypesB(i), inputsB(i)) + i += 1 + } + rvb.endTuple() + val argsOff = rvb.end() + + rvb.start(aggArrayPType) + rvb.startArray(aggElements.length) + aggElements.foreach(r => rvb.addAnnotation(aggType, r)) + rvb.endArray() + val aggOff = rvb.end() + + ctx.scopedExecution { (hcl, fs, tc, r) => + val off = f(hcl, fs, tc, r)(r, argsOff, aggOff) + SafeRow(resultType2.asInstanceOf[PBaseStruct], off).get(0) + } } - rvb.endTuple() - val argsOff = rvb.end() - - rvb.start(aggArrayPType) - rvb.startArray(aggElements.length) - aggElements.foreach(r => rvb.addAnnotation(aggType, r)) - rvb.endArray() - val aggOff = rvb.end() - - val resultOff = - f(theHailClassLoader, ctx.fs, ctx.taskContext, region)(region, argsOff, aggOff) - SafeRow(resultType2.asInstanceOf[PBaseStruct], resultOff).get(0) } case None => @@ -246,19 +246,22 @@ object TestUtils { assert(resultType2.virtualType == resultType) ctx.r.pool.scopedRegion { region => - val rvb = new RegionValueBuilder(ctx.stateManager, region) - rvb.start(argsPType) - rvb.startTuple() - var i = 0 - while (i < inputsB.length) { - rvb.addAnnotation(inputTypesB(i), inputsB(i)) - i += 1 + ctx.local(r = region) { ctx => + val rvb = new RegionValueBuilder(ctx.stateManager, ctx.r) + rvb.start(argsPType) + rvb.startTuple() + var i = 0 + while (i < inputsB.length) { + rvb.addAnnotation(inputTypesB(i), inputsB(i)) + i += 1 + } + rvb.endTuple() + val argsOff = rvb.end() + ctx.scopedExecution { (hcl, fs, tc, r) => + val resultOff = f(hcl, fs, tc, r)(r, argsOff) + SafeRow(resultType2.asInstanceOf[PBaseStruct], resultOff).get(0) + } } - rvb.endTuple() - val argsOff = rvb.end() - - val resultOff = f(theHailClassLoader, ctx.fs, ctx.taskContext, region)(region, argsOff) - SafeRow(resultType2.asInstanceOf[PBaseStruct], resultOff).get(0) } } } @@ -272,7 +275,7 @@ object TestUtils { def assertEvalSame(x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)]): Unit = { val t = x.typ - val (i, i2, c) = ExecuteContext.scoped { ctx => + val (i, i2, c) = { val i = Interpret[Any](ctx, x, env, args) val i2 = Interpret[Any](ctx, x, env, args, optimize = false) val c = eval(x, env, args, None, None, true, ctx) @@ -295,12 +298,11 @@ object TestUtils { env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], regex: String, - ): Unit = - ExecuteContext.scoped { ctx => - interceptException[E](regex)(Interpret[Any](ctx, x, env, args)) - interceptException[E](regex)(Interpret[Any](ctx, x, env, args, optimize = false)) - interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) - } + ): Unit = { + interceptException[E](regex)(Interpret[Any](ctx, x, env, args)) + interceptException[E](regex)(Interpret[Any](ctx, x, env, args, optimize = false)) + interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) + } def assertFatal(x: IR, regex: String): Unit = assertThrows[HailException](x, regex) @@ -318,9 +320,7 @@ object TestUtils { args: IndexedSeq[(Any, Type)], regex: String, ): Unit = - ExecuteContext.scoped { ctx => - interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) - } + interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) def assertCompiledThrows[E <: Throwable: Manifest](x: IR, regex: String): Unit = assertCompiledThrows[E](x, Env.empty[(Any, Type)], FastSeq.empty[(Any, Type)], regex) diff --git a/hail/src/test/scala/is/hail/TestUtilsSuite.scala b/hail/src/test/scala/is/hail/TestUtilsSuite.scala index a6c91c5aef5..2c2b0c08406 100644 --- a/hail/src/test/scala/is/hail/TestUtilsSuite.scala +++ b/hail/src/test/scala/is/hail/TestUtilsSuite.scala @@ -11,21 +11,21 @@ class TestUtilsSuite extends HailSuite { val V = DenseVector(0d, 1d) val V1 = DenseVector(0d, 0.5d) - TestUtils.assertMatrixEqualityDouble(M, DenseMatrix.eye(2)) - TestUtils.assertMatrixEqualityDouble(M, M1, 0.001) - TestUtils.assertVectorEqualityDouble(V, 2d * V1) + assertMatrixEqualityDouble(M, DenseMatrix.eye(2)) + assertMatrixEqualityDouble(M, M1, 0.001) + assertVectorEqualityDouble(V, 2d * V1) - intercept[Exception](TestUtils.assertVectorEqualityDouble(V, V1)) - intercept[Exception](TestUtils.assertMatrixEqualityDouble(M, M1)) + intercept[Exception](assertVectorEqualityDouble(V, V1)) + intercept[Exception](assertMatrixEqualityDouble(M, M1)) } @Test def constantVectorTest(): Unit = { - assert(TestUtils.isConstant(DenseVector())) - assert(TestUtils.isConstant(DenseVector(0))) - assert(TestUtils.isConstant(DenseVector(0, 0))) - assert(TestUtils.isConstant(DenseVector(0, 0, 0))) - assert(!TestUtils.isConstant(DenseVector(0, 1))) - assert(!TestUtils.isConstant(DenseVector(0, 0, 1))) + assert(isConstant(DenseVector())) + assert(isConstant(DenseVector(0))) + assert(isConstant(DenseVector(0, 0))) + assert(isConstant(DenseVector(0, 0, 0))) + assert(!isConstant(DenseVector(0, 1))) + assert(!isConstant(DenseVector(0, 0, 1))) } @Test def removeConstantColsTest(): Unit = { @@ -33,6 +33,6 @@ class TestUtilsSuite extends HailSuite { val M1 = DenseMatrix((0, 1, 0), (1, 0, 1)) - assert(TestUtils.removeConstantCols(M) == M1) + assert(removeConstantCols(M) == M1) } } diff --git a/hail/src/test/scala/is/hail/annotations/UnsafeSuite.scala b/hail/src/test/scala/is/hail/annotations/UnsafeSuite.scala index 81a2a82e3a7..e071caa65a3 100644 --- a/hail/src/test/scala/is/hail/annotations/UnsafeSuite.scala +++ b/hail/src/test/scala/is/hail/annotations/UnsafeSuite.scala @@ -55,7 +55,7 @@ class UnsafeSuite extends HailSuite { @DataProvider(name = "codecs") def codecs(): Array[Array[Any]] = - ExecuteContext.scoped(ctx => codecs(ctx)) + codecs(ctx) def codecs(ctx: ExecuteContext): Array[Array[Any]] = (BufferSpec.specs ++ Array(TypedCodecSpec( diff --git a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala index 694280806cb..0524f6eb6fa 100644 --- a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala @@ -1,17 +1,13 @@ package is.hail.backend -import is.hail.HailFeatureFlags -import is.hail.asm4s.HailClassLoader -import is.hail.backend.service.{ - BatchJobConfig, ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload, -} -import is.hail.io.fs.{CloudStorageFSConfig, RouterFS} +import is.hail.HailSuite +import is.hail.backend.api.BatchJobConfig +import is.hail.backend.service.ServiceBackend import is.hail.services._ import is.hail.services.JobGroupStates.Success import is.hail.utils.{tokenUrlSafe, using} -import scala.collection.mutable -import scala.reflect.io.{Directory, Path} +import scala.reflect.io.Directory import scala.util.Random import java.io.Closeable @@ -21,151 +17,120 @@ import org.mockito.IdiomaticMockito import org.mockito.MockitoSugar.when import org.scalatest.OptionValues import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper -import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test -class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionValues { +class ServiceBackendSuite extends HailSuite with IdiomaticMockito with OptionValues { @Test def testCreateJobPayload(): Unit = - withMockDriverContext { case (rpcConfig, jobConfig) => - val batchClient = mock[BatchClient] - using(ServiceBackend(batchClient, rpcConfig, jobConfig)) { backend => - val contexts = Array.tabulate(1)(_.toString.getBytes) - - // verify that the service backend - // - creates the batch with the correct billing project, and - // - the number of jobs matches the number of partitions, and - // - each job is created in the specified region, and - // - each job's resource configuration matches the rpc config - val batchId = Random.nextInt() - - when(batchClient.newBatch(any[BatchRequest])) thenAnswer { - (batchRequest: BatchRequest) => - batchRequest.billing_project shouldEqual jobConfig.billing_project - batchRequest.n_jobs shouldBe 0 - batchRequest.attributes.get("name").value shouldBe backend.name - batchId - } + withObjectSpied[is.hail.utils.UtilsType] { + // not obvious how to pull out `tokenUrlSafe` and inject this directory + // using a spy is a hack and i don't particularly like it. + when(is.hail.utils.tokenUrlSafe) thenAnswer "TOKEN" + + val jobConfig = BatchJobConfig( + token = tokenUrlSafe, + billing_project = "fancy", + worker_cores = "128", + worker_memory = "a lot.", + storage = "a big ssd?", + cloudfuse_configs = Array(), + regions = Array("lunar1"), + ) - when(batchClient.newJobGroup( - any[Int], - any[String], - any[JobGroupRequest], - any[IndexedSeq[JobRequest]], - )) thenAnswer { - (id: Int, _: String, jobGroup: JobGroupRequest, jobs: IndexedSeq[JobRequest]) => - id shouldBe batchId - jobGroup.job_group_id shouldBe 1 - jobGroup.absolute_parent_id shouldBe 0 - jobs.length shouldEqual contexts.length - jobs.foreach { payload => - payload.regions.value shouldBe jobConfig.regions - payload.resources.value shouldBe JobResources( - preemptible = true, - cpu = Some(jobConfig.worker_cores), - memory = Some(jobConfig.worker_memory), - storage = Some(jobConfig.storage), + val batchClient = mock[BatchClient] + using(ServiceBackend(batchClient, jobConfig)) { backend => + using(LocalTmpFolder) { tmp => + val contexts = Array.tabulate(1)(_.toString.getBytes) + + // verify that the service backend + // - creates the batch with the correct billing project, and + // - the number of jobs matches the number of partitions, and + // - each job is created in the specified region, and + // - each job's resource configuration matches the rpc config + val batchId = Random.nextInt() + + when(batchClient.newBatch(any[BatchRequest])) thenAnswer { + (batchRequest: BatchRequest) => + batchRequest.billing_project shouldEqual jobConfig.billing_project + batchRequest.n_jobs shouldBe 0 + batchRequest.attributes.get("name").value shouldBe backend.name + batchId + } + + when(batchClient.newJobGroup( + any[Int], + any[String], + any[JobGroupRequest], + any[IndexedSeq[JobRequest]], + )) thenAnswer { + (id: Int, _: String, jobGroup: JobGroupRequest, jobs: IndexedSeq[JobRequest]) => + id shouldBe batchId + jobGroup.job_group_id shouldBe 1 + jobGroup.absolute_parent_id shouldBe 0 + jobs.length shouldEqual contexts.length + jobs.foreach { payload => + payload.regions.value shouldBe jobConfig.regions + payload.resources.value shouldBe JobResources( + preemptible = true, + cpu = Some(jobConfig.worker_cores), + memory = Some(jobConfig.worker_memory), + storage = Some(jobConfig.storage), + ) + } + + (37, 1) + } + + // the service backend expects that each job write its output to a well-known + // location when it finishes. + when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { + (id: Int, jobGroupId: Int) => + id shouldEqual batchId + jobGroupId shouldEqual 1 + + val resultsDir = tmp / "parallelizeAndComputeWithIndex" / tokenUrlSafe + resultsDir.createDirectory() + for (i <- contexts.indices) (resultsDir / f"result.$i").toFile.writeAll("11") + + JobGroupResponse( + batch_id = id, + job_group_id = jobGroupId, + state = Success, + complete = true, + n_jobs = contexts.length, + n_completed = contexts.length, + n_succeeded = contexts.length, + n_failed = 0, + n_cancelled = 0, ) - } - - (37, 1) + } + + ctx.local(tmpdir = tmp.toString()) { ctx => + val (failure, _) = + backend.parallelizeAndComputeWithIndex( + backend.backendContext(ctx), + ctx.fs, + contexts, + "stage1", + )((bytes, _, _, _) => bytes) + + failure.foreach(throw _) + } + + batchClient.newBatch(any) wasCalled once + batchClient.newJobGroup(any, any, any, any) wasCalled once } - - // the service backend expects that each job write its output to a well-known - // location when it finishes. - when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { - (id: Int, jobGroupId: Int) => - id shouldEqual batchId - jobGroupId shouldEqual 1 - - val resultsDir = - Path(rpcConfig.remote_tmpdir) / - "parallelizeAndComputeWithIndex" / - tokenUrlSafe - - resultsDir.createDirectory() - for (i <- contexts.indices) (resultsDir / f"result.$i").toFile.writeAll("11") - JobGroupResponse( - batch_id = id, - job_group_id = jobGroupId, - state = Success, - complete = true, - n_jobs = contexts.length, - n_completed = contexts.length, - n_succeeded = contexts.length, - n_failed = 0, - n_cancelled = 0, - ) - } - - val (failure, _) = - backend.parallelizeAndComputeWithIndex( - ServiceBackendContext( - remoteTmpDir = rpcConfig.remote_tmpdir, - jobConfig = jobConfig, - executionCache = ExecutionCache.noCache, - ), - backend.fs, - contexts, - "stage1", - )((bytes, _, _, _) => bytes) - - failure.foreach(throw _) - - batchClient.newBatch(any) wasCalled once - batchClient.newJobGroup(any, any, any, any) wasCalled once } } - def ServiceBackend( - client: BatchClient, - rpcConfig: ServiceBackendRPCPayload, - jobConfig: BatchJobConfig, - ): ServiceBackend = { - val flags = HailFeatureFlags.fromEnv() - val fs = RouterFS.buildRoutes(CloudStorageFSConfig()) + def ServiceBackend(client: BatchClient, jobConfig: BatchJobConfig): ServiceBackend = new ServiceBackend( name = "name", batchClient = client, jarLocation = "us-docker.pkg.dev/hail-vdc/hail/hailgenetics/hail@sha256:fake", - theHailClassLoader = new HailClassLoader(getClass.getClassLoader), batchConfig = None, - rpcConfig = rpcConfig, jobConfig = jobConfig, - flags = flags, - fs = fs, - references = mutable.Map.empty, ) - } - - def withMockDriverContext(test: (ServiceBackendRPCPayload, BatchJobConfig) => Any): Any = - using(LocalTmpFolder) { tmp => - withObjectSpied[is.hail.utils.UtilsType] { - // not obvious how to pull out `tokenUrlSafe` and inject this directory - // using a spy is a hack and i don't particularly like it. - when(is.hail.utils.tokenUrlSafe) thenAnswer "TOKEN" - - test( - ServiceBackendRPCPayload( - tmp_dir = tmp.path, - remote_tmpdir = tmp.path, - flags = Map(), - custom_references = Array(), - liftovers = Map(), - sequences = Map(), - ), - BatchJobConfig( - token = tokenUrlSafe, - billing_project = "fancy", - worker_cores = "128", - worker_memory = "a lot.", - storage = "a big ssd?", - cloudfuse_configs = Array(), - regions = Array("lunar1"), - ), - ) - } - } def LocalTmpFolder: Directory with Closeable = new Directory(Directory.makeTemp("hail-testing-tmp").jfile) with Closeable { diff --git a/hail/src/test/scala/is/hail/expr/ir/ArrayFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/ArrayFunctionsSuite.scala index d76741cb770..d4dfc7b25e5 100644 --- a/hail/src/test/scala/is/hail/expr/ir/ArrayFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/ArrayFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual._ import is.hail.utils.FastSeq diff --git a/hail/src/test/scala/is/hail/expr/ir/CallFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/CallFunctionsSuite.scala index e04a181b506..94d1a7af0e8 100644 --- a/hail/src/test/scala/is/hail/expr/ir/CallFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/CallFunctionsSuite.scala @@ -1,7 +1,7 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.expr.ir.TestUtils.IRCall +import is.hail.expr.ir.TestUtils.{IRArray, IRCall} import is.hail.types.virtual.{TArray, TBoolean, TCall, TInt32} import is.hail.variant._ @@ -60,7 +60,7 @@ class CallFunctionsSuite extends HailSuite { assertEvalsTo(invoke("Call", TCall, I32(1), False()), Call1(1, false)) assertEvalsTo(invoke("Call", TCall, I32(0), I32(0), False()), Call2(0, 0, false)) assertEvalsTo( - invoke("Call", TCall, TestUtils.IRArray(0, 1), False()), + invoke("Call", TCall, IRArray(0, 1), False()), CallN(Array(0, 1), false), ) assertEvalsTo(invoke("Call", TCall, Str("0|1")), Call2(0, 1, true)) diff --git a/hail/src/test/scala/is/hail/expr/ir/DictFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/DictFunctionsSuite.scala index f0e071c6aa1..674cd278e3c 100644 --- a/hail/src/test/scala/is/hail/expr/ir/DictFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/DictFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual._ import is.hail.utils.FastSeq diff --git a/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala b/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala index 4ab89a11177..48f4c52cfd6 100644 --- a/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala @@ -1,10 +1,8 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ -import is.hail.annotations.{Region, SafeRow, ScalaToRegionValue} +import is.hail.annotations.{Region, RegionPool, SafeRow, ScalaToRegionValue} import is.hail.asm4s._ -import is.hail.backend.ExecuteContext import is.hail.expr.ir.agg.{CollectStateSig, PhysicalAggSig, TypedStateSig} import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering.LoweringPipeline @@ -1063,14 +1061,14 @@ class EmitStreamSuite extends HailSuite { def assertMemoryDoesNotScaleWithStreamSize(lowSize: Int = 50, highSize: Int = 2500)(f: IR => IR) : Unit = { - val memUsed1 = ExecuteContext.scoped { ctx => - eval(f(lowSize), Env.empty, FastSeq(), None, None, false, ctx) - ctx.r.pool.getHighestTotalUsage + val memUsed1 = RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool))(ctx => eval(f(lowSize), optimize = false, ctx = ctx)) + pool.getHighestTotalUsage } - val memUsed2 = ExecuteContext.scoped { ctx => - eval(f(highSize), Env.empty, FastSeq(), None, None, false, ctx) - ctx.r.pool.getHighestTotalUsage + val memUsed2 = RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool))(ctx => eval(f(highSize), optimize = false, ctx = ctx)) + pool.getHighestTotalUsage } if (memUsed1 != memUsed2) diff --git a/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala index 9c6c5682a44..f3a8b300d7e 100644 --- a/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.HailSuite -import is.hail.TestUtils._ import is.hail.expr.Nat import is.hail.types.virtual._ import is.hail.utils._ diff --git a/hail/src/test/scala/is/hail/expr/ir/GenotypeFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/GenotypeFunctionsSuite.scala index 521922af0e7..f6ba5f123a3 100644 --- a/hail/src/test/scala/is/hail/expr/ir/GenotypeFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/GenotypeFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual.TFloat64 import is.hail.utils.FastSeq diff --git a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala index bcaf5a93d57..05cf37da1dc 100644 --- a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala @@ -2,12 +2,12 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy -import is.hail.TestUtils._ -import is.hail.annotations.{BroadcastRow, ExtendedOrdering, SafeNDArray} +import is.hail.annotations.{BroadcastRow, ExtendedOrdering, Region, RegionPool, SafeNDArray} import is.hail.backend.ExecuteContext import is.hail.backend.caching.BlockMatrixCache import is.hail.expr.Nat import is.hail.expr.ir.ArrayZipBehavior.ArrayZipBehavior +import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.agg._ import is.hail.expr.ir.functions._ import is.hail.io.{BufferSpec, TypedCodecSpec} @@ -1402,11 +1402,11 @@ class IRSuite extends HailSuite { assertEvalsTo(StreamLen(zip(ArrayZipBehavior.AssumeSameLength, range8, range8)), 8) // https://github.com/hail-is/hail/issues/8359 - is.hail.TestUtils.assertThrows[HailException]( + assertThrows[HailException]( zipToTuple(ArrayZipBehavior.AssertSameLength, range6, range8): IR, "zip: length mismatch": String, ) - is.hail.TestUtils.assertThrows[HailException]( + assertThrows[HailException]( zipToTuple(ArrayZipBehavior.AssertSameLength, range12, lit6): IR, "zip: length mismatch": String, ) @@ -1561,7 +1561,7 @@ class IRSuite extends HailSuite { val na = NA(TDict(TInt32, TString)) assertEvalsTo(LowerBoundOnOrderedCollection(na, I32(0), onKey = true), null) - val dwna = TestUtils.IRDict((1, 3), (3, null), (null, 5)) + val dwna = IRDict((1, 3), (3, null), (null, 5)) assertEvalsTo(LowerBoundOnOrderedCollection(dwna, I32(-1), onKey = true), 0) assertEvalsTo(LowerBoundOnOrderedCollection(dwna, I32(1), onKey = true), 0) assertEvalsTo(LowerBoundOnOrderedCollection(dwna, I32(2), onKey = true), 1) @@ -1569,7 +1569,7 @@ class IRSuite extends HailSuite { assertEvalsTo(LowerBoundOnOrderedCollection(dwna, I32(5), onKey = true), 2) assertEvalsTo(LowerBoundOnOrderedCollection(dwna, NA(TInt32), onKey = true), 2) - val dwoutna = TestUtils.IRDict((1, 3), (3, null)) + val dwoutna = IRDict((1, 3), (3, null)) assertEvalsTo(LowerBoundOnOrderedCollection(dwoutna, I32(-1), onKey = true), 0) assertEvalsTo(LowerBoundOnOrderedCollection(dwoutna, I32(4), onKey = true), 2) assertEvalsTo(LowerBoundOnOrderedCollection(dwoutna, NA(TInt32), onKey = true), 2) @@ -1800,18 +1800,18 @@ class IRSuite extends HailSuite { @Test def testStreamFold(): Unit = { assertEvalsTo(foldIR(StreamRange(1, 2, 1), NA(TBoolean))((accum, elt) => IsNA(accum)), true) - assertEvalsTo(foldIR(TestUtils.IRStream(1, 2, 3), 0)((accum, elt) => accum + elt), 6) + assertEvalsTo(foldIR(IRStream(1, 2, 3), 0)((accum, elt) => accum + elt), 6) assertEvalsTo( - foldIR(TestUtils.IRStream(1, 2, 3), NA(TInt32))((accum, elt) => accum + elt), + foldIR(IRStream(1, 2, 3), NA(TInt32))((accum, elt) => accum + elt), null, ) assertEvalsTo( - foldIR(TestUtils.IRStream(1, null, 3), NA(TInt32))((accum, elt) => accum + elt), + foldIR(IRStream(1, null, 3), NA(TInt32))((accum, elt) => accum + elt), null, ) - assertEvalsTo(foldIR(TestUtils.IRStream(1, null, 3), 0)((accum, elt) => accum + elt), null) + assertEvalsTo(foldIR(IRStream(1, null, 3), 0)((accum, elt) => accum + elt), null) assertEvalsTo( - foldIR(TestUtils.IRStream(1, null, 3), NA(TInt32))((accum, elt) => I32(5) + I32(5)), + foldIR(IRStream(1, null, 3), NA(TInt32))((accum, elt) => I32(5) + I32(5)), 10, ) } @@ -1845,15 +1845,15 @@ class IRSuite extends HailSuite { FastSeq(null, true, false, false), ) assertEvalsTo( - ToArray(streamScanIR(TestUtils.IRStream(1, 2, 3), 0)((accum, elt) => accum + elt)), + ToArray(streamScanIR(IRStream(1, 2, 3), 0)((accum, elt) => accum + elt)), FastSeq(0, 1, 3, 6), ) assertEvalsTo( - ToArray(streamScanIR(TestUtils.IRStream(1, 2, 3), NA(TInt32))((accum, elt) => accum + elt)), + ToArray(streamScanIR(IRStream(1, 2, 3), NA(TInt32))((accum, elt) => accum + elt)), FastSeq(null, null, null, null), ) assertEvalsTo( - ToArray(streamScanIR(TestUtils.IRStream(1, null, 3), NA(TInt32))((accum, elt) => + ToArray(streamScanIR(IRStream(1, null, 3), NA(TInt32))((accum, elt) => accum + elt )), FastSeq(null, null, null, null), @@ -3251,7 +3251,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "valueIRs") def valueIRs(): Array[Array[Object]] = - ExecuteContext.scoped(ctx => valueIRs(ctx)) + valueIRs(ctx) def valueIRs(ctx: ExecuteContext): Array[Array[Object]] = { val fs = ctx.fs @@ -3313,7 +3313,7 @@ class IRSuite extends HailSuite { val table = TableRange(100, 10) val mt = MatrixIR.range(20, 2, Some(3)) - val vcf = is.hail.TestUtils.importVCF(ctx, "src/test/resources/sample.vcf") + val vcf = importVCF(ctx, "src/test/resources/sample.vcf") val bgenReader = MatrixBGENReader( ctx, @@ -3596,7 +3596,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "tableIRs") def tableIRs(): Array[Array[TableIR]] = - ExecuteContext.scoped(ctx => tableIRs(ctx)) + tableIRs(ctx) def tableIRs(ctx: ExecuteContext): Array[Array[TableIR]] = { try { @@ -3705,7 +3705,7 @@ class IRSuite extends HailSuite { @DataProvider(name = "matrixIRs") def matrixIRs(): Array[Array[MatrixIR]] = - ExecuteContext.scoped(ctx => matrixIRs(ctx)) + matrixIRs(ctx) def matrixIRs(ctx: ExecuteContext): Array[Array[MatrixIR]] = { try { @@ -3729,7 +3729,7 @@ class IRSuite extends HailSuite { val read = MatrixIR.read(fs, "src/test/resources/backward_compatability/1.0.0/matrix_table/0.hmt") val range = MatrixIR.range(3, 7, None) - val vcf = is.hail.TestUtils.importVCF(ctx, "src/test/resources/sample.vcf") + val vcf = importVCF(ctx, "src/test/resources/sample.vcf") val bgenReader = MatrixBGENReader( ctx, @@ -3907,38 +3907,27 @@ class IRSuite extends HailSuite { using(new BlockMatrixCache()) { cache => val bm = BlockMatrixRandom(0, gaussian = true, shape = Array(5L, 6L), blockSize = 3) - backend.withExecuteContext { ctx => - ctx.local(blockMatrixCache = cache) { ctx => - backend.execute(ctx, BlockMatrixWrite(bm, BlockMatrixPersistWriter("x", "MEMORY_ONLY"))) - } - } - - backend.withExecuteContext { ctx => - ctx.local(blockMatrixCache = cache) { ctx => - val persist = BlockMatrixRead(BlockMatrixPersistReader("x", bm.typ)) + ctx.local(blockMatrixCache = cache) { ctx => + backend.execute(ctx, BlockMatrixWrite(bm, BlockMatrixPersistWriter("x", "MEMORY_ONLY"))) - val s = Pretty.sexprStyle(persist, elideLiterals = false) - val x2 = IRParser.parse_blockmatrix_ir(ctx, s) - assert(x2 == persist) - } + val persist = BlockMatrixRead(BlockMatrixPersistReader("x", bm.typ)) + val s = Pretty.sexprStyle(persist, elideLiterals = false) + val x2 = IRParser.parse_blockmatrix_ir(ctx, s) + assert(x2 == persist) } } @Test def testCachedIR(): Unit = { val cached = Literal(TSet(TInt32), Set(1)) val s = s"(JavaIR 1)" - val x2 = ExecuteContext.scoped { ctx => - ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_value_ir(ctx, s)) - } + val x2 = ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_value_ir(ctx, s)) assert(x2 eq cached) } @Test def testCachedTableIR(): Unit = { val cached = TableRange(1, 1) val s = s"(JavaTable 1)" - val x2 = ExecuteContext.scoped { ctx => - ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_table_ir(ctx, s)) - } + val x2 = ctx.local(irCache = mutable.Map(1 -> cached))(ctx => IRParser.parse_table_ir(ctx, s)) assert(x2 eq cached) } @@ -4220,16 +4209,18 @@ class IRSuite extends HailSuite { val startingArg = SafeNDArray(IndexedSeq[Long](4L, 4L), (0 until 16).toFastSeq) - var memUsed = 0L - - ExecuteContext.scoped { ctx => - eval(ndSum, Env.empty, FastSeq(2 -> TInt32, startingArg -> ndType), None, None, true, ctx) - memUsed = ctx.r.pool.getHighestTotalUsage + val memUsed = RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool)) { ctx => + eval(ndSum, Env.empty, FastSeq(2 -> TInt32, startingArg -> ndType), None, None, true, ctx) + pool.getHighestTotalUsage + } } - ExecuteContext.scoped { ctx => - eval(ndSum, Env.empty, FastSeq(100 -> TInt32, startingArg -> ndType), None, None, true, ctx) - assert(memUsed == ctx.r.pool.getHighestTotalUsage) + RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool)) { ctx => + eval(ndSum, Env.empty, FastSeq(100 -> TInt32, startingArg -> ndType), None, None, true, ctx) + assert(memUsed == pool.getHighestTotalUsage) + } } } diff --git a/hail/src/test/scala/is/hail/expr/ir/IntervalSuite.scala b/hail/src/test/scala/is/hail/expr/ir/IntervalSuite.scala index 2c803273d55..0338029c806 100644 --- a/hail/src/test/scala/is/hail/expr/ir/IntervalSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/IntervalSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.rvd.RVDPartitioner import is.hail.types.virtual._ import is.hail.utils._ diff --git a/hail/src/test/scala/is/hail/expr/ir/MathFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/MathFunctionsSuite.scala index 320d8d19e47..4fc10ebd290 100644 --- a/hail/src/test/scala/is/hail/expr/ir/MathFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/MathFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{stats, ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.types.virtual._ import is.hail.utils._ diff --git a/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala index 1cf94510f20..c662d02df1c 100644 --- a/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala @@ -2,7 +2,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy -import is.hail.TestUtils._ import is.hail.annotations.BroadcastRow import is.hail.expr.JSONAnnotationImpex import is.hail.expr.ir.TestUtils._ @@ -377,7 +376,7 @@ class MatrixIRSuite extends HailSuite { } @Test def testMatrixMultiWriteDifferentTypesRaisesError(): Unit = { - val vcf = is.hail.TestUtils.importVCF(ctx, "src/test/resources/sample.vcf") + val vcf = importVCF(ctx, "src/test/resources/sample.vcf") val range = rangeMatrix(10, 2, None) val path1 = ctx.createTmpPath("test1") val path2 = ctx.createTmpPath("test2") diff --git a/hail/src/test/scala/is/hail/expr/ir/MemoryLeakSuite.scala b/hail/src/test/scala/is/hail/expr/ir/MemoryLeakSuite.scala index 49695998d14..df23b18897c 100644 --- a/hail/src/test/scala/is/hail/expr/ir/MemoryLeakSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/MemoryLeakSuite.scala @@ -1,8 +1,7 @@ package is.hail.expr.ir import is.hail.HailSuite -import is.hail.TestUtils.eval -import is.hail.backend.ExecuteContext +import is.hail.annotations.{Region, RegionPool} import is.hail.expr.ir import is.hail.types.virtual.{TArray, TBoolean, TSet, TString} import is.hail.utils._ @@ -17,19 +16,22 @@ class MemoryLeakSuite extends HailSuite { def run(size: Int): Long = { val lit = Literal(TSet(TString), (0 until litSize).map(_.toString).toSet) val queries = Literal(TArray(TString), (0 until size).map(_.toString).toFastSeq) - ExecuteContext.scoped { ctx => - eval( - ToArray( - mapIR(ToStream(queries))(r => ir.invoke("contains", TBoolean, lit, r)) - ), - Env.empty, - FastSeq(), - None, - None, - false, - ctx, - ) - ctx.r.pool.getHighestTotalUsage + RegionPool.scoped { pool => + ctx.local(r = Region(pool = pool)) { ctx => + eval( + ToArray( + mapIR(ToStream(queries))(r => ir.invoke("contains", TBoolean, lit, r)) + ), + Env.empty, + FastSeq(), + None, + None, + false, + ctx, + ) + } + + pool.getHighestTotalUsage } } diff --git a/hail/src/test/scala/is/hail/expr/ir/OrderingSuite.scala b/hail/src/test/scala/is/hail/expr/ir/OrderingSuite.scala index c508395138e..e09ff44eee6 100644 --- a/hail/src/test/scala/is/hail/expr/ir/OrderingSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/OrderingSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.annotations._ import is.hail.asm4s._ import is.hail.check.{Gen, Prop} diff --git a/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala b/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala index 7c3597d6c4a..2095e0cfe51 100644 --- a/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.HailSuite -import is.hail.backend.ExecuteContext import is.hail.expr.Nat import is.hail.expr.ir.agg.CallStatsState import is.hail.io.{BufferSpec, TypedCodecSpec} @@ -103,7 +102,7 @@ class RequirednessSuite extends HailSuite { def pinterval(point: PType, r: Boolean): PInterval = PCanonicalInterval(point, r) @DataProvider(name = "valueIR") - def valueIR(): Array[Array[Any]] = ExecuteContext.scoped { ctx => + def valueIR(): Array[Array[Any]] = { val nodes = new BoxedArrayBuilder[Array[Any]](50) val allRequired = Array( diff --git a/hail/src/test/scala/is/hail/expr/ir/SetFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/SetFunctionsSuite.scala index 10c74088d8d..b6489d092c8 100644 --- a/hail/src/test/scala/is/hail/expr/ir/SetFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/SetFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual._ diff --git a/hail/src/test/scala/is/hail/expr/ir/SimplifySuite.scala b/hail/src/test/scala/is/hail/expr/ir/SimplifySuite.scala index 9b3eb0ea59a..1c5c33ba616 100644 --- a/hail/src/test/scala/is/hail/expr/ir/SimplifySuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/SimplifySuite.scala @@ -1,7 +1,7 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.expr.ir.TestUtils.IRAggCount +import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual._ import is.hail.utils.{FastSeq, Interval} import is.hail.variant.Locus diff --git a/hail/src/test/scala/is/hail/expr/ir/StringFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/StringFunctionsSuite.scala index 6c2697b1dfd..f634d51ca75 100644 --- a/hail/src/test/scala/is/hail/expr/ir/StringFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/StringFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual._ import is.hail.utils.FastSeq diff --git a/hail/src/test/scala/is/hail/expr/ir/StringSliceSuite.scala b/hail/src/test/scala/is/hail/expr/ir/StringSliceSuite.scala index ac869b23455..1513589269b 100644 --- a/hail/src/test/scala/is/hail/expr/ir/StringSliceSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/StringSliceSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.types.virtual.TString import is.hail.utils._ diff --git a/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala index 5041df1e705..a835c5bd76e 100644 --- a/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala @@ -2,7 +2,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy -import is.hail.TestUtils._ import is.hail.annotations.SafeNDArray import is.hail.expr.Nat import is.hail.expr.ir.TestUtils._ diff --git a/hail/src/test/scala/is/hail/expr/ir/TrapNodeSuite.scala b/hail/src/test/scala/is/hail/expr/ir/TrapNodeSuite.scala index bab77019356..cf1c84c665d 100644 --- a/hail/src/test/scala/is/hail/expr/ir/TrapNodeSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/TrapNodeSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.types.virtual._ import is.hail.utils._ diff --git a/hail/src/test/scala/is/hail/expr/ir/UtilFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/UtilFunctionsSuite.scala index 024f7cd20fb..d570ec1c4a9 100644 --- a/hail/src/test/scala/is/hail/expr/ir/UtilFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/UtilFunctionsSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.types.virtual.{TBoolean, TInt32, TStream} import org.testng.annotations.Test diff --git a/hail/src/test/scala/is/hail/expr/ir/analyses/SemanticHashSuite.scala b/hail/src/test/scala/is/hail/expr/ir/analyses/SemanticHashSuite.scala index 97089297c03..8866db7d034 100644 --- a/hail/src/test/scala/is/hail/expr/ir/analyses/SemanticHashSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/analyses/SemanticHashSuite.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir.analyses import is.hail.{HAIL_PRETTY_VERSION, HailSuite} -import is.hail.backend.ExecuteContext import is.hail.expr.ir._ import is.hail.io.fs.{FS, FakeFS, FakeURL, FileListEntry} import is.hail.linalg.BlockMatrixMetadata @@ -291,12 +290,14 @@ class SemanticHashSuite extends HailSuite { @Test(dataProvider = "isBaseIRSemanticallyEquivalent") def testSemanticEquivalence(a: BaseIR, b: BaseIR, isEqual: Boolean, comment: String): Unit = - assertResult( - isEqual, - s"expected semhash($a) ${if (isEqual) "==" else "!="} semhash($b), $comment", - )( - semhash(fakeFs)(a) == semhash(fakeFs)(b) - ) + ctx.local(fs = fakeFs) { ctx => + assertResult( + isEqual, + s"expected semhash($a) ${if (isEqual) "==" else "!="} semhash($b), $comment", + )( + SemanticHash(ctx)(a) == SemanticHash(ctx)(b) + ) + } @Test def testFileNotFoundExceptions(): Unit = { @@ -308,14 +309,13 @@ class SemanticHashSuite extends HailSuite { val ir = importMatrix("gs://fake-bucket/fake-matrix") - assertResult(None, "SemHash should be resilient to FileNotFoundExceptions.")( - semhash(fs)(ir) - ) + ctx.local(fs = fs) { ctx => + assertResult(None, "SemHash should be resilient to FileNotFoundExceptions.")( + SemanticHash(ctx)(ir) + ) + } } - def semhash(fs: FS)(ir: BaseIR): Option[SemanticHash.Type] = - ExecuteContext.scoped(_.local(fs = fs)(SemanticHash(_)(ir))) - val fakeFs: FS = new FakeFS { override def eTag(url: FakeURL): Option[String] = diff --git a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala index 827b439e288..2159409407f 100644 --- a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala @@ -5,6 +5,7 @@ import is.hail.expr.ir.{ mapIR, Apply, Ascending, Descending, ErrorIDs, GetField, I32, Literal, LoweringAnalyses, MakeStruct, Ref, SelectFields, SortField, TableIR, TableMapRows, TableRange, ToArray, ToStream, } +import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.lowering.LowerDistributedSort.samplePartition import is.hail.types.RTable import is.hail.types.virtual.{TArray, TInt32, TStruct} @@ -13,7 +14,7 @@ import is.hail.utils.FastSeq import org.apache.spark.sql.Row import org.testng.annotations.Test -class LowerDistributedSortSuite extends HailSuite { +class LowerDistributedSortSuite extends HailSuite with TestUtils { implicit val execStrats = ExecStrategy.compileOnly @Test def testSamplePartition(): Unit = { @@ -69,14 +70,14 @@ class LowerDistributedSortSuite extends HailSuite { val sortedTs = LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) .lower(ctx, myTable.typ.copy(key = FastSeq())) val res = - TestUtils.eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[ + eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[ IndexedSeq[Row] ]].flatten val rowFunc = myTable.typ.rowType.select(sortFields.map(_.field))._2 - val unsortedCollect = is.hail.expr.ir.TestUtils.collect(myTable) + val unsortedCollect = collect(myTable) val unsortedAnalyses = LoweringAnalyses.apply(unsortedCollect, ctx) - val unsorted = TestUtils.eval(LowerTableIR.apply( + val unsorted = eval(LowerTableIR.apply( unsortedCollect, DArrayLowering.All, ctx, diff --git a/hail/src/test/scala/is/hail/expr/ir/table/TableGenSuite.scala b/hail/src/test/scala/is/hail/expr/ir/table/TableGenSuite.scala index 325d691f08b..6637496fac9 100644 --- a/hail/src/test/scala/is/hail/expr/ir/table/TableGenSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/table/TableGenSuite.scala @@ -1,10 +1,8 @@ package is.hail.expr.ir.table import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils.loweredExecute -import is.hail.backend.ExecuteContext import is.hail.expr.ir._ -import is.hail.expr.ir.TestUtils.IRAggCollect +import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR} import is.hail.rvd.RVDPartitioner import is.hail.types.virtual._ @@ -95,7 +93,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("lowering")) def testLowering(): Unit = { - val table = TestUtils.collect(mkTableGen()) + val table = collect(mkTableGen()) val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) assertEvalsTo(lowered, Row(FastSeq(0, 0).map(Row(_)), Row(0))) } @@ -103,13 +101,13 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("lowering")) def testNumberOfContextsMatchesPartitions(): Unit = { val errorId = 42 - val table = TestUtils.collect(mkTableGen( + val table = collect(mkTableGen( partitioner = Some(RVDPartitioner.unkeyed(ctx.stateManager, 0)), errorId = Some(errorId), )) val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) val ex = intercept[HailException] { - ExecuteContext.scoped(ctx => loweredExecute(ctx, lowered, Env.empty, FastSeq(), None)) + loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) } ex.errorId shouldBe errorId ex.getMessage should include("partitioner contains 0 partitions, got 2 contexts.") @@ -118,7 +116,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("lowering")) def testRowsAreCorrectlyKeyed(): Unit = { val errorId = 56 - val table = TestUtils.collect(mkTableGen( + val table = collect(mkTableGen( partitioner = Some(new RVDPartitioner( ctx.stateManager, TStruct("a" -> TInt32), @@ -131,7 +129,7 @@ class TableGenSuite extends HailSuite { )) val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) val ex = intercept[SparkException] { - ExecuteContext.scoped(ctx => loweredExecute(ctx, lowered, Env.empty, FastSeq(), None)) + loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) }.getCause.asInstanceOf[HailException] ex.errorId shouldBe errorId diff --git a/hail/src/test/scala/is/hail/io/compress/BGzipCodecSuite.scala b/hail/src/test/scala/is/hail/io/compress/BGzipCodecSuite.scala index 9c3a98ae9bc..c5be7ab8989 100644 --- a/hail/src/test/scala/is/hail/io/compress/BGzipCodecSuite.scala +++ b/hail/src/test/scala/is/hail/io/compress/BGzipCodecSuite.scala @@ -1,7 +1,6 @@ package is.hail.io.compress import is.hail.HailSuite -import is.hail.TestUtils._ import is.hail.check.Gen import is.hail.check.Prop.forAll import is.hail.expr.ir.GenericLines diff --git a/hail/src/test/scala/is/hail/io/fs/FSSuite.scala b/hail/src/test/scala/is/hail/io/fs/FSSuite.scala index 3173b646b06..fa834fd5de4 100644 --- a/hail/src/test/scala/is/hail/io/fs/FSSuite.scala +++ b/hail/src/test/scala/is/hail/io/fs/FSSuite.scala @@ -12,13 +12,10 @@ import org.apache.hadoop.fs.FileAlreadyExistsException import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test -trait FSSuite extends TestNGSuite { +trait FSSuite extends TestNGSuite with TestUtils { val root: String = System.getenv("HAIL_TEST_STORAGE_URI") - def fsResourcesRoot: String = System.getenv("HAIL_FS_TEST_CLOUD_RESOURCES_URI") - def tmpdir: String = System.getenv("HAIL_TEST_STORAGE_URI") - def fs: FS /* Structure of src/test/resources/fs: @@ -73,7 +70,7 @@ trait FSSuite extends TestNGSuite { @Test def testFileStatusOnDirIsFailure(): Unit = { val f = r("/dir") - TestUtils.interceptException[FileNotFoundException](f)( + interceptException[FileNotFoundException](f)( fs.fileStatus(f) ) } diff --git a/hail/src/test/scala/is/hail/linalg/BlockMatrixSuite.scala b/hail/src/test/scala/is/hail/linalg/BlockMatrixSuite.scala index cb2a9008a5c..22c76c9be79 100644 --- a/hail/src/test/scala/is/hail/linalg/BlockMatrixSuite.scala +++ b/hail/src/test/scala/is/hail/linalg/BlockMatrixSuite.scala @@ -1,6 +1,6 @@ package is.hail.linalg -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.check._ import is.hail.check.Arbitrary._ import is.hail.check.Gen._ @@ -922,9 +922,9 @@ class BlockMatrixSuite extends HailSuite { val bm = BlockMatrix.fromBreezeMatrix(lm, blockSize = 2) val expected = new BDM[Double](2, 3, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) - TestUtils.assertMatrixEqualityDouble(bm.pow(0.0).toBreezeMatrix(), BDM.fill(2, 3)(1.0)) - TestUtils.assertMatrixEqualityDouble(bm.pow(0.5).toBreezeMatrix(), expected) - TestUtils.assertMatrixEqualityDouble(bm.sqrt().toBreezeMatrix(), expected) + assertMatrixEqualityDouble(bm.pow(0.0).toBreezeMatrix(), BDM.fill(2, 3)(1.0)) + assertMatrixEqualityDouble(bm.pow(0.5).toBreezeMatrix(), expected) + assertMatrixEqualityDouble(bm.sqrt().toBreezeMatrix(), expected) } def filteredEquals(bm1: BlockMatrix, bm2: BlockMatrix): Boolean = @@ -1215,17 +1215,17 @@ class BlockMatrixSuite extends HailSuite { val v0 = Array(0.0, Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity) - TestUtils.interceptFatal(notSupported)(bm0 / bm0) - TestUtils.interceptFatal(notSupported)(bm0.reverseRowVectorDiv(v)) - TestUtils.interceptFatal(notSupported)(bm0.reverseColVectorDiv(v)) - TestUtils.interceptFatal(notSupported)(1 / bm0) + interceptFatal(notSupported)(bm0 / bm0) + interceptFatal(notSupported)(bm0.reverseRowVectorDiv(v)) + interceptFatal(notSupported)(bm0.reverseColVectorDiv(v)) + interceptFatal(notSupported)(1 / bm0) - TestUtils.interceptFatal(notSupported)(bm0.rowVectorDiv(v0)) - TestUtils.interceptFatal(notSupported)(bm0.colVectorDiv(v0)) - TestUtils.interceptFatal("multiplication by scalar NaN")(bm0 * Double.NaN) - TestUtils.interceptFatal("division by scalar 0.0")(bm0 / 0) + interceptFatal(notSupported)(bm0.rowVectorDiv(v0)) + interceptFatal(notSupported)(bm0.colVectorDiv(v0)) + interceptFatal("multiplication by scalar NaN")(bm0 * Double.NaN) + interceptFatal("division by scalar 0.0")(bm0 / 0) - TestUtils.interceptFatal(notSupported)(bm0.pow(-1)) + interceptFatal(notSupported)(bm0.pow(-1)) } @Test diff --git a/hail/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala b/hail/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala index ebee4aa797d..97c24f4aae0 100644 --- a/hail/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala +++ b/hail/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala @@ -1,6 +1,6 @@ package is.hail.methods -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.annotations.Annotation import is.hail.check.{Gen, Properties} import is.hail.check.Prop._ @@ -111,7 +111,7 @@ class LocalLDPruneSuite extends HailSuite { val nCores = 4 lazy val mt = Interpret( - TestUtils.importVCF(ctx, "src/test/resources/sample.vcf.bgz", nPartitions = Option(10)), + importVCF(ctx, "src/test/resources/sample.vcf.bgz", nPartitions = Option(10)), ctx, false, ).toMatrixValue(Array("s")) diff --git a/hail/src/test/scala/is/hail/methods/SkatSuite.scala b/hail/src/test/scala/is/hail/methods/SkatSuite.scala index a9b5e723f21..59ca348b593 100644 --- a/hail/src/test/scala/is/hail/methods/SkatSuite.scala +++ b/hail/src/test/scala/is/hail/methods/SkatSuite.scala @@ -1,6 +1,6 @@ package is.hail.methods -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.expr.ir.DoubleArrayBuilder import is.hail.utils._ @@ -31,6 +31,6 @@ class SkatSuite extends HailSuite { val (qLarge, gramianLarge) = Skat.computeGramianLargeN(st) assert(D_==(qSmall, qLarge)) - TestUtils.assertMatrixEqualityDouble(gramianSmall, gramianLarge) + assertMatrixEqualityDouble(gramianSmall, gramianLarge) } } diff --git a/hail/src/test/scala/is/hail/stats/eigSymDSuite.scala b/hail/src/test/scala/is/hail/stats/eigSymDSuite.scala index 73bcb32e9e4..3a51cb32f4a 100644 --- a/hail/src/test/scala/is/hail/stats/eigSymDSuite.scala +++ b/hail/src/test/scala/is/hail/stats/eigSymDSuite.scala @@ -1,6 +1,6 @@ package is.hail.stats -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.utils._ import breeze.linalg.{eigSym, svd, DenseMatrix, DenseVector} @@ -109,7 +109,7 @@ class eigSymDSuite extends HailSuite { val x = DenseVector.fill[Double](n)(rand.nextGaussian()) - TestUtils.assertVectorEqualityDouble(x, TriSolve(A, A * x)) + assertVectorEqualityDouble(x, TriSolve(A, A * x)) } } } diff --git a/hail/src/test/scala/is/hail/utils/RichArraySuite.scala b/hail/src/test/scala/is/hail/utils/RichArraySuite.scala index df1a5809f51..b7793450052 100644 --- a/hail/src/test/scala/is/hail/utils/RichArraySuite.scala +++ b/hail/src/test/scala/is/hail/utils/RichArraySuite.scala @@ -1,6 +1,6 @@ package is.hail.utils -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.utils.richUtils.RichArray import org.testng.annotations.Test @@ -15,7 +15,7 @@ class RichArraySuite extends HailSuite { RichArray.importFromDoubles(fs, file, a2, bufSize = 16) assert(a === a2) - TestUtils.interceptFatal("Premature") { + interceptFatal("Premature") { RichArray.importFromDoubles(fs, file, new Array[Double](101), bufSize = 64) } } diff --git a/hail/src/test/scala/is/hail/utils/RichDenseMatrixDoubleSuite.scala b/hail/src/test/scala/is/hail/utils/RichDenseMatrixDoubleSuite.scala index a8d05321bf4..282e71d77e0 100644 --- a/hail/src/test/scala/is/hail/utils/RichDenseMatrixDoubleSuite.scala +++ b/hail/src/test/scala/is/hail/utils/RichDenseMatrixDoubleSuite.scala @@ -1,6 +1,6 @@ package is.hail.utils -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.linalg.BlockMatrix import is.hail.utils.richUtils.RichDenseMatrixDouble @@ -33,7 +33,7 @@ class RichDenseMatrixDoubleSuite extends HailSuite { val lmT2 = RichDenseMatrixDouble.importFromDoubles(fs, fileT, 100, 50, rowMajor = true) assert(mT === lmT2) - TestUtils.interceptFatal("Premature") { + interceptFatal("Premature") { RichDenseMatrixDouble.importFromDoubles(fs, fileT, 100, 100, rowMajor = true) } } diff --git a/hail/src/test/scala/is/hail/variant/GenotypeSuite.scala b/hail/src/test/scala/is/hail/variant/GenotypeSuite.scala index f88eff4b872..4efb2099889 100644 --- a/hail/src/test/scala/is/hail/variant/GenotypeSuite.scala +++ b/hail/src/test/scala/is/hail/variant/GenotypeSuite.scala @@ -9,8 +9,7 @@ import is.hail.utils._ import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test -class GenotypeSuite extends TestNGSuite { - +class GenotypeSuite extends TestNGSuite with TestUtils { val v = Variant("1", 1, "A", "T") @Test def gtPairGtIndexIsId(): Unit = @@ -119,6 +118,6 @@ class GenotypeSuite extends TestNGSuite { assert(Call.parse("0|1") == Call2(0, 1, phased = true)) intercept[UnsupportedOperationException](Call.parse("1/1/1")) intercept[UnsupportedOperationException](Call.parse("1|1|1")) - TestUtils.interceptFatal("invalid call expression:")(Call.parse("0/")) + interceptFatal("invalid call expression:")(Call.parse("0/")) } } diff --git a/hail/src/test/scala/is/hail/variant/LocusIntervalSuite.scala b/hail/src/test/scala/is/hail/variant/LocusIntervalSuite.scala index 56027cbb1ac..da59fe8c3c8 100644 --- a/hail/src/test/scala/is/hail/variant/LocusIntervalSuite.scala +++ b/hail/src/test/scala/is/hail/variant/LocusIntervalSuite.scala @@ -1,6 +1,6 @@ package is.hail.variant -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.utils._ import org.testng.annotations.Test @@ -161,11 +161,11 @@ class LocusIntervalSuite extends HailSuite { true, true, )) - TestUtils.interceptFatal("Start 'X:0' is not within the range")(Locus.parseInterval( + interceptFatal("Start 'X:0' is not within the range")(Locus.parseInterval( "[X:0-5)", rg, )) - TestUtils.interceptFatal(s"End 'X:${xMax + 1}' is not within the range")(Locus.parseInterval( + interceptFatal(s"End 'X:${xMax + 1}' is not within the range")(Locus.parseInterval( s"[X:1-${xMax + 1}]", rg, )) @@ -208,19 +208,19 @@ class LocusIntervalSuite extends HailSuite { false, )) - TestUtils.interceptFatal("invalid interval expression") { + interceptFatal("invalid interval expression") { Locus.parseInterval("4::start-5:end", rg) } - TestUtils.interceptFatal("invalid interval expression") { + interceptFatal("invalid interval expression") { Locus.parseInterval("4:start-", rg) } - TestUtils.interceptFatal("invalid interval expression") { + interceptFatal("invalid interval expression") { Locus.parseInterval("1:1.1111K-2k", rg) } - TestUtils.interceptFatal("invalid interval expression") { + interceptFatal("invalid interval expression") { Locus.parseInterval("1:1.1111111M-2M", rg) } diff --git a/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala b/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala index f97bce88436..0d6de88b39d 100644 --- a/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala +++ b/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala @@ -1,6 +1,6 @@ package is.hail.variant -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite import is.hail.backend.HailStateManager import is.hail.check.Prop._ import is.hail.check.Properties @@ -51,18 +51,18 @@ class ReferenceGenomeSuite extends HailSuite { } @Test def testAssertions(): Unit = { - TestUtils.interceptFatal("Must have at least one contig in the reference genome.")( + interceptFatal("Must have at least one contig in the reference genome.")( ReferenceGenome("test", Array.empty[String], Map.empty[String, Int]) ) - TestUtils.interceptFatal("No lengths given for the following contigs:")(ReferenceGenome( + interceptFatal("No lengths given for the following contigs:")(ReferenceGenome( "test", Array("1", "2", "3"), Map("1" -> 5), )) - TestUtils.interceptFatal("Contigs found in 'lengths' that are not present in 'contigs'")( + interceptFatal("Contigs found in 'lengths' that are not present in 'contigs'")( ReferenceGenome("test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5, "4" -> 100)) ) - TestUtils.interceptFatal("The following X contig names are absent from the reference:")( + interceptFatal("The following X contig names are absent from the reference:")( ReferenceGenome( "test", Array("1", "2", "3"), @@ -70,7 +70,7 @@ class ReferenceGenomeSuite extends HailSuite { xContigs = Set("X"), ) ) - TestUtils.interceptFatal("The following Y contig names are absent from the reference:")( + interceptFatal("The following Y contig names are absent from the reference:")( ReferenceGenome( "test", Array("1", "2", "3"), @@ -78,7 +78,7 @@ class ReferenceGenomeSuite extends HailSuite { yContigs = Set("Y"), ) ) - TestUtils.interceptFatal( + interceptFatal( "The following mitochondrial contig names are absent from the reference:" )(ReferenceGenome( "test", @@ -86,13 +86,13 @@ class ReferenceGenomeSuite extends HailSuite { Map("1" -> 5, "2" -> 5, "3" -> 5), mtContigs = Set("MT"), )) - TestUtils.interceptFatal("The contig name for PAR interval")(ReferenceGenome( + interceptFatal("The contig name for PAR interval")(ReferenceGenome( "test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5), parInput = Array((Locus("X", 1), Locus("X", 5))), )) - TestUtils.interceptFatal("in both X and Y contigs.")(ReferenceGenome( + interceptFatal("in both X and Y contigs.")(ReferenceGenome( "test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5), @@ -103,7 +103,7 @@ class ReferenceGenomeSuite extends HailSuite { @Test def testContigRemap(): Unit = { val mapping = Map("23" -> "foo") - TestUtils.interceptFatal("have remapped contigs in reference genome")( + interceptFatal("have remapped contigs in reference genome")( ctx.References(ReferenceGenome.GRCh37).validateContigRemap(mapping) ) } From 14b8467a63a14453feaf2705ce9cbf92e7c7845c Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Tue, 1 Oct 2024 08:59:33 -0400 Subject: [PATCH 2/4] checkpoint --- .../is/hail/backend/api/Py4JBackendApi.scala | 136 ++++++++++++------ 1 file changed, 90 insertions(+), 46 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala index bfce7175d48..3e695dac918 100644 --- a/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala +++ b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala @@ -6,10 +6,7 @@ import is.hail.backend._ import is.hail.backend.caching.BlockMatrixCache import is.hail.backend.spark.SparkBackend import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} -import is.hail.expr.ir.{ - BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser, - Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue, -} +import is.hail.expr.ir.{BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue} import is.hail.expr.ir.IRParser.parseType import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.functions.IRFunctionRegistry @@ -26,12 +23,10 @@ import is.hail.variant.ReferenceGenome import scala.collection.mutable import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter} - import java.io.Closeable import java.net.InetSocketAddress import java.util import java.util.concurrent._ - import com.google.api.client.http.HttpStatusCodes import com.sun.net.httpserver.{HttpExchange, HttpServer} import org.apache.hadoop @@ -42,41 +37,40 @@ import org.json4s._ import org.json4s.jackson.{JsonMethods, Serialization} import sourcecode.Enclosing +import javax.annotation.Nullable + final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling { - private[this] val tmpdir: String = ??? - private[this] val localTmpdir: String = ??? - private[this] val longLifeTempFileManager = null: TempFileManager private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() - private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader) + private[this] val hcl = new HailClassLoader(getClass.getClassLoader) private[this] val references = mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*) private[this] val bmCache = new BlockMatrixCache() private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) private[this] val persistedIr = mutable.Map[Int, BaseIR]() private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) - private[this] def cloudfsConfig = CloudStorageFSConfig.fromFlagsAndEnv(None, flags) - - def fs: FS = backend match { - case s: SparkBackend => - val conf = new Configuration(s.sc.hadoopConfiguration) - cloudfsConfig.google.flatMap(_.requester_pays_config).foreach { - case RequesterPaysConfig(prj, bkts) => - bkts - .map { buckets => - conf.set("fs.gs.requester.pays.mode", "CUSTOM") - conf.set("fs.gs.requester.pays.project.id", prj) - conf.set("fs.gs.requester.pays.buckets", buckets.mkString(",")) - } - .getOrElse { - conf.set("fs.gs.requester.pays.mode", "AUTO") - conf.set("fs.gs.requester.pays.project.id", prj) - } - } - new HadoopFS(new SerializableHadoopConfiguration(conf)) + private[this] var irID: Int = 0 + private[this] var tmpdir: String = _ + private[this] var localTmpdir: String = _ + + private[this] object tmpFileManager extends TempFileManager { + private[this] var fs = newFs(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) + private[this] var manager = new OwningTempFileManager(fs) + + def setFs(fs: FS): Unit = { + close() + this.fs = fs + manager = new OwningTempFileManager(fs) + } + + def getFs: FS = + fs + + override def newTmpPath(tmpdir: String, prefix: String, extension: String): String = + manager.newTmpPath(tmpdir, prefix, extension) - case _ => - RouterFS.buildRoutes(cloudfsConfig) + override def close(): Unit = + manager.close() } def pyGetFlag(name: String): String = @@ -88,24 +82,40 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin def pyAvailableFlags: java.util.ArrayList[String] = flags.available - private[this] var irID: Int = 0 + def pySetTmpdir(tmp: String): Unit = + tmpdir = tmp - private[this] def nextIRID(): Int = { - irID += 1 - irID - } + def pySetLocalTmp(tmp: String): Unit = + localTmpdir = tmp - private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = { - val id = nextIRID() - ctx.IrCache += (id -> ir) - id + def pySetRequesterPays(@Nullable project: String, @Nullable buckets: util.List[String]): Unit = { + val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags) + + val rpConfig: Option[RequesterPaysConfig] = + (Option(project).filter(_.nonEmpty), Option(buckets)) match { + case (Some(project), buckets) => Some(RequesterPaysConfig(project, buckets.map(_.asScala.toSet))) + case (None, Some(_)) => fatal("A non-empty, non-null requester pays google project is required to configure requester pays buckets.") + case (None, None) => None + } + + val fs = newFs( + cloudfsConf.copy( + google = (cloudfsConf.google, rpConfig) match { + case (Some(gconf), _) => Some(gconf.copy(requester_pays_config = rpConfig)) + case (None, Some(_)) => Some(GoogleStorageFSConfig(None, rpConfig)) + case _ => None + } + ) + ) + + tmpFileManager.setFs(fs) } def pyRemoveJavaIR(id: Int): Unit = persistedIr.remove(id) def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = - references(name).addSequence(IndexedFastaSequenceFile(fs, fastaFile, indexFile)) + references(name).addSequence(IndexedFastaSequenceFile(tmpFileManager.getFs, fastaFile, indexFile)) def pyRemoveSequence(name: String): Unit = references(name).removeSequence() @@ -239,7 +249,7 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin removeReference(name) def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = - references(name).addLiftover(references(destRGName), LiftOver(fs, chainFile)) + references(name).addLiftover(references(destRGName), LiftOver(tmpFileManager.getFs, chainFile)) def pyRemoveLiftover(name: String, destRGName: String): Unit = references(name).removeLiftover(destRGName) @@ -267,7 +277,7 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin private[this] def removeReference(name: String): Unit = references -= name - private def withExecuteContext[T]( + private[this] def withExecuteContext[T]( selfContainedExecution: Boolean = true )( f: ExecuteContext => T @@ -278,12 +288,12 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin tmpdir = tmpdir, localTmpdir = localTmpdir, backend = backend, - fs = fs, + fs = tmpFileManager.getFs, timer = timer, tempFileManager = if (selfContainedExecution) null - else NonOwningTempFileManager(longLifeTempFileManager), - theHailClassLoader = theHailClassLoader, + else NonOwningTempFileManager(tmpFileManager), + theHailClassLoader = hcl, flags = flags, irMetadata = IrMetadata(None), references = references, @@ -294,6 +304,40 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin )(f) } + private[this] def newFs(cloudfsConfig: CloudStorageFSConfig): FS = + backend match { + case s: SparkBackend => + val conf = new Configuration(s.sc.hadoopConfiguration) + cloudfsConfig.google.flatMap(_.requester_pays_config).foreach { + case RequesterPaysConfig(prj, bkts) => + bkts + .map { buckets => + conf.set("fs.gs.requester.pays.mode", "CUSTOM") + conf.set("fs.gs.requester.pays.project.id", prj) + conf.set("fs.gs.requester.pays.buckets", buckets.mkString(",")) + } + .getOrElse { + conf.set("fs.gs.requester.pays.mode", "AUTO") + conf.set("fs.gs.requester.pays.project.id", prj) + } + } + new HadoopFS(new SerializableHadoopConfiguration(conf)) + + case _ => + RouterFS.buildRoutes(cloudfsConfig) + } + + private[this] def nextIRID(): Int = { + irID += 1 + irID + } + + private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = { + val id = nextIRID() + ctx.IrCache += (id -> ir) + id + } + override def close(): Unit = synchronized { bmCache.close() From 1bf70ac90ab9bed69a0cd06ec9d970e7fabc5cbb Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Tue, 1 Oct 2024 10:38:15 -0400 Subject: [PATCH 3/4] fix warnigns - now builds in "Release Mode" --- .../is/hail/backend/api/Py4JBackendApi.scala | 18 +++++++--- .../main/scala/is/hail/utils/package.scala | 35 ++++++++++--------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala index 3e695dac918..216d55154a1 100644 --- a/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala +++ b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala @@ -38,6 +38,7 @@ import org.json4s.jackson.{JsonMethods, Serialization} import sourcecode.Enclosing import javax.annotation.Nullable +import scala.annotation.nowarn final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling { @@ -92,9 +93,14 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags) val rpConfig: Option[RequesterPaysConfig] = - (Option(project).filter(_.nonEmpty), Option(buckets)) match { - case (Some(project), buckets) => Some(RequesterPaysConfig(project, buckets.map(_.asScala.toSet))) - case (None, Some(_)) => fatal("A non-empty, non-null requester pays google project is required to configure requester pays buckets.") + ( + Option(project).filter(_.nonEmpty), + Option(buckets).map(_.asScala.toSet.filterNot(_.isBlank)).filter(_.nonEmpty), + ) match { + case (Some(project), buckets) => Some(RequesterPaysConfig(project, buckets)) + case (None, Some(_)) => fatal( + "A non-empty, non-null requester pays google project is required to configure requester pays buckets." + ) case (None, None) => None } @@ -115,7 +121,9 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin persistedIr.remove(id) def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = - references(name).addSequence(IndexedFastaSequenceFile(tmpFileManager.getFs, fastaFile, indexFile)) + references(name).addSequence( + IndexedFastaSequenceFile(tmpFileManager.getFs, fastaFile, indexFile) + ) def pyRemoveSequence(name: String): Unit = references(name).removeSequence() @@ -440,7 +448,7 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin t } - def port: Int = httpServer.getAddress.getPort + @nowarn def port: Int = httpServer.getAddress.getPort override def close(): Unit = httpServer.stop(10) thread.start() diff --git a/hail/src/main/scala/is/hail/utils/package.scala b/hail/src/main/scala/is/hail/utils/package.scala index 05d1d0932b4..4fc2acbaf26 100644 --- a/hail/src/main/scala/is/hail/utils/package.scala +++ b/hail/src/main/scala/is/hail/utils/package.scala @@ -91,6 +91,24 @@ package utils { b.result() } } + + + class Lazy[A] private[utils] (f: => A) { + private[this] var option: Option[A] = None + + def apply(): A = + synchronized { + option match { + case Some(a) => a + case None => val a = f; option = Some(a); a + } + } + + def isEvaluated: Boolean = + synchronized { + option.isDefined + } + } } package object utils @@ -1058,23 +1076,6 @@ package object utils implicit def evalLazy[A](f: Lazy[A]): A = f() - - class Lazy[A] private[utils] (f: => A) { - private[this] var option: Option[A] = None - - def apply(): A = - synchronized { - option match { - case Some(a) => a - case None => val a = f; option = Some(a); a - } - } - - def isEvaluated: Boolean = - synchronized { - option.isDefined - } - } } class CancellingExecutorService(delegate: ExecutorService) extends AbstractExecutorService { From 2957928782cd2766d1edd4af7765affc356c2736 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Tue, 1 Oct 2024 12:08:19 -0400 Subject: [PATCH 4/4] wire up spark, local and py4j backends --- hail/python/hail/backend/local_backend.py | 15 ++-------- hail/python/hail/backend/py4j_backend.py | 29 +++++++++++++++++-- hail/python/hail/backend/spark_backend.py | 29 ++++--------------- hail/python/hail/context.py | 13 +++------ .../is/hail/backend/api/Py4JBackendApi.scala | 20 +++++++++---- .../hail/backend/api/ServiceBackendApi.scala | 4 ++- .../main/scala/is/hail/utils/package.scala | 1 - hail/src/test/scala/is/hail/HailSuite.scala | 4 +-- 8 files changed, 59 insertions(+), 56 deletions(-) diff --git a/hail/python/hail/backend/local_backend.py b/hail/python/hail/backend/local_backend.py index 0908ee06986..311a5c7e9bb 100644 --- a/hail/python/hail/backend/local_backend.py +++ b/hail/python/hail/backend/local_backend.py @@ -81,23 +81,14 @@ def __init__( ) jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations) - super(LocalBackend, self).__init__(self._gateway.jvm, jbackend, jhc) + super().__init__(self._gateway.jvm, jbackend, jhc, tmpdir, tmpdir) + self.gcs_requester_pays_configuration = gcs_requester_pays_configuration self._fs = self._exit_stack.enter_context( RouterFS(gcs_kwargs={'gcs_requester_pays_configuration': gcs_requester_pays_configuration}) ) self._logger = None - - flags = {} - if gcs_requester_pays_configuration is not None: - if isinstance(gcs_requester_pays_configuration, str): - flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration - else: - assert isinstance(gcs_requester_pays_configuration, tuple) - flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration[0] - flags['gcs_requester_pays_buckets'] = ','.join(gcs_requester_pays_configuration[1]) - - self._initialize_flags(flags) + self._initialize_flags({}) def validate_file(self, uri: str) -> None: async_to_blocking(validate_file(uri, self._fs.afs)) diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index 49838ea4a85..d59ae456a93 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -13,8 +13,10 @@ import hail from hail.expr import construct_expr +from hail.fs.hadoop_fs import HadoopFS from hail.ir import JavaIR from hail.utils.java import Env, FatalError, scala_package_object +from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration from ..hail_logging import Logger from .backend import ActionTag, Backend, fatal_error_from_java_error_triplet @@ -193,11 +195,17 @@ def decode_bytearray(encoded): self._utils_package_object = scala_package_object(self._hail_package.utils) self._jhc = jhc - self._jbackend = self._hail_package.backend.api.P4jBackendApi(jbackend) + self._jbackend = self._hail_package.backend.api.Py4JBackendApi(jbackend) + self._jbackend.pySetLocalTmp(tmpdir) + self._jbackend.pySetRemoteTmp(remote_tmpdir) + self._jhttp_server = self._jbackend.pyHttpServer() - self._backend_server_port: int = self._jbackend.HttpServer.port() + self._backend_server_port: int = self._jhttp_server.port() self._requests_session = requests.Session() + self._gcs_requester_pays_config = None + self._fs = None + # This has to go after creating the SparkSession. Unclear why. # Maybe it does its own patch? install_exception_handler() @@ -221,6 +229,23 @@ def hail_package(self): def utils_package_object(self): return self._utils_package_object + @property + def gcs_requester_pays_configuration(self) -> Optional[GCSRequesterPaysConfiguration]: + return self._gcs_requester_pays_config + + @gcs_requester_pays_configuration.setter + def gcs_requester_pays_configuration(self, config: Optional[GCSRequesterPaysConfiguration]): + self._gcs_requester_pays_config = config + project, buckets = (None, None) if config is None else (config, None) if isinstance(config, str) else config + self._jbackend.pySetGcsRequesterPaysConfig(project, buckets) + self._fs = None # stale + + @property + def fs(self): + if self._fs is None: + self._fs = HadoopFS(self._utils_package_object, self._jbackend.pyFs()) + return self._fs + @property def logger(self): if self._logger is None: diff --git a/hail/python/hail/backend/spark_backend.py b/hail/python/hail/backend/spark_backend.py index 69f53292443..192ca57eb75 100644 --- a/hail/python/hail/backend/spark_backend.py +++ b/hail/python/hail/backend/spark_backend.py @@ -7,11 +7,11 @@ import pyspark.sql from hail.expr.table_type import ttable -from hail.fs.hadoop_fs import HadoopFS from hail.ir import BaseIR from hail.ir.renderer import CSERenderer from hail.table import Table from hail.utils import copy_log +from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration from hailtop.aiotools.router_fs import RouterAsyncFS from hailtop.aiotools.validators import validate_file from hailtop.utils import async_to_blocking @@ -47,12 +47,9 @@ def __init__( skip_logging_configuration, optimizer_iterations, *, - gcs_requester_pays_project: Optional[str] = None, - gcs_requester_pays_buckets: Optional[str] = None, + gcs_requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None, copy_log_on_error: bool = False, ): - assert gcs_requester_pays_project is not None or gcs_requester_pays_buckets is None - try: local_jar_info = local_jar_information() except ValueError: @@ -120,10 +117,6 @@ def __init__( append, skip_logging_configuration, min_block_size, - tmpdir, - local_tmpdir, - gcs_requester_pays_project, - gcs_requester_pays_buckets, ) jhc = hail_package.HailContext.getOrCreate(jbackend, branching_factor, optimizer_iterations) else: @@ -137,10 +130,6 @@ def __init__( append, skip_logging_configuration, min_block_size, - tmpdir, - local_tmpdir, - gcs_requester_pays_project, - gcs_requester_pays_buckets, ) jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations) @@ -149,12 +138,12 @@ def __init__( self.sc = sc else: self.sc = pyspark.SparkContext(gateway=self._gateway, jsc=jvm.JavaSparkContext(self._jsc)) - self._jspark_session = jbackend.sparkSession() + self._jspark_session = jbackend.sparkSession().apply() self._spark_session = pyspark.sql.SparkSession(self.sc, self._jspark_session) - super(SparkBackend, self).__init__(jvm, jbackend, jhc) + super().__init__(jvm, jbackend, jhc, local_tmpdir, tmpdir) + self.gcs_requester_pays_configuration = gcs_requester_pays_config - self._fs = None self._logger = None if not quiet: @@ -167,7 +156,7 @@ def __init__( self._initialize_flags({}) self._router_async_fs = RouterAsyncFS( - gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_project} + gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_config} ) self._tmpdir = tmpdir @@ -181,12 +170,6 @@ def stop(self): self.sc.stop() self.sc = None - @property - def fs(self): - if self._fs is None: - self._fs = HadoopFS(self._utils_package_object, self._jbackend.fs()) - return self._fs - def from_spark(self, df, key): result_tuple = self._jbackend.pyFromDF(df._jdf, key) tir_id, type_json = result_tuple._1(), result_tuple._2() diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index 5258f27fbc1..3d8689f4be1 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -474,14 +474,10 @@ def init_spark( optimizer_iterations = get_env_or_default(_optimizer_iterations, 'HAIL_OPTIMIZER_ITERATIONS', 3) app_name = app_name or 'Hail' - ( - gcs_requester_pays_project, - gcs_requester_pays_buckets, - ) = convert_gcs_requester_pays_configuration_to_hadoop_conf_style( - get_gcs_requester_pays_configuration( - gcs_requester_pays_configuration=gcs_requester_pays_configuration, - ) + gcs_requester_pays_configuration = get_gcs_requester_pays_configuration( + gcs_requester_pays_configuration=gcs_requester_pays_configuration, ) + backend = SparkBackend( idempotent, sc, @@ -498,8 +494,7 @@ def init_spark( local_tmpdir, skip_logging_configuration, optimizer_iterations, - gcs_requester_pays_project=gcs_requester_pays_project, - gcs_requester_pays_buckets=gcs_requester_pays_buckets, + gcs_requester_pays_config=gcs_requester_pays_configuration, copy_log_on_error=copy_log_on_error, ) if not backend.fs.exists(tmpdir): diff --git a/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala index 216d55154a1..8ea280c08dd 100644 --- a/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala +++ b/hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala @@ -6,7 +6,10 @@ import is.hail.backend._ import is.hail.backend.caching.BlockMatrixCache import is.hail.backend.spark.SparkBackend import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} -import is.hail.expr.ir.{BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue} +import is.hail.expr.ir.{ + BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser, + Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue, +} import is.hail.expr.ir.IRParser.parseType import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.functions.IRFunctionRegistry @@ -21,14 +24,18 @@ import is.hail.utils._ import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome +import scala.annotation.nowarn import scala.collection.mutable import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter} + import java.io.Closeable import java.net.InetSocketAddress import java.util import java.util.concurrent._ + import com.google.api.client.http.HttpStatusCodes import com.sun.net.httpserver.{HttpExchange, HttpServer} +import javax.annotation.Nullable import org.apache.hadoop import org.apache.hadoop.conf.Configuration import org.apache.spark.sql.DataFrame @@ -37,9 +44,6 @@ import org.json4s._ import org.json4s.jackson.{JsonMethods, Serialization} import sourcecode.Enclosing -import javax.annotation.Nullable -import scala.annotation.nowarn - final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling { private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() @@ -74,6 +78,9 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin manager.close() } + def pyFs: FS = + tmpFileManager.getFs + def pyGetFlag(name: String): String = flags.get(name) @@ -83,13 +90,14 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin def pyAvailableFlags: java.util.ArrayList[String] = flags.available - def pySetTmpdir(tmp: String): Unit = + def pySetRemoteTmp(tmp: String): Unit = tmpdir = tmp def pySetLocalTmp(tmp: String): Unit = localTmpdir = tmp - def pySetRequesterPays(@Nullable project: String, @Nullable buckets: util.List[String]): Unit = { + def pySetGcsRequesterPaysConfig(@Nullable project: String, @Nullable buckets: util.List[String]) + : Unit = { val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags) val rpConfig: Option[RequesterPaysConfig] = diff --git a/hail/src/main/scala/is/hail/backend/api/ServiceBackendApi.scala b/hail/src/main/scala/is/hail/backend/api/ServiceBackendApi.scala index 4e5659dc13b..f089d79fb84 100644 --- a/hail/src/main/scala/is/hail/backend/api/ServiceBackendApi.scala +++ b/hail/src/main/scala/is/hail/backend/api/ServiceBackendApi.scala @@ -11,7 +11,9 @@ import is.hail.io.fs.{CloudStorageFSConfig, FS, RouterFS} import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} import is.hail.services._ import is.hail.types.virtual.Kinds -import is.hail.utils.{toRichIterable, using, ErrorHandling, ExecutionTimer, HailWorkerException, Logging} +import is.hail.utils.{ + toRichIterable, using, ErrorHandling, ExecutionTimer, HailWorkerException, Logging, +} import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome diff --git a/hail/src/main/scala/is/hail/utils/package.scala b/hail/src/main/scala/is/hail/utils/package.scala index 4fc2acbaf26..36fc0b91c70 100644 --- a/hail/src/main/scala/is/hail/utils/package.scala +++ b/hail/src/main/scala/is/hail/utils/package.scala @@ -92,7 +92,6 @@ package utils { } } - class Lazy[A] private[utils] (f: => A) { private[this] var option: Option[A] = None diff --git a/hail/src/test/scala/is/hail/HailSuite.scala b/hail/src/test/scala/is/hail/HailSuite.scala index 057d3b77744..b5b40cdf7bd 100644 --- a/hail/src/test/scala/is/hail/HailSuite.scala +++ b/hail/src/test/scala/is/hail/HailSuite.scala @@ -70,8 +70,8 @@ class HailSuite extends TestNGSuite with TestUtils { var pool: RegionPool = _ private[this] var ctx_ : ExecuteContext = _ - def backend: Backend = ctx.backend - def sc: SparkContext = backend.asSpark.sc + def backend: Backend = hc.backend + def sc: SparkContext = hc.backend.asSpark.sc def timer: ExecutionTimer = ctx.timer def theHailClassLoader: HailClassLoader = ctx.theHailClassLoader override def ctx: ExecuteContext = ctx_