Skip to content

Commit

Permalink
Merge pull request #159 from angelcaru/va_bug
Browse files Browse the repository at this point in the history
Fix variadic arguments bug #156
  • Loading branch information
Almas-Ali authored Jun 1, 2024
2 parents e7287e9 + 8faf837 commit fa4ffc0
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 13 deletions.
6 changes: 5 additions & 1 deletion core/builtin_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ def execute(self, args: list[Value], kwargs: dict[str, Value]) -> RTResult[Value
else:
method = self.func

res.register(self.check_and_populate_args(method.arg_names, args, kwargs, method.defaults, exec_ctx))
res.register(
self.check_and_populate_args(
method.arg_names, args, kwargs, method.defaults, len(method.arg_names), exec_ctx
)
)
if res.should_return():
return res

Expand Down
37 changes: 27 additions & 10 deletions core/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,12 +1051,17 @@ def generate_new_context(self) -> Context:
return new_context

def check_args(
self, arg_names: list[str], args: list[Value], kwargs: dict[str, Value], defaults: list[Optional[Value]]
self,
arg_names: list[str],
args: list[Value],
kwargs: dict[str, Value],
defaults: list[Optional[Value]],
max_pos_args: int,
) -> RTResult[None]:
res = RTResult[None]()

args_count = len(args) + len(kwargs)
if self.va_name is None and args_count > len(arg_names):
if self.va_name is None and (args_count > len(arg_names) or len(args) > max_pos_args):
return res.failure(
RTError(
self.pos_start,
Expand Down Expand Up @@ -1090,18 +1095,24 @@ def check_args(

return res.success(None)

def populate_args(self, arg_names, args, kwargs, defaults, exec_ctx):
def populate_args(self, arg_names, args, kwargs, defaults, max_pos_args, exec_ctx):
for i, (arg_name, default) in enumerate(zip(arg_names, defaults)):
if default is not None:
exec_ctx.symbol_table.set(arg_name, default)

populated = 0
for i in range(len(arg_names)):
arg_name = arg_names[i]
if arg_name in kwargs:
if i >= max_pos_args or arg_name in kwargs or i >= len(args):
continue
arg_value = defaults[i] if i >= len(args) else args[i]
arg_value = args[i]
arg_value.set_context(exec_ctx)
exec_ctx.symbol_table.set(arg_name, arg_value)
populated += 1

if self.va_name is not None:
va_list = []
for i in range(len(arg_names), len(args)):
for i in range(populated, len(args)):
arg = args[i]
arg.set_context(exec_ctx)
va_list.append(arg)
Expand All @@ -1111,12 +1122,12 @@ def populate_args(self, arg_names, args, kwargs, defaults, exec_ctx):
kwarg.set_context(exec_ctx)
exec_ctx.symbol_table.set(kw, kwarg)

def check_and_populate_args(self, arg_names, args, kwargs, defaults, exec_ctx):
def check_and_populate_args(self, arg_names, args, kwargs, defaults, max_pos_args, exec_ctx):
res = RTResult()
res.register(self.check_args(arg_names, args, kwargs, defaults))
res.register(self.check_args(arg_names, args, kwargs, defaults, max_pos_args))
if res.should_return():
return res
self.populate_args(arg_names, args, kwargs, defaults, exec_ctx)
self.populate_args(arg_names, args, kwargs, defaults, max_pos_args, exec_ctx)
return res.success(None)


Expand Down Expand Up @@ -1359,6 +1370,7 @@ class Function(BaseFunction):
arg_names: list[str]
defaults: list[Optional[Value]]
should_auto_return: bool
max_pos_args: int

def __help_repr__(self) -> str:
return f"Help on function {self.name}:\n\n{self.__help_repr_method__()}"
Expand All @@ -1382,6 +1394,7 @@ def __init__(
should_auto_return: bool,
desc: str,
va_name: Optional[str],
max_pos_args: int,
) -> None:
super().__init__(name, symbol_table)
self.body_node = body_node
Expand All @@ -1390,6 +1403,7 @@ def __init__(
self.should_auto_return = should_auto_return
self.desc = desc
self.va_name = va_name
self.max_pos_args = max_pos_args

def execute(self, args: list[Value], kwargs: dict[str, Value]) -> RTResult[Value]:
from core.interpreter import Interpreter # Lazy import
Expand All @@ -1398,7 +1412,9 @@ def execute(self, args: list[Value], kwargs: dict[str, Value]) -> RTResult[Value
interpreter = Interpreter()
exec_ctx = self.generate_new_context()

res.register(self.check_and_populate_args(self.arg_names, args, kwargs, self.defaults, exec_ctx))
res.register(
self.check_and_populate_args(self.arg_names, args, kwargs, self.defaults, self.max_pos_args, exec_ctx)
)
if res.should_return():
return res

Expand All @@ -1424,6 +1440,7 @@ def copy(self) -> Function:
self.should_auto_return,
self.desc,
self.va_name,
self.max_pos_args,
)
copy.set_context(self.context)
copy.set_pos(self.pos_start, self.pos_end)
Expand Down
1 change: 1 addition & 0 deletions core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def visit_FuncDefNode(self, node: FuncDefNode, context: Context) -> RTResult[Val
node.should_auto_return,
func_desc,
va_name=node.va_name,
max_pos_args=node.max_pos_args,
)
.set_context(context)
.set_pos(node.pos_start, node.pos_end)
Expand Down
1 change: 1 addition & 0 deletions core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ class FuncDefNode:
static: bool
desc: str
va_name: Optional[str]
max_pos_args: int

pos_start: Position
pos_end: Position
Expand Down
16 changes: 15 additions & 1 deletion core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,14 @@ def call(self) -> ParseResult[Node]:
return res
assert pair is not None
kw, val = pair
if kw is None:
if kw is None and len(kwarg_nodes) == 0:
arg_nodes.append(val)
elif kw is None:
return res.failure(
InvalidSyntaxError(
val.pos_start, val.pos_end, "Positional arguments may not come after keyword arguments"
)
)
else:
kwarg_nodes[kw] = val

Expand Down Expand Up @@ -1233,6 +1239,7 @@ def func_def(self) -> ParseResult[Node]:
defaults: list[Optional[Node]] = []
has_optionals = False
is_va = False
max_pos_args = 0
va_name: Optional[str] = None

if self.current_tok.type == TT_SPREAD:
Expand All @@ -1248,6 +1255,8 @@ def func_def(self) -> ParseResult[Node]:
self.advance(res)
if not is_va:
arg_name_toks.append(arg_name_tok)
if va_name is None:
max_pos_args += 1

if is_va:
va_name = arg_name_tok.value
Expand Down Expand Up @@ -1284,6 +1293,9 @@ def func_def(self) -> ParseResult[Node]:
assert isinstance(arg_name_tok.value, str)
if not is_va:
arg_name_toks.append(arg_name_tok)
if va_name is None:
max_pos_args += 1

self.advance(res)

if is_va:
Expand Down Expand Up @@ -1335,6 +1347,7 @@ def func_def(self) -> ParseResult[Node]:
static=static,
desc="[No Description]",
va_name=va_name,
max_pos_args=max_pos_args,
pos_start=node_pos_start,
pos_end=self.current_tok.pos_end,
)
Expand Down Expand Up @@ -1375,6 +1388,7 @@ def func_def(self) -> ParseResult[Node]:
static=static,
desc=desc,
va_name=va_name,
max_pos_args=max_pos_args,
pos_start=node_pos_start,
pos_end=self.current_tok.pos_end,
)
Expand Down
2 changes: 1 addition & 1 deletion core/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Position:
ftxt: str

def __str__(self) -> str:
return f"{self.fn}:{self.ln}:{self.col}"
return f"{self.fn}:{self.ln+1}:{self.col+1}"

def advance(self, current_char: Optional[str] = None) -> Position:
self.idx += 1
Expand Down
5 changes: 5 additions & 0 deletions tests/va_bug.rn
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fun f(...args, kw=null) -> print(kw)

f(1, 2, 3)
f(1, 2, 3, kw=4)

1 change: 1 addition & 0 deletions tests/va_bug.rn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"code": 0, "stdout": "null\n4\n", "stderr": ""}

0 comments on commit fa4ffc0

Please sign in to comment.