From 76ae3d609818c1f8030d9367ea0793fad461d7f2 Mon Sep 17 00:00:00 2001 From: Reuven Gonzales Date: Tue, 17 Dec 2024 12:01:08 -0800 Subject: [PATCH] More sqlmesh run tweaks (#2657) * allow job retries to be configurable * set flux to this branch * temp uninstall * reinstall * ensure that we use scopes to resolve table deps * go spot or go home * more tests * Improve progress reporting and error logging * restore main --- ops/tf-modules/warehouse-cluster/main.tf | 2 +- warehouse/metrics_tools/compute/client.py | 6 + warehouse/metrics_tools/compute/service.py | 294 ++++++++++++------- warehouse/metrics_tools/compute/test_app.py | 3 +- warehouse/metrics_tools/compute/types.py | 131 +++++++-- warehouse/metrics_tools/factory/factory.py | 1 + warehouse/metrics_tools/utils/tables.py | 41 ++- warehouse/metrics_tools/utils/test_tables.py | 44 +++ 8 files changed, 376 insertions(+), 146 deletions(-) diff --git a/ops/tf-modules/warehouse-cluster/main.tf b/ops/tf-modules/warehouse-cluster/main.tf index ac14045e..2589bc87 100644 --- a/ops/tf-modules/warehouse-cluster/main.tf +++ b/ops/tf-modules/warehouse-cluster/main.tf @@ -164,7 +164,7 @@ locals { max_count = 20 local_ssd_count = 0 local_ssd_ephemeral_storage_count = 2 - spot = false + spot = true disk_size_gb = 100 disk_type = "pd-standard" image_type = "COS_CONTAINERD" diff --git a/warehouse/metrics_tools/compute/client.py b/warehouse/metrics_tools/compute/client.py index 5a9c6799..1da49f1e 100644 --- a/warehouse/metrics_tools/compute/client.py +++ b/warehouse/metrics_tools/compute/client.py @@ -190,6 +190,12 @@ def _handler(response: JobStatusResponse): if final_status.status == QueryJobStatus.FAILED: self.logger.error(f"job[{job_id}] failed with status {final_status.status}") + if final_status.exceptions: + self.logger.error(f"job[{job_id}] failed with exceptions") + + for exc in final_status.exceptions: + self.logger.error(f"job[{job_id}] failed with exceptoin {exc}") + raise Exception(f"job[{job_id}] failed with status {final_status.status}") self.logger.info(f"job[{job_id}] completed with status {final_status.status}") diff --git a/warehouse/metrics_tools/compute/service.py b/warehouse/metrics_tools/compute/service.py index 05d6a186..378c0fd3 100644 --- a/warehouse/metrics_tools/compute/service.py +++ b/warehouse/metrics_tools/compute/service.py @@ -8,7 +8,7 @@ import uuid from datetime import datetime -from dask.distributed import CancelledError, Future +from dask.distributed import CancelledError from metrics_tools.compute.result import DBImportAdapter from metrics_tools.compute.worker import execute_duckdb_load from metrics_tools.runner import FakeEngineAdapter, MetricsRunner @@ -25,10 +25,13 @@ JobStatusResponse, JobSubmitRequest, JobSubmitResponse, - QueryJobProgress, QueryJobState, + QueryJobStateUpdate, QueryJobStatus, + QueryJobTaskStatus, + QueryJobTaskUpdate, QueryJobUpdate, + QueryJobUpdateScope, ) logger = logging.getLogger(__name__) @@ -119,7 +122,7 @@ async def handle_query_job_submit_request( ) except Exception as e: self.logger.error(f"job[{job_id}] failed with exception: {e}") - await self._notify_job_failed(job_id, 0, 0) + await self._notify_job_failed(job_id, False, e) async def _handle_query_job_submit_request( self, @@ -138,7 +141,7 @@ async def _handle_query_job_submit_request( exported_dependent_tables_map = await self.resolve_dependent_tables(input) except Exception as e: self.logger.error(f"job[{job_id}] failed to export dependencies: {e}") - await self._notify_job_failed(job_id, 0, 0) + await self._notify_job_failed(job_id, False, e) return self.logger.debug(f"job[{job_id}] dependencies exported") @@ -147,27 +150,30 @@ async def _handle_query_job_submit_request( ) total = len(tasks) - completed = 0 + if total != input.batch_count(): + self.logger.warning("job[{job_id}] batch count mismatch") - # In the future we should replace this with the python 3.13 version of - # this. - try: - await self._monitor_query_task_progress(job_id, tasks) - except JobTasksFailed as e: - exceptions = e.exceptions - self.logger.error(e) - await self._notify_job_failed(job_id, completed, total) - if len(exceptions) > 0: - for e in exceptions: - self.logger.error(f"job[{job_id}] exception received: {e}") - raise e + exceptions = [] + + for next_task in asyncio.as_completed(tasks): + try: + await next_task + except Exception as e: + self.logger.error( + f"job[{job_id}] task failed with uncaught exception: {e}" + ) + exceptions.append(e) + await self._notify_job_failed(job_id, True, e) # Import the final result into the database self.logger.info("job[{job_id}]: importing final result into the database") await self.import_adapter.import_reference(calculation_export, final_export) self.logger.debug(f"job[{job_id}]: notifying job completed") - await self._notify_job_completed(job_id, completed, total) + await self._notify_job_completed(job_id) + + if len(exceptions) > 0: + raise JobTasksFailed(job_id, len(exceptions), exceptions) async def _batch_query_to_scheduler( self, @@ -177,68 +183,70 @@ async def _batch_query_to_scheduler( exported_dependent_tables_map: t.Dict[str, ExportReference], ): """Given a query job: break down into batches and submit to the scheduler""" - tasks: t.List[Future] = [] - client = await self.cluster_manager.client + tasks: t.List[asyncio.Task] = [] + count = 0 async for batch_id, batch in self.generate_query_batches( input, input.batch_size ): + if count == 0: + await self._notify_job_running(job_id) + task_id = f"{job_id}-{batch_id}" result_path = os.path.join(result_path_base, f"{batch_id}.parquet") self.logger.debug(f"job[{job_id}]: Submitting task {task_id}") - # dependencies = { - # table: to_jsonable_python(reference) - # for table, reference in exported_dependent_tables_map.items() - # } - - task = client.submit( - execute_duckdb_load, - job_id, - task_id, - result_path, - batch, - exported_dependent_tables_map, - retries=input.retries, + task = asyncio.create_task( + self._submit_query_task_to_scheduler( + job_id, + task_id, + result_path, + batch, + exported_dependent_tables_map, + retries=3, + ) ) + tasks.append(task) self.logger.debug(f"job[{job_id}]: Submitted task {task_id}") - tasks.append(task) + count += 1 return tasks - async def _monitor_query_task_progress(self, job_id: str, tasks: t.List[Future]): - total = len(tasks) - completed = 0 - failures = 0 - exceptions = [] + async def _submit_query_task_to_scheduler( + self, + job_id: str, + task_id: str, + result_path: str, + batch: t.List[str], + exported_dependent_tables_map: t.Dict[str, ExportReference], + retries: int, + ): + """Submit a single query task to the scheduler""" + client = await self.cluster_manager.client - # In the future we should replace this with the python 3.13 version of - # this. - for finished in asyncio.as_completed(tasks): - try: - task_id = await finished - completed += 1 - self.logger.info(f"job[{job_id}] progress: {completed}/{total}") - await self._notify_job_updated(job_id, completed, total) - self.logger.debug( - f"job[{job_id}] finished notifying update: {completed}/{total}" - ) - except CancelledError as e: - failures += 1 - self.logger.error(f"job[{job_id}] task cancelled {e.args}") - continue - except Exception as e: - failures += 1 - exceptions.append(e) - self.logger.error(f"job[{job_id}] task failed with exception: {e}") - continue - self.logger.debug(f"job[{job_id}] awaiting finished") + task_future = client.submit( + execute_duckdb_load, + job_id, + task_id, + result_path, + batch, + exported_dependent_tables_map, + retries=retries, + key=task_id, + ) - await self._notify_job_updated(job_id, completed, total) - self.logger.info(f"job[{job_id}] task_id={task_id} finished") - if failures > 0: - raise JobTasksFailed(job_id, failures, exceptions) + try: + await task_future + self.logger.info(f"job[{job_id}] task_id={task_id} completed") + await self._notify_job_task_completed(job_id, task_id) + except CancelledError as e: + self.logger.error(f"job[{job_id}] task cancelled {e.args}") + await self._notify_job_task_cancelled(job_id, task_id) + except Exception as e: + self.logger.error(f"job[{job_id}] task failed with exception: {e}") + await self._notify_job_task_failed(job_id, task_id, e) + return task_id async def close(self): await self.cluster_manager.close() @@ -287,7 +295,7 @@ async def submit_job(self, input: JobSubmitRequest): calculation_export ) - await self._notify_job_pending(job_id, 1) + await self._notify_job_pending(job_id, input) task = asyncio.create_task( self.handle_query_job_submit_request( job_id, @@ -305,79 +313,139 @@ async def submit_job(self, input: JobSubmitRequest): export_reference=final_expected_reference, ) - async def _notify_job_pending(self, job_id: str, total: int): - await self._set_job_state( + async def _notify_job_pending(self, job_id: str, input: JobSubmitRequest): + await self._create_job_state( + job_id, + input, + ) + + async def _notify_job_running(self, job_id: str): + await self._update_job_state( + job_id, + QueryJobUpdate( + time=datetime.now(), + scope=QueryJobUpdateScope.JOB, + payload=QueryJobStateUpdate( + status=QueryJobStatus.RUNNING, + has_remaining_tasks=True, + ), + ), + ) + + async def _notify_job_task_completed(self, job_id: str, task_id: str): + await self._update_job_state( + job_id, + QueryJobUpdate( + time=datetime.now(), + scope=QueryJobUpdateScope.TASK, + payload=QueryJobTaskUpdate( + task_id=task_id, + status=QueryJobTaskStatus.SUCCEEDED, + ), + ), + ) + + async def _notify_job_task_failed( + self, job_id: str, task_id: str, exception: Exception + ): + await self._update_job_state( job_id, QueryJobUpdate( - updated_at=datetime.now(), - status=QueryJobStatus.PENDING, - progress=QueryJobProgress(completed=0, total=total), + time=datetime.now(), + scope=QueryJobUpdateScope.TASK, + payload=QueryJobTaskUpdate( + task_id=task_id, + status=QueryJobTaskStatus.FAILED, + exception=str(exception), + ), ), ) - async def _notify_job_updated(self, job_id: str, completed: int, total: int): - await self._set_job_state( + async def _notify_job_task_cancelled(self, job_id: str, task_id: str): + await self._update_job_state( job_id, QueryJobUpdate( - updated_at=datetime.now(), - status=QueryJobStatus.RUNNING, - progress=QueryJobProgress(completed=completed, total=total), + time=datetime.now(), + scope=QueryJobUpdateScope.TASK, + payload=QueryJobTaskUpdate( + task_id=task_id, + status=QueryJobTaskStatus.CANCELLED, + ), ), ) - async def _notify_job_completed(self, job_id: str, completed: int, total: int): - await self._set_job_state( + async def _notify_job_completed(self, job_id: str): + await self._update_job_state( job_id, QueryJobUpdate( - updated_at=datetime.now(), - status=QueryJobStatus.COMPLETED, - progress=QueryJobProgress(completed=completed, total=total), + time=datetime.now(), + scope=QueryJobUpdateScope.JOB, + payload=QueryJobStateUpdate( + status=QueryJobStatus.COMPLETED, + has_remaining_tasks=False, + ), ), ) - async def _notify_job_failed(self, job_id: str, completed: int, total: int): - await self._set_job_state( + async def _notify_job_failed( + self, + job_id: str, + has_remaining_tasks: bool, + exception: t.Optional[Exception] = None, + ): + await self._update_job_state( job_id, QueryJobUpdate( - updated_at=datetime.now(), - status=QueryJobStatus.FAILED, - progress=QueryJobProgress(completed=completed, total=total), + time=datetime.now(), + scope=QueryJobUpdateScope.JOB, + payload=QueryJobStateUpdate( + status=QueryJobStatus.FAILED, + has_remaining_tasks=has_remaining_tasks, + exception=str(exception) if exception else None, + ), ), ) - async def _set_job_state( + async def _create_job_state(self, job_id: str, input: JobSubmitRequest): + async with self.job_state_lock: + now = datetime.now() + self.job_state[job_id] = QueryJobState( + job_id=job_id, + created_at=now, + tasks_count=input.batch_count(), + updates=[ + QueryJobUpdate( + time=now, + scope=QueryJobUpdateScope.JOB, + payload=QueryJobStateUpdate( + status=QueryJobStatus.PENDING, + has_remaining_tasks=True, + ), + ) + ], + ) + + state = self.job_state[job_id] + self.emit_job_state(job_id, state) + + async def _update_job_state( self, job_id: str, update: QueryJobUpdate, ): - self.logger.debug(f"job[{job_id}] status={update.status}") + self.logger.debug(f"job[{job_id}] status={update.payload.status}") async with self.job_state_lock: - if update.status == QueryJobStatus.PENDING: - self.job_state[job_id] = QueryJobState( - job_id=job_id, - created_at=update.updated_at, - updates=[update], - ) - else: - state = self.job_state.get(job_id) - if not state: - raise ValueError(f"Job {job_id} not found") - - state.updates.append(update) - self.job_state[job_id] = state - - if ( - update.status == QueryJobStatus.COMPLETED - or update.status == QueryJobStatus.FAILED - ): - del self.job_tasks[job_id] - updated_state = copy.deepcopy(self.job_state[job_id]) - - self.logger.info("emitting job update events") - # Some things listen to all job updates - self.emitter.emit("job_update", job_id, updated_state) - # Some things listen to specific job updates - self.emitter.emit(f"job_update:{job_id}", updated_state) + state = self.job_state.get(job_id) + assert state is not None, f"job[{job_id}] not found" + state.update(update) + self.job_state[job_id] = state + self.emit_job_state(job_id, state) + + def emit_job_state(self, job_id: str, state: QueryJobState): + copied_state = copy.deepcopy(state) + self.logger.info("emitting job update events") + self.emitter.emit("job_update", job_id, copied_state) + self.emitter.emit(f"job_update:{job_id}", copied_state) async def _get_job_state(self, job_id: str): """Get the current state of a job as a deep copy (to prevent diff --git a/warehouse/metrics_tools/compute/test_app.py b/warehouse/metrics_tools/compute/test_app.py index 207b7e29..00983b9a 100644 --- a/warehouse/metrics_tools/compute/test_app.py +++ b/warehouse/metrics_tools/compute/test_app.py @@ -102,5 +102,6 @@ def test_app_with_all_debugging(app_client_with_all_debugging): progress_handler=mock_handler, ) - assert mock_handler.call_count == 6 + # The pending to running update, and the 3 completion updates + assert mock_handler.call_count == 4 assert reference is not None diff --git a/warehouse/metrics_tools/compute/types.py b/warehouse/metrics_tools/compute/types.py index 762d7c99..13d6c21b 100644 --- a/warehouse/metrics_tools/compute/types.py +++ b/warehouse/metrics_tools/compute/types.py @@ -1,4 +1,5 @@ import logging +import math import typing as t from datetime import datetime from enum import Enum @@ -107,15 +108,46 @@ class QueryJobStatus(str, Enum): FAILED = "failed" +class QueryJobTaskStatus(str, Enum): + SUCCEEDED = "SUCCEEDED" + CANCELLED = "cancelled" + FAILED = "failed" + + +class QueryJobUpdateScope(str, Enum): + TASK = "task" + JOB = "job" + + class QueryJobProgress(BaseModel): completed: int total: int -class QueryJobUpdate(BaseModel): - updated_at: datetime +class QueryJobTaskUpdate(BaseModel): + type: t.Literal[QueryJobUpdateScope.TASK] = QueryJobUpdateScope.TASK + status: QueryJobTaskStatus + task_id: str + exception: t.Optional[str] = None + + +class QueryJobStateUpdate(BaseModel): + type: t.Literal[QueryJobUpdateScope.JOB] = QueryJobUpdateScope.JOB status: QueryJobStatus - progress: QueryJobProgress + has_remaining_tasks: bool + exception: t.Optional[str] = None + + +QueryJobUpdateTypes = t.Union[ + QueryJobTaskUpdate, + QueryJobStateUpdate, +] + + +class QueryJobUpdate(BaseModel): + time: datetime + scope: QueryJobUpdateScope + payload: QueryJobUpdateTypes = Field(discriminator="type") class ClusterStatus(BaseModel): @@ -151,6 +183,15 @@ def query_as(self, dialect: str) -> str: def columns_def(self) -> ColumnsDefinition: return ColumnsDefinition(columns=self.columns, dialect=self.dialect) + def batch_count(self): + """The expected number of batches for this job" + + This is calculated by getting the range (inclusive) between the start + and end and dividing by the batch size. + """ + inclusive_day_length = (self.end - self.start).days + 1 + return math.ceil(inclusive_day_length / self.batch_size) + class JobSubmitResponse(BaseModel): type: t.Literal["JobSubmitResponse"] = "JobSubmitResponse" @@ -166,17 +207,45 @@ class JobStatusResponse(BaseModel): status: QueryJobStatus progress: QueryJobProgress stats: t.Dict[str, float] = Field(default_factory=dict) + exceptions: t.List[str] = Field(default_factory=list) class QueryJobState(BaseModel): job_id: str created_at: datetime + tasks_count: int + tasks_completed: int = 0 + has_remaining_tasks: bool = True + status: QueryJobStatus = QueryJobStatus.PENDING updates: t.List[QueryJobUpdate] def latest_update(self) -> QueryJobUpdate: return self.updates[-1] - def as_response(self, include_stats: bool = False) -> JobStatusResponse: + def update(self, update: QueryJobUpdate): + """Add an update to the job state and change any relevant job state""" + self.updates.append(update) + if update.scope == QueryJobUpdateScope.JOB: + payload = t.cast(QueryJobStateUpdate, update.payload) + if payload.status == QueryJobStatus.COMPLETED: + if self.status != QueryJobStatus.FAILED: + self.status = QueryJobStatus.COMPLETED + self.has_remaining_tasks = False + elif payload.status == QueryJobStatus.FAILED: + self.has_remaining_tasks = payload.has_remaining_tasks + elif payload.status == QueryJobStatus.RUNNING: + self.status = payload.status + else: + payload = t.cast(QueryJobTaskUpdate, update.payload) + if payload.status == QueryJobTaskStatus.FAILED: + self.status = QueryJobStatus.FAILED + elif payload.status == QueryJobTaskStatus.CANCELLED: + self.status = QueryJobStatus.FAILED + self.tasks_completed += 1 + + def as_response( + self, include_stats: bool = False, include_exceptions_count: int = 5 + ) -> JobStatusResponse: # Turn update events into stats stats = {} if include_stats: @@ -186,20 +255,22 @@ def as_response(self, include_stats: bool = False) -> JobStatusResponse: running_to_failed = None for update in self.updates: - if ( - update.status == QueryJobStatus.RUNNING - and pending_to_running is None - ): - pending_to_running = update.updated_at - elif ( - update.status == QueryJobStatus.COMPLETED - and running_to_completed is None - ): - running_to_completed = update.updated_at - elif ( - update.status == QueryJobStatus.FAILED and running_to_failed is None - ): - running_to_failed = update.updated_at + match update.scope: + case QueryJobUpdateScope.JOB: + if update.payload.status == QueryJobStatus.RUNNING: + pending_to_running = update.time + elif update.payload.status == QueryJobStatus.COMPLETED: + running_to_completed = update.time + elif update.payload.status == QueryJobStatus.FAILED: + if running_to_failed is None: + running_to_failed = update.time + case QueryJobUpdateScope.TASK: + if update.payload.status == QueryJobTaskStatus.FAILED: + if running_to_failed is None: + running_to_failed = update.time + elif update.payload.status == QueryJobTaskStatus.CANCELLED: + if running_to_failed is None: + running_to_failed = update.time if pending_to_running: stats["pending_to_running_seconds"] = ( @@ -217,14 +288,32 @@ def as_response(self, include_stats: bool = False) -> JobStatusResponse: if pending_to_running else None ) + exceptions: t.List[str] = [] + if self.status == QueryJobStatus.FAILED: + for update in reversed(self.updates): + if update.scope == QueryJobUpdateScope.TASK: + payload = t.cast(QueryJobTaskUpdate, update.payload) + if payload.status == QueryJobTaskStatus.FAILED: + if payload.exception: + exceptions.append(payload.exception) + else: + payload = t.cast(QueryJobStateUpdate, update.payload) + if payload.exception: + exceptions.append(payload.exception) + if len(exceptions) >= include_exceptions_count: + break return JobStatusResponse( job_id=self.job_id, created_at=self.created_at, - updated_at=self.latest_update().updated_at, - status=self.latest_update().status, - progress=self.latest_update().progress, + updated_at=self.latest_update().time, + status=self.status, + progress=QueryJobProgress( + completed=self.tasks_completed, + total=self.tasks_count, + ), stats=stats, + exceptions=exceptions, ) diff --git a/warehouse/metrics_tools/factory/factory.py b/warehouse/metrics_tools/factory/factory.py index 357690e1..757d437d 100644 --- a/warehouse/metrics_tools/factory/factory.py +++ b/warehouse/metrics_tools/factory/factory.py @@ -682,6 +682,7 @@ def generated_rolling_query( dependent_tables_map=create_dependent_tables_map( context, rendered_query_str ), + job_retries=env.ensure_int("SQLMESH_MCS_JOB_RETRIES", 5), cluster_min_size=env.ensure_int("SQLMESH_MCS_CLUSTER_MIN_SIZE", 0), cluster_max_size=env.ensure_int("SQLMESH_MCS_CLUSTER_MAX_SIZE", 30), ) diff --git a/warehouse/metrics_tools/utils/tables.py b/warehouse/metrics_tools/utils/tables.py index e0104203..798a7734 100644 --- a/warehouse/metrics_tools/utils/tables.py +++ b/warehouse/metrics_tools/utils/tables.py @@ -1,6 +1,7 @@ import typing as t from sqlglot import exp +from sqlglot.optimizer.scope import Scope, build_scope from sqlmesh import ExecutionContext from sqlmesh.core.dialect import parse_one @@ -13,19 +14,39 @@ def resolve_identifier_or_string(i: exp.Expression | str) -> t.Optional[str]: return None +def resolve_table_name( + context: ExecutionContext, table: exp.Table +) -> t.Tuple[str, str]: + table_name_parts = map(resolve_identifier_or_string, table.parts) + table_name_parts = filter(None, table_name_parts) + table_fqn = ".".join(table_name_parts) + + return (table_fqn, context.table(table_fqn)) + + +def resolve_table_map_from_scope( + context: ExecutionContext, scope: Scope +) -> t.Dict[str, str]: + current_tables_map = {} + for source_name, source in scope.sources.items(): + if isinstance(source, exp.Table): + local_name, actual_name = resolve_table_name(context, source) + current_tables_map[local_name] = actual_name + elif isinstance(source, Scope): + parent_tables_map = resolve_table_map_from_scope(context, source) + current_tables_map.update(parent_tables_map) + else: + raise ValueError(f"Unsupported source type: {type(source)}") + return current_tables_map + + def create_dependent_tables_map( context: ExecutionContext, query_str: str ) -> t.Dict[str, str]: query = parse_one(query_str) - tables = query.find_all(exp.Table) - - tables_map: t.Dict[str, str] = {} - - for table in tables: - table_name_parts = map(resolve_identifier_or_string, table.parts) - table_name_parts = filter(None, table_name_parts) - table_fqn = ".".join(table_name_parts) - - tables_map[table_fqn] = context.table(table_fqn) + scope = build_scope(query) + if not scope: + raise ValueError("Failed to build scope") + tables_map = resolve_table_map_from_scope(context, scope) return tables_map diff --git a/warehouse/metrics_tools/utils/test_tables.py b/warehouse/metrics_tools/utils/test_tables.py index a448a871..7d4c0a74 100644 --- a/warehouse/metrics_tools/utils/test_tables.py +++ b/warehouse/metrics_tools/utils/test_tables.py @@ -1,5 +1,7 @@ +import typing as t from unittest.mock import MagicMock +import pytest from metrics_tools.utils.tables import create_dependent_tables_map @@ -12,3 +14,45 @@ def test_create_dependent_tables_map(): "foo": "test_table", } assert actual_tables_map == expected_tables_map + + +@pytest.mark.parametrize( + "input,expected", + [ + ("select * from foo.bar", {"foo.bar": "test_table"}), + ("select * from foo", {"foo": "test_table"}), + ( + "select * from foo.bar, bar.foo", + {"foo.bar": "test_table", "bar.foo": "test_table"}, + ), + ( + """ + with foo as ( + select * from bar + ) + select * from foo + """, + {"bar": "test_table"}, + ), + ( + """ + with grandfoo as ( + select * from main.source + ) + with foo as ( + select * from grandfoo + ) + select * from foo + """, + {"main.source": "test_table"}, + ), + ], +) +def test_create_dependent_tables_map_parameterized( + input: str, expected: t.Dict[str, str] +): + mock = MagicMock(name="context") + mock.table.return_value = "test_table" + + actual_tables_map = create_dependent_tables_map(mock, input) + assert actual_tables_map == expected