diff --git a/CHANGELOG b/CHANGELOG index 6d9d71f8..150f3b3c 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -14,6 +14,9 @@ Bug Fixes previously used just `strip_comments=True`. `strip_comments` did some of the work that `strip_whitespace` should do. +* Fix error when splitting statements that contain multiple CASE clauses + within a BEGIN block (issue784). + Release 0.5.0 (Apr 13, 2024) ---------------------------- diff --git a/sqlparse/engine/statement_splitter.py b/sqlparse/engine/statement_splitter.py index 5b3a0d9b..6c69d303 100644 --- a/sqlparse/engine/statement_splitter.py +++ b/sqlparse/engine/statement_splitter.py @@ -17,6 +17,7 @@ def __init__(self): def _reset(self): """Set the filter attributes to its default values""" self._in_declare = False + self._in_case = False self._is_create = False self._begin_depth = 0 @@ -58,16 +59,18 @@ def _change_splitlevel(self, ttype, value): return 1 return 0 - # Should this respect a preceding BEGIN? - # In CASE ... WHEN ... END this results in a split level -1. - # Would having multiple CASE WHEN END and a Assignment Operator - # cause the statement to cut off prematurely? + # BEGIN and CASE/WHEN both end with END if unified == 'END': - self._begin_depth = max(0, self._begin_depth - 1) + if not self._in_case: + self._begin_depth = max(0, self._begin_depth - 1) + else: + self._in_case = False return -1 if (unified in ('IF', 'FOR', 'WHILE', 'CASE') and self._is_create and self._begin_depth > 0): + if unified == 'CASE': + self._in_case = True return 1 if unified in ('END IF', 'END FOR', 'END WHILE'): diff --git a/tests/files/multiple_case_in_begin.sql b/tests/files/multiple_case_in_begin.sql new file mode 100644 index 00000000..6cbb3864 --- /dev/null +++ b/tests/files/multiple_case_in_begin.sql @@ -0,0 +1,8 @@ +CREATE TRIGGER mytrig +AFTER UPDATE OF vvv ON mytable +BEGIN + UPDATE aa + SET mycola = (CASE WHEN (A=1) THEN 2 END); + UPDATE bb + SET mycolb = (CASE WHEN (B=1) THEN 5 END); +END; \ No newline at end of file diff --git a/tests/test_split.py b/tests/test_split.py index 90d2eaff..e2f1429e 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -203,3 +203,8 @@ def test_split_strip_semicolon_procedure(load_file): def test_split_go(sql, num): # issue762 stmts = sqlparse.split(sql) assert len(stmts) == num + + +def test_split_multiple_case_in_begin(load_file): # issue784 + stmts = sqlparse.split(load_file('multiple_case_in_begin.sql')) + assert len(stmts) == 1