Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add JSON serializer for ASTs and store them upon node creation #699

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Add query ast to noderevision

Revision ID: 789f91d2d69e
Revises: ccc77abcf899
Create Date: 2023-08-07 14:32:54.290688+00:00

"""
# pylint: disable=no-member, invalid-name, missing-function-docstring, unused-import, no-name-in-module

import sqlalchemy as sa
import sqlmodel

from alembic import op

# revision identifiers, used by Alembic.
revision = "789f91d2d69e"
down_revision = "ccc77abcf899"
branch_labels = None
depends_on = None


def upgrade():
op.add_column("noderevision", sa.Column("query_ast", sa.JSON(), nullable=True))


def downgrade():
op.drop_column("noderevision", "query_ast")
1 change: 1 addition & 0 deletions datajunction-server/datajunction_server/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def client_code_for_creating_node(
"node_id",
"updated_at",
"query" if node.type == NodeType.CUBE else "",
"query_ast",
},
exclude_none=True,
)
Expand Down
3 changes: 2 additions & 1 deletion datajunction-server/datajunction_server/api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from datajunction_server.service_clients import QueryServiceClient
from datajunction_server.sql.dag import get_nodes_with_dimension
from datajunction_server.sql.parsing import ast
from datajunction_server.sql.parsing.ast_json_encoder import ASTEncoder
from datajunction_server.sql.parsing.backends.antlr4 import SqlSyntaxError, parse
from datajunction_server.sql.parsing.backends.exceptions import DJParseException
from datajunction_server.typing import END_JOB_STATES, UTCDatetime
Expand Down Expand Up @@ -415,7 +416,7 @@ def validate_node_data( # pylint: disable=too-many-locals
dependencies_map,
)
validated_node.required_dimensions = matched_bound_columns

validated_node.query_ast = json.loads(json.dumps(query_ast, cls=ASTEncoder))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blown-away-maxwell

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this just handles serializing and storing the ast. As I mentioned below, I might try some basic deserialization to make sure it works for query building, but I'll put the actual implementation in a separate PR :)

errors = []
if missing_parents_map or type_inference_failures or invalid_required_dimensions:
# update status (if needed)
Expand Down
13 changes: 12 additions & 1 deletion datajunction-server/datajunction_server/construction/build.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Functions to add to an ast DJ node queries"""
import collections
import json
import logging
import time

Expand All @@ -16,6 +17,7 @@
from datajunction_server.models.node import BuildCriteria, Node, NodeRevision, NodeType
from datajunction_server.sql.dag import get_shared_dimensions
from datajunction_server.sql.parsing.ast import CompileContext
from datajunction_server.sql.parsing.ast_json_encoder import ast_decoder
from datajunction_server.sql.parsing.backends.antlr4 import ast, parse
from datajunction_server.sql.parsing.types import ColumnType
from datajunction_server.utils import amenable_name
Expand Down Expand Up @@ -432,6 +434,8 @@ def add_filters_dimensions_orderby_limit_to_query_ast(
projection_update += list(projection_addition.values())

query.select.projection = projection_update
query.select._is_compiled = False # pylint: disable=protected-access
query._is_compiled = False # pylint: disable=protected-access

if limit is not None:
query.select.limit = ast.Number(limit)
Expand Down Expand Up @@ -516,7 +520,12 @@ def build_node( # pylint: disable=too-many-arguments
):
return ast.Query(select=select) # pragma: no cover

if node.query:
if node.query_ast:
query = json.loads(
json.dumps(node.query_ast),
object_hook=lambda _dict: ast_decoder(session, _dict),
)
elif node.query:
query = parse(node.query)
else:
query = build_source_node_query(node)
Expand Down Expand Up @@ -824,6 +833,8 @@ def build_ast( # pylint: disable=too-many-arguments
context = CompileContext(session=session, exception=DJException())
if hash(query) in memoized_queries:
query = memoized_queries[hash(query)] # pragma: no cover
elif query.is_compiled():
memoized_queries[hash(query)] = query # pragma: no cover
else:
query.compile(context)
memoized_queries[hash(query)] = query
Expand Down
17 changes: 16 additions & 1 deletion datajunction-server/datajunction_server/models/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime, timezone
from functools import partial
from http import HTTPStatus
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from pydantic import BaseModel, Extra
from pydantic import Field as PydanticField
Expand Down Expand Up @@ -678,6 +678,11 @@ class NodeRevision(NodeRevisionBase, table=True): # type: ignore
},
)

query_ast: Optional[Dict[str, Any]] = Field(
sa_column=SqlaColumn("query_ast", JSON),
default={},
)

def __hash__(self) -> int:
return hash(self.id)

Expand Down Expand Up @@ -830,6 +835,15 @@ def has_available_materialization(self, build_criteria: BuildCriteria) -> bool:
)
)

def __json_encode__(self):
"""
JSON encoder for node revision
"""
return {
"name": self.name,
"type": self.type,
}


