Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Lift backend state into {Service|Py4J}BackendApi #14698

Open
wants to merge 4 commits into
base: ehigham/ctx-coercer-cache
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions hail/python/hail/backend/local_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,31 +74,21 @@ def __init__(
hail_package = getattr(self._gateway.jvm, 'is').hail

jbackend = hail_package.backend.local.LocalBackend.apply(
tmpdir,
log,
True,
append,
skip_logging_configuration,
)
jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations)

super(LocalBackend, self).__init__(self._gateway.jvm, jbackend, jhc)
super().__init__(self._gateway.jvm, jbackend, jhc, tmpdir, tmpdir)
self.gcs_requester_pays_configuration = gcs_requester_pays_configuration
self._fs = self._exit_stack.enter_context(
RouterFS(gcs_kwargs={'gcs_requester_pays_configuration': gcs_requester_pays_configuration})
)

self._logger = None

flags = {}
if gcs_requester_pays_configuration is not None:
if isinstance(gcs_requester_pays_configuration, str):
flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration
else:
assert isinstance(gcs_requester_pays_configuration, tuple)
flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration[0]
flags['gcs_requester_pays_buckets'] = ','.join(gcs_requester_pays_configuration[1])

self._initialize_flags(flags)
self._initialize_flags({})

def validate_file(self, uri: str) -> None:
async_to_blocking(validate_file(uri, self._fs.afs))
Expand Down
45 changes: 38 additions & 7 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@

import hail
from hail.expr import construct_expr
from hail.fs.hadoop_fs import HadoopFS
from hail.ir import JavaIR
from hail.utils.java import Env, FatalError, scala_package_object
from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration

from ..hail_logging import Logger
from .backend import ActionTag, Backend, fatal_error_from_java_error_triplet
Expand Down Expand Up @@ -170,8 +172,15 @@ def parse(node):

class Py4JBackend(Backend):
@abc.abstractmethod
def __init__(self, jvm: JVMView, jbackend: JavaObject, jhc: JavaObject):
super(Py4JBackend, self).__init__()
def __init__(
self,
jvm: JVMView,
jbackend: JavaObject,
jhc: JavaObject,
tmpdir: str,
remote_tmpdir: str,
):
super().__init__()
import base64

def decode_bytearray(encoded):
Expand All @@ -184,14 +193,19 @@ def decode_bytearray(encoded):
self._jvm = jvm
self._hail_package = getattr(self._jvm, 'is').hail
self._utils_package_object = scala_package_object(self._hail_package.utils)
self._jbackend = jbackend
self._jhc = jhc

self._backend_server = self._hail_package.backend.BackendServer(self._jbackend)
self._backend_server_port: int = self._backend_server.port()
self._backend_server.start()
self._jbackend = self._hail_package.backend.api.Py4JBackendApi(jbackend)
self._jbackend.pySetLocalTmp(tmpdir)
self._jbackend.pySetRemoteTmp(remote_tmpdir)

self._jhttp_server = self._jbackend.pyHttpServer()
self._backend_server_port: int = self._jhttp_server.port()
self._requests_session = requests.Session()

self._gcs_requester_pays_config = None
self._fs = None

# This has to go after creating the SparkSession. Unclear why.
# Maybe it does its own patch?
install_exception_handler()
Expand All @@ -215,6 +229,23 @@ def hail_package(self):
def utils_package_object(self):
return self._utils_package_object

@property
def gcs_requester_pays_configuration(self) -> Optional[GCSRequesterPaysConfiguration]:
return self._gcs_requester_pays_config

@gcs_requester_pays_configuration.setter
def gcs_requester_pays_configuration(self, config: Optional[GCSRequesterPaysConfiguration]):
self._gcs_requester_pays_config = config
project, buckets = (None, None) if config is None else (config, None) if isinstance(config, str) else config
self._jbackend.pySetGcsRequesterPaysConfig(project, buckets)
self._fs = None # stale

@property
def fs(self):
if self._fs is None:
self._fs = HadoopFS(self._utils_package_object, self._jbackend.pyFs())
return self._fs

@property
def logger(self):
if self._logger is None:
Expand Down Expand Up @@ -289,7 +320,7 @@ def _to_java_blockmatrix_ir(self, ir):
return self._parse_blockmatrix_ir(self._render_ir(ir))

def stop(self):
self._backend_server.close()
self._jhttp_server.close()
self._jbackend.close()
self._jhc.stop()
self._jhc = None
Expand Down
29 changes: 6 additions & 23 deletions hail/python/hail/backend/spark_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import pyspark.sql

