Skip to content

Commit

Permalink
Skip equations with induction variables and add a basic test.
Browse files Browse the repository at this point in the history
  • Loading branch information
arrangabriel committed Oct 10, 2024
1 parent e228977 commit 73eb3c2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 57 deletions.
36 changes: 16 additions & 20 deletions lib/Dialect/BaseModelica/Transforms/CallCSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,12 @@ void CallCSEPass::collectCallOps(ModelOp modelOp,
templateOps.insert(equationOp.getTemplate());
}

for (const auto templateOp : templateOps) {
templateOp->walk([&](const CallOp callOp) { callOps.push_back(callOp); });
for (auto templateOp : templateOps) {
// Skip templates with induction variables
if (!templateOp.getInductionVariables().empty()) {
continue;
}
templateOp->walk([&](CallOp callOp) { callOps.push_back(callOp); });
}
}

Expand All @@ -96,7 +100,7 @@ void CallCSEPass::buildCallEquivalenceGroups(
auto *equivalenceGroup =
find_if(callEquivalenceGroups, [&](llvm::SmallVector<CallOp> &group) {
// front() is safe as there are no empty groups
const auto representative = mlir::cast<EquationExpressionOpInterface>(
auto representative = mlir::cast<EquationExpressionOpInterface>(
group.front().getOperation());
return callExpression.isEquivalent(representative,
symbolTableCollection);
Expand All @@ -118,6 +122,7 @@ mlir::Operation *CallCSEPass::cloneDefUseChain(mlir::Operation *op,
std::vector<mlir::Operation *> toClone;
std::vector worklist({op});

// DFS through the def-use chain of `op`
while (!worklist.empty()) {
auto *current = worklist.back();
worklist.pop_back();
Expand All @@ -134,13 +139,6 @@ mlir::Operation *CallCSEPass::cloneDefUseChain(mlir::Operation *op,
root = rewriter.clone(*opToClone, mapping);
}
return root;
// for (auto operand : op->getOperands()) {
// if (auto *defOp = operand.getDefiningOp()) {
// cloneDefUseChain(defOp, mapping, rewriter);
// }
// }

// return rewriter.clone(*op, mapping);
}

EquationTemplateOp
Expand All @@ -149,29 +147,27 @@ CallCSEPass::emitCse(ModelOp modelOp, const int emittedCSEs,
mlir::RewriterBase &rewriter) {
assert(!equivalenceGroup.empty() && "equivalenceGroup cannot be empty");
auto representative = equivalenceGroup.front();
const auto loc = representative.getLoc();
// Emit CSE variable
rewriter.setInsertionPointToStart(modelOp.getBody());
auto cseVariable = rewriter.create<VariableOp>(
rewriter.getUnknownLoc(), "_cse" + std::to_string(emittedCSEs),
loc, "_cse" + std::to_string(emittedCSEs),
VariableType::wrap(representative.getResult(0).getType()));

// Create CSE variable driver equation
rewriter.setInsertionPointToEnd(modelOp.getBody());
auto equationTemplateOp =
rewriter.create<EquationTemplateOp>(rewriter.getUnknownLoc());
auto equationTemplateOp = rewriter.create<EquationTemplateOp>(loc);
rewriter.setInsertionPointToStart(equationTemplateOp.createBody(0));

mlir::IRMapping mapping;
auto *clonedRepresentative =
cloneDefUseChain(representative, mapping, rewriter);

const auto cseGetOp =
rewriter.create<VariableGetOp>(rewriter.getUnknownLoc(), cseVariable);
auto lhsOp = rewriter.create<EquationSideOp>(rewriter.getUnknownLoc(),
cseGetOp->getResults());
auto rhsOp = rewriter.create<EquationSideOp>(
rewriter.getUnknownLoc(), clonedRepresentative->getResults());
rewriter.create<EquationSidesOp>(rewriter.getUnknownLoc(), lhsOp, rhsOp);
auto cseGetOp = rewriter.create<VariableGetOp>(loc, cseVariable);
auto lhsOp = rewriter.create<EquationSideOp>(loc, cseGetOp->getResults());
auto rhsOp =
rewriter.create<EquationSideOp>(loc, clonedRepresentative->getResults());
rewriter.create<EquationSidesOp>(loc, lhsOp, rhsOp);

// Replace calls with get to CSE variable
for (auto &callOp : equivalenceGroup) {
Expand Down
51 changes: 14 additions & 37 deletions test/Dialect/BaseModelica/Transforms/CallCSE/basic-cse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,54 +14,27 @@ module @Test {
}

bmodelica.model @M {
// CHECK: bmodelica.variable @[[CSE:_cse0]]
bmodelica.variable @x : !bmodelica.variable<f64>
bmodelica.variable @y : !bmodelica.variable<f64>
bmodelica.variable @z : !bmodelica.variable<f64>

//%t0 = bmodelica.equation_template inductions = [] {
// %0 = bmodelica.variable_get @x : f64
// %lhs = bmodelica.equation_side %0 : tuple<f64>
// %1 = bmodelica.constant 23.0 : f64
// %2 = bmodelica.call @foo(%1) : (f64) -> f64
// %rhs = bmodelica.equation_side %2 : tuple<f64>
// bmodelica.equation_sides %lhs, %rhs : tuple<f64>, tuple<f64>
//}

//%t1 = bmodelica.equation_template inductions = [] {
// %0 = bmodelica.variable_get @y : f64
// %lhs = bmodelica.equation_side %0 : tuple<f64>
// %1 = bmodelica.constant 23.0 : f64
// %2 = bmodelica.call @foo(%1) : (f64) -> f64
// %rhs = bmodelica.equation_side %2 : tuple<f64>
// bmodelica.equation_sides %lhs, %rhs : tuple<f64>, tuple<f64>
//}

%t0 = bmodelica.equation_template inductions = [] {
%0 = bmodelica.variable_get @x : f64
%lhs = bmodelica.equation_side %0 : tuple<f64>
%1 = bmodelica.constant 23.0 : f64
%2 = bmodelica.constant 25.0 : f64
%3 = bmodelica.add %1, %2 : (f64, f64) -> f64
%4 = bmodelica.call @foo(%3) : (f64) -> f64
%rhs = bmodelica.equation_side %4 : tuple<f64>
%1 = bmodelica.constant 1.0 : f64
// CHECK: %[[RES0:.*]] = bmodelica.variable_get @_cse0
// CHECK-NEXT: bmodelica.equation_side %[[RES0]]
%2 = bmodelica.call @foo(%1) : (f64) -> f64
%rhs = bmodelica.equation_side %2 : tuple<f64>
bmodelica.equation_sides %lhs, %rhs : tuple<f64>, tuple<f64>
}

%t1 = bmodelica.equation_template inductions = [] {
%0 = bmodelica.variable_get @y : f64
%lhs = bmodelica.equation_side %0 : tuple<f64>
%1 = bmodelica.constant 23.0 : f64
%2 = bmodelica.constant 25.0 : f64
%3 = bmodelica.add %1, %2 : (f64, f64) -> f64
%4 = bmodelica.call @foo(%3) : (f64) -> f64
%rhs = bmodelica.equation_side %4 : tuple<f64>
bmodelica.equation_sides %lhs, %rhs : tuple<f64>, tuple<f64>
}

%t2 = bmodelica.equation_template inductions = [] {
%0 = bmodelica.variable_get @z : f64
%lhs = bmodelica.equation_side %0 : tuple<f64>
%1 = bmodelica.constant 57.0 : f64
%1 = bmodelica.constant 1.0 : f64
// CHECK: %[[RES1:.*]] = bmodelica.variable_get @_cse0
// CHECK-NEXT: bmodelica.equation_side %[[RES1]]
%2 = bmodelica.call @foo(%1) : (f64) -> f64
%rhs = bmodelica.equation_side %2 : tuple<f64>
bmodelica.equation_sides %lhs, %rhs : tuple<f64>, tuple<f64>
Expand All @@ -70,7 +43,11 @@ module @Test {
bmodelica.dynamic {
bmodelica.equation_instance %t0 : !bmodelica.equation
bmodelica.equation_instance %t1 : !bmodelica.equation
bmodelica.equation_instance %t2 : !bmodelica.equation
}
// CHECK: %[[TEMPLATE:.*]] = bmodelica.equation_template
// CHECK: bmodelica.variable_get @[[CSE]]

// CHECK: bmodelica.dynamic
// CHECK-NEXT: bmodelica.equation_instance %[[TEMPLATE]]
}
}

0 comments on commit 73eb3c2

Please sign in to comment.