Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.add_argument() should also add kwargs to callsites #95

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion bowler/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import List, Optional, Sequence, Union
from inspect import Signature
from typing import TYPE_CHECKING, List, Optional, Sequence, Union

import click
from fissix.pgen2.token import tok_name
from fissix.pytree import Leaf, Node, type_repr

from .types import LN, SYMBOL, TOKEN, Capture, Filename, FilenameMatcher

if TYPE_CHECKING:
from .imr import FunctionSpec

log = logging.getLogger(__name__)

INDENT_STR = ". "
Expand Down Expand Up @@ -144,6 +148,38 @@ def is_call_to(node: LN, func_name: str) -> bool:
)


def spec_contains_parameter_name(param_name: str, spec: "FunctionSpec") -> bool:
return param_name in (arg.name for arg in spec.arguments)


def definition_contains_parameter(param_name: str, spec: "FunctionSpec") -> bool:
return spec_contains_parameter_name(param_name, spec)


def callsite_contains_parameter(
param_name: str, spec: "FunctionSpec", function_sig: Signature
) -> bool:
# Handle kwargs
if spec_contains_parameter_name(param_name, spec):
return True

# Handle positional args
if param_name not in function_sig.parameters:
raise ValueError(
f"Function signature must contain parameter we are looking for"
)

names = list(function_sig.parameters)
position = names.index(param_name)
if names[0] in ("self", "cls", "meta"):
position -= 1

if len(spec.arguments) <= position:
return False

return all(not arg.name and not arg.star for arg in spec.arguments[: position + 1])


def find_first(node: LN, target: int, recursive: bool = False) -> Optional[LN]:
queue: List[LN] = [node]
queue.extend(node.children)
Expand Down
73 changes: 72 additions & 1 deletion bowler/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
import pathlib
import re
from functools import wraps
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, cast

from attr import Factory, dataclass
Expand All @@ -19,6 +19,8 @@

