Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function call CSE pass #45

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def EquationExpressionOpInterface
"void", "printExpression",
(ins "::llvm::raw_ostream&":$os,
"const ::llvm::DenseMap<::mlir::Value, int64_t>&":$inductions)>,
InterfaceMethod<
"Check if two expressions are equivalent",
"bool", "isEquivalent",
(ins "mlir::Operation*":$other,
"mlir::SymbolTableCollection&":$symbolTableCollection), "", [{
// Safely assume that the two expressions are different.
return false;
}]>,
InterfaceMethod<
"Get the number of elements.",
"uint64_t", "getNumOfExpressionElements",
Expand Down
13 changes: 13 additions & 0 deletions include/public/marco/Dialect/BaseModelica/Transforms/CallCSE.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef MARCO_DIALECT_BASEMODELICA_TRANSFORMS_CALLCSE_H
#define MARCO_DIALECT_BASEMODELICA_TRANSFORMS_CALLCSE_H

#include "mlir/Pass/Pass.h"

namespace mlir::bmodelica {
#define GEN_PASS_DECL_CALLCSEPASS
#include "marco/Dialect/BaseModelica/Transforms/Passes.h.inc"

std::unique_ptr<mlir::Pass> createCallCSEPass();
} // namespace mlir::bmodelica

#endif // MARCO_DIALECT_BASEMODELICA_TRANSFORMS_CALLCSE_H
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "marco/Dialect/BaseModelica/Transforms/AccessReplacementTest.h"
#include "marco/Dialect/BaseModelica/Transforms/AutomaticDifferentiation.h"
#include "marco/Dialect/BaseModelica/Transforms/BindingEquationConversion.h"
#include "marco/Dialect/BaseModelica/Transforms/CallCSE.h"
#include "marco/Dialect/BaseModelica/Transforms/DerivativeChainRule.h"
#include "marco/Dialect/BaseModelica/Transforms/DerivativesMaterialization.h"
#include "marco/Dialect/BaseModelica/Transforms/EquationAccessSplit.h"
Expand Down
16 changes: 16 additions & 0 deletions include/public/marco/Dialect/BaseModelica/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,22 @@ def EquationFunctionLoopHoistingPass
let constructor = "mlir::bmodelica::createEquationFunctionLoopHoistingPass()";
}

def CallCSEPass
: Pass<"call-cse", "mlir::ModuleOp">
{
let summary = "Move equal function calls to dedicated equation.";

let description = [{
Move equal function calls to dedicated equation.
}];

let dependentDialects = [
"mlir::bmodelica::BaseModelicaDialect"
];

let constructor = "mlir::bmodelica::createCallCSEPass()";
}

