Skip to content

Commit

Permalink
Reorganize update definitions to mirror each other
Browse files Browse the repository at this point in the history
  • Loading branch information
alexreinking committed Nov 25, 2024
1 parent 576654d commit fd04843
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,15 +830,14 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
dim_ordering.emplace(definition.schedule().dims()[i].var, i);
}

using PreservedData = tuple<RVar, Var, Dim>;
vector<PreservedData> preserved_with_dims;
vector<tuple<RVar, Var, Dim>> preserved_with_dims;
for (const auto &[rv, v] : preserved) {
const optional<Dim> rdim = find_dim(definition.schedule().dims(), rv);
internal_assert(rdim);
preserved_with_dims.emplace_back(rv, v, *rdim);
}

std::sort(preserved_with_dims.begin(), preserved_with_dims.end(), [&](const PreservedData &lhs, const PreservedData &rhs) {
std::sort(preserved_with_dims.begin(), preserved_with_dims.end(), [&](const auto &lhs, const auto &rhs) {
return dim_ordering.at(std::get<2>(lhs).var) < dim_ordering.at(std::get<2>(rhs).var);
});

Expand Down Expand Up @@ -901,9 +900,7 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
intm.function().define_update(args, values, intermediate_rdom);

// Intermediate schedule
intm.function().update(0).schedule() = definition.schedule().get_copy();

auto &intm_dims = intm.function().update(0).schedule().dims();
vector<Dim> intm_dims = definition.schedule().dims();

// Replace rvar dims IN the preserved list with their Vars in the INTERMEDIATE Func
for (auto &dim : intm_dims) {
Expand All @@ -924,14 +921,16 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
for (const auto &dim : intm_dims) {
dims.insert(dim.var);
}
for (const Var &dim_v : preserved_vars) {
const optional<Dim> &dim = find_dim(intm.function().definition().schedule().dims(), dim_v);
internal_assert(dim) << "Failed to find " << dim_v.name() << " in list of pure dims";
for (const Var &var : preserved_vars) {
const optional<Dim> &dim = find_dim(intm.function().definition().schedule().dims(), var);
internal_assert(dim) << "Failed to find " << var.name() << " in list of pure dims";
if (!dims.count(dim->var)) {
intm_dims.insert(intm_dims.end() - 1, *dim);
}
}

intm.function().update(0).schedule() = definition.schedule().get_copy();
intm.function().update(0).schedule().dims() = std::move(intm_dims);
intm.function().update(0).schedule().rvars() = intermediate_rdom.domain();
}

Expand Down Expand Up @@ -980,10 +979,10 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
}

definition.args() = dim_vars_exprs;
definition.values() = substitute(preserved_map, prover_result.pattern.ops);
definition.predicate() = preserved_rdom.predicate();
definition.schedule().dims() = std::move(reducing_dims);
definition.schedule().rvars() = preserved_rdom.domain();
definition.values() = substitute(preserved_map, prover_result.pattern.ops);
}

// Clean up the splits lists
Expand Down

0 comments on commit fd04843

Please sign in to comment.