diff --git a/sqlparse/__init__.py b/sqlparse/__init__.py index bb9b88a5..9bb8b5e9 100644 --- a/sqlparse/__init__.py +++ b/sqlparse/__init__.py @@ -38,6 +38,7 @@ def parsestream(stream, encoding=None): :returns: A generator of :class:`~sqlparse.sql.Statement` instances. """ stack = engine.FilterStack() + stack.stmtprocess.append(stack.grouping_filter) stack.enable_grouping() return stack.run(stream, encoding) diff --git a/sqlparse/engine/filter_stack.py b/sqlparse/engine/filter_stack.py index 9665a224..b60ff469 100644 --- a/sqlparse/engine/filter_stack.py +++ b/sqlparse/engine/filter_stack.py @@ -7,8 +7,8 @@ """filter""" +from sqlparse import filters from sqlparse import lexer -from sqlparse.engine import grouping from sqlparse.engine.statement_splitter import StatementSplitter @@ -17,10 +17,10 @@ def __init__(self): self.preprocess = [] self.stmtprocess = [] self.postprocess = [] - self._grouping = False + self.grouping_filter = filters.GroupingFilter() def enable_grouping(self): - self._grouping = True + self.grouping_filter.enable() def run(self, sql, encoding=None): stream = lexer.tokenize(sql, encoding) @@ -32,9 +32,6 @@ def run(self, sql, encoding=None): # Output: Stream processed Statements for stmt in stream: - if self._grouping: - stmt = grouping.group(stmt) - for filter_ in self.stmtprocess: filter_.process(stmt) diff --git a/sqlparse/filters/__init__.py b/sqlparse/filters/__init__.py index 5bd6b325..9126b57e 100644 --- a/sqlparse/filters/__init__.py +++ b/sqlparse/filters/__init__.py @@ -5,6 +5,7 @@ # This module is part of python-sqlparse and is released under # the BSD License: https://opensource.org/licenses/BSD-3-Clause +from sqlparse.filters.others import GroupingFilter from sqlparse.filters.others import SerializerUnicode from sqlparse.filters.others import StripCommentsFilter from sqlparse.filters.others import StripWhitespaceFilter @@ -22,6 +23,7 @@ from sqlparse.filters.aligned_indent import AlignedIndentFilter __all__ = [ + 'GroupingFilter', 'SerializerUnicode', 'StripCommentsFilter', 'StripWhitespaceFilter', diff --git a/sqlparse/filters/others.py b/sqlparse/filters/others.py index e0e1ca19..778ab361 100644 --- a/sqlparse/filters/others.py +++ b/sqlparse/filters/others.py @@ -8,16 +8,30 @@ import re from sqlparse import sql, tokens as T +from sqlparse.engine import grouping from sqlparse.utils import split_unquoted_newlines -class StripCommentsFilter: +class GroupingFilter: + def __init__(self): + self._enabled = False + + def enable(self): + self._enabled = True + + def process(self, stmt): + if self._enabled: + return grouping.group(stmt) + else: + return stmt + +class StripCommentsFilter: @staticmethod - def _process(tlist): + def process(stmt): def get_next_comment(): # TODO(andi) Comment types should be unified, see related issue38 - return tlist.token_next_by(i=sql.Comment, t=T.Comment) + return stmt.token_next_by(i=sql.Comment, t=T.Comment) def _get_insert_token(token): """Returns either a whitespace or the line breaks from token.""" @@ -28,10 +42,12 @@ def _get_insert_token(token): else: return sql.Token(T.Whitespace, ' ') + grouping.group_comments(stmt) + tidx, token = get_next_comment() while token: - pidx, prev_ = tlist.token_prev(tidx, skip_ws=False) - nidx, next_ = tlist.token_next(tidx, skip_ws=False) + pidx, prev_ = stmt.token_prev(tidx, skip_ws=False) + nidx, next_ = stmt.token_next(tidx, skip_ws=False) # Replace by whitespace if prev and next exist and if they're not # whitespaces. This doesn't apply if prev or next is a parenthesis. if (prev_ is None or next_ is None @@ -40,16 +56,13 @@ def _get_insert_token(token): # Insert a whitespace to ensure the following SQL produces # a valid SQL (see #425). if prev_ is not None and not prev_.match(T.Punctuation, '('): - tlist.tokens.insert(tidx, _get_insert_token(token)) - tlist.tokens.remove(token) + stmt.tokens.insert(tidx, _get_insert_token(token)) + stmt.tokens.remove(token) else: - tlist.tokens[tidx] = _get_insert_token(token) + stmt.tokens[tidx] = _get_insert_token(token) tidx, token = get_next_comment() - def process(self, stmt): - [self.process(sgroup) for sgroup in stmt.get_sublists()] - StripCommentsFilter._process(stmt) return stmt diff --git a/sqlparse/formatter.py b/sqlparse/formatter.py index 1d1871cf..203a983e 100644 --- a/sqlparse/formatter.py +++ b/sqlparse/formatter.py @@ -149,14 +149,17 @@ def build_filter_stack(stack, options): stack.preprocess.append(filters.TruncateStringFilter( width=options['truncate_strings'], char=options['truncate_char'])) - if options.get('use_space_around_operators', False): - stack.enable_grouping() - stack.stmtprocess.append(filters.SpacesAroundOperatorsFilter()) + # Before grouping + if options.get('strip_comments'): + stack.stmtprocess.append(filters.StripCommentsFilter()) + + # Grouping + stack.stmtprocess.append(stack.grouping_filter) # After grouping - if options.get('strip_comments'): + if options.get('use_space_around_operators', False): stack.enable_grouping() - stack.stmtprocess.append(filters.StripCommentsFilter()) + stack.stmtprocess.append(filters.SpacesAroundOperatorsFilter()) if options.get('strip_whitespace') or options.get('reindent'): stack.enable_grouping() diff --git a/sqlparse/sql.py b/sqlparse/sql.py index 6a32c26a..812d12ba 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -36,26 +36,14 @@ def get_alias(self): return self._get_first_name(reverse=True) -class Token: +class TokenBase: """Base class for all other classes in this module. - - It represents a single token and has two instance attributes: - ``value`` is the unchanged value of the token and ``ttype`` is - the type of the token. """ - __slots__ = ('value', 'ttype', 'parent', 'normalized', 'is_keyword', - 'is_group', 'is_whitespace') + __slots__ = ('parent') - def __init__(self, ttype, value): - value = str(value) - self.value = value - self.ttype = ttype + def __init__(self): self.parent = None - self.is_group = False - self.is_keyword = ttype in T.Keyword - self.is_whitespace = self.ttype in T.Whitespace - self.normalized = value.upper() if self.is_keyword else value def __str__(self): return self.value @@ -72,19 +60,12 @@ def __repr__(self): return "<{cls} {q}{value}{q} at 0x{id:2X}>".format( id=id(self), **locals()) - def _get_repr_name(self): - return str(self.ttype).split('.')[-1] - def _get_repr_value(self): raw = str(self) if len(raw) > 7: raw = raw[:6] + '...' return re.sub(r'\s+', ' ', raw) - def flatten(self): - """Resolve subgroups.""" - yield self - def match(self, ttype, values, regex=False): """Checks whether the token matches the given arguments. @@ -146,24 +127,61 @@ def has_ancestor(self, other): return False -class TokenList(Token): +class Token(TokenBase): + """"A single token. + + It has two instance attributes: + ``value`` is the unchanged value of the token and ``ttype`` is + the type of the token. + """ + is_group = False + + __slots__ = ('value', 'ttype', 'normalized', 'is_keyword', 'is_whitespace') + + def __init__(self, ttype, value): + super().__init__() + value = str(value) + self.value = value + self.ttype = ttype + self.is_keyword = ttype in T.Keyword + self.is_whitespace = ttype in T.Whitespace + self.normalized = value.upper() if self.is_keyword else value + + def _get_repr_name(self): + return str(self.ttype).split('.')[-1] + + def flatten(self): + """Resolve subgroups.""" + yield self + + +class TokenList(TokenBase): """A group of tokens. - It has an additional instance attribute ``tokens`` which holds a - list of child-tokens. + It has one instance attribute ``tokens`` which holds a list of + child-tokens. """ __slots__ = 'tokens' + is_group = True + ttype = None + is_keyword = False + is_whitespace = False + def __init__(self, tokens=None): + super().__init__() self.tokens = tokens or [] [setattr(token, 'parent', self) for token in self.tokens] - super().__init__(None, str(self)) - self.is_group = True - def __str__(self): + @property + def value(self): return ''.join(token.value for token in self.flatten()) + @property + def normalized(self): + return self.value + # weird bug # def __len__(self): # return len(self.tokens) @@ -322,7 +340,6 @@ def group_tokens(self, grp_cls, start, end, include_end=True, grp = start grp.tokens.extend(subtokens) del self.tokens[start_idx + 1:end_idx] - grp.value = str(start) else: subtokens = self.tokens[start_idx:end_idx] grp = grp_cls(subtokens)