From 9de7648fc91b0b435f8fa2f034bb855da66fdfcc Mon Sep 17 00:00:00 2001 From: Reuven Gonzales Date: Thu, 12 Dec 2024 14:01:35 -0800 Subject: [PATCH] SQLMesh + Metrics Calculation Service (#2628) * random cleanup * server clean up for mcs * Ready for testing in a real deployment * clean up * further cleanup before deployment * love me some docs * re-enable all metrics * remove custom materialization * small fix --- poetry.lock | 12 +- pyproject.toml | 1 + .../metrics_mesh/models/metrics_factories.py | 25 +-- warehouse/metrics_tools/compute/cache.py | 48 ++-- warehouse/metrics_tools/compute/client.py | 1 + warehouse/metrics_tools/compute/cluster.py | 4 +- warehouse/metrics_tools/compute/constants.py | 13 +- .../metrics_tools/compute/log_config.yaml | 43 ++++ .../compute/manual_testing_utils.py | 7 +- warehouse/metrics_tools/compute/result.py | 205 +++++++++++++++++ warehouse/metrics_tools/compute/server.py | 99 +++++---- warehouse/metrics_tools/compute/service.py | 137 +++++++++--- warehouse/metrics_tools/compute/test_cache.py | 13 +- .../metrics_tools/compute/test_service.py | 9 +- warehouse/metrics_tools/compute/types.py | 80 ++++++- warehouse/metrics_tools/compute/worker.py | 2 +- warehouse/metrics_tools/definition.py | 9 - warehouse/metrics_tools/factory/factory.py | 206 ++++++++++-------- warehouse/metrics_tools/models.py | 7 +- 19 files changed, 674 insertions(+), 247 deletions(-) create mode 100644 warehouse/metrics_tools/compute/log_config.yaml create mode 100644 warehouse/metrics_tools/compute/result.py diff --git a/poetry.lock b/poetry.lock index 47cd26bf4..c293da683 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7053,20 +7053,20 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "uvicorn" -version = "0.32.0" +version = "0.32.1" description = "The lightning-fast ASGI server." optional = false python-versions = ">=3.8" files = [ - {file = "uvicorn-0.32.0-py3-none-any.whl", hash = "sha256:60b8f3a5ac027dcd31448f411ced12b5ef452c646f76f02f8cc3f25d8d26fd82"}, - {file = "uvicorn-0.32.0.tar.gz", hash = "sha256:f78b36b143c16f54ccdb8190d0a26b5f1901fe5a3c777e1ab29f26391af8551e"}, + {file = "uvicorn-0.32.1-py3-none-any.whl", hash = "sha256:82ad92fd58da0d12af7482ecdb5f2470a04c9c9a53ced65b9bbb4a205377602e"}, + {file = "uvicorn-0.32.1.tar.gz", hash = "sha256:ee9519c246a72b1c084cea8d3b44ed6026e78a4a309cbedae9c37e4cb9fbb175"}, ] [package.dependencies] click = ">=7.0" colorama = {version = ">=0.4", optional = true, markers = "sys_platform == \"win32\" and extra == \"standard\""} h11 = ">=0.8" -httptools = {version = ">=0.5.0", optional = true, markers = "extra == \"standard\""} +httptools = {version = ">=0.6.3", optional = true, markers = "extra == \"standard\""} python-dotenv = {version = ">=0.13", optional = true, markers = "extra == \"standard\""} pyyaml = {version = ">=5.1", optional = true, markers = "extra == \"standard\""} uvloop = {version = ">=0.14.0,<0.15.0 || >0.15.0,<0.15.1 || >0.15.1", optional = true, markers = "(sys_platform != \"win32\" and sys_platform != \"cygwin\") and platform_python_implementation != \"PyPy\" and extra == \"standard\""} @@ -7074,7 +7074,7 @@ watchfiles = {version = ">=0.13", optional = true, markers = "extra == \"standar websockets = {version = ">=10.4", optional = true, markers = "extra == \"standard\""} [package.extras] -standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] +standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] [[package]] name = "uvloop" @@ -7699,4 +7699,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.12,<3.13" -content-hash = "028048c5dc20685000a5eb0b290d69be45f247ccd606b2bb06eb15e55b828c2a" +content-hash = "26ed7870bfdb868a65db3093bcf04bf48ec2bb9cf4c7c9a766b0ce9f8a237ceb" diff --git a/pyproject.toml b/pyproject.toml index bf63fba0d..2b86735f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ pyee = "^12.1.1" aiotrino = "^0.2.3" pytest-asyncio = "^0.24.0" isort = "^5.13.2" +uvicorn = { extras = ["standard"], version = "^0.32.1" } [tool.poetry.scripts] diff --git a/warehouse/metrics_mesh/models/metrics_factories.py b/warehouse/metrics_mesh/models/metrics_factories.py index 808f29240..2af63c518 100644 --- a/warehouse/metrics_mesh/models/metrics_factories.py +++ b/warehouse/metrics_mesh/models/metrics_factories.py @@ -1,8 +1,4 @@ -from metrics_tools.factory import ( - timeseries_metrics, - MetricQueryDef, - RollingConfig, -) +from metrics_tools.factory import MetricQueryDef, RollingConfig, timeseries_metrics timeseries_metrics( start="2015-01-01", @@ -148,7 +144,7 @@ "commits_rolling": MetricQueryDef( ref="commits.sql", rolling=RollingConfig( - windows=[180], + windows=[10], unit="day", cron="@daily", ), @@ -261,23 +257,6 @@ ), entity_types=["artifact", "project", "collection"], ), - # "libin": MetricQueryDef( - # ref="libin.sql", - # vars={ - # "activity_event_types": [ - # "COMMIT_CODE", - # "ISSUE_OPENED", - # "PULL_REQUEST_OPENED", - # "PULL_REQUEST_MERGED", - # ], - # }, - # rolling=RollingConfig( - # windows=[30, 90, 180], - # unit="day", - # cron="@daily", - # ), - # entity_types=["artifact"], - # ), "funding_received": MetricQueryDef( ref="funding_received.sql", rolling=RollingConfig( diff --git a/warehouse/metrics_tools/compute/cache.py b/warehouse/metrics_tools/compute/cache.py index 8ca6356dd..2f51f7f4c 100644 --- a/warehouse/metrics_tools/compute/cache.py +++ b/warehouse/metrics_tools/compute/cache.py @@ -5,6 +5,7 @@ import queue import typing as t import uuid +from datetime import datetime from aiotrino.dbapi import Connection from pydantic import BaseModel @@ -12,7 +13,7 @@ from sqlglot import exp from sqlmesh.core.dialect import parse_one -from .types import ExportReference, ExportType +from .types import ColumnsDefinition, ExportReference, ExportType logger = logging.getLogger(__name__) @@ -23,11 +24,14 @@ class ExportCacheCompletedQueueItem(BaseModel): class ExportCacheQueueItem(BaseModel): + execution_time: datetime table: str class DBExportAdapter(abc.ABC): - async def export_table(self, table: str) -> ExportReference: + async def export_table( + self, table: str, execution_time: datetime + ) -> ExportReference: raise NotImplementedError() async def clean_export_table(self, table: str): @@ -38,10 +42,15 @@ class FakeExportAdapter(DBExportAdapter): def __init__(self, log_override: t.Optional[logging.Logger] = None): self.logger = log_override or logger - async def export_table(self, table: str) -> ExportReference: + async def export_table( + self, table: str, execution_time: datetime + ) -> ExportReference: self.logger.info(f"fake exporting table: {table}") return ExportReference( - table=table, type=ExportType.GCS, payload={"gcs_path": "fake_path:{table}"} + table_name=table, + type=ExportType.GCS, + payload={"gcs_path": "fake_path:{table}"}, + columns=ColumnsDefinition(columns=[]), ) async def clean_export_table(self, table: str): @@ -63,7 +72,9 @@ def __init__( self.hive_schema = hive_schema self.logger = log_override or logger - async def export_table(self, table: str) -> ExportReference: + async def export_table( + self, table: str, execution_time: datetime + ) -> ExportReference: columns: t.List[t.Tuple[str, str]] = [] col_result = await self.run_query(f"SHOW COLUMNS FROM {table}") @@ -77,7 +88,9 @@ async def export_table(self, table: str) -> ExportReference: self.logger.debug(f"retrieved columns for {table} export: {columns}") export_table_name = f"export_{table_exp.this.this}_{uuid.uuid4().hex}" - gcs_path = f"gs://{self.gcs_bucket}/trino-export/{export_table_name}/" + # We make cleaning easier by using the execution time to allow listing + # of the export tables + gcs_path = f"gs://{self.gcs_bucket}/trino-export/{execution_time.strftime('%Y/%m/%d/%H')}/{export_table_name}/" # We use a little bit of a hybrid templating+sqlglot magic to generate # the create and insert queries. This saves us having to figure out the @@ -135,7 +148,10 @@ async def export_table(self, table: str) -> ExportReference: await self.run_query(insert_query.sql(dialect="trino")) return ExportReference( - table=table, type=ExportType.GCS, payload={"gcs_path": gcs_path} + table_name=table, + type=ExportType.GCS, + payload={"gcs_path": gcs_path}, + columns=ColumnsDefinition(columns=columns, dialect="trino"), ) async def run_query(self, query: str): @@ -237,9 +253,9 @@ async def stop(self): async def export_queue_loop(self): in_progress: t.Set[str] = set() - async def export_table(table: str): + async def export_table(table: str, execution_time: datetime): try: - return await self._export_table_for_cache(table) + return await self._export_table_for_cache(table, execution_time) except Exception as e: self.logger.error(f"Error exporting table {table}: {e}") in_progress.remove(table) @@ -253,7 +269,7 @@ async def export_table(table: str): # The table is already being exported. Skip this in the queue continue in_progress.add(item.table) - export_reference = await export_table(item.table) + export_reference = await export_table(item.table, item.execution_time) self.event_emitter.emit( "exported_table", table=item.table, export_reference=export_reference ) @@ -281,17 +297,19 @@ async def get_export_table_reference(self, table: str): return None return copy.deepcopy(reference) - async def _export_table_for_cache(self, table: str): + async def _export_table_for_cache(self, table: str, execution_time: datetime): """Triggers an export of a table to a cache location in GCS. This does this by using the Hive catalog in trino to create a new table with the same schema as the original table, but with a different name. This new table is then used as the cache location for the original table.""" - export_reference = await self.export_adapter.export_table(table) + export_reference = await self.export_adapter.export_table(table, execution_time) self.logger.info(f"exported table: {table} -> {export_reference}") return export_reference - async def resolve_export_references(self, tables: t.List[str]): + async def resolve_export_references( + self, tables: t.List[str], execution_time: datetime + ): """Resolves any required export table references or queues up a list of tables to be exported to a cache location. Once ready, the map of tables is resolved.""" @@ -331,5 +349,7 @@ async def handle_exported_table( "exported_table", handle_exported_table ) for table in tables_to_export: - self.export_queue.put_nowait(ExportCacheQueueItem(table=table)) + self.export_queue.put_nowait( + ExportCacheQueueItem(table=table, execution_time=execution_time) + ) return await future diff --git a/warehouse/metrics_tools/compute/client.py b/warehouse/metrics_tools/compute/client.py index aaaaca9fc..a4472db30 100644 --- a/warehouse/metrics_tools/compute/client.py +++ b/warehouse/metrics_tools/compute/client.py @@ -173,6 +173,7 @@ def submit_job( locals=locals, dependent_tables_map=dependent_tables_map, retries=retries, + execution_time=datetime.now(), ) job_response = self.service_post_with_input( QueryJobSubmitResponse, "/job/submit", request diff --git a/warehouse/metrics_tools/compute/cluster.py b/warehouse/metrics_tools/compute/cluster.py index 98dadc12d..8ee488a05 100644 --- a/warehouse/metrics_tools/compute/cluster.py +++ b/warehouse/metrics_tools/compute/cluster.py @@ -431,10 +431,10 @@ def make_new_cluster_with_defaults(): from . import constants return make_new_cluster( - f"{constants.cluster_worker_image_repo}:{constants.cluster_worker_image_tag}", + f"{constants.cluster_image_repo}:{constants.cluster_image_tag}", constants.cluster_name, constants.cluster_namespace, - threads=constants.cluster_worker_threads, + threads=constants.worker_threads, scheduler_memory_limit=constants.scheduler_memory_limit, scheduler_memory_request=constants.scheduler_memory_request, worker_memory_limit=constants.worker_memory_limit, diff --git a/warehouse/metrics_tools/compute/constants.py b/warehouse/metrics_tools/compute/constants.py index 17eeaa1c8..38ed37385 100644 --- a/warehouse/metrics_tools/compute/constants.py +++ b/warehouse/metrics_tools/compute/constants.py @@ -5,15 +5,15 @@ cluster_namespace = env.required_str("METRICS_CLUSTER_NAMESPACE") cluster_name = env.required_str("METRICS_CLUSTER_NAME") -cluster_worker_threads = env.required_int("METRICS_CLUSTER_WORKER_THREADS", 16) -cluster_worker_image_repo = env.required_str( - "METRICS_CLUSTER_WORKER_IMAGE_REPO", "ghcr.io/opensource-observer/dagster-dask" +cluster_image_repo = env.required_str( + "METRICS_CLUSTER_IMAGE_REPO", "ghcr.io/opensource-observer/dagster-dask" ) -cluster_worker_image_tag = env.required_str("METRICS_CLUSTER_WORKER_IMAGE_TAG") +cluster_image_tag = env.required_str("METRICS_CLUSTER_IMAGE_TAG") scheduler_memory_limit = env.required_str("METRICS_SCHEDULER_MEMORY_LIMIT", "90000Mi") scheduler_memory_request = env.required_str( "METRICS_SCHEDULER_MEMORY_REQUEST", "85000Mi" ) +worker_threads = env.required_int("METRICS_WORKER_THREADS", 16) worker_memory_limit = env.required_str("METRICS_WORKER_MEMORY_LIMIT", "90000Mi") worker_memory_request = env.required_str("METRICS_WORKER_MEMORY_REQUEST", "85000Mi") @@ -21,9 +21,7 @@ gcs_key_id = env.required_str("METRICS_GCS_KEY_ID") gcs_secret = env.required_str("METRICS_GCS_SECRET") -results_path_prefix = env.required_str( - "METRICS_GCS_RESULTS_PATH_PREFIX", "metrics-calc-service-results" -) +results_path_prefix = env.required_str("METRICS_GCS_RESULTS_PATH_PREFIX", "mcs-results") worker_duckdb_path = env.required_str("METRICS_WORKER_DUCKDB_PATH") @@ -36,6 +34,7 @@ hive_schema = env.required_str("METRICS_HIVE_SCHEMA", "export") debug_all = env.ensure_bool("METRICS_DEBUG_ALL", False) +debug_with_duckdb = env.ensure_bool("METRICS_DEBUG_WITH_DUCKDB", False) if not debug_all: debug_cache = env.ensure_bool("METRICS_DEBUG_CACHE", False) debug_cluster = env.ensure_bool("METRICS_DEBUG_CLUSTER", False) diff --git a/warehouse/metrics_tools/compute/log_config.yaml b/warehouse/metrics_tools/compute/log_config.yaml new file mode 100644 index 000000000..27a6aa17b --- /dev/null +++ b/warehouse/metrics_tools/compute/log_config.yaml @@ -0,0 +1,43 @@ +# Default log configuration for the metrics calculation service. This can be +# used by uvicorn Thanks to: +# https://gist.github.com/liviaerxin/d320e33cbcddcc5df76dd92948e5be3b for a +# starting point. +version: 1 +disable_existing_loggers: False +formatters: + default: + # "()": uvicorn.logging.DefaultFormatter + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + access: + # "()": uvicorn.logging.AccessFormatter + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' +handlers: + default: + formatter: default + class: logging.StreamHandler + stream: ext://sys.stderr + access: + formatter: access + class: logging.StreamHandler + stream: ext://sys.stdout +loggers: + uvicorn.error: + level: INFO + handlers: + - default + propagate: no + uvicorn.access: + level: INFO + handlers: + - access + propagate: no + metrics_tools: + level: DEBUG + handlers: + - default + propagate: no +root: + level: ERROR + handlers: + - default + propagate: no \ No newline at end of file diff --git a/warehouse/metrics_tools/compute/manual_testing_utils.py b/warehouse/metrics_tools/compute/manual_testing_utils.py index bfba44453..62232736c 100644 --- a/warehouse/metrics_tools/compute/manual_testing_utils.py +++ b/warehouse/metrics_tools/compute/manual_testing_utils.py @@ -14,6 +14,7 @@ from ..definition import PeerMetricDependencyRef from .types import ( ClusterStartRequest, + ColumnsDefinition, ExportedTableLoadRequest, ExportReference, ExportType, @@ -32,11 +33,12 @@ def run_cache_load(url: str): req = ExportedTableLoadRequest( map={ "sqlmesh__metrics.metrics__events_daily_to_artifact__2357434958": ExportReference( - table="export_metrics__events_daily_to_artifact__2357434958_5def5e890a984cf99f7364ce3c2bb958", + table_name="export_metrics__events_daily_to_artifact__2357434958_5def5e890a984cf99f7364ce3c2bb958", type=ExportType.GCS, payload={ "gcs_path": "gs://oso-dataset-transfer-bucket/trino-export/export_metrics__events_daily_to_artifact__2357434958_5def5e890a984cf99f7364ce3c2bb958" }, + columns=ColumnsDefinition(columns=[]), ), } ) @@ -66,11 +68,12 @@ def run_local_test( client.run_cache_manual_load( { "sqlmesh__metrics.metrics__events_daily_to_artifact__2357434958": ExportReference( - table="export_metrics__events_daily_to_artifact__2357434958_5def5e890a984cf99f7364ce3c2bb958", + table_name="export_metrics__events_daily_to_artifact__2357434958_5def5e890a984cf99f7364ce3c2bb958", type=ExportType.GCS, payload={ "gcs_path": "gs://oso-dataset-transfer-bucket/trino-export/export_metrics__events_daily_to_artifact__2357434958_5def5e890a984cf99f7364ce3c2bb958" }, + columns=ColumnsDefinition(columns=[]), ), } ) diff --git a/warehouse/metrics_tools/compute/result.py b/warehouse/metrics_tools/compute/result.py new file mode 100644 index 000000000..ff2a04268 --- /dev/null +++ b/warehouse/metrics_tools/compute/result.py @@ -0,0 +1,205 @@ +""" +For now the results for the metrics calculations are stored in a gcs bucket. We +can list all of those results and deterministically resolve those to trino +tables as well. +""" + +import abc +import logging +import os +import typing as t +from datetime import datetime + +import numpy as np +import pandas as pd +from aiotrino.dbapi import Connection +from sqlglot import exp +from sqlmesh.core.dialect import parse_one + +from .types import ExportReference, ExportType + +logger = logging.getLogger(__name__) + + +class DBImportAdapter(abc.ABC): + async def import_reference(self, reference: ExportReference) -> ExportReference: + raise NotImplementedError() + + async def translate_reference(self, reference: ExportReference) -> ExportReference: + raise NotImplementedError() + + async def clean(self, table: str): + raise NotImplementedError() + + async def clean_expired(self, expiration: datetime): + """Used to clean old imported tables that might not be needed. This is + not required to do anything all import adapters""" + return + + +class DummyImportAdapter(DBImportAdapter): + """A dummy import adapter that does nothing. This is useful for testing + basic operations of the service""" + + async def import_reference(self, reference: ExportReference) -> ExportReference: + return reference + + async def translate_reference(self, reference: ExportReference) -> ExportReference: + return reference + + async def clean(self, table: str): + pass + + async def clean_expired(self, expiration: datetime): + pass + + +class FakeLocalImportAdapter(DBImportAdapter): + """A fake import adapter that writes random data to a temporary directory. + This allows us to use this with duckdb for testing purposes""" + + def __init__( + self, + temp_dir: str, + log_override: t.Optional[logging.Logger] = None, + ): + self.temp_dir = temp_dir + self.logger = log_override or logger + + async def import_reference(self, reference: ExportReference) -> ExportReference: + self.logger.info(f"Importing reference {reference}") + translated_ref = await self.translate_reference(reference) + + # Convert reference.columns into pandas DataFrame columns + df = reference.columns.to_pandas() + self.logger.info(f"Created DataFrame with columns: {df.dtypes}") + + # Convert duckdb types to pandas types + self.logger.info(f"Converted DataFrame types: {df.dtypes}") + + # Generate random data for each column based on its type + fake_data_size = 100 + for column_name, column_type in reference.columns.columns_as_pandas_dtypes(): + if column_type.upper() == "bool": + df[column_name] = np.random.choice([True, False], size=fake_data_size) + elif column_type.upper() in ["int", "int8", "int16", "int32", "int64"]: + df[column_name] = np.random.randint(0, 100, size=fake_data_size) + elif column_type.upper() in ["float", "float32", "float64"]: + df[column_name] = np.random.random(size=fake_data_size) + elif column_type.upper() == ["object"]: + df[column_name] = np.random.choice( + ["oso", "random", "fake", "data", "foo", "bar", "baz"], + size=fake_data_size, + ) + elif column_type.upper() in ["datetime64[ns]"]: + df[column_name] = pd.to_datetime( + np.random.choice( + pd.date_range("2024-01-01", "2025-01-01", periods=100), + size=fake_data_size, + ) + ) + else: + df[column_name] = np.random.choice(["unknown"], size=fake_data_size) + + # Write the DataFrame to a parquet file in the temporary directory + parquet_file_path = translated_ref.payload["local_path"] + df.to_parquet(parquet_file_path) + self.logger.debug(f"Written DataFrame to parquet file: {parquet_file_path}") + + # Update the reference payload with the parquet file path + reference.payload["parquet_file_path"] = parquet_file_path + + return ExportReference( + table_name=reference.table_name, + type=ExportType.LOCALFS, + columns=reference.columns, + payload={"local_path": parquet_file_path}, + ) + + async def translate_reference(self, reference: ExportReference) -> ExportReference: + self.logger.info(f"Translating reference {reference}") + parquet_file_path = f"{self.temp_dir}/{reference.table_name}.parquet" + return ExportReference( + table_name=reference.table_name, + type=ExportType.LOCALFS, + columns=reference.columns, + payload={"local_path": parquet_file_path}, + ) + + +class TrinoImportAdapter(DBImportAdapter): + def __init__( + self, + db: Connection, + gcs_bucket: str, + hive_catalog: str, + hive_schema: str, + log_override: t.Optional[logging.Logger] = None, + ): + self.db = db + self.gcs_bucket = gcs_bucket + self.hive_catalog = hive_catalog + self.hive_schema = hive_schema + self.logger = log_override or logger + + async def import_reference(self, reference: ExportReference) -> ExportReference: + self.logger.info(f"Importing reference {reference}") + if reference.type != ExportType.GCS: + raise NotImplementedError(f"Unsupported reference type {reference.type}") + + # Import the table from gcs into trino using the hive catalog + import_path = reference.payload["gcs_path"] + # If we are using a wildcard path, we need to remove the wildcard for + # trino and keep a trailing slash + if os.path.basename(import_path) == "*.parquet": + import_path = f"{os.path.dirname(import_path)}/" + + base_create_query = f""" + CREATE table "{self.hive_catalog}"."{self.hive_schema}"."{reference.table_name}" ( + placeholder VARCHAR, + ) WITH ( + format = 'PARQUET', + external_location = '{import_path}/' + ) + """ + create_query = parse_one(base_create_query) + create_query.this.set( + "expressions", + [ + exp.ColumnDef( + this=exp.to_identifier(column_name), + kind=parse_one(column_type, into=exp.DataType), + ) + for column_name, column_type in reference.columns + ], + ) + await self.run_query(create_query.sql(dialect="trino")) + + return ExportReference( + catalog_name=self.hive_catalog, + schema_name=self.hive_schema, + table_name=reference.table_name, + type=ExportType.TRINO, + columns=reference.columns, + payload={}, + ) + + async def translate_reference(self, reference: ExportReference) -> ExportReference: + self.logger.info(f"Translating reference {reference}") + if reference.type != ExportType.GCS: + raise NotImplementedError(f"Unsupported reference type {reference.type}") + + return ExportReference( + catalog_name=self.hive_catalog, + schema_name=self.hive_schema, + table_name=reference.table_name, + type=ExportType.TRINO, + columns=reference.columns, + payload={}, + ) + + async def run_query(self, query: str): + cursor = await self.db.cursor() + self.logger.info(f"EXECUTING: {query}") + await cursor.execute(query) + return await cursor.fetchall() diff --git a/warehouse/metrics_tools/compute/server.py b/warehouse/metrics_tools/compute/server.py index ff9f10f71..3ff1d9468 100644 --- a/warehouse/metrics_tools/compute/server.py +++ b/warehouse/metrics_tools/compute/server.py @@ -1,4 +1,6 @@ import logging +import os +import tempfile import typing as t import uuid from contextlib import asynccontextmanager @@ -6,7 +8,11 @@ import aiotrino from dotenv import load_dotenv from fastapi import FastAPI, Request -from metrics_tools.utils.logging import setup_module_logging +from metrics_tools.compute.result import ( + DummyImportAdapter, + FakeLocalImportAdapter, + TrinoImportAdapter, +) from . import constants from .cache import setup_fake_cache_export_manager, setup_trino_cache_export_manager @@ -25,21 +31,21 @@ ) load_dotenv() -logger = logging.getLogger("uvicorn.error.application") +logger = logging.getLogger(__name__) @asynccontextmanager async def initialize_app(app: FastAPI): - # logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) - setup_module_logging("metrics_tools") - - logger.setLevel(logging.DEBUG) - logger.info("Metrics calculation service is starting up") if constants.debug_all: logger.warning("Debugging all services") cache_export_manager = None + temp_dir = None + + if constants.debug_with_duckdb: + temp_dir = tempfile.mkdtemp() + logger.debug(f"Created temp dir {temp_dir}") if not constants.debug_cache: trino_connection = aiotrino.dbapi.connect( host=constants.trino_host, @@ -52,13 +58,22 @@ async def initialize_app(app: FastAPI): constants.gcs_bucket, constants.hive_catalog, constants.hive_schema, - log_override=logger, ) - else: - logger.warning("Loading fake cache export manager") - cache_export_manager = await setup_fake_cache_export_manager( - log_override=logger + import_adapter = TrinoImportAdapter( + db=trino_connection, + gcs_bucket=constants.gcs_bucket, + hive_catalog=constants.hive_catalog, + hive_schema=constants.hive_schema, ) + else: + if constants.debug_with_duckdb: + assert temp_dir is not None + logger.warning("Loading fake cache export manager with duckdb") + import_adapter = FakeLocalImportAdapter(temp_dir) + else: + logger.warning("Loading dummy cache export manager (writes nothing)") + import_adapter = DummyImportAdapter() + cache_export_manager = await setup_fake_cache_export_manager() cluster_manager = None if not constants.debug_cluster: @@ -66,7 +81,6 @@ async def initialize_app(app: FastAPI): cluster_factory = KubeClusterFactory( constants.cluster_namespace, cluster_spec=cluster_spec, - log_override=logger, shutdown_on_close=not constants.debug_cluster_no_shutdown, ) cluster_manager = ClusterManager.with_metrics_plugin( @@ -75,14 +89,12 @@ async def initialize_app(app: FastAPI): constants.gcs_secret, constants.worker_duckdb_path, cluster_factory, - log_override=logger, ) else: logger.warning("Loading fake cluster manager") cluster_factory = LocalClusterFactory() cluster_manager = ClusterManager.with_dummy_metrics_plugin( cluster_factory, - log_override=logger, ) mcs = MetricsCalculationService.setup( @@ -91,21 +103,25 @@ async def initialize_app(app: FastAPI): result_path_prefix=constants.results_path_prefix, cluster_manager=cluster_manager, cache_manager=cache_export_manager, - log_override=logger, + import_adapter=import_adapter, ) try: yield { - "mca": mcs, + "mcs": mcs, } finally: + logger.info("Waiting for metrics calculation service to close") await mcs.close() + if temp_dir: + logger.info("Removing temp dir") + os.rmdir(temp_dir) # Dependency to get the cluster manager -def get_mca(request: Request) -> MetricsCalculationService: - mca = request.state.mca - assert mca is not None - return t.cast(MetricsCalculationService, mca) +def get_mcs(request: Request) -> MetricsCalculationService: + mcs = request.state.mcs + assert mcs is not None + return t.cast(MetricsCalculationService, mcs) app = FastAPI(lifespan=initialize_app) @@ -113,9 +129,7 @@ def get_mca(request: Request) -> MetricsCalculationService: @app.get("/status") async def get_status(): - """ - Liveness endpoint - """ + """Liveness endpoint""" return {"status": "Service is running"} @@ -124,30 +138,27 @@ async def start_cluster( request: Request, start_request: ClusterStartRequest, ): + """Start a Dask cluster in an idempotent way. + + If the cluster is already running, it will not be restarted. """ - Start a Dask cluster in an idempotent way - """ - state = get_mca(request) + state = get_mcs(request) manager = state.cluster_manager return await manager.start_cluster(start_request.min_size, start_request.max_size) @app.post("/cluster/stop") async def stop_cluster(request: Request): - """ - Stop the Dask cluster - """ - state = get_mca(request) + """Stop the Dask cluster""" + state = get_mcs(request) manager = state.cluster_manager return await manager.stop_cluster() @app.get("/cluster/status") async def get_cluster_status(request: Request): - """ - Get the current Dask cluster status - """ - state = get_mca(request) + """Get the current Dask cluster status""" + state = get_mcs(request) manager = state.cluster_manager return await manager.get_cluster_status() @@ -157,10 +168,8 @@ async def submit_job( request: Request, input: QueryJobSubmitRequest, ): - """ - Submits a Dask job for calculation - """ - service = get_mca(request) + """Submits a Dask job for calculation""" + service = get_mcs(request) return await service.submit_job(input) @@ -169,11 +178,9 @@ async def get_job_status( request: Request, job_id: str, ): - """ - Get the status of a job - """ + """Get the status of a job""" include_stats = request.query_params.get("include_stats", "false").lower() == "true" - service = get_mca(request) + service = get_mcs(request) return await service.get_job_status(job_id, include_stats=include_stats) @@ -181,9 +188,7 @@ async def get_job_status( async def add_existing_exported_table_references( request: Request, input: ExportedTableLoadRequest ): - """ - Add a table export to the cache - """ - service = get_mca(request) + """Add a table export to the cache""" + service = get_mcs(request) await service.add_existing_exported_table_references(input.map) return EmptyResponse() diff --git a/warehouse/metrics_tools/compute/service.py b/warehouse/metrics_tools/compute/service.py index fd6ace95b..9147b5015 100644 --- a/warehouse/metrics_tools/compute/service.py +++ b/warehouse/metrics_tools/compute/service.py @@ -9,6 +9,7 @@ from datetime import datetime from dask.distributed import CancelledError, Future +from metrics_tools.compute.result import DBImportAdapter from metrics_tools.compute.worker import execute_duckdb_load from metrics_tools.runner import FakeEngineAdapter, MetricsRunner @@ -17,6 +18,7 @@ from .types import ( ClusterStartRequest, ClusterStatus, + ColumnsDefinition, ExportReference, ExportType, QueryJobProgress, @@ -32,6 +34,22 @@ logger.setLevel(logging.DEBUG) +class JobFailed(Exception): + pass + + +class JobTasksFailed(JobFailed): + exceptions: t.List[Exception] + failures: int + + def __init__(self, job_id: str, failures: int, exceptions: t.List[Exception]): + self.failures = failures + self.exceptions = exceptions + super().__init__( + f"job[{job_id}] failed with {failures} failures and {len(exceptions)} exceptions" + ) + + class MetricsCalculationService: id: str gcs_bucket: str @@ -50,6 +68,7 @@ def setup( result_path_prefix: str, cluster_manager: ClusterManager, cache_manager: CacheExportManager, + import_adapter: DBImportAdapter, log_override: t.Optional[logging.Logger] = None, ): service = cls( @@ -58,9 +77,9 @@ def setup( result_path_prefix, cluster_manager, cache_manager, + import_adapter=import_adapter, log_override=log_override, ) - # service.start_job_state_listener() return service def __init__( @@ -70,6 +89,7 @@ def __init__( result_path_prefix: str, cluster_manager: ClusterManager, cache_manager: CacheExportManager, + import_adapter: DBImportAdapter, log_override: t.Optional[logging.Logger] = None, ): self.id = id @@ -77,16 +97,23 @@ def __init__( self.result_path_prefix = result_path_prefix self.cluster_manager = cluster_manager self.cache_manager = cache_manager + self.import_adapter = import_adapter self.job_state = {} self.job_tasks = {} self.job_state_lock = asyncio.Lock() self.logger = log_override or logger async def handle_query_job_submit_request( - self, job_id: str, result_path_base: str, input: QueryJobSubmitRequest + self, + job_id: str, + result_path_base: str, + input: QueryJobSubmitRequest, + export_reference: ExportReference, ): try: - await self._handle_query_job_submit_request(job_id, result_path_base, input) + await self._handle_query_job_submit_request( + job_id, result_path_base, input, export_reference + ) except Exception as e: self.logger.error(f"job[{job_id}] failed with exception: {e}") await self._notify_job_failed(job_id, 0, 0) @@ -96,25 +123,61 @@ async def _handle_query_job_submit_request( job_id: str, result_path_base: str, input: QueryJobSubmitRequest, + export_reference: ExportReference, ): - self.logger.info(f"job[{job_id}] waiting for cluster to be ready") + self.logger.debug(f"job[{job_id}] waiting for cluster to be ready") await self.cluster_manager.wait_for_ready() - self.logger.info(f"job[{job_id}] cluster ready") + self.logger.debug(f"job[{job_id}] cluster ready") - client = await self.cluster_manager.client - self.logger.info(f"job[{job_id}] waiting for dependencies to be exported") + self.logger.debug(f"job[{job_id}] waiting for dependencies to be exported") exported_dependent_tables_map = await self.resolve_dependent_tables(input) - self.logger.info(f"job[{job_id}] dependencies exported") + self.logger.debug(f"job[{job_id}] dependencies exported") + + tasks = await self._batch_query_to_scheduler( + job_id, result_path_base, input, exported_dependent_tables_map + ) + + total = len(tasks) + completed = 0 + + # In the future we should replace this with the python 3.13 version of + # this. + try: + await self._monitor_query_task_progress(job_id, tasks) + except JobTasksFailed as e: + exceptions = e.exceptions + self.logger.error(e) + await self._notify_job_failed(job_id, completed, total) + if len(exceptions) > 0: + for e in exceptions: + self.logger.error(f"job[{job_id}] exception received: {e}") + raise e + + # Import the final result into the database + self.logger.info("job[{job_id}]: importing final result into the database") + await self.import_adapter.import_reference(export_reference) + self.logger.debug(f"job[{job_id}]: notifying job completed") + await self._notify_job_completed(job_id, completed, total) + + async def _batch_query_to_scheduler( + self, + job_id: str, + result_path_base: str, + input: QueryJobSubmitRequest, + exported_dependent_tables_map: t.Dict[str, ExportReference], + ): + """Given a query job: break down into batches and submit to the scheduler""" tasks: t.List[Future] = [] + client = await self.cluster_manager.client async for batch_id, batch in self.generate_query_batches( input, input.batch_size ): task_id = f"{job_id}-{batch_id}" result_path = os.path.join(result_path_base, job_id, f"{batch_id}.parquet") - self.logger.info(f"job[{job_id}]: Submitting task {task_id}") + self.logger.debug(f"job[{job_id}]: Submitting task {task_id}") # dependencies = { # table: to_jsonable_python(reference) @@ -131,9 +194,11 @@ async def _handle_query_job_submit_request( retries=input.retries, ) - self.logger.info(f"job[{job_id}]: Submitted task {task_id}") + self.logger.debug(f"job[{job_id}]: Submitted task {task_id}") tasks.append(task) + return tasks + async def _monitor_query_task_progress(self, job_id: str, tasks: t.List[Future]): total = len(tasks) completed = 0 failures = 0 @@ -141,14 +206,13 @@ async def _handle_query_job_submit_request( # In the future we should replace this with the python 3.13 version of # this. - for finished in asyncio.as_completed(tasks): try: task_id = await finished completed += 1 self.logger.info(f"job[{job_id}] progress: {completed}/{total}") await self._notify_job_updated(job_id, completed, total) - self.logger.info( + self.logger.debug( f"job[{job_id}] finished notifying update: {completed}/{total}" ) except CancelledError as e: @@ -160,21 +224,12 @@ async def _handle_query_job_submit_request( exceptions.append(e) self.logger.error(f"job[{job_id}] task failed with exception: {e}") continue - self.logger.info(f"job[{job_id}] awaiting finished") + self.logger.debug(f"job[{job_id}] awaiting finished") await self._notify_job_updated(job_id, completed, total) self.logger.info(f"job[{job_id}] task_id={task_id} finished") if failures > 0: - self.logger.error( - f"job[{job_id}] {failures} tasks failed. received {len(exceptions)} exceptions" - ) - await self._notify_job_failed(job_id, completed, total) - if len(exceptions) > 0: - for e in exceptions: - self.logger.error(f"job[{job_id}] exception received: {e}") - else: - self.logger.info(f"job[{job_id}]: done") - await self._notify_job_completed(job_id, completed, total) + raise JobTasksFailed(job_id, failures, exceptions) async def close(self): await self.cluster_manager.close() @@ -194,25 +249,43 @@ async def submit_job(self, input: QueryJobSubmitRequest): self.logger.debug("submitting job") job_id = str(uuid.uuid4()) - result_path_base = os.path.join(self.result_path_prefix, job_id) + # Files are organized in a way that can be searched by date such that we + # can easily clean old files + result_path_base = os.path.join( + self.result_path_prefix, + input.execution_time.strftime("%Y/%m/%d/%H"), + job_id, + ) result_path = os.path.join( - f"gs://{self.gcs_bucket}", result_path_base, "*.parquet" + f"gs://{self.gcs_bucket}", + result_path_base, + "*.parquet", + ) + + final_expected_reference = await self.import_adapter.translate_reference( + ExportReference( + table_name=job_id, + type=ExportType.GCS, + columns=ColumnsDefinition(columns=input.columns, dialect=input.dialect), + payload={"gcs_path": result_path}, + ) ) await self._notify_job_pending(job_id, 1) task = asyncio.create_task( - self.handle_query_job_submit_request(job_id, result_path_base, input) + self.handle_query_job_submit_request( + job_id, + result_path_base, + input, + final_expected_reference, + ) ) async with self.job_state_lock: self.job_tasks[job_id] = task return QueryJobSubmitResponse( job_id=job_id, - export_reference=ExportReference( - table=job_id, - type=ExportType.GCS, - payload={"gcs_path": result_path}, - ), + export_reference=final_expected_reference, ) async def _notify_job_pending(self, job_id: str, total: int): @@ -335,7 +408,7 @@ async def resolve_dependent_tables(self, input: QueryJobSubmitRequest): # First use the cache manager to resolve the export references references = await self.cache_manager.resolve_export_references( - tables_to_export + tables_to_export, input.execution_time ) self.logger.debug(f"resolved references: {references}") diff --git a/warehouse/metrics_tools/compute/test_cache.py b/warehouse/metrics_tools/compute/test_cache.py index 860b82b85..17a22ae50 100644 --- a/warehouse/metrics_tools/compute/test_cache.py +++ b/warehouse/metrics_tools/compute/test_cache.py @@ -1,23 +1,26 @@ import asyncio +from datetime import datetime from unittest.mock import AsyncMock import pytest from metrics_tools.compute.cache import CacheExportManager, FakeExportAdapter -from metrics_tools.compute.types import ExportReference, ExportType +from metrics_tools.compute.types import ColumnsDefinition, ExportReference, ExportType @pytest.mark.asyncio async def test_cache_export_manager(): adapter_mock = AsyncMock(FakeExportAdapter) adapter_mock.export_table.return_value = ExportReference( - table="test", + table_name="test", type=ExportType.GCS, + columns=ColumnsDefinition(columns=[]), payload={}, ) cache = await CacheExportManager.setup(adapter_mock) + execution_time = datetime.now() export_table_0 = await asyncio.wait_for( - cache.resolve_export_references(["table1", "table2"]), timeout=1 + cache.resolve_export_references(["table1", "table2"], execution_time), timeout=1 ) assert export_table_0.keys() == {"table1", "table2"} @@ -25,7 +28,9 @@ async def test_cache_export_manager(): # Attempt to export tables again but this should be mostly cache hits except # for table3 export_table_1 = await asyncio.wait_for( - cache.resolve_export_references(["table1", "table2", "table1", "table3"]), + cache.resolve_export_references( + ["table1", "table2", "table1", "table3"], execution_time + ), timeout=1, ) assert export_table_1.keys() == {"table1", "table2", "table3"} diff --git a/warehouse/metrics_tools/compute/test_service.py b/warehouse/metrics_tools/compute/test_service.py index cea67002a..deb25a459 100644 --- a/warehouse/metrics_tools/compute/test_service.py +++ b/warehouse/metrics_tools/compute/test_service.py @@ -4,9 +4,11 @@ import pytest from metrics_tools.compute.cache import CacheExportManager, FakeExportAdapter from metrics_tools.compute.cluster import ClusterManager, LocalClusterFactory +from metrics_tools.compute.result import DummyImportAdapter from metrics_tools.compute.service import MetricsCalculationService from metrics_tools.compute.types import ( ClusterStartRequest, + ColumnsDefinition, ExportReference, ExportType, QueryJobStatus, @@ -23,13 +25,17 @@ async def test_metrics_calculation_service(): "result_path_prefix", ClusterManager.with_dummy_metrics_plugin(LocalClusterFactory()), await CacheExportManager.setup(FakeExportAdapter()), + DummyImportAdapter(), ) await service.start_cluster(ClusterStartRequest(min_size=1, max_size=1)) await service.add_existing_exported_table_references( { "source.table123": ExportReference( - table="export_table123", + table_name="export_table123", type=ExportType.GCS, + columns=ColumnsDefinition( + columns=[("col1", "INT"), ("col2", "TEXT")], dialect="duckdb" + ), payload={"gcs_path": "gs://bucket/result_path_prefix/export_table123"}, ), } @@ -48,6 +54,7 @@ async def test_metrics_calculation_service(): window=30, unit="day", ), + execution_time=datetime.now(), locals={}, dependent_tables_map={"source.table123": "source.table123"}, ) diff --git a/warehouse/metrics_tools/compute/types.py b/warehouse/metrics_tools/compute/types.py index 92f1e9259..c4716e873 100644 --- a/warehouse/metrics_tools/compute/types.py +++ b/warehouse/metrics_tools/compute/types.py @@ -3,6 +3,7 @@ from datetime import datetime from enum import Enum +import pandas as pd from metrics_tools.definition import PeerMetricDependencyRef from pydantic import BaseModel, Field from sqlmesh.core.dialect import parse_one @@ -17,13 +18,85 @@ class EmptyResponse(BaseModel): class ExportType(str, Enum): ICEBERG = "iceberg" GCS = "gcs" + TRINO = "trino" + LOCALFS = "localfs" + + +DUCKDB_TO_PANDAS_TYPE_MAP = { + "BOOLEAN": "bool", + "BOOL": "bool", + "TINYINT": "int8", + "INT1": "int8", + "SMALLINT": "int16", + "INT2": "int16", + "INTEGER": "int32", + "INT4": "int32", + "BIGINT": "int64", + "INT8": "int64", + "FLOAT": "float32", + "FLOAT4": "float32", + "DOUBLE": "float64", + "FLOAT8": "float64", + "DATE": "datetime64[ns]", + "TIMESTAMP": "datetime64[ns]", + "DATETIME": "datetime64[ns]", + "VARCHAR": "object", + "CHAR": "object", + "BPCHAR": "object", + "TEXT": "object", + "BLOB": "bytes", + "BYTEA": "bytes", + "NUMERIC": "float64", +} + + +class ColumnsDefinition(BaseModel): + columns: t.List[t.Tuple[str, str]] + dialect: str = "duckdb" + + def columns_as(self, dialect: str) -> t.List[t.Tuple[str, str]]: + return [ + (col_name, parse_one(col_type, dialect=self.dialect).sql(dialect=dialect)) + for col_name, col_type in self.columns + ] + + def __iter__(self): + for col_name, col_type in self.columns: + yield (col_name, col_type) + + def to_pandas(self) -> pd.DataFrame: + """Creates a basic dataframe with the columns defined in this definition + coerced to panda datatypes""" + columns_as_pandas_dtypes = self.columns_as_pandas_dtypes() + df = pd.DataFrame({col_name: [] for col_name, _ in columns_as_pandas_dtypes}) + for col_name, col_type in columns_as_pandas_dtypes: + df[col_name] = df[col_name].astype(col_type) # type: ignore + return df + + def columns_as_pandas_dtypes(self) -> t.List[t.Tuple[str, str]]: + return [ + (col_name, DUCKDB_TO_PANDAS_TYPE_MAP[col_type.upper()]) + for col_name, col_type in self.columns_as("duckdb") + ] class ExportReference(BaseModel): - table: str + catalog_name: t.Optional[str] = None + schema_name: t.Optional[str] = None + columns: ColumnsDefinition + table_name: str type: ExportType payload: t.Dict[str, t.Any] + def table_fqn(self) -> str: + names = [] + if self.catalog_name: + names.append(self.catalog_name) + if self.schema_name: + names.append(self.schema_name) + names.append(self.table_name) + return ".".join(names) + class QueryJobStatus(str, Enum): PENDING = "pending" @@ -61,10 +134,15 @@ class QueryJobSubmitRequest(BaseModel): locals: t.Dict[str, t.Any] dependent_tables_map: t.Dict[str, str] retries: t.Optional[int] = None + execution_time: datetime def query_as(self, dialect: str) -> str: return parse_one(self.query_str, self.dialect).sql(dialect=dialect) + @property + def columns_def(self) -> ColumnsDefinition: + return ColumnsDefinition(columns=self.columns, dialect=self.dialect) + class QueryJobSubmitResponse(BaseModel): job_id: str diff --git a/warehouse/metrics_tools/compute/worker.py b/warehouse/metrics_tools/compute/worker.py index d974bc25d..dfb7c7937 100644 --- a/warehouse/metrics_tools/compute/worker.py +++ b/warehouse/metrics_tools/compute/worker.py @@ -110,7 +110,7 @@ def get_for_cache( ): """Checks if a table is cached in the local duckdb""" logger.info( - f"[{self._uuid}] got a cache request for {table_ref_name}:{export_reference.table}" + f"[{self._uuid}] got a cache request for {table_ref_name}:{export_reference.table_name}" ) assert export_reference.type == ExportType.GCS, "Only GCS exports are supported" assert ( diff --git a/warehouse/metrics_tools/definition.py b/warehouse/metrics_tools/definition.py index 5e91091eb..aa0e3f759 100644 --- a/warehouse/metrics_tools/definition.py +++ b/warehouse/metrics_tools/definition.py @@ -402,12 +402,3 @@ class TimeseriesMetricsOptions(t.TypedDict): start: TimeLike timeseries_sources: t.NotRequired[t.List[str]] queries_dir: t.NotRequired[str] - - -class GeneratedArtifactConfig(t.TypedDict): - query_reference_name: str - query_def_as_input: MetricQueryInput - default_dialect: str - peer_table_tuples: t.List[t.Tuple[str, str]] - ref: PeerMetricDependencyRef - timeseries_sources: t.List[str] diff --git a/warehouse/metrics_tools/factory/factory.py b/warehouse/metrics_tools/factory/factory.py index 4ecf23b4a..495e6cd19 100644 --- a/warehouse/metrics_tools/factory/factory.py +++ b/warehouse/metrics_tools/factory/factory.py @@ -1,49 +1,49 @@ import contextlib -from datetime import datetime import inspect import logging import os -from queue import PriorityQueue -import typing as t import textwrap -from metrics_tools.runner import MetricsRunner -from metrics_tools.transformer.tables import ExecutionContextTableTransform -from metrics_tools.utils.logging import add_metrics_tools_to_sqlmesh_logging -import pandas as pd +import typing as t from dataclasses import dataclass, field +from datetime import datetime +from queue import PriorityQueue -from sqlmesh import ExecutionContext -from sqlmesh.core.macros import MacroEvaluator -from sqlmesh.core.model import ModelKindName +import pandas as pd import sqlglot as sql -from sqlglot import exp - -from metrics_tools.joiner import JoinerTransform -from metrics_tools.transformer import ( - SQLTransformer, - IntermediateMacroEvaluatorTransform, -) -from metrics_tools.transformer.qualify import QualifyTransform +from metrics_tools.compute.client import Client +from metrics_tools.compute.types import ExportType from metrics_tools.definition import ( MetricQuery, PeerMetricDependencyRef, TimeseriesMetricsOptions, reference_to_str, ) -from metrics_tools.models import ( - GeneratedModel, - GeneratedPythonModel, -) +from metrics_tools.joiner import JoinerTransform from metrics_tools.macros import ( metrics_end, + metrics_entity_type_alias, metrics_entity_type_col, metrics_name, + metrics_peer_ref, metrics_sample_date, metrics_start, relative_window_sample_date, - metrics_entity_type_alias, - metrics_peer_ref, ) +from metrics_tools.models import GeneratedModel, GeneratedPythonModel +from metrics_tools.runner import MetricsRunner +from metrics_tools.transformer import ( + IntermediateMacroEvaluatorTransform, + SQLTransformer, +) +from metrics_tools.transformer.qualify import QualifyTransform +from metrics_tools.transformer.tables import ExecutionContextTableTransform +from metrics_tools.utils import env +from metrics_tools.utils.logging import add_metrics_tools_to_sqlmesh_logging +from sqlglot import exp +from sqlmesh import ExecutionContext +from sqlmesh.core.dialect import parse_one +from sqlmesh.core.macros import MacroEvaluator +from sqlmesh.core.model import ModelKindName logger = logging.getLogger(__name__) @@ -381,14 +381,9 @@ def generate_model_for_rendered_query( query = query_config["query"] match query.metric_type: case "rolling": - if query.use_python_model: - self.generate_rolling_python_model_for_rendered_query( - calling_file, query_config, dependencies - ) - else: - self.generate_rolling_model_for_rendered_query( - calling_file, query_config, dependencies - ) + self.generate_rolling_python_model_for_rendered_query( + calling_file, query_config, dependencies + ) case "time_aggregation": self.generate_time_aggregation_model_for_rendered_query( calling_file, query_config, dependencies @@ -444,53 +439,6 @@ def generate_rolling_python_model_for_rendered_query( imports={"pd": pd, "generated_rolling_query": generated_rolling_query}, ) - def generate_rolling_model_for_rendered_query( - self, - calling_file: str, - query_config: MetricQueryConfig, - dependencies: t.Set[str], - ): - config = self.serializable_config(query_config) - - ref = query_config["ref"] - query = query_config["query"] - - columns = METRICS_COLUMNS_BY_ENTITY[ref["entity_type"]] - - kind_common = {"batch_size": 1, "batch_concurrency": 1} - partitioned_by = ("day(metrics_sample_date)",) - window = ref.get("window") - assert window is not None - assert query._source.rolling - cron = query._source.rolling["cron"] - - grain = [ - "metric", - f"to_{ref['entity_type']}_id", - "from_artifact_id", - "event_source", - "metrics_sample_date", - ] - - GeneratedModel.create( - func=generated_query, - entrypoint_path=calling_file, - config=config, - name=f"{self.catalog}.{query_config['table_name']}", - kind={ - "name": ModelKindName.INCREMENTAL_BY_TIME_RANGE, - "time_column": "metrics_sample_date", - **kind_common, - }, - dialect="clickhouse", - columns=columns, - grain=grain, - cron=cron, - start=self._raw_options["start"], - additional_macros=self.generated_model_additional_macros, - partitioned_by=partitioned_by, - ) - def generate_time_aggregation_model_for_rendered_query( self, calling_file: str, @@ -638,7 +586,6 @@ def generated_query( vars: t.Dict[str, t.Any], ): """Simple generated query executor for metrics queries""" - from sqlmesh.core.dialect import parse_one with metric_ref_evaluator_context(evaluator, ref, vars): result = evaluator.transform(parse_one(rendered_query_str)) @@ -655,28 +602,96 @@ def generated_rolling_query( rendered_query_str: str, table_name: str, sqlmesh_vars: t.Dict[str, t.Any], + gateway: str | None, *_ignored, ): + """Generates a rolling query for the given metrics query + + If SQLMESH_MCS_ENABLED is set to true, the query will be sent to the metrics + calculation service. Otherwise, the query will be executed as a rolling + query using dataframes (this can be very slow on remote data sources of + non-trivial size). + + This currently takes advantage of a potential hack in sqlmesh. The snapshot + evaluator nor the scheduler in sqlmesh seem to care what the response is + from the python model as long as it's either a dataframe OR a sqlglot + expression. This means we can return a sqlglot expression that takes the + ExportReference from the metrics calculation service and use it as a table + in the sqlmesh query. + + If configured correctly, the metrics calculation service will export a trino + database in production and if testing, it can export a "LOCALFS" export + which gives you a path to parquet file with random data (that satisfies the + requested schema). + """ # Transform the query for the current context transformer = SQLTransformer(transforms=[ExecutionContextTableTransform(context)]) query = transformer.transform(rendered_query_str) locals = vars.copy() locals.update(sqlmesh_vars) - runner = MetricsRunner.from_sqlmesh_context(context, query, ref, locals) - df = runner.run_rolling(start, end) - # If the rolling window is empty we need to yield from an empty tuple - # otherwise sqlmesh fails. See: - # https://sqlmesh.readthedocs.io/en/latest/concepts/models/python_models/#returning-empty-dataframes - total = 0 - if df.empty: - yield from () + mcs_enabled = env.ensure_bool("SQLMESH_MCS_ENABLED", False) + if not mcs_enabled: + runner = MetricsRunner.from_sqlmesh_context(context, query, ref, locals) + df = runner.run_rolling(start, end) + # If the rolling window is empty we need to yield from an empty tuple + # otherwise sqlmesh fails. See: + # https://sqlmesh.readthedocs.io/en/latest/concepts/models/python_models/#returning-empty-dataframes + total = 0 + if df.empty: + yield from () + else: + count = len(df) + total += count + logger.debug(f"table={table_name} yielding rows {count}") + yield df + logger.debug(f"table={table_name} yielded rows{total}") else: - count = len(df) - total += count - logger.debug(f"table={table_name} yielding rows {count}") - yield df - logger.debug(f"table={table_name} yielded rows{total}") + logger.info("metrics calculation service enabled") + + mcs_url = env.required_str("SQLMESH_MCS_URL") + mcs_client = Client(url=mcs_url) + + columns = [ + (col_name, col_type.sql(dialect="duckdb")) + for col_name, col_type in METRICS_COLUMNS_BY_ENTITY[ + ref["entity_type"] + ].items() + ] + + response = mcs_client.calculate_metrics( + query_str=rendered_query_str, + start=start, + end=end, + dialect="clickhouse", + batch_size=1, + columns=columns, + ref=ref, + locals=sqlmesh_vars, + dependent_tables_map={}, + ) + + column_names = list(map(lambda col: col[0], columns)) + engine_dialect = context.engine_adapter.dialect + + if engine_dialect == "duckdb": + if response.type not in [ExportType.GCS, ExportType.LOCALFS]: + raise Exception(f"ExportType={response.type} not supported for duckdb") + # Create a select query from the exported data + path = response.payload.get("local_path", response.payload.get("gcs_path")) + select_query = exp.select(*column_names).from_( + exp.Anonymous( + this="read_parquet", + expressions=[exp.Literal(this=path, is_string=True)], + ), + ) + elif engine_dialect == "trino": + if response.type not in [ExportType.TRINO]: + raise Exception(f"ExportType={response.type} not supported for trino") + select_query = exp.select(*column_names).from_(response.table_fqn()) + else: + raise Exception(f"Dialect={context.engine_adapter.dialect} not supported") + yield select_query def generated_rolling_query_proxy( @@ -690,7 +705,7 @@ def generated_rolling_query_proxy( table_name: str, sqlmesh_vars: t.Dict[str, t.Any], **kwargs, -) -> t.Iterator[pd.DataFrame]: +) -> t.Iterator[pd.DataFrame | exp.Expression]: """This acts as the proxy to the actual function that we'd call for the metrics model.""" @@ -704,6 +719,7 @@ def generated_rolling_query_proxy( rendered_query_str, table_name, sqlmesh_vars, + context.gateway, # Change the following variable to force reevaluation. Hack for now. - "version=v4", + "version=v5", ) diff --git a/warehouse/metrics_tools/models.py b/warehouse/metrics_tools/models.py index e36c0682a..22fff8646 100644 --- a/warehouse/metrics_tools/models.py +++ b/warehouse/metrics_tools/models.py @@ -4,22 +4,22 @@ import re import textwrap import typing as t -from pathlib import Path import uuid +from pathlib import Path from sqlglot import exp from sqlmesh.core import constants as c from sqlmesh.core.dialect import MacroFunc from sqlmesh.core.macros import ExecutableOrMacro, MacroRegistry, macro from sqlmesh.core.model.decorator import model -from sqlmesh.core.model.definition import create_sql_model, create_python_model +from sqlmesh.core.model.definition import create_python_model, create_sql_model from sqlmesh.utils.jinja import JinjaMacroRegistry from sqlmesh.utils.metaprogramming import ( Executable, ExecutableKind, build_env, - serialize_env, normalize_source, + serialize_env, ) logger = logging.getLogger(__name__) @@ -100,6 +100,7 @@ def model( all_vars = self._variables.copy() global_variables = variables or {} + all_vars.update(global_variables) all_vars["sqlmesh_vars"] = global_variables common_kwargs: t.Dict[str, t.Any] = dict(