diff --git a/warehouse/metrics_tools/utils/tables.py b/warehouse/metrics_tools/utils/tables.py index 798a7734..1226f39d 100644 --- a/warehouse/metrics_tools/utils/tables.py +++ b/warehouse/metrics_tools/utils/tables.py @@ -1,3 +1,4 @@ +import copy import typing as t from sqlglot import exp @@ -14,16 +15,58 @@ def resolve_identifier_or_string(i: exp.Expression | str) -> t.Optional[str]: return None +def resolve_table_fqn(table: exp.Table) -> str: + table_name_parts = map(resolve_identifier_or_string, table.parts) + table_name_parts = filter(None, table_name_parts) + return ".".join(table_name_parts) + + def resolve_table_name( context: ExecutionContext, table: exp.Table ) -> t.Tuple[str, str]: - table_name_parts = map(resolve_identifier_or_string, table.parts) - table_name_parts = filter(None, table_name_parts) - table_fqn = ".".join(table_name_parts) + table_fqn = resolve_table_fqn(table) return (table_fqn, context.table(table_fqn)) +def list_query_table_dependencies( + query: exp.Expression, parent_ctes: t.Dict[str, exp.Expression] +) -> t.Set[str]: + tables: t.Set[str] = set() + cte_lookup = copy.deepcopy(parent_ctes) + + assert isinstance( + query, (exp.Select, exp.Union) + ), f"Unsupported query type {type(query)}" + + # Lookup ctes + for cte in query.ctes: + cte_lookup[cte.alias] = cte.this + + if isinstance(query, exp.Union): + tables = tables.union(list_query_table_dependencies(query.this, cte_lookup)) + tables = tables.union( + list_query_table_dependencies(query.expression, cte_lookup) + ) + + else: + table_sources: t.List[exp.Expression] = [query.args["from"]] + joins = query.args.get("joins", []) + table_sources.extend(joins) + for source in table_sources: + for table in source.find_all(exp.Table): + table_fqn = resolve_table_fqn(table) + if table_fqn in cte_lookup: + continue + tables.add(table_fqn) + + # Recurse into the ctes + for cte in query.ctes: + tables = tables.union(list_query_table_dependencies(cte.this, cte_lookup)) + + return tables + + def resolve_table_map_from_scope( context: ExecutionContext, scope: Scope ) -> t.Dict[str, str]: @@ -47,6 +90,8 @@ def create_dependent_tables_map( scope = build_scope(query) if not scope: raise ValueError("Failed to build scope") - tables_map = resolve_table_map_from_scope(context, scope) + # tables_map = resolve_table_map_from_scope(context, scope) + tables = list_query_table_dependencies(query, {}) + tables_map = {table: context.table(table) for table in tables} return tables_map diff --git a/warehouse/metrics_tools/utils/test_tables.py b/warehouse/metrics_tools/utils/test_tables.py index 7d4c0a74..4db0b6a8 100644 --- a/warehouse/metrics_tools/utils/test_tables.py +++ b/warehouse/metrics_tools/utils/test_tables.py @@ -46,6 +46,79 @@ def test_create_dependent_tables_map(): """, {"main.source": "test_table"}, ), + ( + """ + select * from foo + union all + select * from bar + union all + select * from baz + """, + {"foo": "test_table", "bar": "test_table", "baz": "test_table"}, + ), + ( + """ + with foo as ( + select * from bar + ) + select * from foo + union all + select * from baz + """, + {"bar": "test_table", "baz": "test_table"}, + ), + ( # nested ctes, but I don't think any sql engine supports this + """ + with foo as ( + with bar as ( + select * from baz + ) + select * from bar + ) + select * from foo + union all + select * from baz + """, + {"baz": "test_table"}, + ), + ( + """ + with foo as ( + select * from bar + ) + select * from foo + inner join baz on foo.id = baz.id + """, + {"bar": "test_table", "baz": "test_table"}, + ), + ( + """ + select * from ( + select * from foo + ) + """, + {"foo": "test_table"}, + ), + ( + """ + select * from ( + select * from foo + ) + union all + select * from ( + select * from bar + ) + """, + {"foo": "test_table", "bar": "test_table"}, + ), + ( + """ + select * from foo + inner join baz on foo.id = baz.id + left join bar on foo.id = bar.id + """, + {"foo": "test_table", "baz": "test_table", "bar": "test_table"}, + ), ], ) def test_create_dependent_tables_map_parameterized(