From 7be91c3bf865975af75919fcb0afa9356bdcc570 Mon Sep 17 00:00:00 2001 From: Reuven Gonzales Date: Wed, 4 Sep 2024 19:55:30 -0700 Subject: [PATCH] Force casting during the union for clickhouse (#2061) --- .../metrics_mesh/lib/factories/factory.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/warehouse/metrics_mesh/lib/factories/factory.py b/warehouse/metrics_mesh/lib/factories/factory.py index 1c6757450..aca54ea9c 100644 --- a/warehouse/metrics_mesh/lib/factories/factory.py +++ b/warehouse/metrics_mesh/lib/factories/factory.py @@ -48,7 +48,9 @@ class MetricQuery: name: t.Optional[str] = None - def load_exp(self) -> t.List[exp.Expression]: + dialect: t.Optional[str] = None + + def load_exp(self, default_dialect: str) -> t.List[exp.Expression]: """Loads the queries sql file as a sqlglot expression""" raw_sql = open(os.path.join(QUERIES_DIR, self.ref)).read() return t.cast( @@ -89,8 +91,8 @@ def to_input(self) -> MetricQueryInput: class Subquery: @classmethod - def load(cls, *, name: str, source: MetricQuery): - subquery = cls(name, source, source.load_exp()) + def load(cls, *, name: str, default_dialect: str, source: MetricQuery): + subquery = cls(name, source, source.load_exp(default_dialect)) subquery.validate() return subquery @@ -216,12 +218,22 @@ def generated_model(evaluator: MacroEvaluator): for query_name, query_input in metric_queries.items(): query = MetricQuery.from_input(query_input) subquery = subqueries[query_name] = Subquery.load( - name=query_name, source=query + name=query_name, default_dialect="clickhouse", source=query ) union_cte: t.Optional[exp.Query] = None - top_level_select = exp.select( + + cte_column_select = [ "metrics_bucket_date as bucket_day", + "to_artifact_id as to_artifact_id", + "from_artifact_id as from_artifact_id", + "event_source as event_source", + "metric as metric", + "CAST(amount AS Int64) as amount", + ] + + top_level_select = exp.select( + "bucket_day", "to_artifact_id", "from_artifact_id", "event_source", @@ -248,7 +260,7 @@ def generated_model(evaluator: MacroEvaluator): evaluator, extra_vars=dict(trailing_days=trailing_days) ) top_level_select = top_level_select.with_(cte_name, as_=evaluated) - unionable_select = sqlglot.select("*").from_(cte_name) + unionable_select = sqlglot.select(*cte_column_select).from_(cte_name) if not union_cte: union_cte = unionable_select else: