diff --git a/warehouse/metrics_tools/factory/factory.py b/warehouse/metrics_tools/factory/factory.py index 82889c94..1f949e0d 100644 --- a/warehouse/metrics_tools/factory/factory.py +++ b/warehouse/metrics_tools/factory/factory.py @@ -26,7 +26,7 @@ metrics_start, relative_window_sample_date, ) -from metrics_tools.models import GeneratedModel, MacroOverridingModel +from metrics_tools.models import MacroOverridingModel from metrics_tools.transformer import ( IntermediateMacroEvaluatorTransform, SQLTransformer, @@ -296,17 +296,24 @@ def generate_models(self, calling_file: str): ) # Join all of the models of the same entity type into the same view model + override_path = Path(inspect.getfile(join_all_of_entity_type)) + override_module_path = Path( + os.path.dirname(inspect.getfile(join_all_of_entity_type)) + ) for entity_type, tables in self._marts_tables.items(): - GeneratedModel.create( - func=join_all_of_entity_type, - entrypoint_path=calling_file, - config={ - "db": self.catalog, - "tables": tables, - "columns": list( - constants.METRICS_COLUMNS_BY_ENTITY[entity_type].keys() - ), - }, + MacroOverridingModel( + additional_macros=[], + override_module_path=override_module_path, + override_path=override_path, + locals=dict( + config={ + "db": self.catalog, + "tables": tables, + "columns": list( + constants.METRICS_COLUMNS_BY_ENTITY[entity_type].keys() + ), + } + ), name=f"metrics.timeseries_metrics_to_{entity_type}", kind="VIEW", dialect="clickhouse", @@ -319,7 +326,7 @@ def generate_models(self, calling_file: str): ) }, enabled=self._raw_options.get("enabled", True), - ) + )(join_all_of_entity_type) logger.info("model generation complete") def generate_model_for_rendered_query( diff --git a/warehouse/metrics_tools/factory/proxy/proxies.py b/warehouse/metrics_tools/factory/proxy/proxies.py index fc71f9d7..9628129f 100644 --- a/warehouse/metrics_tools/factory/proxy/proxies.py +++ b/warehouse/metrics_tools/factory/proxy/proxies.py @@ -14,15 +14,12 @@ def generated_query( evaluator: MacroEvaluator, - *, - rendered_query_str: str, - ref: PeerMetricDependencyRef, - table_name: str, - vars: t.Dict[str, t.Any], ): """Simple generated query executor for metrics queries""" + rendered_query_str = t.cast(str, evaluator.var("rendered_query_str")) + ref = t.cast(PeerMetricDependencyRef, evaluator.var("ref")) - with metric_ref_evaluator_context(evaluator, ref, vars): + with metric_ref_evaluator_context(evaluator, ref): result = evaluator.transform(parse_one(rendered_query_str)) return result @@ -63,11 +60,16 @@ def generated_rolling_query_proxy( def join_all_of_entity_type( - evaluator: MacroEvaluator, *, db: str, tables: t.List[str], columns: t.List[str] + evaluator: MacroEvaluator, ): # A bit of a hack but we know we have a "metric" column. We want to # transform this metric id to also include the event_source as a prefix to # that metric id in the joined table + + db = t.cast(str, evaluator.var("db")) + tables: t.List[str] = t.cast(t.List[str], evaluator.var("tables")) + columns: t.List[str] = t.cast(t.List[str], evaluator.var("columns")) + transformed_columns = [] for column in columns: if column == "event_source": diff --git a/warehouse/metrics_tools/models.py b/warehouse/metrics_tools/models.py index b3538b74..f2be3e15 100644 --- a/warehouse/metrics_tools/models.py +++ b/warehouse/metrics_tools/models.py @@ -1,20 +1,14 @@ import inspect import json import logging -import re import textwrap import typing as t import uuid from pathlib import Path -from sqlglot import exp from sqlmesh.core import constants as c -from sqlmesh.core.audit import ModelAudit -from sqlmesh.core.dialect import MacroFunc from sqlmesh.core.macros import ExecutableOrMacro, MacroRegistry, macro from sqlmesh.core.model.decorator import model -from sqlmesh.core.model.definition import create_python_model, create_sql_model -from sqlmesh.utils.jinja import JinjaMacroRegistry from sqlmesh.utils.metaprogramming import ( Executable, ExecutableKind, @@ -63,257 +57,6 @@ def model(self, *args, **kwargs): return super().model(*args, **kwargs) -class GeneratedPythonModel: - @classmethod - def create( - cls, - *, - name: str, - entrypoint_path: str, - func: t.Callable, - columns: t.Dict[str, exp.DataType], - additional_macros: t.Optional[CallableAliasList] = None, - variables: t.Optional[t.Dict[str, t.Any]] = None, - imports: t.Optional[t.Dict[str, t.Any]] = None, - **kwargs, - ): - instance = cls( - name=name, - entrypoint_path=entrypoint_path, - func=func, - additional_macros=additional_macros or [], - variables=variables or {}, - columns=columns, - imports=imports or {}, - **kwargs, - ) - registry = model.registry() - registry[name] = instance - return instance - - def __init__( - self, - *, - name: str, - entrypoint_path: str, - func: t.Callable, - additional_macros: CallableAliasList, - variables: t.Dict[str, t.Any], - columns: t.Dict[str, exp.DataType], - imports: t.Dict[str, t.Any], - **kwargs, - ): - self.name = name - self._func = func - self._entrypoint_path = entrypoint_path - self._additional_macros = additional_macros - self._variables = variables - self._kwargs = kwargs - self._columns = columns - self._imports = imports - - def model( - self, - module_path: Path, - path: Path, - defaults: t.Optional[t.Dict[str, t.Any]] = None, - macros: t.Optional[MacroRegistry] = None, - jinja_macros: t.Optional[JinjaMacroRegistry] = None, - audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None, - dialect: t.Optional[str] = None, - time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, - physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, - project: str = "", - default_catalog: t.Optional[str] = None, - variables: t.Optional[t.Dict[str, t.Any]] = None, - infer_names: t.Optional[bool] = False, - ): - fake_module_path = Path(self._entrypoint_path) - macros = MacroRegistry(f"macros_for_{self.name}") - macros.update(macros or macro.get_registry()) - - if self._additional_macros: - macros.update(create_macro_registry_from_list(self._additional_macros)) - - all_vars = self._variables.copy() - global_variables = variables or {} - all_vars.update(global_variables) - all_vars["sqlmesh_vars"] = global_variables - - self._kwargs.setdefault("dialect", dialect) - - common_kwargs: t.Dict[str, t.Any] = dict( - defaults=defaults, - path=path, - time_column_format=time_column_format, - physical_schema_mapping=physical_schema_mapping, - project=project, - default_catalog=default_catalog, - variables=all_vars, - audit_definitions=audit_definitions, - **self._kwargs, - ) - - env = {} - python_env = create_basic_python_env( - env, - self._entrypoint_path, - module_path, - macros=macros, - callables=[self._func], - imports=self._imports, - ) - - return create_python_model( - self.name, - self._func.__name__, - python_env, - macros=macros, - module_path=fake_module_path, - jinja_macros=jinja_macros, - columns=self._columns, - **common_kwargs, - ) - - -class GeneratedModel: - @classmethod - def create( - cls, - *, - func: t.Callable[..., t.Any], - config: t.Mapping[str, t.Any], - name: str, - columns: t.Dict[str, exp.DataType], - entrypoint_path: str, - source: t.Optional[str] = None, - source_loader: t.Optional[t.Callable[[], str]] = None, - additional_macros: t.Optional[CallableAliasList] = None, - **kwargs, - ): - if not source and not source_loader: - try: - source = inspect.getsource(func) - except: # noqa: E722 - pass - - assert ( - source is not None or source_loader is not None - ), "Must have a way to load the source for state diffs" - - instance = cls( - func_name=func.__name__, - import_module=func.__module__, - config=config, - name=name, - columns=columns, - entrypoint_path=entrypoint_path, - source=source, - source_loader=source_loader, - additional_macros=additional_macros, - **kwargs, - ) - registry = model.registry() - registry[name] = instance - return instance - - def __init__( - self, - *, - func_name: str, - import_module: str, - config: t.Mapping[str, t.Any], - name: str, - columns: t.Dict[str, exp.DataType], - entrypoint_path: str, - source: t.Optional[str] = None, - source_loader: t.Optional[t.Callable[[], str]] = None, - additional_macros: t.Optional[CallableAliasList] = None, - **kwargs, - ): - self.kwargs = kwargs - self.func_name = func_name - self.import_module = import_module - self.config = config - self.name = name - self.columns = columns - self.entrypoint_path = entrypoint_path - self.source_loader = source_loader - self.source = source - self.additional_macros = additional_macros or [] - - def model( - self, - *, - module_path: Path, - path: Path, - defaults: t.Optional[t.Dict[str, t.Any]] = None, - macros: t.Optional[MacroRegistry] = None, - jinja_macros: t.Optional[JinjaMacroRegistry] = None, - audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None, - dialect: t.Optional[str] = None, - time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, - physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, - project: str = "", - default_catalog: t.Optional[str] = None, - variables: t.Optional[t.Dict[str, t.Any]] = None, - infer_names: t.Optional[bool] = False, - ): - macros = macros or macro.get_registry() - fake_module_path = Path(self.entrypoint_path) - - if self.additional_macros: - macros = t.cast(MacroRegistry, macros.copy()) - macros.update(create_macro_registry_from_list(self.additional_macros)) - - self.kwargs.setdefault("dialect", dialect) - - common_kwargs: t.Dict[str, t.Any] = dict( - defaults=defaults, - path=fake_module_path, - time_column_format=time_column_format, - physical_schema_mapping=physical_schema_mapping, - project=project, - default_catalog=default_catalog, - variables=variables, - audit_definitions=audit_definitions, - **self.kwargs, - ) - - source = self.source - if not source: - if self.source_loader: - source = self.source_loader() - assert source is not None, "source cannot be empty" - - env = {} - - entrypoint_name, env = create_import_call_env( - self.func_name, - self.import_module, - self.config, - source, - env, - fake_module_path, - project_path=module_path, - macros=macros, - variables=variables, - ) - common_kwargs["python_env"] = env - - query = MacroFunc(this=exp.Anonymous(this=entrypoint_name)) - if self.columns: - common_kwargs["columns"] = self.columns - - return create_sql_model( - self.name, - query, - module_path=fake_module_path, - jinja_macros=jinja_macros, - **common_kwargs, - ) - - def escape_triple_quotes(input_string: str) -> str: escaped_string = input_string.replace("'''", "\\'\\'\\'") escaped_string = escaped_string.replace('"""', '\\"\\"\\"')