diff --git a/core/builtin_funcs.py b/core/builtin_funcs.py index 27385cb..afbb345 100755 --- a/core/builtin_funcs.py +++ b/core/builtin_funcs.py @@ -66,7 +66,11 @@ def execute(self, args: list[Value], kwargs: dict[str, Value]) -> RTResult[Value else: method = self.func - res.register(self.check_and_populate_args(method.arg_names, args, kwargs, method.defaults, exec_ctx)) + res.register( + self.check_and_populate_args( + method.arg_names, args, kwargs, method.defaults, len(method.arg_names), exec_ctx + ) + ) if res.should_return(): return res diff --git a/core/datatypes.py b/core/datatypes.py index d266f0e..749c7a4 100755 --- a/core/datatypes.py +++ b/core/datatypes.py @@ -1051,12 +1051,17 @@ def generate_new_context(self) -> Context: return new_context def check_args( - self, arg_names: list[str], args: list[Value], kwargs: dict[str, Value], defaults: list[Optional[Value]] + self, + arg_names: list[str], + args: list[Value], + kwargs: dict[str, Value], + defaults: list[Optional[Value]], + max_pos_args: int, ) -> RTResult[None]: res = RTResult[None]() args_count = len(args) + len(kwargs) - if self.va_name is None and args_count > len(arg_names): + if self.va_name is None and (args_count > len(arg_names) or len(args) > max_pos_args): return res.failure( RTError( self.pos_start, @@ -1090,18 +1095,24 @@ def check_args( return res.success(None) - def populate_args(self, arg_names, args, kwargs, defaults, exec_ctx): + def populate_args(self, arg_names, args, kwargs, defaults, max_pos_args, exec_ctx): + for i, (arg_name, default) in enumerate(zip(arg_names, defaults)): + if default is not None: + exec_ctx.symbol_table.set(arg_name, default) + + populated = 0 for i in range(len(arg_names)): arg_name = arg_names[i] - if arg_name in kwargs: + if i >= max_pos_args or arg_name in kwargs or i >= len(args): continue - arg_value = defaults[i] if i >= len(args) else args[i] + arg_value = args[i] arg_value.set_context(exec_ctx) exec_ctx.symbol_table.set(arg_name, arg_value) + populated += 1 if self.va_name is not None: va_list = [] - for i in range(len(arg_names), len(args)): + for i in range(populated, len(args)): arg = args[i] arg.set_context(exec_ctx) va_list.append(arg) @@ -1111,12 +1122,12 @@ def populate_args(self, arg_names, args, kwargs, defaults, exec_ctx): kwarg.set_context(exec_ctx) exec_ctx.symbol_table.set(kw, kwarg) - def check_and_populate_args(self, arg_names, args, kwargs, defaults, exec_ctx): + def check_and_populate_args(self, arg_names, args, kwargs, defaults, max_pos_args, exec_ctx): res = RTResult() - res.register(self.check_args(arg_names, args, kwargs, defaults)) + res.register(self.check_args(arg_names, args, kwargs, defaults, max_pos_args)) if res.should_return(): return res - self.populate_args(arg_names, args, kwargs, defaults, exec_ctx) + self.populate_args(arg_names, args, kwargs, defaults, max_pos_args, exec_ctx) return res.success(None) @@ -1359,6 +1370,7 @@ class Function(BaseFunction): arg_names: list[str] defaults: list[Optional[Value]] should_auto_return: bool + max_pos_args: int def __help_repr__(self) -> str: return f"Help on function {self.name}:\n\n{self.__help_repr_method__()}" @@ -1382,6 +1394,7 @@ def __init__( should_auto_return: bool, desc: str, va_name: Optional[str], + max_pos_args: int, ) -> None: super().__init__(name, symbol_table) self.body_node = body_node @@ -1390,6 +1403,7 @@ def __init__( self.should_auto_return = should_auto_return self.desc = desc self.va_name = va_name + self.max_pos_args = max_pos_args def execute(self, args: list[Value], kwargs: dict[str, Value]) -> RTResult[Value]: from core.interpreter import Interpreter # Lazy import @@ -1398,7 +1412,9 @@ def execute(self, args: list[Value], kwargs: dict[str, Value]) -> RTResult[Value interpreter = Interpreter() exec_ctx = self.generate_new_context() - res.register(self.check_and_populate_args(self.arg_names, args, kwargs, self.defaults, exec_ctx)) + res.register( + self.check_and_populate_args(self.arg_names, args, kwargs, self.defaults, self.max_pos_args, exec_ctx) + ) if res.should_return(): return res @@ -1424,6 +1440,7 @@ def copy(self) -> Function: self.should_auto_return, self.desc, self.va_name, + self.max_pos_args, ) copy.set_context(self.context) copy.set_pos(self.pos_start, self.pos_end) diff --git a/core/interpreter.py b/core/interpreter.py index f09f157..94b6c52 100755 --- a/core/interpreter.py +++ b/core/interpreter.py @@ -517,6 +517,7 @@ def visit_FuncDefNode(self, node: FuncDefNode, context: Context) -> RTResult[Val node.should_auto_return, func_desc, va_name=node.va_name, + max_pos_args=node.max_pos_args, ) .set_context(context) .set_pos(node.pos_start, node.pos_end) diff --git a/core/nodes.py b/core/nodes.py index 7bf6959..ebb37b3 100755 --- a/core/nodes.py +++ b/core/nodes.py @@ -239,6 +239,7 @@ class FuncDefNode: static: bool desc: str va_name: Optional[str] + max_pos_args: int pos_start: Position pos_end: Position diff --git a/core/parser.py b/core/parser.py index e41b290..6c36edd 100755 --- a/core/parser.py +++ b/core/parser.py @@ -568,8 +568,14 @@ def call(self) -> ParseResult[Node]: return res assert pair is not None kw, val = pair - if kw is None: + if kw is None and len(kwarg_nodes) == 0: arg_nodes.append(val) + elif kw is None: + return res.failure( + InvalidSyntaxError( + val.pos_start, val.pos_end, "Positional arguments may not come after keyword arguments" + ) + ) else: kwarg_nodes[kw] = val @@ -1233,6 +1239,7 @@ def func_def(self) -> ParseResult[Node]: defaults: list[Optional[Node]] = [] has_optionals = False is_va = False + max_pos_args = 0 va_name: Optional[str] = None if self.current_tok.type == TT_SPREAD: @@ -1248,6 +1255,8 @@ def func_def(self) -> ParseResult[Node]: self.advance(res) if not is_va: arg_name_toks.append(arg_name_tok) + if va_name is None: + max_pos_args += 1 if is_va: va_name = arg_name_tok.value @@ -1284,6 +1293,9 @@ def func_def(self) -> ParseResult[Node]: assert isinstance(arg_name_tok.value, str) if not is_va: arg_name_toks.append(arg_name_tok) + if va_name is None: + max_pos_args += 1 + self.advance(res) if is_va: @@ -1335,6 +1347,7 @@ def func_def(self) -> ParseResult[Node]: static=static, desc="[No Description]", va_name=va_name, + max_pos_args=max_pos_args, pos_start=node_pos_start, pos_end=self.current_tok.pos_end, ) @@ -1375,6 +1388,7 @@ def func_def(self) -> ParseResult[Node]: static=static, desc=desc, va_name=va_name, + max_pos_args=max_pos_args, pos_start=node_pos_start, pos_end=self.current_tok.pos_end, ) diff --git a/core/tokens.py b/core/tokens.py index 5d89863..27fcde9 100755 --- a/core/tokens.py +++ b/core/tokens.py @@ -41,7 +41,7 @@ class Position: ftxt: str def __str__(self) -> str: - return f"{self.fn}:{self.ln}:{self.col}" + return f"{self.fn}:{self.ln+1}:{self.col+1}" def advance(self, current_char: Optional[str] = None) -> Position: self.idx += 1 diff --git a/tests/va_bug.rn b/tests/va_bug.rn new file mode 100644 index 0000000..6f25f17 --- /dev/null +++ b/tests/va_bug.rn @@ -0,0 +1,5 @@ +fun f(...args, kw=null) -> print(kw) + +f(1, 2, 3) +f(1, 2, 3, kw=4) + diff --git a/tests/va_bug.rn.json b/tests/va_bug.rn.json new file mode 100644 index 0000000..cd5f481 --- /dev/null +++ b/tests/va_bug.rn.json @@ -0,0 +1 @@ +{"code": 0, "stdout": "null\n4\n", "stderr": ""} \ No newline at end of file