def ReadOnlyVariablesPropagationPass
: Pass<"propagate-read-only-variables", "mlir::ModuleOp">
{
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/BaseModelica/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRBaseModelicaTransforms
AllocationOpInterfaceImpl.cpp
BindingEquationConversion.cpp
BufferizableOpInterfaceImpl.cpp
CallCSE.cpp
ConstantMaterializableTypeInterfaceImpl.cpp
DerivableOpInterfaceImpl.cpp
DerivableTypeInterfaceImpl.cpp
Expand Down
220 changes: 220 additions & 0 deletions lib/Dialect/BaseModelica/Transforms/CallCSE.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
#include "marco/Dialect/BaseModelica/Transforms/CallCSE.h"
#include "marco/Dialect/BaseModelica/IR/BaseModelica.h"

#include <marco/AST/Node/Operation.h>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear why is the AST library is being imported here. At this point in the pipeline the AST is already gone.


namespace mlir::bmodelica {
#define GEN_PASS_DEF_CALLCSEPASS
#include "marco/Dialect/BaseModelica/Transforms/Passes.h.inc"
} // namespace mlir::bmodelica

using namespace ::mlir::bmodelica;

namespace {
class CallCSEPass final : public impl::CallCSEPassBase<CallCSEPass> {
public:
using CallCSEPassBase::CallCSEPassBase;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep the template argument explicit so to be resilient to any possible change in the code automatically generated with Tablegen.


void runOnOperation() override;

private:
static mlir::LogicalResult processModelOp(ModelOp modelOp);
Copy link
Member

@mscuttari mscuttari Oct 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to declare everything static. It would also prevent you to update the pass statistics.


/// Get all callOps in the model.
static void collectCallOps(ModelOp modelOp,
llvm::SmallVector<CallOp> &callOps);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use llvm::SmallVectorImpl so not to force a particular amount of elements on the stack.


/// Partition the list of call operations into groups given by
/// EquationExpressionOpInterface::isEquivalent
static void buildCallEquivalenceGroups(
llvm::SmallVector<CallOp> &callOps,
mlir::SymbolTableCollection &symbolTableCollection,
llvm::SmallVector<llvm::SmallVector<CallOp>> &callEquivalenceGroups);

static mlir::Operation *cloneDefUseChain(mlir::Operation *op,
mlir::IRMapping &mapping,
mlir::RewriterBase &rewriter);

/// Replace all calls in the equivalence group with gets to a generated
/// variable. The variable will be driven by an equation derived from the
/// first call in the group.
static EquationTemplateOp emitCse(ModelOp modelOp, int emittedCSEs,
llvm::SmallVector<CallOp> &equivalenceGroup,
mlir::RewriterBase &rewriter);
};
} // namespace

void CallCSEPass::runOnOperation() {
llvm::SmallVector<ModelOp, 1> modelOps;

walkClasses(getOperation(), [&](mlir::Operation *op) {
if (auto modelOp = mlir::dyn_cast<ModelOp>(op)) {
modelOps.push_back(modelOp);
}
});

if (failed(failableParallelForEach(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep the mlir namespace explicit

&getContext(), modelOps, [&](mlir::Operation *op) {
return processModelOp(mlir::cast<ModelOp>(op));
}))) {
return signalPassFailure();
}
}

void CallCSEPass::collectCallOps(ModelOp modelOp,
llvm::SmallVector<CallOp> &callOps) {
llvm::SmallVector<EquationInstanceOp> initialEquationOps;
llvm::SmallVector<EquationInstanceOp> dynamicEquationOps;

modelOp.collectInitialEquations(initialEquationOps);
modelOp.collectMainEquations(dynamicEquationOps);

llvm::DenseSet<EquationTemplateOp> templateOps;

// TODO: Figure out if these should be included
// for (auto equationOp : initialEquationOps) {
// templateOps.insert(equationOp.getTemplate());
// }

for (auto equationOp : dynamicEquationOps) {
templateOps.insert(equationOp.getTemplate());
}

for (auto templateOp : templateOps) {
// Skip templates with induction variables
if (!templateOp.getInductionVariables().empty()) {
continue;
}
templateOp->walk([&](CallOp callOp) { callOps.push_back(callOp); });
}
}

void CallCSEPass::buildCallEquivalenceGroups(
llvm::SmallVector<CallOp> &callOps,
mlir::SymbolTableCollection &symbolTableCollection,
llvm::SmallVector<llvm::SmallVector<CallOp>> &callEquivalenceGroups) {
for (auto callOp : callOps) {
auto callExpression =
mlir::cast<EquationExpressionOpInterface>(callOp.getOperation());

auto *equivalenceGroup =
find_if(callEquivalenceGroups, [&](llvm::SmallVector<CallOp> &group) {
// front() is safe as there are no empty groups
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still, an assert would be nice to have.

auto representative = mlir::cast<EquationExpressionOpInterface>(
group.front().getOperation());
return callExpression.isEquivalent(representative,
symbolTableCollection);
});

if (equivalenceGroup != callEquivalenceGroups.end()) {
// Add equivalent call to existing group
equivalenceGroup->push_back(callOp);
} else {
// Create new equivalence group
callEquivalenceGroups.push_back({callOp});
}
}
}

mlir::Operation *CallCSEPass::cloneDefUseChain(mlir::Operation *op,
mlir::IRMapping &mapping,
mlir::RewriterBase &rewriter) {
std::vector<mlir::Operation *> toClone;
std::vector worklist({op});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use llvm::SmallVector.


// DFS through the def-use chain of `op`
while (!worklist.empty()) {
auto *current = worklist.back();
worklist.pop_back();
toClone.push_back(current);
for (auto operand : current->getOperands()) {
if (auto *defOp = operand.getDefiningOp()) {
worklist.push_back(defOp);
}
}
}

mlir::Operation *root = nullptr;
for (auto *opToClone : llvm::reverse(toClone)) {
root = rewriter.clone(*opToClone, mapping);
}
return root;
}

EquationTemplateOp
CallCSEPass::emitCse(ModelOp modelOp, const int emittedCSEs,
llvm::SmallVector<CallOp> &equivalenceGroup,
mlir::RewriterBase &rewriter) {
assert(!equivalenceGroup.empty() && "equivalenceGroup cannot be empty");
auto representative = equivalenceGroup.front();
const auto loc = representative.getLoc();
// Emit CSE variable
rewriter.setInsertionPointToStart(modelOp.getBody());
auto cseVariable = rewriter.create<VariableOp>(
loc, "_cse" + std::to_string(emittedCSEs),
VariableType::wrap(representative.getResult(0).getType()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the function has multiple results?


// Create CSE variable driver equation
rewriter.setInsertionPointToEnd(modelOp.getBody());
auto equationTemplateOp = rewriter.create<EquationTemplateOp>(loc);
rewriter.setInsertionPointToStart(equationTemplateOp.createBody(0));

mlir::IRMapping mapping;
auto *clonedRepresentative =
cloneDefUseChain(representative, mapping, rewriter);

auto cseGetOp = rewriter.create<VariableGetOp>(loc, cseVariable);
auto lhsOp = rewriter.create<EquationSideOp>(loc, cseGetOp->getResults());
auto rhsOp =
rewriter.create<EquationSideOp>(loc, clonedRepresentative->getResults());
rewriter.create<EquationSidesOp>(loc, lhsOp, rhsOp);

// Replace calls with get to CSE variable
for (auto &callOp : equivalenceGroup) {
rewriter.setInsertionPoint(callOp);
rewriter.replaceOpWithNewOp<VariableGetOp>(callOp, cseVariable);
}

return equationTemplateOp;
}

mlir::LogicalResult CallCSEPass::processModelOp(ModelOp modelOp) {
mlir::IRRewriter rewriter(modelOp);
mlir::SymbolTableCollection symbolTableCollection;

llvm::SmallVector<CallOp> callOps;
collectCallOps(modelOp, callOps);

llvm::SmallVector<llvm::SmallVector<CallOp>> callEquivalenceGroups;
buildCallEquivalenceGroups(callOps, symbolTableCollection,
callEquivalenceGroups);

int emittedCSEs = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible - and would be nice - to add this measurement to the pass statistics.

llvm::SmallVector<EquationTemplateOp> cseEquationTemplateOps;

for (auto &equivalenceGroup :
make_filter_range(callEquivalenceGroups,
[](auto &group) { return group.size() > 1; })) {
cseEquationTemplateOps.push_back(
emitCse(modelOp, emittedCSEs++, equivalenceGroup, rewriter));
}

if (!cseEquationTemplateOps.empty()) {
rewriter.setInsertionPointToEnd(modelOp.getBody());
auto dynamicOp = rewriter.create<DynamicOp>(rewriter.getUnknownLoc());
rewriter.setInsertionPointToStart(
rewriter.createBlock(&dynamicOp.getRegion()));
for (auto &equationTemplateOp : cseEquationTemplateOps) {
rewriter.create<EquationInstanceOp>(rewriter.getUnknownLoc(),
equationTemplateOp);
}
}

return mlir::success();
}

namespace mlir::bmodelica {
std::unique_ptr<Pass> createCallCSEPass() {
return std::make_unique<CallCSEPass>();
}
} // namespace mlir::bmodelica
Loading