Skip to content

Commit

Permalink
further improve task handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ravenac95 committed Dec 10, 2024
1 parent 47b918f commit e1e15f9
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 23 deletions.
6 changes: 5 additions & 1 deletion warehouse/metrics_tools/compute/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def calculate_metrics(
ref: PeerMetricDependencyRef,
locals: t.Dict[str, t.Any],
dependent_tables_map: t.Dict[str, str],
cluster_min_size: int = 6,
cluster_max_size: int = 6,
retries: t.Optional[int] = None,
):
"""Calculate metrics for a given period and write the results to a gcs
Expand Down Expand Up @@ -77,7 +79,9 @@ def calculate_metrics(
str: The gcs result path from the metrics calculation service
"""
# Trigger the cluster start
status = self.start_cluster(min_size=1, max_size=1)
status = self.start_cluster(
min_size=cluster_min_size, max_size=cluster_max_size
)
self.logger.info(f"cluster status: {status}")

job_response = self.submit_job(
Expand Down
35 changes: 24 additions & 11 deletions warehouse/metrics_tools/compute/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import typing as t
import asyncio
import inspect
from pyee.asyncio import AsyncIOEventEmitter

from dask.distributed import Client, LocalCluster, Future as DaskFuture
Expand All @@ -27,8 +28,10 @@ def start_duckdb_cluster(
min_size: int = 6,
max_size: int = 6,
quiet: bool = False,
**kwargs: t.Any,
):
options: t.Dict[str, t.Any] = {"namespace": namespace}
options.update(kwargs)
print("starting duckdb cluster")
if cluster_spec:
options["custom_cluster_spec"] = cluster_spec
Expand All @@ -45,21 +48,29 @@ async def start_duckdb_cluster_async(
cluster_spec: t.Optional[dict] = None,
min_size: int = 6,
max_size: int = 6,
**kwargs: t.Any,
):
"""The async version of start_duckdb_cluster which wraps the sync version in
a thread. The "async" version of dask's KubeCluster doesn't work as
expected. So for now we do this."""

# options: t.Dict[str, t.Any] = {"namespace": namespace}
# if cluster_spec:
# options["custom_cluster_spec"] = cluster_spec
# cluster = await KubeCluster(quiet=True, asynchronous=True, **options)
# await cluster.adapt(minimum=min_size, maximum=max_size)
# return cluster
options: t.Dict[str, t.Any] = {"namespace": namespace}
options.update(kwargs)
if cluster_spec:
options["custom_cluster_spec"] = cluster_spec

return await asyncio.to_thread(
start_duckdb_cluster, namespace, cluster_spec, min_size, max_size
)
# loop = asyncio.get_running_loop()
cluster = await KubeCluster(asynchronous=True, **options)
print(f"is cluster awaitable?: {inspect.isawaitable(cluster)}")
adapt_response = cluster.adapt(minimum=min_size, maximum=max_size)
print(f"is adapt_response awaitable?: {inspect.isawaitable(adapt_response)}")
if inspect.isawaitable(adapt_response):
await adapt_response
return cluster

# return await asyncio.to_thread(
# start_duckdb_cluster, namespace, cluster_spec, min_size, max_size
# )


class ClusterProxy(abc.ABC):
Expand Down Expand Up @@ -129,7 +140,7 @@ async def status(self) -> ClusterStatus:
)

async def stop(self):
self.cluster.close()
await self.cluster.close()

@property
def dashboard_link(self):
Expand All @@ -153,14 +164,16 @@ def __init__(
namespace: str,
cluster_spec: t.Optional[dict] = None,
log_override: t.Optional[logging.Logger] = None,
**kwargs: t.Any,
):
self._namespace = namespace
self.logger = log_override or logger
self._cluster_spec = cluster_spec
self.kwargs = kwargs

async def create_cluster(self, min_size: int, max_size: int):
cluster = await start_duckdb_cluster_async(
self._namespace, self._cluster_spec, min_size, max_size
self._namespace, self._cluster_spec, min_size, max_size, **self.kwargs
)
return KubeClusterProxy(cluster)

Expand Down
4 changes: 4 additions & 0 deletions warehouse/metrics_tools/compute/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
if not debug_all:
debug_cache = env.ensure_bool("METRICS_DEBUG_CACHE", False)
debug_cluster = env.ensure_bool("METRICS_DEBUG_CLUSTER", False)
debug_cluster_no_shutdown = env.ensure_bool(
"METRICS_DEBUG_CLUSTER_NO_SHUTDOWN", False
)
else:
debug_cache = debug_all
debug_cluster = debug_all
debug_cluster_no_shutdown = debug_all
17 changes: 14 additions & 3 deletions warehouse/metrics_tools/compute/run_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def run_get_status(url: str, job_id: str):
print(response.json())


def run_local_test(url: str, start: str, end: str, batch_size: int):
def run_local_test(
url: str, start: str, end: str, batch_size: int, cluster_size: int = 6
):
import sys

logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
Expand Down Expand Up @@ -143,18 +145,27 @@ def run_local_test(url: str, start: str, end: str, batch_size: int):
"metrics.events_daily_to_artifact": "sqlmesh__metrics.metrics__events_daily_to_artifact__2357434958"
},
batch_size=batch_size,
cluster_max_size=cluster_size,
cluster_min_size=cluster_size,
)


@click.command()
@click.option("--url", default="http://localhost:8000")
@click.option("--batch-size", type=click.INT, default=1)
@click.option("--start", default="2024-01-01")
@click.option("--cluster-size", type=click.INT, default=6)
@click.option("--end")
def main(url: str, batch_size: int, start, end):
def main(url: str, batch_size: int, start: str, end: str, cluster_size: int):
if not end:
end = datetime.now().strftime("%Y-%m-%d")
run_local_test(url, start, end, batch_size)
run_local_test(
url,
start,
end,
batch_size,
cluster_size=cluster_size,
)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions warehouse/metrics_tools/compute/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ async def initialize_app(app: FastAPI):
constants.cluster_namespace,
cluster_spec=cluster_spec,
log_override=logger,
shutdown_on_close=not constants.debug_cluster_no_shutdown,
)
cluster_manager = ClusterManager.with_metrics_plugin(
constants.gcs_bucket,
Expand Down
15 changes: 7 additions & 8 deletions warehouse/metrics_tools/compute/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ async def _handle_query_job_submit_request(
retries=input.retries,
)

self.logger.info(f"job[{job_id}]: Submitted task has: {task._id}")
# task = asyncio.create_task(submit())
self.logger.info(f"job[{job_id}]: Submitted task {task_id}")
tasks.append(task)

total = len(tasks)
Expand All @@ -164,14 +163,14 @@ async def _handle_query_job_submit_request(
# this.

for finished in asyncio.as_completed(tasks):
completed += 1
self.logger.info(f"job[{job_id}] progress: {completed}/{total}")
await self._notify_job_updated(job_id, completed, total)
self.logger.info(
f"job[{job_id}] finished notifying update: {completed}/{total}"
)
try:
task_id = await finished
completed += 1
self.logger.info(f"job[{job_id}] progress: {completed}/{total}")
await self._notify_job_updated(job_id, completed, total)
self.logger.info(
f"job[{job_id}] finished notifying update: {completed}/{total}"
)
except CancelledError as e:
failures += 1
self.logger.error(f"job[{job_id}] task cancelled {e.args}")
Expand Down

0 comments on commit e1e15f9

Please sign in to comment.