From 26d8077e97d541d2d5cc11753d68c7ea11657470 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Wed, 13 Mar 2024 18:27:01 -0400 Subject: [PATCH] chore: improve SQL parsing (#26767) --- .../cypress/e2e/explore/AdhocMetrics.test.ts | 2 +- .../e2e/explore/visualizations/table.test.ts | 4 +- .../src/SqlLab/actions/sqlLab.js | 1 + superset/connectors/sqla/models.py | 53 +---- superset/db_engine_specs/base.py | 4 +- superset/db_engine_specs/postgres.py | 7 +- superset/errors.py | 2 + superset/exceptions.py | 17 ++ superset/models/helpers.py | 27 ++- superset/sql_parse.py | 190 +++++++++++++++++- superset/sqllab/api.py | 4 +- superset/sqllab/schemas.py | 1 + tests/integration_tests/celery_tests.py | 8 +- .../charts/data/api_tests.py | 8 +- tests/integration_tests/core_tests.py | 2 +- tests/integration_tests/datasource_tests.py | 2 +- .../db_engine_specs/base_engine_spec_tests.py | 5 +- .../db_engine_specs/bigquery_tests.py | 2 +- tests/integration_tests/model_tests.py | 57 +----- .../integration_tests/query_context_tests.py | 67 +++--- .../security/row_level_security_tests.py | 4 +- tests/integration_tests/sql_lab/api_tests.py | 2 +- tests/integration_tests/sqla_models_tests.py | 21 +- tests/unit_tests/db_engine_specs/test_base.py | 6 +- .../db_engine_specs/test_bigquery.py | 7 +- tests/unit_tests/jinja_context_test.py | 52 +++-- tests/unit_tests/sql_parse_tests.py | 35 ++++ 27 files changed, 394 insertions(+), 196 deletions(-) diff --git a/superset-frontend/cypress-base/cypress/e2e/explore/AdhocMetrics.test.ts b/superset-frontend/cypress-base/cypress/e2e/explore/AdhocMetrics.test.ts index e97ac74c3f2a2..b1c0fd56cf6fa 100644 --- a/superset-frontend/cypress-base/cypress/e2e/explore/AdhocMetrics.test.ts +++ b/superset-frontend/cypress-base/cypress/e2e/explore/AdhocMetrics.test.ts @@ -25,7 +25,7 @@ describe('AdhocMetrics', () => { }); it('Clear metric and set simple adhoc metric', () => { - const metric = 'sum(num_girls)'; + const metric = 'SUM(num_girls)'; const metricName = 'Sum Girls'; cy.get('[data-test=metrics]') .find('[data-test="remove-control-button"]') diff --git a/superset-frontend/cypress-base/cypress/e2e/explore/visualizations/table.test.ts b/superset-frontend/cypress-base/cypress/e2e/explore/visualizations/table.test.ts index 7db1dbe8ef95a..425e5e694b489 100644 --- a/superset-frontend/cypress-base/cypress/e2e/explore/visualizations/table.test.ts +++ b/superset-frontend/cypress-base/cypress/e2e/explore/visualizations/table.test.ts @@ -100,7 +100,7 @@ describe('Visualization > Table', () => { }); cy.verifySliceSuccess({ waitAlias: '@chartData', - querySubstring: /group by.*name/i, + querySubstring: /group by\n.*name/i, chartSelector: 'table', }); }); @@ -246,7 +246,7 @@ describe('Visualization > Table', () => { cy.visitChartByParams(formData); cy.verifySliceSuccess({ waitAlias: '@chartData', - querySubstring: /group by.*state/i, + querySubstring: /group by\n.*state/i, chartSelector: 'table', }); cy.get('td').contains(/\d*%/); diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js index 16bf3f530f139..e96198a0eab6f 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.js @@ -921,6 +921,7 @@ export function formatQuery(queryEditor) { const { sql } = getUpToDateQuery(getState(), queryEditor); return SupersetClient.post({ endpoint: `/api/v1/sqllab/format_sql/`, + // TODO (betodealmeida): pass engine as a parameter for better formatting body: JSON.stringify({ sql }), headers: { 'Content-Type': 'application/json' }, }).then(({ json }) => { diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 089b9c2f28960..dd9334d9d06ec 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -33,7 +33,6 @@ import numpy as np import pandas as pd import sqlalchemy as sa -import sqlparse from flask import escape, Markup from flask_appbuilder import Model from flask_appbuilder.security.sqla.models import User @@ -104,7 +103,6 @@ ExploreMixin, ImportExportMixin, QueryResult, - QueryStringExtended, validate_adhoc_subquery, ) from superset.models.slice import Slice @@ -1099,7 +1097,9 @@ def _process_sql_expression( class SqlaTable( - Model, BaseDatasource, ExploreMixin + Model, + BaseDatasource, + ExploreMixin, ): # pylint: disable=too-many-public-methods """An ORM object for SqlAlchemy table references""" @@ -1413,26 +1413,6 @@ def mutate_query_from_config(self, sql: str) -> str: def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: return get_template_processor(table=self, database=self.database, **kwargs) - def get_query_str_extended( - self, - query_obj: QueryObjectDict, - mutate: bool = True, - ) -> QueryStringExtended: - sqlaq = self.get_sqla_query(**query_obj) - sql = self.database.compile_sqla_query(sqlaq.sqla_query) - sql = self._apply_cte(sql, sqlaq.cte) - sql = sqlparse.format(sql, reindent=True) - if mutate: - sql = self.mutate_query_from_config(sql) - return QueryStringExtended( - applied_template_filters=sqlaq.applied_template_filters, - applied_filter_columns=sqlaq.applied_filter_columns, - rejected_filter_columns=sqlaq.rejected_filter_columns, - labels_expected=sqlaq.labels_expected, - prequeries=sqlaq.prequeries, - sql=sql, - ) - def get_query_str(self, query_obj: QueryObjectDict) -> str: query_str_ext = self.get_query_str_extended(query_obj) all_queries = query_str_ext.prequeries + [query_str_ext.sql] @@ -1474,33 +1454,6 @@ def get_from_clause( return from_clause, cte - def get_rendered_sql( - self, template_processor: BaseTemplateProcessor | None = None - ) -> str: - """ - Render sql with template engine (Jinja). - """ - - sql = self.sql - if template_processor: - try: - sql = template_processor.process_template(sql) - except TemplateError as ex: - raise QueryObjectValidationError( - _( - "Error while rendering virtual dataset query: %(msg)s", - msg=ex.message, - ) - ) from ex - sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True) - if not sql: - raise QueryObjectValidationError(_("Virtual dataset query cannot be empty")) - if len(sqlparse.split(sql)) > 1: - raise QueryObjectValidationError( - _("Virtual dataset query cannot consist of multiple statements") - ) - return sql - def adhoc_metric_to_sqla( self, metric: AdhocMetric, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 5b850913ed778..e8790bdcd4f77 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -59,7 +59,7 @@ from superset.constants import TimeGrain as TimeGrainConstants from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.sql_parse import ParsedQuery, Table +from superset.sql_parse import ParsedQuery, SQLScript, Table from superset.superset_typing import ResultSetColumnType, SQLAColumnType from superset.utils import core as utils from superset.utils.core import ColumnSpec, GenericDataType @@ -1448,7 +1448,7 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals qry = partition_query sql = database.compile_sqla_query(qry) if indent: - sql = sqlparse.format(sql, reindent=True) + sql = SQLScript(sql, engine=cls.engine).format() return sql @classmethod diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 07be634d0777e..b4755c6cd7671 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -24,7 +24,6 @@ from re import Pattern from typing import Any, TYPE_CHECKING -import sqlparse from flask_babel import gettext as __ from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON from sqlalchemy.dialects.postgresql.base import PGInspector @@ -37,6 +36,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException, SupersetSecurityException from superset.models.sql_lab import Query +from superset.sql_parse import SQLScript from superset.utils import core as utils from superset.utils.core import GenericDataType @@ -281,8 +281,9 @@ def get_default_schema_for_query( This method simply uses the parent method after checking that there are no malicious path setting in the query. """ - sql = sqlparse.format(query.sql, strip_comments=True) - if re.search(r"set\s+search_path\s*=", sql, re.IGNORECASE): + script = SQLScript(query.sql, engine=cls.engine) + settings = script.get_settings() + if "search_path" in settings: raise SupersetSecurityException( SupersetError( error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR, diff --git a/superset/errors.py b/superset/errors.py index 6be4f966ce403..7c383891676db 100644 --- a/superset/errors.py +++ b/superset/errors.py @@ -83,6 +83,7 @@ class SupersetErrorType(StrEnum): RESULTS_BACKEND_ERROR = "RESULTS_BACKEND_ERROR" ASYNC_WORKERS_ERROR = "ASYNC_WORKERS_ERROR" ADHOC_SUBQUERY_NOT_ALLOWED_ERROR = "ADHOC_SUBQUERY_NOT_ALLOWED_ERROR" + INVALID_SQL_ERROR = "INVALID_SQL_ERROR" # Generic errors GENERIC_COMMAND_ERROR = "GENERIC_COMMAND_ERROR" @@ -176,6 +177,7 @@ class SupersetErrorType(StrEnum): SupersetErrorType.INVALID_PAYLOAD_SCHEMA_ERROR: [1020], SupersetErrorType.INVALID_CTAS_QUERY_ERROR: [1023], SupersetErrorType.INVALID_CVAS_QUERY_ERROR: [1024, 1025], + SupersetErrorType.INVALID_SQL_ERROR: [1003], SupersetErrorType.SQLLAB_TIMEOUT_ERROR: [1026, 1027], SupersetErrorType.OBJECT_DOES_NOT_EXIST_ERROR: [1029], SupersetErrorType.SYNTAX_ERROR: [1030], diff --git a/superset/exceptions.py b/superset/exceptions.py index 3642a9279ec23..0ce72e2e1a6f1 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -295,3 +295,20 @@ def __init__(self, exc: ValidationError, payload: dict[str, Any]): extra={"messages": exc.messages, "payload": payload}, ) super().__init__(error) + + +class SupersetParseError(SupersetErrorException): + """ + Exception to be raised when we fail to parse SQL. + """ + + status = 422 + + def __init__(self, sql: str, engine: Optional[str] = None): + error = SupersetError( + message=_("The SQL is invalid and cannot be parsed."), + error_type=SupersetErrorType.INVALID_SQL_ERROR, + level=ErrorLevel.ERROR, + extra={"sql": sql, "engine": engine}, + ) + super().__init__(error) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 8d3ed36c465f7..684ef51efa799 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -64,6 +64,7 @@ ColumnNotFoundException, QueryClauseValidationException, QueryObjectValidationError, + SupersetParseError, SupersetSecurityException, ) from superset.extensions import feature_flag_manager @@ -73,6 +74,8 @@ insert_rls_in_predicate, ParsedQuery, sanitize_clause, + SQLScript, + SQLStatement, ) from superset.superset_typing import ( AdhocMetric, @@ -901,12 +904,18 @@ def _apply_cte(sql: str, cte: Optional[str]) -> str: return sql def get_query_str_extended( - self, query_obj: QueryObjectDict, mutate: bool = True + self, + query_obj: QueryObjectDict, + mutate: bool = True, ) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) sql = self._apply_cte(sql, sqlaq.cte) - sql = sqlparse.format(sql, reindent=True) + try: + sql = SQLStatement(sql, engine=self.db_engine_spec.engine).format() + except SupersetParseError: + logger.warning("Unable to parse SQL to format it, passing it as-is") + if mutate: sql = self.mutate_query_from_config(sql) return QueryStringExtended( @@ -1054,7 +1063,8 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: ) def get_rendered_sql( - self, template_processor: Optional[BaseTemplateProcessor] = None + self, + template_processor: Optional[BaseTemplateProcessor] = None, ) -> str: """ Render sql with template engine (Jinja). @@ -1071,13 +1081,16 @@ def get_rendered_sql( msg=ex.message, ) ) from ex - sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True) - if not sql: - raise QueryObjectValidationError(_("Virtual dataset query cannot be empty")) - if len(sqlparse.split(sql)) > 1: + + script = SQLScript(sql.strip("\t\r\n; "), engine=self.db_engine_spec.engine) + if len(script.statements) > 1: raise QueryObjectValidationError( _("Virtual dataset query cannot consist of multiple statements") ) + + sql = script.statements[0].format(comments=False) + if not sql: + raise QueryObjectValidationError(_("Virtual dataset query cannot be empty")) return sql def text(self, clause: str) -> TextClause: diff --git a/superset/sql_parse.py b/superset/sql_parse.py index c85afc9460f12..58dc210e2b604 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -22,13 +22,14 @@ import urllib.parse from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Any, cast, Optional +from typing import Any, cast, Optional, Union +import sqlglot import sqlparse from sqlalchemy import and_ from sqlglot import exp, parse, parse_one -from sqlglot.dialects import Dialects -from sqlglot.errors import SqlglotError +from sqlglot.dialects.dialect import Dialect, Dialects +from sqlglot.errors import ParseError, SqlglotError from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope from sqlparse import keywords from sqlparse.lexer import Lexer @@ -55,7 +56,7 @@ ) from sqlparse.utils import imt -from superset.exceptions import QueryClauseValidationException +from superset.exceptions import QueryClauseValidationException, SupersetParseError from superset.utils.backports import StrEnum try: @@ -252,6 +253,185 @@ def __eq__(self, __o: object) -> bool: return str(self) == str(__o) +def extract_tables_from_statement( + statement: exp.Expression, + dialect: Optional[Dialects], +) -> set[Table]: + """ + Extract all table references in a single statement. + + Please not that this is not trivial; consider the following queries: + + DESCRIBE some_table; + SHOW PARTITIONS FROM some_table; + WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name; + + See the unit tests for other tricky cases. + """ + sources: Iterable[exp.Table] + + if isinstance(statement, exp.Describe): + # A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly + # query for all tables. + sources = statement.find_all(exp.Table) + elif isinstance(statement, exp.Command): + # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a + # `SELECT` statetement in order to extract tables. + literal = statement.find(exp.Literal) + if not literal: + return set() + + try: + pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect) + except ParseError: + return set() + sources = pseudo_query.find_all(exp.Table) + else: + sources = [ + source + for scope in traverse_scope(statement) + for source in scope.sources.values() + if isinstance(source, exp.Table) and not is_cte(source, scope) + ] + + return { + Table( + source.name, + source.db if source.db != "" else None, + source.catalog if source.catalog != "" else None, + ) + for source in sources + } + + +def is_cte(source: exp.Table, scope: Scope) -> bool: + """ + Is the source a CTE? + + CTEs in the parent scope look like tables (and are represented by + exp.Table objects), but should not be considered as such; + otherwise a user with access to table `foo` could access any table + with a query like this: + + WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo + + """ + parent_sources = scope.parent.sources if scope.parent else {} + ctes_in_scope = { + name + for name, parent_scope in parent_sources.items() + if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE + } + + return source.name in ctes_in_scope + + +class SQLScript: + """ + A SQL script, with 0+ statements. + """ + + def __init__( + self, + query: str, + engine: Optional[str] = None, + ): + dialect = SQLGLOT_DIALECTS.get(engine) if engine else None + + self.statements = [ + SQLStatement(statement, engine=engine) + for statement in parse(query, dialect=dialect) + if statement + ] + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL query. + """ + return ";\n".join(statement.format(comments) for statement in self.statements) + + def get_settings(self) -> dict[str, str]: + """ + Return the settings for the SQL query. + + >>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'") + >>> statement.get_settings() + {"foo": "'baz'"} + + """ + settings: dict[str, str] = {} + for statement in self.statements: + settings.update(statement.get_settings()) + + return settings + + +class SQLStatement: + """ + A SQL statement. + + This class provides helper methods to manipulate and introspect SQL. + """ + + def __init__( + self, + statement: Union[str, exp.Expression], + engine: Optional[str] = None, + ): + dialect = SQLGLOT_DIALECTS.get(engine) if engine else None + + if isinstance(statement, str): + try: + self._parsed = self._parse_statement(statement, dialect) + except ParseError as ex: + raise SupersetParseError(statement, engine) from ex + else: + self._parsed = statement + + self._dialect = dialect + self.tables = extract_tables_from_statement(self._parsed, dialect) + + @staticmethod + def _parse_statement( + sql_statement: str, + dialect: Optional[Dialects], + ) -> exp.Expression: + """ + Parse a single SQL statement. + """ + statements = [ + statement + for statement in sqlglot.parse(sql_statement, dialect=dialect) + if statement + ] + if len(statements) != 1: + raise ValueError("SQLStatement should have exactly one statement") + + return statements[0] + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL statement. + """ + write = Dialect.get_or_raise(self._dialect) + return write.generate(self._parsed, copy=False, comments=comments, pretty=True) + + def get_settings(self) -> dict[str, str]: + """ + Return the settings for the SQL statement. + + >>> statement = SQLStatement("SET foo = 'bar'") + >>> statement.get_settings() + {"foo": "'bar'"} + + """ + return { + eq.this.sql(): eq.expression.sql() + for set_item in self._parsed.find_all(exp.SetItem) + for eq in set_item.find_all(exp.EQ) + } + + class ParsedQuery: def __init__( self, @@ -294,7 +474,7 @@ def _extract_tables_from_sql(self) -> set[Table]: return { table for statement in statements - for table in self._extract_tables_from_statement(statement) + for table in extract_tables_from_statement(statement, self._dialect) if statement } diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py index eeb95e6aadb4e..24c04dfe5e92e 100644 --- a/superset/sqllab/api.py +++ b/superset/sqllab/api.py @@ -19,7 +19,6 @@ from urllib import parse import simplejson as json -import sqlparse from flask import request, Response from flask_appbuilder import permission_name from flask_appbuilder.api import expose, protect, rison, safe @@ -38,6 +37,7 @@ from superset.jinja_context import get_template_processor from superset.models.sql_lab import Query from superset.sql_lab import get_sql_results +from superset.sql_parse import SQLScript from superset.sqllab.command_status import SqlJsonExecutionStatus from superset.sqllab.exceptions import ( QueryIsForbiddenToAccessException, @@ -230,7 +230,7 @@ def format_sql(self) -> FlaskResponse: """ try: model = self.format_model_schema.load(request.json) - result = sqlparse.format(model["sql"], reindent=True, keyword_case="upper") + result = SQLScript(model["sql"], model.get("engine")).format() return self.response(200, result=result) except ValidationError as error: return self.response_400(message=error.messages) diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py index dba54cd3b52b7..0864420c9077d 100644 --- a/superset/sqllab/schemas.py +++ b/superset/sqllab/schemas.py @@ -44,6 +44,7 @@ class EstimateQueryCostSchema(Schema): class FormatQueryPayloadSchema(Schema): sql = fields.String(required=True) + engine = fields.String(required=False, allow_none=True) class ExecutePayloadSchema(Schema): diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index b0ac6ac6c7ed4..53e7b217ee0ed 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -138,8 +138,8 @@ def get_select_star(table: str, limit: int, schema: Optional[str] = None): schema = quote_f(schema) table = quote_f(table) if schema: - return f"SELECT *\nFROM {schema}.{table}\nLIMIT {limit}" - return f"SELECT *\nFROM {table}\nLIMIT {limit}" + return f"SELECT\n *\nFROM {schema}.{table}\nLIMIT {limit}" + return f"SELECT\n *\nFROM {table}\nLIMIT {limit}" @pytest.mark.usefixtures("login_as_admin") @@ -333,9 +333,9 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method): query = wait_for_success(result) assert QueryStatus.SUCCESS == query.status - sqllite_select_sql = f"SELECT *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0" + sqlite_select_sql = f"SELECT\n *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0" assert query.select_sql == ( - sqllite_select_sql + sqlite_select_sql if backend() == "sqlite" else get_select_star(tmp_table, query.limit) ) diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 652e20ba87c24..6efa0deed4951 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -694,7 +694,7 @@ def test_when_where_parameter_is_template_and_query_result_type__query_is_templa rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") result = rv.json["result"][0]["query"] if get_example_database().backend != "presto": - assert "('boy' = 'boy')" in result + assert "(\n 'boy' = 'boy'\n )" in result @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @@ -1319,13 +1319,13 @@ def test_time_filter_with_grain(test_client, login_as_admin, physical_query_cont backend = get_example_database().backend if backend == "sqlite": assert ( - "DATETIME(col5, 'start of day', -strftime('%w', col5) || ' days') >=" + "DATETIME(col5, 'start of day', -STRFTIME('%w', col5) || ' days') >=" in query ) elif backend == "mysql": - assert "DATE(DATE_SUB(col5, INTERVAL DAYOFWEEK(col5) - 1 DAY)) >=" in query + assert "DATE(DATE_SUB(col5, INTERVAL (DAYOFWEEK(col5) - 1) DAY)) >=" in query elif backend == "postgresql": - assert "DATE_TRUNC('week', col5) >=" in query + assert "DATE_TRUNC('WEEK', col5) >=" in query elif backend == "presto": assert "date_trunc('week', CAST(col5 AS TIMESTAMP)) >=" in query diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 573b096fc7038..ceb2bf778290b 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -531,7 +531,7 @@ def test_mssql_engine_spec_pymssql(self): ) def test_comments_in_sqlatable_query(self): - clean_query = "SELECT '/* val 1 */' as c1, '-- val 2' as c2 FROM tbl" + clean_query = "SELECT\n '/* val 1 */' AS c1,\n '-- val 2' AS c2\nFROM tbl" commented_query = "/* comment 1 */" + clean_query + "-- comment 2" table = SqlaTable( table_name="test_comments_in_sqlatable_query_table", diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 91e843fc3f883..a38218769d6c1 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -674,7 +674,7 @@ def test_get_samples_with_multiple_filters( assert "2000-01-02" in rv.json["result"]["query"] assert "2000-01-04" in rv.json["result"]["query"] assert "col3 = 1.2" in rv.json["result"]["query"] - assert "col4 is null" in rv.json["result"]["query"] + assert "col4 IS NULL" in rv.json["result"]["query"] assert "col2 = 'c'" in rv.json["result"]["query"] diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index 15465d8a79c79..ababce38e5c69 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -308,10 +308,7 @@ def test_calculated_column_in_order_by_base_engine_spec(self): } sql = table.get_query_str(query_obj) assert ( - """ORDER BY case - when gender='boy' then 'male' - else 'female' - end ASC;""" + "ORDER BY\n CASE WHEN gender = 'boy' THEN 'male' ELSE 'female' END ASC" in sql ) diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index 3523e71b0b022..3eaaa3d1c7ddc 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -381,4 +381,4 @@ def test_calculated_column_in_order_by(self): "orderby": [["gender_cc", True]], } sql = table.get_query_str(query_obj) - assert "ORDER BY `gender_cc` ASC" in sql + assert "ORDER BY\n `gender_cc` ASC" in sql diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 5222c1cb34ef1..2a4c33a281315 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -17,7 +17,6 @@ # isort:skip_file import json from superset.utils.core import DatasourceType -import textwrap import unittest from unittest import mock @@ -298,58 +297,25 @@ def test_select_star(self): sql = db.select_star(table_name, show_cols=False, latest_partition=False) with db.get_sqla_engine_with_context() as engine: quote = engine.dialect.identifier_preparer.quote_identifier - expected = ( - textwrap.dedent( - f"""\ - SELECT * - FROM {quote(table_name)} - LIMIT 100""" - ) - if db.backend in {"presto", "hive"} - else textwrap.dedent( - f"""\ - SELECT * - FROM {table_name} - LIMIT 100""" - ) - ) + + source = quote(table_name) if db.backend in {"presto", "hive"} else table_name + expected = f"SELECT\n *\nFROM {source}\nLIMIT 100" assert expected in sql sql = db.select_star(table_name, show_cols=True, latest_partition=False) # TODO(bkyryliuk): unify sql generation if db.backend == "presto": assert ( - textwrap.dedent( - """\ - SELECT "source" AS "source", - "target" AS "target", - "value" AS "value" - FROM "energy_usage" - LIMIT 100""" - ) - == sql + 'SELECT\n "source" AS "source",\n "target" AS "target",\n "value" AS "value"\nFROM "energy_usage"\nLIMIT 100' + in sql ) elif db.backend == "hive": assert ( - textwrap.dedent( - """\ - SELECT `source`, - `target`, - `value` - FROM `energy_usage` - LIMIT 100""" - ) - == sql + "SELECT\n `source`,\n `target`,\n `value`\nFROM `energy_usage`\nLIMIT 100" + in sql ) else: assert ( - textwrap.dedent( - """\ - SELECT source, - target, - value - FROM energy_usage - LIMIT 100""" - ) + "SELECT\n source,\n target,\n value\nFROM energy_usage\nLIMIT 100" in sql ) @@ -367,12 +333,7 @@ def test_select_star_fully_qualified_names(self): } fully_qualified_name = fully_qualified_names.get(db.db_engine_spec.engine) if fully_qualified_name: - expected = textwrap.dedent( - f"""\ - SELECT * - FROM {fully_qualified_name} - LIMIT 100""" - ) + expected = f"SELECT\n *\nFROM {fully_qualified_name}\nLIMIT 100" assert sql.startswith(expected) def test_single_statement(self): diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 0d6d69e4ce9e7..94b69152040b8 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -373,11 +373,12 @@ def test_query_response_type(self): self.login(username="admin") payload = get_query_context("birth_names") sql_text = get_sql_text(payload) + assert "SELECT" in sql_text - assert re.search(r'[`"\[]?num[`"\]]? IS NOT NULL', sql_text) + assert re.search(r'NOT [`"\[]?num[`"\]]? IS NULL', sql_text) assert re.search( - r"""NOT \([`"\[]?name[`"\]]? IS NULL[\s\n]* """ - r"""OR [`"\[]?name[`"\]]? IN \('"abc"'\)\)""", + r"""NOT \([\s\n]*[`"\[]?name[`"\]]? IS NULL[\s\n]* """ + r"""OR [`"\[]?name[`"\]]? IN \('"abc"'\)[\s\n]*\)""", sql_text, ) @@ -396,7 +397,7 @@ def test_handle_sort_by_metrics(self): # the alias should be in ORDER BY assert "ORDER BY `sum__num` DESC" in sql_text else: - assert re.search(r'ORDER BY [`"\[]?sum__num[`"\]]? DESC', sql_text) + assert re.search(r'ORDER BY[\s\n]* [`"\[]?sum__num[`"\]]? DESC', sql_text) sql_text = get_sql_text( get_query_context("birth_names:only_orderby_has_metric") @@ -407,7 +408,9 @@ def test_handle_sort_by_metrics(self): assert "ORDER BY `sum__num` DESC" in sql_text else: assert re.search( - r'ORDER BY SUM\([`"\[]?num[`"\]]?\) DESC', sql_text, re.IGNORECASE + r'ORDER BY[\s\n]* SUM\([`"\[]?num[`"\]]?\) DESC', + sql_text, + re.IGNORECASE, ) sql_text = get_sql_text(get_query_context("birth_names:orderby_dup_alias")) @@ -438,7 +441,7 @@ def test_handle_sort_by_metrics(self): assert "sum(`num_girls`) AS `SUM(num_girls)`" not in sql_text # Should reference all ORDER BY columns by aliases - assert "ORDER BY `num_girls` DESC," in sql_text + assert "ORDER BY[\\s\n]* `num_girls` DESC," in sql_text assert "`AVG(num_boys)` DESC," in sql_text assert "`MAX(CASE WHEN...` ASC" in sql_text else: @@ -446,14 +449,14 @@ def test_handle_sort_by_metrics(self): # since the selected `num_boys` is renamed to `num_boys__` # it must be references as expression assert re.search( - r'ORDER BY SUM\([`"\[]?num_girls[`"\]]?\) DESC', + r'ORDER BY[\s\n]* SUM\([`"\[]?num_girls[`"\]]?\) DESC', sql_text, re.IGNORECASE, ) else: # Should reference the adhoc metric by alias when possible assert re.search( - r'ORDER BY [`"\[]?num_girls[`"\]]? DESC', + r'ORDER BY[\s\n]* [`"\[]?num_girls[`"\]]? DESC', sql_text, re.IGNORECASE, ) @@ -1075,27 +1078,41 @@ def test_time_offset_with_temporal_range_filter(app_context, physical_dataset): sqls = query_payload["query"].split(";") """ - SELECT DATE_TRUNC('quarter', col6) AS col6, - SUM(col1) AS "SUM(col1)" - FROM physical_dataset - WHERE col6 >= TO_TIMESTAMP('2002-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US') - AND col6 < TO_TIMESTAMP('2003-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US') - GROUP BY DATE_TRUNC('quarter', col6) - LIMIT 10000; - - SELECT DATE_TRUNC('quarter', col6) AS col6, - SUM(col1) AS "SUM(col1)" - FROM physical_dataset - WHERE col6 >= TO_TIMESTAMP('2001-10-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US') - AND col6 < TO_TIMESTAMP('2002-10-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US') - GROUP BY DATE_TRUNC('quarter', col6) - LIMIT 10000; + SELECT + DATETIME(col6, 'start of month', PRINTF('-%d month', ( + STRFTIME('%m', col6) - 1 + ) % 3)) AS col6, + SUM(col1) AS "SUM(col1)" +FROM physical_dataset +WHERE + col6 >= '2002-01-01 00:00:00' AND col6 < '2003-01-01 00:00:00' +GROUP BY + DATETIME(col6, 'start of month', PRINTF('-%d month', ( + STRFTIME('%m', col6) - 1 + ) % 3)) +LIMIT 10000 +OFFSET 0 + +SELECT + DATETIME(col6, 'start of month', PRINTF('-%d month', ( + STRFTIME('%m', col6) - 1 + ) % 3)) AS col6, + SUM(col1) AS "SUM(col1)" +FROM physical_dataset +WHERE + col6 >= '2001-10-01 00:00:00' AND col6 < '2002-10-01 00:00:00' +GROUP BY + DATETIME(col6, 'start of month', PRINTF('-%d month', ( + STRFTIME('%m', col6) - 1 + ) % 3)) +LIMIT 10000 +OFFSET 0 """ assert ( - re.search(r"WHERE col6 >= .*2002-01-01", sqls[0]) + re.search(r"WHERE\n col6 >= .*2002-01-01", sqls[0]) and re.search(r"AND col6 < .*2003-01-01", sqls[0]) ) is not None assert ( - re.search(r"WHERE col6 >= .*2001-10-01", sqls[1]) + re.search(r"WHERE\n col6 >= .*2001-10-01", sqls[1]) and re.search(r"AND col6 < .*2002-10-01", sqls[1]) ) is not None diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index 7518621ddd6d6..916871e538ec4 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -273,7 +273,7 @@ def test_rls_filter_alters_gamma_birth_names_query(self): # establish that the filters are grouped together correctly with # ANDs, ORs and parens in the correct place assert ( - "WHERE ((name like 'A%'\n or name like 'B%')\n OR (name like 'Q%'))\n AND (gender = 'boy');" + "WHERE\n (\n (\n name LIKE 'A%' OR name LIKE 'B%'\n ) OR (\n name LIKE 'Q%'\n )\n )\n AND (\n gender = 'boy'\n )" in sql ) @@ -619,7 +619,7 @@ def _base_filter(query): RLS_ALICE_REGEX = re.compile(r"name = 'Alice'") -RLS_GENDER_REGEX = re.compile(r"AND \(gender = 'girl'\)") +RLS_GENDER_REGEX = re.compile(r"AND \([\s\n]*gender = 'girl'[\s\n]*\)") @mock.patch.dict( diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py index 597f961346abb..d47bff92f72cd 100644 --- a/tests/integration_tests/sql_lab/api_tests.py +++ b/tests/integration_tests/sql_lab/api_tests.py @@ -290,7 +290,7 @@ def test_format_sql_request(self): "/api/v1/sqllab/format_sql/", json=data, ) - success_resp = {"result": "SELECT 1\nFROM my_table"} + success_resp = {"result": "SELECT\n 1\nFROM my_table"} resp_data = json.loads(rv.data.decode("utf-8")) self.assertDictEqual(resp_data, success_resp) self.assertEqual(rv.status_code, 200) diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 0e4f57967b8a7..1a66903da79e4 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -150,9 +150,10 @@ def test_extra_cache_keys(self, mock_user_email, mock_username, mock_user_id): table1 = SqlaTable( table_name="test_has_extra_cache_keys_table", sql=""" - SELECT '{{ current_user_id() }}' as id, - SELECT '{{ current_username() }}' as username, - SELECT '{{ current_user_email() }}' as email, + SELECT + '{{ current_user_id() }}' as id, + '{{ current_username() }}' as username, + '{{ current_user_email() }}' as email """, database=get_example_database(), ) @@ -166,9 +167,10 @@ def test_extra_cache_keys(self, mock_user_email, mock_username, mock_user_id): table2 = SqlaTable( table_name="test_has_extra_cache_keys_disabled_table", sql=""" - SELECT '{{ current_user_id(False) }}' as id, - SELECT '{{ current_username(False) }}' as username, - SELECT '{{ current_user_email(False) }}' as email, + SELECT + '{{ current_user_id(False) }}' as id, + '{{ current_username(False) }}' as username, + '{{ current_user_email(False) }}' as email, """, database=get_example_database(), ) @@ -250,10 +252,11 @@ def test_jinja_metrics_and_calc_columns(self, mock_username): sqla_query = table.get_sqla_query(**base_query_obj) query = table.database.compile_sqla_query(sqla_query.sqla_query) + # assert virtual dataset - assert "SELECT 'user_abc' as user, 'xyz_P1D' as time_grain" in query + assert "SELECT\n 'user_abc' AS user,\n 'xyz_P1D' AS time_grain" in query # assert dataset calculated column - assert "case when 'abc' = 'abc' then 'yes' else 'no' end AS expr" in query + assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query # assert adhoc column assert "'foo_P1D'" in query # assert dataset saved metric @@ -746,7 +749,7 @@ def test_none_operand_in_filter(login_as_admin, physical_dataset): { "operator": FilterOperator.NOT_EQUALS.value, "count": 0, - "sql_should_contain": "COL4 IS NOT NULL", + "sql_should_contain": "NOT COL4 IS NULL", }, ] for expected in expected_results: diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index e4cf4dfc1f9c0..86eb37183f25d 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -219,7 +219,8 @@ class NoLimitDBEngineSpec(BaseEngineSpec): ) assert ( sql - == """SELECT a + == """SELECT + a FROM my_table LIMIT ? OFFSET ?""" @@ -238,6 +239,7 @@ class NoLimitDBEngineSpec(BaseEngineSpec): ) assert ( sql - == """SELECT a + == """SELECT + a FROM my_table""" ) diff --git a/tests/unit_tests/db_engine_specs/test_bigquery.py b/tests/unit_tests/db_engine_specs/test_bigquery.py index 37d04defc8e51..3870297db8b70 100644 --- a/tests/unit_tests/db_engine_specs/test_bigquery.py +++ b/tests/unit_tests/db_engine_specs/test_bigquery.py @@ -148,7 +148,7 @@ def test_select_star(mocker: MockFixture) -> None: # mock the database so we can compile the query database = mocker.MagicMock() database.compile_sqla_query = lambda query: str( - query.compile(dialect=BigQueryDialect()) + query.compile(dialect=BigQueryDialect(), compile_kwargs={"literal_binds": True}) ) engine = mocker.MagicMock() @@ -167,9 +167,10 @@ def test_select_star(mocker: MockFixture) -> None: ) assert ( sql - == """SELECT `trailer` AS `trailer` + == """SELECT + `trailer` AS `trailer` FROM `my_table` -LIMIT :param_1""" +LIMIT 100""" ) diff --git a/tests/unit_tests/jinja_context_test.py b/tests/unit_tests/jinja_context_test.py index e2a5e8cd49280..c8c700f6a4b1b 100644 --- a/tests/unit_tests/jinja_context_test.py +++ b/tests/unit_tests/jinja_context_test.py @@ -47,6 +47,15 @@ def test_dataset_macro(mocker: MockFixture) -> None: from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.models.core import Database + mocker.patch( + "superset.connectors.sqla.models.security_manager.get_guest_rls_filters", + return_value=[], + ) + mocker.patch( + "superset.models.helpers.security_manager.get_rls_filters", + return_value=[], + ) + columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), TableColumn(column_name="num_boys", type="INTEGER"), @@ -94,11 +103,12 @@ def test_dataset_macro(mocker: MockFixture) -> None: assert ( dataset_macro(1) == """( -SELECT ds AS ds, - num_boys AS num_boys, - revenue AS revenue, - expenses AS expenses, - revenue-expenses AS profit +SELECT + ds AS ds, + num_boys AS num_boys, + revenue AS revenue, + expenses AS expenses, + revenue - expenses AS profit FROM my_schema.old_dataset ) AS dataset_1""" ) @@ -106,28 +116,32 @@ def test_dataset_macro(mocker: MockFixture) -> None: assert ( dataset_macro(1, include_metrics=True) == """( -SELECT ds AS ds, - num_boys AS num_boys, - revenue AS revenue, - expenses AS expenses, - revenue-expenses AS profit, - COUNT(*) AS cnt +SELECT + ds AS ds, + num_boys AS num_boys, + revenue AS revenue, + expenses AS expenses, + revenue - expenses AS profit, + COUNT(*) AS cnt FROM my_schema.old_dataset -GROUP BY ds, - num_boys, - revenue, - expenses, - revenue-expenses +GROUP BY + ds, + num_boys, + revenue, + expenses, + revenue - expenses ) AS dataset_1""" ) assert ( dataset_macro(1, include_metrics=True, columns=["ds"]) == """( -SELECT ds AS ds, - COUNT(*) AS cnt +SELECT + ds AS ds, + COUNT(*) AS cnt FROM my_schema.old_dataset -GROUP BY ds +GROUP BY + ds ) AS dataset_1""" ) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 2fd23f7e8e4f2..f097fd1df365b 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -35,6 +35,8 @@ insert_rls_in_predicate, ParsedQuery, sanitize_clause, + SQLScript, + SQLStatement, strip_comments_from_sql, Table, ) @@ -1850,3 +1852,36 @@ def test_is_select() -> None: ) SELECT * FROM t""" ).is_select() + + +def test_sqlquery() -> None: + """ + Test the `SQLScript` class. + """ + script = SQLScript("SELECT 1; SELECT 2;") + + assert len(script.statements) == 2 + assert script.format() == "SELECT\n 1;\nSELECT\n 2" + assert script.statements[0].format() == "SELECT\n 1" + + script = SQLScript("SET a=1; SET a=2; SELECT 3;") + assert script.get_settings() == {"a": "2"} + + +def test_sqlstatement() -> None: + """ + Test the `SQLStatement` class. + """ + statement = SQLStatement("SELECT * FROM table1 UNION ALL SELECT * FROM table2") + + assert statement.tables == { + Table(table="table1", schema=None, catalog=None), + Table(table="table2", schema=None, catalog=None), + } + assert ( + statement.format() + == "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2" + ) + + statement = SQLStatement("SET a=1") + assert statement.get_settings() == {"a": "1"}