Skip to content

Commit

Permalink
changes for new ODE bpinn solver from PR SciML#842
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed May 8, 2024
1 parent a1b3125 commit 5648819
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 12 deletions.
12 changes: 9 additions & 3 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ struct BNNODE{C, K, IT <: NamedTuple,
init_params::I
Adaptorkwargs::A
Integratorkwargs::IT
numensemble::Int64
estim_collocate::Bool
autodiff::Bool
progress::Bool
verbose::Bool
Expand All @@ -112,6 +114,8 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
numensemble = floor(Int, draw_samples / 3),
estim_collocate = false,
autodiff = false, progress = false, verbose = false)
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
Expand Down Expand Up @@ -186,7 +190,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
@unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy,
draw_samples, dataset, init_params,
nchains, physdt, Adaptorkwargs, Integratorkwargs,
MCMCkwargs, autodiff, progress, verbose = alg
MCMCkwargs, numensemble, estim_collocate, autodiff, progress,
verbose = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
param = param === nothing ? [] : param
Expand All @@ -211,7 +216,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
Integratorkwargs = Integratorkwargs,
MCMCkwargs = MCMCkwargs,
progress = progress,
verbose = verbose)
verbose = verbose,
estim_collocate = estim_collocate)

fullsolution = BPINNstats(mcmcchain, samples, statistics)
ninv = length(param)
Expand All @@ -220,7 +226,7 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
if chain isa Lux.AbstractExplicitLayer
θinit, st = Lux.setup(Random.default_rng(), chain)
θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit)
for i in (draw_samples - numensemble):draw_samples]
for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)]
luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
θinit = collect(ComponentArrays.ComponentArray(θinit))
Expand Down
1 change: 1 addition & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ include("advancedHMC_MCMC.jl")
include("BPINN_ode.jl")
include("PDE_BPINN.jl")
include("dgm.jl")
include("collocated_estim.jl")

export NNODE, NNDAE,
PhysicsInformedNN, discretize,
Expand Down
24 changes: 15 additions & 9 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::AbstractVector)
init_params::AbstractVector, estim_collocate)
new{
typeof(chain),
Nothing,
Expand All @@ -39,7 +39,8 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
autodiff,
physdt,
extraparams,
init_params)
init_params,
estim_collocate)
end
function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, strategy,
dataset,
Expand Down Expand Up @@ -83,7 +84,12 @@ end
vector_to_parameters(ps_new::AbstractVector, ps::AbstractVector) = ps_new

function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ)
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
if Tar.estim_collocate
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) +
L2loss2(Tar, θ)
else
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
end
end

LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim
Expand Down Expand Up @@ -247,7 +253,7 @@ function innerdiff(Tar::LogTargetDensity, f, autodiff::Bool, t::AbstractVector,

vals = nnsol .- physsol

# N dimensional vector if N outputs for NN(each row has logpdf of i[i] where u is vector of dependant variables)
# N dimensional vector if N outputs for NN(each row has logpdf of u[i] where u is vector of dependant variables)
return [logpdf(
MvNormal(vals[i, :],
LinearAlgebra.Diagonal(abs2.(Tar.phystd[i] .*
Expand Down Expand Up @@ -442,7 +448,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false)
progress = false, verbose = false, estim_collocate = false)
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
# NN parameter prior mean and variance(PriorsNN must be a tuple)
Expand All @@ -467,7 +473,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
# Lux-Named Tuple
initial_nnθ, recon, st = generate_Tar(chain, init_params)
else
error("Only Lux.AbstractExplicitLayer neural networks are supported")
error("Only Lux.AbstractExplicitLayer Neural Networks are supported")
end

if nchains > Threads.nthreads()
Expand Down Expand Up @@ -500,7 +506,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
t0 = prob.tspan[1]
# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = LogTargetDensity(nparameters, prob, recon, st, strategy, dataset, priors,
phystd, l2std, autodiff, physdt, ninv, initial_nnθ)
phystd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)

try
ℓπ(t0, initial_θ[1:(nparameters - ninv)])
Expand Down Expand Up @@ -569,8 +575,8 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
L2LossData(ℓπ, samples[end]))

# return a chain(basic chain),samples and stats
matrix_samples = hcat(samples...)
mcmc_chain = MCMCChains.Chains(matrix_samples')
matrix_samples = reshape(hcat(samples...), (length(samples[1]), length(samples), 1))
mcmc_chain = MCMCChains.Chains(matrix_samples)
return mcmc_chain, samples, stats
end
end
Empty file added src/collocated_estim.jl
Empty file.

0 comments on commit 5648819

Please sign in to comment.