A compiler pass is used to analyze or transform RAF IR. As RAF IR is extended from Relay, most of RAF passes use the pass infrastructure of Relay. Please see TVM's dev guide firstly: Adding a Compiler Pass to Relay.
In the following sections, this article will introduce the process of how to add a pass to RAF firstly. Then there will be the differences you should know between RAF and TVM pass, as well as how to avoid stack overflow.
- put your code under
src/pass
, e.g. passFoldConstant
is insrc/pass/fold_const.cc
- register the pass, similar with TVM but use RAF's API (
CreateRAFFunctionPass
andRAF_REGISTER_GLOBAL
):
Pass FoldConstant() {
TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = [=](Function f, IRModule m,
PassContext pc) {
return Downcast<Function>(fold_const::ConstantFolder().Mutate(f));
};
return CreateRAFFunctionPass(pass_func, 1, "FoldConstant", {});
}
RAF_REGISTER_GLOBAL("raf.pass_.FoldConstant").set_body_typed(FoldConstant);
- also add pass declaration
Pass FoldConstant();
toinclude/raf/pass.h
, so that the pass can be used in other source code files - run codegen with
scripts/src_codegen/run_all.sh
, you will see the auto-generated FFI in python, the you can use the pass in python
Relay's pass infra traverses RelayConstantNode
, but sometimes we need to process on RAF ConstantNode
which has raf::Value
in, we should get it by casting at first:
void VisitExpr_(const RelayConstantNode* op) override {
const ConstantNode* node = static_cast<const ConstantNode*>(op);
// do something
}
RAF uses ExtendedVar
instead of Relay's Var
, whenever we need a new var, use MakeVar
.
For example, create a new var with empty type annotation:
Var new_var = MakeVar("name_of_var", {});
This part mainly introduces how to avoid stack overflow for ANF IR. For GNF IR, just use MixedModeVisitor/Mutator
instead of ExprVisitor/Mutator
, or refactor the pass by using ExprRewriter
and PostOrderRewrite
.
When processing LetNode
, use a loop or the utility function ExpandANormalForm
. Choose the approach you think is best for code readability.
For visitor:
void VisitExpr_(const LetNode* ln) final {
Expr expr = GetRef<Let>(ln);
// Iteratively visit let nodes to avoid stack overflow.
while (expr->IsInstance<LetNode>()) {
Let let = Downcast<Let>(expr);
// do something
expr = let->body;
}
// Visit the last body
MixedModeVisitor::VisitExpr(expr);
}
For mutator:
Expr VisitExpr_(const LetNode* node) {
scopes_.emplace_back(new LetList);
auto scope = scopes_.back().get();
Expr body;
do {
// do something, then push to scope
scope->Push(new_var, new_value);
body = node->body;
node = body.as<LetNode>();
} while (node);
auto new_body = VisitExpr(body);
auto ret = scopes_.back()->Get(new_body);
scopes_.pop_back();
return ret;
}
In regular usage, it usually needs memory/cache/counter to avoid nested function call: when the first time LetNode’s body is visited in post_visit, it visits the var node (first time post_visit
is called) and update the counter/cache, so in the following post_visit
calls, it actually accesses the counter/cache without visiting again.
For ExprVisitor/MixedModeVisitor, the default implementation should be:
void VisitExpr_(const LetNode* op) final {
auto pre_visit = [this](const LetNode* op) {
this->VisitExpr(op->var);
this->VisitExpr(op->value);
};
auto post_visit = [this](const LetNode* op) {
this->VisitExpr(op->body);
this->visit_counter_[op] += 1; // avoid call nestedly
};
ExpandANormalForm(op, pre_visit, post_visit);
}
It is the same as:
void VisitExpr_(const LetNode* op) {
this->VisitExpr(op->value);
this->VisitExpr(op->var);
this->VisitExpr(op->body);
}
For ExprMutator/MixedModeMutator, the default implementation should be:
Expr VisitExpr_(const LetNode* op) {
auto pre_visit = [this](const LetNode* op) {
// Rely on the Memoizer to cache pre-visit values
this->VisitExpr(op->var);
this->VisitExpr(op->value);
};
auto post_visit = [this](const LetNode* op) {
// Rely on the Memoizer to cache pre-visit values
Var var = Downcast<Var>(this->VisitExpr(op->var));
Expr value = this->VisitExpr(op->value);
// Visit body and cache the op
Expr body = this->VisitExpr(op->body);
auto expr = GetRef<Expr>(op);
if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
this->memo_[expr] = expr; // avoid call nestedly
} else {
this->memo_[expr] = Let(var, value, body); // avoid call nestedly
}
};
ExpandANormalForm(op, pre_visit, post_visit);
return memo_[GetRef<Expr>(op)];
}
It is the same as:
Expr VisitExpr_(const LetNode* op) {
Var var = Downcast<Var>(this->Mutate(op->var));
auto value = this->Mutate(op->value);
auto body = this->Mutate(op->body);
if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return Let(var, value, body, op->span);
}
}
If the var/value has been mutated in pre_visit, when we use Mutate/VisitExpr, we will get it from cache (memo_).
- if not override, copy the default implementation mentioned above. otherwise,
- find code segment that visits
op->body
:VisitExpr(op->body)
/Mutate(op->body)
- copy the logic before that to
pre_visit
, after that topost_visit
- in
post_visit
, if a visitor, add 1 to the counter - in
post_visit
, if a mutator, put the return value to cache
Examples(see https://github.com/awslabs/raf/commit/1dbc22b904a26d9bc0a5306c2d4a0c70530cbc4c):
- a simple example, note that
ExprVisitor::VisitExpr_(op)
is equal to visitop->var
+op->value
+op->body
// before:
void VisitExpr_(const LetNode* op) final {
this->Update(op->var, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque);
this->Update(op->body, nullptr, kOpaque);
let_binding_.emplace(op->var, op->value);
ExprVisitor::VisitExpr_(op);
}
// after:
void VisitExpr_(const LetNode* op) final {
auto pre_visit = [this](const LetNode* op) {
this->Update(op->var, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque);
this->Update(op->body, nullptr, kOpaque);
let_binding_.emplace(op->var, op->value);
this->VisitExpr(op->var);
this->VisitExpr(op->value);
};
auto post_visit = [this](const LetNode* op) {
this->VisitExpr(op->body);
this->visit_counter_[op] += 1;
};
ExpandANormalForm(op, pre_visit, post_visit);
}
- because body is visited in an if structure, we keep if-else logic in both pre_visit and post_visit
// before
Expr VisitExpr_(const LetNode* op) override {
Expr ovalue = op->value;
Var var = op->var;
Expr value = VisitExpr(ovalue);
if (value.as<ConstantNode>()) {
memo_[var] = value;
return VisitExpr(op->body); // visit body
}
const VarNode* v = value.as<VarNode>();
if (v && var_value_map_.count(v)) {
var_value_map_[op->var.get()] = var_value_map_[v];
} else {
var_value_map_[op->var.get()] = value;
}
var->checked_type_ = value->checked_type();
Expr body = VisitExpr(op->body); // visit body
Let let(var, value, body);
let->checked_type_ = body->checked_type();
return let;
}
// after:
Expr VisitExpr_(const LetNode* op) override {
auto pre_visit = [this](const LetNode* op) {
Expr ovalue = op->value;
Var var = op->var;
Expr value = VisitExpr(ovalue);
if (value.as<ConstantNode>()) {
memo_[var] = value;
} else {
const VarNode* v = value.as<VarNode>();
if (v && var_value_map_.count(v)) {
var_value_map_[op->var.get()] = var_value_map_[v];
} else {
var_value_map_[op->var.get()] = value;
}
var->checked_type_ = value->checked_type();
}
};
auto post_visit = [this](const LetNode* op) {
auto expr = GetRef<Expr>(op);
Expr value = this->VisitExpr(op->value); // get the cached value
Expr body = this->VisitExpr(op->body);
if (value.as<ConstantNode>()) {
this->memo_[expr] = body;
} else {
Let let(op->var, value, body);
let->checked_type_ = body->checked_type();
this->memo_[expr] = let;
}
};
ExpandANormalForm(op, pre_visit, post_visit);
return memo_[GetRef<Expr>(op)];
}
- use map to store local variable(s) (
alias
here) that used in bothpre_visit
andpost_visit
// before:
Expr VisitExpr_(const LetNode* let) {
if (let->value->IsInstance<TupleNode>()) {
tuple_map_.emplace(let->var, Downcast<Tuple>(let->value));
}
auto new_value = VisitExpr(let->value);
bool alias = false;
if (new_value->IsInstance<VarNode>()) {
auto alias_var = Downcast<Var>(new_value);
alias_map_.emplace(let->var.get(), alias_var);
alias = true;
}
auto new_body = VisitExpr(let->body);
if (alias) {
return new_body;
}
return Let(let->var, new_value, new_body);
}
// after:
Expr VisitExpr_(const LetNode* let) {
std::unordered_map<Expr, bool, ObjectPtrHash, ObjectPtrEqual> let_alias_map;
auto pre_visit = [this, &let_alias_map](const LetNode* op) {
Expr expr = GetRef<Expr>(op);
if (op->value->IsInstance<TupleNode>()) {
tuple_map_.emplace(op->var, Downcast<Tuple>(op->value));
}
auto new_value = this->VisitExpr(op->value);
let_alias_map[expr] = false;
if (new_value->IsInstance<VarNode>()) {
auto alias_var = Downcast<Var>(new_value);
alias_map_.emplace(op->var.get(), alias_var);
let_alias_map[expr] = true;
}
};
auto post_visit = [this, &let_alias_map](const LetNode* op) {
auto expr = GetRef<Expr>(op);
auto new_body = VisitExpr(op->body);
if (let_alias_map[expr]) {
this->memo_[expr] = new_body;
} else {
auto new_value = this->VisitExpr(op->value);
this->memo_[expr] = Let(op->var, new_value, new_body);
}
};
ExpandANormalForm(let, pre_visit, post_visit);
return memo_[GetRef<Expr>(let)];
}