Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix variadic arguments bug #159

Merged
merged 2 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion core/builtin_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ def execute(self, args: list[Value], kwargs: dict[str, Value]) -> RTResult[Value
else:
method = self.func

res.register(self.check_and_populate_args(method.arg_names, args, kwargs, method.defaults, exec_ctx))
res.register(
self.check_and_populate_args(
method.arg_names, args, kwargs, method.defaults, len(method.arg_names), exec_ctx
)
)
if res.should_return():
return res

Expand Down
37 changes: 27 additions & 10 deletions core/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,12 +1051,17 @@ def generate_new_context(self) -> Context:
return new_context

def check_args(
self, arg_names: list[str], args: list[Value], kwargs: dict[str, Value], defaults: list[Optional[Value]]
self,
arg_names: list[str],
args: list[Value],
kwargs: dict[str, Value],
defaults: list[Optional[Value]],
max_pos_args: int,
) -> RTResult[None]:
res = RTResult[None]()

args_count = len(args) + len(kwargs)
if self.va_name is None and args_count > len(arg_names):
if self.va_name is None and (args_count > len(arg_names) or len(args) > max_pos_args):
return res.failure(
RTError(
self.pos_start,
Expand Down Expand Up @@ -1090,18 +1095,24 @@ def check_args(

return res.success(None)

def populate_args(self, arg_names, args, kwargs, defaults, exec_ctx):
def populate_args(self, arg_names, args, kwargs, defaults, max_pos_args, exec_ctx):
for i, (arg_name, default) in enumerate(zip(arg_names, defaults)):
if default is not None:
exec_ctx.symbol_table.set(arg_name, default)

populated = 0
for i in range(len(arg_names)):
arg_name = arg_names[i]
if arg_name in kwargs:
if i >= max_pos_args or arg_name in kwargs or i >= len(args):
continue
arg_value = defaults[i] if i >= len(args) else args[i]
arg_value = args[i]
arg_value.set_context(exec_ctx)
exec_ctx.symbol_table.set(arg_name, arg_value)
populated += 1

if self.va_name is not None:
va_list = []
for i in range(len(arg_names), len(args)):
for i in range(populated, len(args)):
arg = args[i]
arg.set_context(exec_ctx)
va_list.append(arg)
Expand All @@ -1111,12 +1122,12 @@ def populate_args(self, arg_names, args, kwargs, defaults, exec_ctx):
kwarg.set_context(exec_ctx)
exec_ctx.symbol_table.set(kw, kwarg)

def check_and_populate_args(self, arg_names, args, kwargs, defaults, exec_ctx):
def check_and_populate_args(self, arg_names, args, kwargs, defaults, max_pos_args, exec_ctx):
res = RTResult()
res.register(self.check_args(arg_names, args, kwargs, defaults))
res.register(self.check_args(arg_names, args, kwargs, defaults, max_pos_args))
if res.should_return():
return res
self.populate_args(arg_names, args, kwargs, defaults, exec_ctx)
self.populate_args(arg_names, args, kwargs, defaults, max_pos_args, exec_ctx)
return res.success(None)


Expand Down Expand Up @@ -1359,6 +1370,7 @@ class Function(BaseFunction):
arg_names: list[str]
defaults: list[Optional[Value]]
should_auto_return: bool
max_pos_args: int

def __help_repr__(self) -> str:
return f"Help on function {self.name}:\n\n{self.__help_repr_method__()}"
Expand All @@ -1382,6 +1394,7 @@ def __init__(
should_auto_return: bool,
desc: str,
va_name: Optional[str],
max_pos_args: int,
) -> None:
super().__init__(name, symbol_table)
self.body_node = body_node
Expand All @@ -1390,6 +1403,7 @@ def __init__(
self.should_auto_return = should_auto_return
self.desc = desc
self.va_name = va_name
self.max_pos_args = max_pos_args

def execute(self, args: list[Value], kwargs: dict[str, Value]) -> RTResult[Value]:
from core.interpreter import Interpreter # Lazy import
Expand All @@ -1398,7 +1412,9 @@ def execute(self, args: list[Value], kwargs: dict[str, Value]) -> RTResult[Value
interpreter = Interpreter()
exec_ctx = self.generate_new_context()

res.register(self.check_and_populate_args(self.arg_names, args, kwargs, self.defaults, exec_ctx))
res.register(
self.check_and_populate_args(self.arg_names, args, kwargs, self.defaults, self.max_pos_args, exec_ctx)
)
if res.should_return():
return res

Expand All @@ -1424,6 +1440,7 @@ def copy(self) -> Function:
self.should_auto_return,
self.desc,
self.va_name,
self.max_pos_args,
)
copy.set_context(self.context)
copy.set_pos(self.pos_start, self.pos_end)
Expand Down
1 change: 1 addition & 0 deletions core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def visit_FuncDefNode(self, node: FuncDefNode, context: Context) -> RTResult[Val
node.should_auto_return,
func_desc,
va_name=node.va_name,
max_pos_args=node.max_pos_args,
)
.set_context(context)
.set_pos(node.pos_start, node.pos_end)
Expand Down
1 change: 1 addition & 0 deletions core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ class FuncDefNode:
static: bool
desc: str
va_name: Optional[str]
max_pos_args: int

pos_start: Position
pos_end: Position
Expand Down
8 changes: 8 additions & 0 deletions core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,7 @@ def func_def(self) -> ParseResult[Node]:
defaults: list[Optional[Node]] = []
has_optionals = False
is_va = False
max_pos_args = 0
va_name: Optional[str] = None

if self.current_tok.type == TT_SPREAD:
Expand All @@ -1248,6 +1249,8 @@ def func_def(self) -> ParseResult[Node]:
self.advance(res)
if not is_va:
arg_name_toks.append(arg_name_tok)
if va_name is None:
max_pos_args += 1

if is_va:
va_name = arg_name_tok.value
Expand Down Expand Up @@ -1284,6 +1287,9 @@ def func_def(self) -> ParseResult[Node]:
assert isinstance(arg_name_tok.value, str)
if not is_va:
arg_name_toks.append(arg_name_tok)
if va_name is None:
max_pos_args += 1

self.advance(res)

if is_va:
Expand Down Expand Up @@ -1335,6 +1341,7 @@ def func_def(self) -> ParseResult[Node]:
static=static,
desc="[No Description]",
va_name=va_name,
max_pos_args=max_pos_args,
pos_start=node_pos_start,
pos_end=self.current_tok.pos_end,
)
Expand Down Expand Up @@ -1375,6 +1382,7 @@ def func_def(self) -> ParseResult[Node]:
static=static,
desc=desc,
va_name=va_name,
max_pos_args=max_pos_args,
pos_start=node_pos_start,
pos_end=self.current_tok.pos_end,
)
Expand Down
2 changes: 1 addition & 1 deletion core/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Position:
ftxt: str

def __str__(self) -> str:
return f"{self.fn}:{self.ln}:{self.col}"
return f"{self.fn}:{self.ln+1}:{self.col+1}"

def advance(self, current_char: Optional[str] = None) -> Position:
self.idx += 1
Expand Down
5 changes: 5 additions & 0 deletions tests/va_bug.rn
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fun f(...args, kw=null) -> print(kw)

f(1, 2, 3)
f(1, 2, 3, kw=4, 5, 6)
angelcaru marked this conversation as resolved.
Show resolved Hide resolved

1 change: 1 addition & 0 deletions tests/va_bug.rn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"code": 0, "stdout": "null\n4\n", "stderr": ""}