From dda7c8bc00d890d6b8455e4ed66174c3504a3104 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Mon, 7 Oct 2024 11:37:31 +0200 Subject: [PATCH 1/4] Cancel modifications if cycles solving by substitution failed --- .../Transforms/SCCSolvingBySubstitution.cpp | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp b/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp index 5e88da95a..3832fb140 100644 --- a/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp +++ b/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp @@ -544,7 +544,8 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( llvm::SmallVector allNewEquations; llvm::DenseSet nonExplicitableEquations; - auto createSCCsFn = llvm::make_scope_exit([&]() { + auto createSCCsOnSuccessFn = llvm::make_scope_exit([&]() { + // Erase the equations that have been discarded. for (MatchedEquationInstanceOp equation : toBeErased) { rewriter.eraseOp(equation); } @@ -567,17 +568,10 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( return mlir::failure(); } - if (cycles.empty()) { - return mlir::success(); - } - - bool atLeastOneChanged; int64_t currentIteration = 0; while (!cycles.empty() && currentIteration++ < maxIterations) { // Try to solve one cycle. - atLeastOneChanged = false; - for (const auto& cycle : llvm::enumerate(cycles)) { currentEquations.clear(); @@ -644,16 +638,10 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( allNewEquations.push_back(newEquation); } - atLeastOneChanged = true; break; } } - if (!atLeastOneChanged) { - // The IR can't be modified more. - return mlir::LogicalResult::success(); - } - // Search for the remaining cycles. cycles.clear(); @@ -675,6 +663,16 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( } } + if (!cycles.empty()) { + // Could not solve the cycle. + // Cancel the modifications and keep the original SCC. + createSCCsOnSuccessFn.release(); + + for (MatchedEquationInstanceOp equationOp : allNewEquations) { + equationOp.erase(); + } + } + return mlir::success(); } From 17e47b0fa3b10285bf86d38bdd56375fd2c0ef26 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Mon, 7 Oct 2024 11:37:53 +0200 Subject: [PATCH 2/4] Add comments --- .../Transforms/SCCSolvingBySubstitution.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp b/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp index 3832fb140..bbce34ddb 100644 --- a/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp +++ b/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp @@ -91,6 +91,8 @@ namespace ModelOp modelOp, SCCOp scc); + /// Detect the SCCs among a set of equations and create the SCC + /// operations containing them. void createSCCs( mlir::RewriterBase& rewriter, mlir::SymbolTableCollection& symbolTableCollection, @@ -534,14 +536,22 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( ModelOp modelOp, SCCOp scc) { + // The equations that initially compose the SCC. llvm::SmallVector originalEquations; scc.collectEquations(originalEquations); + // The equations to be considered during an iteration. + // Initially, they are the equations within the SCC. llvm::SmallVector currentEquations( originalEquations.begin(), originalEquations.end()); + // The equations to be erased after having solved the SCC. llvm::DenseSet toBeErased; + + // The newly inserted equations. llvm::SmallVector allNewEquations; + + // The set of equations that have deemed to be non-explicitable. llvm::DenseSet nonExplicitableEquations; auto createSCCsOnSuccessFn = llvm::make_scope_exit([&]() { @@ -550,6 +560,7 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( rewriter.eraseOp(equation); } + // Collect the remaining equations. llvm::SmallVector resultEquations; for (MatchedEquationInstanceOp equation : @@ -557,10 +568,12 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( resultEquations.push_back(equation); } + // Compute the new SCCs and erase the original one. createSCCs(rewriter, symbolTableCollection, modelOp, scc, resultEquations); rewriter.eraseOp(scc); }); + // Compute the cyclic dependencies within the SCC. llvm::SmallVector cycles; if (mlir::failed(getCycles( From 922f6e456d2292b0d00b5ddb7aad649238001e9b Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Mon, 7 Oct 2024 11:38:01 +0200 Subject: [PATCH 3/4] Reformat code --- .../Transforms/SCCSolvingBySubstitution.cpp | 443 ++++++++---------- 1 file changed, 203 insertions(+), 240 deletions(-) diff --git a/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp b/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp index bbce34ddb..ada21095b 100644 --- a/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp +++ b/lib/Dialect/BaseModelica/Transforms/SCCSolvingBySubstitution.cpp @@ -9,39 +9,35 @@ #define DEBUG_TYPE "scc-solving-by-substitution" -namespace mlir::bmodelica -{ +namespace mlir::bmodelica { #define GEN_PASS_DEF_SCCSOLVINGBYSUBSTITUTIONPASS #include "marco/Dialect/BaseModelica/Transforms/Passes.h.inc" -} +} // namespace mlir::bmodelica using namespace ::mlir::bmodelica; using namespace ::mlir::bmodelica::bridge; -namespace -{ - struct CyclicEquation - { - MatchedEquationInstanceOp equation; - IndexSet equationIndices; - VariableAccess writeAccess; - IndexSet writtenVariableIndices; - VariableAccess readAccess; - IndexSet readVariableIndices; - }; -} +namespace { +struct CyclicEquation { + MatchedEquationInstanceOp equation; + IndexSet equationIndices; + VariableAccess writeAccess; + IndexSet writtenVariableIndices; + VariableAccess readAccess; + IndexSet readVariableIndices; +}; +} // namespace using Cycle = llvm::SmallVector; -static void printCycle(llvm::raw_ostream& os, const Cycle& cycle) -{ - for (const CyclicEquation& cyclicEquation : cycle) { +static void printCycle(llvm::raw_ostream &os, const Cycle &cycle) { + for (const CyclicEquation &cyclicEquation : cycle) { os << cyclicEquation.writeAccess.getVariable() << " -> "; } os << cycle.back().readAccess.getVariable() << "\n"; - for (const CyclicEquation& cyclicEquation : cycle) { + for (const CyclicEquation &cyclicEquation : cycle) { MatchedEquationInstanceOp equationOp = cyclicEquation.equation; os << "[writing " << cyclicEquation.writeAccess.getVariable() << "] "; equationOp.printInline(llvm::dbgs()); @@ -49,72 +45,63 @@ static void printCycle(llvm::raw_ostream& os, const Cycle& cycle) } } -namespace -{ - class SCCSolvingBySubstitutionPass - : public mlir::bmodelica::impl::SCCSolvingBySubstitutionPassBase< - SCCSolvingBySubstitutionPass>, - public VariableAccessAnalysis::AnalysisProvider - { - public: - using SCCSolvingBySubstitutionPassBase - ::SCCSolvingBySubstitutionPassBase; - - void runOnOperation() override; - - std::optional> - getCachedVariableAccessAnalysis(EquationTemplateOp op) override; - - private: - std::optional> - getVariableAccessAnalysis( - EquationTemplateOp equationTemplate, - mlir::SymbolTableCollection& symbolTableCollection); - - mlir::LogicalResult processModelOp(ModelOp modelOp); - - mlir::LogicalResult getCycles( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTableCollection, - ModelOp modelOp, - llvm::ArrayRef equations); - - mlir::LogicalResult solveCycles( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection, - ModelOp modelOp, - llvm::ArrayRef SCCs); - - mlir::LogicalResult solveCycle( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection, - ModelOp modelOp, - SCCOp scc); - - /// Detect the SCCs among a set of equations and create the SCC - /// operations containing them. - void createSCCs( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection, - ModelOp modelOp, - SCCOp originalSCC, - llvm::ArrayRef equations); - - mlir::LogicalResult cleanModelOp(ModelOp modelOp); - }; -} - -void SCCSolvingBySubstitutionPass::runOnOperation() -{ +namespace { +class SCCSolvingBySubstitutionPass + : public mlir::bmodelica::impl::SCCSolvingBySubstitutionPassBase< + SCCSolvingBySubstitutionPass>, + public VariableAccessAnalysis::AnalysisProvider { +public: + using SCCSolvingBySubstitutionPassBase< + SCCSolvingBySubstitutionPass>::SCCSolvingBySubstitutionPassBase; + + void runOnOperation() override; + + std::optional> + getCachedVariableAccessAnalysis(EquationTemplateOp op) override; + +private: + std::optional> + getVariableAccessAnalysis(EquationTemplateOp equationTemplate, + mlir::SymbolTableCollection &symbolTableCollection); + + mlir::LogicalResult processModelOp(ModelOp modelOp); + + mlir::LogicalResult + getCycles(llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTableCollection, ModelOp modelOp, + llvm::ArrayRef equations); + + mlir::LogicalResult + solveCycles(mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection, + ModelOp modelOp, llvm::ArrayRef SCCs); + + mlir::LogicalResult + solveCycle(mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection, + ModelOp modelOp, SCCOp scc); + + /// Detect the SCCs among a set of equations and create the SCC + /// operations containing them. + void createSCCs(mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection, + ModelOp modelOp, SCCOp originalSCC, + llvm::ArrayRef equations); + + mlir::LogicalResult cleanModelOp(ModelOp modelOp); +}; +} // namespace + +void SCCSolvingBySubstitutionPass::runOnOperation() { llvm::SmallVector modelOps; - walkClasses(getOperation(), [&](mlir::Operation* op) { + walkClasses(getOperation(), [&](mlir::Operation *op) { if (auto modelOp = mlir::dyn_cast(op)) { modelOps.push_back(modelOp); } }); - auto runFn = [&](mlir::Operation* op) { + auto runFn = [&](mlir::Operation *op) { auto modelOp = mlir::cast(op); LLVM_DEBUG(llvm::dbgs() << "Input model:\n" << modelOp << "\n"); @@ -130,18 +117,18 @@ void SCCSolvingBySubstitutionPass::runOnOperation() return mlir::success(); }; - if (mlir::failed(mlir::failableParallelForEach( - &getContext(), modelOps, runFn))) { + if (mlir::failed( + mlir::failableParallelForEach(&getContext(), modelOps, runFn))) { return signalPassFailure(); } } std::optional> -SCCSolvingBySubstitutionPass::getCachedVariableAccessAnalysis(EquationTemplateOp op) -{ +SCCSolvingBySubstitutionPass::getCachedVariableAccessAnalysis( + EquationTemplateOp op) { mlir::ModuleOp moduleOp = getOperation(); - mlir::Operation* parentOp = op->getParentOp(); - llvm::SmallVector parentOps; + mlir::Operation *parentOp = op->getParentOp(); + llvm::SmallVector parentOps; while (parentOp != moduleOp) { parentOps.push_back(parentOp); @@ -150,7 +137,7 @@ SCCSolvingBySubstitutionPass::getCachedVariableAccessAnalysis(EquationTemplateOp mlir::AnalysisManager analysisManager = getAnalysisManager(); - for (mlir::Operation* currentParentOp : llvm::reverse(parentOps)) { + for (mlir::Operation *currentParentOp : llvm::reverse(parentOps)) { analysisManager = analysisManager.nest(currentParentOp); } @@ -160,11 +147,10 @@ SCCSolvingBySubstitutionPass::getCachedVariableAccessAnalysis(EquationTemplateOp std::optional> SCCSolvingBySubstitutionPass::getVariableAccessAnalysis( EquationTemplateOp equationTemplate, - mlir::SymbolTableCollection& symbolTableCollection) -{ + mlir::SymbolTableCollection &symbolTableCollection) { mlir::ModuleOp moduleOp = getOperation(); - mlir::Operation* parentOp = equationTemplate->getParentOp(); - llvm::SmallVector parentOps; + mlir::Operation *parentOp = equationTemplate->getParentOp(); + llvm::SmallVector parentOps; while (parentOp != moduleOp) { parentOps.push_back(parentOp); @@ -173,7 +159,7 @@ SCCSolvingBySubstitutionPass::getVariableAccessAnalysis( mlir::AnalysisManager analysisManager = getAnalysisManager(); - for (mlir::Operation* op : llvm::reverse(parentOps)) { + for (mlir::Operation *op : llvm::reverse(parentOps)) { analysisManager = analysisManager.nest(op); } @@ -183,7 +169,7 @@ SCCSolvingBySubstitutionPass::getVariableAccessAnalysis( return *analysis; } - auto& analysis = analysisManager.getChildAnalysis( + auto &analysis = analysisManager.getChildAnalysis( equationTemplate); if (mlir::failed(analysis.initialize(symbolTableCollection))) { @@ -193,9 +179,8 @@ SCCSolvingBySubstitutionPass::getVariableAccessAnalysis( return std::reference_wrapper(analysis); } -mlir::LogicalResult SCCSolvingBySubstitutionPass::processModelOp( - ModelOp modelOp) -{ +mlir::LogicalResult +SCCSolvingBySubstitutionPass::processModelOp(ModelOp modelOp) { mlir::IRRewriter rewriter(&getContext()); // Collect the equations. @@ -209,10 +194,10 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::processModelOp( // Perform the solving process on the 'initial conditions' model. if (!initialSCCs.empty()) { - if (mlir::failed(solveCycles( - rewriter, symbolTableCollection, modelOp, initialSCCs))) { + if (mlir::failed(solveCycles(rewriter, symbolTableCollection, modelOp, + initialSCCs))) { modelOp.emitError() - << "Cycles solving failed for the 'initial conditions' model"; + << "Cycles solving failed for the 'initial conditions' model"; return mlir::failure(); } @@ -220,8 +205,8 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::processModelOp( // Perform the solving process on the 'main' model. if (!mainSCCs.empty()) { - if (mlir::failed(solveCycles( - rewriter, symbolTableCollection, modelOp, mainSCCs))) { + if (mlir::failed( + solveCycles(rewriter, symbolTableCollection, modelOp, mainSCCs))) { modelOp.emitError() << "Cycles solving failed for the 'main' model"; return mlir::failure(); } @@ -231,11 +216,9 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::processModelOp( } mlir::LogicalResult SCCSolvingBySubstitutionPass::getCycles( - llvm::SmallVectorImpl& result, - mlir::SymbolTableCollection& symbolTableCollection, - ModelOp modelOp, - llvm::ArrayRef equations) -{ + llvm::SmallVectorImpl &result, + mlir::SymbolTableCollection &symbolTableCollection, ModelOp modelOp, + llvm::ArrayRef equations) { LLVM_DEBUG({ llvm::dbgs() << "Searching cycles among the following equations:\n"; @@ -256,13 +239,13 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::getCycles( }); llvm::SmallVector> variableBridges; - llvm::DenseMap variablesMap; + llvm::DenseMap variablesMap; llvm::SmallVector> equationBridges; - llvm::SmallVector equationPtrs; + llvm::SmallVector equationPtrs; for (VariableOp variableOp : modelOp.getVariables()) { - auto& bridge = variableBridges.emplace_back( - VariableBridge::build(variableOp)); + auto &bridge = + variableBridges.emplace_back(VariableBridge::build(variableOp)); auto symbolRefAttr = mlir::SymbolRefAttr::get(variableOp.getSymNameAttr()); variablesMap[symbolRefAttr] = bridge.get(); @@ -272,38 +255,37 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::getCycles( auto variableAccessAnalysis = getVariableAccessAnalysis( equation.getTemplate(), symbolTableCollection); - auto& bridge = equationBridges.emplace_back( - MatchedEquationBridge::build( - equation, symbolTableCollection, *variableAccessAnalysis, - variablesMap)); + auto &bridge = equationBridges.emplace_back( + MatchedEquationBridge::build(equation, symbolTableCollection, + *variableAccessAnalysis, variablesMap)); equationPtrs.push_back(bridge.get()); } - using DependencyGraph = marco::modeling::DependencyGraph< - VariableBridge*, MatchedEquationBridge*>; + using DependencyGraph = + marco::modeling::DependencyGraph; DependencyGraph dependencyGraph(&getContext()); dependencyGraph.addEquations(equationPtrs); auto cycles = dependencyGraph.getEquationsCycles(); - for (auto& cycle : cycles) { - auto& resultCycle = result.emplace_back(); - - for (auto& cyclicEquation : cycle) { - resultCycle.emplace_back(CyclicEquation{ - dependencyGraph[cyclicEquation.equation]->op, - std::move(cyclicEquation.equationIndices), - std::move(cyclicEquation.writeAccess).getProperty(), - std::move(cyclicEquation.writtenVariableIndices), - std::move(cyclicEquation.readAccess).getProperty(), - std::move(cyclicEquation.readVariableIndices) - }); + for (auto &cycle : cycles) { + auto &resultCycle = result.emplace_back(); + + for (auto &cyclicEquation : cycle) { + resultCycle.emplace_back( + CyclicEquation{dependencyGraph[cyclicEquation.equation]->op, + std::move(cyclicEquation.equationIndices), + std::move(cyclicEquation.writeAccess).getProperty(), + std::move(cyclicEquation.writtenVariableIndices), + std::move(cyclicEquation.readAccess).getProperty(), + std::move(cyclicEquation.readVariableIndices)}); } } - llvm::sort(result, [](const Cycle& first, const Cycle& second) { + llvm::sort(result, [](const Cycle &first, const Cycle &second) { return first.size() > second.size(); }); @@ -312,13 +294,11 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::getCycles( } static mlir::LogicalResult solveCycle( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection, - const Cycle& cycle, + mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection, const Cycle &cycle, size_t index, - llvm::SmallVectorImpl& newEquations, - llvm::DenseSet& nonExplicitableEquations) -{ + llvm::SmallVectorImpl &newEquations, + llvm::DenseSet &nonExplicitableEquations) { if (index + 1 == cycle.size()) { MatchedEquationInstanceOp equationOp = cycle[index].equation; rewriter.setInsertionPoint(equationOp); @@ -337,42 +317,41 @@ static mlir::LogicalResult solveCycle( } }); - if (mlir::failed(solveCycle( - rewriter, symbolTableCollection, cycle, index + 1, - writingEquations, nonExplicitableEquations))) { + if (mlir::failed(solveCycle(rewriter, symbolTableCollection, cycle, index + 1, + writingEquations, nonExplicitableEquations))) { return mlir::failure(); } - const CyclicEquation& readingEquation = cycle[index]; + const CyclicEquation &readingEquation = cycle[index]; MatchedEquationInstanceOp readingEquationOp = readingEquation.equation; LLVM_DEBUG(llvm::dbgs() << "Cycle index: " << index << "\n"); LLVM_DEBUG({ - llvm::dbgs() << "Reading equation:\n"; - readingEquationOp.printInline(llvm::dbgs()); + llvm::dbgs() << "Reading equation:\n"; + readingEquationOp.printInline(llvm::dbgs()); - llvm::dbgs() << "\n" - << "Read variable: " - << readingEquation.readAccess.getVariable() - << "\n"; + llvm::dbgs() << "\n" + << "Read variable: " + << readingEquation.readAccess.getVariable() << "\n"; }); - const AccessFunction& readAccessFunction = + const AccessFunction &readAccessFunction = readingEquation.readAccess.getAccessFunction(); for (MatchedEquationInstanceOp writingEquationOp : writingEquations) { LLVM_DEBUG({ - llvm::dbgs() << "Writing equation:\n"; - writingEquationOp.printInline(llvm::dbgs()); - llvm::dbgs() << "\n"; + llvm::dbgs() << "Writing equation:\n"; + writingEquationOp.printInline(llvm::dbgs()); + llvm::dbgs() << "\n"; }); MatchedEquationInstanceOp explicitWritingEquationOp = writingEquationOp.cloneAndExplicitate(rewriter, symbolTableCollection); if (!explicitWritingEquationOp) { - LLVM_DEBUG(llvm::dbgs() << "The writing equation can't be made explicit\n"); + LLVM_DEBUG(llvm::dbgs() + << "The writing equation can't be made explicit\n"); nonExplicitableEquations.insert(writingEquationOp); return mlir::failure(); } @@ -383,9 +362,8 @@ static mlir::LogicalResult solveCycle( llvm::dbgs() << "\n"; }); - auto removeExplicitEquation = llvm::make_scope_exit([&]() { - rewriter.eraseOp(explicitWritingEquationOp); - }); + auto removeExplicitEquation = llvm::make_scope_exit( + [&]() { rewriter.eraseOp(explicitWritingEquationOp); }); auto explicitWriteAccess = explicitWritingEquationOp.getMatchedAccess(symbolTableCollection); @@ -394,7 +372,7 @@ static mlir::LogicalResult solveCycle( return mlir::failure(); } - const AccessFunction& writeAccessFunction = + const AccessFunction &writeAccessFunction = explicitWriteAccess->getAccessFunction(); IndexSet writingEquationIndices = @@ -411,12 +389,12 @@ static mlir::LogicalResult solveCycle( readingEquationIndices = readAccessFunction.inverseMap( writtenVariableIndices, readingEquation.equationIndices); - readingEquationIndices = readingEquationIndices.intersect( - readingEquation.equationIndices); + readingEquationIndices = + readingEquationIndices.intersect(readingEquation.equationIndices); } std::optional> - optionalReadingEquationIndices = std::nullopt; + optionalReadingEquationIndices = std::nullopt; if (!readingEquationIndices.empty()) { optionalReadingEquationIndices = @@ -424,12 +402,9 @@ static mlir::LogicalResult solveCycle( } if (mlir::failed(readingEquationOp.cloneWithReplacedAccess( - rewriter, - optionalReadingEquationIndices, - readingEquation.readAccess, - explicitWritingEquationOp.getTemplate(), - *explicitWriteAccess, - newEquations))) { + rewriter, optionalReadingEquationIndices, + readingEquation.readAccess, explicitWritingEquationOp.getTemplate(), + *explicitWriteAccess, newEquations))) { return mlir::failure(); } } @@ -447,38 +422,35 @@ static mlir::LogicalResult solveCycle( } static mlir::LogicalResult solveCycle( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection, - const Cycle& cycle, - llvm::SmallVectorImpl& newEquations, - llvm::DenseSet& nonExplicitableEquations) -{ + mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection, const Cycle &cycle, + llvm::SmallVectorImpl &newEquations, + llvm::DenseSet &nonExplicitableEquations) { LLVM_DEBUG({ llvm::dbgs() << "Solving cycle composed by the following equations:\n"; - for (const CyclicEquation& cyclicEquation : cycle) { + for (const CyclicEquation &cyclicEquation : cycle) { llvm::dbgs() << cyclicEquation.writeAccess.getVariable() << " -> "; } llvm::dbgs() << cycle.back().readAccess.getVariable() << "\n"; - for (const CyclicEquation& cyclicEquation : cycle) { + for (const CyclicEquation &cyclicEquation : cycle) { MatchedEquationInstanceOp equationOp = cyclicEquation.equation; - llvm::dbgs() << "[writing " << cyclicEquation.writeAccess.getVariable() << "] "; + llvm::dbgs() << "[writing " << cyclicEquation.writeAccess.getVariable() + << "] "; equationOp.printInline(llvm::dbgs()); llvm::dbgs() << "\n"; } }); - return ::solveCycle( - rewriter, symbolTableCollection, cycle, 0, - newEquations, nonExplicitableEquations); + return ::solveCycle(rewriter, symbolTableCollection, cycle, 0, newEquations, + nonExplicitableEquations); } -static bool isContainedInBiggerCycle( - llvm::ArrayRef cycles, size_t cycleIndex) -{ - const Cycle& cycle = cycles[cycleIndex]; +static bool isContainedInBiggerCycle(llvm::ArrayRef cycles, + size_t cycleIndex) { + const Cycle &cycle = cycles[cycleIndex]; llvm::DenseSet involvedEquations; for (size_t i = 1, e = cycle.size(); i < e; ++i) { @@ -489,7 +461,7 @@ static bool isContainedInBiggerCycle( // currently analyzed cycle. for (size_t otherCycleIndex = 0; otherCycleIndex < cycleIndex; ++otherCycleIndex) { - const Cycle& otherCycle = cycles[otherCycleIndex]; + const Cycle &otherCycle = cycles[otherCycleIndex]; if (otherCycle.size() <= cycle.size()) { break; @@ -515,14 +487,12 @@ static bool isContainedInBiggerCycle( } mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycles( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection, - ModelOp modelOp, - llvm::ArrayRef SCCs) -{ + mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection, ModelOp modelOp, + llvm::ArrayRef SCCs) { for (SCCOp scc : SCCs) { - if (mlir::failed(solveCycle( - rewriter, symbolTableCollection, modelOp, scc))) { + if (mlir::failed( + solveCycle(rewriter, symbolTableCollection, modelOp, scc))) { return mlir::failure(); } } @@ -531,11 +501,9 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycles( } mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection, - ModelOp modelOp, - SCCOp scc) -{ + mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection, ModelOp modelOp, + SCCOp scc) { // The equations that initially compose the SCC. llvm::SmallVector originalEquations; scc.collectEquations(originalEquations); @@ -576,8 +544,8 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( // Compute the cyclic dependencies within the SCC. llvm::SmallVector cycles; - if (mlir::failed(getCycles( - cycles, symbolTableCollection, modelOp, currentEquations))) { + if (mlir::failed(getCycles(cycles, symbolTableCollection, modelOp, + currentEquations))) { return mlir::failure(); } @@ -585,12 +553,13 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( while (!cycles.empty() && currentIteration++ < maxIterations) { // Try to solve one cycle. - for (const auto& cycle : llvm::enumerate(cycles)) { + for (const auto &cycle : llvm::enumerate(cycles)) { currentEquations.clear(); if (isContainedInBiggerCycle(cycles, cycle.index())) { LLVM_DEBUG({ - llvm::dbgs() << "The following cycle is skipped for being part of a bigger SCC\n"; + llvm::dbgs() << "The following cycle is skipped for being part of a " + "bigger SCC\n"; printCycle(llvm::dbgs(), cycle.value()); }); @@ -601,7 +570,8 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( for (size_t i = 1, e = cycle.value().size(); i < e; ++i) { if (nonExplicitableEquations.contains(cycle.value()[i].equation)) { LLVM_DEBUG({ - llvm::dbgs() << "The following cycle is skipped for having a non-explicitable equation\n"; + llvm::dbgs() << "The following cycle is skipped for having a " + "non-explicitable equation\n"; printCycle(llvm::dbgs(), cycle.value()); }); @@ -611,9 +581,9 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( llvm::SmallVector newEquations; - if (mlir::succeeded(::solveCycle( - rewriter, symbolTableCollection, cycle.value(), - newEquations, nonExplicitableEquations))) { + if (mlir::succeeded(::solveCycle(rewriter, symbolTableCollection, + cycle.value(), newEquations, + nonExplicitableEquations))) { MatchedEquationInstanceOp firstEquation = cycle.value()[0].equation; IndexSet originalIterationSpace = firstEquation.getIterationSpace(); @@ -626,9 +596,9 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( rewriter.setInsertionPoint(firstEquation); - for (const MultidimensionalRange& range : llvm::make_range( - remainingIndices.rangesBegin(), - remainingIndices.rangesEnd())) { + for (const MultidimensionalRange &range : + llvm::make_range(remainingIndices.rangesBegin(), + remainingIndices.rangesEnd())) { auto clonedOp = mlir::cast( rewriter.clone(*firstEquation.getOperation())); @@ -636,9 +606,8 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( MultidimensionalRange explicitRange = range.takeFirstDimensions(indices->getValue().rank()); - clonedOp.setIndicesAttr( - MultidimensionalRangeAttr::get( - rewriter.getContext(), std::move(explicitRange))); + clonedOp.setIndicesAttr(MultidimensionalRangeAttr::get( + rewriter.getContext(), std::move(explicitRange))); } currentEquations.push_back(clonedOp); @@ -670,8 +639,8 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( } } - if (mlir::failed(getCycles( - cycles, symbolTableCollection, modelOp, currentEquations))) { + if (mlir::failed(getCycles(cycles, symbolTableCollection, modelOp, + currentEquations))) { return mlir::failure(); } } @@ -690,22 +659,19 @@ mlir::LogicalResult SCCSolvingBySubstitutionPass::solveCycle( } void SCCSolvingBySubstitutionPass::createSCCs( - mlir::RewriterBase& rewriter, - mlir::SymbolTableCollection& symbolTableCollection, - ModelOp modelOp, - SCCOp originalSCC, - llvm::ArrayRef equations) -{ + mlir::RewriterBase &rewriter, + mlir::SymbolTableCollection &symbolTableCollection, ModelOp modelOp, + SCCOp originalSCC, llvm::ArrayRef equations) { mlir::OpBuilder::InsertionGuard guard(rewriter); llvm::SmallVector> variableBridges; - llvm::DenseMap variablesMap; + llvm::DenseMap variablesMap; llvm::SmallVector> equationBridges; - llvm::SmallVector equationPtrs; + llvm::SmallVector equationPtrs; for (VariableOp variableOp : modelOp.getVariables()) { - auto& bridge = variableBridges.emplace_back( - VariableBridge::build(variableOp)); + auto &bridge = + variableBridges.emplace_back(VariableBridge::build(variableOp)); auto symbolRefAttr = mlir::SymbolRefAttr::get(variableOp.getSymNameAttr()); variablesMap[symbolRefAttr] = bridge.get(); @@ -715,16 +681,16 @@ void SCCSolvingBySubstitutionPass::createSCCs( auto variableAccessAnalysis = getVariableAccessAnalysis( equation.getTemplate(), symbolTableCollection); - auto& bridge = equationBridges.emplace_back( - MatchedEquationBridge::build( - equation, symbolTableCollection, *variableAccessAnalysis, - variablesMap)); + auto &bridge = equationBridges.emplace_back( + MatchedEquationBridge::build(equation, symbolTableCollection, + *variableAccessAnalysis, variablesMap)); equationPtrs.push_back(bridge.get()); } - using DependencyGraph = marco::modeling::DependencyGraph< - VariableBridge*, MatchedEquationBridge*>; + using DependencyGraph = + marco::modeling::DependencyGraph; DependencyGraph dependencyGraph(&getContext()); dependencyGraph.addEquations(equationPtrs); @@ -734,21 +700,21 @@ void SCCSolvingBySubstitutionPass::createSCCs( rewriter.setInsertionPointAfter(originalSCC); - for (const DependencyGraph::SCC& scc : SCCs) { + for (const DependencyGraph::SCC &scc : SCCs) { auto sccOp = rewriter.create(modelOp.getLoc()); mlir::OpBuilder::InsertionGuard sccGuard(rewriter); rewriter.setInsertionPointToStart( rewriter.createBlock(&sccOp.getBodyRegion())); - for (const auto& sccElement : scc) { - const auto& equation = dependencyGraph[*sccElement]; - const IndexSet& indices = sccElement.getIndices(); + for (const auto &sccElement : scc) { + const auto &equation = dependencyGraph[*sccElement]; + const IndexSet &indices = sccElement.getIndices(); size_t numOfInductions = equation->op.getInductionVariables().size(); bool isScalarEquation = numOfInductions == 0; - for (const MultidimensionalRange& matchedEquationRange : + for (const MultidimensionalRange &matchedEquationRange : llvm::make_range(indices.rangesBegin(), indices.rangesEnd())) { auto clonedOp = mlir::cast( rewriter.clone(*equation->op.getOperation())); @@ -765,23 +731,20 @@ void SCCSolvingBySubstitutionPass::createSCCs( } } -mlir::LogicalResult SCCSolvingBySubstitutionPass::cleanModelOp(ModelOp modelOp) -{ +mlir::LogicalResult +SCCSolvingBySubstitutionPass::cleanModelOp(ModelOp modelOp) { mlir::RewritePatternSet patterns(&getContext()); ModelOp::getCleaningPatterns(patterns, &getContext()); return mlir::applyPatternsAndFoldGreedily(modelOp, std::move(patterns)); } -namespace mlir::bmodelica -{ - std::unique_ptr createSCCSolvingBySubstitutionPass() - { - return std::make_unique(); - } +namespace mlir::bmodelica { +std::unique_ptr createSCCSolvingBySubstitutionPass() { + return std::make_unique(); +} - std::unique_ptr createSCCSolvingBySubstitutionPass( - const SCCSolvingBySubstitutionPassOptions& options) - { - return std::make_unique(options); - } +std::unique_ptr createSCCSolvingBySubstitutionPass( + const SCCSolvingBySubstitutionPassOptions &options) { + return std::make_unique(options); } +} // namespace mlir::bmodelica From 09ad3c885ba2a62a0f9b01afbc5aff50e8bd38a9 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Mon, 7 Oct 2024 12:19:47 +0200 Subject: [PATCH 4/4] Fix raw variables being ignored during heap to stack promotion --- .../public/marco/Frontend/FrontendActions.h | 1 + lib/Frontend/FrontendActions.cpp | 29 +++++++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/include/public/marco/Frontend/FrontendActions.h b/include/public/marco/Frontend/FrontendActions.h index 7cc21d219..a21fef480 100644 --- a/include/public/marco/Frontend/FrontendActions.h +++ b/include/public/marco/Frontend/FrontendActions.h @@ -178,6 +178,7 @@ class CodeGenAction : public ASTAction { std::unique_ptr createMLIROneShotBufferizePass(); void buildMLIRBufferDeallocationPipeline(mlir::OpPassManager &pm); std::unique_ptr createMLIRLoopTilingPass(); + std::unique_ptr createMLIRPromoteBuffersToStackPass(); /// } /// @name LLVM-IR diff --git a/lib/Frontend/FrontendActions.cpp b/lib/Frontend/FrontendActions.cpp index 48b353471..c9ebb95c3 100644 --- a/lib/Frontend/FrontendActions.cpp +++ b/lib/Frontend/FrontendActions.cpp @@ -961,8 +961,7 @@ void CodeGenAction::buildMLIRLoweringPipeline(mlir::PassManager &pm) { } if (ci.getCodeGenOptions().heapToStackPromotion) { - pm.addNestedPass( - mlir::bufferization::createPromoteBuffersToStackPass()); + pm.addNestedPass(createMLIRPromoteBuffersToStackPass()); } // Buffer deallocations placements must be performed after loop @@ -1112,6 +1111,32 @@ std::unique_ptr CodeGenAction::createMLIRLoopTilingPass() { return mlir::affine::createLoopTilingPass(); } +std::unique_ptr +CodeGenAction::createMLIRPromoteBuffersToStackPass() { + // TODO: control with CLI + unsigned int maxAllocSizeInBytes = 1024; + + auto isSmallAllocFn = [=](mlir::Value alloc) -> bool { + auto type = mlir::dyn_cast(alloc.getType()); + + if (!type || + !alloc.getDefiningOp()) { + return false; + } + + if (!type.hasStaticShape()) { + return false; + } + + unsigned int bitwidth = mlir::DataLayout::closest(alloc.getDefiningOp()) + .getTypeSizeInBits(type.getElementType()); + + return type.getNumElements() * bitwidth <= maxAllocSizeInBytes * 8; + }; + + return mlir::bufferization::createPromoteBuffersToStackPass(isSmallAllocFn); +} + void CodeGenAction::registerMLIRToLLVMIRTranslations() { mlir::registerBuiltinDialectTranslation(getMLIRContext()); mlir::registerLLVMDialectTranslation(getMLIRContext());