diff --git a/warehouse/metrics_tools/compute/service.py b/warehouse/metrics_tools/compute/service.py index 378c0fd3..89a747a1 100644 --- a/warehouse/metrics_tools/compute/service.py +++ b/warehouse/metrics_tools/compute/service.py @@ -322,9 +322,7 @@ async def _notify_job_pending(self, job_id: str, input: JobSubmitRequest): async def _notify_job_running(self, job_id: str): await self._update_job_state( job_id, - QueryJobUpdate( - time=datetime.now(), - scope=QueryJobUpdateScope.JOB, + QueryJobUpdate.create_job_update( payload=QueryJobStateUpdate( status=QueryJobStatus.RUNNING, has_remaining_tasks=True, @@ -335,9 +333,7 @@ async def _notify_job_running(self, job_id: str): async def _notify_job_task_completed(self, job_id: str, task_id: str): await self._update_job_state( job_id, - QueryJobUpdate( - time=datetime.now(), - scope=QueryJobUpdateScope.TASK, + QueryJobUpdate.create_task_update( payload=QueryJobTaskUpdate( task_id=task_id, status=QueryJobTaskStatus.SUCCEEDED, @@ -350,9 +346,7 @@ async def _notify_job_task_failed( ): await self._update_job_state( job_id, - QueryJobUpdate( - time=datetime.now(), - scope=QueryJobUpdateScope.TASK, + QueryJobUpdate.create_task_update( payload=QueryJobTaskUpdate( task_id=task_id, status=QueryJobTaskStatus.FAILED, @@ -364,9 +358,7 @@ async def _notify_job_task_failed( async def _notify_job_task_cancelled(self, job_id: str, task_id: str): await self._update_job_state( job_id, - QueryJobUpdate( - time=datetime.now(), - scope=QueryJobUpdateScope.TASK, + QueryJobUpdate.create_task_update( payload=QueryJobTaskUpdate( task_id=task_id, status=QueryJobTaskStatus.CANCELLED, @@ -377,9 +369,7 @@ async def _notify_job_task_cancelled(self, job_id: str, task_id: str): async def _notify_job_completed(self, job_id: str): await self._update_job_state( job_id, - QueryJobUpdate( - time=datetime.now(), - scope=QueryJobUpdateScope.JOB, + QueryJobUpdate.create_job_update( payload=QueryJobStateUpdate( status=QueryJobStatus.COMPLETED, has_remaining_tasks=False, @@ -395,9 +385,7 @@ async def _notify_job_failed( ): await self._update_job_state( job_id, - QueryJobUpdate( - time=datetime.now(), - scope=QueryJobUpdateScope.JOB, + QueryJobUpdate.create_job_update( payload=QueryJobStateUpdate( status=QueryJobStatus.FAILED, has_remaining_tasks=has_remaining_tasks, diff --git a/warehouse/metrics_tools/compute/test_service.py b/warehouse/metrics_tools/compute/test_service.py index 0731b678..b71b388a 100644 --- a/warehouse/metrics_tools/compute/test_service.py +++ b/warehouse/metrics_tools/compute/test_service.py @@ -1,4 +1,5 @@ import asyncio +import typing as t from datetime import datetime import pytest @@ -11,6 +12,7 @@ ColumnsDefinition, ExportReference, ExportType, + JobStatusResponse, JobSubmitRequest, QueryJobStatus, ) @@ -62,12 +64,91 @@ async def test_metrics_calculation_service(): ) async def wait_for_job_to_complete(): - status = await service.get_job_status(response.job_id) - while status.status in [QueryJobStatus.PENDING, QueryJobStatus.RUNNING]: - status = await service.get_job_status(response.job_id) - await asyncio.sleep(1) + updates: t.List[JobStatusResponse] = [] + future = asyncio.Future() + + async def collect_updates(update: JobStatusResponse): + updates.append(update) + if update.status not in [QueryJobStatus.PENDING, QueryJobStatus.RUNNING]: + future.set_result(updates) + + close = service.listen_for_job_updates(response.job_id, collect_updates) + return (close, future) + + close, updates_future = await asyncio.create_task(wait_for_job_to_complete()) + updates = await updates_future + close() + + assert len(updates) == 5 + + status = await service.get_job_status(response.job_id) + assert status.status == QueryJobStatus.COMPLETED + + await service.close() + + +@pytest.mark.asyncio +async def test_metrics_calculation_service_using_monthly_cron(): + service = MetricsCalculationService.setup( + "someid", + "bucket", + "result_path_prefix", + ClusterManager.with_dummy_metrics_plugin(LocalClusterFactory()), + await CacheExportManager.setup(FakeExportAdapter()), + DummyImportAdapter(), + ) + await service.start_cluster(ClusterStartRequest(min_size=1, max_size=1)) + await service.add_existing_exported_table_references( + { + "source.table123": ExportReference( + table_name="export_table123", + type=ExportType.GCS, + columns=ColumnsDefinition( + columns=[("col1", "INT"), ("col2", "TEXT")], dialect="duckdb" + ), + payload={"gcs_path": "gs://bucket/result_path_prefix/export_table123"}, + ), + } + ) + response = await service.submit_job( + JobSubmitRequest( + query_str="SELECT * FROM ref.table123", + start=datetime(2021, 1, 1), + end=datetime(2021, 4, 1), + dialect="duckdb", + batch_size=1, + columns=[("col1", "int"), ("col2", "string")], + ref=PeerMetricDependencyRef( + name="test", + entity_type="artifact", + window=30, + unit="day", + cron="@monthly", + ), + execution_time=datetime.now(), + locals={}, + dependent_tables_map={"source.table123": "source.table123"}, + ) + ) + + async def wait_for_job_to_complete(): + updates: t.List[JobStatusResponse] = [] + future = asyncio.Future() + + async def collect_updates(update: JobStatusResponse): + updates.append(update) + if update.status not in [QueryJobStatus.PENDING, QueryJobStatus.RUNNING]: + future.set_result(updates) + + close = service.listen_for_job_updates(response.job_id, collect_updates) + return (close, future) + + close, updates_future = await asyncio.create_task(wait_for_job_to_complete()) + updates = await updates_future + close() + + assert len(updates) == 6 - await asyncio.wait_for(asyncio.create_task(wait_for_job_to_complete()), timeout=60) status = await service.get_job_status(response.job_id) assert status.status == QueryJobStatus.COMPLETED diff --git a/warehouse/metrics_tools/compute/test_types.py b/warehouse/metrics_tools/compute/test_types.py new file mode 100644 index 00000000..5e81a726 --- /dev/null +++ b/warehouse/metrics_tools/compute/test_types.py @@ -0,0 +1,164 @@ +import pytest + +from .types import ( + QueryJobState, + QueryJobStateUpdate, + QueryJobStatus, + QueryJobTaskStatus, + QueryJobTaskUpdate, + QueryJobUpdate, +) + + +@pytest.mark.parametrize( + "description,updates,expected_status,expected_has_remaining_tasks,expected_exceptions_count", + [ + ( + "should fail if job failed", + [ + QueryJobUpdate.create_job_update( + QueryJobStateUpdate( + status=QueryJobStatus.FAILED, + has_remaining_tasks=False, + exception="failed", + ) + ) + ], + QueryJobStatus.FAILED, + False, + 1, + ), + ( + "should still be running if no failure", + [ + QueryJobUpdate.create_job_update( + QueryJobStateUpdate( + status=QueryJobStatus.RUNNING, + has_remaining_tasks=True, + ), + ), + QueryJobUpdate.create_task_update( + QueryJobTaskUpdate( + status=QueryJobTaskStatus.SUCCEEDED, + task_id="task_id", + ) + ), + ], + QueryJobStatus.RUNNING, + True, + 0, + ), + ( + "should fail if task failed and still has remaining tasks", + [ + QueryJobUpdate.create_job_update( + QueryJobStateUpdate( + status=QueryJobStatus.RUNNING, + has_remaining_tasks=True, + ), + ), + QueryJobUpdate.create_task_update( + QueryJobTaskUpdate( + status=QueryJobTaskStatus.FAILED, + task_id="task_id", + exception="failed", + ) + ), + ], + QueryJobStatus.FAILED, + True, + 1, + ), + ( + "should fail if task failed and job failed but no remaining tasks", + [ + QueryJobUpdate.create_job_update( + QueryJobStateUpdate( + status=QueryJobStatus.RUNNING, + has_remaining_tasks=True, + ), + ), + QueryJobUpdate.create_task_update( + QueryJobTaskUpdate( + status=QueryJobTaskStatus.FAILED, + task_id="task_id", + exception="failed", + ) + ), + QueryJobUpdate.create_job_update( + QueryJobStateUpdate( + status=QueryJobStatus.FAILED, + has_remaining_tasks=False, + exception="failed", + ) + ), + ], + QueryJobStatus.FAILED, + False, + 2, + ), + ( + "should fail if task failed and job supposedly completed but no remaining tasks", + [ + QueryJobUpdate.create_job_update( + QueryJobStateUpdate( + status=QueryJobStatus.RUNNING, + has_remaining_tasks=True, + ), + ), + QueryJobUpdate.create_task_update( + QueryJobTaskUpdate( + status=QueryJobTaskStatus.FAILED, + task_id="task_id", + exception="failed", + ) + ), + QueryJobUpdate.create_job_update( + QueryJobStateUpdate( + status=QueryJobStatus.COMPLETED, + has_remaining_tasks=False, + ) + ), + ], + QueryJobStatus.FAILED, + False, + 1, + ), + ( + "should fail if a task is cancelled", + [ + QueryJobUpdate.create_job_update( + QueryJobStateUpdate( + status=QueryJobStatus.RUNNING, + has_remaining_tasks=True, + ), + ), + QueryJobUpdate.create_task_update( + QueryJobTaskUpdate( + status=QueryJobTaskStatus.CANCELLED, + task_id="task_id", + ) + ), + ], + QueryJobStatus.FAILED, + True, + 0, + ), + ], +) +def test_query_job_state( + description, + updates, + expected_status, + expected_has_remaining_tasks, + expected_exceptions_count, +): + state = QueryJobState.start("job_id", 4) + for update in updates: + state.update(update) + assert state.status == expected_status, description + assert state.has_remaining_tasks == expected_has_remaining_tasks, description + + response = state.as_response() + assert response.status == expected_status, description + assert len(response.exceptions) == expected_exceptions_count, description diff --git a/warehouse/metrics_tools/compute/types.py b/warehouse/metrics_tools/compute/types.py index 13d6c21b..b0b2baf4 100644 --- a/warehouse/metrics_tools/compute/types.py +++ b/warehouse/metrics_tools/compute/types.py @@ -149,6 +149,14 @@ class QueryJobUpdate(BaseModel): scope: QueryJobUpdateScope payload: QueryJobUpdateTypes = Field(discriminator="type") + @classmethod + def create_job_update(cls, payload: QueryJobStateUpdate) -> "QueryJobUpdate": + return cls(time=datetime.now(), scope=QueryJobUpdateScope.JOB, payload=payload) + + @classmethod + def create_task_update(cls, payload: QueryJobTaskUpdate) -> "QueryJobUpdate": + return cls(time=datetime.now(), scope=QueryJobUpdateScope.TASK, payload=payload) + class ClusterStatus(BaseModel): status: str @@ -219,6 +227,25 @@ class QueryJobState(BaseModel): status: QueryJobStatus = QueryJobStatus.PENDING updates: t.List[QueryJobUpdate] + @classmethod + def start(cls, job_id: str, tasks_count: int) -> "QueryJobState": + now = datetime.now() + return cls( + job_id=job_id, + created_at=now, + tasks_count=tasks_count, + updates=[ + QueryJobUpdate( + time=now, + scope=QueryJobUpdateScope.JOB, + payload=QueryJobStateUpdate( + status=QueryJobStatus.PENDING, + has_remaining_tasks=True, + ), + ) + ], + ) + def latest_update(self) -> QueryJobUpdate: return self.updates[-1] @@ -233,6 +260,7 @@ def update(self, update: QueryJobUpdate): self.has_remaining_tasks = False elif payload.status == QueryJobStatus.FAILED: self.has_remaining_tasks = payload.has_remaining_tasks + self.status = payload.status elif payload.status == QueryJobStatus.RUNNING: self.status = payload.status else: diff --git a/warehouse/metrics_tools/test_runner.py b/warehouse/metrics_tools/test_runner.py new file mode 100644 index 00000000..0d9957c1 --- /dev/null +++ b/warehouse/metrics_tools/test_runner.py @@ -0,0 +1,28 @@ +from datetime import datetime + +import duckdb +from metrics_tools.definition import PeerMetricDependencyRef +from metrics_tools.runner import MetricsRunner + + +def test_runner_rendering(): + runner = MetricsRunner.create_duckdb_execution_context( + conn=duckdb.connect(), + query=""" + select time from foo + where time between @metrics_start('DATE') + and @metrics_end('DATE') + """, + ref=PeerMetricDependencyRef( + name="test", + entity_type="artifact", + window=30, + unit="day", + cron="@monthly", + ), + locals={}, + ) + start = datetime.strptime("2024-01-01", "%Y-%m-%d") + end = datetime.strptime("2024-12-31", "%Y-%m-%d") + rendered = list(runner.render_rolling_queries(start, end)) + assert len(rendered) == 12