from hail.expr.table_type import ttable
from hail.fs.hadoop_fs import HadoopFS
from hail.ir import BaseIR
from hail.ir.renderer import CSERenderer
from hail.table import Table
from hail.utils import copy_log
from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration
from hailtop.aiotools.router_fs import RouterAsyncFS
from hailtop.aiotools.validators import validate_file
from hailtop.utils import async_to_blocking
Expand Down Expand Up @@ -47,12 +47,9 @@ def __init__(
skip_logging_configuration,
optimizer_iterations,
*,
gcs_requester_pays_project: Optional[str] = None,
gcs_requester_pays_buckets: Optional[str] = None,
gcs_requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None,
copy_log_on_error: bool = False,
):
assert gcs_requester_pays_project is not None or gcs_requester_pays_buckets is None

try:
local_jar_info = local_jar_information()
except ValueError:
Expand Down Expand Up @@ -120,10 +117,6 @@ def __init__(
append,
skip_logging_configuration,
min_block_size,
tmpdir,
local_tmpdir,
gcs_requester_pays_project,
gcs_requester_pays_buckets,
)
jhc = hail_package.HailContext.getOrCreate(jbackend, branching_factor, optimizer_iterations)
else:
Expand All @@ -137,10 +130,6 @@ def __init__(
append,
skip_logging_configuration,
min_block_size,
tmpdir,
local_tmpdir,
gcs_requester_pays_project,
gcs_requester_pays_buckets,
)
jhc = hail_package.HailContext.apply(jbackend, branching_factor, optimizer_iterations)

Expand All @@ -149,12 +138,12 @@ def __init__(
self.sc = sc
else:
self.sc = pyspark.SparkContext(gateway=self._gateway, jsc=jvm.JavaSparkContext(self._jsc))
self._jspark_session = jbackend.sparkSession()
self._jspark_session = jbackend.sparkSession().apply()
self._spark_session = pyspark.sql.SparkSession(self.sc, self._jspark_session)

super(SparkBackend, self).__init__(jvm, jbackend, jhc)
super().__init__(jvm, jbackend, jhc, local_tmpdir, tmpdir)
self.gcs_requester_pays_configuration = gcs_requester_pays_config

self._fs = None
self._logger = None

if not quiet:
Expand All @@ -167,7 +156,7 @@ def __init__(
self._initialize_flags({})

self._router_async_fs = RouterAsyncFS(
gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_project}
gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_config}
)

self._tmpdir = tmpdir
Expand All @@ -181,12 +170,6 @@ def stop(self):
self.sc.stop()
self.sc = None

@property
def fs(self):
if self._fs is None:
self._fs = HadoopFS(self._utils_package_object, self._jbackend.fs())
return self._fs

def from_spark(self, df, key):
result_tuple = self._jbackend.pyFromDF(df._jdf, key)
tir_id, type_json = result_tuple._1(), result_tuple._2()
Expand Down
13 changes: 4 additions & 9 deletions hail/python/hail/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,14 +474,10 @@ def init_spark(
optimizer_iterations = get_env_or_default(_optimizer_iterations, 'HAIL_OPTIMIZER_ITERATIONS', 3)

app_name = app_name or 'Hail'
(
gcs_requester_pays_project,
gcs_requester_pays_buckets,
) = convert_gcs_requester_pays_configuration_to_hadoop_conf_style(
get_gcs_requester_pays_configuration(
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
)
gcs_requester_pays_configuration = get_gcs_requester_pays_configuration(
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
)

backend = SparkBackend(
idempotent,
sc,
Expand All @@ -498,8 +494,7 @@ def init_spark(
local_tmpdir,
skip_logging_configuration,
optimizer_iterations,
gcs_requester_pays_project=gcs_requester_pays_project,
gcs_requester_pays_buckets=gcs_requester_pays_buckets,
gcs_requester_pays_config=gcs_requester_pays_configuration,
copy_log_on_error=copy_log_on_error,
)
if not backend.fs.exists(tmpdir):
Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import is.hail.io.fs.FS
import is.hail.types.RTable
import is.hail.types.encoded.EType
import is.hail.types.physical.PTuple
import is.hail.utils.ExecutionTimer.Timings
import is.hail.utils.fatal

import scala.reflect.ClassTag
Expand Down Expand Up @@ -54,6 +53,7 @@ trait BackendContext {
}

abstract class Backend extends Closeable {

// From https://github.com/hail-is/hail/issues/14580 :
// IR can get quite big, especially as it can contain an arbitrary
// amount of encoded literals from the user's python session. This
Expand Down Expand Up @@ -119,7 +119,7 @@ abstract class Backend extends Closeable {
def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses)
: TableStage

def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings)

def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]

def backendContext(ctx: ExecuteContext): BackendContext
}
4 changes: 0 additions & 4 deletions hail/src/main/scala/is/hail/backend/BackendRpc.scala
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,4 @@ trait HttpLikeBackendRpc[A] extends BackendRpc {
)
}
}

implicit protected def Ask: Routing
implicit protected def Write: Write[A]
implicit protected def Context: Context[A]
}
123 changes: 0 additions & 123 deletions hail/src/main/scala/is/hail/backend/BackendServer.scala

This file was deleted.

Loading