Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Jan 24, 2024
1 parent a6cc819 commit 8237798
Show file tree
Hide file tree
Showing 22 changed files with 140 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ describe('AdhocMetrics', () => {
});

it('Clear metric and set simple adhoc metric', () => {
const metric = 'sum(num_girls)';
const metric = 'SUM(num_girls)';
const metricName = 'Sum Girls';
cy.get('[data-test=metrics]')
.find('[data-test="remove-control-button"]')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ describe('Visualization > Table', () => {
});
cy.verifySliceSuccess({
waitAlias: '@chartData',
querySubstring: /group by.*name/i,
querySubstring: /group by\n.*name/i,
chartSelector: 'table',
});
});
Expand Down Expand Up @@ -246,7 +246,7 @@ describe('Visualization > Table', () => {
cy.visitChartByParams(formData);
cy.verifySliceSuccess({
waitAlias: '@chartData',
querySubstring: /group by.*state/i,
querySubstring: /group by\n.*state/i,
chartSelector: 'table',
});
cy.get('td').contains(/\d*%/);
Expand Down
1 change: 0 additions & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@
ExploreMixin,
ImportExportMixin,
QueryResult,
QueryStringExtended,
validate_adhoc_subquery,
)
from superset.models.slice import Slice
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,7 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals
qry = partition_query
sql = database.compile_sqla_query(qry)
if indent:
sql = SQLQuery(sql).format()
sql = SQLQuery(sql, engine=cls.engine).format()
return sql

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def get_default_schema_for_query(
This method simply uses the parent method after checking that there are no
malicious path setting in the query.
"""
statement = SQLQuery(query.sql)
statement = SQLQuery(query.sql, engine=cls.engine)
settings = statement.get_settings()
if "search_path" in settings:
raise SupersetSecurityException(
Expand Down
2 changes: 2 additions & 0 deletions superset/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class SupersetErrorType(StrEnum):
RESULTS_BACKEND_ERROR = "RESULTS_BACKEND_ERROR"
ASYNC_WORKERS_ERROR = "ASYNC_WORKERS_ERROR"
ADHOC_SUBQUERY_NOT_ALLOWED_ERROR = "ADHOC_SUBQUERY_NOT_ALLOWED_ERROR"
INVALID_SQL_ERROR = "INVALID_SQL_ERROR"

# Generic errors
GENERIC_COMMAND_ERROR = "GENERIC_COMMAND_ERROR"
Expand Down Expand Up @@ -175,6 +176,7 @@ class SupersetErrorType(StrEnum):
SupersetErrorType.INVALID_PAYLOAD_SCHEMA_ERROR: [1020],
SupersetErrorType.INVALID_CTAS_QUERY_ERROR: [1023],
SupersetErrorType.INVALID_CVAS_QUERY_ERROR: [1024, 1025],
SupersetErrorType.INVALID_SQL_ERROR: [1003],
SupersetErrorType.SQLLAB_TIMEOUT_ERROR: [1026, 1027],
SupersetErrorType.OBJECT_DOES_NOT_EXIST_ERROR: [1029],
SupersetErrorType.SYNTAX_ERROR: [1030],
Expand Down
17 changes: 17 additions & 0 deletions superset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,20 @@ def __init__(self, exc: ValidationError, payload: dict[str, Any]):
extra={"messages": exc.messages, "payload": payload},
)
super().__init__(error)


class SupersetParseError(SupersetErrorException):
"""
Exception to be raised when we fail to parse SQL.
"""

status = 422

def __init__(self, sql: str, engine: Optional[str] = None):
error = SupersetError(
message=_("The SQL is invalid and cannot be parsed."),
error_type=SupersetErrorType.INVALID_SQL_ERROR,
level=ErrorLevel.ERROR,
extra={"sql": sql, "engine": engine},
)
super().__init__(error)
9 changes: 7 additions & 2 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
ColumnNotFoundException,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetParseError,
SupersetSecurityException,
)
from superset.extensions import feature_flag_manager
Expand Down Expand Up @@ -910,7 +911,11 @@ def get_query_str_extended(
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = self._apply_cte(sql, sqlaq.cte)
sql = SQLStatement(sql).format()
try:
sql = SQLStatement(sql, engine=self.db_engine_spec.engine).format()
except SupersetParseError:
logger.warning("Unable to parse SQL to format it, passing it as-is")

if mutate:
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
Expand Down Expand Up @@ -1077,7 +1082,7 @@ def get_rendered_sql(
)
) from ex

query = SQLQuery(sql.strip("\t\r\n; "))
query = SQLQuery(sql.strip("\t\r\n; "), engine=self.db_engine_spec.engine)
if len(query.statements) > 1:
raise QueryObjectValidationError(
_("Virtual dataset query cannot consist of multiple statements")
Expand Down
18 changes: 11 additions & 7 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
)
from sqlparse.utils import imt

from superset.exceptions import QueryClauseValidationException
from superset.exceptions import QueryClauseValidationException, SupersetParseError
from superset.utils.backports import StrEnum

try:
Expand Down Expand Up @@ -345,7 +345,7 @@ def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL query.
"""
return ";\n".join(statement.format() for statement in self.statements)
return ";\n".join(statement.format(comments) for statement in self.statements)

def get_settings(self) -> dict[str, str]:
"""
Expand Down Expand Up @@ -376,11 +376,15 @@ def __init__(
engine: Optional[str] = None,
):
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._parsed = (
self._parse_statement(statement, dialect)
if isinstance(statement, str)
else statement
)

if isinstance(statement, str):
try:
self._parsed = self._parse_statement(statement, dialect)
except ParseError as ex:
raise SupersetParseError(statement, engine) from ex
else:
self._parsed = statement

self._dialect = dialect
self.tables = extract_tables_from_statement(self._parsed, dialect)

Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def get_select_star(table: str, limit: int, schema: Optional[str] = None):
schema = quote_f(schema)
table = quote_f(table)
if schema:
return f"SELECT *\nFROM {schema}.{table}\nLIMIT {limit}"
return f"SELECT *\nFROM {table}\nLIMIT {limit}"
return f"SELECT\n *\nFROM {schema}.{table}\nLIMIT {limit}"
return f"SELECT\n *\nFROM {table}\nLIMIT {limit}"


@pytest.mark.usefixtures("login_as_admin")
Expand Down Expand Up @@ -333,9 +333,9 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method):
query = wait_for_success(result)
assert QueryStatus.SUCCESS == query.status

sqllite_select_sql = f"SELECT *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0"
sqlite_select_sql = f"SELECT\n *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0"
assert query.select_sql == (
sqllite_select_sql
sqlite_select_sql
if backend() == "sqlite"
else get_select_star(tmp_table, query.limit)
)
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/charts/data/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def test_when_where_parameter_is_template_and_query_result_type__query_is_templa
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
result = rv.json["result"][0]["query"]
if get_example_database().backend != "presto":
assert "('boy' = 'boy')" in result
assert "(\n 'boy' = 'boy'\n )" in result

@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
Expand Down Expand Up @@ -1244,13 +1244,13 @@ def test_time_filter_with_grain(test_client, login_as_admin, physical_query_cont
backend = get_example_database().backend
if backend == "sqlite":
assert (
"DATETIME(col5, 'start of day', -strftime('%w', col5) || ' days') >="
"DATETIME(col5, 'start of day', -STRFTIME('%w', col5) || ' days') >="
in query
)
elif backend == "mysql":
assert "DATE(DATE_SUB(col5, INTERVAL DAYOFWEEK(col5) - 1 DAY)) >=" in query
assert "DATE(DATE_SUB(col5, INTERVAL (DAYOFWEEK(col5) - 1) DAY)) >=" in query
elif backend == "postgresql":
assert "DATE_TRUNC('week', col5) >=" in query
assert "DATE_TRUNC('WEEK', col5) >=" in query
elif backend == "presto":
assert "date_trunc('week', CAST(col5 AS TIMESTAMP)) >=" in query

Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def test_mssql_engine_spec_pymssql(self):
)

def test_comments_in_sqlatable_query(self):
clean_query = "SELECT '/* val 1 */' as c1, '-- val 2' as c2 FROM tbl"
clean_query = "SELECT\n '/* val 1 */' AS c1,\n '-- val 2' AS c2\nFROM tbl"
commented_query = "/* comment 1 */" + clean_query + "-- comment 2"
table = SqlaTable(
table_name="test_comments_in_sqlatable_query_table",
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/datasource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def test_get_samples_with_multiple_filters(
assert "2000-01-02" in rv.json["result"]["query"]
assert "2000-01-04" in rv.json["result"]["query"]
assert "col3 = 1.2" in rv.json["result"]["query"]
assert "col4 is null" in rv.json["result"]["query"]
assert "col4 IS NULL" in rv.json["result"]["query"]
assert "col2 = 'c'" in rv.json["result"]["query"]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,7 @@ def test_calculated_column_in_order_by_base_engine_spec(self):
}
sql = table.get_query_str(query_obj)
assert (
"""ORDER BY case
when gender='boy' then 'male'
else 'female'
end ASC;"""
"ORDER BY\n CASE WHEN gender = 'boy' THEN 'male' ELSE 'female' END ASC"
in sql
)

Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/db_engine_specs/bigquery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,4 +381,4 @@ def test_calculated_column_in_order_by(self):
"orderby": [["gender_cc", True]],
}
sql = table.get_query_str(query_obj)
assert "ORDER BY `gender_cc` ASC" in sql
assert "ORDER BY\n `gender_cc` ASC" in sql
57 changes: 9 additions & 48 deletions tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# isort:skip_file
import json
from superset.utils.core import DatasourceType
import textwrap
import unittest
from unittest import mock

Expand Down Expand Up @@ -298,58 +297,25 @@ def test_select_star(self):
sql = db.select_star(table_name, show_cols=False, latest_partition=False)
with db.get_sqla_engine_with_context() as engine:
quote = engine.dialect.identifier_preparer.quote_identifier
expected = (
textwrap.dedent(
f"""\
SELECT *
FROM {quote(table_name)}
LIMIT 100"""
)
if db.backend in {"presto", "hive"}
else textwrap.dedent(
f"""\
SELECT *
FROM {table_name}
LIMIT 100"""
)
)

source = quote(table_name) if db.backend in {"presto", "hive"} else table_name
expected = f"SELECT\n *\nFROM {source}\nLIMIT 100"
assert expected in sql
sql = db.select_star(table_name, show_cols=True, latest_partition=False)
# TODO(bkyryliuk): unify sql generation
if db.backend == "presto":
assert (
textwrap.dedent(
"""\
SELECT "source" AS "source",
"target" AS "target",
"value" AS "value"
FROM "energy_usage"
LIMIT 100"""
)
== sql
'SELECT\n "source" AS "source",\n "target" AS "target",\n "value" AS "value"\nFROM "energy_usage"\nLIMIT 100'
in sql
)
elif db.backend == "hive":
assert (
textwrap.dedent(
"""\
SELECT `source`,
`target`,
`value`
FROM `energy_usage`
LIMIT 100"""
)
== sql
"SELECT\n `source`,\n `target`,\n `value`\nFROM `energy_usage`\nLIMIT 100"
in sql
)
else:
assert (
textwrap.dedent(
"""\
SELECT source,
target,
value
FROM energy_usage
LIMIT 100"""
)
"SELECT\n source,\n target,\n value\nFROM energy_usage\nLIMIT 100"
in sql
)

Expand All @@ -367,12 +333,7 @@ def test_select_star_fully_qualified_names(self):
}
fully_qualified_name = fully_qualified_names.get(db.db_engine_spec.engine)
if fully_qualified_name:
expected = textwrap.dedent(
f"""\
SELECT *
FROM {fully_qualified_name}
LIMIT 100"""
)
expected = f"SELECT\n *\nFROM {fully_qualified_name}\nLIMIT 100"
assert sql.startswith(expected)

def test_single_statement(self):
Expand Down
Loading

0 comments on commit 8237798

Please sign in to comment.