Skip to content

Commit

Permalink
Ignore read-only variables when applying using the IDA solver
Browse files Browse the repository at this point in the history
  • Loading branch information
mscuttari committed Oct 7, 2024
1 parent 5ed1b6a commit 63871fc
Showing 1 changed file with 80 additions and 34 deletions.
114 changes: 80 additions & 34 deletions lib/Dialect/BaseModelica/Transforms/IDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1337,6 +1337,11 @@ mlir::LogicalResult IDAInstance::getIndependentVariablesForAD(
symbolTableCollection->lookupSymbolIn<VariableOp>(
modelOp, access.getVariable());

if (variableOp.isReadOnly()) {
// Treat read-only variables as if they were just numbers.
continue;
}

result.insert(variableOp);

if (auto derivative = getDerivative(access.getVariable())) {
Expand Down Expand Up @@ -1445,8 +1450,11 @@ FunctionOp IDAInstance::createPartialDerTemplateFromEquation(
// Start the body of the function.
rewriter.setInsertionPointToStart(functionOp.getBody());

// Keep track of all the variables that have been declared for creating the body of the function to be derived.
llvm::DenseSet<llvm::StringRef> allLocalVariables;

// Replicate the original independent variables inside the function.
llvm::StringMap<VariableOp> localVariableOps;
llvm::StringMap<VariableOp> mappedVariableOps;
size_t independentVariableIndex = 0;

for (VariableOp variableOp : variableOps) {
Expand All @@ -1460,7 +1468,8 @@ FunctionOp IDAInstance::createPartialDerTemplateFromEquation(
auto clonedVariableOp = rewriter.create<VariableOp>(
variableOp.getLoc(), variableOp.getSymName(), variableType);

localVariableOps[variableOp.getSymName()] = clonedVariableOp;
allLocalVariables.insert(clonedVariableOp.getSymName());
mappedVariableOps[variableOp.getSymName()] = clonedVariableOp;
independentVariablesPos[variableOp] = independentVariableIndex++;
}

Expand All @@ -1472,7 +1481,7 @@ FunctionOp IDAInstance::createPartialDerTemplateFromEquation(
continue;
}

auto mappedVariable = localVariableOps[variableOp.getSymName()];
auto mappedVariable = mappedVariableOps[variableOp.getSymName()];

llvm::dbgs() << variableOp.getSymName() << " -> "
<< mappedVariable.getSymName() << "\n";
Expand All @@ -1495,6 +1504,7 @@ FunctionOp IDAInstance::createPartialDerTemplateFromEquation(
auto variableOp = rewriter.create<VariableOp>(
loc, variableName, variableType);

allLocalVariables.insert(variableOp.getSymName());
inductionVariablesOps.push_back(variableOp);
}

Expand All @@ -1515,6 +1525,8 @@ FunctionOp IDAInstance::createPartialDerTemplateFromEquation(
VariabilityProperty::none,
IOProperty::output));

allLocalVariables.insert(outputVariableOp.getSymName());

// Create the body of the function.
auto algorithmOp = rewriter.create<AlgorithmOp>(loc);

Expand Down Expand Up @@ -1577,7 +1589,7 @@ FunctionOp IDAInstance::createPartialDerTemplateFromEquation(
}

if (auto globalGetOp = mlir::dyn_cast<GlobalVariableGetOp>(op)) {
VariableOp variableOp = localVariableOps[globalGetOp.getVariable()];
VariableOp variableOp = mappedVariableOps[globalGetOp.getVariable()];

auto getOp = rewriter.create<VariableGetOp>(
globalGetOp.getLoc(), variableOp);
Expand All @@ -1598,6 +1610,19 @@ FunctionOp IDAInstance::createPartialDerTemplateFromEquation(

rewriter.create<VariableSetOp>(loc, outputVariableOp, result);

// Use the qualified accesses for the non-independent variables.
llvm::SmallVector<VariableGetOp> getOpsToQualify;

functionOp.walk([&](VariableGetOp getOp) {
if (!allLocalVariables.contains(getOp.getVariable())) {
getOpsToQualify.push_back(getOp);
}
});

if (mlir::failed(replaceVariableGetOps(rewriter, modelOp, getOpsToQualify))) {
return nullptr;
}

// Create the derivative template function.
LLVM_DEBUG({
llvm::dbgs() << "Function being derived:\n" << functionOp << "\n";
Expand All @@ -1614,19 +1639,19 @@ FunctionOp IDAInstance::createPartialDerTemplateFromEquation(

rewriter.eraseOp(functionOp);

// Replace the local variables with the global ones.
// Replace the mapped variables with qualified accesses.
llvm::DenseSet<VariableOp> variablesToBeReplaced;

for (VariableOp variableOp : derTemplate->getVariables()) {
if (localVariableOps.count(variableOp.getSymName()) != 0) {
if (mappedVariableOps.contains(variableOp.getSymName())) {
variablesToBeReplaced.insert(variableOp);
}
}

llvm::SmallVector<VariableGetOp> variableGetOps;

derTemplate->walk([&](VariableGetOp getOp) {
if (localVariableOps.count(getOp.getVariable()) != 0) {
if (mappedVariableOps.contains(getOp.getVariable())) {
variableGetOps.push_back(getOp);
}
});
Expand Down Expand Up @@ -1747,14 +1772,16 @@ mlir::LogicalResult IDAInstance::createJacobianFunction(

if (jacobianOneSweep) {
// Perform just one call to the template function.
assert(independentVariablesPos.count(independentVariable) != 0);
std::optional<size_t> oneSeedPosition = std::nullopt;

// Set the seed of the variable to one.
size_t oneSeedPosition =
independentVariablesPos.lookup(independentVariable);
if (independentVariablesPos.contains(independentVariable)) {
// Set the seed of the variable to one.
oneSeedPosition =
independentVariablesPos.lookup(independentVariable);

setGlobalADSeed(builder, loc, varSeeds[oneSeedPosition],
jacobianFunction.getVariableIndices(), one);
setGlobalADSeed(builder, loc, varSeeds[*oneSeedPosition],
jacobianFunction.getVariableIndices(), one);
}

// Set the seed of the derivative to alpha.
std::optional<size_t> alphaSeedPosition = std::nullopt;
Expand All @@ -1764,8 +1791,9 @@ mlir::LogicalResult IDAInstance::createJacobianFunction(
auto derVariableOp =
symbolTableCollection->lookupSymbolIn<VariableOp>(modelOp, *derivative);

assert(independentVariablesPos.count(derVariableOp) != 0);
alphaSeedPosition = independentVariablesPos.lookup(derVariableOp);
if (independentVariablesPos.contains(derVariableOp)) {
alphaSeedPosition = independentVariablesPos.lookup(derVariableOp);
}
}

if (alphaSeedPosition) {
Expand All @@ -1791,8 +1819,10 @@ mlir::LogicalResult IDAInstance::createJacobianFunction(
mlir::Value result = templateCall.getResult(0);

// Reset the seeds.
setGlobalADSeed(builder, loc, varSeeds[oneSeedPosition],
jacobianFunction.getVariableIndices(), zero);
if (oneSeedPosition) {
setGlobalADSeed(builder, loc, varSeeds[*oneSeedPosition],
jacobianFunction.getVariableIndices(), zero);
}

if (alphaSeedPosition) {
setGlobalADSeed(builder, loc, varSeeds[*alphaSeedPosition],
Expand All @@ -1804,16 +1834,19 @@ mlir::LogicalResult IDAInstance::createJacobianFunction(
builder.create<mlir::ida::ReturnOp>(loc, result);
} else {
llvm::SmallVector<mlir::Value> args;
std::optional<size_t> oneSeedPosition = std::nullopt;

// Perform the first call to the template function.
assert(independentVariablesPos.count(independentVariable) != 0);

// Set the seed of the variable to one.
size_t oneSeedPosition =
independentVariablesPos.lookup(independentVariable);
if (independentVariablesPos.contains(independentVariable)) {
// Set the seed of the variable to one.
oneSeedPosition =
independentVariablesPos.lookup(independentVariable);
}

setGlobalADSeed(builder, loc, varSeeds[oneSeedPosition],
jacobianFunction.getVariableIndices(), one);
if (oneSeedPosition) {
setGlobalADSeed(builder, loc, varSeeds[*oneSeedPosition],
jacobianFunction.getVariableIndices(), one);
}

// Call the template function.
args.clear();
Expand All @@ -1827,23 +1860,28 @@ mlir::LogicalResult IDAInstance::createJacobianFunction(

mlir::Value result = firstTemplateCall.getResult(0);

// Reset the seed of the variable.
setGlobalADSeed(builder, loc, varSeeds[oneSeedPosition],
jacobianFunction.getVariableIndices(), zero);
if (oneSeedPosition) {
// Reset the seed of the variable.
setGlobalADSeed(builder, loc, varSeeds[*oneSeedPosition],
jacobianFunction.getVariableIndices(), zero);
}

if (auto derivative = getDerivative(mlir::SymbolRefAttr::get(
independentVariable.getSymNameAttr()))) {
auto globalDerivativeOp =
symbolTableCollection->lookupSymbolIn<VariableOp>(modelOp, *derivative);

assert(independentVariablesPos.count(globalDerivativeOp) != 0);
std::optional<size_t> derSeedPosition = std::nullopt;

size_t derSeedPosition =
independentVariablesPos.lookup(globalDerivativeOp);
if (independentVariablesPos.contains(globalDerivativeOp)) {
derSeedPosition = independentVariablesPos.lookup(globalDerivativeOp);
}

// Set the seed of the derivative to one.
setGlobalADSeed(builder, loc, varSeeds[derSeedPosition],
jacobianFunction.getVariableIndices(), one);
if (derSeedPosition) {
// Set the seed of the derivative to one.
setGlobalADSeed(builder, loc, varSeeds[*derSeedPosition],
jacobianFunction.getVariableIndices(), one);
}

// Call the template function.
args.clear();
Expand All @@ -1855,6 +1893,12 @@ mlir::LogicalResult IDAInstance::createJacobianFunction(
RealType::get(builder.getContext()),
args);

if (derSeedPosition) {
// Reset the seed of the variable.
setGlobalADSeed(builder, loc, varSeeds[*derSeedPosition],
jacobianFunction.getVariableIndices(), zero);
}

mlir::Value secondResult = secondTemplateCall.getResult(0);

mlir::Value secondResultTimesAlpha = builder.create<MulOp>(
Expand Down Expand Up @@ -2267,7 +2311,7 @@ mlir::LogicalResult IDAPass::solveMainModel(
} else {
LLVM_DEBUG(llvm::dbgs() << "Reduced system feature disabled\n");

// Add all the variables to IDA.
// Add all the non-read-only variables to IDA.
for (VariableOp variable : variables) {
auto variableName =
mlir::SymbolRefAttr::get(variable.getSymNameAttr());
Expand Down Expand Up @@ -2394,6 +2438,8 @@ mlir::LogicalResult IDAPass::addMainModelEquation(
symbolTableCollection.lookupSymbolIn<VariableOp>(
modelOp, writtenVariable);

assert(!writtenVariableOp.isReadOnly());

if (derivativesMap.getDerivedVariable(writtenVariable)) {
LLVM_DEBUG(llvm::dbgs() << "Add derivative variable: "
<< writtenVariableOp.getSymName() << "\n");
Expand Down

0 comments on commit 63871fc

Please sign in to comment.