diff --git a/bowler/helpers.py b/bowler/helpers.py index a31d2d0..9deb1e8 100644 --- a/bowler/helpers.py +++ b/bowler/helpers.py @@ -6,7 +6,8 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import List, Optional, Sequence, Union +from inspect import Signature +from typing import TYPE_CHECKING, List, Optional, Sequence, Union import click from fissix.pgen2.token import tok_name @@ -14,6 +15,9 @@ from .types import LN, SYMBOL, TOKEN, Capture, Filename, FilenameMatcher +if TYPE_CHECKING: + from .imr import FunctionSpec + log = logging.getLogger(__name__) INDENT_STR = ". " @@ -144,6 +148,38 @@ def is_call_to(node: LN, func_name: str) -> bool: ) +def spec_contains_parameter_name(param_name: str, spec: "FunctionSpec") -> bool: + return param_name in (arg.name for arg in spec.arguments) + + +def definition_contains_parameter(param_name: str, spec: "FunctionSpec") -> bool: + return spec_contains_parameter_name(param_name, spec) + + +def callsite_contains_parameter( + param_name: str, spec: "FunctionSpec", function_sig: Signature +) -> bool: + # Handle kwargs + if spec_contains_parameter_name(param_name, spec): + return True + + # Handle positional args + if param_name not in function_sig.parameters: + raise ValueError( + f"Function signature must contain parameter we are looking for" + ) + + names = list(function_sig.parameters) + position = names.index(param_name) + if names[0] in ("self", "cls", "meta"): + position -= 1 + + if len(spec.arguments) <= position: + return False + + return all(not arg.name and not arg.star for arg in spec.arguments[: position + 1]) + + def find_first(node: LN, target: int, recursive: bool = False) -> Optional[LN]: queue: List[LN] = [node] queue.extend(node.children) diff --git a/bowler/query.py b/bowler/query.py index 16ae466..c3b670f 100644 --- a/bowler/query.py +++ b/bowler/query.py @@ -9,7 +9,7 @@ import logging import pathlib import re -from functools import wraps +from functools import partial, wraps from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, cast from attr import Factory, dataclass @@ -19,6 +19,8 @@ from .helpers import ( Once, + callsite_contains_parameter, + definition_contains_parameter, dotted_parts, find_first, find_last, @@ -438,6 +440,42 @@ def filter_is_call(node: LN, capture: Capture, filename: Filename) -> bool: self.current.filters.append(filter_is_call) return self + def filter_contains_parameter( + self, param_name: str, node: LN, capture: Capture, filename: Filename + ) -> bool: + transform = self.current + if transform.selector not in ("function", "method"): + raise ValueError( + "parameter presence filter must follow select_function/select_method" + ) + if "source" not in transform.kwargs: + raise ValueError( + "Determining if parameter is present requires passing original function" + ) + + if "function_def" in capture or "class_def" in capture: + spec = FunctionSpec.build(node, capture) + return definition_contains_parameter(param_name, spec) + if "function_call" in capture or "class_call" in capture: + spec = FunctionSpec.build(node, capture) + sig = inspect.signature(transform.kwargs["source"]) + return callsite_contains_parameter(param_name, spec, sig) + + return False + + def contains_parameter(self, name: str) -> "Query": + self.current.filters.append(partial(self.filter_contains_parameter, name)) + return self + + def missing_parameter(self, name: str) -> "Query": + def filter_missing_parameter( + node: LN, capture: Capture, filename: Filename + ) -> bool: + return not self.filter_contains_parameter(name, node, capture, filename) + + self.current.filters.append(filter_missing_parameter) + return self + def is_def(self) -> "Query": def filter_is_def(node: LN, capture: Capture, filename: Filename) -> bool: return bool("function_def" in capture or "class_def" in capture) @@ -748,6 +786,10 @@ def add_argument_transform( value_leaf = Name(value) if spec.is_def: + if definition_contains_parameter(name, spec): + raise ValueError( + f"{name} is already present in the definition of {spec.name}" + ) new_arg = FunctionArgument( name, value_leaf if keyword else None, @@ -772,6 +814,12 @@ def add_argument_transform( spec.arguments.append(new_arg) elif positional: + if after not in (SENTINEL, START): + sig = inspect.signature(transform.kwargs["source"]) + if callsite_contains_parameter(name, spec, sig): + raise ValueError( + f"{name} is already present in the callsite of {spec.name}" + ) new_arg = FunctionArgument(value=value_leaf) for index, argument in enumerate(spec.arguments): if argument.star and argument.star.type == TOKEN.STAR: @@ -792,6 +840,29 @@ def add_argument_transform( if not done: spec.arguments.append(new_arg) + else: + if "source" not in transform.kwargs: + raise ValueError( + "Adding a positional arg to a callsite requires passing " + "original function" + ) + + sig = inspect.signature(transform.kwargs["source"]) + if callsite_contains_parameter(name, spec, sig): + raise ValueError( + f"{name} is already present in the callsite of {spec.name}" + ) + + # Drop kwarg at end of call to avoid breaking positional args + new_arg = FunctionArgument(name=name, value=value_leaf) + + # Double star should only ever be present at the end of a call + star_arg = spec.arguments[-1].star if spec.arguments else None + if star_arg and star_arg.type == TOKEN.DOUBLESTAR: + spec.arguments.insert(-1, new_arg) + else: + spec.arguments.append(new_arg) + spec.explode() transform.callbacks.append(add_argument_transform) diff --git a/bowler/tests/query.py b/bowler/tests/query.py index 94df5d1..05d320b 100644 --- a/bowler/tests/query.py +++ b/bowler/tests/query.py @@ -122,6 +122,57 @@ def query_func(x): query_func=query_func, ) + def test_filter_contains_parameter(self): + def f(x, y): + pass + + def query_func(x): + return Query(x).select_function(f).contains_parameter("y").rename("g") + + self.run_bowler_modifiers( + [ + ("def f(x, y):\n pass", "def g(x, y):\n pass"), + ("def f(x, y, z):\n pass", "def g(x, y, z):\n pass"), + ("def f(x):\n pass", "def f(x):\n pass"), + ("def x(y):\n pass", "def x(y):\n pass"), + ("def f(*_):\n pass", "def f(*_):\n pass"), + ("def f(**_):\n pass", "def f(**_):\n pass"), + ("f(1, y=2)", "g(1, y=2)"), + ("f(1, 2)", "g(1, 2)"), + ("f(1, 2, z=3)", "g(1, 2, z=3)"), + ("f(y=2)", "g(y=2)"), + ("f(x=2, **a)", "f(x=2, **a)"), + ("f(*a, **a)", "f(*a, **a)"), + ("f(*_)", "f(*_)"), + ("f(**_)", "f(**_)"), + ("f(x=1)", "f(x=1)"), + ("f(1)", "f(1)"), + ("f(1, x=2)", "f(1, x=2)"), + ], + query_func=query_func, + ) + + def test_filter_missing_parameter(self): + def f(x, y): + pass + + def query_func(x): + return Query(x).select_function(f).missing_parameter("y").rename("g") + + self.run_bowler_modifiers( + [ + ("def f(x, z):\n pass", "def g(x, z):\n pass"), + ("def f(x, y, z):\n pass", "def f(x, y, z):\n pass"), + ("f(1, z=2)", "g(1, z=2)"), + ("f(1)", "g(1)"), + ("f(x=2, **a)", "g(x=2, **a)"), + ("f(*a, **a)", "g(*a, **a)"), + ("f(1, 2, z=3)", "f(1, 2, z=3)"), + ("f(y=2)", "f(y=2)"), + ], + query_func=query_func, + ) + def test_filter_in_class(self): def query_func_bar(x): return Query(x).select_function("f").in_class("Bar", False).rename("g") @@ -172,19 +223,92 @@ def query_func_foo_subclasses(x): [("def f(): pass", "def f(): pass")], query_func=query_func_bar ) - def test_add_argument(self): + def test_add_keyword_argument(self): + def f(z, x): + pass + + def def_query_func(x): + return Query(x).select_function(f).is_def().add_argument("y", "5") + + def call_query_func(x): + return Query(x).select_function(f).is_call().add_argument("x", "5") + + def conditional_call_query_func(x): + return ( + Query(x) + .select_function(f) + .is_call() + .missing_parameter("x") + .add_argument("x", "5") + ) + + # Definition kwarg tests + self.run_bowler_modifiers( + [ + ("def f(z, x): pass", "def f(z, x, y=5): pass"), + ("def g(x): pass", "def g(x): pass"), + ], + query_func=def_query_func, + ) + with self.assertRaises(AssertionError): + self.run_bowler_modifier("def f(z, x, y): pass", query_func=def_query_func) + + # Callsite kwarg tests + self.run_bowler_modifiers( + [ + ("f(1)", "f(1, x=5)"), + ("f(z=1)", "f(z=1, x=5)"), + ("f(z=1, **a)", "f(z=1, x=5, **a)"), + ("g()", "g()"), + ], + query_func=call_query_func, + ) + with self.assertRaises(AssertionError): + self.run_bowler_modifier("f(1, 2)", query_func=call_query_func) + + # Conditional callsite kwarg tests + self.run_bowler_modifiers( + [ + ("f(1)", "f(1, x=5)"), + ("f(1, 2)", "f(1, 2)"), + ("f(z=1)", "f(z=1, x=5)"), + ("f(z=1, x=2)", "f(z=1, x=2)"), + ("f(z=1, **a)", "f(z=1, x=5, **a)"), + ("f(z=1, x=2)", "f(z=1, x=2)"), + ("f(z=1, **a)", "f(z=1, x=5, **a)"), + ("g()", "g()"), + ], + query_func=conditional_call_query_func, + ) + + def test_add_positional_argument(self): def query_func(x): - return Query(x).select_function("f").add_argument("y", "5") + return Query(x).select_function("f").add_argument("y", "5", True) self.run_bowler_modifiers( [ - ("def f(x): pass", "def f(x, y=5): pass"), + ("def f(x): pass", "def f(x, y): pass"), ("def g(x): pass", "def g(x): pass"), - # ("f()", "???"), + ("f()", "f(5)"), ("g()", "g()"), ], + query_func=conditional_call_query_func, + ) + + def test_add_positional_argument(self): + def query_func(x): + return Query(x).select_function("f").add_argument("y", "5", True) + + self.run_bowler_modifiers( + [ + ("def f(x): pass", "def f(x, y): pass"), + ("def g(x): pass", "def g(x): pass"), + ("f()", "f(5)"), + ], query_func=query_func, ) + with self.assertRaises(AssertionError): + self.run_bowler_modifier("def f(x, y): pass", query_func=query_func) def test_modifier_return_value(self): input = "a+b" diff --git a/docs/api-modifiers.md b/docs/api-modifiers.md index 63ad96e..7d497d4 100644 --- a/docs/api-modifiers.md +++ b/docs/api-modifiers.md @@ -70,9 +70,7 @@ new_name | New name for element. ### `.add_argument()` -Add an argument to a function or method, as well as callers. For positional arguments, -the default value will be used to update all callers; for keyword arguments, it will -be used in the function definition. +Add an argument to a function or method, as well as callers. Requires use of [`.select_function`](api-selectors#select-function) or [`.select_method`](api-selectors#select-method).