From 73eb3c2ae366ffdedf8032c88c6e6615219ab2a0 Mon Sep 17 00:00:00 2001 From: Arran Gabriel Date: Thu, 10 Oct 2024 22:11:43 +0200 Subject: [PATCH] Skip equations with induction variables and add a basic test. --- .../BaseModelica/Transforms/CallCSE.cpp | 36 ++++++------- .../Transforms/CallCSE/basic-cse.mlir | 51 +++++-------------- 2 files changed, 30 insertions(+), 57 deletions(-) diff --git a/lib/Dialect/BaseModelica/Transforms/CallCSE.cpp b/lib/Dialect/BaseModelica/Transforms/CallCSE.cpp index b9af1e3d1..4e8f2abe3 100644 --- a/lib/Dialect/BaseModelica/Transforms/CallCSE.cpp +++ b/lib/Dialect/BaseModelica/Transforms/CallCSE.cpp @@ -80,8 +80,12 @@ void CallCSEPass::collectCallOps(ModelOp modelOp, templateOps.insert(equationOp.getTemplate()); } - for (const auto templateOp : templateOps) { - templateOp->walk([&](const CallOp callOp) { callOps.push_back(callOp); }); + for (auto templateOp : templateOps) { + // Skip templates with induction variables + if (!templateOp.getInductionVariables().empty()) { + continue; + } + templateOp->walk([&](CallOp callOp) { callOps.push_back(callOp); }); } } @@ -96,7 +100,7 @@ void CallCSEPass::buildCallEquivalenceGroups( auto *equivalenceGroup = find_if(callEquivalenceGroups, [&](llvm::SmallVector &group) { // front() is safe as there are no empty groups - const auto representative = mlir::cast( + auto representative = mlir::cast( group.front().getOperation()); return callExpression.isEquivalent(representative, symbolTableCollection); @@ -118,6 +122,7 @@ mlir::Operation *CallCSEPass::cloneDefUseChain(mlir::Operation *op, std::vector toClone; std::vector worklist({op}); + // DFS through the def-use chain of `op` while (!worklist.empty()) { auto *current = worklist.back(); worklist.pop_back(); @@ -134,13 +139,6 @@ mlir::Operation *CallCSEPass::cloneDefUseChain(mlir::Operation *op, root = rewriter.clone(*opToClone, mapping); } return root; - // for (auto operand : op->getOperands()) { - // if (auto *defOp = operand.getDefiningOp()) { - // cloneDefUseChain(defOp, mapping, rewriter); - // } - // } - - // return rewriter.clone(*op, mapping); } EquationTemplateOp @@ -149,29 +147,27 @@ CallCSEPass::emitCse(ModelOp modelOp, const int emittedCSEs, mlir::RewriterBase &rewriter) { assert(!equivalenceGroup.empty() && "equivalenceGroup cannot be empty"); auto representative = equivalenceGroup.front(); + const auto loc = representative.getLoc(); // Emit CSE variable rewriter.setInsertionPointToStart(modelOp.getBody()); auto cseVariable = rewriter.create( - rewriter.getUnknownLoc(), "_cse" + std::to_string(emittedCSEs), + loc, "_cse" + std::to_string(emittedCSEs), VariableType::wrap(representative.getResult(0).getType())); // Create CSE variable driver equation rewriter.setInsertionPointToEnd(modelOp.getBody()); - auto equationTemplateOp = - rewriter.create(rewriter.getUnknownLoc()); + auto equationTemplateOp = rewriter.create(loc); rewriter.setInsertionPointToStart(equationTemplateOp.createBody(0)); mlir::IRMapping mapping; auto *clonedRepresentative = cloneDefUseChain(representative, mapping, rewriter); - const auto cseGetOp = - rewriter.create(rewriter.getUnknownLoc(), cseVariable); - auto lhsOp = rewriter.create(rewriter.getUnknownLoc(), - cseGetOp->getResults()); - auto rhsOp = rewriter.create( - rewriter.getUnknownLoc(), clonedRepresentative->getResults()); - rewriter.create(rewriter.getUnknownLoc(), lhsOp, rhsOp); + auto cseGetOp = rewriter.create(loc, cseVariable); + auto lhsOp = rewriter.create(loc, cseGetOp->getResults()); + auto rhsOp = + rewriter.create(loc, clonedRepresentative->getResults()); + rewriter.create(loc, lhsOp, rhsOp); // Replace calls with get to CSE variable for (auto &callOp : equivalenceGroup) { diff --git a/test/Dialect/BaseModelica/Transforms/CallCSE/basic-cse.mlir b/test/Dialect/BaseModelica/Transforms/CallCSE/basic-cse.mlir index 77499bc32..ba1c41fbb 100644 --- a/test/Dialect/BaseModelica/Transforms/CallCSE/basic-cse.mlir +++ b/test/Dialect/BaseModelica/Transforms/CallCSE/basic-cse.mlir @@ -14,54 +14,27 @@ module @Test { } bmodelica.model @M { + // CHECK: bmodelica.variable @[[CSE:_cse0]] bmodelica.variable @x : !bmodelica.variable bmodelica.variable @y : !bmodelica.variable - bmodelica.variable @z : !bmodelica.variable - - //%t0 = bmodelica.equation_template inductions = [] { - // %0 = bmodelica.variable_get @x : f64 - // %lhs = bmodelica.equation_side %0 : tuple - // %1 = bmodelica.constant 23.0 : f64 - // %2 = bmodelica.call @foo(%1) : (f64) -> f64 - // %rhs = bmodelica.equation_side %2 : tuple - // bmodelica.equation_sides %lhs, %rhs : tuple, tuple - //} - - //%t1 = bmodelica.equation_template inductions = [] { - // %0 = bmodelica.variable_get @y : f64 - // %lhs = bmodelica.equation_side %0 : tuple - // %1 = bmodelica.constant 23.0 : f64 - // %2 = bmodelica.call @foo(%1) : (f64) -> f64 - // %rhs = bmodelica.equation_side %2 : tuple - // bmodelica.equation_sides %lhs, %rhs : tuple, tuple - //} %t0 = bmodelica.equation_template inductions = [] { %0 = bmodelica.variable_get @x : f64 %lhs = bmodelica.equation_side %0 : tuple - %1 = bmodelica.constant 23.0 : f64 - %2 = bmodelica.constant 25.0 : f64 - %3 = bmodelica.add %1, %2 : (f64, f64) -> f64 - %4 = bmodelica.call @foo(%3) : (f64) -> f64 - %rhs = bmodelica.equation_side %4 : tuple + %1 = bmodelica.constant 1.0 : f64 + // CHECK: %[[RES0:.*]] = bmodelica.variable_get @_cse0 + // CHECK-NEXT: bmodelica.equation_side %[[RES0]] + %2 = bmodelica.call @foo(%1) : (f64) -> f64 + %rhs = bmodelica.equation_side %2 : tuple bmodelica.equation_sides %lhs, %rhs : tuple, tuple } %t1 = bmodelica.equation_template inductions = [] { %0 = bmodelica.variable_get @y : f64 %lhs = bmodelica.equation_side %0 : tuple - %1 = bmodelica.constant 23.0 : f64 - %2 = bmodelica.constant 25.0 : f64 - %3 = bmodelica.add %1, %2 : (f64, f64) -> f64 - %4 = bmodelica.call @foo(%3) : (f64) -> f64 - %rhs = bmodelica.equation_side %4 : tuple - bmodelica.equation_sides %lhs, %rhs : tuple, tuple - } - - %t2 = bmodelica.equation_template inductions = [] { - %0 = bmodelica.variable_get @z : f64 - %lhs = bmodelica.equation_side %0 : tuple - %1 = bmodelica.constant 57.0 : f64 + %1 = bmodelica.constant 1.0 : f64 + // CHECK: %[[RES1:.*]] = bmodelica.variable_get @_cse0 + // CHECK-NEXT: bmodelica.equation_side %[[RES1]] %2 = bmodelica.call @foo(%1) : (f64) -> f64 %rhs = bmodelica.equation_side %2 : tuple bmodelica.equation_sides %lhs, %rhs : tuple, tuple @@ -70,7 +43,11 @@ module @Test { bmodelica.dynamic { bmodelica.equation_instance %t0 : !bmodelica.equation bmodelica.equation_instance %t1 : !bmodelica.equation - bmodelica.equation_instance %t2 : !bmodelica.equation } + // CHECK: %[[TEMPLATE:.*]] = bmodelica.equation_template + // CHECK: bmodelica.variable_get @[[CSE]] + + // CHECK: bmodelica.dynamic + // CHECK-NEXT: bmodelica.equation_instance %[[TEMPLATE]] } }