Skip to content

Commit

Permalink
MCS Bugs (#2667)
Browse files Browse the repository at this point in the history
* Found these MCS bugs

wrote regression tests to cover them

* additional refactoring

* remove unused comments
  • Loading branch information
ravenac95 authored Dec 18, 2024
1 parent e1a8786 commit 1a9849d
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 23 deletions.
24 changes: 6 additions & 18 deletions warehouse/metrics_tools/compute/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,7 @@ async def _notify_job_pending(self, job_id: str, input: JobSubmitRequest):
async def _notify_job_running(self, job_id: str):
await self._update_job_state(
job_id,
QueryJobUpdate(
time=datetime.now(),
scope=QueryJobUpdateScope.JOB,
QueryJobUpdate.create_job_update(
payload=QueryJobStateUpdate(
status=QueryJobStatus.RUNNING,
has_remaining_tasks=True,
Expand All @@ -335,9 +333,7 @@ async def _notify_job_running(self, job_id: str):
async def _notify_job_task_completed(self, job_id: str, task_id: str):
await self._update_job_state(
job_id,
QueryJobUpdate(
time=datetime.now(),
scope=QueryJobUpdateScope.TASK,
QueryJobUpdate.create_task_update(
payload=QueryJobTaskUpdate(
task_id=task_id,
status=QueryJobTaskStatus.SUCCEEDED,
Expand All @@ -350,9 +346,7 @@ async def _notify_job_task_failed(
):
await self._update_job_state(
job_id,
QueryJobUpdate(
time=datetime.now(),
scope=QueryJobUpdateScope.TASK,
QueryJobUpdate.create_task_update(
payload=QueryJobTaskUpdate(
task_id=task_id,
status=QueryJobTaskStatus.FAILED,
Expand All @@ -364,9 +358,7 @@ async def _notify_job_task_failed(
async def _notify_job_task_cancelled(self, job_id: str, task_id: str):
await self._update_job_state(
job_id,
QueryJobUpdate(
time=datetime.now(),
scope=QueryJobUpdateScope.TASK,
QueryJobUpdate.create_task_update(
payload=QueryJobTaskUpdate(
task_id=task_id,
status=QueryJobTaskStatus.CANCELLED,
Expand All @@ -377,9 +369,7 @@ async def _notify_job_task_cancelled(self, job_id: str, task_id: str):
async def _notify_job_completed(self, job_id: str):
await self._update_job_state(
job_id,
QueryJobUpdate(
time=datetime.now(),
scope=QueryJobUpdateScope.JOB,
QueryJobUpdate.create_job_update(
payload=QueryJobStateUpdate(
status=QueryJobStatus.COMPLETED,
has_remaining_tasks=False,
Expand All @@ -395,9 +385,7 @@ async def _notify_job_failed(
):
await self._update_job_state(
job_id,
QueryJobUpdate(
time=datetime.now(),
scope=QueryJobUpdateScope.JOB,
QueryJobUpdate.create_job_update(
payload=QueryJobStateUpdate(
status=QueryJobStatus.FAILED,
has_remaining_tasks=has_remaining_tasks,
Expand Down
91 changes: 86 additions & 5 deletions warehouse/metrics_tools/compute/test_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import typing as t
from datetime import datetime

import pytest
Expand All @@ -11,6 +12,7 @@
ColumnsDefinition,
ExportReference,
ExportType,
JobStatusResponse,
JobSubmitRequest,
QueryJobStatus,
)
Expand Down Expand Up @@ -62,12 +64,91 @@ async def test_metrics_calculation_service():
)

async def wait_for_job_to_complete():
status = await service.get_job_status(response.job_id)
while status.status in [QueryJobStatus.PENDING, QueryJobStatus.RUNNING]:
status = await service.get_job_status(response.job_id)
await asyncio.sleep(1)
updates: t.List[JobStatusResponse] = []
future = asyncio.Future()

async def collect_updates(update: JobStatusResponse):
updates.append(update)
if update.status not in [QueryJobStatus.PENDING, QueryJobStatus.RUNNING]:
future.set_result(updates)

close = service.listen_for_job_updates(response.job_id, collect_updates)
return (close, future)

close, updates_future = await asyncio.create_task(wait_for_job_to_complete())
updates = await updates_future
close()

assert len(updates) == 5

status = await service.get_job_status(response.job_id)
assert status.status == QueryJobStatus.COMPLETED

await service.close()


@pytest.mark.asyncio
async def test_metrics_calculation_service_using_monthly_cron():
service = MetricsCalculationService.setup(
"someid",
"bucket",
"result_path_prefix",
ClusterManager.with_dummy_metrics_plugin(LocalClusterFactory()),
await CacheExportManager.setup(FakeExportAdapter()),
DummyImportAdapter(),
)
await service.start_cluster(ClusterStartRequest(min_size=1, max_size=1))
await service.add_existing_exported_table_references(
{
"source.table123": ExportReference(
table_name="export_table123",
type=ExportType.GCS,
columns=ColumnsDefinition(
columns=[("col1", "INT"), ("col2", "TEXT")], dialect="duckdb"
),
payload={"gcs_path": "gs://bucket/result_path_prefix/export_table123"},
),
}
)
response = await service.submit_job(
JobSubmitRequest(
query_str="SELECT * FROM ref.table123",
start=datetime(2021, 1, 1),
end=datetime(2021, 4, 1),
dialect="duckdb",
batch_size=1,
columns=[("col1", "int"), ("col2", "string")],
ref=PeerMetricDependencyRef(
name="test",
entity_type="artifact",
window=30,
unit="day",
cron="@monthly",
),
execution_time=datetime.now(),
locals={},
dependent_tables_map={"source.table123": "source.table123"},
)
)

async def wait_for_job_to_complete():
updates: t.List[JobStatusResponse] = []
future = asyncio.Future()

async def collect_updates(update: JobStatusResponse):
updates.append(update)
if update.status not in [QueryJobStatus.PENDING, QueryJobStatus.RUNNING]:
future.set_result(updates)

close = service.listen_for_job_updates(response.job_id, collect_updates)
return (close, future)

close, updates_future = await asyncio.create_task(wait_for_job_to_complete())
updates = await updates_future
close()

assert len(updates) == 6

await asyncio.wait_for(asyncio.create_task(wait_for_job_to_complete()), timeout=60)
status = await service.get_job_status(response.job_id)
assert status.status == QueryJobStatus.COMPLETED

Expand Down
164 changes: 164 additions & 0 deletions warehouse/metrics_tools/compute/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import pytest

from .types import (
QueryJobState,
QueryJobStateUpdate,
QueryJobStatus,
QueryJobTaskStatus,
QueryJobTaskUpdate,
QueryJobUpdate,
)


@pytest.mark.parametrize(
"description,updates,expected_status,expected_has_remaining_tasks,expected_exceptions_count",
[
(
"should fail if job failed",
[
QueryJobUpdate.create_job_update(
QueryJobStateUpdate(
status=QueryJobStatus.FAILED,
has_remaining_tasks=False,
exception="failed",
)
)
],
QueryJobStatus.FAILED,
False,
1,
),
(
"should still be running if no failure",
[
QueryJobUpdate.create_job_update(
QueryJobStateUpdate(
status=QueryJobStatus.RUNNING,
has_remaining_tasks=True,
),
),
QueryJobUpdate.create_task_update(
QueryJobTaskUpdate(
status=QueryJobTaskStatus.SUCCEEDED,
task_id="task_id",
)
),
],
QueryJobStatus.RUNNING,
True,
0,
),
(
"should fail if task failed and still has remaining tasks",
[
QueryJobUpdate.create_job_update(
QueryJobStateUpdate(
status=QueryJobStatus.RUNNING,
has_remaining_tasks=True,
),
),
QueryJobUpdate.create_task_update(
QueryJobTaskUpdate(
status=QueryJobTaskStatus.FAILED,
task_id="task_id",
exception="failed",
)
),
],
QueryJobStatus.FAILED,
True,
1,
),
(
"should fail if task failed and job failed but no remaining tasks",
[
QueryJobUpdate.create_job_update(
QueryJobStateUpdate(
status=QueryJobStatus.RUNNING,
has_remaining_tasks=True,
),
),
QueryJobUpdate.create_task_update(
QueryJobTaskUpdate(
status=QueryJobTaskStatus.FAILED,
task_id="task_id",
exception="failed",
)
),
QueryJobUpdate.create_job_update(
QueryJobStateUpdate(
status=QueryJobStatus.FAILED,
has_remaining_tasks=False,
exception="failed",
)
),
],
QueryJobStatus.FAILED,
False,
2,
),
(
"should fail if task failed and job supposedly completed but no remaining tasks",
[
QueryJobUpdate.create_job_update(
QueryJobStateUpdate(
status=QueryJobStatus.RUNNING,
has_remaining_tasks=True,
),
),
QueryJobUpdate.create_task_update(
QueryJobTaskUpdate(
status=QueryJobTaskStatus.FAILED,
task_id="task_id",
exception="failed",
)
),
QueryJobUpdate.create_job_update(
QueryJobStateUpdate(
status=QueryJobStatus.COMPLETED,
has_remaining_tasks=False,
)
),
],
QueryJobStatus.FAILED,
False,
1,
),
(
"should fail if a task is cancelled",
[
QueryJobUpdate.create_job_update(
QueryJobStateUpdate(
status=QueryJobStatus.RUNNING,
has_remaining_tasks=True,
),
),
QueryJobUpdate.create_task_update(
QueryJobTaskUpdate(
status=QueryJobTaskStatus.CANCELLED,
task_id="task_id",
)
),
],
QueryJobStatus.FAILED,
True,
0,
),
],
)
def test_query_job_state(
description,
updates,
expected_status,
expected_has_remaining_tasks,
expected_exceptions_count,
):
state = QueryJobState.start("job_id", 4)
for update in updates:
state.update(update)
assert state.status == expected_status, description
assert state.has_remaining_tasks == expected_has_remaining_tasks, description

response = state.as_response()
assert response.status == expected_status, description
assert len(response.exceptions) == expected_exceptions_count, description
28 changes: 28 additions & 0 deletions warehouse/metrics_tools/compute/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ class QueryJobUpdate(BaseModel):
scope: QueryJobUpdateScope
payload: QueryJobUpdateTypes = Field(discriminator="type")

@classmethod
def create_job_update(cls, payload: QueryJobStateUpdate) -> "QueryJobUpdate":
return cls(time=datetime.now(), scope=QueryJobUpdateScope.JOB, payload=payload)

@classmethod
def create_task_update(cls, payload: QueryJobTaskUpdate) -> "QueryJobUpdate":
return cls(time=datetime.now(), scope=QueryJobUpdateScope.TASK, payload=payload)


class ClusterStatus(BaseModel):
status: str
Expand Down Expand Up @@ -219,6 +227,25 @@ class QueryJobState(BaseModel):
status: QueryJobStatus = QueryJobStatus.PENDING
updates: t.List[QueryJobUpdate]

@classmethod
def start(cls, job_id: str, tasks_count: int) -> "QueryJobState":
now = datetime.now()
return cls(
job_id=job_id,
created_at=now,
tasks_count=tasks_count,
updates=[
QueryJobUpdate(
time=now,
scope=QueryJobUpdateScope.JOB,
payload=QueryJobStateUpdate(
status=QueryJobStatus.PENDING,
has_remaining_tasks=True,
),
)
],
)

def latest_update(self) -> QueryJobUpdate:
return self.updates[-1]

Expand All @@ -233,6 +260,7 @@ def update(self, update: QueryJobUpdate):
self.has_remaining_tasks = False
elif payload.status == QueryJobStatus.FAILED:
self.has_remaining_tasks = payload.has_remaining_tasks
self.status = payload.status
elif payload.status == QueryJobStatus.RUNNING:
self.status = payload.status
else:
Expand Down
Loading

0 comments on commit 1a9849d

Please sign in to comment.