Skip to content

Commit

Permalink
Table dependencies resolution now fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
ravenac95 committed Dec 17, 2024
1 parent 51ab1e1 commit 42a749d
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 4 deletions.
53 changes: 49 additions & 4 deletions warehouse/metrics_tools/utils/tables.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import typing as t

from sqlglot import exp
Expand All @@ -14,16 +15,58 @@ def resolve_identifier_or_string(i: exp.Expression | str) -> t.Optional[str]:
return None


def resolve_table_fqn(table: exp.Table) -> str:
table_name_parts = map(resolve_identifier_or_string, table.parts)
table_name_parts = filter(None, table_name_parts)
return ".".join(table_name_parts)


def resolve_table_name(
context: ExecutionContext, table: exp.Table
) -> t.Tuple[str, str]:
table_name_parts = map(resolve_identifier_or_string, table.parts)
table_name_parts = filter(None, table_name_parts)
table_fqn = ".".join(table_name_parts)
table_fqn = resolve_table_fqn(table)

return (table_fqn, context.table(table_fqn))


def list_query_table_dependencies(
query: exp.Expression, parent_ctes: t.Dict[str, exp.Expression]
) -> t.Set[str]:
tables: t.Set[str] = set()
cte_lookup = copy.deepcopy(parent_ctes)

assert isinstance(
query, (exp.Select, exp.Union)
), f"Unsupported query type {type(query)}"

# Lookup ctes
for cte in query.ctes:
cte_lookup[cte.alias] = cte.this

if isinstance(query, exp.Union):
tables = tables.union(list_query_table_dependencies(query.this, cte_lookup))
tables = tables.union(
list_query_table_dependencies(query.expression, cte_lookup)
)

else:
table_sources: t.List[exp.Expression] = [query.args["from"]]
joins = query.args.get("joins", [])
table_sources.extend(joins)
for source in table_sources:
for table in source.find_all(exp.Table):
table_fqn = resolve_table_fqn(table)
if table_fqn in cte_lookup:
continue
tables.add(table_fqn)

# Recurse into the ctes
for cte in query.ctes:
tables = tables.union(list_query_table_dependencies(cte.this, cte_lookup))

return tables


def resolve_table_map_from_scope(
context: ExecutionContext, scope: Scope
) -> t.Dict[str, str]:
Expand All @@ -47,6 +90,8 @@ def create_dependent_tables_map(
scope = build_scope(query)
if not scope:
raise ValueError("Failed to build scope")
tables_map = resolve_table_map_from_scope(context, scope)
# tables_map = resolve_table_map_from_scope(context, scope)
tables = list_query_table_dependencies(query, {})
tables_map = {table: context.table(table) for table in tables}

return tables_map
73 changes: 73 additions & 0 deletions warehouse/metrics_tools/utils/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,79 @@ def test_create_dependent_tables_map():
""",
{"main.source": "test_table"},
),
(
"""
select * from foo
union all
select * from bar
union all
select * from baz
""",
{"foo": "test_table", "bar": "test_table", "baz": "test_table"},
),
(
"""
with foo as (
select * from bar
)
select * from foo
union all
select * from baz
""",
{"bar": "test_table", "baz": "test_table"},
),
( # nested ctes, but I don't think any sql engine supports this
"""
with foo as (
with bar as (
select * from baz
)
select * from bar
)
select * from foo
union all
select * from baz
""",
{"baz": "test_table"},
),
(
"""
with foo as (
select * from bar
)
select * from foo
inner join baz on foo.id = baz.id
""",
{"bar": "test_table", "baz": "test_table"},
),
(
"""
select * from (
select * from foo
)
""",
{"foo": "test_table"},
),
(
"""
select * from (
select * from foo
)
union all
select * from (
select * from bar
)
""",
{"foo": "test_table", "bar": "test_table"},
),
(
"""
select * from foo
inner join baz on foo.id = baz.id
left join bar on foo.id = bar.id
""",
{"foo": "test_table", "baz": "test_table", "bar": "test_table"},
),
],
)
def test_create_dependent_tables_map_parameterized(
Expand Down

0 comments on commit 42a749d

Please sign in to comment.