From fcd72d0b46a14b7f86048d78f2e4f2176b2fe396 Mon Sep 17 00:00:00 2001 From: Mark Molinaro <16494982+markjm@users.noreply.github.com> Date: Tue, 12 Mar 2024 06:25:43 +0000 Subject: [PATCH] fix: #543 more properly identify CREATE TABLE ... LIKE ... statements --- .vscode/settings.json | 14 +++++++++ sqlparse/engine/grouping.py | 60 ++++++++++++++++++++++++------------- sqlparse/keywords.py | 3 +- tests/test_grouping.py | 13 ++++---- tests/test_regressions.py | 7 +++++ tests/test_tokenize.py | 4 +-- 6 files changed, 72 insertions(+), 29 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..1b213ba6 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,14 @@ +{ + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "test_*.py" + ], + "python.testing.pytestEnabled": true, + "python.testing.unittestEnabled": false, + "python.testing.pytestArgs": [ + "tests" + ] +} \ No newline at end of file diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index c486318a..223db379 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -77,7 +77,7 @@ def group_typecasts(tlist): def match(token): return token.match(T.Punctuation, '::') - def valid(token): + def valid(token, idx): return token is not None def post(tlist, pidx, tidx, nidx): @@ -91,10 +91,10 @@ def group_tzcasts(tlist): def match(token): return token.ttype == T.Keyword.TZCast - def valid_prev(token): + def valid_prev(token, idx): return token is not None - def valid_next(token): + def valid_next(token, idx): return token is not None and ( token.is_whitespace or token.match(T.Keyword, 'AS') @@ -119,13 +119,13 @@ def match(token): def match_to_extend(token): return isinstance(token, sql.TypedLiteral) - def valid_prev(token): + def valid_prev(token, idx): return token is not None - def valid_next(token): + def valid_next(token, idx): return token is not None and token.match(*sql.TypedLiteral.M_CLOSE) - def valid_final(token): + def valid_final(token, idx): return token is not None and token.match(*sql.TypedLiteral.M_EXTEND) def post(tlist, pidx, tidx, nidx): @@ -141,12 +141,12 @@ def group_period(tlist): def match(token): return token.match(T.Punctuation, '.') - def valid_prev(token): + def valid_prev(token, idx): sqlcls = sql.SquareBrackets, sql.Identifier ttypes = T.Name, T.String.Symbol return imt(token, i=sqlcls, t=ttypes) - def valid_next(token): + def valid_next(token, idx): # issue261, allow invalid next token return True @@ -166,10 +166,10 @@ def group_as(tlist): def match(token): return token.is_keyword and token.normalized == 'AS' - def valid_prev(token): + def valid_prev(token, idx): return token.normalized == 'NULL' or not token.is_keyword - def valid_next(token): + def valid_next(token, idx): ttypes = T.DML, T.DDL, T.CTE return not imt(token, t=ttypes) and token is not None @@ -183,7 +183,7 @@ def group_assignment(tlist): def match(token): return token.match(T.Assignment, ':=') - def valid(token): + def valid(token, idx): return token is not None and token.ttype not in (T.Keyword,) def post(tlist, pidx, tidx, nidx): @@ -202,9 +202,12 @@ def group_comparison(tlist): ttypes = T_NUMERICAL + T_STRING + T_NAME def match(token): - return token.ttype == T.Operator.Comparison + return imt(token, + t=(T.Operator.Comparison), + m=(T.Keyword, 'LIKE') + ) - def valid(token): + def valid(token, idx): if imt(token, t=ttypes, i=sqlcls): return True elif token and token.is_keyword and token.normalized == 'NULL': @@ -214,8 +217,23 @@ def valid(token): def post(tlist, pidx, tidx, nidx): return pidx, nidx + + def valid_next(token, idx): + return valid(token, idx) + + def valid_prev(token, idx): + # https://dev.mysql.com/doc/refman/8.0/en/create-table-like.html + # LIKE is usually a compatarator, except when used in `CREATE TABLE x LIKE y` statements + # Check if we are constructing a table - otherwise assume it is indeed a comparator + two_tokens_back_idx = idx - 3 + if two_tokens_back_idx >= 0: + _, two_tokens_back = tlist.token_next(two_tokens_back_idx) + if imt(two_tokens_back, m=(T.Keyword, 'TABLE')): + return False + + return valid(token, idx) + - valid_prev = valid_next = valid _group(tlist, sql.Comparison, match, valid_prev, valid_next, post, extend=False) @@ -237,10 +255,10 @@ def group_arrays(tlist): def match(token): return isinstance(token, sql.SquareBrackets) - def valid_prev(token): + def valid_prev(token, idx): return imt(token, i=sqlcls, t=ttypes) - def valid_next(token): + def valid_next(token, idx): return True def post(tlist, pidx, tidx, nidx): @@ -258,7 +276,7 @@ def group_operator(tlist): def match(token): return imt(token, t=(T.Operator, T.Wildcard)) - def valid(token): + def valid(token, idx): return imt(token, i=sqlcls, t=ttypes) \ or (token and token.match( T.Keyword, @@ -283,7 +301,7 @@ def group_identifier_list(tlist): def match(token): return token.match(T.Punctuation, ',') - def valid(token): + def valid(token, idx): return imt(token, i=sqlcls, m=m_role, t=ttypes) def post(tlist, pidx, tidx, nidx): @@ -431,8 +449,8 @@ def group(stmt): def _group(tlist, cls, match, - valid_prev=lambda t: True, - valid_next=lambda t: True, + valid_prev=lambda t, idx: True, + valid_next=lambda t, idx: True, post=None, extend=True, recurse=True @@ -454,7 +472,7 @@ def _group(tlist, cls, match, if match(token): nidx, next_ = tlist.token_next(tidx) - if prev_ and valid_prev(prev_) and valid_next(next_): + if prev_ and valid_prev(prev_, pidx) and valid_next(next_, nidx): from_idx, to_idx = post(tlist, pidx, tidx, nidx) grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend) diff --git a/sqlparse/keywords.py b/sqlparse/keywords.py index d3794fd3..fa8b6e65 100644 --- a/sqlparse/keywords.py +++ b/sqlparse/keywords.py @@ -82,7 +82,8 @@ r'(EXPLODE|INLINE|PARSE_URL_TUPLE|POSEXPLODE|STACK)\b', tokens.Keyword), (r"(AT|WITH')\s+TIME\s+ZONE\s+'[^']+'", tokens.Keyword.TZCast), - (r'(NOT\s+)?(LIKE|ILIKE|RLIKE)\b', tokens.Operator.Comparison), + (r'(NOT\s+)(LIKE|ILIKE|RLIKE)\b', tokens.Operator.Comparison), + (r'(ILIKE|RLIKE)\b', tokens.Operator.Comparison), (r'(NOT\s+)?(REGEXP)\b', tokens.Operator.Comparison), # Check for keywords, also returns tokens.Name if regex matches # but the match isn't a keyword. diff --git a/tests/test_grouping.py b/tests/test_grouping.py index e90243b5..5e71dca5 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -498,7 +498,10 @@ def test_comparison_with_strings(operator): assert p.tokens[0].right.ttype == T.String.Single -def test_like_and_ilike_comparison(): +@pytest.mark.parametrize('operator', ( + 'LIKE', 'NOT LIKE', 'ILIKE', 'NOT ILIKE', 'RLIKE', 'NOT RLIKE' +)) +def test_like_and_ilike_comparison(operator): def validate_where_clause(where_clause, expected_tokens): assert len(where_clause.tokens) == len(expected_tokens) for where_token, expected_token in zip(where_clause, expected_tokens): @@ -513,22 +516,22 @@ def validate_where_clause(where_clause, expected_tokens): assert (isinstance(where_token, expected_ttype) and re.match(expected_value, where_token.value)) - [p1] = sqlparse.parse("select * from mytable where mytable.mycolumn LIKE 'expr%' limit 5;") + [p1] = sqlparse.parse(f"select * from mytable where mytable.mycolumn {operator} 'expr%' limit 5;") [p1_where] = [token for token in p1 if isinstance(token, sql.Where)] validate_where_clause(p1_where, [ (T.Keyword, "where"), (T.Whitespace, None), - (sql.Comparison, r"mytable.mycolumn LIKE.*"), + (sql.Comparison, f"mytable.mycolumn {operator}.*"), (T.Whitespace, None), ]) [p2] = sqlparse.parse( - "select * from mytable where mycolumn NOT ILIKE '-expr' group by othercolumn;") + f"select * from mytable where mycolumn {operator} '-expr' group by othercolumn;") [p2_where] = [token for token in p2 if isinstance(token, sql.Where)] validate_where_clause(p2_where, [ (T.Keyword, "where"), (T.Whitespace, None), - (sql.Comparison, r"mycolumn NOT ILIKE.*"), + (sql.Comparison, f"mycolumn {operator}.*"), (T.Whitespace, None), ]) diff --git a/tests/test_regressions.py b/tests/test_regressions.py index 961adc17..65bac271 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -444,3 +444,10 @@ def test_copy_issue672(): p = sqlparse.parse('select * from foo')[0] copied = copy.deepcopy(p) assert str(p) == str(copied) + +def test_copy_issue543(): + tokens = sqlparse.parse('create table tab1.b like tab2')[0].tokens + assert [(t.ttype, t.value) for t in tokens if t.ttype != T.Whitespace] == [(T.DDL, 'create'), (T.Keyword, 'table'), (None, 'tab1.b'),(T.Keyword, 'like'), (None, 'tab2')] + + comparison = sqlparse.parse('a LIKE "b"')[0].tokens[0] + assert isinstance(comparison, sql.Comparison) \ No newline at end of file diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py index af0ba163..ce1b4af0 100644 --- a/tests/test_tokenize.py +++ b/tests/test_tokenize.py @@ -209,10 +209,10 @@ def test_parse_window_as(): @pytest.mark.parametrize('s', ( - "LIKE", "ILIKE", "NOT LIKE", "NOT ILIKE", + "ILIKE", "NOT LIKE", "NOT ILIKE", "NOT LIKE", "NOT ILIKE", )) -def test_like_and_ilike_parsed_as_comparisons(s): +def test_likeish_but_not_like_parsed_as_comparisons(s): p = sqlparse.parse(s)[0] assert len(p.tokens) == 1 assert p.tokens[0].ttype == T.Operator.Comparison