diff --git a/poetry.lock b/poetry.lock index 60f83f29c..c71c44c88 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5367,6 +5367,24 @@ pluggy = ">=1.5,<2" [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.24.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"}, + {file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"}, +] + +[package.dependencies] +pytest = ">=8.2,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "python-box" version = "7.2.0" @@ -7667,4 +7685,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.12,<3.13" -content-hash = "1c8c410d5892963a5afbe7ddebb39c6f04b6363661e8c933f3a6a071835e7097" +content-hash = "8ae2b5b0266100c2ecb14a2d961dc580442f13afdd2df8df8f4ee962f51cd89c" diff --git a/pyproject.toml b/pyproject.toml index b5c3720e4..54280bb9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ bokeh = "^3.6.1" fastapi = {extras = ["standard"], version = "^0.115.6"} pyee = "^12.1.1" aiotrino = "^0.2.3" +pytest-asyncio = "^0.24.0" [tool.poetry.scripts] diff --git a/warehouse/metrics_tools/compute/cache.py b/warehouse/metrics_tools/compute/cache.py index 3918e04d0..d1196b612 100644 --- a/warehouse/metrics_tools/compute/cache.py +++ b/warehouse/metrics_tools/compute/cache.py @@ -245,7 +245,10 @@ async def export_table(table: str): in_progress.remove(table) while not self.stop_signal.is_set(): - item = await self.export_queue.get() + try: + item = await asyncio.wait_for(self.export_queue.get(), timeout=1) + except asyncio.TimeoutError: + continue if item.table in in_progress: # The table is already being exported. Skip this in the queue continue @@ -300,7 +303,8 @@ async def resolve_export_references(self, tables: t.List[str]): registration = None export_map: t.Dict[str, ExportReference] = {} - for table in tables: + # Ensure we are only comparing unique tables + for table in set(tables): reference = await self.get_export_table_reference(table) if reference is not None: export_map[table] = reference diff --git a/warehouse/metrics_tools/compute/service.py b/warehouse/metrics_tools/compute/service.py index 4aff7c0e7..2d481b541 100644 --- a/warehouse/metrics_tools/compute/service.py +++ b/warehouse/metrics_tools/compute/service.py @@ -104,7 +104,6 @@ def __init__( async def handle_query_job_submit_request( self, job_id: str, result_path_base: str, input: QueryJobSubmitRequest ): - await self._notify_job_pending(job_id, 1) try: await self._handle_query_job_submit_request(job_id, result_path_base, input) except Exception as e: @@ -198,6 +197,7 @@ async def _handle_query_job_submit_request( async def close(self): await self.cluster_manager.close() + await self.cache_manager.stop() self.stop_event.set() async def start_cluster(self, start_request: ClusterStartRequest) -> ClusterStatus: @@ -219,6 +219,7 @@ async def submit_job(self, input: QueryJobSubmitRequest): f"gs://{self.gcs_bucket}", result_path_base, "*.parquet" ) + await self._notify_job_pending(job_id, 1) task = asyncio.create_task( self.handle_query_job_submit_request(job_id, result_path_base, input) ) diff --git a/warehouse/metrics_tools/compute/test_cache.py b/warehouse/metrics_tools/compute/test_cache.py new file mode 100644 index 000000000..4feb234d8 --- /dev/null +++ b/warehouse/metrics_tools/compute/test_cache.py @@ -0,0 +1,33 @@ +import asyncio +from unittest.mock import AsyncMock +from metrics_tools.compute.types import ExportReference, ExportType +import pytest + +from metrics_tools.compute.cache import CacheExportManager, FakeExportAdapter + + +@pytest.mark.asyncio +async def test_cache_export_manager(): + adapter_mock = AsyncMock(FakeExportAdapter) + adapter_mock.export_table.return_value = ExportReference( + table="test", + type=ExportType.GCS, + payload={}, + ) + cache = await CacheExportManager.setup(adapter_mock) + + export_table_0 = await asyncio.wait_for( + cache.resolve_export_references(["table1", "table2"]), timeout=1 + ) + + assert export_table_0.keys() == {"table1", "table2"} + + # Attempt to export tables again but this should be mostly cache hits except + # for table3 + export_table_1 = await asyncio.wait_for( + cache.resolve_export_references(["table1", "table2", "table1", "table3"]), + timeout=1, + ) + assert export_table_1.keys() == {"table1", "table2", "table3"} + + assert adapter_mock.export_table.call_count == 3 diff --git a/warehouse/metrics_tools/compute/test_cluster.py b/warehouse/metrics_tools/compute/test_cluster.py new file mode 100644 index 000000000..bdeb195a4 --- /dev/null +++ b/warehouse/metrics_tools/compute/test_cluster.py @@ -0,0 +1,69 @@ +import asyncio + +from dask.distributed import Client +import pytest + +from metrics_tools.compute.cluster import ( + ClusterManager, + ClusterProxy, + ClusterFactory, + ClusterStatus, +) + + +class FakeClient(Client): + def __init__(self): + pass + + def close(self, *args, **kwargs): + pass + + def register_worker_plugin(self, *args, **kwargs): + pass + + +class FakeClusterProxy(ClusterProxy): + def __init__(self, min_size: int, max_size: int): + self.min_size = min_size + self.max_size = max_size + + async def client(self) -> Client: + return await FakeClient() + + async def status(self): + return ClusterStatus( + status="running", + is_ready=True, + dashboard_url="", + workers=1, + ) + + async def stop(self): + return + + @property + def dashboard_link(self): + return "http://fake-dashboard.com" + + @property + def workers(self): + return 1 + + +class FakeClusterFactory(ClusterFactory): + async def create_cluster(self, min_size: int, max_size: int): + return FakeClusterProxy(min_size, max_size) + + +@pytest.mark.asyncio +async def test_cluster_manager_reports_ready(): + cluster_manager = ClusterManager.with_dummy_metrics_plugin(FakeClusterFactory()) + + ready_future = cluster_manager.wait_for_ready() + + await cluster_manager.start_cluster(1, 1) + + try: + await asyncio.wait_for(ready_future, timeout=1) + except asyncio.TimeoutError: + pytest.fail("Cluster never reported ready") diff --git a/warehouse/metrics_tools/compute/test_service.py b/warehouse/metrics_tools/compute/test_service.py new file mode 100644 index 000000000..463339b3c --- /dev/null +++ b/warehouse/metrics_tools/compute/test_service.py @@ -0,0 +1,66 @@ +from metrics_tools.compute.types import ( + ClusterStartRequest, + ExportReference, + ExportType, + QueryJobStatus, + QueryJobSubmitRequest, +) +from metrics_tools.definition import PeerMetricDependencyRef +import pytest +import asyncio + +from metrics_tools.compute.service import MetricsCalculationService +from metrics_tools.compute.cluster import ClusterManager, LocalClusterFactory +from metrics_tools.compute.cache import CacheExportManager, FakeExportAdapter +from datetime import datetime + + +@pytest.mark.asyncio +async def test_metrics_calculation_service(): + service = MetricsCalculationService.setup( + "someid", + "bucket", + "result_path_prefix", + ClusterManager.with_dummy_metrics_plugin(LocalClusterFactory()), + await CacheExportManager.setup(FakeExportAdapter()), + ) + await service.start_cluster(ClusterStartRequest(min_size=1, max_size=1)) + await service.add_existing_exported_table_references( + { + "source.table123": ExportReference( + table="export_table123", + type=ExportType.GCS, + payload={"gcs_path": "gs://bucket/result_path_prefix/export_table123"}, + ), + } + ) + response = await service.submit_job( + QueryJobSubmitRequest( + query_str="SELECT * FROM ref.table123", + start=datetime(2021, 1, 1), + end=datetime(2021, 1, 3), + dialect="duckdb", + batch_size=1, + columns=[("col1", "int"), ("col2", "string")], + ref=PeerMetricDependencyRef( + name="test", + entity_type="artifact", + window=30, + unit="day", + ), + locals={}, + dependent_tables_map={"source.table123": "source.table123"}, + ) + ) + + async def wait_for_job_to_complete(): + status = await service.get_job_status(response.job_id) + while status.status in [QueryJobStatus.PENDING, QueryJobStatus.RUNNING]: + status = await service.get_job_status(response.job_id) + await asyncio.sleep(1) + + await asyncio.wait_for(asyncio.create_task(wait_for_job_to_complete()), timeout=60) + status = await service.get_job_status(response.job_id) + assert status.status == QueryJobStatus.COMPLETED + + await service.close()