From 8b0342730b2e53dcc7887e5b4563d479444ec0ba Mon Sep 17 00:00:00 2001 From: Andi Albrecht Date: Mon, 13 May 2024 09:26:26 +0200 Subject: [PATCH] Fix grouping of comments (fixes #772). The grouping of comments was a bit too greedy by also consuming whitespaces at the end. --- CHANGELOG | 8 +++++++- sqlparse/engine/grouping.py | 2 +- sqlparse/sql.py | 3 ++- tests/test_format.py | 11 ++++++++--- tests/test_grouping.py | 7 ------- 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index 91febb7e..a5a1ba9e 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,7 +1,13 @@ Development Version ------------------- -Nothing yet. +Bug Fixes + +* The strip comments filter was a bit greedy and removed too much + whitespace (issue772). + Note: In some cases you might want to add `strip_whitespace=True` where you + previously used just `strip_comments=True`. `strip_comments` did some of the + work that `strip_whitespace` should do. Release 0.5.0 (Apr 13, 2024) diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index 926a3c1b..a63f4da2 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -314,7 +314,7 @@ def group_comments(tlist): tidx, token = tlist.token_next_by(t=T.Comment) while token: eidx, end = tlist.token_not_matching( - lambda tk: imt(tk, t=T.Comment) or tk.is_whitespace, idx=tidx) + lambda tk: imt(tk, t=T.Comment) or tk.is_newline, idx=tidx) if end is not None: eidx, end = tlist.token_prev(eidx, skip_ws=False) tlist.group_tokens(sql.Comment, tidx, eidx) diff --git a/sqlparse/sql.py b/sqlparse/sql.py index bd5f35b1..10373751 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -46,7 +46,7 @@ class Token: """ __slots__ = ('value', 'ttype', 'parent', 'normalized', 'is_keyword', - 'is_group', 'is_whitespace') + 'is_group', 'is_whitespace', 'is_newline') def __init__(self, ttype, value): value = str(value) @@ -56,6 +56,7 @@ def __init__(self, ttype, value): self.is_group = False self.is_keyword = ttype in T.Keyword self.is_whitespace = self.ttype in T.Whitespace + self.is_newline = self.ttype in T.Newline self.normalized = value.upper() if self.is_keyword else value def __str__(self): diff --git a/tests/test_format.py b/tests/test_format.py index a616f360..6a4b6f16 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -73,15 +73,15 @@ def test_strip_comments_multi(self): assert res == 'select' sql = '/* sql starts here */ select' res = sqlparse.format(sql, strip_comments=True) - assert res == 'select' + assert res == ' select' # note whitespace is preserved, see issue 772 sql = '/*\n * sql starts here\n */\nselect' res = sqlparse.format(sql, strip_comments=True) assert res == 'select' sql = 'select (/* sql starts here */ select 2)' - res = sqlparse.format(sql, strip_comments=True) + res = sqlparse.format(sql, strip_comments=True, strip_whitespace=True) assert res == 'select (select 2)' sql = 'select (/* sql /* starts here */ select 2)' - res = sqlparse.format(sql, strip_comments=True) + res = sqlparse.format(sql, strip_comments=True, strip_whitespace=True) assert res == 'select (select 2)' def test_strip_comments_preserves_linebreak(self): @@ -100,6 +100,11 @@ def test_strip_comments_preserves_linebreak(self): sql = 'select * -- a comment\n\nfrom foo' res = sqlparse.format(sql, strip_comments=True) assert res == 'select *\n\nfrom foo' + + def test_strip_comments_preserves_whitespace(self): + sql = 'SELECT 1/*bar*/ AS foo' # see issue772 + res = sqlparse.format(sql, strip_comments=True) + assert res == 'SELECT 1 AS foo' def test_strip_ws(self): f = lambda sql: sqlparse.format(sql, strip_whitespace=True) diff --git a/tests/test_grouping.py b/tests/test_grouping.py index b39ff270..88b762cd 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -17,13 +17,6 @@ def test_grouping_parenthesis(): assert len(parsed.tokens[2].tokens[3].tokens) == 3 -def test_grouping_comments(): - s = '/*\n * foo\n */ \n bar' - parsed = sqlparse.parse(s)[0] - assert str(parsed) == s - assert len(parsed.tokens) == 2 - - @pytest.mark.parametrize('s', ['foo := 1;', 'foo := 1']) def test_grouping_assignment(s): parsed = sqlparse.parse(s)[0]