Skip to content

Commit

Permalink
build out CallCSE pass
Browse files Browse the repository at this point in the history
  • Loading branch information
arrangabriel committed Oct 10, 2024
1 parent f802714 commit e228977
Show file tree
Hide file tree
Showing 4 changed files with 648 additions and 541 deletions.
180 changes: 165 additions & 15 deletions lib/Dialect/BaseModelica/Transforms/CallCSE.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "marco/Dialect/BaseModelica/Transforms/CallCSE.h"
#include "marco/Dialect/BaseModelica/IR/BaseModelica.h"

#include <marco/AST/Node/Operation.h>

namespace mlir::bmodelica {
#define GEN_PASS_DEF_CALLCSEPASS
#include "marco/Dialect/BaseModelica/Transforms/Passes.h.inc"
Expand All @@ -9,14 +11,36 @@ namespace mlir::bmodelica {
using namespace ::mlir::bmodelica;

namespace {
class CallCSEPass : public impl::CallCSEPassBase<CallCSEPass> {
class CallCSEPass final : public impl::CallCSEPassBase<CallCSEPass> {
public:
using CallCSEPassBase<CallCSEPass>::CallCSEPassBase;
using CallCSEPassBase::CallCSEPassBase;

void runOnOperation() override;

private:
mlir::LogicalResult processModelOp(ModelOp modelOp);
static mlir::LogicalResult processModelOp(ModelOp modelOp);

/// Get all callOps in the model.
static void collectCallOps(ModelOp modelOp,
llvm::SmallVector<CallOp> &callOps);

/// Partition the list of call operations into groups given by
/// EquationExpressionOpInterface::isEquivalent
static void buildCallEquivalenceGroups(
llvm::SmallVector<CallOp> &callOps,
mlir::SymbolTableCollection &symbolTableCollection,
llvm::SmallVector<llvm::SmallVector<CallOp>> &callEquivalenceGroups);

static mlir::Operation *cloneDefUseChain(mlir::Operation *op,
mlir::IRMapping &mapping,
mlir::RewriterBase &rewriter);

/// Replace all calls in the equivalence group with gets to a generated
/// variable. The variable will be driven by an equation derived from the
/// first call in the group.
static EquationTemplateOp emitCse(ModelOp modelOp, int emittedCSEs,
llvm::SmallVector<CallOp> &equivalenceGroup,
mlir::RewriterBase &rewriter);
};
} // namespace

Expand All @@ -29,16 +53,16 @@ void CallCSEPass::runOnOperation() {
}
});

if (mlir::failed(mlir::failableParallelForEach(
if (failed(failableParallelForEach(
&getContext(), modelOps, [&](mlir::Operation *op) {
return processModelOp(mlir::cast<ModelOp>(op));
}))) {
return signalPassFailure();
}
}

mlir::LogicalResult CallCSEPass::processModelOp(ModelOp modelOp) {
mlir::IRRewriter rewriter(modelOp);
void CallCSEPass::collectCallOps(ModelOp modelOp,
llvm::SmallVector<CallOp> &callOps) {
llvm::SmallVector<EquationInstanceOp> initialEquationOps;
llvm::SmallVector<EquationInstanceOp> dynamicEquationOps;

Expand All @@ -47,29 +71,155 @@ mlir::LogicalResult CallCSEPass::processModelOp(ModelOp modelOp) {

llvm::DenseSet<EquationTemplateOp> templateOps;

for (EquationInstanceOp equationOp : initialEquationOps) {
// TODO: Figure out if these should be included
// for (auto equationOp : initialEquationOps) {
// templateOps.insert(equationOp.getTemplate());
// }

for (auto equationOp : dynamicEquationOps) {
templateOps.insert(equationOp.getTemplate());
}

for (EquationInstanceOp equationOp : dynamicEquationOps) {
templateOps.insert(equationOp.getTemplate());
for (const auto templateOp : templateOps) {
templateOp->walk([&](const CallOp callOp) { callOps.push_back(callOp); });
}
}

void CallCSEPass::buildCallEquivalenceGroups(
llvm::SmallVector<CallOp> &callOps,
mlir::SymbolTableCollection &symbolTableCollection,
llvm::SmallVector<llvm::SmallVector<CallOp>> &callEquivalenceGroups) {
for (auto callOp : callOps) {
auto callExpression =
mlir::cast<EquationExpressionOpInterface>(callOp.getOperation());

auto *equivalenceGroup =
find_if(callEquivalenceGroups, [&](llvm::SmallVector<CallOp> &group) {
// front() is safe as there are no empty groups
const auto representative = mlir::cast<EquationExpressionOpInterface>(
group.front().getOperation());
return callExpression.isEquivalent(representative,
symbolTableCollection);
});

if (equivalenceGroup != callEquivalenceGroups.end()) {
// Add equivalent call to existing group
equivalenceGroup->push_back(callOp);
} else {
// Create new equivalence group
callEquivalenceGroups.push_back({callOp});
}
}
}

mlir::Operation *CallCSEPass::cloneDefUseChain(mlir::Operation *op,
mlir::IRMapping &mapping,
mlir::RewriterBase &rewriter) {
std::vector<mlir::Operation *> toClone;
std::vector worklist({op});

while (!worklist.empty()) {
auto *current = worklist.back();
worklist.pop_back();
toClone.push_back(current);
for (auto operand : current->getOperands()) {
if (auto *defOp = operand.getDefiningOp()) {
worklist.push_back(defOp);
}
}
}

mlir::Operation *root = nullptr;
for (auto *opToClone : llvm::reverse(toClone)) {
root = rewriter.clone(*opToClone, mapping);
}
return root;
// for (auto operand : op->getOperands()) {
// if (auto *defOp = operand.getDefiningOp()) {
// cloneDefUseChain(defOp, mapping, rewriter);
// }
// }

// return rewriter.clone(*op, mapping);
}

EquationTemplateOp
CallCSEPass::emitCse(ModelOp modelOp, const int emittedCSEs,
llvm::SmallVector<CallOp> &equivalenceGroup,
mlir::RewriterBase &rewriter) {
assert(!equivalenceGroup.empty() && "equivalenceGroup cannot be empty");
auto representative = equivalenceGroup.front();
// Emit CSE variable
rewriter.setInsertionPointToStart(modelOp.getBody());
auto cseVariable = rewriter.create<VariableOp>(
rewriter.getUnknownLoc(), "_cse" + std::to_string(emittedCSEs),
VariableType::wrap(representative.getResult(0).getType()));

// Create CSE variable driver equation
rewriter.setInsertionPointToEnd(modelOp.getBody());
auto equationTemplateOp =
rewriter.create<EquationTemplateOp>(rewriter.getUnknownLoc());
rewriter.setInsertionPointToStart(equationTemplateOp.createBody(0));

mlir::IRMapping mapping;
auto *clonedRepresentative =
cloneDefUseChain(representative, mapping, rewriter);

const auto cseGetOp =
rewriter.create<VariableGetOp>(rewriter.getUnknownLoc(), cseVariable);
auto lhsOp = rewriter.create<EquationSideOp>(rewriter.getUnknownLoc(),
cseGetOp->getResults());
auto rhsOp = rewriter.create<EquationSideOp>(
rewriter.getUnknownLoc(), clonedRepresentative->getResults());
rewriter.create<EquationSidesOp>(rewriter.getUnknownLoc(), lhsOp, rhsOp);

// Replace calls with get to CSE variable
for (auto &callOp : equivalenceGroup) {
rewriter.setInsertionPoint(callOp);
rewriter.replaceOpWithNewOp<VariableGetOp>(callOp, cseVariable);
}

return equationTemplateOp;
}

mlir::LogicalResult CallCSEPass::processModelOp(ModelOp modelOp) {
mlir::IRRewriter rewriter(modelOp);
mlir::SymbolTableCollection symbolTableCollection;
symbolTableCollection.getSymbolTable(modelOp);

llvm::SmallVector<CallOp> callOps;
collectCallOps(modelOp, callOps);

llvm::SmallVector<llvm::SmallVector<CallOp>> callEquivalenceGroups;
buildCallEquivalenceGroups(callOps, symbolTableCollection,
callEquivalenceGroups);

for (EquationTemplateOp templateOp : templateOps) {
templateOp->walk([&](CallOp callOp) {
callOps.push_back(callOp);
});
int emittedCSEs = 0;
llvm::SmallVector<EquationTemplateOp> cseEquationTemplateOps;

for (auto &equivalenceGroup :
make_filter_range(callEquivalenceGroups,
[](auto &group) { return group.size() > 1; })) {
cseEquationTemplateOps.push_back(
emitCse(modelOp, emittedCSEs++, equivalenceGroup, rewriter));
}


if (!cseEquationTemplateOps.empty()) {
rewriter.setInsertionPointToEnd(modelOp.getBody());
auto dynamicOp = rewriter.create<DynamicOp>(rewriter.getUnknownLoc());
rewriter.setInsertionPointToStart(
rewriter.createBlock(&dynamicOp.getRegion()));
for (auto &equationTemplateOp : cseEquationTemplateOps) {
rewriter.create<EquationInstanceOp>(rewriter.getUnknownLoc(),
equationTemplateOp);
}
}

return mlir::success();
}

namespace mlir::bmodelica {
std::unique_ptr<mlir::Pass> createCallCSEPass() {
std::unique_ptr<Pass> createCallCSEPass() {
return std::make_unique<CallCSEPass>();
}
} // namespace mlir::bmodelica
Loading

0 comments on commit e228977

Please sign in to comment.