diff --git a/.github/workflows/ci-default.yml b/.github/workflows/ci-default.yml index 3e48cc624..31584c0f0 100644 --- a/.github/workflows/ci-default.yml +++ b/.github/workflows/ci-default.yml @@ -60,7 +60,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - component: ["node", "python", "dbt"] + component: ["node", "python", "dbt", "sqlmesh"] steps: - name: Checkout code uses: actions/checkout@v4 @@ -176,5 +176,10 @@ jobs: run: | poetry run pytest if: ${{ always() && matrix.component == 'python' }} + + - name: Test SQLMesh + run: | + poetry install && cd warehouse/metrics_mesh && poetry run -C ../../ sqlmesh test + if: ${{ always() && matrix.component == 'sqlmesh' }} diff --git a/pyproject.toml b/pyproject.toml index 7de9b585b..4ec39807c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ packages = [ { include = "bq2cloudsql", from = "warehouse/" }, { include = "common", from = "warehouse/" }, { include = "oso_dagster", from = "warehouse/" }, + { include = "metrics_tools", from = "warehouse/" }, ] [tool.poetry.dependencies] @@ -54,7 +55,7 @@ sqlalchemy = "^2.0.25" textual = "^0.52.1" redis = "^5.0.7" githubkit = "^0.11.6" -sqlmesh = {extras = ["trino"], version = "^0.129.0"} +sqlmesh = { extras = ["trino"], version = "^0.129.0" } dagster-duckdb = "^0.24.0" dagster-duckdb-polars = "^0.24.0" google-cloud-bigquery-storage = "^2.25.0" diff --git a/warehouse/metrics_mesh/models/metrics_factories.py b/warehouse/metrics_mesh/models/metrics_factories.py index 2c045606e..16ef7f201 100644 --- a/warehouse/metrics_mesh/models/metrics_factories.py +++ b/warehouse/metrics_mesh/models/metrics_factories.py @@ -6,6 +6,7 @@ timeseries_metrics( start="2015-01-01", + catalog="metrics", model_prefix="timeseries", metric_queries={ # This will automatically generate star counts for the given roll up periods. diff --git a/warehouse/metrics_mesh/oso_metrics/developer_activity_classification.sql b/warehouse/metrics_mesh/oso_metrics/developer_activity_classification.sql index 9dcae9bf9..d5b2762c0 100644 --- a/warehouse/metrics_mesh/oso_metrics/developer_activity_classification.sql +++ b/warehouse/metrics_mesh/oso_metrics/developer_activity_classification.sql @@ -62,6 +62,7 @@ from @metrics_peer_ref( window := @rolling_window, unit := @rolling_unit ) as active +where active.metrics_sample_date = @metrics_end('DATE') group by metric, from_artifact_id, @metrics_entity_type_col('to_{entity_type}_id', table_alias := active), diff --git a/warehouse/metrics_mesh/tests/test_events_daily_to_artifact.yml b/warehouse/metrics_mesh/tests/test_events_daily_to_artifact.yml new file mode 100644 index 000000000..2b6dd0a39 --- /dev/null +++ b/warehouse/metrics_mesh/tests/test_events_daily_to_artifact.yml @@ -0,0 +1,84 @@ + +test_events_daily_to_artifact: + model: metrics.events_daily_to_artifact + vars: + start: 2024-01-01 + end: 2024-02-01 + inputs: + sources.timeseries_events_by_artifact_v0: + rows: + - to_artifact_id: contract_0 + from_artifact_id: user_0 + event_source: BLOCKCHAIN + event_type: CONTRACT_INVOCATION_SUCCESS_DAILY_COUNT + time: 2024-01-01T00:00:00Z + amount: 10 + - to_artifact_id: contract_1 + from_artifact_id: user_0 + event_source: BLOCKCHAIN + event_type: CONTRACT_INVOCATION_SUCCESS_DAILY_COUNT + time: 2024-01-01T00:00:00Z + amount: 10 + - to_artifact_id: repo_0 + from_artifact_id: dev_0 + event_source: SOURCE_PROVIDER + event_type: COMMIT_CODE + time: 2024-01-01T02:00:00Z + amount: 1 + - to_artifact_id: repo_0 + from_artifact_id: dev_0 + event_source: SOURCE_PROVIDER + event_type: COMMIT_CODE + time: 2024-01-01T03:00:00Z + amount: 1 + - to_artifact_id: repo_0 + from_artifact_id: dev_0 + event_source: SOURCE_PROVIDER + event_type: COMMIT_CODE + time: 2024-01-01T04:00:00Z + amount: 1 + - to_artifact_id: repo_0 + from_artifact_id: dev_1 + event_source: SOURCE_PROVIDER + event_type: COMMIT_CODE + time: 2024-01-01T04:00:00Z + amount: 1 + - to_artifact_id: repo_0 + from_artifact_id: dev_1 + event_source: SOURCE_PROVIDER + event_type: COMMIT_CODE + time: 2024-01-02T04:00:00Z + amount: 1 + outputs: + query: + rows: + - to_artifact_id: contract_0 + from_artifact_id: user_0 + event_source: BLOCKCHAIN + event_type: CONTRACT_INVOCATION_SUCCESS_DAILY_COUNT + bucket_day: 2024-01-01 + amount: 10 + - to_artifact_id: contract_1 + from_artifact_id: user_0 + event_source: BLOCKCHAIN + event_type: CONTRACT_INVOCATION_SUCCESS_DAILY_COUNT + bucket_day: 2024-01-01 + amount: 10 + - to_artifact_id: repo_0 + from_artifact_id: dev_0 + event_source: SOURCE_PROVIDER + event_type: COMMIT_CODE + bucket_day: 2024-01-01 + amount: 3 + - to_artifact_id: repo_0 + from_artifact_id: dev_1 + event_source: SOURCE_PROVIDER + event_type: COMMIT_CODE + bucket_day: 2024-01-01 + amount: 1 + - to_artifact_id: repo_0 + from_artifact_id: dev_1 + event_source: SOURCE_PROVIDER + event_type: COMMIT_CODE + bucket_day: 2024-01-02 + amount: 1 \ No newline at end of file diff --git a/warehouse/metrics_tools/definition.py b/warehouse/metrics_tools/definition.py index 76893822b..0f5b4dbfd 100644 --- a/warehouse/metrics_tools/definition.py +++ b/warehouse/metrics_tools/definition.py @@ -178,6 +178,8 @@ class MetricQueryDef: enabled: bool = True + use_python_model: bool = True + def raw_sql(self, queries_dir: str): return open(os.path.join(queries_dir, self.ref)).read() @@ -367,6 +369,10 @@ def validate(self): f"There must only be a single query expression in metrics query {self._source.ref}" ) + @property + def use_python_model(self): + return self._source.use_python_model + @property def query_expression(self) -> exp.Query: return t.cast(exp.Query, find_query_expressions(self._expressions)[0]) @@ -792,12 +798,12 @@ def join_all_of_entity_type( class TimeseriesMetricsOptions(t.TypedDict): model_prefix: str + catalog: str metric_queries: t.Dict[str, MetricQueryDef] default_dialect: t.NotRequired[str] - model_options: t.NotRequired[t.Dict[str, t.Any]] start: TimeLike - timeseries_sources: t.NotRequired[t.Optional[t.List[str]]] - queries_dir: t.NotRequired[t.Optional[str]] + timeseries_sources: t.NotRequired[t.List[str]] + queries_dir: t.NotRequired[str] class GeneratedArtifactConfig(t.TypedDict): diff --git a/warehouse/metrics_tools/factory/factory.py b/warehouse/metrics_tools/factory/factory.py index a9c4227f9..b259e4601 100644 --- a/warehouse/metrics_tools/factory/factory.py +++ b/warehouse/metrics_tools/factory/factory.py @@ -1,14 +1,18 @@ import contextlib +from datetime import datetime import inspect import logging import os from queue import PriorityQueue import typing as t import textwrap +from metrics_tools.runner import MetricsRunner +from metrics_tools.transformer.tables import ExecutionContextTableTransform +import pandas as pd from dataclasses import dataclass, field +from sqlmesh import ExecutionContext from sqlmesh.core.macros import MacroEvaluator -from sqlmesh.utils.date import TimeLike from sqlmesh.core.model import ModelKindName import sqlglot as sql from sqlglot import exp @@ -25,7 +29,10 @@ TimeseriesMetricsOptions, reference_to_str, ) -from metrics_tools.models import GeneratedModel +from metrics_tools.models import ( + GeneratedModel, + GeneratedPythonModel, +) from metrics_tools.factory.macros import ( metrics_end, metrics_entity_type_col, @@ -78,166 +85,6 @@ } -def generate_metric_models( - calling_file: str, - query: MetricQuery, - default_dialect: str, - peer_table_map: t.Dict[str, str], - start: TimeLike, - timeseries_sources: t.List[str], -): - # Turn the source into a dict so it can be used in the sqlmesh context - refs = query.provided_dependency_refs - - all_tables: t.Dict[str, t.List[str]] = { - "artifact": [], - "project": [], - "collection": [], - } - - for ref in refs: - cron = "@daily" - time_aggregation = ref.get("time_aggregation") - window = ref.get("window") - if time_aggregation: - cron = TIME_AGGREGATION_TO_CRON[time_aggregation] - else: - if not window: - raise Exception("window or time_aggregation must be set") - assert query._source.rolling - cron = query._source.rolling["cron"] - - table_name = query.table_name(ref) - all_tables[ref["entity_type"]].append(table_name) - columns = METRICS_COLUMNS_BY_ENTITY[ref["entity_type"]] - additional_macros = [ - metrics_peer_ref, - metrics_entity_type_col, - metrics_entity_type_alias, - relative_window_sample_date, - (metrics_name, ["metric_name"]), - ] - - kind_common = {"batch_size": 1} - partitioned_by = ("day(metrics_sample_date)",) - - # Due to how the schedulers work for sqlmesh we actually can't batch if - # we're using a weekly cron for a time aggregation. In order to have - # this work we just adjust the start/end time for the - # metrics_start/metrics_end and also give a large enough batch time to - # fit a few weeks. This ensures there's on missing data - if time_aggregation == "weekly": - kind_common = {"batch_size": 182, "lookback": 7} - if time_aggregation == "monthly": - kind_common = {"batch_size": 6} - partitioned_by = ("month(metrics_sample_date)",) - if time_aggregation == "daily": - kind_common = {"batch_size": 180} - - evaluator_variables: t.Dict[str, t.Any] = { - "entity_type": ref["entity_type"], - "time_aggregation": ref.get("time_aggregation", None), - "rolling_window": ref.get("window", None), - "rolling_unit": ref.get("unit", None), - } - evaluator_variables.update(query.vars) - - transformer = SQLTransformer( - disable_qualify=True, - transforms=[ - IntermediateMacroEvaluatorTransform( - additional_macros, - variables=evaluator_variables, - ), - JoinerTransform( - ref["entity_type"], - ), - ], - ) - - rendered_query = transformer.transform([query.query_expression]) - logger.debug(rendered_query) - - if ref["entity_type"] == "artifact": - GeneratedModel.create( - func=generated_query, - source="", - entrypoint_path=calling_file, - config={}, - name=f"metrics.{table_name}", - kind={ - "name": ModelKindName.INCREMENTAL_BY_TIME_RANGE, - "time_column": "metrics_sample_date", - **kind_common, - }, - dialect="clickhouse", - columns=columns, - grain=[ - "metric", - "to_artifact_id", - "from_artifact_id", - "metrics_sample_date", - ], - cron=cron, - start=start, - additional_macros=additional_macros, - partitioned_by=partitioned_by, - ) - - if ref["entity_type"] == "project": - GeneratedModel.create( - func=generated_query, - source="", - entrypoint_path=calling_file, - config={}, - name=f"metrics.{table_name}", - kind={ - "name": ModelKindName.INCREMENTAL_BY_TIME_RANGE, - "time_column": "metrics_sample_date", - **kind_common, - }, - dialect="clickhouse", - columns=columns, - grain=[ - "metric", - "to_project_id", - "from_artifact_id", - "metrics_sample_date", - ], - cron=cron, - start=start, - additional_macros=additional_macros, - partitioned_by=partitioned_by, - ) - if ref["entity_type"] == "collection": - GeneratedModel.create( - func=generated_query, - source="", - entrypoint_path=calling_file, - config={}, - name=f"metrics.{table_name}", - kind={ - "name": ModelKindName.INCREMENTAL_BY_TIME_RANGE, - "time_column": "metrics_sample_date", - **kind_common, - }, - dialect="clickhouse", - columns=columns, - grain=[ - "metric", - "to_collection_id", - "from_artifact_id", - "metrics_sample_date", - ], - cron=cron, - start=start, - additional_macros=additional_macros, - partitioned_by=partitioned_by, - ) - - return all_tables - - @contextlib.contextmanager def metric_ref_evaluator_context( evaluator: MacroEvaluator, @@ -260,21 +107,6 @@ def metric_ref_evaluator_context( evaluator.locals = before -def generated_query( - evaluator: MacroEvaluator, - *, - rendered_query_str: str, - ref: PeerMetricDependencyRef, - table_name: str, - vars: t.Dict[str, t.Any], -): - from sqlmesh.core.dialect import parse_one - - with metric_ref_evaluator_context(evaluator, ref, vars): - result = evaluator.transform(parse_one(rendered_query_str)) - return result - - class MetricQueryConfig(t.TypedDict): table_name: str ref: PeerMetricDependencyRef @@ -340,6 +172,11 @@ def __init__( self._rendered = False self._rendered_queries: t.Dict[str, MetricQueryConfig] = {} + @property + def catalog(self): + """The catalog (sometimes db name) to use for rendered queries""" + return self._raw_options["catalog"] + def generate_queries(self): if self._rendered: return self._rendered_queries @@ -435,6 +272,7 @@ class MetricQueryConfigQueueItem: visited: t.Dict[str, int] = {} cycle_lock: t.Dict[str, bool] = {} + dependencies: t.Dict[str, t.Set[str]] = {} def queue_query(name: str): if name in cycle_lock: @@ -448,6 +286,7 @@ def queue_query(name: str): rendered_query = query_config["rendered_query"] depth = 0 tables = rendered_query.find_all(exp.Table) + parents = set() for table in tables: db_name = table.db if isinstance(table.db, exp.Identifier): @@ -457,6 +296,7 @@ def queue_query(name: str): if db_name != "metrics": continue + parents.add(table_name) if table_name in sources: continue @@ -480,6 +320,7 @@ def queue_query(name: str): depth = parent_depth + 1 queue.put(MetricQueryConfigQueueItem(depth, query_config)) visited[name] = depth + dependencies[name] = parents del cycle_lock[name] return depth @@ -491,14 +332,16 @@ def queue_query(name: str): item = t.cast(MetricQueryConfigQueueItem, queue.get()) depth = item.depth query_config = item.config - yield (depth, query_config) + yield (depth, query_config, dependencies[query_config["table_name"]]) def generate_models(self, calling_file: str): """Generates sqlmesh models for all the configured metrics definitions""" # Generate the models - for _, query_config in self.generate_ordered_queries(): - self.generate_model_for_rendered_query(calling_file, query_config) + for _, query_config, dependencies in self.generate_ordered_queries(): + self.generate_model_for_rendered_query( + calling_file, query_config, dependencies + ) # Join all of the models of the same entity type into the same view model for entity_type, tables in self._marts_tables.items(): @@ -506,7 +349,7 @@ def generate_models(self, calling_file: str): func=join_all_of_entity_type, entrypoint_path=calling_file, config={ - "db": "metrics", + "db": self.catalog, "tables": tables, "columns": list(METRICS_COLUMNS_BY_ENTITY[entity_type].keys()), }, @@ -522,26 +365,86 @@ def generate_models(self, calling_file: str): ) }, ) - print("model generation complete") + logger.debug("model generation complete") def generate_model_for_rendered_query( - self, calling_file: str, query_config: MetricQueryConfig + self, + calling_file: str, + query_config: MetricQueryConfig, + dependencies: t.Set[str], ): query = query_config["query"] match query.metric_type: case "rolling": - self.generate_rolling_model_for_rendered_query( - calling_file, query_config - ) + if query.use_python_model: + self.generate_rolling_python_model_for_rendered_query( + calling_file, query_config, dependencies + ) + else: + self.generate_rolling_model_for_rendered_query( + calling_file, query_config, dependencies + ) case "time_aggregation": self.generate_time_aggregation_model_for_rendered_query( - calling_file, query_config + calling_file, query_config, dependencies ) + def generate_rolling_python_model_for_rendered_query( + self, + calling_file: str, + query_config: MetricQueryConfig, + dependencies: t.Set[str], + ): + depends_on = set() + for dep in dependencies: + depends_on.add(f"{self.catalog}.{dep}") + + ref = query_config["ref"] + query = query_config["query"] + + columns = METRICS_COLUMNS_BY_ENTITY[ref["entity_type"]] + + kind_common = {"batch_size": 90} + partitioned_by = ("day(metrics_sample_date)",) + window = ref.get("window") + assert window is not None + assert query._source.rolling + cron = query._source.rolling["cron"] + + grain = [ + "metric", + f"to_{ref['entity_type']}_id", + "event_source", + "from_artifact_id", + "metrics_sample_date", + ] + + return GeneratedPythonModel.create( + name=f"{self.catalog}.{query_config['table_name']}", + func=generated_rolling_query_proxy, + entrypoint_path=calling_file, + additional_macros=self.generated_model_additional_macros, + variables=self.serializable_config(query_config), + depends_on=depends_on, + columns=columns, + kind={ + "name": ModelKindName.INCREMENTAL_BY_TIME_RANGE, + "time_column": "metrics_sample_date", + **kind_common, + }, + partitioned_by=partitioned_by, + cron=cron, + start=self._raw_options["start"], + grain=grain, + imports={"pd": pd, "generated_rolling_query": generated_rolling_query}, + ) + def generate_rolling_model_for_rendered_query( - self, calling_file: str, query_config: MetricQueryConfig + self, + calling_file: str, + query_config: MetricQueryConfig, + dependencies: t.Set[str], ): - """TODO change this to a python model""" config = self.serializable_config(query_config) ref = query_config["ref"] @@ -560,6 +463,7 @@ def generate_rolling_model_for_rendered_query( "metric", f"to_{ref['entity_type']}_id", "from_artifact_id", + "event_source", "metrics_sample_date", ] @@ -567,7 +471,7 @@ def generate_rolling_model_for_rendered_query( func=generated_query, entrypoint_path=calling_file, config=config, - name=f"metrics.{query_config['table_name']}", + name=f"{self.catalog}.{query_config['table_name']}", kind={ "name": ModelKindName.INCREMENTAL_BY_TIME_RANGE, "time_column": "metrics_sample_date", @@ -583,7 +487,10 @@ def generate_rolling_model_for_rendered_query( ) def generate_time_aggregation_model_for_rendered_query( - self, calling_file: str, query_config: MetricQueryConfig + self, + calling_file: str, + query_config: MetricQueryConfig, + dependencies: t.Set[str], ): """Generate model for time aggregation models""" # Use a simple python sql model to generate the time_aggregation model @@ -609,6 +516,7 @@ def generate_time_aggregation_model_for_rendered_query( "metric", f"to_{ref['entity_type']}_id", "from_artifact_id", + "event_source", "metrics_sample_date", ] cron = TIME_AGGREGATION_TO_CRON[time_aggregation] @@ -617,7 +525,7 @@ def generate_time_aggregation_model_for_rendered_query( func=generated_query, entrypoint_path=calling_file, config=config, - name=f"metrics.{query_config['table_name']}", + name=f"{self.catalog}.{query_config['table_name']}", kind={ "name": ModelKindName.INCREMENTAL_BY_TIME_RANGE, "time_column": "metrics_sample_date", @@ -646,7 +554,9 @@ def serializable_config(self, query_config: MetricQueryConfig): return config @property - def generated_model_additional_macros(self): + def generated_model_additional_macros( + self, + ) -> t.List[t.Callable | t.Tuple[t.Callable, t.List[str]]]: return [metrics_end, metrics_start, metrics_sample_date] @@ -694,3 +604,71 @@ def join_all_of_entity_type( ) # Calculate the correct metric_id for all of the entity types return query + + +def generated_query( + evaluator: MacroEvaluator, + *, + rendered_query_str: str, + ref: PeerMetricDependencyRef, + table_name: str, + vars: t.Dict[str, t.Any], +): + """Simple generated query executor for metrics queries""" + from sqlmesh.core.dialect import parse_one + + with metric_ref_evaluator_context(evaluator, ref, vars): + result = evaluator.transform(parse_one(rendered_query_str)) + return result + + +def generated_rolling_query( + context: ExecutionContext, + start: datetime, + end: datetime, + execution_time: datetime, + ref: PeerMetricDependencyRef, + vars: t.Dict[str, t.Any], + rendered_query_str: str, + table_name: str, + sqlmesh_vars: t.Dict[str, t.Any], + *_ignored, +): + # Transform the query for the current context + transformer = SQLTransformer(transforms=[ExecutionContextTableTransform(context)]) + query = transformer.transform(rendered_query_str) + locals = vars.copy() + locals.update(sqlmesh_vars) + + runner = MetricsRunner.from_sqlmesh_context(context, query, ref, locals) + yield runner.run_rolling(start, end) + + +def generated_rolling_query_proxy( + context: ExecutionContext, + start: datetime, + end: datetime, + execution_time: datetime, + ref: PeerMetricDependencyRef, + vars: t.Dict[str, t.Any], + rendered_query_str: str, + table_name: str, + sqlmesh_vars: t.Dict[str, t.Any], + **kwargs, +) -> t.Iterator[pd.DataFrame]: + """This acts as the proxy to the actual function that we'd call for + the metrics model.""" + + yield from generated_rolling_query( + context, + start, + end, + execution_time, + ref, + vars, + rendered_query_str, + table_name, + sqlmesh_vars, + # Change the following variable to force reevaluation. Hack for now. + "version=v4", + ) diff --git a/warehouse/metrics_tools/factory/test_factory.py b/warehouse/metrics_tools/factory/test_factory.py index d4bd06a0e..384118017 100644 --- a/warehouse/metrics_tools/factory/test_factory.py +++ b/warehouse/metrics_tools/factory/test_factory.py @@ -2,9 +2,9 @@ from metrics_tools.utils.testing import duckdb_df_context import os import arrow -from metrics_tools.factory.gen_data import MetricsDBFixture import pytest +from metrics_tools.utils.fixtures.gen_data import MetricsDBFixture from metrics_tools.runner import MetricsRunner from metrics_tools.definition import MetricQueryDef, RollingConfig from .factory import TimeseriesMetrics @@ -29,32 +29,24 @@ def timeseries_duckdb(): "c_3": ["p_0", "p_1", "p_2"], }, ) + start = "2023-12-01" + end = "2024-02-01" - fixture.generate_daily_events( - "2023-01-01", "2024-12-31", "VISIT", "user_0", "service_0" - ) - fixture.generate_daily_events( - "2023-01-01", "2024-12-31", "VISIT", "user_0", "service_1" - ) - fixture.generate_daily_events( - "2023-01-01", "2024-12-31", "VISIT", "user_1", "service_1" - ) - fixture.generate_daily_events( - "2023-01-01", "2024-12-31", "VISIT", "user_2", "service_2" - ) + fixture.generate_daily_events(start, end, "VISIT", "user_0", "service_0") + fixture.generate_daily_events(start, end, "VISIT", "user_0", "service_1") + fixture.generate_daily_events(start, end, "VISIT", "user_1", "service_1") + fixture.generate_daily_events(start, end, "VISIT", "user_2", "service_2") for ft_dev_index in range(5): dev_name = f"ft_dev_{ft_dev_index}" - fixture.generate_daily_events( - "2023-01-01", "2024-12-31", "COMMIT_CODE", dev_name, "repo_0" - ) + fixture.generate_daily_events(start, end, "COMMIT_CODE", dev_name, "repo_0") # Change in developers for ft_dev_index in range(5, 10): dev_name = f"ft_dev_{ft_dev_index}" fixture.generate_daily_events( - "2023-01-01", - "2024-12-31", + start, + end, "COMMIT_CODE", dev_name, "repo_0", @@ -71,8 +63,8 @@ def timeseries_duckdb(): for pt_dev_index in range(10): dev_name = f"pt_dev_{pt_dev_index}" fixture.generate_daily_events( - "2023-01-01", - "2024-12-31", + start, + end, "COMMIT_CODE", dev_name, "repo_0", @@ -86,6 +78,7 @@ def timeseries_duckdb(): def timeseries_metrics_to_test(): return TimeseriesMetrics.from_raw_options( start="2024-01-01", + catalog="metrics", model_prefix="timeseries", metric_queries={ "visits": MetricQueryDef( @@ -179,7 +172,7 @@ def test_runner( base_locals = {"oso_source": "sources"} connection = timeseries_duckdb._conn - for _, query_config in timeseries_metrics_to_test.generate_ordered_queries(): + for _, query_config, _ in timeseries_metrics_to_test.generate_ordered_queries(): ref = query_config["ref"] locals = query_config["vars"].copy() locals.update(base_locals) diff --git a/warehouse/metrics_tools/models.py b/warehouse/metrics_tools/models.py index 65f889468..2c33ff7c6 100644 --- a/warehouse/metrics_tools/models.py +++ b/warehouse/metrics_tools/models.py @@ -11,7 +11,7 @@ from sqlmesh.core.dialect import MacroFunc from sqlmesh.core.macros import ExecutableOrMacro, MacroRegistry, macro from sqlmesh.core.model.decorator import model -from sqlmesh.core.model.definition import create_sql_model +from sqlmesh.core.model.definition import create_sql_model, create_python_model from sqlmesh.utils.jinja import JinjaMacroRegistry from sqlmesh.utils.metaprogramming import ( Executable, @@ -23,6 +23,117 @@ logger = logging.getLogger(__name__) +CallableAliasList = t.List[t.Callable | t.Tuple[t.Callable, t.List[str]]] + + +class GeneratedPythonModel: + @classmethod + def create( + cls, + *, + name: str, + entrypoint_path: str, + func: t.Callable, + columns: t.Dict[str, exp.DataType], + additional_macros: t.Optional[CallableAliasList] = None, + variables: t.Optional[t.Dict[str, t.Any]] = None, + imports: t.Optional[t.Dict[str, t.Any]] = None, + **kwargs, + ): + instance = cls( + name=name, + entrypoint_path=entrypoint_path, + func=func, + additional_macros=additional_macros or [], + variables=variables or {}, + columns=columns, + imports=imports or {}, + **kwargs, + ) + registry = model.registry() + registry[name] = instance + return instance + + def __init__( + self, + *, + name: str, + entrypoint_path: str, + func: t.Callable, + additional_macros: CallableAliasList, + variables: t.Dict[str, t.Any], + columns: t.Dict[str, exp.DataType], + imports: t.Dict[str, t.Any], + **kwargs, + ): + self.name = name + self._func = func + self._entrypoint_path = entrypoint_path + self._additional_macros = additional_macros + self._variables = variables + self._kwargs = kwargs + self._columns = columns + self._imports = imports + + def model( + self, + module_path: Path, + path: Path, + defaults: t.Optional[t.Dict[str, t.Any]] = None, + macros: t.Optional[MacroRegistry] = None, + jinja_macros: t.Optional[JinjaMacroRegistry] = None, + dialect: t.Optional[str] = None, + time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, + physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, + project: str = "", + default_catalog: t.Optional[str] = None, + variables: t.Optional[t.Dict[str, t.Any]] = None, + infer_names: t.Optional[bool] = False, + ): + fake_module_path = Path(self._entrypoint_path) + macros = MacroRegistry(f"macros_for_{self.name}") + macros.update(macros or macro.get_registry()) + + if self._additional_macros: + macros.update(create_macro_registry_from_list(self._additional_macros)) + + all_vars = self._variables.copy() + global_variables = variables or {} + all_vars["sqlmesh_vars"] = global_variables + + common_kwargs: t.Dict[str, t.Any] = dict( + defaults=defaults, + path=path, + time_column_format=time_column_format, + physical_schema_mapping=physical_schema_mapping, + project=project, + default_catalog=default_catalog, + variables=all_vars, + **self._kwargs, + ) + + env = {} + python_env = create_basic_python_env( + env, + self._entrypoint_path, + module_path, + macros=macros, + callables=[self._func], + imports=self._imports, + ) + + return create_python_model( + self.name, + self._func.__name__, + python_env, + macros=macros, + module_path=fake_module_path, + jinja_macros=jinja_macros, + columns=self._columns, + dialect=dialect, + **common_kwargs, + ) + class GeneratedModel: @classmethod @@ -36,6 +147,7 @@ def create( entrypoint_path: str, source: t.Optional[str] = None, source_loader: t.Optional[t.Callable[[], str]] = None, + additional_macros: t.Optional[CallableAliasList] = None, **kwargs, ): if not source and not source_loader: @@ -57,6 +169,7 @@ def create( entrypoint_path=entrypoint_path, source=source, source_loader=source_loader, + additional_macros=additional_macros, **kwargs, ) registry = model.registry() @@ -74,9 +187,7 @@ def __init__( entrypoint_path: str, source: t.Optional[str] = None, source_loader: t.Optional[t.Callable[[], str]] = None, - additional_macros: t.Optional[ - t.List[t.Callable | t.Tuple[t.Callable, t.List[str]]] - ] = None, + additional_macros: t.Optional[CallableAliasList] = None, **kwargs, ): self.kwargs = kwargs @@ -111,11 +222,7 @@ def model( if self.additional_macros: macros = t.cast(MacroRegistry, macros.copy()) - for additional_macro in self.additional_macros: - if isinstance(additional_macro, tuple): - macros.update(create_unregistered_wrapped_macro(*additional_macro)) - else: - macros.update(create_unregistered_wrapped_macro(additional_macro)) + macros.update(create_macro_registry_from_list(self.additional_macros)) common_kwargs: t.Dict[str, t.Any] = dict( defaults=defaults, @@ -267,6 +374,8 @@ def create_basic_python_env( macros: t.Optional[MacroRegistry] = None, additional_macros: t.Optional[MacroRegistry] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, + callables: t.Optional[t.List[t.Callable]] = None, + imports: t.Optional[t.Dict[str, t.Any]] = None, ): if isinstance(path, str): path = Path(path) @@ -291,8 +400,29 @@ def create_basic_python_env( for name, value in variables.items(): serialized[name] = Executable.value(value) - serialized.update(serialize_env(python_env, project_path)) + imports = imports or {} + for name, imp in imports.items(): + python_env[name] = imp + # serialized[func.__name__] = Executable( + # payload=f"from {func.__module__} import {func.__name__}", + # kind=ExecutableKind.IMPORT, + # ) + + callables = callables or [] + for func in callables: + # FIXME: this is not ideal right now, we should generalize + # create_import_call_env to support this. + + serialized[func.__name__] = Executable( + name=func.__name__, + payload=normalize_source(func), + kind=ExecutableKind.DEFINITION, + path="", + alias=None, + is_metadata=False, + ) + serialized.update(serialize_env(python_env, project_path)) return serialized @@ -352,3 +482,13 @@ def {entrypoint_name}(evaluator): ) return (entrypoint_name, serialized) + + +def create_macro_registry_from_list(macro_list: CallableAliasList): + registry = MacroRegistry("macros") + for additional_macro in macro_list: + if isinstance(additional_macro, tuple): + registry.update(create_unregistered_wrapped_macro(*additional_macro)) + else: + registry.update(create_unregistered_wrapped_macro(additional_macro)) + return registry diff --git a/warehouse/metrics_tools/runner.py b/warehouse/metrics_tools/runner.py index 161e004aa..fc320a7cb 100644 --- a/warehouse/metrics_tools/runner.py +++ b/warehouse/metrics_tools/runner.py @@ -3,6 +3,7 @@ import duckdb import arrow import logging +from metrics_tools.utils.glot import str_or_expressions from sqlmesh.core.context import ExecutionContext from sqlmesh.core.config import DuckDBConnectionConfig from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter @@ -82,7 +83,7 @@ class MetricsRunner: def create_duckdb_execution_context( cls, conn: duckdb.DuckDBPyConnection, - query: t.List[exp.Expression], + query: str | t.List[exp.Expression], ref: PeerMetricDependencyRef, locals: t.Optional[t.Dict[str, t.Any]], ): @@ -91,7 +92,17 @@ def connection_factory(): engine_adapter = DuckDBEngineAdapter(connection_factory) context = ExecutionContext(engine_adapter, {}) - return cls(context, query, ref, locals) + return cls(context, str_or_expressions(query), ref, locals) + + @classmethod + def from_sqlmesh_context( + cls, + context: ExecutionContext, + query: str | t.List[exp.Expression], + ref: PeerMetricDependencyRef, + locals: t.Optional[t.Dict[str, t.Any]] = None, + ): + return cls(context, str_or_expressions(query), ref, locals) def __init__( self, @@ -120,7 +131,9 @@ def run_time_aggregation(self, start: datetime, end: datetime): def run_rolling(self, start: datetime, end: datetime): df: pd.DataFrame = pd.DataFrame() logger.debug(f"run_rolling called with start={start} and end={end}") + count = 0 for day in arrow.Arrow.range("day", arrow.get(start), arrow.get(end)): + count += 1 rendered_query = self.render_query(day.datetime, day.datetime) logger.debug( f"executing rolling window: {rendered_query}", @@ -128,6 +141,7 @@ def run_rolling(self, start: datetime, end: datetime): ) day_result = self._context.engine_adapter.fetchdf(rendered_query) df = pd.concat([df, day_result]) + return df def render_query(self, start: datetime, end: datetime) -> str: diff --git a/warehouse/metrics_tools/transformer/tables.py b/warehouse/metrics_tools/transformer/tables.py new file mode 100644 index 000000000..50ce8c9c2 --- /dev/null +++ b/warehouse/metrics_tools/transformer/tables.py @@ -0,0 +1,38 @@ +"""Transforms table references from an execution context +""" + +import typing as t + +from sqlglot import exp +from sqlmesh.core.context import ExecutionContext +from .base import Transform + + +class ExecutionContextTableTransform(Transform): + def __init__( + self, + context: ExecutionContext, + ): + self._context = context + + def __call__(self, query: t.List[exp.Expression]) -> t.List[exp.Expression]: + context = self._context + + def transform_tables(node: exp.Expression): + if not isinstance(node, exp.Table): + return node + table_name = f"{node.db}.{node.this.this}" + try: + actual_table_name = context.table(table_name) + except KeyError: + return node + table_kwargs = {} + if node.alias: + table_kwargs["alias"] = node.alias + return exp.to_table(actual_table_name, **table_kwargs) + + transformed_expressions = [] + for expression in query: + transformed = expression.transform(transform_tables) + transformed_expressions.append(transformed) + return transformed_expressions diff --git a/warehouse/metrics_tools/utils/fixtures/__init__.py b/warehouse/metrics_tools/utils/fixtures/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/warehouse/metrics_tools/factory/gen_data.py b/warehouse/metrics_tools/utils/fixtures/gen_data.py similarity index 100% rename from warehouse/metrics_tools/factory/gen_data.py rename to warehouse/metrics_tools/utils/fixtures/gen_data.py diff --git a/warehouse/metrics_tools/utils/glot.py b/warehouse/metrics_tools/utils/glot.py index 7bb86f3d1..ade90cd93 100644 --- a/warehouse/metrics_tools/utils/glot.py +++ b/warehouse/metrics_tools/utils/glot.py @@ -1,6 +1,7 @@ import typing as t from sqlglot import exp +from sqlmesh.core.dialect import parse def exp_literal_to_py_literal(glot_literal: exp.Expression) -> t.Any: @@ -8,3 +9,9 @@ def exp_literal_to_py_literal(glot_literal: exp.Expression) -> t.Any: if not isinstance(glot_literal, exp.Literal): return glot_literal return glot_literal.this + + +def str_or_expressions(query: str | t.List[exp.Expression]) -> t.List[exp.Expression]: + if not isinstance(query, list): + return parse(query) + return query