Skip to content

Commit

Permalink
Merge pull request #153 from angelcaru/varargs
Browse files Browse the repository at this point in the history
Add variadic arguments from #152
  • Loading branch information
Almas-Ali authored May 31, 2024
2 parents a0ac8b8 + 67c3e1c commit 236dd90
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 42 deletions.
1 change: 1 addition & 0 deletions core/builtin_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class BuiltInFunction(BaseFunction):
def __init__(self, name: str, func: Optional[RadonCompatibleFunction] = None):
super().__init__(name, None)
self.func = func
self.va_name = None

def execute(self, args: list[Value], kwargs: dict[str, Value]) -> RTResult[Value]:
res = RTResult[Value]()
Expand Down
14 changes: 13 additions & 1 deletion core/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,7 @@ class BaseFunction(Value):
symbol_table: Optional[SymbolTable]
desc: str
arg_names: list[str]
va_name: Optional[str]

def __init__(self, name: Optional[str], symbol_table: Optional[SymbolTable]) -> None:
super().__init__()
Expand All @@ -1055,7 +1056,7 @@ def check_args(
res = RTResult[None]()

args_count = len(args) + len(kwargs)
if args_count > len(arg_names):
if self.va_name is None and args_count > len(arg_names):
return res.failure(
RTError(
self.pos_start,
Expand Down Expand Up @@ -1098,6 +1099,14 @@ def populate_args(self, arg_names, args, kwargs, defaults, exec_ctx):
arg_value.set_context(exec_ctx)
exec_ctx.symbol_table.set(arg_name, arg_value)

if self.va_name is not None:
va_list = []
for i in range(len(arg_names), len(args)):
arg = args[i]
arg.set_context(exec_ctx)
va_list.append(arg)
exec_ctx.symbol_table.set(self.va_name, Array(va_list))

for kw, kwarg in kwargs.items():
kwarg.set_context(exec_ctx)
exec_ctx.symbol_table.set(kw, kwarg)
Expand Down Expand Up @@ -1372,13 +1381,15 @@ def __init__(
defaults: list[Optional[Value]],
should_auto_return: bool,
desc: str,
va_name: Optional[str],
) -> None:
super().__init__(name, symbol_table)
self.body_node = body_node
self.arg_names = arg_names
self.defaults = defaults
self.should_auto_return = should_auto_return
self.desc = desc
self.va_name = va_name

def execute(self, args: list[Value], kwargs: dict[str, Value]) -> RTResult[Value]:
from core.interpreter import Interpreter # Lazy import
Expand Down Expand Up @@ -1412,6 +1423,7 @@ def copy(self) -> Function:
self.defaults,
self.should_auto_return,
self.desc,
self.va_name,
)
copy.set_context(self.context)
copy.set_pos(self.pos_start, self.pos_end)
Expand Down
9 changes: 8 additions & 1 deletion core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,14 @@ def visit_FuncDefNode(self, node: FuncDefNode, context: Context) -> RTResult[Val

func_value = (
Function(
func_name, context.symbol_table, body_node, arg_names, defaults, node.should_auto_return, func_desc
func_name,
context.symbol_table,
body_node,
arg_names,
defaults,
node.should_auto_return,
func_desc,
va_name=node.va_name,
)
.set_context(context)
.set_pos(node.pos_start, node.pos_end)
Expand Down
15 changes: 13 additions & 2 deletions core/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def make_tokens(self) -> tuple[list[Token], Optional[Error]]:
tokens.append(Token(TT_COLON, pos_start=self.pos))
self.advance()
elif self.current_char == ".":
tokens.append(Token(TT_DOT, pos_start=self.pos))
self.advance()
tokens.append(self.make_dot())
elif self.current_char == "!":
token, error = self.make_not_equals()
if error is not None:
Expand Down Expand Up @@ -273,6 +272,18 @@ def make_power_equals(self) -> Token:

return Token(tok_type, pos_start=pos_start, pos_end=self.pos)

def make_dot(self) -> Token:
tok_type = TT_DOT
pos_start = self.pos.copy()
self.advance()

if self.text[self.pos.idx :].startswith(".."):
self.advance()
self.advance()
tok_type = TT_SPREAD

return Token(tok_type, pos_start=pos_start, pos_end=self.pos.copy())

def skip_comment(self) -> None:
multi_line = False
self.advance()
Expand Down
29 changes: 2 additions & 27 deletions core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def __init__(self, condition_node: Node, body_node: Node, should_return_null: bo
self.pos_end = self.body_node.pos_end


@dataclass
class FuncDefNode:
var_name_tok: Optional[Token]
arg_name_toks: list[Token]
Expand All @@ -237,37 +238,11 @@ class FuncDefNode:
should_auto_return: bool
static: bool
desc: str
va_name: Optional[str]

pos_start: Position
pos_end: Position

def __init__(
self,
var_name_tok: Optional[Token],
arg_name_toks: list[Token],
defaults: list[Optional[Node]],
body_node: Node,
should_auto_return: bool,
static: bool = False,
desc: str = "",
) -> None:
self.var_name_tok = var_name_tok
self.arg_name_toks = arg_name_toks
self.defaults = defaults
self.body_node = body_node
self.should_auto_return = should_auto_return
self.static = static
self.desc = desc

if self.var_name_tok:
self.pos_start = self.var_name_tok.pos_start
elif len(self.arg_name_toks) > 0:
self.pos_start = self.arg_name_toks[0].pos_start
else:
self.pos_start = self.body_node.pos_start

self.pos_end = self.body_node.pos_end


class CallNode:
node_to_call: Node
Expand Down
71 changes: 60 additions & 11 deletions core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,8 @@ def class_node(self) -> ParseResult[Node]:
def func_def(self) -> ParseResult[Node]:
res = ParseResult[Node]()

node_pos_start = self.current_tok.pos_start

static = False
if self.current_tok.matches(TT_KEYWORD, "static"):
self.advance(res)
Expand Down Expand Up @@ -1229,50 +1231,73 @@ def func_def(self) -> ParseResult[Node]:
self.advance(res)
arg_name_toks = []
defaults: list[Optional[Node]] = []
hasOptionals = False
has_optionals = False
is_va = False
va_name: Optional[str] = None

if self.current_tok.type == TT_SPREAD:
is_va = True
self.advance(res)

if self.current_tok.type == TT_IDENTIFIER:
pos_start = self.current_tok.pos_start.copy()
pos_end = self.current_tok.pos_end.copy()

arg_name_toks.append(self.current_tok)
arg_name_tok = self.current_tok
assert isinstance(arg_name_tok.value, str)
self.advance(res)
if not is_va:
arg_name_toks.append(arg_name_tok)

if self.current_tok.type == TT_EQ:
if is_va:
va_name = arg_name_tok.value
is_va = False
elif self.current_tok.type == TT_EQ:
self.advance(res)
default = res.register(self.expr())
if res.error:
return res
assert default is not None
defaults.append(default)
hasOptionals = True
elif hasOptionals:
has_optionals = True
elif has_optionals:
return res.failure(InvalidSyntaxError(pos_start, pos_end, "Expected optional parameter."))
else:
defaults.append(None)

while self.current_tok.type == TT_COMMA:
self.advance(res)

if self.current_tok.type == TT_SPREAD:
is_va = True
self.advance(res)

if self.current_tok.type != TT_IDENTIFIER:
return res.failure(
InvalidSyntaxError(self.current_tok.pos_start, self.current_tok.pos_end, "Expected identifier")
)

pos_start = self.current_tok.pos_start.copy()
pos_end = self.current_tok.pos_end.copy()
arg_name_toks.append(self.current_tok)

arg_name_tok = self.current_tok
assert isinstance(arg_name_tok.value, str)
if not is_va:
arg_name_toks.append(arg_name_tok)
self.advance(res)

if self.current_tok.type == TT_EQ:
if is_va:
va_name = arg_name_tok.value
is_va = False
elif self.current_tok.type == TT_EQ:
self.advance(res)
default = res.register(self.expr())
if res.error:
return res
assert default is not None
defaults.append(default)
hasOptionals = True
elif hasOptionals:
has_optionals = True
elif has_optionals:
return res.failure(InvalidSyntaxError(pos_start, pos_end, "Expected optional parameter."))
else:
defaults.append(None)
Expand Down Expand Up @@ -1301,7 +1326,18 @@ def func_def(self) -> ParseResult[Node]:
assert body is not None

return res.success(
FuncDefNode(var_name_tok, arg_name_toks, defaults, body, True, static=static, desc="[No Description]")
FuncDefNode(
var_name_tok,
arg_name_toks,
defaults,
body,
True,
static=static,
desc="[No Description]",
va_name=va_name,
pos_start=node_pos_start,
pos_end=self.current_tok.pos_end,
)
)

self.skip_newlines()
Expand Down Expand Up @@ -1329,7 +1365,20 @@ def func_def(self) -> ParseResult[Node]:

self.advance(res)

return res.success(FuncDefNode(var_name_tok, arg_name_toks, defaults, body, False, static=static, desc=desc))
return res.success(
FuncDefNode(
var_name_tok,
arg_name_toks,
defaults,
body,
False,
static=static,
desc=desc,
va_name=va_name,
pos_start=node_pos_start,
pos_end=self.current_tok.pos_end,
)
)

def switch_statement(self) -> ParseResult[Node]:
res = ParseResult[Node]()
Expand Down
1 change: 1 addition & 0 deletions core/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def copy(self) -> Position:
TT_SLICE = TokenType("SLICE") # x[1:2:3]
TT_PLUS_PLUS = TokenType("PLUS_PLUS") # ++
TT_MINUS_MINUS = TokenType("MINUS_MINUS") # --
TT_SPREAD = TokenType("SPREAD") # ...

KEYWORDS = [
"and",
Expand Down
20 changes: 20 additions & 0 deletions tests/varargs.rn
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

fun f(...args) {
print(args)
}

fun g(a, ...args) {
print(a)
print(args)
}

f()
f(1, 2, 3)
f("hello", "world", "!")
f("a", "b", "c")

g(1)
g(1, 2, 3)
g("hello", "world", "!")
g("a", "b", "c")

1 change: 1 addition & 0 deletions tests/varargs.rn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"code": 0, "stdout": "[]\n[1, 2, 3]\n[\"hello\", \"world\", \"!\"]\n[\"a\", \"b\", \"c\"]\n1\n[]\n1\n[2, 3]\nhello\n[\"world\", \"!\"]\na\n[\"b\", \"c\"]\n", "stderr": ""}

0 comments on commit 236dd90

Please sign in to comment.