diff --git a/ops/tf-modules/warehouse-cluster/main.tf b/ops/tf-modules/warehouse-cluster/main.tf index 344e14aee..ac14045e4 100644 --- a/ops/tf-modules/warehouse-cluster/main.tf +++ b/ops/tf-modules/warehouse-cluster/main.tf @@ -161,9 +161,9 @@ locals { machine_type = "n1-highmem-16" node_locations = join(",", var.cluster_zones) min_count = 0 - max_count = 10 + max_count = 20 local_ssd_count = 0 - local_ssd_ephemeral_storage_count = 1 + local_ssd_ephemeral_storage_count = 2 spot = false disk_size_gb = 100 disk_type = "pd-standard" @@ -357,7 +357,7 @@ module "vpc" { module "gke" { source = "terraform-google-modules/kubernetes-engine/google" - version = "~> 33.0" + version = "~> 35.0.0" project_id = var.project_id name = var.cluster_name region = var.cluster_region diff --git a/poetry.lock b/poetry.lock index 621468df4..e8f7cbb97 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "agate" @@ -3214,8 +3214,6 @@ optional = false python-versions = "*" files = [ {file = "jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c"}, - {file = "jsonpath_ng-1.7.0-py2-none-any.whl", hash = "sha256:898c93fc173f0c336784a3fa63d7434297544b7198124a68f9a3ef9597b0ae6e"}, - {file = "jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6"}, ] [package.dependencies] @@ -5168,6 +5166,26 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pydantic-settings" +version = "2.7.0" +description = "Settings management using Pydantic" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic_settings-2.7.0-py3-none-any.whl", hash = "sha256:e00c05d5fa6cbbb227c84bd7487c5c1065084119b750df7c8c1a554aed236eb5"}, + {file = "pydantic_settings-2.7.0.tar.gz", hash = "sha256:ac4bfd4a36831a48dbf8b2d9325425b549a0a6f18cea118436d728eb4f1c4d66"}, +] + +[package.dependencies] +pydantic = ">=2.7.0" +python-dotenv = ">=0.21.0" + +[package.extras] +azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0)"] +toml = ["tomli (>=2.0.1)"] +yaml = ["pyyaml (>=6.0.1)"] + [[package]] name = "pyee" version = "12.1.1" @@ -5984,7 +6002,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd"}, - {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a52d48f4e7bf9005e8f0a89209bf9a73f7190ddf0489eee5eb51377385f59f2a"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6"}, @@ -5993,7 +6010,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2"}, - {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1492a6051dab8d912fc2adeef0e8c72216b24d57bd896ea607cb90bb0c4981d3"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:20b0f8dc160ba83b6dcc0e256846e1a02d044e13f7ea74a3d1d56ede4e48c632"}, @@ -6002,7 +6018,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680"}, - {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b82a7c94a498853aa0b272fd5bc67f29008da798d4f93a2f9f289feb8426a58d"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:4c8c5d82f50bb53986a5e02d1b3092b03622c02c2eb78e29bec33fd9593bae1a"}, @@ -6011,7 +6026,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96777d473c05ee3e5e3c3e999f5d23c6f4ec5b0c38c098b3a5229085f74236c6"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:3bc2a80e6420ca8b7d3590791e2dfc709c88ab9152c00eeb511c9875ce5778bf"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e188d2699864c11c36cdfdada94d781fd5d6b0071cd9c427bceb08ad3d7c70e1"}, - {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4f6f3eac23941b32afccc23081e1f50612bdbe4e982012ef4f5797986828cd01"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win32.whl", hash = "sha256:6442cb36270b3afb1b4951f060eccca1ce49f3d087ca1ca4563a6eb479cb3de6"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win_amd64.whl", hash = "sha256:e5b8daf27af0b90da7bb903a876477a9e6d7270be6146906b276605997c7e9a3"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fc4b630cd3fa2cf7fce38afa91d7cfe844a9f75d7f0f36393fa98815e911d987"}, @@ -6020,7 +6034,6 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f1c3765db32be59d18ab3953f43ab62a761327aafc1594a2a1fbe038b8b8a7"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d85252669dc32f98ebcd5d36768f5d4faeaeaa2d655ac0473be490ecdae3c285"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e143ada795c341b56de9418c58d028989093ee611aa27ffb9b7f609c00d813ed"}, - {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2c59aa6170b990d8d2719323e628aaf36f3bfbc1c26279c0eeeb24d05d2d11c7"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win32.whl", hash = "sha256:beffaed67936fbbeffd10966a4eb53c402fafd3d6833770516bf7314bc6ffa12"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win_amd64.whl", hash = "sha256:040ae85536960525ea62868b642bdb0c2cc6021c9f9d507810c0c604e66f5a7b"}, {file = "ruamel.yaml.clib-0.2.12.tar.gz", hash = "sha256:6c8fbb13ec503f99a91901ab46e0b07ae7941cd527393187039aec586fdfd36f"}, @@ -7671,4 +7684,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.12,<3.13" -content-hash = "9fc640d518aa707b0479720a8c2ac62fba8e170191b36471bd3c426ff2e89ebe" +content-hash = "d34c20fd2b3ef341b2a4fb6d418a4ac512e03eb76ea55210f03322c00477ad83" diff --git a/pyproject.toml b/pyproject.toml index d658ecb1c..189b0d2b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,8 @@ aiotrino = "^0.2.3" pytest-asyncio = "^0.24.0" isort = "^5.13.2" uvicorn = { extras = ["standard"], version = "^0.32.1" } +websockets = "^14.1" +pydantic-settings = "^2.7.0" [tool.poetry.scripts] diff --git a/warehouse/metrics_tools/compute/app.py b/warehouse/metrics_tools/compute/app.py new file mode 100644 index 000000000..bd472b370 --- /dev/null +++ b/warehouse/metrics_tools/compute/app.py @@ -0,0 +1,303 @@ +"""The main definition of the metrics calculation service FastAPI application. + +Please note, this only defines an app factory that returns an app. If you're +looking for the main entrypoint go to server.py +""" + +import asyncio +import logging +import shutil +import tempfile +import typing as t +import uuid +from contextlib import asynccontextmanager + +import aiotrino +from dotenv import load_dotenv +from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect +from fastapi.datastructures import State +from metrics_tools.compute.result import ( + DummyImportAdapter, + FakeLocalImportAdapter, + TrinoImportAdapter, +) + +from .cache import setup_fake_cache_export_manager, setup_trino_cache_export_manager +from .cluster import ( + ClusterManager, + KubeClusterFactory, + LocalClusterFactory, + make_new_cluster_with_defaults, +) +from .service import MetricsCalculationService +from .types import ( + AppConfig, + AppLifespanFactory, + ClusterStartRequest, + EmptyResponse, + ExportedTableLoadRequest, + JobStatusResponse, + JobSubmitRequest, +) + +load_dotenv() +logger = logging.getLogger(__name__) + + +def default_lifecycle(config: AppConfig): + @asynccontextmanager + async def initialize_app(app: FastAPI): + logger.info("Metrics calculation service is starting up") + if config.debug_all: + logger.warning("Debugging all services") + + cache_export_manager = None + temp_dir = None + + if config.debug_with_duckdb: + temp_dir = tempfile.mkdtemp() + logger.debug(f"Created temp dir {temp_dir}") + if not config.debug_cache: + trino_connection = aiotrino.dbapi.connect( + host=config.trino_host, + port=config.trino_port, + user=config.trino_user, + catalog=config.trino_catalog, + ) + cache_export_manager = await setup_trino_cache_export_manager( + trino_connection, + config.gcs_bucket, + config.hive_catalog, + config.hive_schema, + ) + import_adapter = TrinoImportAdapter( + db=trino_connection, + gcs_bucket=config.gcs_bucket, + hive_catalog=config.hive_catalog, + hive_schema=config.hive_schema, + ) + else: + if config.debug_with_duckdb: + assert temp_dir is not None + logger.warning("Loading fake cache export manager with duckdb") + import_adapter = FakeLocalImportAdapter(temp_dir) + else: + logger.warning("Loading dummy cache export manager (writes nothing)") + import_adapter = DummyImportAdapter() + cache_export_manager = await setup_fake_cache_export_manager() + + cluster_manager = None + if not config.debug_cluster: + cluster_spec = make_new_cluster_with_defaults(config) + cluster_factory = KubeClusterFactory( + config.cluster_namespace, + cluster_spec=cluster_spec, + shutdown_on_close=not config.debug_cluster_no_shutdown, + ) + cluster_manager = ClusterManager.with_metrics_plugin( + config.gcs_bucket, + config.gcs_key_id, + config.gcs_secret, + config.worker_duckdb_path, + cluster_factory, + ) + else: + logger.warning("Loading fake cluster manager") + cluster_factory = LocalClusterFactory() + cluster_manager = ClusterManager.with_dummy_metrics_plugin( + cluster_factory, + ) + + mcs = MetricsCalculationService.setup( + id=str(uuid.uuid4()), + gcs_bucket=config.gcs_bucket, + result_path_prefix=config.results_path_prefix, + cluster_manager=cluster_manager, + cache_manager=cache_export_manager, + import_adapter=import_adapter, + ) + try: + yield { + "mcs": mcs, + } + finally: + logger.info("Waiting for metrics calculation service to close") + await mcs.close() + if temp_dir: + logger.info("Removing temp dir") + shutil.rmtree(temp_dir, ignore_errors=True) + + return initialize_app + + +def app_factory[T](lifespan_factory: AppLifespanFactory[T], config: T): + logger.debug(f"loading application with config: {config}") + app = setup_app(lifespan=lifespan_factory(config)) + return app + + +class ApplicationStateStorage(t.Protocol): + @property + def state(self) -> State: ... + + +def get_mcs(storage: ApplicationStateStorage) -> MetricsCalculationService: + mcs = storage.state.mcs + assert mcs is not None + return t.cast(MetricsCalculationService, mcs) + + +def setup_app(lifespan: t.Callable[[FastAPI], t.Any]): + # Dependency to get the cluster manager + + app = FastAPI(lifespan=lifespan) + + @app.get("/status") + async def get_status(): + """Liveness endpoint""" + return {"status": "Service is running"} + + @app.post("/cluster/start") + async def start_cluster( + request: Request, + start_request: ClusterStartRequest, + ): + """Start a Dask cluster in an idempotent way. + + If the cluster is already running, it will not be restarted. + """ + state = get_mcs(request) + manager = state.cluster_manager + return await manager.start_cluster( + start_request.min_size, start_request.max_size + ) + + @app.post("/cluster/stop") + async def stop_cluster(request: Request): + """Stop the Dask cluster""" + state = get_mcs(request) + manager = state.cluster_manager + return await manager.stop_cluster() + + @app.get("/cluster/status") + async def get_cluster_status(request: Request): + """Get the current Dask cluster status""" + state = get_mcs(request) + manager = state.cluster_manager + return await manager.get_cluster_status() + + @app.post("/job/submit") + async def submit_job( + request: Request, + input: JobSubmitRequest, + ): + """Submits a Dask job for calculation""" + service = get_mcs(request) + return await service.submit_job(input) + + @app.get("/job/status/{job_id}") + async def get_job_status( + request: Request, + job_id: str, + ): + """Get the status of a job""" + include_stats = ( + request.query_params.get("include_stats", "false").lower() == "true" + ) + service = get_mcs(request) + return await service.get_job_status(job_id, include_stats=include_stats) + + @app.websocket("/job/status/{job_id}/ws") + async def job_status_ws( + websocket: WebSocket, + job_id: str, + ): + """Websocket endpoint for job status updates""" + service = get_mcs(websocket) + update_queue: asyncio.Queue[JobStatusResponse] = asyncio.Queue() + + await websocket.accept() + + async def listener(job_status_response: JobStatusResponse): + logger.debug(f"Received job status update: {job_status_response}") + await update_queue.put(job_status_response) + + stop_listening = service.listen_for_job_updates(job_id, listener) + + try: + while True: + update = await update_queue.get() + await websocket.send_text(update.model_dump_json()) + except WebSocketDisconnect: + stop_listening() + await websocket.close() + + @app.post("/cache/manual") + async def add_existing_exported_table_references( + request: Request, input: ExportedTableLoadRequest + ): + """Add a table export to the cache""" + service = get_mcs(request) + await service.add_existing_exported_table_references(input.map) + return EmptyResponse() + + # @app.websocket("/ws") + # async def websocket_endpoint(websocket: WebSocket): + # await websocket.accept() + + # service = get_mcs(websocket) + + # response_queue: asyncio.Queue[MCSResponseTypes] = asyncio.Queue() + # request_tasks: t.List[asyncio.Task[None]] = [] + + # async def receive_request(): + # return await websocket.receive_text() + + # async def receive_response(): + # return await response_queue.get() + + # async def send_response(response: MCSResponseTypes): + # await websocket.send_text( + # ServiceResponse(type=response.type, response=response).model_dump_json() + # ) + + # async def request_router(mcs_request_str: str): + # try: + # mcs_request = ServiceRequest.model_validate_json(mcs_request_str) + # except pydantic.ValidationError as e: + # await response_queue.put(ErrorResponse(message=str(e))) + # return + # print(mcs_request) + + # try: + # mcs_request_task = asyncio.create_task(receive_request()) + # mcs_response_task = asyncio.create_task(receive_response()) + # pending: t.Set[asyncio.Task[str] | asyncio.Task[MCSResponseTypes]] = { + # mcs_request_task, + # mcs_response_task, + # } + # while True: + # done, pending = await asyncio.wait( + # pending, + # return_when=asyncio.FIRST_COMPLETED, + # ) + # for task in done: + # if task == mcs_request_task: + # mcs_request_str = t.cast(str, await mcs_request_task) + + # request_tasks.append( + # asyncio.create_task(request_router(mcs_request_str)) + # ) + # mcs_request_task = asyncio.create_task(receive_request()) + # pending.add(mcs_request_task) + # else: + # response = t.cast(MCSResponseTypes, await mcs_response_task) + # await send_response(response) + + # mcs_response_task = asyncio.create_task(receive_response()) + # except Exception as e: + # await send_response(ErrorResponse(message=str(e))) + # finally: + # await websocket.close() + + return app diff --git a/warehouse/metrics_tools/compute/client.py b/warehouse/metrics_tools/compute/client.py index a4472db30..5a9c67998 100644 --- a/warehouse/metrics_tools/compute/client.py +++ b/warehouse/metrics_tools/compute/client.py @@ -3,9 +3,11 @@ import logging import time import typing as t +from contextlib import contextmanager from datetime import datetime +from urllib.parse import urljoin -import requests +import httpx from metrics_tools.compute.types import ( ClusterStartRequest, ClusterStatus, @@ -13,14 +15,16 @@ ExportedTableLoadRequest, ExportReference, InspectCacheResponse, + JobStatusResponse, + JobSubmitRequest, + JobSubmitResponse, QueryJobStatus, - QueryJobStatusResponse, - QueryJobSubmitRequest, - QueryJobSubmitResponse, ) from metrics_tools.definition import PeerMetricDependencyRef from pydantic import BaseModel from pydantic_core import to_jsonable_python +from websockets.sync.client import connect +from websockets.sync.connection import Connection logger = logging.getLogger(__name__) @@ -29,14 +33,86 @@ class ResponseObject[T](t.Protocol): def model_validate(self, obj: dict) -> T: ... +class BaseWebsocketConnector: + def receive(self) -> str: + raise NotImplementedError() + + def send(self, data: str): + raise NotImplementedError() + + +class WebsocketConnectFactory(t.Protocol): + def __call__( + self, *, base_url: str, path: str + ) -> t.ContextManager[BaseWebsocketConnector]: ... + + +class WebsocketsConnector(BaseWebsocketConnector): + def __init__(self, connection: Connection): + self.connection = connection + + def receive(self): + data = self.connection.recv() + if isinstance(data, str): + return data + else: + return data.decode() + + def send(self, data: str): + return self.connection.send(data) + + +class ClientRetriesExceeded(Exception): + pass + + +@contextmanager +def default_ws(*, base_url: str, path: str): + url = urljoin(base_url, path) + with connect(url) as ws: + yield WebsocketsConnector(ws) + + class Client: """A metrics calculation service client""" - url: str + url: httpx.Client logger: logging.Logger - def __init__(self, url: str, log_override: t.Optional[logging.Logger] = None): - self.url = url + @classmethod + def from_url( + cls, + url: str, + retries: int = 5, + log_override: t.Optional[logging.Logger] = None, + ): + """Create a client from a base url + + Args: + url (str): The base url + retries (int): The number of retries the client should attempt when connecting + log_override (t.Optional[logging.Logger]): An optional logger override + + Returns: + Client: The client instance + """ + return cls( + httpx.Client(base_url=url), + retries, + default_ws, + log_override=log_override, + ) + + def __init__( + self, + client: httpx.Client, + retries: int, + websocket_connect_factory: WebsocketConnectFactory, + log_override: t.Optional[logging.Logger] = None, + ): + self.client = client + self.retries = retries + self.websocket_connect_factory = websocket_connect_factory self.logger = log_override or logger def calculate_metrics( @@ -50,9 +126,10 @@ def calculate_metrics( ref: PeerMetricDependencyRef, locals: t.Dict[str, t.Any], dependent_tables_map: t.Dict[str, str], + progress_handler: t.Optional[t.Callable[[JobStatusResponse], None]] = None, cluster_min_size: int = 6, cluster_max_size: int = 6, - retries: t.Optional[int] = None, + job_retries: int = 3, ): """Calculate metrics for a given period and write the results to a gcs folder. This method is a high level method that triggers all of the @@ -73,7 +150,7 @@ def calculate_metrics( ref (PeerMetricDependencyRef): The dependency reference locals (t.Dict[str, t.Any]): The local variables to use dependent_tables_map (t.Dict[str, str]): The dependent tables map - retries (t.Optional[int], optional): The number of retries. Defaults to None. + job_retries (int): The number of retries for a given job in the worker queue. Defaults to 3. Returns: ExportReference: The export reference for the resulting calculation @@ -94,33 +171,28 @@ def calculate_metrics( ref, locals, dependent_tables_map, - retries, + job_retries, ) job_id = job_response.job_id export_reference = job_response.export_reference + if not progress_handler: + + def _handler(response: JobStatusResponse): + self.logger.info( + f"job[{job_id}] status: {response.status}, progress: {response.progress}" + ) + + progress_handler = _handler + # Wait for the job to be completed - status_response = self.get_job_status(job_id) - while status_response.status in [ - QueryJobStatus.PENDING, - QueryJobStatus.RUNNING, - ]: - self.logger.info(f"job[{job_id}] is still pending") - status_response = self.get_job_status(job_id) - time.sleep(5) - self.logger.info(f"job[{job_id}] status is {status_response}") - - if status_response.status == QueryJobStatus.FAILED: - self.logger.error( - f"job[{job_id}] failed with status {status_response.status}" - ) - raise Exception( - f"job[{job_id}] failed with status {status_response.status}" - ) + final_status = self.wait_for_job(job_id, progress_handler) - self.logger.info( - f"job[{job_id}] completed with status {status_response.status}" - ) + if final_status.status == QueryJobStatus.FAILED: + self.logger.error(f"job[{job_id}] failed with status {final_status.status}") + raise Exception(f"job[{job_id}] failed with status {final_status.status}") + + self.logger.info(f"job[{job_id}] completed with status {final_status.status}") return export_reference @@ -132,6 +204,24 @@ def start_cluster(self, min_size: int, max_size: int): ) return response + def wait_for_job( + self, job_id: str, progress_handler: t.Callable[[JobStatusResponse], None] + ): + """Connect to the websocket and listen for job updates""" + url = self.client.base_url + with self.websocket_connect_factory( + base_url=f"{url.copy_with(scheme="ws")}", path=f"/job/status/{job_id}/ws" + ) as ws: + while True: + raw_response = ws.receive() + response = JobStatusResponse.model_validate_json(raw_response) + if response.status not in [ + QueryJobStatus.PENDING, + QueryJobStatus.RUNNING, + ]: + return response + progress_handler(response) + def submit_job( self, query_str: str, @@ -143,7 +233,7 @@ def submit_job( ref: PeerMetricDependencyRef, locals: t.Dict[str, t.Any], dependent_tables_map: t.Dict[str, str], - retries: t.Optional[int] = None, + job_retries: t.Optional[int] = None, ): """Submit a job to the metrics calculation service @@ -157,12 +247,12 @@ def submit_job( ref (PeerMetricDependencyRef): The dependency reference locals (t.Dict[str, t.Any]): The local variables to use dependent_tables_map (t.Dict[str, str]): The dependent tables map - retries (t.Optional[int], optional): The number of retries. Defaults to None. + job_retries (int): The number of retries for a given job in the worker queue. Defaults to 3. Returns: QueryJobSubmitResponse: The job response from the metrics calculation service """ - request = QueryJobSubmitRequest( + request = JobSubmitRequest( query_str=query_str, start=start, end=end, @@ -172,17 +262,17 @@ def submit_job( ref=ref, locals=locals, dependent_tables_map=dependent_tables_map, - retries=retries, + retries=job_retries, execution_time=datetime.now(), ) job_response = self.service_post_with_input( - QueryJobSubmitResponse, "/job/submit", request + JobSubmitResponse, "/job/submit", request ) return job_response def get_job_status(self, job_id: str): """Get the status of a job""" - return self.service_get(QueryJobStatusResponse, f"/job/status/{job_id}") + return self.service_get(JobStatusResponse, f"/job/status/{job_id}") def run_cache_manual_load(self, map: t.Dict[str, ExportReference]): """Load a cache with the provided map. This is useful for testing @@ -196,12 +286,46 @@ def inspect_cache(self): def service_request[ T - ](self, method: str, factory: ResponseObject[T], path: str, **kwargs) -> T: - response = requests.request( - method, - f"{self.url}{path}", - **kwargs, - ) + ]( + self, + method: str, + factory: ResponseObject[T], + path: str, + client_retries: t.Optional[int] = None, + **kwargs, + ) -> T: + def make_request(): + return self.client.request( + method, + path, + **kwargs, + ) + + def retry_request(retries: int): + for i in range(retries): + try: + response = make_request() + response.raise_for_status() + return response + except httpx.NetworkError as e: + self.logger.error(f"Failed request with network error, {e}") + except httpx.TimeoutException as e: + self.logger.error(f"Failed request with timeout, {e}") + except httpx.HTTPStatusError as e: + self.logger.error( + f"Failed request with response code: {e.response.status_code}" + ) + if e.response.status_code >= 500: + self.logger.debug("server error, retrying") + elif e.response.status_code == 408: + self.logger.debug("request timeout, retrying") + else: + raise e + time.sleep(2**i) # Exponential backoff + raise ClientRetriesExceeded("Request failed after too many retries") + + client_retries = client_retries or self.retries + response = retry_request(client_retries) return factory.model_validate(response.json()) def service_post_with_input[ diff --git a/warehouse/metrics_tools/compute/cluster.py b/warehouse/metrics_tools/compute/cluster.py index cf67070fb..b8a502c6e 100644 --- a/warehouse/metrics_tools/compute/cluster.py +++ b/warehouse/metrics_tools/compute/cluster.py @@ -11,7 +11,7 @@ from dask.distributed import Future as DaskFuture from dask.distributed import LocalCluster from dask_kubernetes.operator import KubeCluster, make_cluster_spec -from metrics_tools.compute.types import ClusterStatus +from metrics_tools.compute.types import ClusterConfig, ClusterStatus from pyee.asyncio import AsyncIOEventEmitter from .worker import ( @@ -232,6 +232,9 @@ async def start_cluster(self, min_size: int, max_size: int) -> ClusterStatus: async with self._lock: if self._cluster is not None: self.logger.info("cluster already running") + + # Trigger scaling if necessary + return ClusterStatus( status="Cluster already running", is_ready=True, @@ -430,20 +433,19 @@ def make_new_cluster( return spec -def make_new_cluster_with_defaults(): +def make_new_cluster_with_defaults(config: ClusterConfig): # Import here to avoid dependency on constants for all dependents on the # cluster module - from . import constants return make_new_cluster( - image=f"{constants.cluster_image_repo}:{constants.cluster_image_tag}", - cluster_id=constants.cluster_name, - service_account_name=constants.cluster_service_account, - threads=constants.worker_threads, - scheduler_memory_limit=constants.scheduler_memory_limit, - scheduler_memory_request=constants.scheduler_memory_request, - scheduler_pool_type=constants.scheduler_pool_type, - worker_memory_limit=constants.worker_memory_limit, - worker_memory_request=constants.worker_memory_request, - worker_pool_type=constants.worker_pool_type, + image=f"{config.cluster_image_repo}:{config.cluster_image_tag}", + cluster_id=config.cluster_name, + service_account_name=config.cluster_service_account, + threads=config.worker_threads, + scheduler_memory_limit=config.scheduler_memory_limit, + scheduler_memory_request=config.scheduler_memory_request, + scheduler_pool_type=config.scheduler_pool_type, + worker_memory_limit=config.worker_memory_limit, + worker_memory_request=config.worker_memory_request, + worker_pool_type=config.worker_pool_type, ) diff --git a/warehouse/metrics_tools/compute/constants.py b/warehouse/metrics_tools/compute/constants.py deleted file mode 100644 index 4a4291796..000000000 --- a/warehouse/metrics_tools/compute/constants.py +++ /dev/null @@ -1,52 +0,0 @@ -from dotenv import load_dotenv -from metrics_tools.utils import env - -load_dotenv() - -cluster_namespace = env.required_str("METRICS_CLUSTER_NAMESPACE") -cluster_service_account = env.required_str("METRICS_CLUSTER_SERVICE_ACCOUNT") -cluster_name = env.required_str("METRICS_CLUSTER_NAME") -cluster_image_repo = env.required_str( - "METRICS_CLUSTER_IMAGE_REPO", "ghcr.io/opensource-observer/oso" -) -cluster_image_tag = env.required_str("METRICS_CLUSTER_IMAGE_TAG") -scheduler_memory_limit = env.required_str("METRICS_SCHEDULER_MEMORY_LIMIT", "90000Mi") -scheduler_memory_request = env.required_str( - "METRICS_SCHEDULER_MEMORY_REQUEST", "85000Mi" -) -scheduler_pool_type = env.required_str( - "METRICS_SCHEDULER_POOL_TYPE", "sqlmesh-scheduler" -) -worker_threads = env.required_int("METRICS_WORKER_THREADS", 16) -worker_memory_limit = env.required_str("METRICS_WORKER_MEMORY_LIMIT", "90000Mi") -worker_memory_request = env.required_str("METRICS_WORKER_MEMORY_REQUEST", "85000Mi") -worker_pool_type = env.required_str("METRICS_WORKER_POOL_TYPE", "sqlmesh-worker") - -gcs_bucket = env.required_str("METRICS_GCS_BUCKET") -gcs_key_id = env.required_str("METRICS_GCS_KEY_ID") -gcs_secret = env.required_str("METRICS_GCS_SECRET") - -results_path_prefix = env.required_str("METRICS_GCS_RESULTS_PATH_PREFIX", "mcs-results") - -worker_duckdb_path = env.required_str("METRICS_WORKER_DUCKDB_PATH") - -trino_host = env.required_str("METRICS_TRINO_HOST") -trino_port = env.required_str("METRICS_TRINO_PORT") -trino_user = env.required_str("METRICS_TRINO_USER") -trino_catalog = env.required_str("METRICS_TRINO_CATALOG") - -hive_catalog = env.required_str("METRICS_HIVE_CATALOG", "source") -hive_schema = env.required_str("METRICS_HIVE_SCHEMA", "export") - -debug_all = env.ensure_bool("METRICS_DEBUG_ALL", False) -debug_with_duckdb = env.ensure_bool("METRICS_DEBUG_WITH_DUCKDB", False) -if not debug_all: - debug_cache = env.ensure_bool("METRICS_DEBUG_CACHE", False) - debug_cluster = env.ensure_bool("METRICS_DEBUG_CLUSTER", False) - debug_cluster_no_shutdown = env.ensure_bool( - "METRICS_DEBUG_CLUSTER_NO_SHUTDOWN", False - ) -else: - debug_cache = debug_all - debug_cluster = debug_all - debug_cluster_no_shutdown = debug_all diff --git a/warehouse/metrics_tools/compute/debug.py b/warehouse/metrics_tools/compute/debug.py index ee377aee1..b344b0451 100644 --- a/warehouse/metrics_tools/compute/debug.py +++ b/warehouse/metrics_tools/compute/debug.py @@ -6,24 +6,17 @@ from metrics_tools.compute.cluster import ( KubeClusterFactory, make_new_cluster_with_defaults, - start_duckdb_cluster, ) - -from . import constants +from metrics_tools.compute.types import AppConfig logger = logging.getLogger(__name__) -def test_setup_cluster(): - cluster_spec = make_new_cluster_with_defaults() - return start_duckdb_cluster(constants.cluster_namespace, cluster_spec) - - -def async_test_setup_cluster(): - cluster_spec = make_new_cluster_with_defaults() +def async_test_setup_cluster(config: AppConfig): + cluster_spec = make_new_cluster_with_defaults(config=config) cluster_factory = KubeClusterFactory( - constants.cluster_namespace, + config.cluster_namespace, cluster_spec=cluster_spec, log_override=logger, ) diff --git a/warehouse/metrics_tools/compute/manual_testing_utils.py b/warehouse/metrics_tools/compute/manual_testing_utils.py index 62232736c..3724d8f78 100644 --- a/warehouse/metrics_tools/compute/manual_testing_utils.py +++ b/warehouse/metrics_tools/compute/manual_testing_utils.py @@ -63,7 +63,7 @@ def run_local_test( logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) - client = Client(url, log_override=logger) + client = Client.from_url(url, log_override=logger) client.run_cache_manual_load( { diff --git a/warehouse/metrics_tools/compute/server.py b/warehouse/metrics_tools/compute/server.py index 3ff1d9468..d2b170cb0 100644 --- a/warehouse/metrics_tools/compute/server.py +++ b/warehouse/metrics_tools/compute/server.py @@ -1,194 +1,10 @@ -import logging -import os -import tempfile -import typing as t -import uuid -from contextlib import asynccontextmanager +"""This is the main entrypoint for uvicorn or fastapi to load the mcs server.""" -import aiotrino -from dotenv import load_dotenv -from fastapi import FastAPI, Request -from metrics_tools.compute.result import ( - DummyImportAdapter, - FakeLocalImportAdapter, - TrinoImportAdapter, -) +from .app import app_factory, default_lifecycle +from .types import AppConfig -from . import constants -from .cache import setup_fake_cache_export_manager, setup_trino_cache_export_manager -from .cluster import ( - ClusterManager, - KubeClusterFactory, - LocalClusterFactory, - make_new_cluster_with_defaults, -) -from .service import MetricsCalculationService -from .types import ( - ClusterStartRequest, - EmptyResponse, - ExportedTableLoadRequest, - QueryJobSubmitRequest, +app = app_factory( + default_lifecycle, + # App config won't resolve types correctly due to pydantic's BaseSettings + AppConfig(), # type: ignore ) - -load_dotenv() -logger = logging.getLogger(__name__) - - -@asynccontextmanager -async def initialize_app(app: FastAPI): - logger.info("Metrics calculation service is starting up") - if constants.debug_all: - logger.warning("Debugging all services") - - cache_export_manager = None - temp_dir = None - - if constants.debug_with_duckdb: - temp_dir = tempfile.mkdtemp() - logger.debug(f"Created temp dir {temp_dir}") - if not constants.debug_cache: - trino_connection = aiotrino.dbapi.connect( - host=constants.trino_host, - port=constants.trino_port, - user=constants.trino_user, - catalog=constants.trino_catalog, - ) - cache_export_manager = await setup_trino_cache_export_manager( - trino_connection, - constants.gcs_bucket, - constants.hive_catalog, - constants.hive_schema, - ) - import_adapter = TrinoImportAdapter( - db=trino_connection, - gcs_bucket=constants.gcs_bucket, - hive_catalog=constants.hive_catalog, - hive_schema=constants.hive_schema, - ) - else: - if constants.debug_with_duckdb: - assert temp_dir is not None - logger.warning("Loading fake cache export manager with duckdb") - import_adapter = FakeLocalImportAdapter(temp_dir) - else: - logger.warning("Loading dummy cache export manager (writes nothing)") - import_adapter = DummyImportAdapter() - cache_export_manager = await setup_fake_cache_export_manager() - - cluster_manager = None - if not constants.debug_cluster: - cluster_spec = make_new_cluster_with_defaults() - cluster_factory = KubeClusterFactory( - constants.cluster_namespace, - cluster_spec=cluster_spec, - shutdown_on_close=not constants.debug_cluster_no_shutdown, - ) - cluster_manager = ClusterManager.with_metrics_plugin( - constants.gcs_bucket, - constants.gcs_key_id, - constants.gcs_secret, - constants.worker_duckdb_path, - cluster_factory, - ) - else: - logger.warning("Loading fake cluster manager") - cluster_factory = LocalClusterFactory() - cluster_manager = ClusterManager.with_dummy_metrics_plugin( - cluster_factory, - ) - - mcs = MetricsCalculationService.setup( - id=str(uuid.uuid4()), - gcs_bucket=constants.gcs_bucket, - result_path_prefix=constants.results_path_prefix, - cluster_manager=cluster_manager, - cache_manager=cache_export_manager, - import_adapter=import_adapter, - ) - try: - yield { - "mcs": mcs, - } - finally: - logger.info("Waiting for metrics calculation service to close") - await mcs.close() - if temp_dir: - logger.info("Removing temp dir") - os.rmdir(temp_dir) - - -# Dependency to get the cluster manager -def get_mcs(request: Request) -> MetricsCalculationService: - mcs = request.state.mcs - assert mcs is not None - return t.cast(MetricsCalculationService, mcs) - - -app = FastAPI(lifespan=initialize_app) - - -@app.get("/status") -async def get_status(): - """Liveness endpoint""" - return {"status": "Service is running"} - - -@app.post("/cluster/start") -async def start_cluster( - request: Request, - start_request: ClusterStartRequest, -): - """Start a Dask cluster in an idempotent way. - - If the cluster is already running, it will not be restarted. - """ - state = get_mcs(request) - manager = state.cluster_manager - return await manager.start_cluster(start_request.min_size, start_request.max_size) - - -@app.post("/cluster/stop") -async def stop_cluster(request: Request): - """Stop the Dask cluster""" - state = get_mcs(request) - manager = state.cluster_manager - return await manager.stop_cluster() - - -@app.get("/cluster/status") -async def get_cluster_status(request: Request): - """Get the current Dask cluster status""" - state = get_mcs(request) - manager = state.cluster_manager - return await manager.get_cluster_status() - - -@app.post("/job/submit") -async def submit_job( - request: Request, - input: QueryJobSubmitRequest, -): - """Submits a Dask job for calculation""" - service = get_mcs(request) - return await service.submit_job(input) - - -@app.get("/job/status/{job_id}") -async def get_job_status( - request: Request, - job_id: str, -): - """Get the status of a job""" - include_stats = request.query_params.get("include_stats", "false").lower() == "true" - service = get_mcs(request) - return await service.get_job_status(job_id, include_stats=include_stats) - - -@app.post("/cache/manual") -async def add_existing_exported_table_references( - request: Request, input: ExportedTableLoadRequest -): - """Add a table export to the cache""" - service = get_mcs(request) - await service.add_existing_exported_table_references(input.map) - return EmptyResponse() diff --git a/warehouse/metrics_tools/compute/service.py b/warehouse/metrics_tools/compute/service.py index b37e2f9a8..05d6a1862 100644 --- a/warehouse/metrics_tools/compute/service.py +++ b/warehouse/metrics_tools/compute/service.py @@ -22,12 +22,12 @@ ColumnsDefinition, ExportReference, ExportType, + JobStatusResponse, + JobSubmitRequest, + JobSubmitResponse, QueryJobProgress, QueryJobState, QueryJobStatus, - QueryJobStatusResponse, - QueryJobSubmitRequest, - QueryJobSubmitResponse, QueryJobUpdate, ) @@ -109,7 +109,7 @@ async def handle_query_job_submit_request( self, job_id: str, result_path_base: str, - input: QueryJobSubmitRequest, + input: JobSubmitRequest, calculation_export: ExportReference, final_export: ExportReference, ): @@ -125,7 +125,7 @@ async def _handle_query_job_submit_request( self, job_id: str, result_path_base: str, - input: QueryJobSubmitRequest, + input: JobSubmitRequest, calculation_export: ExportReference, final_export: ExportReference, ): @@ -173,7 +173,7 @@ async def _batch_query_to_scheduler( self, job_id: str, result_path_base: str, - input: QueryJobSubmitRequest, + input: JobSubmitRequest, exported_dependent_tables_map: t.Dict[str, ExportReference], ): """Given a query job: break down into batches and submit to the scheduler""" @@ -253,7 +253,7 @@ async def start_cluster(self, start_request: ClusterStartRequest) -> ClusterStat async def get_cluster_status(self): return self.cluster_manager.get_cluster_status() - async def submit_job(self, input: QueryJobSubmitRequest): + async def submit_job(self, input: JobSubmitRequest): """Submit a job to the cluster to compute the metrics""" self.logger.debug("submitting job") job_id = f"export_{str(uuid.uuid4().hex)}" @@ -300,7 +300,7 @@ async def submit_job(self, input: QueryJobSubmitRequest): async with self.job_state_lock: self.job_tasks[job_id] = task - return QueryJobSubmitResponse( + return JobSubmitResponse( job_id=job_id, export_reference=final_expected_reference, ) @@ -371,6 +371,13 @@ async def _set_job_state( or update.status == QueryJobStatus.FAILED ): del self.job_tasks[job_id] + updated_state = copy.deepcopy(self.job_state[job_id]) + + self.logger.info("emitting job update events") + # Some things listen to all job updates + self.emitter.emit("job_update", job_id, updated_state) + # Some things listen to specific job updates + self.emitter.emit(f"job_update:{job_id}", updated_state) async def _get_job_state(self, job_id: str): """Get the current state of a job as a deep copy (to prevent @@ -379,9 +386,7 @@ async def _get_job_state(self, job_id: str): state = copy.deepcopy(self.job_state.get(job_id)) return state - async def generate_query_batches( - self, input: QueryJobSubmitRequest, batch_size: int - ): + async def generate_query_batches(self, input: JobSubmitRequest, batch_size: int): runner = MetricsRunner.from_engine_adapter( FakeEngineAdapter("duckdb"), input.query_as("duckdb"), @@ -403,7 +408,7 @@ async def generate_query_batches( if len(batch) > 0: yield (batch_num, batch) - async def resolve_dependent_tables(self, input: QueryJobSubmitRequest): + async def resolve_dependent_tables(self, input: JobSubmitRequest): """Resolve the dependent tables for the given input and returns the associate export references""" @@ -439,15 +444,27 @@ async def resolve_dependent_tables(self, input: QueryJobSubmitRequest): async def get_job_status( self, job_id: str, include_stats: bool = False - ) -> QueryJobStatusResponse: + ) -> JobStatusResponse: state = await self._get_job_state(job_id) if not state: raise ValueError(f"Job {job_id} not found") return state.as_response(include_stats=include_stats) - async def wait_for_job_status(self, job_id: str) -> QueryJobStatusResponse: - # self.emitter.once(f"job_status:{job_id}") - raise NotImplementedError("This method is not implemented yet") + def listen_for_job_updates( + self, job_id: str, handler: t.Callable[[JobStatusResponse], t.Awaitable[None]] + ): + self.logger.info("Listening for job updates") + + async def convert_to_response(state: QueryJobState): + self.logger.debug("converting to response") + return await handler(state.as_response()) + + handle = self.emitter.add_listener(f"job_update:{job_id}", convert_to_response) + + def remove_listener(): + return self.emitter.remove_listener(f"job_update:{job_id}", handle) + + return remove_listener async def add_existing_exported_table_references( self, update: t.Dict[str, ExportReference] diff --git a/warehouse/metrics_tools/compute/test_app.py b/warehouse/metrics_tools/compute/test_app.py new file mode 100644 index 000000000..207b7e297 --- /dev/null +++ b/warehouse/metrics_tools/compute/test_app.py @@ -0,0 +1,106 @@ +from contextlib import contextmanager +from datetime import datetime +from unittest.mock import MagicMock + +import pytest +from fastapi.testclient import TestClient +from metrics_tools.compute.app import app_factory, default_lifecycle +from metrics_tools.compute.client import BaseWebsocketConnector, Client +from metrics_tools.compute.types import AppConfig +from metrics_tools.definition import PeerMetricDependencyRef +from metrics_tools.utils.logging import setup_module_logging +from starlette.testclient import WebSocketTestSession + + +@pytest.fixture +def app_client_with_all_debugging(): + setup_module_logging("metrics_tools") + app = app_factory( + default_lifecycle, + AppConfig( + cluster_namespace="namespace", + cluster_service_account="service_account", + cluster_name="name", + worker_duckdb_path="path", + trino_catalog="catalog", + trino_host="trino", + trino_user="trino", + trino_port=8080, + hive_catalog="catalog", + hive_schema="schema", + gcs_bucket="bucket", + gcs_key_id="key", + gcs_secret="secret", + debug_all=True, + ), + ) + + with TestClient(app) as client: + yield client + + +class TestWebsocketConnector(BaseWebsocketConnector): + def __init__(self, session: WebSocketTestSession): + self.session = session + + def receive(self): + return self.session.receive_text() + + def send(self, data: str): + return self.session.send_text(data) + + +class TestClientWebsocketConnectFactory(BaseWebsocketConnector): + def __init__(self, client: TestClient): + self.client = client + + @contextmanager + def __call__(self, *, base_url: str, path: str): + with self.client.websocket_connect(path) as ws: + yield TestWebsocketConnector(ws) + + +def test_app_with_all_debugging(app_client_with_all_debugging): + client = Client( + app_client_with_all_debugging, + retries=1, + websocket_connect_factory=TestClientWebsocketConnectFactory( + app_client_with_all_debugging + ), + ) + + start = "2021-01-01" + end = "2021-01-03" + batch_size = 1 + cluster_size = 1 + + mock_handler = MagicMock() + + reference = client.calculate_metrics( + query_str="""SELECT * FROM test""", + start=datetime.strptime(start, "%Y-%m-%d"), + end=datetime.strptime(end, "%Y-%m-%d"), + dialect="duckdb", + columns=[ + ("bucket_day", "TIMESTAMP"), + ("to_artifact_id", "VARCHAR"), + ("from_artifact_id", "VARCHAR"), + ("event_source", "VARCHAR"), + ("event_type", "VARCHAR"), + ("amount", "NUMERIC"), + ], + ref=PeerMetricDependencyRef( + name="", entity_type="artifact", window=30, unit="day" + ), + locals={}, + dependent_tables_map={ + "metrics.events_daily_to_artifact": "sqlmesh__metrics.metrics__events_daily_to_artifact__2357434958" + }, + batch_size=batch_size, + cluster_max_size=cluster_size, + cluster_min_size=cluster_size, + progress_handler=mock_handler, + ) + + assert mock_handler.call_count == 6 + assert reference is not None diff --git a/warehouse/metrics_tools/compute/test_service.py b/warehouse/metrics_tools/compute/test_service.py index deb25a459..f87dba30f 100644 --- a/warehouse/metrics_tools/compute/test_service.py +++ b/warehouse/metrics_tools/compute/test_service.py @@ -11,8 +11,8 @@ ColumnsDefinition, ExportReference, ExportType, + JobSubmitRequest, QueryJobStatus, - QueryJobSubmitRequest, ) from metrics_tools.definition import PeerMetricDependencyRef @@ -41,7 +41,7 @@ async def test_metrics_calculation_service(): } ) response = await service.submit_job( - QueryJobSubmitRequest( + JobSubmitRequest( query_str="SELECT * FROM ref.table123", start=datetime(2021, 1, 1), end=datetime(2021, 1, 3), diff --git a/warehouse/metrics_tools/compute/types.py b/warehouse/metrics_tools/compute/types.py index c4716e873..762d7c997 100644 --- a/warehouse/metrics_tools/compute/types.py +++ b/warehouse/metrics_tools/compute/types.py @@ -4,15 +4,17 @@ from enum import Enum import pandas as pd +from fastapi import FastAPI from metrics_tools.definition import PeerMetricDependencyRef -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict from sqlmesh.core.dialect import parse_one logger = logging.getLogger(__name__) class EmptyResponse(BaseModel): - pass + type: t.Literal["EmptyResponse"] = "EmptyResponse" class ExportType(str, Enum): @@ -123,7 +125,13 @@ class ClusterStatus(BaseModel): workers: int -class QueryJobSubmitRequest(BaseModel): +class ClusterStatusResponse(BaseModel): + type: t.Literal["ClusterStatusResponse"] = "ClusterStatusResponse" + status: ClusterStatus + + +class JobSubmitRequest(BaseModel): + type: t.Literal["JobSubmitRequest"] = "JobSubmitRequest" query_str: str start: datetime end: datetime @@ -144,12 +152,14 @@ def columns_def(self) -> ColumnsDefinition: return ColumnsDefinition(columns=self.columns, dialect=self.dialect) -class QueryJobSubmitResponse(BaseModel): +class JobSubmitResponse(BaseModel): + type: t.Literal["JobSubmitResponse"] = "JobSubmitResponse" job_id: str export_reference: ExportReference -class QueryJobStatusResponse(BaseModel): +class JobStatusResponse(BaseModel): + type: t.Literal["JobStatusResponse"] = "JobStatusResponse" job_id: str created_at: datetime updated_at: datetime @@ -166,7 +176,7 @@ class QueryJobState(BaseModel): def latest_update(self) -> QueryJobUpdate: return self.updates[-1] - def as_response(self, include_stats: bool = False) -> QueryJobStatusResponse: + def as_response(self, include_stats: bool = False) -> JobStatusResponse: # Turn update events into stats stats = {} if include_stats: @@ -208,7 +218,7 @@ def as_response(self, include_stats: bool = False) -> QueryJobStatusResponse: else None ) - return QueryJobStatusResponse( + return JobStatusResponse( job_id=self.job_id, created_at=self.created_at, updated_at=self.latest_update().updated_at, @@ -219,13 +229,128 @@ def as_response(self, include_stats: bool = False) -> QueryJobStatusResponse: class ClusterStartRequest(BaseModel): + type: t.Literal["ClusterStartRequest"] = "ClusterStartRequest" min_size: int max_size: int +class ClusterStatusRequest(BaseModel): + type: t.Literal["ClusterStatusRequest"] = "ClusterStatusRequest" + + +class JobStatusRequest(BaseModel): + type: t.Literal["JobStatusRequest"] = "JobStatusRequest" + job_id: str + include_stats: bool + + class ExportedTableLoadRequest(BaseModel): + type: t.Literal["ExportedTableLoadRequest"] = "ExportedTableLoadRequest" map: t.Dict[str, ExportReference] +class InspectCacheRequest(BaseModel): + type: t.Literal["InspectCacheRequest"] = "InspectCacheRequest" + + class InspectCacheResponse(BaseModel): + type: t.Literal["InspectCacheResponse"] = "InspectCacheResponse" map: t.Dict[str, ExportReference] + + +class ErrorResponse(BaseModel): + type: t.Literal["ErrorResponse"] = "ErrorResponse" + message: str + + +ServiceRequestTypes = t.Union[ + ClusterStartRequest, + ClusterStatusRequest, + JobStatusRequest, + ExportedTableLoadRequest, +] + + +class ServiceRequest(BaseModel): + type: str + request: ServiceRequestTypes = Field(discriminator="type") + + +ServiceResponseTypes = t.Union[ + ClusterStatusResponse, + JobStatusResponse, + EmptyResponse, + InspectCacheResponse, + ErrorResponse, +] + + +class ServiceResponse(BaseModel): + type: str + response: ServiceResponseTypes = Field(discriminator="type") + + +class ClusterConfig(BaseSettings): + model_config = SettingsConfigDict(env_prefix="metrics_") + + cluster_namespace: str + cluster_service_account: str + cluster_name: str + cluster_image_repo: str = "ghcr.io/opensource-observer/oso" + cluster_image_tag: str = "latest" + + scheduler_memory_limit: str = "90000Mi" + scheduler_memory_request: str = "85000Mi" + scheduler_pool_type: str = "sqlmesh-scheduler" + + worker_threads: int = 16 + worker_memory_limit: str = "90000Mi" + worker_memory_request: str = "85000Mi" + worker_pool_type: str = "sqlmesh-worker" + worker_duckdb_path: str + + +class GCSConfig(BaseSettings): + model_config = SettingsConfigDict(env_prefix="metrics_") + + gcs_bucket: str + gcs_key_id: str + gcs_secret: str + + +class TrinoCacheExportConfig(BaseSettings): + model_config = SettingsConfigDict(env_prefix="metrics_") + + trino_host: str + trino_port: int + trino_user: str + trino_catalog: str + hive_catalog: str = "source" + hive_schema: str = "export" + + +class AppConfig(ClusterConfig, TrinoCacheExportConfig, GCSConfig): + model_config = SettingsConfigDict(env_prefix="metrics_") + + results_path_prefix: str = "mcs-results" + + debug_all: bool = False + debug_with_duckdb: bool = False + debug_cache: bool = False + debug_cluster: bool = False + debug_cluster_no_shutdown: bool = False + + @model_validator(mode="after") + def handle_debugging(self): + if self.debug_all: + self.debug_cache = True + self.debug_cluster = True + self.debug_with_duckdb = True + return self + + +AppLifespan = t.Callable[[FastAPI], t.Any] + +ConfigType = t.TypeVar("ConfigType") + +AppLifespanFactory = t.Callable[[ConfigType], AppLifespan] diff --git a/warehouse/metrics_tools/factory/factory.py b/warehouse/metrics_tools/factory/factory.py index cf3fc77a1..357690e17 100644 --- a/warehouse/metrics_tools/factory/factory.py +++ b/warehouse/metrics_tools/factory/factory.py @@ -413,7 +413,7 @@ def generate_rolling_python_model_for_rendered_query( columns = METRICS_COLUMNS_BY_ENTITY[ref["entity_type"]] - kind_common = {"batch_concurrency": 1} + kind_common = {"batch_size": 365, "batch_concurrency": 1} partitioned_by = ("day(metrics_sample_date)",) window = ref.get("window") assert window is not None @@ -661,7 +661,7 @@ def generated_rolling_query( logger.info("metrics calculation service enabled") mcs_url = env.required_str("SQLMESH_MCS_URL") - mcs_client = Client(url=mcs_url) + mcs_client = Client.from_url(url=mcs_url) columns = [ (col_name, col_type.sql(dialect="duckdb")) @@ -682,6 +682,8 @@ def generated_rolling_query( dependent_tables_map=create_dependent_tables_map( context, rendered_query_str ), + cluster_min_size=env.ensure_int("SQLMESH_MCS_CLUSTER_MIN_SIZE", 0), + cluster_max_size=env.ensure_int("SQLMESH_MCS_CLUSTER_MAX_SIZE", 30), ) column_names = list(map(lambda col: col[0], columns))