Skip to content

Commit

Permalink
AST - Refactor if let and while let to allow multiple conditions …
Browse files Browse the repository at this point in the history
…(and be a de-sugar)
  • Loading branch information
thepowersgang committed Jan 5, 2024
1 parent 15b827f commit 8745fba
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 76 deletions.
45 changes: 31 additions & 14 deletions src/ast/dump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,6 @@ class RustPrinter:
m_os << "while ";
AST::NodeVisitor::visit(n.m_cond);
break;
case AST::ExprNode_Loop::WHILELET:
m_os << "while let ";
print_pattern(n.m_pattern, true);
m_os << " = ";
AST::NodeVisitor::visit(n.m_cond);
break;
case AST::ExprNode_Loop::FOR:
m_os << "while for ";
print_pattern(n.m_pattern, true);
Expand All @@ -287,6 +281,36 @@ class RustPrinter:

AST::NodeVisitor::visit(n.m_code);
}
void visit_iflet_conditions(std::vector<AST::IfLet_Condition>& conds) {
for(size_t i = 0; i < conds.size(); i ++) {
if(i != 0) m_os << " && ";
if(conds[i].opt_pat) {
print_pattern(*conds[i].opt_pat, true);
m_os << " = ";
}
m_os << "(";
AST::NodeVisitor::visit(conds[i].value);
m_os << ")";
}
}
void visit(AST::ExprNode_WhileLet& n) override {
bool expr_root = m_expr_root;
m_expr_root = false;

m_os << "while let ";
visit_iflet_conditions(n.m_conditions);
if( expr_root )
{
m_os << "\n";
m_os << indent();
}
else
{
m_os << " ";
}

AST::NodeVisitor::visit(n.m_code);
}
virtual void visit(AST::ExprNode_Match& n) override {
bool expr_root = m_expr_root;
m_expr_root = false;
Expand Down Expand Up @@ -361,14 +385,7 @@ class RustPrinter:
bool expr_root = m_expr_root;
m_expr_root = false;
m_os << "if let ";
for(const auto& pat : n.m_patterns)
{
if(&pat != &n.m_patterns.front())
m_os << " | ";
print_pattern(pat, /*is_refutable=*/true);
}
m_os << " = ";
AST::NodeVisitor::visit(n.m_value);
visit_iflet_conditions(n.m_conditions);

visit_if_common(expr_root, n.m_true, n.m_false);
}
Expand Down
71 changes: 59 additions & 12 deletions src/ast/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,35 @@ NODE(ExprNode_Loop, {
os << " in/= " << *m_cond;
os << " " << *m_code;
},{
return NEWNODE(ExprNode_Loop, m_label, m_type, m_pattern.clone(), OPT_CLONE(m_cond), m_code->clone());
return NEWNODE(ExprNode_Loop, m_type, m_label, m_pattern.clone(), OPT_CLONE(m_cond), m_code->clone());
})
NODE(ExprNode_WhileLet, {
if(m_label != "") {
os << "'" << m_label << ": ";
}
os << "while let ";
for(size_t i = 0; i < m_conditions.size(); i ++) {
if( i != 0 ) {
os << " && ";
}
if( m_conditions[i].opt_pat ) {
os << *m_conditions[i].opt_pat << " = ";
}
os << "(" << *m_conditions[i].value << ")";
}
os << " { " << *m_code << " }";
},{
decltype(m_conditions) new_conds;
for(const auto& cond : m_conditions) {
AST::IfLet_Condition new_cond;
if( cond.opt_pat ) {
new_cond.opt_pat = std::make_unique<AST::Pattern>( cond.opt_pat->clone() );
}
new_cond.value = cond.value->clone();
new_conds.push_back(std::move(new_cond));
}
return NEWNODE(ExprNode_WhileLet, m_label, mv$(new_conds), m_code->clone());
})

