Skip to content

Commit

Permalink
Refactor IDA Jacobian one sweep optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
mscuttari committed Oct 7, 2024
1 parent c5dd6d6 commit e8f507a
Showing 1 changed file with 31 additions and 51 deletions.
82 changes: 31 additions & 51 deletions lib/Dialect/BaseModelica/Transforms/IDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1602,38 +1602,40 @@ mlir::LogicalResult IDAInstance::createJacobianFunction(
}
};

// Determine the positions of the seeds within the seeds list.
std::optional<size_t> oneSeedPosition = std::nullopt;
std::optional<size_t> derSeedPosition = std::nullopt;

if (independentVariablesPos.contains(independentVariable)) {
oneSeedPosition = independentVariablesPos.lookup(independentVariable);
}

if (auto derivative = getDerivative(
mlir::SymbolRefAttr::get(independentVariable.getSymNameAttr()))) {
auto derVariableOp =
symbolTableCollection->lookupSymbolIn<VariableOp>(modelOp, *derivative);

if (independentVariablesPos.contains(derVariableOp)) {
derSeedPosition = independentVariablesPos.lookup(derVariableOp);
}
}

if (jacobianOneSweep) {
// Perform just one call to the template function.
std::optional<size_t> oneSeedPosition = std::nullopt;

if (independentVariablesPos.contains(independentVariable)) {
if (oneSeedPosition) {
// Set the seed of the variable to one.
oneSeedPosition = independentVariablesPos.lookup(independentVariable);

setGlobalADSeed(builder, loc, varSeeds[*oneSeedPosition],
jacobianFunction.getVariableIndices(), one);
}

// Set the seed of the derivative to alpha.
std::optional<size_t> alphaSeedPosition = std::nullopt;

if (auto derivative = getDerivative(
mlir::SymbolRefAttr::get(independentVariable.getSymNameAttr()))) {
auto derVariableOp = symbolTableCollection->lookupSymbolIn<VariableOp>(
modelOp, *derivative);

if (independentVariablesPos.contains(derVariableOp)) {
alphaSeedPosition = independentVariablesPos.lookup(derVariableOp);
}
}

if (alphaSeedPosition) {
if (derSeedPosition) {
// Set the seed of the derivative to alpha.
mlir::Value alpha = jacobianFunction.getAlpha();

alpha = builder.create<CastOp>(
alpha.getLoc(), RealType::get(builder.getContext()), alpha);

setGlobalADSeed(builder, loc, varSeeds[*alphaSeedPosition],
setGlobalADSeed(builder, loc, varSeeds[*derSeedPosition],
jacobianFunction.getVariableIndices(), alpha);
}

Expand All @@ -1654,8 +1656,8 @@ mlir::LogicalResult IDAInstance::createJacobianFunction(
jacobianFunction.getVariableIndices(), zero);
}

if (alphaSeedPosition) {
setGlobalADSeed(builder, loc, varSeeds[*alphaSeedPosition],
if (derSeedPosition) {
setGlobalADSeed(builder, loc, varSeeds[*derSeedPosition],
jacobianFunction.getVariableIndices(), zero);
}

Expand All @@ -1664,20 +1666,13 @@ mlir::LogicalResult IDAInstance::createJacobianFunction(
builder.create<mlir::ida::ReturnOp>(loc, result);
} else {
llvm::SmallVector<mlir::Value> args;
std::optional<size_t> oneSeedPosition = std::nullopt;

// Perform the first call to the template function.
if (independentVariablesPos.contains(independentVariable)) {
// Set the seed of the variable to one.
oneSeedPosition = independentVariablesPos.lookup(independentVariable);
}

if (oneSeedPosition) {
setGlobalADSeed(builder, loc, varSeeds[*oneSeedPosition],
jacobianFunction.getVariableIndices(), one);
}

// Call the template function.
args.clear();
collectArgsFn(args);

Expand All @@ -1694,23 +1689,10 @@ mlir::LogicalResult IDAInstance::createJacobianFunction(
jacobianFunction.getVariableIndices(), zero);
}

if (auto derivative = getDerivative(
mlir::SymbolRefAttr::get(independentVariable.getSymNameAttr()))) {
auto globalDerivativeOp =
symbolTableCollection->lookupSymbolIn<VariableOp>(modelOp,
*derivative);

std::optional<size_t> derSeedPosition = std::nullopt;

if (independentVariablesPos.contains(globalDerivativeOp)) {
derSeedPosition = independentVariablesPos.lookup(globalDerivativeOp);
}

if (derSeedPosition) {
// Set the seed of the derivative to one.
setGlobalADSeed(builder, loc, varSeeds[*derSeedPosition],
jacobianFunction.getVariableIndices(), one);
}
if (derSeedPosition) {
// Set the seed of the derivative to one.
setGlobalADSeed(builder, loc, varSeeds[*derSeedPosition],
jacobianFunction.getVariableIndices(), one);

// Call the template function.
args.clear();
Expand All @@ -1722,11 +1704,9 @@ mlir::LogicalResult IDAInstance::createJacobianFunction(
partialDerTemplateName),
RealType::get(builder.getContext()), args);

if (derSeedPosition) {
// Reset the seed of the variable.
setGlobalADSeed(builder, loc, varSeeds[*derSeedPosition],
jacobianFunction.getVariableIndices(), zero);
}
// Reset the seed of the variable.
setGlobalADSeed(builder, loc, varSeeds[*derSeedPosition],
jacobianFunction.getVariableIndices(), zero);

mlir::Value secondResult = secondTemplateCall.getResult(0);

Expand Down

0 comments on commit e8f507a

Please sign in to comment.