Skip to content

Commit

Permalink
Consider *all* Exprs a func uses, not just the RHS, in Li2018 (#8326)
Browse files Browse the repository at this point in the history
Fixes #8312
  • Loading branch information
abadams committed Jun 26, 2024
1 parent cab27d8 commit a4a7531
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions src/DerivativeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ map<string, Box> inference_bounds(const vector<Func> &funcs,
for (auto it = order.rbegin(); it != order.rend(); it++) {
Func func = Func(env[*it]);
// We should already have the bounds of this function
internal_assert(bounds.find(*it) != bounds.end());
internal_assert(bounds.find(*it) != bounds.end()) << *it << "\n";
const Box &current_bounds = bounds[*it];
internal_assert(func.args().size() == current_bounds.size());
// We know the range for each argument of this function
Expand All @@ -262,29 +262,33 @@ map<string, Box> inference_bounds(const vector<Func> &funcs,
scope.push(arg, current_bounds[i]);
}
// Propagate the bounds
for (int update_id = -1; update_id < func.num_update_definitions(); update_id++) {
// For each rhs expression
Tuple tuple = update_id == -1 ? func.values() : func.update_values(update_id);
for (const auto &expr : tuple.as_vector()) {
// For all the immediate dependencies of this expression,
// find the required ranges
map<string, Box> update_bounds =
boxes_required(expr, scope, func_value_bounds);
// Loop over the dependencies
for (const auto &it : update_bounds) {
if (it.first == func.name()) {
// Skip self reference
continue;
}
// Update the bounds, if not exists then create a new one
auto found = bounds.find(it.first);
if (found == bounds.end()) {
bounds[it.first] = it.second;
} else {
Box new_box = box_union(found->second, it.second);
bounds[it.first] = new_box;
}
}
class CollectExprs : public IRMutator {
public:
using IRMutator::mutate;
Expr mutate(const Expr &e) override {
exprs.push_back(e);
return e;
}
std::vector<Expr> exprs;
} expr_collector;
func.function().mutate(&expr_collector);

Expr bundle = Call::make(Int(32), Call::bundle, expr_collector.exprs, Call::PureIntrinsic);
map<string, Box> update_bounds =
boxes_required(bundle, scope, func_value_bounds);
// Loop over the dependencies
for (const auto &it : update_bounds) {
if (it.first == func.name()) {
// Skip self reference
continue;
}
// Update the bounds, if not exists then create a new one
auto found = bounds.find(it.first);
if (found == bounds.end()) {
bounds[it.first] = it.second;
} else {
Box new_box = box_union(found->second, it.second);
bounds[it.first] = new_box;
}
}
for (int i = 0; i < (int)current_bounds.size(); i++) {
Expand Down

0 comments on commit a4a7531

Please sign in to comment.