from .helpers import (
Once,
callsite_contains_parameter,
definition_contains_parameter,
dotted_parts,
find_first,
find_last,
Expand Down Expand Up @@ -438,6 +440,42 @@ def filter_is_call(node: LN, capture: Capture, filename: Filename) -> bool:
self.current.filters.append(filter_is_call)
return self

def filter_contains_parameter(
self, param_name: str, node: LN, capture: Capture, filename: Filename
) -> bool:
transform = self.current
if transform.selector not in ("function", "method"):
raise ValueError(
"parameter presence filter must follow select_function/select_method"
)
if "source" not in transform.kwargs:
raise ValueError(
"Determining if parameter is present requires passing original function"
)

if "function_def" in capture or "class_def" in capture:
spec = FunctionSpec.build(node, capture)
return definition_contains_parameter(param_name, spec)
if "function_call" in capture or "class_call" in capture:
spec = FunctionSpec.build(node, capture)
sig = inspect.signature(transform.kwargs["source"])
return callsite_contains_parameter(param_name, spec, sig)

return False

def contains_parameter(self, name: str) -> "Query":
self.current.filters.append(partial(self.filter_contains_parameter, name))
return self

def missing_parameter(self, name: str) -> "Query":
def filter_missing_parameter(
node: LN, capture: Capture, filename: Filename
) -> bool:
return not self.filter_contains_parameter(name, node, capture, filename)

self.current.filters.append(filter_missing_parameter)
return self

def is_def(self) -> "Query":
def filter_is_def(node: LN, capture: Capture, filename: Filename) -> bool:
return bool("function_def" in capture or "class_def" in capture)
Expand Down Expand Up @@ -748,6 +786,10 @@ def add_argument_transform(
value_leaf = Name(value)

if spec.is_def:
if definition_contains_parameter(name, spec):
raise ValueError(
f"{name} is already present in the definition of {spec.name}"
)
new_arg = FunctionArgument(
name,
value_leaf if keyword else None,
Expand All @@ -772,6 +814,12 @@ def add_argument_transform(
spec.arguments.append(new_arg)

elif positional:
if after not in (SENTINEL, START):
sig = inspect.signature(transform.kwargs["source"])
if callsite_contains_parameter(name, spec, sig):
raise ValueError(
f"{name} is already present in the callsite of {spec.name}"
)
new_arg = FunctionArgument(value=value_leaf)
for index, argument in enumerate(spec.arguments):
if argument.star and argument.star.type == TOKEN.STAR:
Expand All @@ -792,6 +840,29 @@ def add_argument_transform(
if not done:
spec.arguments.append(new_arg)

else:
if "source" not in transform.kwargs:
raise ValueError(
"Adding a positional arg to a callsite requires passing "
"original function"
)

sig = inspect.signature(transform.kwargs["source"])
if callsite_contains_parameter(name, spec, sig):
raise ValueError(
f"{name} is already present in the callsite of {spec.name}"
)

# Drop kwarg at end of call to avoid breaking positional args
new_arg = FunctionArgument(name=name, value=value_leaf)

# Double star should only ever be present at the end of a call
star_arg = spec.arguments[-1].star if spec.arguments else None
if star_arg and star_arg.type == TOKEN.DOUBLESTAR:
spec.arguments.insert(-1, new_arg)
else:
spec.arguments.append(new_arg)

spec.explode()

transform.callbacks.append(add_argument_transform)
Expand Down
132 changes: 128 additions & 4 deletions bowler/tests/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,57 @@ def query_func(x):
query_func=query_func,
)

def test_filter_contains_parameter(self):
def f(x, y):
pass

def query_func(x):
return Query(x).select_function(f).contains_parameter("y").rename("g")

self.run_bowler_modifiers(
[
("def f(x, y):\n pass", "def g(x, y):\n pass"),
("def f(x, y, z):\n pass", "def g(x, y, z):\n pass"),
("def f(x):\n pass", "def f(x):\n pass"),
("def x(y):\n pass", "def x(y):\n pass"),
("def f(*_):\n pass", "def f(*_):\n pass"),
("def f(**_):\n pass", "def f(**_):\n pass"),
("f(1, y=2)", "g(1, y=2)"),
("f(1, 2)", "g(1, 2)"),
("f(1, 2, z=3)", "g(1, 2, z=3)"),
("f(y=2)", "g(y=2)"),
("f(x=2, **a)", "f(x=2, **a)"),
("f(*a, **a)", "f(*a, **a)"),
("f(*_)", "f(*_)"),
("f(**_)", "f(**_)"),
("f(x=1)", "f(x=1)"),
("f(1)", "f(1)"),
("f(1, x=2)", "f(1, x=2)"),
],
query_func=query_func,
)

def test_filter_missing_parameter(self):
def f(x, y):
pass

def query_func(x):
return Query(x).select_function(f).missing_parameter("y").rename("g")

self.run_bowler_modifiers(
[
("def f(x, z):\n pass", "def g(x, z):\n pass"),
("def f(x, y, z):\n pass", "def f(x, y, z):\n pass"),
("f(1, z=2)", "g(1, z=2)"),
("f(1)", "g(1)"),
("f(x=2, **a)", "g(x=2, **a)"),
("f(*a, **a)", "g(*a, **a)"),
("f(1, 2, z=3)", "f(1, 2, z=3)"),
("f(y=2)", "f(y=2)"),
],
query_func=query_func,
)

def test_filter_in_class(self):
def query_func_bar(x):
return Query(x).select_function("f").in_class("Bar", False).rename("g")
Expand Down Expand Up @@ -172,19 +223,92 @@ def query_func_foo_subclasses(x):
[("def f(): pass", "def f(): pass")], query_func=query_func_bar
)

def test_add_argument(self):
def test_add_keyword_argument(self):
def f(z, x):
pass

def def_query_func(x):
return Query(x).select_function(f).is_def().add_argument("y", "5")

def call_query_func(x):
return Query(x).select_function(f).is_call().add_argument("x", "5")

def conditional_call_query_func(x):
return (
Query(x)
.select_function(f)
.is_call()
.missing_parameter("x")
.add_argument("x", "5")
)

# Definition kwarg tests
self.run_bowler_modifiers(
[
("def f(z, x): pass", "def f(z, x, y=5): pass"),
("def g(x): pass", "def g(x): pass"),
],
query_func=def_query_func,
)
with self.assertRaises(AssertionError):
self.run_bowler_modifier("def f(z, x, y): pass", query_func=def_query_func)

# Callsite kwarg tests
self.run_bowler_modifiers(
[
("f(1)", "f(1, x=5)"),
("f(z=1)", "f(z=1, x=5)"),
("f(z=1, **a)", "f(z=1, x=5, **a)"),
("g()", "g()"),
],
query_func=call_query_func,
)
with self.assertRaises(AssertionError):
self.run_bowler_modifier("f(1, 2)", query_func=call_query_func)

# Conditional callsite kwarg tests
self.run_bowler_modifiers(
[
("f(1)", "f(1, x=5)"),
("f(1, 2)", "f(1, 2)"),
("f(z=1)", "f(z=1, x=5)"),
("f(z=1, x=2)", "f(z=1, x=2)"),
("f(z=1, **a)", "f(z=1, x=5, **a)"),
("f(z=1, x=2)", "f(z=1, x=2)"),
("f(z=1, **a)", "f(z=1, x=5, **a)"),
("g()", "g()"),
],
query_func=conditional_call_query_func,
)

def test_add_positional_argument(self):
def query_func(x):
return Query(x).select_function("f").add_argument("y", "5")
return Query(x).select_function("f").add_argument("y", "5", True)

self.run_bowler_modifiers(
[
("def f(x): pass", "def f(x, y=5): pass"),
("def f(x): pass", "def f(x, y): pass"),
("def g(x): pass", "def g(x): pass"),
# ("f()", "???"),
("f()", "f(5)"),
("g()", "g()"),
],
query_func=conditional_call_query_func,
)

def test_add_positional_argument(self):
def query_func(x):
return Query(x).select_function("f").add_argument("y", "5", True)

self.run_bowler_modifiers(
[
("def f(x): pass", "def f(x, y): pass"),
("def g(x): pass", "def g(x): pass"),
("f()", "f(5)"),
],
query_func=query_func,
)
with self.assertRaises(AssertionError):
self.run_bowler_modifier("def f(x, y): pass", query_func=query_func)

def test_modifier_return_value(self):
input = "a+b"
Expand Down
4 changes: 1 addition & 3 deletions docs/api-modifiers.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ new_name | New name for element.

### `.add_argument()`

Add an argument to a function or method, as well as callers. For positional arguments,
the default value will be used to update all callers; for keyword arguments, it will
be used in the function definition.
Add an argument to a function or method, as well as callers.

Requires use of [`.select_function`](api-selectors#select-function) or
[`.select_method`](api-selectors#select-method).
Expand Down