Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds JAX IDAKLU solver integration #481

Open
wants to merge 22 commits into
base: develop
Choose a base branch
from

Conversation

BradyPlanden
Copy link
Member

@BradyPlanden BradyPlanden commented Sep 2, 2024

Description

This PR adds the jaxified IDAKLU solver to enable autodiff for the cost and likelihood classes. At the moment the IDAKLU solver is limited to first order sensitivity information and as such we are limited to gradient information from the autodiff cost/likelihood classes.

As an example of how to use the jaxified IDAKLU, an experimental subdirectory is added with the JaxSumSquaredError and JaxLogNormalLikelihood classes. These classes only required the evaluate method to be defined, with jax's value_and_grad method to capture the gradient information. Currently, this solver matches the casadi fast with events solver in most cases, with greatly improved performance in computing sensitivities. This performance is expected to improve even more with the next PyBaMM release.

This also opens up future functionality for gradient based optimisers in design optimisation of non-geometric parameters, as autodiff can provide gradients for any constructed cost/likelihood/design function.

To Do

  • Add Tests

Issue reference

Fixes # (issue-number)

Review

Before you mark your PR as ready for review, please ensure that you've considered the following:

  • Updated the CHANGELOG.md in reverse chronological order (newest at the top) with a concise description of the changes, including the PR number.
  • Noted any breaking changes, including details on how it might impact existing functionality.

Type of change

  • New Feature: A non-breaking change that adds new functionality.
  • Optimization: A code change that improves performance.
  • Examples: A change to existing or additional examples.
  • Bug Fix: A non-breaking change that addresses an issue.
  • Documentation: Updates to documentation or new documentation for new features.
  • Refactoring: Non-functional changes that improve the codebase.
  • Style: Non-functional changes related to code style (formatting, naming, etc).
  • Testing: Additional tests to improve coverage or confirm functionality.
  • Other: (Insert description of change)

Key checklist:

  • No style issues: $ pre-commit run (or $ nox -s pre-commit) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)
  • All unit tests pass: $ nox -s tests
  • The documentation builds: $ nox -s doctest

You can run integration tests, unit tests, and doctests together at once, using $ nox -s quick.

Further checks:

  • Code is well-commented, especially in complex or unclear areas.
  • Added tests that prove my fix is effective or that my feature works.
  • Checked that coverage remains or improves, and added tests if necessary to maintain or increase coverage.

Thank you for contributing to our project! Your efforts help us to deliver great software.

Copy link

codecov bot commented Sep 2, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.11%. Comparing base (c8b00e6) to head (44e3095).
Report is 8 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop     #481      +/-   ##
===========================================
+ Coverage    99.08%   99.11%   +0.02%     
===========================================
  Files           52       56       +4     
  Lines         3605     3838     +233     
===========================================
+ Hits          3572     3804     +232     
- Misses          33       34       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@BradyPlanden
Copy link
Member Author

Here's the benchmark script for the solvers in this PR. A version of it is also in the PR.

import time
import numpy as np
import pybamm

import pybop

n = 50  # Number of solves
solvers = [
    pybamm.CasadiSolver(mode="fast with events", atol=1e-6, rtol=1e-6),
    pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6),
]

# Parameter set and model definition
parameter_set = pybop.ParameterSet.pybamm("Chen2020")
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=solvers[0])

# Fitting parameters
parameters = pybop.Parameters(
    pybop.Parameter(
        "Negative electrode active material volume fraction", initial_value=0.55
    ),
    pybop.Parameter(
        "Positive electrode active material volume fraction", initial_value=0.55
    ),
)

# Define test protocol and generate data
t_eval = np.linspace(0, 100, 1000)
values = model.predict(
    initial_state={"Initial open-circuit voltage [V]": 4.2}, t_eval=t_eval
)

# Form dataset
dataset = pybop.Dataset(
    {
        "Time [s]": values["Time [s]"].data,
        "Current function [A]": values["Current [A]"].data,
        "Voltage [V]": values["Voltage [V]"].data,
    }
)


# Create inputs function for benchmarking
def inputs():
    return {
        "Negative electrode active material volume fraction": 0.55
        + np.random.normal(0, 0.01),
        "Positive electrode active material volume fraction": 0.55
        + np.random.normal(0, 0.01),
    }


