diff --git a/.lintstagedrc b/.lintstagedrc index d022da589..ad24a00f7 100644 --- a/.lintstagedrc +++ b/.lintstagedrc @@ -1,19 +1,20 @@ { - "**/!(.eslintrc)*.{js,jsx,ts,tsx,sol}": [ - "eslint --ignore-path .gitignore --max-warnings 0", - "prettier --ignore-path .gitignore --write", - "prettier --ignore-path .gitignore --log-level warn --check" - ], - "**/*.{md,json}": [ - "prettier --ignore-path .gitignore --write", - "prettier --ignore-path .gitignore --log-level warn --check" - ], - "**/*.py": [ - "poetry run ruff check --fix --force-exclude", - "pnpm pyright" - ], - "warehouse/dbt/**/*.sql": [ - "poetry run sqlfluff fix -f", - "poetry run sqlfluff lint" - ] + "**/!(.eslintrc)*.{js,jsx,ts,tsx,sol}": [ + "eslint --ignore-path .gitignore --max-warnings 0", + "prettier --ignore-path .gitignore --write", + "prettier --ignore-path .gitignore --log-level warn --check" + ], + "**/*.{md,json}": [ + "prettier --ignore-path .gitignore --write", + "prettier --ignore-path .gitignore --log-level warn --check" + ], + "**/*.py": [ + "poetry run ruff check --fix --force-exclude", + "poetry run isort", + "pnpm pyright" + ], + "warehouse/dbt/**/*.sql": [ + "poetry run sqlfluff fix -f", + "poetry run sqlfluff lint" + ] } \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index d252ab649..47cd26bf4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -170,6 +170,25 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "aiotrino" +version = "0.2.3" +description = "ASyncIO Client for the Trino distributed SQL Engine" +optional = false +python-versions = "*" +files = [ + {file = "aiotrino-0.2.3-py3-none-any.whl", hash = "sha256:bef7b55036ea07a371b2d8c2bd47486d1299df0a2c7aeb48f70a97f951747673"}, + {file = "aiotrino-0.2.3.tar.gz", hash = "sha256:ae64015b9cf23370d107a1917e75eb23ca13f830b31b95d2b721c9daffa4882f"}, +] + +[package.dependencies] +aiohttp = "*" + +[package.extras] +all = ["requests-kerberos"] +kerberos = ["requests-kerberos"] +tests = ["aioresponses", "click", "mock", "pytest", "pytest-aiohttp", "pytest-asyncio", "pytest-runner", "pytz", "requests-kerberos"] + [[package]] name = "alembic" version = "1.13.3" @@ -1724,6 +1743,26 @@ sqlalchemy = ["alembic (>1.10.0)", "sqlalchemy (>=1.4)"] synapse = ["adlfs (>=2022.4.0)", "pyarrow (>=12.0.0)", "pyodbc (>=4.0.39)"] weaviate = ["weaviate-client (>=3.22)"] +[[package]] +name = "dnspython" +version = "2.7.0" +description = "DNS toolkit" +optional = false +python-versions = ">=3.9" +files = [ + {file = "dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86"}, + {file = "dnspython-2.7.0.tar.gz", hash = "sha256:ce9c432eda0dc91cf618a5cedf1a4e142651196bbcd2c80e89ed5a907e5cfaf1"}, +] + +[package.extras] +dev = ["black (>=23.1.0)", "coverage (>=7.0)", "flake8 (>=7)", "hypercorn (>=0.16.0)", "mypy (>=1.8)", "pylint (>=3)", "pytest (>=7.4)", "pytest-cov (>=4.1.0)", "quart-trio (>=0.11.0)", "sphinx (>=7.2.0)", "sphinx-rtd-theme (>=2.0.0)", "twine (>=4.0.0)", "wheel (>=0.42.0)"] +dnssec = ["cryptography (>=43)"] +doh = ["h2 (>=4.1.0)", "httpcore (>=1.0.0)", "httpx (>=0.26.0)"] +doq = ["aioquic (>=1.0.0)"] +idna = ["idna (>=3.7)"] +trio = ["trio (>=0.23)"] +wmi = ["wmi (>=1.5.1)"] + [[package]] name = "docstring-parser" version = "0.16" @@ -1807,6 +1846,21 @@ files = [ {file = "durationpy-0.9.tar.gz", hash = "sha256:fd3feb0a69a0057d582ef643c355c40d2fa1c942191f914d12203b1a01ac722a"}, ] +[[package]] +name = "email-validator" +version = "2.2.0" +description = "A robust email address syntax and deliverability validation library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "email_validator-2.2.0-py3-none-any.whl", hash = "sha256:561977c2d73ce3611850a06fa56b414621e0c8faa9d66f2611407d87465da631"}, + {file = "email_validator-2.2.0.tar.gz", hash = "sha256:cb690f344c617a714f22e66ae771445a1ceb46821152df8e165c5f9a364582b7"}, +] + +[package.dependencies] +dnspython = ">=2.0.0" +idna = ">=2.0.0" + [[package]] name = "executing" version = "2.1.0" @@ -1821,6 +1875,51 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] +[[package]] +name = "fastapi" +version = "0.115.6" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fastapi-0.115.6-py3-none-any.whl", hash = "sha256:e9240b29e36fa8f4bb7290316988e90c381e5092e0cbe84e7818cc3713bcf305"}, + {file = "fastapi-0.115.6.tar.gz", hash = "sha256:9ec46f7addc14ea472958a96aae5b5de65f39721a46aaf5705c480d9a8b76654"}, +] + +[package.dependencies] +email-validator = {version = ">=2.0.0", optional = true, markers = "extra == \"standard\""} +fastapi-cli = {version = ">=0.0.5", extras = ["standard"], optional = true, markers = "extra == \"standard\""} +httpx = {version = ">=0.23.0", optional = true, markers = "extra == \"standard\""} +jinja2 = {version = ">=2.11.2", optional = true, markers = "extra == \"standard\""} +pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" +python-multipart = {version = ">=0.0.7", optional = true, markers = "extra == \"standard\""} +starlette = ">=0.40.0,<0.42.0" +typing-extensions = ">=4.8.0" +uvicorn = {version = ">=0.12.0", extras = ["standard"], optional = true, markers = "extra == \"standard\""} + +[package.extras] +all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] + +[[package]] +name = "fastapi-cli" +version = "0.0.6" +description = "Run and manage FastAPI apps from the command line with FastAPI CLI. 🚀" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fastapi_cli-0.0.6-py3-none-any.whl", hash = "sha256:43288efee46338fae8902f9bf4559aed3aed639f9516f5d394a7ff19edcc8faf"}, + {file = "fastapi_cli-0.0.6.tar.gz", hash = "sha256:2835a8f0c44b68e464d5cafe5ec205265f02dc1ad1d640db33a994ba3338003b"}, +] + +[package.dependencies] +rich-toolkit = ">=0.11.1" +typer = ">=0.12.3" +uvicorn = {version = ">=0.15.0", extras = ["standard"]} + +[package.extras] +standard = ["uvicorn[standard] (>=0.15.0)"] + [[package]] name = "filelock" version = "3.16.1" @@ -3084,6 +3183,20 @@ files = [ [package.dependencies] six = "*" +[[package]] +name = "isort" +version = "5.13.2" +description = "A Python utility / library to sort Python imports." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, + {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, +] + +[package.extras] +colors = ["colorama (>=0.4.6)"] + [[package]] name = "jedi" version = "0.19.1" @@ -5090,6 +5203,23 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pyee" +version = "12.1.1" +description = "A rough port of Node.js's EventEmitter to Python with a few tricks of its own" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyee-12.1.1-py3-none-any.whl", hash = "sha256:18a19c650556bb6b32b406d7f017c8f513aceed1ef7ca618fb65de7bd2d347ef"}, + {file = "pyee-12.1.1.tar.gz", hash = "sha256:bbc33c09e2ff827f74191e3e5bbc6be7da02f627b7ec30d86f5ce1a6fb2424a3"}, +] + +[package.dependencies] +typing-extensions = "*" + +[package.extras] +dev = ["black", "build", "flake8", "flake8-black", "isort", "jupyter-console", "mkdocs", "mkdocs-include-markdown-plugin", "mkdocstrings[python]", "pytest", "pytest-asyncio", "pytest-trio", "sphinx", "toml", "tox", "trio", "trio", "trio-typing", "twine", "twisted", "validate-pyproject[all]"] + [[package]] name = "pygments" version = "2.18.0" @@ -5251,6 +5381,24 @@ pluggy = ">=1.5,<2" [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.24.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"}, + {file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"}, +] + +[package.dependencies] +pytest = ">=8.2,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "python-box" version = "7.2.0" @@ -5336,6 +5484,17 @@ files = [ {file = "python_jsonpath-1.2.0.tar.gz", hash = "sha256:a29a84ec3ac38e5dcaa62ac2a215de72c4eb60cb1303e10700da980cf7873775"}, ] +[[package]] +name = "python-multipart" +version = "0.0.19" +description = "A streaming multipart parser for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python_multipart-0.0.19-py3-none-any.whl", hash = "sha256:f8d5b0b9c618575bf9df01c684ded1d94a338839bdd8223838afacfb4bb2082d"}, + {file = "python_multipart-0.0.19.tar.gz", hash = "sha256:905502ef39050557b7a6af411f454bc19526529ca46ae6831508438890ce12cc"}, +] + [[package]] name = "python-slugify" version = "8.0.4" @@ -5684,6 +5843,22 @@ pygments = ">=2.13.0,<3.0.0" [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "rich-toolkit" +version = "0.12.0" +description = "Rich toolkit for building command-line applications" +optional = false +python-versions = ">=3.8" +files = [ + {file = "rich_toolkit-0.12.0-py3-none-any.whl", hash = "sha256:a2da4416384410ae871e890db7edf8623e1f5e983341dbbc8cc03603ce24f0ab"}, + {file = "rich_toolkit-0.12.0.tar.gz", hash = "sha256:facb0b40418010309f77abd44e2583b4936656f6ee5c8625da807564806a6c40"}, +] + +[package.dependencies] +click = ">=8.1.7" +rich = ">=13.7.1" +typing-extensions = ">=4.12.2" + [[package]] name = "rpds-py" version = "0.20.0" @@ -7524,4 +7699,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.12,<3.13" -content-hash = "7d1a62623fd1b281c059d34ccd159608538429f6546d10b583875658eb9e96cb" +content-hash = "028048c5dc20685000a5eb0b290d69be45f247ccd606b2bb06eb15e55b828c2a" diff --git a/pyproject.toml b/pyproject.toml index 8539117c5..bf63fba0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,11 @@ dagster-k8s = "^0.24.6" pyiceberg = { extras = ["hive"], version = "^0.7.1" } connectorx = "^0.4.0" bokeh = "^3.6.1" +fastapi = { extras = ["standard"], version = "^0.115.6" } +pyee = "^12.1.1" +aiotrino = "^0.2.3" +pytest-asyncio = "^0.24.0" +isort = "^5.13.2" [tool.poetry.scripts] @@ -136,3 +141,6 @@ exclude = [ "warehouse/oso_dagster/dlt_sources/sql_database/**/*.py", "warehouse/oso_dagster/dlt_sources/sql_database/*.py", ] + +[tool.isort] +profile = "black" diff --git a/warehouse/metrics_tools/compute/__init__.py b/warehouse/metrics_tools/compute/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/warehouse/metrics_tools/compute/cache.py b/warehouse/metrics_tools/compute/cache.py new file mode 100644 index 000000000..8ca6356dd --- /dev/null +++ b/warehouse/metrics_tools/compute/cache.py @@ -0,0 +1,335 @@ +import abc +import asyncio +import copy +import logging +import queue +import typing as t +import uuid + +from aiotrino.dbapi import Connection +from pydantic import BaseModel +from pyee.asyncio import AsyncIOEventEmitter +from sqlglot import exp +from sqlmesh.core.dialect import parse_one + +from .types import ExportReference, ExportType + +logger = logging.getLogger(__name__) + + +class ExportCacheCompletedQueueItem(BaseModel): + table: str + export_table: str + + +class ExportCacheQueueItem(BaseModel): + table: str + + +class DBExportAdapter(abc.ABC): + async def export_table(self, table: str) -> ExportReference: + raise NotImplementedError() + + async def clean_export_table(self, table: str): + raise NotImplementedError() + + +class FakeExportAdapter(DBExportAdapter): + def __init__(self, log_override: t.Optional[logging.Logger] = None): + self.logger = log_override or logger + + async def export_table(self, table: str) -> ExportReference: + self.logger.info(f"fake exporting table: {table}") + return ExportReference( + table=table, type=ExportType.GCS, payload={"gcs_path": "fake_path:{table}"} + ) + + async def clean_export_table(self, table: str): + pass + + +class TrinoExportAdapter(DBExportAdapter): + def __init__( + self, + db: Connection, + gcs_bucket: str, + hive_catalog: str, + hive_schema: str, + log_override: t.Optional[logging.Logger] = None, + ): + self.db = db + self.gcs_bucket = gcs_bucket + self.hive_catalog = hive_catalog + self.hive_schema = hive_schema + self.logger = log_override or logger + + async def export_table(self, table: str) -> ExportReference: + columns: t.List[t.Tuple[str, str]] = [] + + col_result = await self.run_query(f"SHOW COLUMNS FROM {table}") + + for row in col_result: + column_name = row[0] + column_type = row[1] + columns.append((column_name, column_type)) + + table_exp = exp.to_table(table) + self.logger.debug(f"retrieved columns for {table} export: {columns}") + export_table_name = f"export_{table_exp.this.this}_{uuid.uuid4().hex}" + + gcs_path = f"gs://{self.gcs_bucket}/trino-export/{export_table_name}/" + + # We use a little bit of a hybrid templating+sqlglot magic to generate + # the create and insert queries. This saves us having to figure out the + # exact sqlglot objects + base_create_query = f""" + CREATE table "{self.hive_catalog}"."{self.hive_schema}"."{export_table_name}" ( + placeholder VARCHAR, + ) WITH ( + format = 'PARQUET', + external_location = '{gcs_path}' + ) + """ + + # Parse the create query + create_query = parse_one(base_create_query) + # Rewrite the column definitions we need to rewrite + create_query.this.set( + "expressions", + [ + exp.ColumnDef( + this=exp.to_identifier(column_name), + kind=parse_one(column_type, into=exp.DataType), + ) + for column_name, column_type in columns + ], + ) + + # Execute the create query which will create the export table + await self.run_query(create_query.sql(dialect="trino")) + + # Again using a hybrid templating+sqlglot magic to generate the insert + # for the export table + base_insert_query = f""" + INSERT INTO "{self.hive_catalog}"."{self.hive_schema}"."{export_table_name}" (placeholder) + SELECT placeholder + FROM {table_exp} + """ + + column_identifiers = [ + exp.to_identifier(column_name) for column_name, _ in columns + ] + + # Rewrite the column identifiers in the insert into statement + insert_query = parse_one(base_insert_query) + insert_query.this.set( + "expressions", + column_identifiers, + ) + + # Rewrite the column identifiers in the select statement + select = t.cast(exp.Select, insert_query.expression) + select.set("expressions", column_identifiers) + + # Execute the insert query which will populate the export table + await self.run_query(insert_query.sql(dialect="trino")) + + return ExportReference( + table=table, type=ExportType.GCS, payload={"gcs_path": gcs_path} + ) + + async def run_query(self, query: str): + cursor = await self.db.cursor() + self.logger.info(f"EXECUTING: {query}") + await cursor.execute(query) + return await cursor.fetchall() + + async def clean_export_table(self, table: str): + pass + + +def setup_trino_cache_export_manager( + db: Connection, + gcs_bucket: str, + hive_catalog: str, + hive_schema: str, + preloaded_exported_map: t.Optional[t.Dict[str, ExportReference]] = None, + log_override: t.Optional[logging.Logger] = None, +): + adapter = TrinoExportAdapter( + db, gcs_bucket, hive_catalog, hive_schema, log_override=log_override + ) + return CacheExportManager.setup( + export_adapter=adapter, + preloaded_exported_map=preloaded_exported_map, + log_override=log_override, + ) + + +def setup_fake_cache_export_manager( + preloaded_exported_map: t.Optional[t.Dict[str, ExportReference]] = None, + log_override: t.Optional[logging.Logger] = None, +): + adapter = FakeExportAdapter() + return CacheExportManager.setup( + export_adapter=adapter, + preloaded_exported_map=preloaded_exported_map, + log_override=log_override, + ) + + +class CacheExportManager: + """Manages the export of tables to a cache location. For now this only + supports GCS and can be used easily by duckdb or other compute resources. + This is necessary because pyiceberg and duckdb's iceberg libraries are quite + slow at processing the lakehouse data directly. In the future we'd simply + want to use iceberg directly but for now this is a necessary workaround. + + This class requires a database export adapter. The adapter is called to + trigger the database export. Once the export is completed any consumers of + this export manager can listen for the `exported_table` event to know when + the export is complete. + """ + + export_queue_task: asyncio.Task + export_queue: asyncio.Queue[ExportCacheQueueItem] + event_emitter: AsyncIOEventEmitter + + @classmethod + async def setup( + cls, + export_adapter: DBExportAdapter, + preloaded_exported_map: t.Optional[t.Dict[str, ExportReference]] = None, + log_override: t.Optional[logging.Logger] = None, + ): + cache = cls( + export_adapter, + preloaded_exported_map=preloaded_exported_map, + log_override=log_override, + ) + await cache.start() + return cache + + def __init__( + self, + export_adapter: DBExportAdapter, + preloaded_exported_map: t.Optional[t.Dict[str, ExportReference]] = None, + log_override: t.Optional[logging.Logger] = None, + ): + self.exported_map: t.Dict[str, ExportReference] = preloaded_exported_map or {} + self.export_adapter = export_adapter + self.exported_map_lock = asyncio.Lock() + self.export_queue: asyncio.Queue[ExportCacheQueueItem] = asyncio.Queue() + self.export_completed_queue: queue.Queue[ExportCacheCompletedQueueItem] = ( + queue.Queue() + ) + self.stop_signal = asyncio.Event() + self.logger = log_override or logger + self.event_emitter = AsyncIOEventEmitter() + + async def start(self): + self.export_queue_task = asyncio.create_task(self.export_queue_loop()) + + async def stop(self): + self.stop_signal.set() + await self.export_queue_task + + async def export_queue_loop(self): + in_progress: t.Set[str] = set() + + async def export_table(table: str): + try: + return await self._export_table_for_cache(table) + except Exception as e: + self.logger.error(f"Error exporting table {table}: {e}") + in_progress.remove(table) + + while not self.stop_signal.is_set(): + try: + item = await asyncio.wait_for(self.export_queue.get(), timeout=1) + except asyncio.TimeoutError: + continue + if item.table in in_progress: + # The table is already being exported. Skip this in the queue + continue + in_progress.add(item.table) + export_reference = await export_table(item.table) + self.event_emitter.emit( + "exported_table", table=item.table, export_reference=export_reference + ) + self.export_queue.task_done() + + async def add_export_table_reference( + self, table: str, export_reference: ExportReference + ): + await self.add_export_table_references({table: export_reference}) + + async def add_export_table_references( + self, table_map: t.Dict[str, ExportReference] + ): + async with self.exported_map_lock: + self.exported_map.update(table_map) + + async def inspect_export_table_references(self): + async with self.exported_map_lock: + return copy.deepcopy(self.exported_map) + + async def get_export_table_reference(self, table: str): + async with self.exported_map_lock: + reference = self.exported_map.get(table) + if not reference: + return None + return copy.deepcopy(reference) + + async def _export_table_for_cache(self, table: str): + """Triggers an export of a table to a cache location in GCS. This does + this by using the Hive catalog in trino to create a new table with the + same schema as the original table, but with a different name. This new + table is then used as the cache location for the original table.""" + + export_reference = await self.export_adapter.export_table(table) + self.logger.info(f"exported table: {table} -> {export_reference}") + return export_reference + + async def resolve_export_references(self, tables: t.List[str]): + """Resolves any required export table references or queues up a list of + tables to be exported to a cache location. Once ready, the map of tables + is resolved.""" + future: asyncio.Future[t.Dict[str, ExportReference]] = ( + asyncio.get_event_loop().create_future() + ) + + tables_to_export = set(tables) + registration = None + export_map: t.Dict[str, ExportReference] = {} + + # Ensure we are only comparing unique tables + for table in set(tables): + reference = await self.get_export_table_reference(table) + if reference is not None: + export_map[table] = reference + tables_to_export.remove(table) + if len(tables_to_export) == 0: + return export_map + + self.logger.info(f"unknown tables to export: {tables_to_export}") + + async def handle_exported_table( + *, table: str, export_reference: ExportReference + ): + self.logger.info(f"exported table ready: {table} -> {export_reference}") + if table in tables_to_export: + tables_to_export.remove(table) + export_map[table] = export_reference + await self.add_export_table_reference(table, export_reference) + if len(tables_to_export) == 0: + future.set_result(export_map) + if registration: + self.event_emitter.remove_listener("exported_table", registration) + + registration = self.event_emitter.add_listener( + "exported_table", handle_exported_table + ) + for table in tables_to_export: + self.export_queue.put_nowait(ExportCacheQueueItem(table=table)) + return await future diff --git a/warehouse/metrics_tools/compute/client.py b/warehouse/metrics_tools/compute/client.py new file mode 100644 index 000000000..aaaaca9fc --- /dev/null +++ b/warehouse/metrics_tools/compute/client.py @@ -0,0 +1,236 @@ +"""Metrics Calculation Service Client""" + +import logging +import time +import typing as t +from datetime import datetime + +import requests +from metrics_tools.compute.types import ( + ClusterStartRequest, + ClusterStatus, + EmptyResponse, + ExportedTableLoadRequest, + ExportReference, + InspectCacheResponse, + QueryJobStatus, + QueryJobStatusResponse, + QueryJobSubmitRequest, + QueryJobSubmitResponse, +) +from metrics_tools.definition import PeerMetricDependencyRef +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +logger = logging.getLogger(__name__) + + +class ResponseObject[T](t.Protocol): + def model_validate(self, obj: dict) -> T: ... + + +class Client: + """A metrics calculation service client""" + + url: str + logger: logging.Logger + + def __init__(self, url: str, log_override: t.Optional[logging.Logger] = None): + self.url = url + self.logger = log_override or logger + + def calculate_metrics( + self, + query_str: str, + start: datetime, + end: datetime, + dialect: str, + batch_size: int, + columns: t.List[t.Tuple[str, str]], + ref: PeerMetricDependencyRef, + locals: t.Dict[str, t.Any], + dependent_tables_map: t.Dict[str, str], + cluster_min_size: int = 6, + cluster_max_size: int = 6, + retries: t.Optional[int] = None, + ): + """Calculate metrics for a given period and write the results to a gcs + folder. This method is a high level method that triggers all of the + necessary calls to complete a metrics calculation. Namely: + + 1. Tell the metrics calculation service to start a compute cluster + 2. Submit a job to the service + 3. Wait for the job to complete (and log progress) + 4. Return the gcs result path + + Args: + query_str (str): The query to execute + start (datetime): The start date + end (datetime): The end date + dialect (str): The sql dialect for the provided query + batch_size (int): The batch size + columns (t.List[t.Tuple[str, str]]): The columns to expect + ref (PeerMetricDependencyRef): The dependency reference + locals (t.Dict[str, t.Any]): The local variables to use + dependent_tables_map (t.Dict[str, str]): The dependent tables map + retries (t.Optional[int], optional): The number of retries. Defaults to None. + + Returns: + ExportReference: The export reference for the resulting calculation + """ + # Trigger the cluster start + status = self.start_cluster( + min_size=cluster_min_size, max_size=cluster_max_size + ) + self.logger.info(f"cluster status: {status}") + + job_response = self.submit_job( + query_str, + start, + end, + dialect, + batch_size, + columns, + ref, + locals, + dependent_tables_map, + retries, + ) + job_id = job_response.job_id + export_reference = job_response.export_reference + + # Wait for the job to be completed + status_response = self.get_job_status(job_id) + while status_response.status in [ + QueryJobStatus.PENDING, + QueryJobStatus.RUNNING, + ]: + self.logger.info(f"job[{job_id}] is still pending") + status_response = self.get_job_status(job_id) + time.sleep(5) + self.logger.info(f"job[{job_id}] status is {status_response}") + + if status_response.status == QueryJobStatus.FAILED: + self.logger.error( + f"job[{job_id}] failed with status {status_response.status}" + ) + raise Exception( + f"job[{job_id}] failed with status {status_response.status}" + ) + + self.logger.info( + f"job[{job_id}] completed with status {status_response.status}" + ) + + return export_reference + + def start_cluster(self, min_size: int, max_size: int): + """Start a compute cluster with the given min and max size""" + request = ClusterStartRequest(min_size=min_size, max_size=max_size) + response = self.service_post_with_input( + ClusterStatus, "/cluster/start", request + ) + return response + + def submit_job( + self, + query_str: str, + start: datetime, + end: datetime, + dialect: str, + batch_size: int, + columns: t.List[t.Tuple[str, str]], + ref: PeerMetricDependencyRef, + locals: t.Dict[str, t.Any], + dependent_tables_map: t.Dict[str, str], + retries: t.Optional[int] = None, + ): + """Submit a job to the metrics calculation service + + Args: + query_str (str): The query to execute + start (datetime): The start date + end (datetime): The end date + dialect (str): The sql dialect for the provided query + batch_size (int): The batch size + columns (t.List[t.Tuple[str, str]]): The columns to expect + ref (PeerMetricDependencyRef): The dependency reference + locals (t.Dict[str, t.Any]): The local variables to use + dependent_tables_map (t.Dict[str, str]): The dependent tables map + retries (t.Optional[int], optional): The number of retries. Defaults to None. + + Returns: + QueryJobSubmitResponse: The job response from the metrics calculation service + """ + request = QueryJobSubmitRequest( + query_str=query_str, + start=start, + end=end, + dialect=dialect, + batch_size=batch_size, + columns=columns, + ref=ref, + locals=locals, + dependent_tables_map=dependent_tables_map, + retries=retries, + ) + job_response = self.service_post_with_input( + QueryJobSubmitResponse, "/job/submit", request + ) + return job_response + + def get_job_status(self, job_id: str): + """Get the status of a job""" + return self.service_get(QueryJobStatusResponse, f"/job/status/{job_id}") + + def run_cache_manual_load(self, map: t.Dict[str, ExportReference]): + """Load a cache with the provided map. This is useful for testing + purposes but generally shouldn't be used in production""" + req = ExportedTableLoadRequest(map=map) + return self.service_post_with_input(EmptyResponse, "/cache/manual", req) + + def inspect_cache(self): + """Inspect the cached export tables for the service""" + return self.service_get(InspectCacheResponse, "/cache/inspect") + + def service_request[ + T + ](self, method: str, factory: ResponseObject[T], path: str, **kwargs) -> T: + response = requests.request( + method, + f"{self.url}{path}", + **kwargs, + ) + return factory.model_validate(response.json()) + + def service_post_with_input[ + T + ]( + self, + factory: ResponseObject[T], + path: str, + input: BaseModel, + params: t.Optional[t.Dict[str, t.Any]] = None, + ) -> T: + return self.service_request( + "POST", + factory, + path, + json=to_jsonable_python(input), + params=params, + ) + + def service_get[ + T + ]( + self, + factory: ResponseObject[T], + path: str, + params: t.Optional[t.Dict[str, t.Any]] = None, + ) -> T: + return self.service_request( + "GET", + factory, + path, + params=params, + ) diff --git a/warehouse/metrics_tools/compute/cluster.py b/warehouse/metrics_tools/compute/cluster.py index bdb29b661..98dadc12d 100644 --- a/warehouse/metrics_tools/compute/cluster.py +++ b/warehouse/metrics_tools/compute/cluster.py @@ -1,20 +1,442 @@ """Sets up a dask cluster """ +import abc +import asyncio +import inspect +import logging import typing as t -from dask_kubernetes.operator import KubeCluster + +from dask.distributed import Client +from dask.distributed import Future as DaskFuture +from dask.distributed import LocalCluster +from dask_kubernetes.operator import KubeCluster, make_cluster_spec +from metrics_tools.compute.types import ClusterStatus +from pyee.asyncio import AsyncIOEventEmitter + +from .worker import ( + DuckDBMetricsWorkerPlugin, + DummyMetricsWorkerPlugin, + MetricsWorkerPlugin, +) + +logger = logging.getLogger(__name__) def start_duckdb_cluster( namespace: str, - gcs_key_id: str, - gcs_secret: str, - duckdb_path: str, cluster_spec: t.Optional[dict] = None, + min_size: int = 6, + max_size: int = 6, + quiet: bool = False, + **kwargs: t.Any, +): + options: t.Dict[str, t.Any] = {"namespace": namespace} + options.update(kwargs) + print("starting duckdb cluster") + if cluster_spec: + options["custom_cluster_spec"] = cluster_spec + print(cluster_spec) + print("starting duckdb cluster1") + cluster = KubeCluster(quiet=quiet, **options) + print(f"starting duckdb cluster with min_size={min_size} and max_size={max_size}") + cluster.adapt(minimum=min_size, maximum=max_size) + return cluster + + +async def start_duckdb_cluster_async( + namespace: str, + cluster_spec: t.Optional[dict] = None, + min_size: int = 6, + max_size: int = 6, + **kwargs: t.Any, ): + """The async version of start_duckdb_cluster which wraps the sync version in + a thread. The "async" version of dask's KubeCluster doesn't work as + expected. So for now we do this.""" + options: t.Dict[str, t.Any] = {"namespace": namespace} + options.update(kwargs) if cluster_spec: options["custom_cluster_spec"] = cluster_spec - cluster = KubeCluster(**options) - cluster.adapt(minimum=6, maximum=6) + + # loop = asyncio.get_running_loop() + cluster = await KubeCluster(asynchronous=True, **options) + print(f"is cluster awaitable?: {inspect.isawaitable(cluster)}") + adapt_response = cluster.adapt(minimum=min_size, maximum=max_size) + print(f"is adapt_response awaitable?: {inspect.isawaitable(adapt_response)}") + if inspect.isawaitable(adapt_response): + await adapt_response return cluster + + # return await asyncio.to_thread( + # start_duckdb_cluster, namespace, cluster_spec, min_size, max_size + # ) + + +class ClusterProxy(abc.ABC): + async def client(self) -> Client: + raise NotImplementedError("client not implemented") + + async def status(self) -> ClusterStatus: + raise NotImplementedError("status not implemented") + + async def stop(self): + raise NotImplementedError("stop not implemented") + + @property + def dashboard_link(self): + raise NotImplementedError("dashboard_link not implemented") + + @property + def workers(self): + raise NotImplementedError("workers not implemented") + + +class ClusterFactory(abc.ABC): + async def create_cluster(self, min_size: int, max_size: int) -> ClusterProxy: + raise NotImplementedError("start_cluster not implemented") + + +class LocalClusterProxy(ClusterProxy): + def __init__(self, cluster: LocalCluster): + self.cluster = cluster + + async def client(self) -> Client: + return await Client(self.cluster, asynchronous=True) + + async def status(self) -> ClusterStatus: + return ClusterStatus( + status="Cluster running", + is_ready=True, + dashboard_url=self.cluster.dashboard_link, + workers=len(self.cluster.scheduler_info["workers"]), + ) + + async def stop(self): + self.cluster.close() + + @property + def dashboard_link(self): + return self.cluster.dashboard_link + + @property + def workers(self): + return len(self.cluster.scheduler_info["workers"]) + + +class KubeClusterProxy(ClusterProxy): + def __init__(self, cluster: KubeCluster): + self.cluster = cluster + + async def client(self) -> Client: + return await Client(self.cluster, asynchronous=True) + + async def status(self) -> ClusterStatus: + return ClusterStatus( + status="Cluster running", + is_ready=True, + dashboard_url=self.cluster.dashboard_link, + workers=len(self.cluster.scheduler_info["workers"]), + ) + + async def stop(self): + await self.cluster.close() + + @property + def dashboard_link(self): + return self.cluster.dashboard_link + + @property + def workers(self): + return len(self.cluster.scheduler_info["workers"]) + + +class LocalClusterFactory(ClusterFactory): + async def create_cluster(self, min_size: int, max_size: int) -> ClusterProxy: + return LocalClusterProxy( + await LocalCluster(n_workers=max_size, asynchronous=True) + ) + + +class KubeClusterFactory(ClusterFactory): + def __init__( + self, + namespace: str, + cluster_spec: t.Optional[dict] = None, + log_override: t.Optional[logging.Logger] = None, + **kwargs: t.Any, + ): + self._namespace = namespace + self.logger = log_override or logger + self._cluster_spec = cluster_spec + self.kwargs = kwargs + + async def create_cluster(self, min_size: int, max_size: int): + cluster = await start_duckdb_cluster_async( + self._namespace, self._cluster_spec, min_size, max_size, **self.kwargs + ) + return KubeClusterProxy(cluster) + + +class ClusterManager: + """Internal metrics worker cluster manager""" + + event_emitter: AsyncIOEventEmitter + + @classmethod + def with_metrics_plugin( + cls, + gcs_bucket: str, + gcs_key_id: str, + gcs_secret: str, + duckdb_path: str, + cluster_factory: ClusterFactory, + log_override: t.Optional[logging.Logger] = None, + ): + def plugin_factory(): + return DuckDBMetricsWorkerPlugin( + gcs_bucket, gcs_key_id, gcs_secret, duckdb_path + ) + + return cls(plugin_factory, cluster_factory, log_override) + + @classmethod + def with_dummy_metrics_plugin( + cls, + cluster_factory: ClusterFactory, + log_override: t.Optional[logging.Logger] = None, + ): + def plugin_factory(): + return DummyMetricsWorkerPlugin() + + return cls(plugin_factory, cluster_factory, log_override) + + def __init__( + self, + plugin_factory: t.Callable[[], MetricsWorkerPlugin], + cluster_factory: ClusterFactory, + log_override: t.Optional[logging.Logger] = None, + ): + self.plugin_factory = plugin_factory + self._cluster: t.Optional[ClusterProxy] = None + self._client: t.Optional[Client] = None + self.logger = log_override or logger + self._starting = False + self._lock = asyncio.Lock() + self.factory = cluster_factory + self._start_task: t.Optional[asyncio.Task] = None + self.event_emitter = AsyncIOEventEmitter() + + async def start_cluster(self, min_size: int, max_size: int) -> ClusterStatus: + async with self._lock: + if self._cluster is not None: + self.logger.info("cluster already running") + return ClusterStatus( + status="Cluster already running", + is_ready=True, + dashboard_url=self._cluster.dashboard_link, + workers=self._cluster.workers, + ) + if self._starting: + self.logger.info("cluster already starting") + return ClusterStatus( + status="Cluster starting", + is_ready=False, + dashboard_url="", + workers=0, + ) + self.logger.info("cluster not running, starting") + self._starting = True + self._start_task = asyncio.create_task( + self._start_cluster_internal(min_size, max_size) + ) + return ClusterStatus( + status="Cluster starting", + is_ready=False, + dashboard_url="", + workers=0, + ) + + async def _start_cluster_internal(self, min_size: int, max_size: int): + self.logger.info("starting cluster") + try: + self.logger.debug("calling create_cluster factory") + cluster = await self.factory.create_cluster(min_size, max_size) + self.logger.debug("getting the client") + client = await cluster.client() + registration = client.register_worker_plugin( + self.plugin_factory(), + name="metrics", + ) + self.logger.debug(f"What type is {type(registration)}") + if isinstance(registration, asyncio.Future): + self.logger.debug("registration is an async future") + await registration + if isinstance(registration, DaskFuture): + self.logger.debug("registration is a dask future") + await registration + if isinstance(registration, t.Coroutine): + self.logger.debug("registration is a coroutine") + await registration + self.logger.debug("done registring the client plugin") + async with self._lock: + self._cluster = cluster + self._client = client + except Exception as e: + self.logger.error(f"Failed to start cluster: {e}") + async with self._lock: + self._starting = False + raise e + async with self._lock: + self._starting = False + self.logger.debug("emitting cluster_ready") + self.event_emitter.emit("cluster_ready") + + async def stop_cluster(self): + self.logger.info("stopping cluster") + async with self._lock: + if self._cluster is not None: + await self._cluster.stop() + self._cluster = None + self._client = None + + async def get_cluster_status(self): + async with self._lock: + if self._cluster is None: + return ClusterStatus( + status="Cluster not started", + is_ready=False, + dashboard_url="", + workers=0, + ) + return ClusterStatus( + status="Cluster running", + is_ready=True, + dashboard_url=self._cluster.dashboard_link, + workers=self._cluster.workers, + ) + + @property + async def client(self): + self.logger.debug("getting client") + async with self._lock: + client = self._client + assert client is not None, "Client hasn't been initialized" + return client + + async def close(self): + if self._start_task and not self._start_task.done(): + await self._start_task + + if self._cluster: + await self._cluster.stop() + + async def wait_for_ready(self) -> bool: + async with self._lock: + if self._cluster is not None: + self.logger.debug("no wait needed, cluster is ready") + return True + + future: asyncio.Future[bool] = asyncio.Future() + + def cluster_ready(): + self.logger.info("cluster is ready received") + future.set_result(True) + + self.event_emitter.once("cluster_ready", cluster_ready) + return await future + + +def make_new_cluster( + image: str, + cluster_id: str, + service_account_name: str, + threads: int, + scheduler_memory_request: str, + scheduler_memory_limit: str, + worker_memory_request: str, + worker_memory_limit: str, +): + spec = make_cluster_spec( + name=f"{cluster_id}", + resources={ + "requests": {"memory": scheduler_memory_request}, + "limits": {"memory": scheduler_memory_limit}, + }, + image=image, + ) + spec["spec"]["scheduler"]["spec"]["tolerations"] = [ + { + "key": "pool_type", + "effect": "NoSchedule", + "operator": "Equal", + "value": "sqlmesh-worker", + } + ] + spec["spec"]["scheduler"]["spec"]["nodeSelector"] = {"pool_type": "sqlmesh-worker"} + + spec["spec"]["worker"]["spec"]["tolerations"] = [ + { + "key": "pool_type", + "effect": "NoSchedule", + "operator": "Equal", + "value": "sqlmesh-worker", + } + ] + spec["spec"]["worker"]["spec"]["nodeSelector"] = {"pool_type": "sqlmesh-worker"} + + # Give the workers a different resource allocation + for container in spec["spec"]["worker"]["spec"]["containers"]: + container["resources"] = { + "limits": { + "memory": worker_memory_limit, + }, + "requests": { + "memory": worker_memory_request, + }, + } + volume_mounts = container.get("volumeMounts", []) + volume_mounts.append( + { + "mountPath": "/scratch", + "name": "scratch", + } + ) + if container["name"] == "worker": + args: t.List[str] = container["args"] + args.append("--nthreads") + args.append(f"{threads}") + args.append("--nworkers") + args.append("1") + args.append("--memory-limit") + args.append("0") + container["volumeMounts"] = volume_mounts + volumes = spec["spec"]["worker"]["spec"].get("volumes", []) + volumes.append( + { + "name": "scratch", + "emptyDir": {}, + } + ) + spec["spec"]["worker"]["spec"]["volumes"] = volumes + spec["spec"]["worker"]["spec"]["serviceAccountName"] = service_account_name + + return spec + + +def make_new_cluster_with_defaults(): + # Import here to avoid dependency on constants for all dependents on the + # cluster module + from . import constants + + return make_new_cluster( + f"{constants.cluster_worker_image_repo}:{constants.cluster_worker_image_tag}", + constants.cluster_name, + constants.cluster_namespace, + threads=constants.cluster_worker_threads, + scheduler_memory_limit=constants.scheduler_memory_limit, + scheduler_memory_request=constants.scheduler_memory_request, + worker_memory_limit=constants.worker_memory_limit, + worker_memory_request=constants.worker_memory_request, + ) diff --git a/warehouse/metrics_tools/compute/constants.py b/warehouse/metrics_tools/compute/constants.py new file mode 100644 index 000000000..17eeaa1c8 --- /dev/null +++ b/warehouse/metrics_tools/compute/constants.py @@ -0,0 +1,48 @@ +from dotenv import load_dotenv +from metrics_tools.utils import env + +load_dotenv() + +cluster_namespace = env.required_str("METRICS_CLUSTER_NAMESPACE") +cluster_name = env.required_str("METRICS_CLUSTER_NAME") +cluster_worker_threads = env.required_int("METRICS_CLUSTER_WORKER_THREADS", 16) +cluster_worker_image_repo = env.required_str( + "METRICS_CLUSTER_WORKER_IMAGE_REPO", "ghcr.io/opensource-observer/dagster-dask" +) +cluster_worker_image_tag = env.required_str("METRICS_CLUSTER_WORKER_IMAGE_TAG") +scheduler_memory_limit = env.required_str("METRICS_SCHEDULER_MEMORY_LIMIT", "90000Mi") +scheduler_memory_request = env.required_str( + "METRICS_SCHEDULER_MEMORY_REQUEST", "85000Mi" +) +worker_memory_limit = env.required_str("METRICS_WORKER_MEMORY_LIMIT", "90000Mi") +worker_memory_request = env.required_str("METRICS_WORKER_MEMORY_REQUEST", "85000Mi") + +gcs_bucket = env.required_str("METRICS_GCS_BUCKET") +gcs_key_id = env.required_str("METRICS_GCS_KEY_ID") +gcs_secret = env.required_str("METRICS_GCS_SECRET") + +results_path_prefix = env.required_str( + "METRICS_GCS_RESULTS_PATH_PREFIX", "metrics-calc-service-results" +) + +worker_duckdb_path = env.required_str("METRICS_WORKER_DUCKDB_PATH") + +trino_host = env.required_str("METRICS_TRINO_HOST") +trino_port = env.required_str("METRICS_TRINO_PORT") +trino_user = env.required_str("METRICS_TRINO_USER") +trino_catalog = env.required_str("METRICS_TRINO_CATALOG") + +hive_catalog = env.required_str("METRICS_HIVE_CATALOG", "source") +hive_schema = env.required_str("METRICS_HIVE_SCHEMA", "export") + +debug_all = env.ensure_bool("METRICS_DEBUG_ALL", False) +if not debug_all: + debug_cache = env.ensure_bool("METRICS_DEBUG_CACHE", False) + debug_cluster = env.ensure_bool("METRICS_DEBUG_CLUSTER", False) + debug_cluster_no_shutdown = env.ensure_bool( + "METRICS_DEBUG_CLUSTER_NO_SHUTDOWN", False + ) +else: + debug_cache = debug_all + debug_cluster = debug_all + debug_cluster_no_shutdown = debug_all diff --git a/warehouse/metrics_tools/compute/debug.py b/warehouse/metrics_tools/compute/debug.py new file mode 100644 index 000000000..ee377aee1 --- /dev/null +++ b/warehouse/metrics_tools/compute/debug.py @@ -0,0 +1,30 @@ +"""Random manual debugging utilities""" + +import asyncio +import logging + +from metrics_tools.compute.cluster import ( + KubeClusterFactory, + make_new_cluster_with_defaults, + start_duckdb_cluster, +) + +from . import constants + +logger = logging.getLogger(__name__) + + +def test_setup_cluster(): + cluster_spec = make_new_cluster_with_defaults() + return start_duckdb_cluster(constants.cluster_namespace, cluster_spec) + + +def async_test_setup_cluster(): + cluster_spec = make_new_cluster_with_defaults() + + cluster_factory = KubeClusterFactory( + constants.cluster_namespace, + cluster_spec=cluster_spec, + log_override=logger, + ) + asyncio.run(cluster_factory.create_cluster(2, 2)) diff --git a/warehouse/metrics_tools/compute/flight.py b/warehouse/metrics_tools/compute/flight.py deleted file mode 100644 index d7fbfcaa6..000000000 --- a/warehouse/metrics_tools/compute/flight.py +++ /dev/null @@ -1,706 +0,0 @@ -"""A python arrow service that is used to as a proxy to the cluster of compute for the -metrics tools. This allows us to change the underlying compute infrastructure -while maintaining the same interface to the sqlmesh runner. -""" - -import concurrent.futures -import logging -import sys -import typing as t -import json -import click -import time -import concurrent -from metrics_tools.compute.cluster import start_duckdb_cluster -from metrics_tools.compute.worker import MetricsWorkerPlugin -from metrics_tools.definition import PeerMetricDependencyRef -from metrics_tools.runner import FakeEngineAdapter, MetricsRunner -from metrics_tools.transformer.tables import MapTableTransform -from metrics_tools.transformer.transformer import SQLTransformer -import pyarrow as pa -import pyarrow.flight as fl -import asyncio -import pandas as pd -import threading -import trino -import queue -from sqlglot import exp -from sqlmesh.core.dialect import parse_one -from trino.dbapi import Connection, Cursor -import abc -import uuid -from pydantic import BaseModel -from datetime import datetime -from dask.distributed import Client, get_worker, Future, as_completed, print as dprint -from dask_kubernetes.operator import KubeCluster, make_cluster_spec -from dask_kubernetes.operator.kubecluster.kubecluster import CreateMode -from dataclasses import dataclass - - -logger = logging.getLogger(__name__) - - -type_mapping = { - "INTEGER": "int64", - "BIGINT": "int64", - "SMALLINT": "int32", - "NUMERIC": "float64", - "REAL": "float32", - "DOUBLE PRECISION": "float64", - "VARCHAR": "object", - "TEXT": "object", - "BOOLEAN": "bool", - "DATE": "datetime64[ns]", - "TIMESTAMP": "datetime64[ns]", - # Add more mappings as needed -} - -arrow_type_mapping = { - "INTEGER": pa.int32(), - "BIGINT": pa.int64(), - "SMALLINT": pa.int16(), - "NUMERIC": pa.float64(), - "REAL": pa.float32(), - "DOUBLE PRECISION": pa.float64(), - "VARCHAR": pa.string(), - "TEXT": pa.string(), - "BOOLEAN": pa.bool_(), - "DATE": pa.date32(), - "TIMESTAMP": pa.timestamp("us"), -} - - -class QueryInput(BaseModel): - query_str: str - start: datetime - end: datetime - dialect: str - batch_size: int - columns: t.List[t.Tuple[str, str]] - ref: PeerMetricDependencyRef - locals: t.Dict[str, t.Any] - dependent_tables_map: t.Dict[str, str] - - def to_ticket(self) -> fl.Ticket: - return fl.Ticket(self.model_dump_json()) - - def to_column_names(self) -> pd.Series: - return pd.Series(list(map(lambda a: a[0], self.columns))) - - def to_arrow_schema(self) -> pa.Schema: - schema_input = [ - (col_name, arrow_type_mapping[col_type]) - for col_name, col_type in self.columns - ] - print(schema_input) - return pa.schema(schema_input) - - # def coerce_datetimes(self, df: pd.DataFrame) -> pd.DataFrame: - # for col_name, col_type in self.columns: - # if col_type == - - -class Engine(abc.ABC): - def run_query(self, query: str) -> Cursor: - raise NotImplementedError("run_query not implemented") - - -class TrinoEngine(Engine): - @classmethod - def create(cls, host: str, port: int, user: str, catalog: str): - conn = trino.dbapi.connect( - host=host, - port=port, - user=user, - catalog=catalog, - ) - return cls(conn) - - def __init__(self, conn: Connection): - self._conn = conn - - def run_query(self, query: str) -> Cursor: - cursor = self._conn.cursor() - logger.info(f"EXECUTING: {query}") - return cursor.execute(query) - - -def start_loop(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - -def run_coroutine_in_thread(coro): - loop = asyncio.new_event_loop() - thread = threading.Thread(target=start_loop, args=(loop,)) - thread.start() - - -def execute_duckdb_load( - id: int, gcs_path: str, queries: t.List[str], dependencies: t.Dict[str, str] -): - dprint("Starting duckdb load") - worker = get_worker() - plugin = t.cast(MetricsWorkerPlugin, worker.plugins["metrics"]) - for ref, actual in dependencies.items(): - dprint(f"Loading cache for {ref}:{actual}") - plugin.get_for_cache(ref, actual) - conn = plugin.connection - results: t.List[pd.DataFrame] = [] - for query in queries: - result = conn.execute(query).df() - results.append(result) - - return DuckdbLoadedItem( - id=id, - df=pd.concat(results, ignore_index=True, sort=False), - ) - - -@dataclass(kw_only=True) -class DuckdbLoadedItem: - id: int - df: pd.DataFrame - - -@dataclass(kw_only=True) -class ResultQueueItem: - id: int - record_batch: pa.RecordBatch - - -class MetricsCalculatorFlightServer(fl.FlightServerBase): - def __init__( - self, - cluster: KubeCluster, - engine: TrinoEngine, - gcs_bucket: str, - location: str = "grpc://0.0.0.0:8815", - exported_map: t.Optional[t.Dict[str, str]] = None, - downloaders: int = 64, - queue_size: int = 100, - ): - super().__init__(location) - self.data = pa.Table.from_pydict({"col1": [1, 2, 3]}) - self.loop_loop = asyncio.new_event_loop() - self.loop_thread = threading.Thread( - target=start_loop, - args=(self.loop_loop,), - ) - self.loop_thread.start() - self.engine = engine - self.cluster = cluster - self.exported_map: t.Dict[str, str] = exported_map or {} - self.gcs_bucket = gcs_bucket - self.queue_size = queue_size - self.downloader_count = downloaders - - def run_initialization( - self, hive_uri: str, gcs_key_id: str, gcs_secret: str, duckdb_path: str - ): - logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) - client = Client(self.cluster) - self.client = client - client.register_plugin( - MetricsWorkerPlugin( - self.gcs_bucket, - hive_uri, - gcs_key_id, - gcs_secret, - duckdb_path, - ), - name="metrics", - ) - - def finalizer(self): - self.client.close() - - def _ticket_to_query_input(self, ticket: fl.Ticket) -> QueryInput: - return QueryInput(**json.loads(ticket.ticket)) - - def table_rewrite(self, query: str, rewrite_map: t.Dict[str, str]): - transformer = SQLTransformer( - transforms=[ - MapTableTransform(rewrite_map), - ] - ) - return transformer.transform(query) - - def export_table_for_cache(self, table: str): - # Using the actual name - # Export with trino - if table in self.exported_map: - logger.debug(f"CACHE HIT FOR {table}") - return self.exported_map[table] - - columns: t.List[t.Tuple[str, str]] = [] - - col_result = self.engine.run_query(f"SHOW COLUMNS FROM {table}").fetchall() - for row in col_result: - column_name = row[0] - column_type = row[1] - columns.append((column_name, column_type)) - - table_exp = exp.to_table(table) - logger.info(f"RETREIVED COLUMNS: {columns}") - export_table_name = f"export_{table_exp.this.this}_{uuid.uuid4().hex}" - - base_create_query = f""" - CREATE table "source"."export"."{export_table_name}" ( - placeholder VARCHAR, - ) WITH ( - format = 'PARQUET', - external_location = 'gs://{self.gcs_bucket}/trino-export/{export_table_name}/' - ) - """ - create_query = parse_one(base_create_query) - create_query.this.set( - "expressions", - [ - exp.ColumnDef( - this=exp.to_identifier(column_name), - kind=parse_one(column_type, into=exp.DataType), - ) - for column_name, column_type in columns - ], - ) - - self.engine.run_query(create_query.sql(dialect="trino")) - - base_insert_query = f""" - INSERT INTO "source"."export"."{export_table_name}" (placeholder) - SELECT placeholder - FROM {table_exp} - """ - - column_identifiers = [ - exp.to_identifier(column_name) for column_name, _ in columns - ] - - insert_query = parse_one(base_insert_query) - insert_query.this.set( - "expressions", - column_identifiers, - ) - select = t.cast(exp.Select, insert_query.expression) - select.set("expressions", column_identifiers) - - self.engine.run_query(insert_query.sql(dialect="trino")) - - self.exported_map[table] = export_table_name - return self.exported_map[table] - - # def shutdown(self): - # pass - - def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket): - input = self._ticket_to_query_input(ticket) - - exported_dependent_tables_map: t.Dict[str, str] = {} - - # Parse the query - for ref_name, actual_name in input.dependent_tables_map.items(): - # Any deps, use trino to export to gcs - exported_table_name = self.export_table_for_cache(actual_name) - exported_dependent_tables_map[ref_name] = exported_table_name - - # rewrite the query for the temporary caches made by trino - # ex = self.table_rewrite(input.query_str, exported_dependent_tables_map) - # if len(ex) != 1: - # raise Exception("unexpected number of expressions") - - rewritten_query = parse_one(input.query_str).sql(dialect="duckdb") - # columns = input.to_column_names() - - # def gen(): - # futures: t.List[concurrent.futures.Future[pd.DataFrame]] = [] - # for rendered_query in runner.render_rolling_queries(input.start, input.end): - # future = asyncio.run_coroutine_threadsafe( - # async_gen_batch(self.engine, rendered_query, columns), - # self.loop_loop, - # ) - # futures.append(future) - # for res in concurrent.futures.as_completed(futures): - # yield pa.RecordBatch.from_pandas(res.result()) - - def gen_with_dask( - rewritten_query: str, - input: QueryInput, - exported_dependent_tables_map: t.Dict[str, str], - download_queue: queue.Queue[Future], - ): - client = self.client - futures: t.List[Future] = [] - current_batch: t.List[str] = [] - task_ids: t.List[int] = [] - - runner = MetricsRunner.from_engine_adapter( - FakeEngineAdapter("duckdb"), - rewritten_query, - input.ref, - input.locals, - ) - - task_id = 0 - for rendered_query in runner.render_rolling_queries(input.start, input.end): - current_batch.append(rendered_query) - if len(current_batch) >= input.batch_size: - future = client.submit( - execute_duckdb_load, - task_id, - current_batch[:], - exported_dependent_tables_map, - ) - futures.append(future) - current_batch = [] - task_ids.append(task_id) - task_id += 1 - if len(current_batch) > 0: - future = client.submit( - execute_duckdb_load, - task_id, - current_batch[:], - exported_dependent_tables_map, - ) - futures.append(future) - task_ids.append(task_id) - task_id += 1 - - completed_batches = 0 - total_batches = len(futures) - for future in as_completed(futures): - completed_batches += 1 - logger.info(f"progress received [{completed_batches}/{total_batches}]") - future = t.cast(Future, future) - if future.cancelled: - if future.done(): - logger.info("future actually done???") - else: - logger.error("future cancelled. skipping for now?") - print(future) - print(future.result() is not None) - continue - download_queue.put(future) - return task_ids - - def downloader( - kill_event: threading.Event, - download_queue: queue.Queue[Future], - res_queue: queue.Queue[ResultQueueItem], - ): - logger.debug("waiting for download") - while True: - try: - future = download_queue.get(timeout=0.1) - try: - item = t.cast(DuckdbLoadedItem, future.result()) - record_batch = pa.RecordBatch.from_pandas(item.df) - res_queue.put( - ResultQueueItem( - id=item.id, - record_batch=record_batch, - ) - ) - logger.debug("download completed") - finally: - download_queue.task_done() - except queue.Empty: - if kill_event.is_set(): - logger.debug("shutting down downloader") - return - if kill_event.is_set() and not download_queue.empty(): - logger.debug("shutting down downloader prematurely") - return - - def gen_record_batches(size: int): - download_queue: queue.Queue[Future] = queue.Queue(maxsize=size) - res_queue: queue.Queue[ResultQueueItem] = queue.Queue(maxsize=size) - kill_event = threading.Event() - result_queue_timeout = 5.0 - max_result_timeout = 300 - - with concurrent.futures.ThreadPoolExecutor( - max_workers=self.downloader_count + 5 - ) as executor: - dask_thread = executor.submit( - gen_with_dask, - rewritten_query, - input, - exported_dependent_tables_map, - download_queue, - ) - downloaders = [] - for i in range(self.downloader_count): - downloaders.append( - executor.submit( - downloader, kill_event, download_queue, res_queue - ) - ) - - wait_retries = 0 - - completed_task_ids: t.Set[int] = set() - task_ids: t.Optional[t.Set[int]] = None - - while task_ids != completed_task_ids: - try: - result = res_queue.get(timeout=result_queue_timeout) - wait_retries = 0 - logger.debug("sending batch to client") - - completed_task_ids.add(result.id) - - yield result.record_batch - except queue.Empty: - wait_retries += 1 - if task_ids is None: - # If the dask thread is done we know if we can check for completion - if dask_thread.done(): - task_ids = set(dask_thread.result()) - else: - # If we have waited longer then 15 mins let's stop waiting - current_wait_time = wait_retries * result_queue_timeout - if current_wait_time > max_result_timeout: - logger.debug( - "record batches might be completed. with some kind of error" - ) - break - kill_event.set() - logger.debug("waiting for the downloaders to shutdown") - executor.shutdown(cancel_futures=True) - - logger.debug( - f"Distributing query for {input.start} to {input.end}: {rewritten_query}" - ) - try: - return fl.GeneratorStream( - input.to_arrow_schema(), - gen_record_batches(size=self.queue_size), - ) - except Exception as e: - print("caught error") - logger.error("Caught error generating stream", exc_info=e) - raise e - - -def run_get( - start: str, - end: str, - batch_size: int = 1, -): - run_start = time.time() - client = fl.connect("grpc://0.0.0.0:8815") - input = QueryInput( - query_str=""" - SELECT bucket_day, to_artifact_id, from_artifact_id, event_source, event_type, SUM(amount) as amount - FROM metrics.events_daily_to_artifact - where bucket_day >= strptime(@start_ds, '%Y-%m-%d') and bucket_day <= strptime(@end_ds, '%Y-%m-%d') - group by - bucket_day, - to_artifact_id, - from_artifact_id, - event_source, - event_type - """, - start=datetime.strptime(start, "%Y-%m-%d"), - end=datetime.strptime(end, "%Y-%m-%d"), - dialect="duckdb", - columns=[ - ("bucket_day", "TIMESTAMP"), - ("to_artifact_id", "VARCHAR"), - ("from_artifact_id", "VARCHAR"), - ("event_source", "VARCHAR"), - ("event_type", "VARCHAR"), - ("amount", "NUMERIC"), - ], - ref=PeerMetricDependencyRef( - name="", entity_type="artifact", window=30, unit="day" - ), - locals={}, - dependent_tables_map={ - "metrics.events_daily_to_artifact": "sqlmesh__metrics.metrics__events_daily_to_artifact__2357434958" - }, - batch_size=batch_size, - ) - reader = client.do_get(input.to_ticket()) - r = reader.to_reader() - count = 0 - for batch in r: - count += 1 - print(f"[{count}] ROWS={batch.num_rows}") - run_end = time.time() - print(f"DURATION={run_end - run_start}s") - - -@click.command() -@click.option("--host", envvar="SQLMESH_TRINO_HOST", required=True) -@click.option("--port", default=8080, type=click.INT) -@click.option("--catalog", default="metrics") -@click.option("--user", default="sqlmesh") -@click.option("--gcs-bucket", envvar="METRICS_FLIGHT_SERVER_GCS_BUCKET", required=True) -@click.option("--gcs-key-id", envvar="METRICS_FLIGHT_SERVER_GCS_KEY_ID", required=True) -@click.option("--gcs-secret", envvar="METRICS_FLIGHT_SERVER_GCS_SECRET", required=True) -@click.option( - "--worker-duckdb-path", - envvar="METRICS_FLIGHT_SERVER_WORKER_DUCKDB_PATH", - required=True, -) -@click.option("--hive-uri", envvar="METRICS_FLIGHT_SERVER_HIVE_URI", required=True) -@click.option("--image-tag", required=True) -@click.option("--threads", type=click.INT, default=16) -@click.option("--worker-memory-limit", default="90000Mi") -@click.option("--worker-memory-request", default="75000Mi") -@click.option("--scheduler-memory-limit", default="90000Mi") -@click.option("--scheduler-memory-request", default="75000Mi") -@click.option("--cluster-only/--no-cluster-only", default=False) -@click.option("--cluster-name", default="sqlmesh-flight") -@click.option("--cluster-namespace", default="sqlmesh-manual") -def main( - host: str, - port: int, - catalog: str, - user: str, - gcs_bucket: str, - gcs_key_id: str, - gcs_secret: str, - worker_duckdb_path: str, - hive_uri: str, - image_tag: str, - threads: int, - scheduler_memory_limit: str, - scheduler_memory_request: str, - worker_memory_limit: str, - worker_memory_request: str, - cluster_only: bool, - cluster_name: str, - cluster_namespace: str, -): - cluster_spec = make_new_cluster( - f"ghcr.io/opensource-observer/dagster-dask:{image_tag}", - cluster_name, - cluster_namespace, - threads=threads, - scheduler_memory_limit=scheduler_memory_limit, - scheduler_memory_request=scheduler_memory_request, - worker_memory_limit=worker_memory_limit, - worker_memory_request=worker_memory_request, - ) - - if cluster_only: - # Start the cluster - cluster = start_duckdb_cluster( - cluster_namespace, - gcs_key_id, - gcs_secret, - worker_duckdb_path, - cluster_spec=cluster_spec, - ) - try: - while True: - time.sleep(1.0) - finally: - cluster.close() - else: - cluster = KubeCluster( - name=cluster_name, - namespace=cluster_namespace, - create_mode=CreateMode.CONNECT_ONLY, - shutdown_on_close=False, - ) - server = MetricsCalculatorFlightServer( - cluster, - TrinoEngine.create( - host, - port, - user, - catalog, - ), - gcs_bucket, - exported_map={ - "sqlmesh__metrics.metrics__events_daily_to_artifact__2357434958": "export_metrics__events_daily_to_artifact__2357434958_5def5e890a984cf99f7364ce3c2bb958", - }, - ) - server.run_initialization(hive_uri, gcs_key_id, gcs_secret, worker_duckdb_path) - with server as s: - s.serve() - - -def make_new_cluster( - image: str, - cluster_id: str, - service_account_name: str, - threads: int, - scheduler_memory_request: str, - scheduler_memory_limit: str, - worker_memory_request: str, - worker_memory_limit: str, -): - spec = make_cluster_spec( - name=f"{cluster_id}", - resources={ - "requests": {"memory": scheduler_memory_request}, - "limits": {"memory": scheduler_memory_limit}, - }, - image=image, - ) - spec["spec"]["scheduler"]["spec"]["tolerations"] = [ - { - "key": "pool_type", - "effect": "NoSchedule", - "operator": "Equal", - "value": "sqlmesh-worker", - } - ] - spec["spec"]["scheduler"]["spec"]["nodeSelector"] = {"pool_type": "sqlmesh-worker"} - - spec["spec"]["worker"]["spec"]["tolerations"] = [ - { - "key": "pool_type", - "effect": "NoSchedule", - "operator": "Equal", - "value": "sqlmesh-worker", - } - ] - spec["spec"]["worker"]["spec"]["nodeSelector"] = {"pool_type": "sqlmesh-worker"} - - # Give the workers a different resource allocation - for container in spec["spec"]["worker"]["spec"]["containers"]: - container["resources"] = { - "limits": { - "memory": worker_memory_limit, - }, - "requests": { - "memory": worker_memory_request, - }, - } - volume_mounts = container.get("volumeMounts", []) - volume_mounts.append( - { - "mountPath": "/scratch", - "name": "scratch", - } - ) - if container["name"] == "worker": - args: t.List[str] = container["args"] - args.append("--nthreads") - args.append(f"{threads}") - args.append("--nworkers") - args.append("1") - args.append("--memory-limit") - args.append("0") - container["volumeMounts"] = volume_mounts - volumes = spec["spec"]["worker"]["spec"].get("volumes", []) - volumes.append( - { - "name": "scratch", - "emptyDir": {}, - } - ) - spec["spec"]["worker"]["spec"]["volumes"] = volumes - spec["spec"]["worker"]["spec"]["serviceAccountName"] = service_account_name - - return spec - - -if __name__ == "__main__": - main() diff --git a/warehouse/metrics_tools/compute/manual_testing_utils.py b/warehouse/metrics_tools/compute/manual_testing_utils.py new file mode 100644 index 000000000..bfba44453 --- /dev/null +++ b/warehouse/metrics_tools/compute/manual_testing_utils.py @@ -0,0 +1,133 @@ +"""Manual testing scripts for the metrics calculation service. + +Eventually we should replace this with a larger end-to-end test +""" + +import logging +from datetime import datetime + +import click +import requests +from metrics_tools.compute.client import Client +from pydantic_core import to_jsonable_python + +from ..definition import PeerMetricDependencyRef +from .types import ( + ClusterStartRequest, + ExportedTableLoadRequest, + ExportReference, + ExportType, +) + +logger = logging.getLogger(__name__) + + +def run_start(url: str, min=6, max=10): + req = ClusterStartRequest(min_size=min, max_size=max) + response = requests.post(f"{url}/cluster/start", json=to_jsonable_python(req)) + print(response.json()) + + +def run_cache_load(url: str): + req = ExportedTableLoadRequest( + map={ + "sqlmesh__metrics.metrics__events_daily_to_artifact__2357434958": ExportReference( + table="export_metrics__events_daily_to_artifact__2357434958_5def5e890a984cf99f7364ce3c2bb958", + type=ExportType.GCS, + payload={ + "gcs_path": "gs://oso-dataset-transfer-bucket/trino-export/export_metrics__events_daily_to_artifact__2357434958_5def5e890a984cf99f7364ce3c2bb958" + }, + ), + } + ) + response = requests.post(f"{url}/cache/manual", json=to_jsonable_python(req)) + print(response.json()) + + +def run_stop(url: str): + response = requests.post(f"{url}/cluster/stop") + print(response.json()) + + +def run_get_status(url: str, job_id: str): + response = requests.get(f"{url}/job/status/{job_id}") + print(response.json()) + + +def run_local_test( + url: str, start: str, end: str, batch_size: int, cluster_size: int = 6 +): + import sys + + logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) + + client = Client(url, log_override=logger) + + client.run_cache_manual_load( + { + "sqlmesh__metrics.metrics__events_daily_to_artifact__2357434958": ExportReference( + table="export_metrics__events_daily_to_artifact__2357434958_5def5e890a984cf99f7364ce3c2bb958", + type=ExportType.GCS, + payload={ + "gcs_path": "gs://oso-dataset-transfer-bucket/trino-export/export_metrics__events_daily_to_artifact__2357434958_5def5e890a984cf99f7364ce3c2bb958" + }, + ), + } + ) + + client.calculate_metrics( + query_str=""" + SELECT bucket_day, to_artifact_id, from_artifact_id, event_source, event_type, SUM(amount) as amount + FROM metrics.events_daily_to_artifact + where bucket_day >= strptime(@start_ds, '%Y-%m-%d') and bucket_day <= strptime(@end_ds, '%Y-%m-%d') + group by + bucket_day, + to_artifact_id, + from_artifact_id, + event_source, + event_type + """, + start=datetime.strptime(start, "%Y-%m-%d"), + end=datetime.strptime(end, "%Y-%m-%d"), + dialect="duckdb", + columns=[ + ("bucket_day", "TIMESTAMP"), + ("to_artifact_id", "VARCHAR"), + ("from_artifact_id", "VARCHAR"), + ("event_source", "VARCHAR"), + ("event_type", "VARCHAR"), + ("amount", "NUMERIC"), + ], + ref=PeerMetricDependencyRef( + name="", entity_type="artifact", window=30, unit="day" + ), + locals={}, + dependent_tables_map={ + "metrics.events_daily_to_artifact": "sqlmesh__metrics.metrics__events_daily_to_artifact__2357434958" + }, + batch_size=batch_size, + cluster_max_size=cluster_size, + cluster_min_size=cluster_size, + ) + + +@click.command() +@click.option("--url", default="http://localhost:8000") +@click.option("--batch-size", type=click.INT, default=1) +@click.option("--start", default="2024-01-01") +@click.option("--cluster-size", type=click.INT, default=6) +@click.option("--end") +def main(url: str, batch_size: int, start: str, end: str, cluster_size: int): + if not end: + end = datetime.now().strftime("%Y-%m-%d") + run_local_test( + url, + start, + end, + batch_size, + cluster_size=cluster_size, + ) + + +if __name__ == "__main__": + main() diff --git a/warehouse/metrics_tools/compute/run_get.py b/warehouse/metrics_tools/compute/run_get.py deleted file mode 100644 index 4c3f09525..000000000 --- a/warehouse/metrics_tools/compute/run_get.py +++ /dev/null @@ -1,18 +0,0 @@ -# Testing script -from metrics_tools.compute import flight -import click -from datetime import datetime - - -@click.command() -@click.option("--batch-size", type=click.INT, default=1) -@click.option("--start", default="2024-01-01") -@click.option("--end") -def main(batch_size: int, start, end): - if not end: - end = datetime.now().strftime("%Y-%m-%d") - flight.run_get(batch_size=batch_size, start=start, end=end) - - -if __name__ == "__main__": - main() diff --git a/warehouse/metrics_tools/compute/server.py b/warehouse/metrics_tools/compute/server.py new file mode 100644 index 000000000..ff9f10f71 --- /dev/null +++ b/warehouse/metrics_tools/compute/server.py @@ -0,0 +1,189 @@ +import logging +import typing as t +import uuid +from contextlib import asynccontextmanager + +import aiotrino +from dotenv import load_dotenv +from fastapi import FastAPI, Request +from metrics_tools.utils.logging import setup_module_logging + +from . import constants +from .cache import setup_fake_cache_export_manager, setup_trino_cache_export_manager +from .cluster import ( + ClusterManager, + KubeClusterFactory, + LocalClusterFactory, + make_new_cluster_with_defaults, +) +from .service import MetricsCalculationService +from .types import ( + ClusterStartRequest, + EmptyResponse, + ExportedTableLoadRequest, + QueryJobSubmitRequest, +) + +load_dotenv() +logger = logging.getLogger("uvicorn.error.application") + + +@asynccontextmanager +async def initialize_app(app: FastAPI): + # logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) + setup_module_logging("metrics_tools") + + logger.setLevel(logging.DEBUG) + + logger.info("Metrics calculation service is starting up") + if constants.debug_all: + logger.warning("Debugging all services") + + cache_export_manager = None + if not constants.debug_cache: + trino_connection = aiotrino.dbapi.connect( + host=constants.trino_host, + port=constants.trino_port, + user=constants.trino_user, + catalog=constants.trino_catalog, + ) + cache_export_manager = await setup_trino_cache_export_manager( + trino_connection, + constants.gcs_bucket, + constants.hive_catalog, + constants.hive_schema, + log_override=logger, + ) + else: + logger.warning("Loading fake cache export manager") + cache_export_manager = await setup_fake_cache_export_manager( + log_override=logger + ) + + cluster_manager = None + if not constants.debug_cluster: + cluster_spec = make_new_cluster_with_defaults() + cluster_factory = KubeClusterFactory( + constants.cluster_namespace, + cluster_spec=cluster_spec, + log_override=logger, + shutdown_on_close=not constants.debug_cluster_no_shutdown, + ) + cluster_manager = ClusterManager.with_metrics_plugin( + constants.gcs_bucket, + constants.gcs_key_id, + constants.gcs_secret, + constants.worker_duckdb_path, + cluster_factory, + log_override=logger, + ) + else: + logger.warning("Loading fake cluster manager") + cluster_factory = LocalClusterFactory() + cluster_manager = ClusterManager.with_dummy_metrics_plugin( + cluster_factory, + log_override=logger, + ) + + mcs = MetricsCalculationService.setup( + id=str(uuid.uuid4()), + gcs_bucket=constants.gcs_bucket, + result_path_prefix=constants.results_path_prefix, + cluster_manager=cluster_manager, + cache_manager=cache_export_manager, + log_override=logger, + ) + try: + yield { + "mca": mcs, + } + finally: + await mcs.close() + + +# Dependency to get the cluster manager +def get_mca(request: Request) -> MetricsCalculationService: + mca = request.state.mca + assert mca is not None + return t.cast(MetricsCalculationService, mca) + + +app = FastAPI(lifespan=initialize_app) + + +@app.get("/status") +async def get_status(): + """ + Liveness endpoint + """ + return {"status": "Service is running"} + + +@app.post("/cluster/start") +async def start_cluster( + request: Request, + start_request: ClusterStartRequest, +): + """ + Start a Dask cluster in an idempotent way + """ + state = get_mca(request) + manager = state.cluster_manager + return await manager.start_cluster(start_request.min_size, start_request.max_size) + + +@app.post("/cluster/stop") +async def stop_cluster(request: Request): + """ + Stop the Dask cluster + """ + state = get_mca(request) + manager = state.cluster_manager + return await manager.stop_cluster() + + +@app.get("/cluster/status") +async def get_cluster_status(request: Request): + """ + Get the current Dask cluster status + """ + state = get_mca(request) + manager = state.cluster_manager + return await manager.get_cluster_status() + + +@app.post("/job/submit") +async def submit_job( + request: Request, + input: QueryJobSubmitRequest, +): + """ + Submits a Dask job for calculation + """ + service = get_mca(request) + return await service.submit_job(input) + + +@app.get("/job/status/{job_id}") +async def get_job_status( + request: Request, + job_id: str, +): + """ + Get the status of a job + """ + include_stats = request.query_params.get("include_stats", "false").lower() == "true" + service = get_mca(request) + return await service.get_job_status(job_id, include_stats=include_stats) + + +@app.post("/cache/manual") +async def add_existing_exported_table_references( + request: Request, input: ExportedTableLoadRequest +): + """ + Add a table export to the cache + """ + service = get_mca(request) + await service.add_existing_exported_table_references(input.map) + return EmptyResponse() diff --git a/warehouse/metrics_tools/compute/service.py b/warehouse/metrics_tools/compute/service.py new file mode 100644 index 000000000..fd6ace95b --- /dev/null +++ b/warehouse/metrics_tools/compute/service.py @@ -0,0 +1,366 @@ +"""Main interface for computing metrics""" + +import asyncio +import copy +import logging +import os +import typing as t +import uuid +from datetime import datetime + +from dask.distributed import CancelledError, Future +from metrics_tools.compute.worker import execute_duckdb_load +from metrics_tools.runner import FakeEngineAdapter, MetricsRunner + +from .cache import CacheExportManager +from .cluster import ClusterManager +from .types import ( + ClusterStartRequest, + ClusterStatus, + ExportReference, + ExportType, + QueryJobProgress, + QueryJobState, + QueryJobStatus, + QueryJobStatusResponse, + QueryJobSubmitRequest, + QueryJobSubmitResponse, + QueryJobUpdate, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class MetricsCalculationService: + id: str + gcs_bucket: str + cluster_manager: ClusterManager + cache_manager: CacheExportManager + job_state: t.Dict[str, QueryJobState] + job_tasks: t.Dict[str, asyncio.Task] + job_state_lock: asyncio.Lock + logger: logging.Logger + + @classmethod + def setup( + cls, + id: str, + gcs_bucket: str, + result_path_prefix: str, + cluster_manager: ClusterManager, + cache_manager: CacheExportManager, + log_override: t.Optional[logging.Logger] = None, + ): + service = cls( + id, + gcs_bucket, + result_path_prefix, + cluster_manager, + cache_manager, + log_override=log_override, + ) + # service.start_job_state_listener() + return service + + def __init__( + self, + id: str, + gcs_bucket: str, + result_path_prefix: str, + cluster_manager: ClusterManager, + cache_manager: CacheExportManager, + log_override: t.Optional[logging.Logger] = None, + ): + self.id = id + self.gcs_bucket = gcs_bucket + self.result_path_prefix = result_path_prefix + self.cluster_manager = cluster_manager + self.cache_manager = cache_manager + self.job_state = {} + self.job_tasks = {} + self.job_state_lock = asyncio.Lock() + self.logger = log_override or logger + + async def handle_query_job_submit_request( + self, job_id: str, result_path_base: str, input: QueryJobSubmitRequest + ): + try: + await self._handle_query_job_submit_request(job_id, result_path_base, input) + except Exception as e: + self.logger.error(f"job[{job_id}] failed with exception: {e}") + await self._notify_job_failed(job_id, 0, 0) + + async def _handle_query_job_submit_request( + self, + job_id: str, + result_path_base: str, + input: QueryJobSubmitRequest, + ): + self.logger.info(f"job[{job_id}] waiting for cluster to be ready") + await self.cluster_manager.wait_for_ready() + self.logger.info(f"job[{job_id}] cluster ready") + + client = await self.cluster_manager.client + self.logger.info(f"job[{job_id}] waiting for dependencies to be exported") + exported_dependent_tables_map = await self.resolve_dependent_tables(input) + self.logger.info(f"job[{job_id}] dependencies exported") + + tasks: t.List[Future] = [] + + async for batch_id, batch in self.generate_query_batches( + input, input.batch_size + ): + task_id = f"{job_id}-{batch_id}" + result_path = os.path.join(result_path_base, job_id, f"{batch_id}.parquet") + + self.logger.info(f"job[{job_id}]: Submitting task {task_id}") + + # dependencies = { + # table: to_jsonable_python(reference) + # for table, reference in exported_dependent_tables_map.items() + # } + + task = client.submit( + execute_duckdb_load, + job_id, + task_id, + result_path, + batch, + exported_dependent_tables_map, + retries=input.retries, + ) + + self.logger.info(f"job[{job_id}]: Submitted task {task_id}") + tasks.append(task) + + total = len(tasks) + completed = 0 + failures = 0 + exceptions = [] + + # In the future we should replace this with the python 3.13 version of + # this. + + for finished in asyncio.as_completed(tasks): + try: + task_id = await finished + completed += 1 + self.logger.info(f"job[{job_id}] progress: {completed}/{total}") + await self._notify_job_updated(job_id, completed, total) + self.logger.info( + f"job[{job_id}] finished notifying update: {completed}/{total}" + ) + except CancelledError as e: + failures += 1 + self.logger.error(f"job[{job_id}] task cancelled {e.args}") + continue + except Exception as e: + failures += 1 + exceptions.append(e) + self.logger.error(f"job[{job_id}] task failed with exception: {e}") + continue + self.logger.info(f"job[{job_id}] awaiting finished") + + await self._notify_job_updated(job_id, completed, total) + self.logger.info(f"job[{job_id}] task_id={task_id} finished") + if failures > 0: + self.logger.error( + f"job[{job_id}] {failures} tasks failed. received {len(exceptions)} exceptions" + ) + await self._notify_job_failed(job_id, completed, total) + if len(exceptions) > 0: + for e in exceptions: + self.logger.error(f"job[{job_id}] exception received: {e}") + else: + self.logger.info(f"job[{job_id}]: done") + await self._notify_job_completed(job_id, completed, total) + + async def close(self): + await self.cluster_manager.close() + await self.cache_manager.stop() + + async def start_cluster(self, start_request: ClusterStartRequest) -> ClusterStatus: + self.logger.debug("starting cluster") + return await self.cluster_manager.start_cluster( + start_request.min_size, start_request.max_size + ) + + async def get_cluster_status(self): + return self.cluster_manager.get_cluster_status() + + async def submit_job(self, input: QueryJobSubmitRequest): + """Submit a job to the cluster to compute the metrics""" + self.logger.debug("submitting job") + job_id = str(uuid.uuid4()) + + result_path_base = os.path.join(self.result_path_prefix, job_id) + result_path = os.path.join( + f"gs://{self.gcs_bucket}", result_path_base, "*.parquet" + ) + + await self._notify_job_pending(job_id, 1) + task = asyncio.create_task( + self.handle_query_job_submit_request(job_id, result_path_base, input) + ) + async with self.job_state_lock: + self.job_tasks[job_id] = task + + return QueryJobSubmitResponse( + job_id=job_id, + export_reference=ExportReference( + table=job_id, + type=ExportType.GCS, + payload={"gcs_path": result_path}, + ), + ) + + async def _notify_job_pending(self, job_id: str, total: int): + await self._set_job_state( + job_id, + QueryJobUpdate( + updated_at=datetime.now(), + status=QueryJobStatus.PENDING, + progress=QueryJobProgress(completed=0, total=total), + ), + ) + + async def _notify_job_updated(self, job_id: str, completed: int, total: int): + await self._set_job_state( + job_id, + QueryJobUpdate( + updated_at=datetime.now(), + status=QueryJobStatus.RUNNING, + progress=QueryJobProgress(completed=completed, total=total), + ), + ) + + async def _notify_job_completed(self, job_id: str, completed: int, total: int): + await self._set_job_state( + job_id, + QueryJobUpdate( + updated_at=datetime.now(), + status=QueryJobStatus.COMPLETED, + progress=QueryJobProgress(completed=completed, total=total), + ), + ) + + async def _notify_job_failed(self, job_id: str, completed: int, total: int): + await self._set_job_state( + job_id, + QueryJobUpdate( + updated_at=datetime.now(), + status=QueryJobStatus.FAILED, + progress=QueryJobProgress(completed=completed, total=total), + ), + ) + + async def _set_job_state( + self, + job_id: str, + update: QueryJobUpdate, + ): + self.logger.debug(f"job[{job_id}] status={update.status}") + async with self.job_state_lock: + if update.status == QueryJobStatus.PENDING: + self.job_state[job_id] = QueryJobState( + job_id=job_id, + created_at=update.updated_at, + updates=[update], + ) + else: + state = self.job_state.get(job_id) + if not state: + raise ValueError(f"Job {job_id} not found") + + state.updates.append(update) + self.job_state[job_id] = state + + if ( + update.status == QueryJobStatus.COMPLETED + or update.status == QueryJobStatus.FAILED + ): + del self.job_tasks[job_id] + + async def _get_job_state(self, job_id: str): + """Get the current state of a job as a deep copy (to prevent + mutation)""" + async with self.job_state_lock: + state = copy.deepcopy(self.job_state.get(job_id)) + return state + + async def generate_query_batches( + self, input: QueryJobSubmitRequest, batch_size: int + ): + runner = MetricsRunner.from_engine_adapter( + FakeEngineAdapter("duckdb"), + input.query_as("duckdb"), + input.ref, + input.locals, + ) + + batch: t.List[str] = [] + batch_num = 0 + + async for rendered_query in runner.render_rolling_queries_async( + input.start, input.end + ): + batch.append(rendered_query) + if len(batch) >= batch_size: + yield (batch_num, batch) + batch = [] + batch_num += 1 + if len(batch) > 0: + yield (batch_num, batch) + + async def resolve_dependent_tables(self, input: QueryJobSubmitRequest): + """Resolve the dependent tables for the given input and returns the + associate export references""" + + # Dependent tables come in the form: + # { reference_table_name: actual_table_table } + + # The reference_table_name is something like + # "metrics.events_daily_to_artifact". The actual_table_name is something + # like + # "sqlmesh__metrics.events_daily_to_artifact__some_system_generated_id" + dependent_tables_map = input.dependent_tables_map + tables_to_export = list(dependent_tables_map.values()) + + # The cache manager will generate random export references for each + # table you ask to cache for use in metrics calculations. However it is + # not aware of the `reference_table_name` that the user provides. We + # need to resolve the actual table names to the export references + reverse_dependent_tables_map = {v: k for k, v in dependent_tables_map.items()} + + # First use the cache manager to resolve the export references + references = await self.cache_manager.resolve_export_references( + tables_to_export + ) + self.logger.debug(f"resolved references: {references}") + + # Now map the reference_table_names to the export references + exported_dependent_tables_map = { + reverse_dependent_tables_map[actual_name]: reference + for actual_name, reference in references.items() + } + + return exported_dependent_tables_map + + async def get_job_status( + self, job_id: str, include_stats: bool = False + ) -> QueryJobStatusResponse: + state = await self._get_job_state(job_id) + if not state: + raise ValueError(f"Job {job_id} not found") + return state.as_response(include_stats=include_stats) + + async def add_existing_exported_table_references( + self, update: t.Dict[str, ExportReference] + ): + """This is mostly used for testing purposes, but allows us to load a + previously cached table's reference into the cache manager""" + await self.cache_manager.add_export_table_references(update) + + async def inspect_exported_table_references(self): + return await self.cache_manager.inspect_export_table_references() diff --git a/warehouse/metrics_tools/compute/test_cache.py b/warehouse/metrics_tools/compute/test_cache.py new file mode 100644 index 000000000..860b82b85 --- /dev/null +++ b/warehouse/metrics_tools/compute/test_cache.py @@ -0,0 +1,33 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest +from metrics_tools.compute.cache import CacheExportManager, FakeExportAdapter +from metrics_tools.compute.types import ExportReference, ExportType + + +@pytest.mark.asyncio +async def test_cache_export_manager(): + adapter_mock = AsyncMock(FakeExportAdapter) + adapter_mock.export_table.return_value = ExportReference( + table="test", + type=ExportType.GCS, + payload={}, + ) + cache = await CacheExportManager.setup(adapter_mock) + + export_table_0 = await asyncio.wait_for( + cache.resolve_export_references(["table1", "table2"]), timeout=1 + ) + + assert export_table_0.keys() == {"table1", "table2"} + + # Attempt to export tables again but this should be mostly cache hits except + # for table3 + export_table_1 = await asyncio.wait_for( + cache.resolve_export_references(["table1", "table2", "table1", "table3"]), + timeout=1, + ) + assert export_table_1.keys() == {"table1", "table2", "table3"} + + assert adapter_mock.export_table.call_count == 3 diff --git a/warehouse/metrics_tools/compute/test_cluster.py b/warehouse/metrics_tools/compute/test_cluster.py new file mode 100644 index 000000000..88594087c --- /dev/null +++ b/warehouse/metrics_tools/compute/test_cluster.py @@ -0,0 +1,68 @@ +import asyncio + +import pytest +from dask.distributed import Client +from metrics_tools.compute.cluster import ( + ClusterFactory, + ClusterManager, + ClusterProxy, + ClusterStatus, +) + + +class FakeClient(Client): + def __init__(self): + pass + + def close(self, *args, **kwargs): + pass + + def register_worker_plugin(self, *args, **kwargs): + pass + + +class FakeClusterProxy(ClusterProxy): + def __init__(self, min_size: int, max_size: int): + self.min_size = min_size + self.max_size = max_size + + async def client(self) -> Client: + return await FakeClient() + + async def status(self): + return ClusterStatus( + status="running", + is_ready=True, + dashboard_url="", + workers=1, + ) + + async def stop(self): + return + + @property + def dashboard_link(self): + return "http://fake-dashboard.com" + + @property + def workers(self): + return 1 + + +class FakeClusterFactory(ClusterFactory): + async def create_cluster(self, min_size: int, max_size: int): + return FakeClusterProxy(min_size, max_size) + + +@pytest.mark.asyncio +async def test_cluster_manager_reports_ready(): + cluster_manager = ClusterManager.with_dummy_metrics_plugin(FakeClusterFactory()) + + ready_future = cluster_manager.wait_for_ready() + + await cluster_manager.start_cluster(1, 1) + + try: + await asyncio.wait_for(ready_future, timeout=1) + except asyncio.TimeoutError: + pytest.fail("Cluster never reported ready") diff --git a/warehouse/metrics_tools/compute/test_service.py b/warehouse/metrics_tools/compute/test_service.py new file mode 100644 index 000000000..cea67002a --- /dev/null +++ b/warehouse/metrics_tools/compute/test_service.py @@ -0,0 +1,66 @@ +import asyncio +from datetime import datetime + +import pytest +from metrics_tools.compute.cache import CacheExportManager, FakeExportAdapter +from metrics_tools.compute.cluster import ClusterManager, LocalClusterFactory +from metrics_tools.compute.service import MetricsCalculationService +from metrics_tools.compute.types import ( + ClusterStartRequest, + ExportReference, + ExportType, + QueryJobStatus, + QueryJobSubmitRequest, +) +from metrics_tools.definition import PeerMetricDependencyRef + + +@pytest.mark.asyncio +async def test_metrics_calculation_service(): + service = MetricsCalculationService.setup( + "someid", + "bucket", + "result_path_prefix", + ClusterManager.with_dummy_metrics_plugin(LocalClusterFactory()), + await CacheExportManager.setup(FakeExportAdapter()), + ) + await service.start_cluster(ClusterStartRequest(min_size=1, max_size=1)) + await service.add_existing_exported_table_references( + { + "source.table123": ExportReference( + table="export_table123", + type=ExportType.GCS, + payload={"gcs_path": "gs://bucket/result_path_prefix/export_table123"}, + ), + } + ) + response = await service.submit_job( + QueryJobSubmitRequest( + query_str="SELECT * FROM ref.table123", + start=datetime(2021, 1, 1), + end=datetime(2021, 1, 3), + dialect="duckdb", + batch_size=1, + columns=[("col1", "int"), ("col2", "string")], + ref=PeerMetricDependencyRef( + name="test", + entity_type="artifact", + window=30, + unit="day", + ), + locals={}, + dependent_tables_map={"source.table123": "source.table123"}, + ) + ) + + async def wait_for_job_to_complete(): + status = await service.get_job_status(response.job_id) + while status.status in [QueryJobStatus.PENDING, QueryJobStatus.RUNNING]: + status = await service.get_job_status(response.job_id) + await asyncio.sleep(1) + + await asyncio.wait_for(asyncio.create_task(wait_for_job_to_complete()), timeout=60) + status = await service.get_job_status(response.job_id) + assert status.status == QueryJobStatus.COMPLETED + + await service.close() diff --git a/warehouse/metrics_tools/compute/types.py b/warehouse/metrics_tools/compute/types.py new file mode 100644 index 000000000..92f1e9259 --- /dev/null +++ b/warehouse/metrics_tools/compute/types.py @@ -0,0 +1,153 @@ +import logging +import typing as t +from datetime import datetime +from enum import Enum + +from metrics_tools.definition import PeerMetricDependencyRef +from pydantic import BaseModel, Field +from sqlmesh.core.dialect import parse_one + +logger = logging.getLogger(__name__) + + +class EmptyResponse(BaseModel): + pass + + +class ExportType(str, Enum): + ICEBERG = "iceberg" + GCS = "gcs" + + +class ExportReference(BaseModel): + table: str + type: ExportType + payload: t.Dict[str, t.Any] + + +class QueryJobStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +class QueryJobProgress(BaseModel): + completed: int + total: int + + +class QueryJobUpdate(BaseModel): + updated_at: datetime + status: QueryJobStatus + progress: QueryJobProgress + + +class ClusterStatus(BaseModel): + status: str + is_ready: bool + dashboard_url: str + workers: int + + +class QueryJobSubmitRequest(BaseModel): + query_str: str + start: datetime + end: datetime + dialect: str + batch_size: int + columns: t.List[t.Tuple[str, str]] + ref: PeerMetricDependencyRef + locals: t.Dict[str, t.Any] + dependent_tables_map: t.Dict[str, str] + retries: t.Optional[int] = None + + def query_as(self, dialect: str) -> str: + return parse_one(self.query_str, self.dialect).sql(dialect=dialect) + + +class QueryJobSubmitResponse(BaseModel): + job_id: str + export_reference: ExportReference + + +class QueryJobStatusResponse(BaseModel): + job_id: str + created_at: datetime + updated_at: datetime + status: QueryJobStatus + progress: QueryJobProgress + stats: t.Dict[str, float] = Field(default_factory=dict) + + +class QueryJobState(BaseModel): + job_id: str + created_at: datetime + updates: t.List[QueryJobUpdate] + + def latest_update(self) -> QueryJobUpdate: + return self.updates[-1] + + def as_response(self, include_stats: bool = False) -> QueryJobStatusResponse: + # Turn update events into stats + stats = {} + if include_stats: + # Calculate the time between each status change + pending_to_running = None + running_to_completed = None + running_to_failed = None + + for update in self.updates: + if ( + update.status == QueryJobStatus.RUNNING + and pending_to_running is None + ): + pending_to_running = update.updated_at + elif ( + update.status == QueryJobStatus.COMPLETED + and running_to_completed is None + ): + running_to_completed = update.updated_at + elif ( + update.status == QueryJobStatus.FAILED and running_to_failed is None + ): + running_to_failed = update.updated_at + + if pending_to_running: + stats["pending_to_running_seconds"] = ( + pending_to_running - self.created_at + ).total_seconds() + if running_to_completed: + stats["running_to_completed_seconds"] = ( + (running_to_completed - pending_to_running).total_seconds() + if pending_to_running + else None + ) + if running_to_failed: + stats["running_to_failed_seconds"] = ( + (running_to_failed - pending_to_running).total_seconds() + if pending_to_running + else None + ) + + return QueryJobStatusResponse( + job_id=self.job_id, + created_at=self.created_at, + updated_at=self.latest_update().updated_at, + status=self.latest_update().status, + progress=self.latest_update().progress, + stats=stats, + ) + + +class ClusterStartRequest(BaseModel): + min_size: int + max_size: int + + +class ExportedTableLoadRequest(BaseModel): + map: t.Dict[str, ExportReference] + + +class InspectCacheResponse(BaseModel): + map: t.Dict[str, ExportReference] diff --git a/warehouse/metrics_tools/compute/worker.py b/warehouse/metrics_tools/compute/worker.py index e88e1006f..d974bc25d 100644 --- a/warehouse/metrics_tools/compute/worker.py +++ b/warehouse/metrics_tools/compute/worker.py @@ -1,44 +1,73 @@ # The worker initialization -import abc +import io +import logging import os -from metrics_tools.utils.logging import add_metrics_tools_to_existing_logger -import pandas as pd +import time import typing as t -import duckdb import uuid -from sqlglot import exp -from dask.distributed import WorkerPlugin, Worker -import logging -import sys -from threading import Lock -from google.cloud import storage from contextlib import contextmanager +from threading import Lock - -from pyiceberg.catalog import load_catalog -from pyiceberg.table import Table as IcebergTable +import duckdb +import pandas as pd +from dask.distributed import Worker, WorkerPlugin, get_worker +from google.cloud import storage +from metrics_tools.compute.types import ExportReference, ExportType +from metrics_tools.utils.logging import setup_module_logging +from sqlglot import exp logger = logging.getLogger(__name__) mutex = Lock() -class DuckDBWorkerInterface(abc.ABC): - def fetchdf(self, query: str) -> pd.DataFrame: - raise NotImplementedError("fetchdf not implemented") +class MetricsWorkerPlugin(WorkerPlugin): + logger: logging.Logger + + def setup(self, worker: Worker): + setup_module_logging("metrics_tools") + logger.info("setting up metrics worker plugin") + def teardown(self, worker: Worker): + logger.info("tearing down metrics worker plugin") -class MetricsWorkerPlugin(WorkerPlugin): + def handle_query( + self, + job_id: str, + task_id: str, + result_path: str, + queries: t.List[str], + dependencies: t.Dict[str, ExportReference], + ) -> t.Any: + """Execute a query on the worker""" + raise NotImplementedError() + + +class DummyMetricsWorkerPlugin(MetricsWorkerPlugin): + def handle_query( + self, + job_id: str, + task_id: str, + result_path: str, + queries: t.List[str], + dependencies: t.Dict[str, ExportReference], + ) -> t.Any: + logger.info(f"job[{job_id}][{task_id}]: dummy executing query {queries}") + logger.info(f"job[{job_id}][{task_id}]: deps received: {dependencies}") + logger.info(f"job[{job_id}][{task_id}]: result_path: {result_path}") + time.sleep(1) + return task_id + + +class DuckDBMetricsWorkerPlugin(MetricsWorkerPlugin): def __init__( self, gcs_bucket: str, - hive_uri: str, gcs_key_id: str, gcs_secret: str, duckdb_path: str, ): self._gcs_bucket = gcs_bucket - self._hive_uri = hive_uri self._gcs_key_id = gcs_key_id self._gcs_secret = gcs_secret self._duckdb_path = duckdb_path @@ -47,19 +76,16 @@ def __init__( self._catalog = None self._mode = "duckdb" self._uuid = uuid.uuid4().hex + self.logger = logger def setup(self, worker: Worker): - add_metrics_tools_to_existing_logger("distributed") - logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + setup_module_logging("metrics_tools") + logger.info("setting up metrics worker plugin") self._conn = duckdb.connect(self._duckdb_path) - # Connect to iceberg if this is a remote worker - worker.log_event("info", "what") + # Connect to gcs sql = f""" - INSTALL iceberg; - LOAD iceberg; - CREATE SECRET secret1 ( TYPE GCS, KEY_ID '{self._gcs_key_id}', @@ -67,14 +93,6 @@ def setup(self, worker: Worker): ); """ self._conn.sql(sql) - self._catalog = load_catalog( - "metrics", - **{ - "uri": self._hive_uri, - "gcs.project-id": "opensource-observer", - "gcs.access": "read_only", - }, - ) def teardown(self, worker: Worker): if self._conn: @@ -88,12 +106,17 @@ def connection(self): def get_for_cache( self, table_ref_name: str, - table_actual_name: str, + export_reference: ExportReference, ): """Checks if a table is cached in the local duckdb""" logger.info( - f"[{self._uuid}] got a cache request for {table_ref_name}:{table_actual_name}" + f"[{self._uuid}] got a cache request for {table_ref_name}:{export_reference.table}" ) + assert export_reference.type == ExportType.GCS, "Only GCS exports are supported" + assert ( + export_reference.payload.get("gcs_path") is not None + ), "A gcs_path is required" + if self._cache_status.get(table_ref_name): return with mutex: @@ -101,74 +124,26 @@ def get_for_cache( return destination_table = exp.to_table(table_ref_name) - # if self._mode == "duckdb": - # self.load_using_duckdb( - # table_ref_name, table_actual_name, destination_table, table - # ) - # else: - # self.load_using_pyiceberg( - # table_ref_name, table_actual_name, destination_table, table - # ) - self.load_using_gcs_parquet( - table_ref_name, table_actual_name, destination_table - ) - - self._cache_status[table_ref_name] = True - - def load_using_duckdb( - self, - table_ref_name: str, - table_actual_name: str, - destination_table: exp.Table, - ): - source_table = exp.to_table(table_actual_name) - assert self._catalog is not None - table = self._catalog.load_table((source_table.db, source_table.this.this)) + gcs_path = export_reference.payload["gcs_path"] - self.connection.execute(f"CREATE SCHEMA IF NOT EXISTS {destination_table.db}") - caching_sql = f""" - CREATE TABLE IF NOT EXISTS {destination_table.db}.{destination_table.this.this} AS - SELECT * FROM iceberg_scan('{table.metadata_location}') - """ - logger.info(f"CACHING TABLE {table_ref_name} WITH SQL: {caching_sql}") - self.connection.sql(caching_sql) - logger.info(f"CACHING TABLE {table_ref_name} COMPLETED") + self.load_using_gcs_parquet(table_ref_name, gcs_path, destination_table) - def load_using_pyiceberg( - self, - table_ref_name: str, - table_actual_name: str, - destination_table: exp.Table, - table: IcebergTable, - ): - source_table = exp.to_table(table_actual_name) - assert self._catalog is not None - table = self._catalog.load_table((source_table.db, source_table.this.this)) - batch_reader = table.scan().to_arrow_batch_reader() # noqa: F841 - self.connection.execute(f"CREATE SCHEMA IF NOT EXISTS {destination_table.db}") - logger.info(f"CACHING TABLE {table_ref_name} WITH ICEBERG") - self.connection.sql( - f""" - CREATE TABLE IF NOT EXISTS {destination_table.db}.{destination_table.this.this} AS - SELECT * FROM batch_reader - """ - ) - logger.info(f"CACHING TABLE {table_ref_name} COMPLETED") + self._cache_status[table_ref_name] = True def load_using_gcs_parquet( self, table_ref_name: str, - table_actual_name: str, + gcs_path: str, destination_table: exp.Table, ): self.connection.execute(f"CREATE SCHEMA IF NOT EXISTS {destination_table.db}") logger.info(f"CACHING TABLE {table_ref_name} WITH PARQUET") - self.connection.sql( - f""" - CREATE TABLE IF NOT EXISTS {destination_table.db}.{destination_table.this.this} AS - SELECT * FROM read_parquet('gs://{self._gcs_bucket}/trino-export/{table_actual_name}/*') + cache_sql = f""" + CREATE TABLE IF NOT EXISTS "{destination_table.db}"."{destination_table.this.this}" AS + SELECT * FROM read_parquet('{gcs_path}/*') """ - ) + logger.debug(f"executing: {cache_sql}") + self.connection.sql(cache_sql) logger.info(f"CACHING TABLE {table_ref_name} COMPLETED") @contextmanager @@ -185,3 +160,94 @@ def bucket(self): def bucket_path(self, *joins: str): return os.path.join(f"gs://{self.bucket}", *joins) + + def upload_to_gcs_bucket(self, blob_path: str, file: t.IO): + with self.gcs_client() as client: + bucket = client.bucket(self._gcs_bucket) + blob = bucket.blob(blob_path) + blob.upload_from_file(file) + + def handle_query( + self, + job_id: str, + task_id: str, + result_path: str, + queries: t.List[str], + dependencies: t.Dict[str, ExportReference], + ) -> t.Any: + """Execute a duckdb load on a worker. + + This executes the query with duckdb and writes the results to a gcs path. + """ + for ref, actual in dependencies.items(): + self.logger.info( + f"job[{job_id}][{task_id}] Loading cache for {ref}:{actual}" + ) + self.get_for_cache(ref, actual) + conn = self.connection + results: t.List[pd.DataFrame] = [] + for query in queries: + self.logger.info(f"job[{job_id}][{task_id}]: Executing query {query}") + result = conn.execute(query).df() + results.append(result) + # Concatenate the results + self.logger.info(f"job[{job_id}][{task_id}]: Concatenating results") + results_df = pd.concat(results) + + # Export the results to a parquet file in memory + self.logger.info(f"job[{job_id}][{task_id}]: Writing to in memory parquet") + inmem_file = io.BytesIO() + results_df.to_parquet(inmem_file) + inmem_file.seek(0) + + # Upload the parquet to gcs + self.logger.info(f"job[{job_id}][{task_id}]: Uploading to gcs {result_path}") + self.upload_to_gcs_bucket(result_path, inmem_file) + return task_id + + +def execute_duckdb_load( + job_id: str, + task_id: str, + result_path: str, + queries: t.List[str], + dependencies: t.Dict[str, ExportReference], +): + """Execute a duckdb load on a worker. + + This executes the query with duckdb and writes the results to a gcs path. + """ + worker = get_worker() + + # The metrics plugin keeps a record of the cached tables on the worker. + plugin = t.cast(MetricsWorkerPlugin, worker.plugins["metrics"]) + plugin.handle_query(job_id, task_id, result_path, queries, dependencies) + + return task_id + + +def bad_execute(*args, **kwargs): + """Intentionally throws an exception + + Used for testing error handling + """ + worker = get_worker() + + # The metrics plugin keeps a record of the cached tables on the worker. + plugin = t.cast(MetricsWorkerPlugin, worker.plugins["metrics"]) + plugin.logger.info("Intentionally throwing an exception") + + raise ValueError("Intentionally throwing an exception") + + +def noop_execute(job_id: str, task_id: str, *args, **kwargs): + """Does nothing + + Used for testing + """ + worker = get_worker() + + # The metrics plugin keeps a record of the cached tables on the worker. + plugin = t.cast(MetricsWorkerPlugin, worker.plugins["metrics"]) + plugin.logger.info("Doing nothing") + return task_id diff --git a/warehouse/metrics_tools/compute/wrapper.py b/warehouse/metrics_tools/compute/wrapper.py deleted file mode 100644 index d8a10088d..000000000 --- a/warehouse/metrics_tools/compute/wrapper.py +++ /dev/null @@ -1,13 +0,0 @@ -# A very basic cli or function wrapper that starts a dask cluster and injects an -# environment variable for it that sqlmesh can use. - -import subprocess -import sys - - -def cli(): - subprocess.run(sys.argv[1:]) - - -if __name__ == "__main__": - cli() diff --git a/warehouse/metrics_tools/definition.py b/warehouse/metrics_tools/definition.py index 0f5b4dbfd..5e91091eb 100644 --- a/warehouse/metrics_tools/definition.py +++ b/warehouse/metrics_tools/definition.py @@ -6,17 +6,9 @@ import sqlglot from sqlglot import exp -from sqlglot.optimizer.qualify import qualify from sqlmesh.core.macros import MacroEvaluator from sqlmesh.utils.date import TimeLike -from .dialect.translate import ( - CustomFuncHandler, - CustomFuncRegistry, -) -from .evaluator import FunctionsTransformer -from .utils import exp_literal_to_py_literal - CURR_DIR = os.path.dirname(__file__) QUERIES_DIR = os.path.abspath(os.path.join(CURR_DIR, "../metrics_mesh/oso_metrics")) @@ -98,22 +90,6 @@ class MetricModelRef(t.TypedDict): time_aggregation: t.NotRequired[t.Optional[str]] -def model_meta_matches_peer_dependency(meta: MetricModelRef, dep: MetricModelRef): - if meta["name"] != dep["name"]: - return False - if isinstance(dep["entity_type"], list): - if not meta["entity_type"] not in dep["entity_type"]: - return False - else: - if meta["entity_type"] != dep["entity_type"]: - if dep["entity_type"] != "*": - return False - dep_window = dep.get("window") - if dep_window: - if isinstance(dep_window, list): - pass - - @dataclass(kw_only=True) class PeerMetricDependencyDataClass: name: str @@ -250,52 +226,6 @@ def resolve_table_name( return model_name -class PeerRefRelativeWindowHandler(CustomFuncHandler): - pass - - -class PeerRefHandler(CustomFuncHandler[PeerMetricDependencyRef]): - name = "metrics_peer_ref" - - def to_obj( - self, - name, - *, - entity_type: t.Optional[exp.Expression] = None, - window: t.Optional[exp.Expression] = None, - unit: t.Optional[exp.Expression] = None, - time_aggregation: t.Optional[exp.Expression] = None, - ) -> PeerMetricDependencyRef: - entity_type_val = ( - t.cast(str, exp_literal_to_py_literal(entity_type)) if entity_type else "" - ) - window_val = int(exp_literal_to_py_literal(window)) if window else None - unit_val = t.cast(str, exp_literal_to_py_literal(unit)) if unit else None - time_aggregation_val = ( - t.cast(str, exp_literal_to_py_literal(time_aggregation)) - if time_aggregation - else None - ) - return PeerMetricDependencyRef( - name=name, - entity_type=entity_type_val, - window=window_val, - unit=unit_val, - time_aggregation=time_aggregation_val, - ) - - def transform( - self, - evaluator: MacroEvaluator, - context: t.Dict[str, t.Any], - obj: PeerMetricDependencyRef, - ) -> exp.Expression: - if not obj.get("entity_type"): - obj["entity_type"] = context["entity_type"] - peer_table_map = context["peer_table_map"] - return sqlglot.to_table(f"metrics.{to_actual_table_name(obj, peer_table_map)}") - - class MetricQueryContext: def __init__(self, source: MetricQueryDef, expressions: t.List[exp.Expression]): self._expressions = expressions @@ -452,298 +382,6 @@ def metric_type(self): # This _shouldn't_ happen raise Exception("unknown metric type") - def generate_query_ref( - self, - ref: PeerMetricDependencyRef, - evaluator: MacroEvaluator, - peer_table_map: t.Dict[str, str], - ): - sources_database = evaluator.locals.get("oso_source") or "default" - if ref["entity_type"] == "artifact": - return self.generate_artifact_query( - evaluator, - ref["name"], - peer_table_map, - window=ref.get("window"), - unit=ref.get("unit"), - time_aggregation=ref.get("time_aggregation"), - ) - elif ref["entity_type"] == "project": - return self.generate_project_query( - evaluator, - ref["name"], - sources_database, - peer_table_map, - window=ref.get("window"), - unit=ref.get("unit"), - time_aggregation=ref.get("time_aggregation"), - ) - elif ref["entity_type"] == "collection": - return self.generate_collection_query( - evaluator, - ref["name"], - sources_database, - peer_table_map, - window=ref.get("window"), - unit=ref.get("unit"), - time_aggregation=ref.get("time_aggregation"), - ) - raise Exception(f"Invalid entity_type {ref["entity_type"]}") - - def generate_metrics_query( - self, - evaluator: MacroEvaluator, - name: str, - peer_table_map: t.Dict[str, str], - entity_type: str, - window: t.Optional[int] = None, - unit: t.Optional[str] = None, - time_aggregation: t.Optional[str] = None, - ): - """This takes the actual metrics query and tranforms it for a specific - window/aggregation setting.""" - context = self.expression_context() - - extra_vars: t.Dict[str, ExtraVarType] = { - "entity_type": entity_type, - } - if window: - extra_vars["rolling_window"] = window - if unit: - extra_vars["rolling_unit"] = unit - if time_aggregation: - extra_vars["time_aggregation"] = time_aggregation - - metrics_query = context.evaluate( - name, - evaluator, - extra_vars, - ) - # Rewrite all of the table peer references. We do this last so that all - # macros are resolved when rewriting the anonymous functions - peer_handler = PeerRefHandler() - registry = CustomFuncRegistry() - registry.register(peer_handler) - transformer = FunctionsTransformer( - registry, - evaluator, - context={ - "peer_table_map": peer_table_map, - "entity_type": entity_type, - }, - ) - return transformer.transform(metrics_query) - - def generate_artifact_query( - self, - evaluator: MacroEvaluator, - name: str, - peer_table_map: t.Dict[str, str], - window: t.Optional[int] = None, - unit: t.Optional[str] = None, - time_aggregation: t.Optional[str] = None, - ): - metrics_query = self.generate_metrics_query( - evaluator, - name, - peer_table_map, - "artifact", - window, - unit, - time_aggregation, - ) - - top_level_select = exp.select( - "metrics_sample_date as metrics_sample_date", - "to_artifact_id as to_artifact_id", - "from_artifact_id as from_artifact_id", - "event_source as event_source", - "metric as metric", - "CAST(amount AS Float64) as amount", - ).from_("metrics_query") - - top_level_select = top_level_select.with_("metrics_query", as_=metrics_query) - return top_level_select - - def artifact_to_upstream_entity_transform( - self, - entity_type: str, - sources_database: str, - ) -> t.Callable[[exp.Expression], exp.Expression | None]: - def _transform(node: exp.Expression): - if not isinstance(node, exp.Select): - return node - select = node - - # Check if this using the timeseries source tables as a join or the from - is_using_timeseries_source = False - for table in select.find_all(exp.Table): - if table.this.this in ["events_daily_to_artifact"]: - is_using_timeseries_source = True - if not is_using_timeseries_source: - return node - - for i in range(len(select.expressions)): - ex = select.expressions[i] - if not isinstance(ex, exp.Alias): - continue - - # If to_artifact_id is being aggregated then it's time to rewrite - if isinstance(ex.this, exp.Column) and isinstance( - ex.this.this, exp.Identifier - ): - if ex.this.this.this == "to_artifact_id": - updated_select = select.copy() - current_from = t.cast(exp.From, updated_select.args.get("from")) - assert isinstance(current_from.this, exp.Table) - current_table = current_from.this - current_alias = current_table.alias - - # Add a join to this select - updated_select = updated_select.join( - f"{sources_database}.artifacts_by_project_v1", - on=f"{current_alias}.to_artifact_id = artifacts_by_project_v1.artifact_id", - join_type="inner", - ) - - new_to_entity_id_col = exp.to_column( - "artifacts_by_project_v1.project_id", quoted=True - ) - new_to_entity_alias = exp.to_identifier( - "to_project_id", quoted=True - ) - - if entity_type == "collection": - updated_select = updated_select.join( - f"{sources_database}.projects_by_collection_v1", - on="artifacts_by_project_v1.project_id = projects_by_collection_v1.project_id", - join_type="inner", - ) - - new_to_entity_id_col = exp.to_column( - "projects_by_collection_v1.collection_id", quoted=True - ) - new_to_entity_alias = exp.to_identifier( - "to_collection_id", quoted=True - ) - - # replace the select and the grouping with the project id in the joined table - to_artifact_id_col_sel = t.cast( - exp.Alias, updated_select.expressions[i] - ) - current_to_artifact_id_col = t.cast( - exp.Column, to_artifact_id_col_sel.this - ) - - to_artifact_id_col_sel.replace( - exp.alias_( - new_to_entity_id_col, - alias=new_to_entity_alias, - ) - ) - - group = t.cast(exp.Group, updated_select.args.get("group")) - for group_idx in range(len(group.expressions)): - group_col = t.cast(exp.Column, group.expressions[group_idx]) - if group_col == current_to_artifact_id_col: - group_col.replace(new_to_entity_id_col) - - return updated_select - # If nothing happens in the for loop then we didn't find the kind of - # expected select statement - return node - - return _transform - - def transform_aggregating_selects( - self, - expression: exp.Expression, - cb: t.Callable[[exp.Expression], exp.Expression | None], - ): - return expression.transform(cb) - - def generate_project_query( - self, - evaluator: MacroEvaluator, - name: str, - sources_database: str, - peer_table_map: t.Dict[str, str], - window: t.Optional[int] = None, - unit: t.Optional[str] = None, - time_aggregation: t.Optional[str] = None, - ): - metrics_query = self.generate_metrics_query( - evaluator, - name, - peer_table_map, - "project", - window, - unit, - time_aggregation, - ) - - # We use qualify to ensure that references are properly aliased. This - # seems to make transforms more reliable/testable/predictable. - metrics_query = qualify(metrics_query) - - metrics_query = self.transform_aggregating_selects( - metrics_query, - self.artifact_to_upstream_entity_transform("project", sources_database), - ) - - top_level_select = exp.select( - "metrics_sample_date as metrics_sample_date", - "to_project_id as to_project_id", - "from_artifact_id as from_artifact_id", - "event_source as event_source", - "metric as metric", - "CAST(amount AS Float64) as amount", - ).from_("metrics_query") - - top_level_select = top_level_select.with_("metrics_query", as_=metrics_query) - return top_level_select - - def generate_collection_query( - self, - evaluator: MacroEvaluator, - name: str, - sources_database: str, - peer_table_map: t.Dict[str, str], - window: t.Optional[int] = None, - unit: t.Optional[str] = None, - time_aggregation: t.Optional[str] = None, - ): - metrics_query = self.generate_metrics_query( - evaluator, - name, - peer_table_map, - "collection", - window, - unit, - time_aggregation, - ) - - # We use qualify to ensure that references are properly aliased. This - # seems to make transforms more reliable/testable/predictable. - metrics_query = qualify(metrics_query) - - metrics_query = self.transform_aggregating_selects( - metrics_query, - self.artifact_to_upstream_entity_transform("collection", sources_database), - ) - - top_level_select = exp.select( - "metrics_sample_date as metrics_sample_date", - "to_collection_id as to_collection_id", - "from_artifact_id as from_artifact_id", - "event_source as event_source", - "metric as metric", - "CAST(amount AS Float64) as amount", - ).from_("metrics_query") - - top_level_select = top_level_select.with_("metrics_query", as_=metrics_query) - return top_level_select - def find_query_expressions(expressions: t.List[exp.Expression]): return list(filter(lambda a: isinstance(a, exp.Query), expressions)) @@ -756,46 +394,6 @@ class DailyTimeseriesRollingWindowOptions(t.TypedDict): model_options: t.NotRequired[t.Dict[str, t.Any]] -def join_all_of_entity_type( - evaluator: MacroEvaluator, *, db: str, tables: t.List[str], columns: t.List[str] -): - # A bit of a hack but we know we have a "metric" column. We want to - # transform this metric id to also include the event_source as a prefix to - # that metric id in the joined table - transformed_columns = [] - for column in columns: - if column == "event_source": - continue - if column == "metric": - transformed_columns.append( - exp.alias_( - exp.Concat( - expressions=[ - exp.to_column("event_source"), - exp.Literal(this="_", is_string=True), - exp.to_column(column), - ], - safe=False, - coalesce=False, - ), - alias="metric", - ) - ) - else: - transformed_columns.append(column) - - query = exp.select(*transformed_columns).from_( - sqlglot.to_table(f"{db}.{tables[0]}") - ) - for table in tables[1:]: - query = query.union( - exp.select(*transformed_columns).from_(sqlglot.to_table(f"{db}.{table}")), - distinct=False, - ) - # Calculate the correct metric_id for all of the entity types - return query - - class TimeseriesMetricsOptions(t.TypedDict): model_prefix: str catalog: str @@ -813,30 +411,3 @@ class GeneratedArtifactConfig(t.TypedDict): peer_table_tuples: t.List[t.Tuple[str, str]] ref: PeerMetricDependencyRef timeseries_sources: t.List[str] - - -def generated_entity( - evaluator: MacroEvaluator, - query_reference_name: str, - query_def_as_input: MetricQueryInput, - default_dialect: str, - peer_table_tuples: t.List[t.Tuple[str, str]], - ref: PeerMetricDependencyRef, - timeseries_sources: t.List[str], -): - query_def = MetricQueryDef.from_input(query_def_as_input) - query = MetricQuery.load( - name=query_reference_name, - default_dialect=default_dialect, - source=query_def, - queries_dir=QUERIES_DIR, - ) - peer_table_map = dict(peer_table_tuples) - e = query.generate_query_ref( - ref, - evaluator, - peer_table_map=peer_table_map, - ) - if not e: - raise Exception("failed to generate query ref") - return e diff --git a/warehouse/metrics_tools/docs/compute.mmd b/warehouse/metrics_tools/docs/compute.mmd new file mode 100644 index 000000000..c3aeb8157 --- /dev/null +++ b/warehouse/metrics_tools/docs/compute.mmd @@ -0,0 +1,49 @@ +sequenceDiagram + participant dagster as Dagster + participant sqlmesh as SQLMesh + participant m_api as Metrics Frontend API + participant m_cluster as Metrics Worker Cluster + participant trino as TrinoDB + participant iceberg as Iceberg Data Lake (SDK) + participant storage as Cloud Storage + + dagster->>m_api: Initialize (through k8s api) + m_api->>+m_cluster: Start worker cluster (through dask) + m_cluster-->>-m_api: Workers Ready + dagster->>+m_api: Poll for Ready + m_api-->>-dagster: Ready + + dagster->>+sqlmesh: Trigger Run + alt non-metrics models + sqlmesh->>+trino: Trigger query for model batch + trino->>+iceberg: Write model batch + iceberg->>+storage: Write to storage + storage-->>-iceberg: Completed + iceberg-->>-trino: Completed + trino-->>-sqlmesh: Committed + end + alt metrics models + sqlmesh->>+m_api: Send metrics query, total time range, and dependent tables + m_api->>+trino: Export dependent tables to "hive" directory for cache + trino->>+storage: Write table to parquet files in a "hive" directory + storage-->>-trino: Completed + trino-->>-m_api: Completed + + m_api->>+m_cluster: Send metrics query for an interval of the total time range and exported dependent table reference + m_cluster->>+storage: Download exported tables to a local duckdb + storage-->>-m_cluster: Completed + m_cluster->>m_cluster: Run metrics query for a time range + m_cluster->>+storage: Upload query results as parquet files + storage-->>-m_cluster: Completed + m_cluster-->>-m_api: Completed + m_api-->>-sqlmesh: Completed + sqlmesh->>+trino: Create an external table to the parquet files in cloud storage + trino-->>-sqlmesh: Completed + sqlmesh-->>+trino: Trigger query to import the parquet files into an iceberg table + trino->>+iceberg: Write storage parquet files to iceberg + iceberg-->>-trino: Completed + trino-->>-sqlmesh: Completed + end + sqlmesh-->>-dagster: Completed + + diff --git a/warehouse/metrics_tools/evaluator.py b/warehouse/metrics_tools/evaluator.py deleted file mode 100644 index 354989dd3..000000000 --- a/warehouse/metrics_tools/evaluator.py +++ /dev/null @@ -1,29 +0,0 @@ -import typing as t - -from .dialect.translate import ( - CustomFuncRegistry, - send_anonymous_to_callable, -) -from sqlmesh.core.macros import MacroEvaluator -from sqlglot import exp - - -class FunctionsTransformer: - def __init__( - self, - registry: CustomFuncRegistry, - evaluator: MacroEvaluator, - context: t.Dict[str, t.Any], - ): - self._registry = registry - self._evaluator = evaluator - self._context = context - - def transform(self, expression: exp.Expression): - expression = expression.copy() - for anon in expression.find_all(exp.Anonymous): - handler = self._registry.get(anon.this) - if handler: - obj = send_anonymous_to_callable(anon, handler.to_obj) - anon.replace(handler.transform(self._evaluator, self._context, obj)) - return expression diff --git a/warehouse/metrics_tools/runner.py b/warehouse/metrics_tools/runner.py index acd76624a..16a4b2f56 100644 --- a/warehouse/metrics_tools/runner.py +++ b/warehouse/metrics_tools/runner.py @@ -3,6 +3,7 @@ import duckdb import arrow import logging +import asyncio from metrics_tools.utils.glot import str_or_expressions from sqlmesh import EngineAdapter from sqlmesh.core.context import ExecutionContext @@ -215,6 +216,16 @@ def render_rolling_queries(self, start: datetime, end: datetime) -> t.Iterator[s rendered_query = self.render_query(day.datetime, day.datetime) yield rendered_query + async def render_rolling_queries_async(self, start: datetime, end: datetime): + logger.debug( + f"render_rolling_rolling_async called with start={start} and end={end}" + ) + for day in arrow.Arrow.range("day", arrow.get(start), arrow.get(end)): + rendered_query = await asyncio.to_thread( + self.render_query, day.datetime, day.datetime + ) + yield rendered_query + def commit(self, start: datetime, end: datetime, destination: str): """Like run but commits the result to the database""" try: diff --git a/warehouse/metrics_tools/utils/env.py b/warehouse/metrics_tools/utils/env.py new file mode 100644 index 000000000..1bbba1f30 --- /dev/null +++ b/warehouse/metrics_tools/utils/env.py @@ -0,0 +1,45 @@ +import os +import typing as t + + +def required_var[T](var: str, default: t.Optional[T] = None): + value = os.environ.get(var, default) + assert value, f"{var} is required" + return value + + +def required_int(var: str, default: t.Optional[int] = None): + """Ensures an environment variables is an integer""" + return int(required_var(var, default)) + + +def required_str(var: str, default: t.Optional[str] = None): + return required_var(var, default) + + +def required_bool(var: str, default: t.Optional[bool] = None): + resolved = required_var(var, str(default)) + if isinstance(resolved, bool): + return resolved + else: + return var.lower() in ["true", "1"] + + +def ensure_var[T](var: str, default: T, converter: t.Callable[[str], T]): + try: + value = os.environ[var] + return converter(value) + except KeyError: + return default + + +def ensure_str(var: str, default: str): + return ensure_var(var, default, str) + + +def ensure_int(var: str, default: int): + return ensure_var(var, default, int) + + +def ensure_bool(var: str, default: bool = False): + return ensure_var(var, default, lambda a: a.lower() in ["true", "1"]) diff --git a/warehouse/metrics_tools/utils/logging.py b/warehouse/metrics_tools/utils/logging.py index d3f4dcc92..2f37253ff 100644 --- a/warehouse/metrics_tools/utils/logging.py +++ b/warehouse/metrics_tools/utils/logging.py @@ -1,5 +1,7 @@ +import typing as t import logging import os +import sys connected_to_sqlmesh_logs = False @@ -31,3 +33,43 @@ def filter(self, record): app_logger = logging.getLogger(logger_name) app_logger.addFilter(MetricsToolsFilter()) + + +class ModuleFilter(logging.Filter): + """Allows logs only from the specified module.""" + + def __init__(self, module_name): + super().__init__() + self.module_name = module_name + + def filter(self, record): + return record.name.startswith(self.module_name) + + +def setup_multiple_modules_logging(module_names: t.List[str]): + for module_name in module_names: + setup_module_logging(module_name) + + +# Configure logging +def setup_module_logging( + module_name: str, + level: int = logging.DEBUG, + format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s", +): + logger = logging.getLogger(module_name) + logger.setLevel(level) # Adjust the level as needed + + # Create a handler that logs to stdout + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(level) # Adjust the level as needed + + # Add the filter to the handler + stdout_handler.addFilter(ModuleFilter(module_name)) + + # Set a formatter (optional) + formatter = logging.Formatter(format, datefmt="%Y-%m-%dT%H:%M:%S") + stdout_handler.setFormatter(formatter) + + # Add the handler to the logger + logger.addHandler(stdout_handler)