Skip to content

Commit

Permalink
checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Oct 1, 2024
1 parent 33a991e commit 7c2810f
Showing 1 changed file with 90 additions and 46 deletions.
136 changes: 90 additions & 46 deletions hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ import is.hail.backend._
import is.hail.backend.caching.BlockMatrixCache
import is.hail.backend.spark.SparkBackend
import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex}
import is.hail.expr.ir.{
BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser,
Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue,
}
import is.hail.expr.ir.{BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue}
import is.hail.expr.ir.IRParser.parseType
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.functions.IRFunctionRegistry
Expand All @@ -26,12 +23,10 @@ import is.hail.variant.ReferenceGenome

import scala.collection.mutable
import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter}

import java.io.Closeable
import java.net.InetSocketAddress
import java.util
import java.util.concurrent._

import com.google.api.client.http.HttpStatusCodes
import com.sun.net.httpserver.{HttpExchange, HttpServer}
import org.apache.hadoop
Expand All @@ -42,41 +37,40 @@ import org.json4s._
import org.json4s.jackson.{JsonMethods, Serialization}
import sourcecode.Enclosing

import javax.annotation.Nullable

final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling {

private[this] val tmpdir: String = ???
private[this] val localTmpdir: String = ???
private[this] val longLifeTempFileManager = null: TempFileManager
private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
private[this] val hcl = new HailClassLoader(getClass.getClassLoader)
private[this] val references = mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*)
private[this] val bmCache = new BlockMatrixCache()
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
private[this] val persistedIr = mutable.Map[Int, BaseIR]()
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)

private[this] def cloudfsConfig = CloudStorageFSConfig.fromFlagsAndEnv(None, flags)

def fs: FS = backend match {
case s: SparkBackend =>
val conf = new Configuration(s.sc.hadoopConfiguration)
cloudfsConfig.google.flatMap(_.requester_pays_config).foreach {
case RequesterPaysConfig(prj, bkts) =>
bkts
.map { buckets =>
conf.set("fs.gs.requester.pays.mode", "CUSTOM")
conf.set("fs.gs.requester.pays.project.id", prj)
conf.set("fs.gs.requester.pays.buckets", buckets.mkString(","))
}
.getOrElse {
conf.set("fs.gs.requester.pays.mode", "AUTO")
conf.set("fs.gs.requester.pays.project.id", prj)
}
}
new HadoopFS(new SerializableHadoopConfiguration(conf))
private[this] var irID: Int = 0
private[this] var tmpdir: String = _
private[this] var localTmpdir: String = _

private[this] object tmpFileManager extends TempFileManager {
private[this] var fs = newFs(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
private[this] var manager = new OwningTempFileManager(fs)

def setFs(fs: FS): Unit = {
close()
this.fs = fs
manager = new OwningTempFileManager(fs)
}

def getFs: FS =
fs

override def newTmpPath(tmpdir: String, prefix: String, extension: String): String =
manager.newTmpPath(tmpdir, prefix, extension)

case _ =>
RouterFS.buildRoutes(cloudfsConfig)
override def close(): Unit =
manager.close()
}

def pyGetFlag(name: String): String =
Expand All @@ -88,24 +82,40 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
def pyAvailableFlags: java.util.ArrayList[String] =
flags.available

private[this] var irID: Int = 0
def pySetTmpdir(tmp: String): Unit =
tmpdir = tmp

private[this] def nextIRID(): Int = {
irID += 1
irID
}
def pySetLocalTmp(tmp: String): Unit =
localTmpdir = tmp

private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = {
val id = nextIRID()
ctx.IrCache += (id -> ir)
id
def pySetRequesterPays(@Nullable project: String, @Nullable buckets: util.List[String]): Unit = {
val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags)

val rpConfig: Option[RequesterPaysConfig] =
(Option(project).filter(_.nonEmpty), Option(buckets)) match {
case (Some(project), buckets) => Some(RequesterPaysConfig(project, buckets.map(_.asScala.toSet)))
case (None, Some(_)) => fatal("A non-empty, non-null requester pays google project is required to configure requester pays buckets.")
case (None, None) => None
}

val fs = newFs(
cloudfsConf.copy(
google = (cloudfsConf.google, rpConfig) match {
case (Some(gconf), _) => Some(gconf.copy(requester_pays_config = rpConfig))
case (None, Some(_)) => Some(GoogleStorageFSConfig(None, rpConfig))
case _ => None
}
)
)

tmpFileManager.setFs(fs)
}

def pyRemoveJavaIR(id: Int): Unit =
persistedIr.remove(id)

def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
references(name).addSequence(IndexedFastaSequenceFile(fs, fastaFile, indexFile))
references(name).addSequence(IndexedFastaSequenceFile(tmpFileManager.getFs, fastaFile, indexFile))

def pyRemoveSequence(name: String): Unit =
references(name).removeSequence()
Expand Down Expand Up @@ -239,7 +249,7 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
removeReference(name)

def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit =
references(name).addLiftover(references(destRGName), LiftOver(fs, chainFile))
references(name).addLiftover(references(destRGName), LiftOver(tmpFileManager.getFs, chainFile))

def pyRemoveLiftover(name: String, destRGName: String): Unit =
references(name).removeLiftover(destRGName)
Expand Down Expand Up @@ -267,7 +277,7 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
private[this] def removeReference(name: String): Unit =
references -= name

private def withExecuteContext[T](
private[this] def withExecuteContext[T](
selfContainedExecution: Boolean = true
)(
f: ExecuteContext => T
Expand All @@ -278,12 +288,12 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
tmpdir = tmpdir,
localTmpdir = localTmpdir,
backend = backend,
fs = fs,
fs = tmpFileManager.getFs,
timer = timer,
tempFileManager =
if (selfContainedExecution) null
else NonOwningTempFileManager(longLifeTempFileManager),
theHailClassLoader = theHailClassLoader,
else NonOwningTempFileManager(tmpFileManager),
theHailClassLoader = hcl,
flags = flags,
irMetadata = IrMetadata(None),
references = references,
Expand All @@ -294,6 +304,40 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
)(f)
}

private[this] def newFs(cloudfsConfig: CloudStorageFSConfig): FS =
backend match {
case s: SparkBackend =>
val conf = new Configuration(s.sc.hadoopConfiguration)
cloudfsConfig.google.flatMap(_.requester_pays_config).foreach {
case RequesterPaysConfig(prj, bkts) =>
bkts
.map { buckets =>
conf.set("fs.gs.requester.pays.mode", "CUSTOM")
conf.set("fs.gs.requester.pays.project.id", prj)
conf.set("fs.gs.requester.pays.buckets", buckets.mkString(","))
}
.getOrElse {
conf.set("fs.gs.requester.pays.mode", "AUTO")
conf.set("fs.gs.requester.pays.project.id", prj)
}
}
new HadoopFS(new SerializableHadoopConfiguration(conf))

case _ =>
RouterFS.buildRoutes(cloudfsConfig)
}

private[this] def nextIRID(): Int = {
irID += 1
irID
}

private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = {
val id = nextIRID()
ctx.IrCache += (id -> ir)
id
}

override def close(): Unit =
synchronized {
bmCache.close()
Expand Down

0 comments on commit 7c2810f

Please sign in to comment.