Skip to content

Commit

Permalink
chore: improve SQL parsing (#26767)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Mar 13, 2024
1 parent a75bb76 commit 26d8077
Show file tree
Hide file tree
Showing 27 changed files with 394 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ describe('AdhocMetrics', () => {
});

it('Clear metric and set simple adhoc metric', () => {
const metric = 'sum(num_girls)';
const metric = 'SUM(num_girls)';
const metricName = 'Sum Girls';
cy.get('[data-test=metrics]')
.find('[data-test="remove-control-button"]')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ describe('Visualization > Table', () => {
});
cy.verifySliceSuccess({
waitAlias: '@chartData',
querySubstring: /group by.*name/i,
querySubstring: /group by\n.*name/i,
chartSelector: 'table',
});
});
Expand Down Expand Up @@ -246,7 +246,7 @@ describe('Visualization > Table', () => {
cy.visitChartByParams(formData);
cy.verifySliceSuccess({
waitAlias: '@chartData',
querySubstring: /group by.*state/i,
querySubstring: /group by\n.*state/i,
chartSelector: 'table',
});
cy.get('td').contains(/\d*%/);
Expand Down
1 change: 1 addition & 0 deletions superset-frontend/src/SqlLab/actions/sqlLab.js
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,7 @@ export function formatQuery(queryEditor) {
const { sql } = getUpToDateQuery(getState(), queryEditor);
return SupersetClient.post({
endpoint: `/api/v1/sqllab/format_sql/`,
// TODO (betodealmeida): pass engine as a parameter for better formatting
body: JSON.stringify({ sql }),
headers: { 'Content-Type': 'application/json' },
}).then(({ json }) => {
Expand Down
53 changes: 3 additions & 50 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import numpy as np
import pandas as pd
import sqlalchemy as sa
import sqlparse
from flask import escape, Markup
from flask_appbuilder import Model
from flask_appbuilder.security.sqla.models import User
Expand Down Expand Up @@ -104,7 +103,6 @@
ExploreMixin,
ImportExportMixin,
QueryResult,
QueryStringExtended,
validate_adhoc_subquery,
)
from superset.models.slice import Slice
Expand Down Expand Up @@ -1099,7 +1097,9 @@ def _process_sql_expression(


class SqlaTable(
Model, BaseDatasource, ExploreMixin
Model,
BaseDatasource,
ExploreMixin,
): # pylint: disable=too-many-public-methods
"""An ORM object for SqlAlchemy table references"""

Expand Down Expand Up @@ -1413,26 +1413,6 @@ def mutate_query_from_config(self, sql: str) -> str:
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
return get_template_processor(table=self, database=self.database, **kwargs)

def get_query_str_extended(
self,
query_obj: QueryObjectDict,
mutate: bool = True,
) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = self._apply_cte(sql, sqlaq.cte)
sql = sqlparse.format(sql, reindent=True)
if mutate:
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
applied_template_filters=sqlaq.applied_template_filters,
applied_filter_columns=sqlaq.applied_filter_columns,
rejected_filter_columns=sqlaq.rejected_filter_columns,
labels_expected=sqlaq.labels_expected,
prequeries=sqlaq.prequeries,
sql=sql,
)

def get_query_str(self, query_obj: QueryObjectDict) -> str:
query_str_ext = self.get_query_str_extended(query_obj)
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
Expand Down Expand Up @@ -1474,33 +1454,6 @@ def get_from_clause(

return from_clause, cte

def get_rendered_sql(
self, template_processor: BaseTemplateProcessor | None = None
) -> str:
"""
Render sql with template engine (Jinja).
"""

sql = self.sql
if template_processor:
try:
sql = template_processor.process_template(sql)
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error while rendering virtual dataset query: %(msg)s",
msg=ex.message,
)
) from ex
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:
raise QueryObjectValidationError(
_("Virtual dataset query cannot consist of multiple statements")
)
return sql

def adhoc_metric_to_sqla(
self,
metric: AdhocMetric,
Expand Down
4 changes: 2 additions & 2 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from superset.constants import TimeGrain as TimeGrainConstants
from superset.databases.utils import make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import ParsedQuery, SQLScript, Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
Expand Down Expand Up @@ -1448,7 +1448,7 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals
qry = partition_query
sql = database.compile_sqla_query(qry)
if indent:
sql = sqlparse.format(sql, reindent=True)
sql = SQLScript(sql, engine=cls.engine).format()
return sql

@classmethod
Expand Down
7 changes: 4 additions & 3 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from re import Pattern
from typing import Any, TYPE_CHECKING

import sqlparse
from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector
Expand All @@ -37,6 +36,7 @@
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.models.sql_lab import Query
from superset.sql_parse import SQLScript
from superset.utils import core as utils
from superset.utils.core import GenericDataType

Expand Down Expand Up @@ -281,8 +281,9 @@ def get_default_schema_for_query(
This method simply uses the parent method after checking that there are no
malicious path setting in the query.
"""
sql = sqlparse.format(query.sql, strip_comments=True)
if re.search(r"set\s+search_path\s*=", sql, re.IGNORECASE):
script = SQLScript(query.sql, engine=cls.engine)
settings = script.get_settings()
if "search_path" in settings:
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
Expand Down
2 changes: 2 additions & 0 deletions superset/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class SupersetErrorType(StrEnum):
RESULTS_BACKEND_ERROR = "RESULTS_BACKEND_ERROR"
ASYNC_WORKERS_ERROR = "ASYNC_WORKERS_ERROR"
ADHOC_SUBQUERY_NOT_ALLOWED_ERROR = "ADHOC_SUBQUERY_NOT_ALLOWED_ERROR"
INVALID_SQL_ERROR = "INVALID_SQL_ERROR"

# Generic errors
GENERIC_COMMAND_ERROR = "GENERIC_COMMAND_ERROR"
Expand Down Expand Up @@ -176,6 +177,7 @@ class SupersetErrorType(StrEnum):
SupersetErrorType.INVALID_PAYLOAD_SCHEMA_ERROR: [1020],
SupersetErrorType.INVALID_CTAS_QUERY_ERROR: [1023],
SupersetErrorType.INVALID_CVAS_QUERY_ERROR: [1024, 1025],
SupersetErrorType.INVALID_SQL_ERROR: [1003],
SupersetErrorType.SQLLAB_TIMEOUT_ERROR: [1026, 1027],
SupersetErrorType.OBJECT_DOES_NOT_EXIST_ERROR: [1029],
SupersetErrorType.SYNTAX_ERROR: [1030],
Expand Down
17 changes: 17 additions & 0 deletions superset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,20 @@ def __init__(self, exc: ValidationError, payload: dict[str, Any]):
extra={"messages": exc.messages, "payload": payload},
)
super().__init__(error)


class SupersetParseError(SupersetErrorException):
"""
Exception to be raised when we fail to parse SQL.
"""

status = 422

def __init__(self, sql: str, engine: Optional[str] = None):
error = SupersetError(
message=_("The SQL is invalid and cannot be parsed."),
error_type=SupersetErrorType.INVALID_SQL_ERROR,
level=ErrorLevel.ERROR,
extra={"sql": sql, "engine": engine},
)
super().__init__(error)
27 changes: 20 additions & 7 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
ColumnNotFoundException,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetParseError,
SupersetSecurityException,
)
from superset.extensions import feature_flag_manager
Expand All @@ -73,6 +74,8 @@
insert_rls_in_predicate,
ParsedQuery,
sanitize_clause,
SQLScript,
SQLStatement,
)
from superset.superset_typing import (
AdhocMetric,
Expand Down Expand Up @@ -901,12 +904,18 @@ def _apply_cte(sql: str, cte: Optional[str]) -> str:
return sql

def get_query_str_extended(
self, query_obj: QueryObjectDict, mutate: bool = True
self,
query_obj: QueryObjectDict,
mutate: bool = True,
) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = self._apply_cte(sql, sqlaq.cte)
sql = sqlparse.format(sql, reindent=True)
try:
sql = SQLStatement(sql, engine=self.db_engine_spec.engine).format()
except SupersetParseError:
logger.warning("Unable to parse SQL to format it, passing it as-is")

if mutate:
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
Expand Down Expand Up @@ -1054,7 +1063,8 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:
)

def get_rendered_sql(
self, template_processor: Optional[BaseTemplateProcessor] = None
self,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> str:
"""
Render sql with template engine (Jinja).
Expand All @@ -1071,13 +1081,16 @@ def get_rendered_sql(
msg=ex.message,
)
) from ex
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:

script = SQLScript(sql.strip("\t\r\n; "), engine=self.db_engine_spec.engine)
if len(script.statements) > 1:
raise QueryObjectValidationError(
_("Virtual dataset query cannot consist of multiple statements")
)

sql = script.statements[0].format(comments=False)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
return sql

def text(self, clause: str) -> TextClause:
Expand Down
Loading

0 comments on commit 26d8077

Please sign in to comment.