Skip to content

Commit

Permalink
Merge pull request #249 from SciML/docsoptv4
Browse files Browse the repository at this point in the history
Update docs remove extra returns from loss and extra args from callback
  • Loading branch information
ChrisRackauckas authored Oct 29, 2024
2 parents 0be2ba4 + 109fdac commit 210ee41
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 36 deletions.
14 changes: 7 additions & 7 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ IncompleteLU = "0.2"
Integrals = "4"
LineSearches = "7"
LinearSolve = "2"
Lux = "0.5"
Lux = "1"
LuxCUDA = "0.3"
MCMCChains = "6"
Measurements = "2"
Expand All @@ -78,12 +78,12 @@ ModelingToolkit = "9.9"
MultiDocumenter = "0.7"
NeuralPDE = "5.15"
NonlinearSolve = "3"
Optimization = "3"
OptimizationMOI = "0.4"
OptimizationNLopt = "0.2"
OptimizationOptimJL = "0.2, 0.3"
OptimizationOptimisers = "0.2"
OptimizationPolyalgorithms = "0.2"
Optimization = "4"
OptimizationMOI = "0.5"
OptimizationNLopt = "0.3"
OptimizationOptimJL = "0.4"
OptimizationOptimisers = "0.3"
OptimizationPolyalgorithms = "0.3"
OrdinaryDiffEq = "6"
Plots = "1"
SciMLExpectations = "2"
Expand Down
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ makedocs(sitename = "Overview of Julia's SciML",
"https://epubs.siam.org/doi/10.1137/0903023",
"https://bkamins.github.io/julialang/2020/12/24/minilanguage.html",
"https://arxiv.org/abs/2109.06786",
"https://arxiv.org/abs/2001.04385"],
"https://arxiv.org/abs/2001.04385",
"https://code.visualstudio.com/"],
format = Documenter.HTML(assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/stable/",
mathengine = mathengine),
Expand Down
31 changes: 13 additions & 18 deletions docs/src/getting_started/fit_simulation.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,14 @@ function loss(newp)
newprob = remake(prob, p = newp)
sol = solve(newprob, saveat = 1)
loss = sum(abs2, sol .- xy_data)
return loss, sol
return loss
end
# Define a callback function to monitor optimization progress
function callback(p, l, sol)
function callback(state, l)
display(l)
newprob = remake(prob, p = state.u)
sol = solve(newprob, saveat = 1)
plt = plot(sol, ylim = (0, 6), label = ["Current x Prediction" "Current y Prediction"])
scatter!(plt, t_data, xy_data', label = ["x Data" "y Data"])
display(plt)
Expand Down Expand Up @@ -278,37 +280,28 @@ function loss(newp)
newprob = remake(prob, p = newp)
sol = solve(newprob, saveat = 1)
l = sum(abs2, sol .- xy_data)
return l, sol
return l
end
```

Notice that our loss function returns the loss value as the first return,
but returns extra information (the ODE solution with the new parameters)
as an extra return argument.
We will explain why this extra return information is helpful in the next section.

### Step 5: Solve the Optimization Problem

This step will look very similar to [the first optimization tutorial](@ref first_opt),
except now we have a new loss function `loss` which returns both the loss value
and the associated ODE solution.
(In the previous tutorial, `L` only returned the loss value.)
The `Optimization.solve` function can accept an optional callback function
to monitor the optimization process using extra arguments returned from `loss`.

The callback syntax is always:

```
callback(
optimization variables,
state,
the current loss value,
other arguments returned from the loss function, ...
)
```

In this case, we will provide the callback the arguments `(p, l, sol)`,
since it always takes the current state of the optimization first (`p`)
then the returns from the loss function (`l, sol`).
In this case, we will provide the callback the arguments `(state, l)`,
since it always takes the current state of the optimization first (`state`)
then the current loss value (`l`).
The return value of the callback function should default to `false`.
`Optimization.solve` will halt if/when the callback function returns `true` instead.
Typically the `return` statement would monitor the loss value
Expand All @@ -318,16 +311,18 @@ More details about callbacks in Optimization.jl can be found
[here](https://docs.sciml.ai/Optimization/stable/API/solve/).

```@example odefit
function callback(p, l, sol)
function callback(state, l)
display(l)
newprob = remake(prob, p = state.u)
sol = solve(newprob, saveat = 1)
plt = plot(sol, ylim = (0, 6), label = ["Current x Prediction" "Current y Prediction"])
scatter!(plt, t_data, xy_data', label = ["x Data" "y Data"])
display(plt)
return false
end
```

With this callback function, every step of the optimization will display both the loss value and a plot of how the solution compares to the training data.
With this callback function, every step of the optimization will display both the loss value and a plot of how the solution compares to the training data. Since we want to track the fit visually we plot the simulation at each iteration and compare it to the data. This is expensive since it requires an extra `solve` call and a plotting step for each iteration.

Now, just like [the first optimization tutorial](@ref first_opt),
we set up our `OptimizationFunction` and `OptimizationProblem`,
Expand Down
2 changes: 1 addition & 1 deletion docs/src/highlevels/modeling_languages.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ doing standard molecular dynamics approximations.

## DiffEqFinancial.jl: Financial models for use in the DifferentialEquations ecosystem

The goal of [DiffEqFinancial.jl](https://github.com/SciML/DiffEqFinancial.jl/commits/master) is to be a feature-complete set
The goal of [DiffEqFinancial.jl](https://github.com/SciML/DiffEqFinancial.jl/) is to be a feature-complete set
of solvers for the types of problems found in libraries like QuantLib, such as the Heston process or the
Black-Scholes model.

Expand Down
9 changes: 4 additions & 5 deletions docs/src/showcase/blackhole.md
Original file line number Diff line number Diff line change
Expand Up @@ -490,10 +490,8 @@ function loss(NN_params)
prob_nn, RK4(), u0 = u0, p = NN_params, saveat = tsteps, dt = dt, adaptive = false))
pred_waveform = compute_waveform(dt_data, pred, mass_ratio, model_params)[1]
loss = (sum(abs2,
view(waveform, obs_to_use_for_training) .-
view(pred_waveform, obs_to_use_for_training)))
return loss, pred_waveform
loss = ( sum(abs2, view(waveform,obs_to_use_for_training) .- view(pred_waveform,obs_to_use_for_training) ) )
return loss
end
```

Expand All @@ -508,10 +506,11 @@ We'll use the following callback to save the history of the loss values.
```@example ude
losses = []
callback(θ, l, pred_waveform; doplot = true) = begin
callback(state, l; doplot = true) = begin
push!(losses, l)
#= Disable plotting as it trains since in docs
display(l)
waveform = compute_waveform(dt_data, soln, mass_ratio, model_params)[1]
# plot current prediction against data
plt = plot(tsteps, waveform,
markershape=:circle, markeralpha = 0.25,
Expand Down
2 changes: 1 addition & 1 deletion docs/src/showcase/gpu_spde.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ These last two ways enclose the pointer to our cache arrays locally but still pr
function f(du,u,p,t) to the ODE solver.

Now, since PDEs are large, many times we don't care about getting the whole timeseries. Using
the [output controls from DifferentialEquations.jl](https://diffeq.sciml.ai/latest/basics/common_solver_opts.html#Output-Control-1), we can make it only output the final timepoint.
the [output controls from DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/), we can make it only output the final timepoint.

```julia
prob = ODEProblem(f, u0, (0.0, 100.0))
Expand Down
2 changes: 1 addition & 1 deletion docs/src/showcase/missing_physics.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ current loss:
```@example ude
losses = Float64[]
callback = function (p, l)
callback = function (state, l)
push!(losses, l)
if length(losses) % 50 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
Expand Down
2 changes: 1 addition & 1 deletion docs/src/showcase/pinngpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ prob = discretize(pde_system, discretization)
## Step 6: Solve the Optimization Problem

```@example pinn
callback = function (p, l)
callback = function (state, l)
println("Current loss is: $l")
return false
end
Expand Down
2 changes: 1 addition & 1 deletion docs/src/showcase/symbolic_analysis.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ Did you implement the DAE incorrectly? No. Is the solver broken? No.

It turns out that this is a property of the DAE that we are attempting to solve.
This kind of DAE is known as an index-3 DAE. For a complete discussion of DAE
index, see [this article](https://www.scholarpedia.org/article/Differential-algebraic_equations).
index, see [this article](http://www.scholarpedia.org/article/Differential-algebraic_equations).
Essentially, the issue here is that we have 4 differential variables (``x``, ``v_x``, ``y``, ``v_y``)
and one algebraic variable ``T`` (which we can know because there is no `D(T)`
term in the equations). An index-1 DAE always satisfies that the Jacobian of
Expand Down

0 comments on commit 210ee41

Please sign in to comment.