Skip to content

Commit

Permalink
[query] Move LoweredTableReaderCoercer into ExecuteContext
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Oct 1, 2024
1 parent b02c83b commit 63191ba
Show file tree
Hide file tree
Showing 9 changed files with 383 additions and 380 deletions.
2 changes: 0 additions & 2 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ abstract class Backend extends Closeable {
def asSpark(implicit E: Enclosing): SparkBackend =
fatal(s"${getClass.getSimpleName}: ${E.value} requires SparkBackend")

def shouldCacheQueryInfo: Boolean = true

def lowerDistributedSort(
ctx: ExecuteContext,
stage: TableStage,
Expand Down
6 changes: 6 additions & 0 deletions hail/src/main/scala/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import is.hail.annotations.{Region, RegionPool}
import is.hail.asm4s.HailClassLoader
import is.hail.backend.local.LocalTaskContext
import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction}
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.lowering.IrMetadata
import is.hail.io.fs.FS
import is.hail.linalg.BlockMatrix
Expand Down Expand Up @@ -74,6 +75,7 @@ object ExecuteContext {
blockMatrixCache: mutable.Map[String, BlockMatrix],
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
irCache: mutable.Map[Int, BaseIR],
coercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
)(
f: ExecuteContext => T
): T = {
Expand All @@ -95,6 +97,7 @@ object ExecuteContext {
blockMatrixCache,
codeCache,
irCache,
coercerCache,
))(f(_))
}
}
Expand Down Expand Up @@ -127,6 +130,7 @@ class ExecuteContext(
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
val IrCache: mutable.Map[Int, BaseIR],
val CoercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
) extends Closeable {

val rngNonce: Long =
Expand Down Expand Up @@ -199,6 +203,7 @@ class ExecuteContext(
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
irCache: mutable.Map[Int, BaseIR] = this.IrCache,
coercerCache: mutable.Map[Any, LoweredTableReaderCoercer] = this.CoercerCache,
)(
f: ExecuteContext => A
): A =
Expand All @@ -218,5 +223,6 @@ class ExecuteContext(
blockMatrixCache,
codeCache,
irCache,
coercerCache,
))(f)
}
16 changes: 16 additions & 0 deletions hail/src/main/scala/is/hail/backend/caching/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package is.hail.backend

import scala.collection.mutable

package object caching {
private[this] object NoCachingInstance extends mutable.AbstractMap[Any, Any] {
override def +=(kv: (Any, Any)): NoCachingInstance.this.type = this
override def -=(key: Any): NoCachingInstance.this.type = this
override def get(key: Any): Option[Any] = None
override def iterator: Iterator[(Any, Any)] = Iterator.empty
override def getOrElseUpdate(key: Any, op: => Any): Any = op
}

def NoCaching[K, V]: mutable.Map[K, V] =
NoCachingInstance.asInstanceOf[mutable.Map[K, V]]
}
3 changes: 3 additions & 0 deletions hail/src/main/scala/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import is.hail.backend._
import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.lowering._
Expand Down Expand Up @@ -81,6 +82,7 @@ class LocalBackend(
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)

// flags can be set after construction from python
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
Expand All @@ -106,6 +108,7 @@ class LocalBackend(
ImmutableMap.empty,
codeCache,
persistedIR,
coercerCache,
)(f)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package is.hail.backend.py4j
import is.hail.HailFeatureFlags
import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager}
import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex}
import is.hail.expr.ir.{BaseIR, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue}
import is.hail.expr.ir.{
BaseIR, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR,
MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue,
}
import is.hail.expr.ir.IRParser.parseType
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver}
Expand Down
10 changes: 5 additions & 5 deletions hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags}
import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend._
import is.hail.backend.caching.NoCaching
import is.hail.backend.service.ServiceBackend.MaxAvailableGcsConnections
import is.hail.expr.Validate
import is.hail.expr.ir.{
Expand Down Expand Up @@ -63,8 +64,6 @@ class ServiceBackend(
private[this] var stageCount = 0
private[this] val executor = Executors.newFixedThreadPool(MaxAvailableGcsConnections)

override def shouldCacheQueryInfo: Boolean = false

def defaultParallelism: Int = 4

def broadcast[T: ClassTag](_value: T): BroadcastValue[T] = {
Expand Down Expand Up @@ -316,9 +315,10 @@ class ServiceBackend(
),
IrMetadata(None),
references,
ImmutableMap.empty,
mutable.Map.empty,
ImmutableMap.empty,
NoCaching,
NoCaching,
NoCaching,
NoCaching,
)(f)
}
}
Expand Down
8 changes: 6 additions & 2 deletions hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import is.hail.{HailContext, HailFeatureFlags}
import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend._
import is.hail.backend.caching.BlockMatrixCache
import is.hail.backend.caching.{BlockMatrixCache, NoCaching}
import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.lowering._
Expand Down Expand Up @@ -343,6 +344,7 @@ class SparkBackend(
private[this] val bmCache = new BlockMatrixCache()
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
private[this] val persistedIr = mutable.Map.empty[Int, BaseIR]
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)

def createExecuteContextForTests(
timer: ExecutionTimer,
Expand All @@ -365,8 +367,9 @@ class SparkBackend(
IrMetadata(None),
references,
ImmutableMap.empty,
mutable.Map.empty,
NoCaching,
ImmutableMap.empty,
NoCaching,
)

override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) =
Expand All @@ -389,6 +392,7 @@ class SparkBackend(
bmCache,
codeCache,
persistedIr,
coercerCache,
)(f)
}

Expand Down
35 changes: 12 additions & 23 deletions hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package is.hail.expr.ir
import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.backend.ExecuteContext
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.functions.UtilFunctions
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.expr.ir.streams.StreamProducer
Expand Down Expand Up @@ -143,16 +144,6 @@ class PartitionIteratorLongReader(
)
}

abstract class LoweredTableReaderCoercer {
def coerce(
ctx: ExecuteContext,
globals: IR,
contextType: Type,
contexts: IndexedSeq[Any],
body: IR => IR,
): TableStage
}

class GenericTableValue(
val fullTableType: TableType,
val uidFieldName: String,
Expand All @@ -168,12 +159,11 @@ class GenericTableValue(
assert(contextType.hasField("partitionIndex"))
assert(contextType.fieldType("partitionIndex") == TInt32)

private var ltrCoercer: LoweredTableReaderCoercer = _

private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any)
: LoweredTableReaderCoercer = {
if (ltrCoercer == null) {
ltrCoercer = LoweredTableReader.makeCoercer(
: LoweredTableReaderCoercer =
ctx.CoercerCache.getOrElseUpdate(
(1, contextType, fullTableType.key, cacheKey),
LoweredTableReader.makeCoercer(
ctx,
fullTableType.key,
1,
Expand All @@ -184,11 +174,8 @@ class GenericTableValue(
bodyPType,
body,
context,
cacheKey,
)
}
ltrCoercer
}
),
)

def toTableStage(ctx: ExecuteContext, requestedType: TableType, context: String, cacheKey: Any)
: TableStage = {
Expand Down Expand Up @@ -217,11 +204,13 @@ class GenericTableValue(
val contextsIR = ToStream(Literal(TArray(contextType), contexts))
TableStage(globalsIR, p, TableStageDependency.none, contextsIR, requestedBody)
} else {
getLTVCoercer(ctx, context, cacheKey).coerce(
getLTVCoercer(ctx, context, cacheKey)(
ctx,
globalsIR,
contextType, contexts,
requestedBody)
contextType,
contexts,
requestedBody,
)
}
}
}
Loading

0 comments on commit 63191ba

Please sign in to comment.