class ImmutableNodeFields(BaseSQLModel):
"""
Expand Down Expand Up @@ -1101,6 +1115,7 @@ class NodeRevisionOutput(SQLModel):
table: Optional[str]
description: str = ""
query: Optional[str] = None
query_ast: Optional[Dict] = {}
availability: Optional[AvailabilityState] = None
columns: List[ColumnOutput]
updated_at: UTCDatetime
Expand Down
49 changes: 49 additions & 0 deletions datajunction-server/datajunction_server/sql/parsing/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ class Node(ABC):

_is_compiled: bool = False

@property
def json_ignore_keys(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this pattern 🙂

return ["parent", "parent_key", "_is_compiled"]

def __json_encode__(self):
return {
key: self.__dict__[key]
for key in self.__dict__
if key not in self.json_ignore_keys
}

def __post_init__(self):
self.add_self_as_parent()

Expand Down Expand Up @@ -624,6 +635,10 @@ def identifier(self, quotes: bool = True) -> str:
f"{namespace}{quote_style}{self.name}{quote_style}" # pylint: disable=C0301
)

@property
def json_ignore_keys(self):
return ["names", "parent", "parent_key"]


TNamed = TypeVar("TNamed", bound="Named") # pylint: disable=C0103

Expand Down Expand Up @@ -705,6 +720,10 @@ class Column(Aliasable, Named, Expression):
_expression: Optional[Expression] = field(repr=False, default=None)
_is_compiled: bool = False

@property
def json_ignore_keys(self):
return ["parent", "parent_key", "columns"]

@property
def type(self):
if self._type:
Expand Down Expand Up @@ -985,6 +1004,18 @@ class TableExpression(Aliasable, Expression):
# ref (referenced) columns are columns used elsewhere from this table
_ref_columns: List[Column] = field(init=False, repr=False, default_factory=list)

@property
def json_ignore_keys(self):
return [
"parent",
"parent_key",
# "_is_compiled",
"_columns",
# "column_list",
"_ref_columns",
# "columns",
]

@property
def columns(self) -> List[Expression]:
"""
Expand Down Expand Up @@ -1229,6 +1260,11 @@ class BinaryOpKind(DJEnum):
Minus = "-"
Modulo = "%"

def __json_encode__(self):
return {
"value": self.value,
}


@dataclass(eq=False)
class BinaryOp(Operation):
Expand Down Expand Up @@ -2003,6 +2039,19 @@ class FunctionTable(FunctionTableExpression):
Represents a table-valued function used in a statement
"""

@property
def json_ignore_keys(self):
return [
"parent",
"parent_key",
"_is_compiled",
"_table",
"_columns",
"column_list",
"_ref_columns",
"columns",
]

def __str__(self) -> str:
alias = f" {self.alias}" if self.alias else ""
as_ = " AS " if self.as_ else ""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
JSON encoder for AST objects
"""
from json import JSONEncoder

from sqlmodel import select

from datajunction_server.models import Node
from datajunction_server.sql.parsing import ast


def remove_circular_refs(obj, _seen: set = None):
"""
Short-circuits circular references in AST nodes
"""
if _seen is None:
_seen = set()
if id(obj) in _seen:
return None
_seen.add(id(obj))
if issubclass(obj.__class__, ast.Node):
serializable_keys = [
key for key in obj.__dict__.keys() if key not in obj.json_ignore_keys
]
for key in serializable_keys:
setattr(obj, key, remove_circular_refs(getattr(obj, key), _seen))
_seen.remove(id(obj))
return obj


class ASTEncoder(JSONEncoder):
"""
JSON encoder for AST objects. Disables the original circular check in favor
of our own version with _processed so that we can catch and handle circular
traversals.
"""

def __init__(self, *args, **kwargs):
kwargs["check_circular"] = False
self.markers = set()
super().__init__(*args, **kwargs)

def default(self, o):
o = remove_circular_refs(o)
json_dict = {
"__class__": o.__class__.__name__,
}
if hasattr(o, "__json_encode__"): # pragma: no cover
json_dict = {**json_dict, **o.__json_encode__()}
return json_dict


def ast_decoder(session, json_dict):
"""
Decodes json dict back into an AST entity
"""
class_name = json_dict["__class__"]
clazz = getattr(ast, class_name)

# Instantiate the class
instance = clazz(
**{
k: v
for k, v in json_dict.items()
if k not in {"__class__", "_type", "laterals", "_is_compiled"}
},
)

# Set attributes where possible
for key, value in json_dict.items():
if key not in {"__class__", "_is_compiled"}:
if hasattr(instance, key) and class_name not in {"BinaryOpKind"}:
setattr(instance, key, value)

if class_name == "NodeRevision":
# Overwrite with DB object if it's a node revision
instance = (
session.exec(select(Node).where(Node.name == json_dict["name"]))
.one()
.current
)
elif class_name == "Column":
# Add in a reference to the table from the column
instance._table.parent = instance # pylint: disable=protected-access
instance._table.parent_key = "_table" # pylint: disable=protected-access
return instance
5 changes: 5 additions & 0 deletions datajunction-server/datajunction_server/sql/parsing/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def __str__(self):
def __deepcopy__(self, memo):
return self

def __json_encode__(self):
return {
"__class__": self.__class__.__name__,
}

@classmethod
def __get_validators__(cls) -> Generator[AnyCallable, None, None]:
"""
Expand Down
Loading
Loading