# Iterate over the solvers and print benchmarks
for solver in solvers:
    # Setup Fitting Problem
    model.solver = solver
    problem = pybop.FittingProblem(model, parameters, dataset)
    cost = pybop.SumSquaredError(problem)

    start_time = time.time()
    for _i in range(n):
        out = problem.model.simulate(inputs=inputs(), t_eval=t_eval)
    print(f"({solver.name}) Time model.simulate: {time.time() - start_time:.4f}")

    start_time = time.time()
    for _i in range(n):
        out = problem.model.simulateS1(inputs=inputs(), t_eval=t_eval)
    print(f"({solver.name}) Time model.SimulateS1: {time.time() - start_time:.4f}")

    start_time = time.time()
    for _i in range(n):
        out = problem.evaluate(inputs=inputs())
    print(f"({solver.name}) Time problem.evaluate: {time.time() - start_time:.4f}")

    start_time = time.time()
    for _i in range(n):
        out = problem.evaluateS1(inputs=inputs())
    print(f"({solver.name}) Time Problem.EvaluateS1: {time.time() - start_time:.4f}")

    start_time = time.time()
    for _i in range(n):
        out = cost(inputs(), calculate_grad=False)
    print(f"({solver.name}) Time PyBOP Cost w/o grad: {time.time() - start_time:.4f}")

    start_time = time.time()
    for _i in range(n):
        out = cost(inputs(), calculate_grad=True)
    print(f"({solver.name}) Time PyBOP Cost w/grad: {time.time() - start_time:.4f}")

# Recreate for Jax IDAKLU solver
ida_solver =pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6)
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=ida_solver, jax=True)
problem = pybop.FittingProblem(model, parameters, dataset)
cost = pybop.JaxSumSquaredError(problem)

start_time = time.time()
for _i in range(n):
    out = cost(inputs(), calculate_grad=False)
print(f"Time Jax SumSquaredError w/o grad: {time.time() - start_time:.4f}")

start_time = time.time()
for _i in range(n):
    out = cost(inputs(), calculate_grad=True)
print(f"Time Jax SumSquaredError w/ grad: {time.time() - start_time:.4f}")

which produces the following on my M3 Pro Macbook:

(CasADi solver with 'fast with events' mode) Time model.simulate: 3.2579
(CasADi solver with 'fast with events' mode) Time model.SimulateS1: 13.5679
(CasADi solver with 'fast with events' mode) Time problem.evaluate: 6.8836
(CasADi solver with 'fast with events' mode) Time Problem.EvaluateS1: 152.7627
(CasADi solver with 'fast with events' mode) Time PyBOP Cost w/o grad: 7.1857
(CasADi solver with 'fast with events' mode) Time PyBOP Cost w/grad: 155.6699
(IDA KLU solver) Time model.simulate: 6.5524
(IDA KLU solver) Time model.SimulateS1: 17.9455
(IDA KLU solver) Time problem.evaluate: 6.6003
(IDA KLU solver) Time Problem.EvaluateS1: 18.0940
(IDA KLU solver) Time PyBOP Cost w/o grad: 6.5335
(IDA KLU solver) Time PyBOP Cost w/grad: 18.1650
Time Jax SumSquaredError w/o grad: 6.9650
Time Jax SumSquaredError w/ grad: 19.5255

@BradyPlanden BradyPlanden marked this pull request as ready for review September 25, 2024 13:23
@martinjrobins
Copy link
Contributor

Nice speedups :) I'm a bit concerned that there are now 3 possible ways to calculate the cost functions: 1) original cost functions in python, 2) my proposal to use pybamm variables #513, and now 3) JAX. I know from maintaining 4 different solvers in pybamm that this is no fun, and at the moment we're trying to reduce that down to just the idaklu solver, so I think we might need to choose the main mechanism for calculating cost functions in PyBop rather than implementing and then maintaining all three of them.

@BradyPlanden
Copy link
Member Author

BradyPlanden commented Sep 25, 2024

Yes, this is something I've been thinking about, which is one of the reasons this adds an /experimental subdirectory to the repository. The idea here is that we can provide methods that don't offer the same level of stability as the main classes, but offer improvements in other ways. The Jax-based implementations in this PR don't offer the same robustness, but are easier to support, and creating new gradient-based cost instances require less overhead. Similarly, they open up the possibility of JITing the entire inference process.

Going forward, these would remain in /experimental until (if) they are improved enough to justify better integration with PyBOP's canonical cost classes. After taking a quick look at your toy example in #513, I think the cost function's within this PR could easily be refactored to be completely integrated with #513.

@martinjrobins
Copy link
Contributor

ok, fair enough, happy with putting the new cost function methods into experimental for now. But I would suggest that we try and stabilise on a single method as soon as we can so there is less to maintain going forward.

Copy link
Contributor

@martinjrobins martinjrobins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @BradyPlanden. In general I worry that the cost function are mutating a reference to a model, but they don't know what else the user is doing with that model. The user might want to compare the result of a jax cost function with a standard cost funtion using the same model. It would simplify the code if models were copied by problems or cost functions, then they could mutate them as much as they wanted. And you wouldn't have to pass that jax arg to models, because the cost function can just set it. What do you think?

pybop/models/base_model.py Show resolved Hide resolved
pybop/problems/base_problem.py Outdated Show resolved Hide resolved
@BradyPlanden BradyPlanden added the ask This PR needs a review for merging label Oct 11, 2024
# Conflicts:
#	CHANGELOG.md
#	examples/notebooks/1-single-pulse-circuit-model.ipynb
# Conflicts:
#	examples/notebooks/single_pulse_circuit_model.ipynb
# Conflicts:
#	examples/notebooks/single_pulse_circuit_model.ipynb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ask This PR needs a review for merging
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

2 participants