Skip to content

Commit

Permalink
Added some testing for the service
Browse files Browse the repository at this point in the history
  • Loading branch information
ravenac95 committed Dec 10, 2024
1 parent 14e6a58 commit a7a8ce0
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 4 deletions.
20 changes: 19 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ bokeh = "^3.6.1"
fastapi = {extras = ["standard"], version = "^0.115.6"}
pyee = "^12.1.1"
aiotrino = "^0.2.3"
pytest-asyncio = "^0.24.0"


[tool.poetry.scripts]
Expand Down
8 changes: 6 additions & 2 deletions warehouse/metrics_tools/compute/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,10 @@ async def export_table(table: str):
in_progress.remove(table)

while not self.stop_signal.is_set():
item = await self.export_queue.get()
try:
item = await asyncio.wait_for(self.export_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
if item.table in in_progress:
# The table is already being exported. Skip this in the queue
continue
Expand Down Expand Up @@ -300,7 +303,8 @@ async def resolve_export_references(self, tables: t.List[str]):
registration = None
export_map: t.Dict[str, ExportReference] = {}

for table in tables:
# Ensure we are only comparing unique tables
for table in set(tables):
reference = await self.get_export_table_reference(table)
if reference is not None:
export_map[table] = reference
Expand Down
3 changes: 2 additions & 1 deletion warehouse/metrics_tools/compute/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def __init__(
async def handle_query_job_submit_request(
self, job_id: str, result_path_base: str, input: QueryJobSubmitRequest
):
await self._notify_job_pending(job_id, 1)
try:
await self._handle_query_job_submit_request(job_id, result_path_base, input)
except Exception as e:
Expand Down Expand Up @@ -198,6 +197,7 @@ async def _handle_query_job_submit_request(

async def close(self):
await self.cluster_manager.close()
await self.cache_manager.stop()
self.stop_event.set()

async def start_cluster(self, start_request: ClusterStartRequest) -> ClusterStatus:
Expand All @@ -219,6 +219,7 @@ async def submit_job(self, input: QueryJobSubmitRequest):
f"gs://{self.gcs_bucket}", result_path_base, "*.parquet"
)

await self._notify_job_pending(job_id, 1)
task = asyncio.create_task(
self.handle_query_job_submit_request(job_id, result_path_base, input)
)
Expand Down
33 changes: 33 additions & 0 deletions warehouse/metrics_tools/compute/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import asyncio
from unittest.mock import AsyncMock
from metrics_tools.compute.types import ExportReference, ExportType
import pytest

from metrics_tools.compute.cache import CacheExportManager, FakeExportAdapter


@pytest.mark.asyncio
async def test_cache_export_manager():
adapter_mock = AsyncMock(FakeExportAdapter)
adapter_mock.export_table.return_value = ExportReference(
table="test",
type=ExportType.GCS,
payload={},
)
cache = await CacheExportManager.setup(adapter_mock)

export_table_0 = await asyncio.wait_for(
cache.resolve_export_references(["table1", "table2"]), timeout=1
)

assert export_table_0.keys() == {"table1", "table2"}

# Attempt to export tables again but this should be mostly cache hits except
# for table3
export_table_1 = await asyncio.wait_for(
cache.resolve_export_references(["table1", "table2", "table1", "table3"]),
timeout=1,
)
assert export_table_1.keys() == {"table1", "table2", "table3"}

assert adapter_mock.export_table.call_count == 3
69 changes: 69 additions & 0 deletions warehouse/metrics_tools/compute/test_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import asyncio

from dask.distributed import Client
import pytest

from metrics_tools.compute.cluster import (
ClusterManager,
ClusterProxy,
ClusterFactory,
ClusterStatus,
)


class FakeClient(Client):
def __init__(self):
pass

def close(self, *args, **kwargs):
pass

def register_worker_plugin(self, *args, **kwargs):
pass


class FakeClusterProxy(ClusterProxy):
def __init__(self, min_size: int, max_size: int):
self.min_size = min_size
self.max_size = max_size

async def client(self) -> Client:
return await FakeClient()

async def status(self):
return ClusterStatus(
status="running",
is_ready=True,
dashboard_url="",
workers=1,
)

async def stop(self):
return

@property
def dashboard_link(self):
return "http://fake-dashboard.com"

@property
def workers(self):
return 1


class FakeClusterFactory(ClusterFactory):
async def create_cluster(self, min_size: int, max_size: int):
return FakeClusterProxy(min_size, max_size)


@pytest.mark.asyncio
async def test_cluster_manager_reports_ready():
cluster_manager = ClusterManager.with_dummy_metrics_plugin(FakeClusterFactory())

ready_future = cluster_manager.wait_for_ready()

await cluster_manager.start_cluster(1, 1)

try:
await asyncio.wait_for(ready_future, timeout=1)
except asyncio.TimeoutError:
pytest.fail("Cluster never reported ready")
66 changes: 66 additions & 0 deletions warehouse/metrics_tools/compute/test_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from metrics_tools.compute.types import (
ClusterStartRequest,
ExportReference,
ExportType,
QueryJobStatus,
QueryJobSubmitRequest,
)
from metrics_tools.definition import PeerMetricDependencyRef
import pytest
import asyncio

from metrics_tools.compute.service import MetricsCalculationService
from metrics_tools.compute.cluster import ClusterManager, LocalClusterFactory
from metrics_tools.compute.cache import CacheExportManager, FakeExportAdapter
from datetime import datetime


@pytest.mark.asyncio
async def test_metrics_calculation_service():
service = MetricsCalculationService.setup(
"someid",
"bucket",
"result_path_prefix",
ClusterManager.with_dummy_metrics_plugin(LocalClusterFactory()),
await CacheExportManager.setup(FakeExportAdapter()),
)
await service.start_cluster(ClusterStartRequest(min_size=1, max_size=1))
await service.add_existing_exported_table_references(
{
"source.table123": ExportReference(
table="export_table123",
type=ExportType.GCS,
payload={"gcs_path": "gs://bucket/result_path_prefix/export_table123"},
),
}
)
response = await service.submit_job(
QueryJobSubmitRequest(
query_str="SELECT * FROM ref.table123",
start=datetime(2021, 1, 1),
end=datetime(2021, 1, 3),
dialect="duckdb",
batch_size=1,
columns=[("col1", "int"), ("col2", "string")],
ref=PeerMetricDependencyRef(
name="test",
entity_type="artifact",
window=30,
unit="day",
),
locals={},
dependent_tables_map={"source.table123": "source.table123"},
)
)

async def wait_for_job_to_complete():
status = await service.get_job_status(response.job_id)
while status.status in [QueryJobStatus.PENDING, QueryJobStatus.RUNNING]:
status = await service.get_job_status(response.job_id)
await asyncio.sleep(1)

await asyncio.wait_for(asyncio.create_task(wait_for_job_to_complete()), timeout=60)
status = await service.get_job_status(response.job_id)
assert status.status == QueryJobStatus.COMPLETED

await service.close()

0 comments on commit a7a8ce0

Please sign in to comment.