Skip to content

Commit

Permalink
direct usage of StatefulLuxLayer (#413)
Browse files Browse the repository at this point in the history
* direct usage of `StatefulLuxLayer`

* fix
  • Loading branch information
prbzrg authored May 6, 2024
1 parent d6199f8 commit 0e989bb
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 122 deletions.
47 changes: 20 additions & 27 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ function inference_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(icnf.nn, ps, st)
nn = icnf.nn
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(xs, zrs),
steer_tspan(icnf, mode),
ps,
Expand All @@ -259,9 +259,9 @@ function inference_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st)
nn = CondLayer(icnf.nn, ys)
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(xs, zrs),
steer_tspan(icnf, mode),
ps,
Expand All @@ -281,9 +281,9 @@ function inference_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2))
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(icnf.nn, ps, st)
nn = icnf.nn
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(xs, zrs),
steer_tspan(icnf, mode),
ps,
Expand All @@ -304,9 +304,9 @@ function inference_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2))
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st)
nn = CondLayer(icnf.nn, ys)
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(xs, zrs),
steer_tspan(icnf, mode),
ps,
Expand All @@ -327,9 +327,9 @@ function generate_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(icnf.nn, ps, st)
nn = icnf.nn
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(new_xs, zrs),
reverse(steer_tspan(icnf, mode)),
ps,
Expand All @@ -351,9 +351,9 @@ function generate_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st)
nn = CondLayer(icnf.nn, ys)
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(new_xs, zrs),
reverse(steer_tspan(icnf, mode)),
ps,
Expand All @@ -375,9 +375,9 @@ function generate_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(icnf.nn, ps, st)
nn = icnf.nn
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(new_xs, zrs),
reverse(steer_tspan(icnf, mode)),
ps,
Expand All @@ -400,9 +400,9 @@ function generate_prob(
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st)
nn = CondLayer(icnf.nn, ys)
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
make_ode_func(icnf, mode, nn, ϵ),
make_ode_func(icnf, mode, nn, st, ϵ),
vcat(new_xs, zrs),
reverse(steer_tspan(icnf, mode)),
ps,
Expand Down Expand Up @@ -515,28 +515,21 @@ end
@inline function make_ode_func(
icnf::AbstractICNF{T, CM, INPLACE},
mode::Mode,
nn::Lux.StatefulLuxLayer,
nn::LuxCore.AbstractExplicitLayer,
st::NamedTuple,
ϵ::AbstractVecOrMat{T},
) where {T <: AbstractFloat, CM, INPLACE}
function ode_func_op(u, p, t)
augmented_f(u, p, t, icnf, mode, nn, ϵ)
augmented_f(u, p, t, icnf, mode, nn, st, ϵ)
end

function ode_func_ip(du, u, p, t)
augmented_f(du, u, p, t, icnf, mode, nn, ϵ)
augmented_f(du, u, p, t, icnf, mode, nn, st, ϵ)
end

ifelse(INPLACE, ode_func_ip, ode_func_op)
end

@inline function make_dyn_func(nn::Lux.StatefulLuxLayer, ps::Any)
function dyn_func(x)
LuxCore.apply(nn, x, ps)
end

dyn_func
end

@inline function (icnf::AbstractICNF{T, CM, INPLACE, false})(
xs::AbstractVecOrMat,
ps::Any,
Expand Down
Loading

0 comments on commit 0e989bb

Please sign in to comment.