-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds JAX IDAKLU solver integration #481
base: develop
Are you sure you want to change the base?
Conversation
…benchmarks. Adds JaxSumSquaredError and JaxLogNormalLikelihood.
# Conflicts: # pyproject.toml
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #481 +/- ##
===========================================
+ Coverage 99.08% 99.11% +0.02%
===========================================
Files 52 56 +4
Lines 3605 3838 +233
===========================================
+ Hits 3572 3804 +232
- Misses 33 34 +1 ☔ View full report in Codecov by Sentry. |
Here's the benchmark script for the solvers in this PR. A version of it is also in the PR. import time
import numpy as np
import pybamm
import pybop
n = 50 # Number of solves
solvers = [
pybamm.CasadiSolver(mode="fast with events", atol=1e-6, rtol=1e-6),
pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6),
]
# Parameter set and model definition
parameter_set = pybop.ParameterSet.pybamm("Chen2020")
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=solvers[0])
# Fitting parameters
parameters = pybop.Parameters(
pybop.Parameter(
"Negative electrode active material volume fraction", initial_value=0.55
),
pybop.Parameter(
"Positive electrode active material volume fraction", initial_value=0.55
),
)
# Define test protocol and generate data
t_eval = np.linspace(0, 100, 1000)
values = model.predict(
initial_state={"Initial open-circuit voltage [V]": 4.2}, t_eval=t_eval
)
# Form dataset
dataset = pybop.Dataset(
{
"Time [s]": values["Time [s]"].data,
"Current function [A]": values["Current [A]"].data,
"Voltage [V]": values["Voltage [V]"].data,
}
)
# Create inputs function for benchmarking
def inputs():
return {
"Negative electrode active material volume fraction": 0.55
+ np.random.normal(0, 0.01),
"Positive electrode active material volume fraction": 0.55
+ np.random.normal(0, 0.01),
}
# Iterate over the solvers and print benchmarks
for solver in solvers:
# Setup Fitting Problem
model.solver = solver
problem = pybop.FittingProblem(model, parameters, dataset)
cost = pybop.SumSquaredError(problem)
start_time = time.time()
for _i in range(n):
out = problem.model.simulate(inputs=inputs(), t_eval=t_eval)
print(f"({solver.name}) Time model.simulate: {time.time() - start_time:.4f}")
start_time = time.time()
for _i in range(n):
out = problem.model.simulateS1(inputs=inputs(), t_eval=t_eval)
print(f"({solver.name}) Time model.SimulateS1: {time.time() - start_time:.4f}")
start_time = time.time()
for _i in range(n):
out = problem.evaluate(inputs=inputs())
print(f"({solver.name}) Time problem.evaluate: {time.time() - start_time:.4f}")
start_time = time.time()
for _i in range(n):
out = problem.evaluateS1(inputs=inputs())
print(f"({solver.name}) Time Problem.EvaluateS1: {time.time() - start_time:.4f}")
start_time = time.time()
for _i in range(n):
out = cost(inputs(), calculate_grad=False)
print(f"({solver.name}) Time PyBOP Cost w/o grad: {time.time() - start_time:.4f}")
start_time = time.time()
for _i in range(n):
out = cost(inputs(), calculate_grad=True)
print(f"({solver.name}) Time PyBOP Cost w/grad: {time.time() - start_time:.4f}")
# Recreate for Jax IDAKLU solver
ida_solver =pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6)
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=ida_solver, jax=True)
problem = pybop.FittingProblem(model, parameters, dataset)
cost = pybop.JaxSumSquaredError(problem)
start_time = time.time()
for _i in range(n):
out = cost(inputs(), calculate_grad=False)
print(f"Time Jax SumSquaredError w/o grad: {time.time() - start_time:.4f}")
start_time = time.time()
for _i in range(n):
out = cost(inputs(), calculate_grad=True)
print(f"Time Jax SumSquaredError w/ grad: {time.time() - start_time:.4f}") which produces the following on my M3 Pro Macbook:
|
# Conflicts: # pyproject.toml
Nice speedups :) I'm a bit concerned that there are now 3 possible ways to calculate the cost functions: 1) original cost functions in python, 2) my proposal to use pybamm variables #513, and now 3) JAX. I know from maintaining 4 different solvers in pybamm that this is no fun, and at the moment we're trying to reduce that down to just the idaklu solver, so I think we might need to choose the main mechanism for calculating cost functions in PyBop rather than implementing and then maintaining all three of them. |
Yes, this is something I've been thinking about, which is one of the reasons this adds an Going forward, these would remain in |
ok, fair enough, happy with putting the new cost function methods into experimental for now. But I would suggest that we try and stabilise on a single method as soon as we can so there is less to maintain going forward. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @BradyPlanden. In general I worry that the cost function are mutating a reference to a model, but they don't know what else the user is doing with that model. The user might want to compare the result of a jax cost function with a standard cost funtion using the same model. It would simplify the code if models were copied by problems or cost functions, then they could mutate them as much as they wanted. And you wouldn't have to pass that jax arg to models, because the cost function can just set it. What do you think?
# Conflicts: # CHANGELOG.md # examples/notebooks/1-single-pulse-circuit-model.ipynb
# Conflicts: # examples/notebooks/single_pulse_circuit_model.ipynb
# Conflicts: # examples/notebooks/single_pulse_circuit_model.ipynb
Description
This PR adds the jaxified IDAKLU solver to enable autodiff for the cost and likelihood classes. At the moment the IDAKLU solver is limited to first order sensitivity information and as such we are limited to gradient information from the autodiff cost/likelihood classes.
As an example of how to use the jaxified IDAKLU, an
experimental
subdirectory is added with theJaxSumSquaredError
andJaxLogNormalLikelihood
classes. These classes only required theevaluate
method to be defined, with jax'svalue_and_grad
method to capture the gradient information. Currently, this solver matches thecasadi fast with events
solver in most cases, with greatly improved performance in computing sensitivities. This performance is expected to improve even more with the next PyBaMM release.This also opens up future functionality for gradient based optimisers in design optimisation of non-geometric parameters, as autodiff can provide gradients for any constructed cost/likelihood/design function.
To Do
Issue reference
Fixes # (issue-number)
Review
Before you mark your PR as ready for review, please ensure that you've considered the following:
Type of change
Key checklist:
$ pre-commit run
(or$ nox -s pre-commit
) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)$ nox -s tests
$ nox -s doctest
You can run integration tests, unit tests, and doctests together at once, using
$ nox -s quick
.Further checks:
Thank you for contributing to our project! Your efforts help us to deliver great software.