MatchGuard MatchGuard::clone() const
{
Expand Down Expand Up @@ -382,19 +409,28 @@ NODE(ExprNode_If, {
})
NODE(ExprNode_IfLet, {
os << "if let ";
for(const auto& pat : m_patterns)
{
if(&pat != &m_patterns.front())
os << " | ";
os << pat;
for(size_t i = 0; i < m_conditions.size(); i ++) {
if( i != 0 ) {
os << " && ";
}
if( m_conditions[i].opt_pat ) {
os << *m_conditions[i].opt_pat << " = ";
}
os << "(" << *m_conditions[i].value << ")";
}
os << " = (" << *m_value << ") { " << *m_true << " }";
os << " { " << *m_true << " }";
if(m_false) os << " else { " << *m_false << " }";
},{
decltype(m_patterns) new_pats;
for(const auto& pat : m_patterns)
new_pats.push_back(pat.clone());
return NEWNODE(ExprNode_IfLet, mv$(new_pats), m_value->clone(), m_true->clone(), OPT_CLONE(m_false));
decltype(m_conditions) new_conds;
for(const auto& cond : m_conditions) {
AST::IfLet_Condition new_cond;
if( cond.opt_pat ) {
new_cond.opt_pat = std::make_unique<AST::Pattern>( cond.opt_pat->clone() );
}
new_cond.value = cond.value->clone();
new_conds.push_back(std::move(new_cond));
}
return NEWNODE(ExprNode_IfLet, mv$(new_conds), m_true->clone(), OPT_CLONE(m_false));
})

NODE(ExprNode_Integer, {
Expand Down Expand Up @@ -707,6 +743,15 @@ NV(ExprNode_Loop,
visit(node.m_code);
UNINDENT();
})
NV(ExprNode_WhileLet,
{
INDENT();
for(auto& c : node.m_conditions) {
visit(c.value);
}
visit(node.m_code);
UNINDENT();
})
NV(ExprNode_Match,
{
INDENT();
Expand All @@ -733,7 +778,9 @@ NV(ExprNode_If,
NV(ExprNode_IfLet,
{
INDENT();
visit(node.m_value);
for(auto& c : node.m_conditions) {
visit(c.value);
}
visit(node.m_true);
visit(node.m_false);
UNINDENT();
Expand Down
47 changes: 37 additions & 10 deletions src/ast/expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

namespace AST {

class Pattern;
class NodeVisitor;

class ExprNode
Expand Down Expand Up @@ -292,7 +293,6 @@ struct ExprNode_Loop:
enum Type {
LOOP,
WHILE,
WHILELET,
FOR,
} m_type;
Ident m_label;
Expand All @@ -316,14 +316,41 @@ struct ExprNode_Loop:
m_cond( ::std::move(cond) ),
m_code( ::std::move(code) )
{}
ExprNode_Loop(Ident label, Type type, AST::Pattern pattern, ExprNodeP val, ExprNodeP code):
m_type(type),
ExprNode_Loop(Ident label, AST::Pattern pattern, ExprNodeP val, ExprNodeP code):
m_type(FOR),
m_label( ::std::move(label) ),
m_pattern( ::std::move(pattern) ),
m_cond( ::std::move(val) ),
m_code( ::std::move(code) )
{}
NODE_METHODS();
private:
ExprNode_Loop(Type type, Ident label, AST::Pattern pattern, ExprNodeP val, ExprNodeP code)
: m_type( type )
, m_label( ::std::move(label) )
, m_pattern( ::std::move(pattern) )
, m_cond( ::std::move(val) )
, m_code( ::std::move(code) )
{}
};
struct IfLet_Condition
{
::std::unique_ptr<AST::Pattern> opt_pat;
ExprNodeP value;
};
struct ExprNode_WhileLet:
public ExprNode
{
Ident m_label;
std::vector<IfLet_Condition> m_conditions;
ExprNodeP m_code;

ExprNode_WhileLet(Ident label, std::vector<IfLet_Condition> conditions, ExprNodeP code)
: m_label( ::std::move(label) )
, m_conditions( ::std::move(conditions) )
, m_code( ::std::move(code) )
{}
NODE_METHODS();
};

TAGGED_UNION_EX(MatchGuard, (), None, (
Expand Down Expand Up @@ -387,16 +414,14 @@ struct ExprNode_If:
struct ExprNode_IfLet:
public ExprNode
{
std::vector<AST::Pattern> m_patterns;
ExprNodeP m_value;
std::vector<IfLet_Condition> m_conditions;
ExprNodeP m_true;
ExprNodeP m_false;

ExprNode_IfLet(std::vector<AST::Pattern> patterns, ExprNodeP cond, ExprNodeP true_code, ExprNodeP false_code):
m_patterns( ::std::move(patterns) ),
m_value( ::std::move(cond) ),
m_true( ::std::move(true_code) ),
m_false( ::std::move(false_code) )
ExprNode_IfLet(std::vector<IfLet_Condition> conditions, ExprNodeP true_code, ExprNodeP false_code)
: m_conditions( ::std::move(conditions) )
, m_true( ::std::move(true_code) )
, m_false( ::std::move(false_code) )
{
}
NODE_METHODS();
Expand Down Expand Up @@ -727,6 +752,7 @@ class NodeVisitor
NT(ExprNode_CallMethod);
NT(ExprNode_CallObject);
NT(ExprNode_Loop);
NT(ExprNode_WhileLet);
NT(ExprNode_Match);
NT(ExprNode_If);
NT(ExprNode_IfLet);
Expand Down Expand Up @@ -776,6 +802,7 @@ class NodeVisitorDef:
NT(ExprNode_CallMethod);
NT(ExprNode_CallObject);
NT(ExprNode_Loop);
NT(ExprNode_WhileLet);
NT(ExprNode_Match);
NT(ExprNode_If);
NT(ExprNode_IfLet);
Expand Down
Loading

0 comments on commit 8745fba

Please sign in to comment.