From 276ccfba979e3b4eb19527d0e022c845eb1f5157 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 23 Jan 2024 16:53:41 -0500 Subject: [PATCH] chore: get rid of sqlparse --- .../src/SqlLab/actions/sqlLab.js | 1 + superset/connectors/sqla/models.py | 52 +--- superset/db_engine_specs/base.py | 4 +- superset/db_engine_specs/postgres.py | 7 +- superset/models/helpers.py | 22 +- superset/sql_parse.py | 245 +++++++++++++----- superset/sqllab/api.py | 4 +- superset/sqllab/schemas.py | 1 + tests/unit_tests/sql_parse_tests.py | 35 +++ 9 files changed, 239 insertions(+), 132 deletions(-) diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js index 567d3383d752d..e1d74f8b97c68 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.js @@ -912,6 +912,7 @@ export function formatQuery(queryEditor) { const { sql } = getUpToDateQuery(getState(), queryEditor); return SupersetClient.post({ endpoint: `/api/v1/sqllab/format_sql/`, + // TODO (betodealmeida): pass engine as a parameter for better formatting body: JSON.stringify({ sql }), headers: { 'Content-Type': 'application/json' }, }).then(({ json }) => { diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 08dc923c21b27..42b3b579b5d40 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -33,7 +33,6 @@ import numpy as np import pandas as pd import sqlalchemy as sa -import sqlparse from flask import escape, Markup from flask_appbuilder import Model from flask_appbuilder.security.sqla.models import User @@ -1100,7 +1099,9 @@ def _process_sql_expression( class SqlaTable( - Model, BaseDatasource, ExploreMixin + Model, + BaseDatasource, + ExploreMixin, ): # pylint: disable=too-many-public-methods """An ORM object for SqlAlchemy table references""" @@ -1414,26 +1415,6 @@ def mutate_query_from_config(self, sql: str) -> str: def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: return get_template_processor(table=self, database=self.database, **kwargs) - def get_query_str_extended( - self, - query_obj: QueryObjectDict, - mutate: bool = True, - ) -> QueryStringExtended: - sqlaq = self.get_sqla_query(**query_obj) - sql = self.database.compile_sqla_query(sqlaq.sqla_query) - sql = self._apply_cte(sql, sqlaq.cte) - sql = sqlparse.format(sql, reindent=True) - if mutate: - sql = self.mutate_query_from_config(sql) - return QueryStringExtended( - applied_template_filters=sqlaq.applied_template_filters, - applied_filter_columns=sqlaq.applied_filter_columns, - rejected_filter_columns=sqlaq.rejected_filter_columns, - labels_expected=sqlaq.labels_expected, - prequeries=sqlaq.prequeries, - sql=sql, - ) - def get_query_str(self, query_obj: QueryObjectDict) -> str: query_str_ext = self.get_query_str_extended(query_obj) all_queries = query_str_ext.prequeries + [query_str_ext.sql] @@ -1475,33 +1456,6 @@ def get_from_clause( return from_clause, cte - def get_rendered_sql( - self, template_processor: BaseTemplateProcessor | None = None - ) -> str: - """ - Render sql with template engine (Jinja). - """ - - sql = self.sql - if template_processor: - try: - sql = template_processor.process_template(sql) - except TemplateError as ex: - raise QueryObjectValidationError( - _( - "Error while rendering virtual dataset query: %(msg)s", - msg=ex.message, - ) - ) from ex - sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True) - if not sql: - raise QueryObjectValidationError(_("Virtual dataset query cannot be empty")) - if len(sqlparse.split(sql)) > 1: - raise QueryObjectValidationError( - _("Virtual dataset query cannot consist of multiple statements") - ) - return sql - def adhoc_metric_to_sqla( self, metric: AdhocMetric, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 3b8bb2bd33292..a5a52b5f143b3 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -59,7 +59,7 @@ from superset.constants import TimeGrain as TimeGrainConstants from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.sql_parse import ParsedQuery, Table +from superset.sql_parse import ParsedQuery, SQLQuery, Table from superset.superset_typing import ResultSetColumnType, SQLAColumnType from superset.utils import core as utils from superset.utils.core import ColumnSpec, GenericDataType @@ -1448,7 +1448,7 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals qry = partition_query sql = database.compile_sqla_query(qry) if indent: - sql = sqlparse.format(sql, reindent=True) + sql = SQLQuery(sql).format() return sql @classmethod diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 07be634d0777e..dc19a7b6b9b7b 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -24,7 +24,6 @@ from re import Pattern from typing import Any, TYPE_CHECKING -import sqlparse from flask_babel import gettext as __ from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON from sqlalchemy.dialects.postgresql.base import PGInspector @@ -37,6 +36,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException, SupersetSecurityException from superset.models.sql_lab import Query +from superset.sql_parse import SQLQuery from superset.utils import core as utils from superset.utils.core import GenericDataType @@ -281,8 +281,9 @@ def get_default_schema_for_query( This method simply uses the parent method after checking that there are no malicious path setting in the query. """ - sql = sqlparse.format(query.sql, strip_comments=True) - if re.search(r"set\s+search_path\s*=", sql, re.IGNORECASE): + statement = SQLQuery(query.sql) + settings = statement.get_settings() + if "search_path" in settings: raise SupersetSecurityException( SupersetError( error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR, diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 9c8e83147e9ed..475fe79407fcc 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -73,6 +73,8 @@ insert_rls_in_predicate, ParsedQuery, sanitize_clause, + SQLQuery, + SQLStatement, ) from superset.superset_typing import ( AdhocMetric, @@ -901,12 +903,14 @@ def _apply_cte(sql: str, cte: Optional[str]) -> str: return sql def get_query_str_extended( - self, query_obj: QueryObjectDict, mutate: bool = True + self, + query_obj: QueryObjectDict, + mutate: bool = True, ) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) sql = self._apply_cte(sql, sqlaq.cte) - sql = sqlparse.format(sql, reindent=True) + sql = SQLStatement(sql).format() if mutate: sql = self.mutate_query_from_config(sql) return QueryStringExtended( @@ -1054,7 +1058,8 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: ) def get_rendered_sql( - self, template_processor: Optional[BaseTemplateProcessor] = None + self, + template_processor: Optional[BaseTemplateProcessor] = None, ) -> str: """ Render sql with template engine (Jinja). @@ -1071,13 +1076,16 @@ def get_rendered_sql( msg=ex.message, ) ) from ex - sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True) - if not sql: - raise QueryObjectValidationError(_("Virtual dataset query cannot be empty")) - if len(sqlparse.split(sql)) > 1: + + query = SQLQuery(sql.strip("\t\r\n; ")) + if len(query.statements) > 1: raise QueryObjectValidationError( _("Virtual dataset query cannot consist of multiple statements") ) + + sql = query.statements[0].format(comments=False) + if not sql: + raise QueryObjectValidationError(_("Virtual dataset query cannot be empty")) return sql def text(self, clause: str) -> TextClause: diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 07704171dee3d..1d1807c28ef97 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -22,12 +22,13 @@ import urllib.parse from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Any, cast, Optional +from typing import Any, cast, Optional, Union +import sqlglot import sqlparse from sqlalchemy import and_ from sqlglot import exp, parse, parse_one -from sqlglot.dialects import Dialects +from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.errors import ParseError from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope from sqlparse import keywords @@ -252,6 +253,178 @@ def __eq__(self, __o: object) -> bool: return str(self) == str(__o) +def extract_tables_from_statement( + statement: exp.Expression, + dialect: Optional[Dialects], +) -> set[Table]: + """ + Extract all table references in a single statement. + + Please not that this is not trivial; consider the following queries: + + DESCRIBE some_table; + SHOW PARTITIONS FROM some_table; + WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name; + + See the unit tests for other tricky cases. + """ + sources: Iterable[exp.Table] + + if isinstance(statement, exp.Describe): + # A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly + # query for all tables. + sources = statement.find_all(exp.Table) + elif isinstance(statement, exp.Command): + # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a + # `SELECT` statetement in order to extract tables. + literal = statement.find(exp.Literal) + if not literal: + return set() + + pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect) + sources = pseudo_query.find_all(exp.Table) + else: + sources = [ + source + for scope in traverse_scope(statement) + for source in scope.sources.values() + if isinstance(source, exp.Table) and not is_cte(source, scope) + ] + + return { + Table( + source.name, + source.db if source.db != "" else None, + source.catalog if source.catalog != "" else None, + ) + for source in sources + } + + +def is_cte(source: exp.Table, scope: Scope) -> bool: + """ + Is the source a CTE? + + CTEs in the parent scope look like tables (and are represented by + exp.Table objects), but should not be considered as such; + otherwise a user with access to table `foo` could access any table + with a query like this: + + WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo + + """ + parent_sources = scope.parent.sources if scope.parent else {} + ctes_in_scope = { + name + for name, parent_scope in parent_sources.items() + if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE + } + + return source.name in ctes_in_scope + + +class SQLQuery: + """ + A SQL query, with 0+ statements. + """ + + def __init__( + self, + query: str, + engine: Optional[str] = None, + ): + dialect = SQLGLOT_DIALECTS.get(engine) if engine else None + + self.statements = [ + SQLStatement(statement, engine=engine) + for statement in parse(query, dialect=dialect) + if statement + ] + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL query. + """ + return ";\n".join(statement.format() for statement in self.statements) + + def get_settings(self) -> dict[str, str]: + """ + Return the settings for the SQL query. + + >>> statement = SQLQuery("SET foo = 'bar'; SET foo = 'baz'") + >>> statement.get_settings() + {"foo": "'baz'"} + + """ + settings: dict[str, str] = {} + for statement in self.statements: + settings.update(statement.get_settings()) + + return settings + + +class SQLStatement: + """ + A SQL statement. + + This class provides helper methods to manipulate and introspect SQL. + """ + + def __init__( + self, + statement: Union[str, exp.Expression], + engine: Optional[str] = None, + ): + dialect = SQLGLOT_DIALECTS.get(engine) if engine else None + self._parsed = ( + self._parse_statement(statement, dialect) + if isinstance(statement, str) + else statement + ) + self._dialect = dialect + self.tables = extract_tables_from_statement(self._parsed, dialect) + + @staticmethod + def _parse_statement( + sql_statement: str, + dialect: Optional[Dialects], + ) -> exp.Expression: + """ + Parse a single SQL statement. + """ + statements = [ + statement + for statement in sqlglot.parse(sql_statement, dialect=dialect) + if statement + ] + if len(statements) != 1: + raise ValueError("SQLStatement should have exactly one statement") + + return statements[0] + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL statement. + """ + write = Dialect.get_or_raise(self._dialect) + return write.generate(self._parsed, copy=False, comments=comments, pretty=True) + + def get_settings(self) -> dict[str, str]: + """ + Return the settings for the SQL statement. + + >>> statement = SQLStatement("SET foo = 'bar'") + >>> statement.get_settings() + {"foo": "'bar'"} + + """ + return { + eq.this.sql(): eq.expression.sql() + for set_item in self._parsed.find_all(exp.SetItem) + for eq in set_item.find_all(exp.EQ) + } + + class ParsedQuery: def __init__( self, @@ -294,76 +467,10 @@ def _extract_tables_from_sql(self) -> set[Table]: return { table for statement in statements - for table in self._extract_tables_from_statement(statement) + for table in extract_tables_from_statement(statement, self._dialect) if statement } - def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]: - """ - Extract all table references in a single statement. - - Please not that this is not trivial; consider the following queries: - - DESCRIBE some_table; - SHOW PARTITIONS FROM some_table; - WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name; - - See the unit tests for other tricky cases. - """ - sources: Iterable[exp.Table] - - if isinstance(statement, exp.Describe): - # A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly - # query for all tables. - sources = statement.find_all(exp.Table) - elif isinstance(statement, exp.Command): - # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a - # `SELECT` statetement in order to extract tables. - literal = statement.find(exp.Literal) - if not literal: - return set() - - pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect) - sources = pseudo_query.find_all(exp.Table) - else: - sources = [ - source - for scope in traverse_scope(statement) - for source in scope.sources.values() - if isinstance(source, exp.Table) and not self._is_cte(source, scope) - ] - - return { - Table( - source.name, - source.db if source.db != "" else None, - source.catalog if source.catalog != "" else None, - ) - for source in sources - } - - def _is_cte(self, source: exp.Table, scope: Scope) -> bool: - """ - Is the source a CTE? - - CTEs in the parent scope look like tables (and are represented by - exp.Table objects), but should not be considered as such; - otherwise a user with access to table `foo` could access any table - with a query like this: - - WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo - - """ - parent_sources = scope.parent.sources if scope.parent else {} - ctes_in_scope = { - name - for name, parent_scope in parent_sources.items() - if isinstance(parent_scope, Scope) - and parent_scope.scope_type == ScopeType.CTE - } - - return source.name in ctes_in_scope - @property def limit(self) -> Optional[int]: return self._limit diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py index 6be378a9b5117..ce302570038bc 100644 --- a/superset/sqllab/api.py +++ b/superset/sqllab/api.py @@ -19,7 +19,6 @@ from urllib import parse import simplejson as json -import sqlparse from flask import request, Response from flask_appbuilder import permission_name from flask_appbuilder.api import expose, protect, rison, safe @@ -38,6 +37,7 @@ from superset.jinja_context import get_template_processor from superset.models.sql_lab import Query from superset.sql_lab import get_sql_results +from superset.sql_parse import SQLQuery from superset.sqllab.command_status import SqlJsonExecutionStatus from superset.sqllab.exceptions import ( QueryIsForbiddenToAccessException, @@ -230,7 +230,7 @@ def format_sql(self) -> FlaskResponse: """ try: model = self.format_model_schema.load(request.json) - result = sqlparse.format(model["sql"], reindent=True, keyword_case="upper") + result = SQLQuery(model["sql"], model.get("engine")).format() return self.response(200, result=result) except ValidationError as error: return self.response_400(message=error.messages) diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py index 66f90a6e920a0..7329e7edd90ea 100644 --- a/superset/sqllab/schemas.py +++ b/superset/sqllab/schemas.py @@ -44,6 +44,7 @@ class EstimateQueryCostSchema(Schema): class FormatQueryPayloadSchema(Schema): sql = fields.String(required=True) + engine = fields.String(required=False, allow_none=True) class ExecutePayloadSchema(Schema): diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index f650b77734f36..268e482919bdc 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -35,6 +35,8 @@ insert_rls_in_predicate, ParsedQuery, sanitize_clause, + SQLQuery, + SQLStatement, strip_comments_from_sql, Table, ) @@ -1849,3 +1851,36 @@ def test_is_select() -> None: ) SELECT * FROM t""" ).is_select() + + +def test_sqlquery() -> None: + """ + Test the `SQLQuery` class. + """ + query = SQLQuery("SELECT 1; SELECT 2;") + + assert len(query.statements) == 2 + assert query.format() == "SELECT\n 1;\nSELECT\n 2" + assert query.statements[0].format() == "SELECT\n 1" + + query = SQLQuery("SET a=1; SET a=2; SELECT 3;") + assert query.get_settings() == {"a": "2"} + + +def test_sqlstatement() -> None: + """ + Test the `SQLStatement` class. + """ + statement = SQLStatement("SELECT * FROM table1 UNION ALL SELECT * FROM table2") + + assert statement.tables == { + Table(table="table1", schema=None, catalog=None), + Table(table="table2", schema=None, catalog=None), + } + assert ( + statement.format() + == "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2" + ) + + statement = SQLStatement("SET a=1") + assert statement.get_settings() == {"a": "1"}