Skip to content

Commit

Permalink
add augmented node & steer
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Aug 1, 2023
1 parent c3e1e29 commit 9a439b7
Show file tree
Hide file tree
Showing 12 changed files with 873 additions and 51 deletions.
50 changes: 47 additions & 3 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,36 @@ export construct
function construct(
aicnf::Type{<:AbstractFlows},
nn,
nvars::Integer;
nvars::Integer,
naugmented::Integer = 0;
data_type::Type{<:AbstractFloat} = Float32,
array_type::Type{<:AbstractArray} = Array,
compute_mode::Type{<:ComputeMode} = ADVectorMode,
basedist::Distribution = MvNormal(Zeros{data_type}(nvars), Eye{data_type}(nvars)),
augmented::Bool = false,
steer::Bool = false,
basedist::Distribution = MvNormal(
Zeros{data_type}(nvars + naugmented),
Eye{data_type}(nvars + naugmented),
),
tspan::NTuple{2} = (zero(data_type), one(data_type)),
steer_rate::AbstractFloat = zero(data_type),
differentiation_backend::AbstractDifferentiation.AbstractBackend = AbstractDifferentiation.ZygoteBackend(),
sol_args::Tuple = (),
sol_kwargs::Dict = Dict(
:alg_hints => [:nonstiff, :memorybound],
:reltol => 1e-2 + eps(1e-2),
),
)
aicnf{data_type, array_type, compute_mode}(
!augmented && !iszero(naugmented) && error("'naugmented' > 0: 'augmented' must be true")
!steer && !iszero(steer_rate) && error("'steer_rate' > 0: 'steer' must be true")

aicnf{data_type, array_type, compute_mode, augmented, steer}(
nn,
nvars,
naugmented,
basedist,
tspan,
steer_rate,
differentiation_backend,
sol_args,
sol_kwargs,
Expand All @@ -44,6 +56,38 @@ function Base.show(io::IO, icnf::AbstractFlows)
)
end

@inline function n_augment_input(
icnf::AbstractFlows{<:AbstractFloat, <:AbstractArray, <:ComputeMode, true},
)
icnf.naugmented
end

@inline function n_augment_input(icnf::AbstractFlows)
0
end

@inline function steer_tspan(
icnf::AbstractFlows{T, <:AbstractArray, <:ComputeMode, AUGMENTED, true},
tspan::NTuple{2} = icnf.tspan,
steer_rate::AbstractFloat = icnf.steer_rate,
rng::AbstractRNG = Random.default_rng(),
) where {T <: AbstractFloat, AUGMENTED}
t₀, t₁ = tspan
steer_b = steer_rate * t₁
d_s = Uniform{T}(t₁ - steer_b, t₁ + steer_b)
t₁_new = convert(T, rand(rng, d_s))
(t₀, t₁_new)
end

@inline function steer_tspan(
icnf::AbstractFlows,
tspan::NTuple{2} = icnf.tspan,
steer_rate::AbstractFloat = icnf.steer_rate,
rng::AbstractRNG = Random.default_rng(),
) where {T <: AbstractFloat, AUGMENTED}
tspan
end

@inline function zeros_T_AT(
::AbstractFlows{T, <:CuArray},
dims...,
Expand Down
60 changes: 48 additions & 12 deletions src/base_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,26 @@ function inference_prob(
ps::Any,
st::Any;
tspan::NTuple{2} = icnf.tspan,
steer_rate::AbstractFloat = icnf.steer_rate,
basedist::Distribution = icnf.basedist,
differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend,
rng::AbstractRNG = Random.default_rng(),
sol_args::Tuple = icnf.sol_args,
sol_kwargs::Dict = icnf.sol_kwargs,
)
n_aug = n_augment(icnf, mode)
zrs = zeros_T_AT(icnf, n_aug + 1)
n_aug_input = n_augment_input(icnf)
zrs = zeros_T_AT(icnf, n_aug_input + n_aug + 1)
f_aug = augmented_f(icnf, mode, ys, st; differentiation_backend, rng)
func = ODEFunction{false, SciMLBase.FullSpecialize}(f_aug)
prob = ODEProblem{false, SciMLBase.FullSpecialize}(func, vcat(xs, zrs), tspan, ps)
prob = ODEProblem{false, SciMLBase.FullSpecialize}(
func,
vcat(xs, zrs),
steer_tspan(icnf, tspan, steer_rate, rng),
ps,
sol_args...;
sol_kwargs...,
)
prob
end

Expand All @@ -30,6 +39,7 @@ function inference(
ps::Any,
st::Any;
tspan::NTuple{2} = icnf.tspan,
steer_rate::AbstractFloat = icnf.steer_rate,
basedist::Distribution = icnf.basedist,
differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend,
rng::AbstractRNG = Random.default_rng(),
Expand All @@ -44,6 +54,7 @@ function inference(
ps,
st;
tspan,
steer_rate,
basedist,
differentiation_backend,
rng,
Expand All @@ -68,17 +79,26 @@ function inference_prob(
ps::Any,
st::Any;
tspan::NTuple{2} = icnf.tspan,
steer_rate::AbstractFloat = icnf.steer_rate,
basedist::Distribution = icnf.basedist,
differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend,
rng::AbstractRNG = Random.default_rng(),
sol_args::Tuple = icnf.sol_args,
sol_kwargs::Dict = icnf.sol_kwargs,
)
n_aug = n_augment(icnf, mode)
zrs = zeros_T_AT(icnf, n_aug + 1, size(xs, 2))
n_aug_input = n_augment_input(icnf)
zrs = zeros_T_AT(icnf, n_aug_input + n_aug + 1, size(xs, 2))
f_aug = augmented_f(icnf, mode, ys, st, size(xs, 2); differentiation_backend, rng)
func = ODEFunction{false, SciMLBase.FullSpecialize}(f_aug)
prob = ODEProblem{false, SciMLBase.FullSpecialize}(func, vcat(xs, zrs), tspan, ps)
prob = ODEProblem{false, SciMLBase.FullSpecialize}(
func,
vcat(xs, zrs),
steer_tspan(icnf, tspan, steer_rate, rng),
ps,
sol_args...;
sol_kwargs...,
)
prob
end

Expand All @@ -90,6 +110,7 @@ function inference(
ps::Any,
st::Any;
tspan::NTuple{2} = icnf.tspan,
steer_rate::AbstractFloat = icnf.steer_rate,
basedist::Distribution = icnf.basedist,
differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend,
rng::AbstractRNG = Random.default_rng(),
Expand All @@ -104,6 +125,7 @@ function inference(
ps,
st;
tspan,
steer_rate,
basedist,
differentiation_backend,
rng,
Expand All @@ -121,28 +143,31 @@ function inference(
end

function generate_prob(
icnf::AbstractCondICNF{<:AbstractFloat, AT, <:VectorMode},
icnf::AbstractCondICNF{T, AT, <:VectorMode},
mode::Mode,
ys::AbstractVector{<:Real},
ps::Any,
st::Any;
tspan::NTuple{2} = icnf.tspan,
steer_rate::AbstractFloat = icnf.steer_rate,
basedist::Distribution = icnf.basedist,
differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend,
rng::AbstractRNG = Random.default_rng(),
sol_args::Tuple = icnf.sol_args,
sol_kwargs::Dict = icnf.sol_kwargs,
) where {AT <: AbstractArray}
) where {T <: AbstractFloat, AT <: AbstractArray}
n_aug = n_augment(icnf, mode)
new_xs = convert(AT, rand(rng, basedist))
new_xs = convert(AT{T}, rand(rng, basedist))
zrs = zeros_T_AT(icnf, n_aug + 1)
f_aug = augmented_f(icnf, mode, ys, st; differentiation_backend, rng)
func = ODEFunction{false, SciMLBase.FullSpecialize}(f_aug)
prob = ODEProblem{false, SciMLBase.FullSpecialize}(
func,
vcat(new_xs, zrs),
reverse(tspan),
reverse(steer_tspan(icnf, tspan, steer_rate, rng)),
ps,
sol_args...;
sol_kwargs...,
)
prob
end
Expand All @@ -154,6 +179,7 @@ function generate(
ps::Any,
st::Any;
tspan::NTuple{2} = icnf.tspan,
steer_rate::AbstractFloat = icnf.steer_rate,
basedist::Distribution = icnf.basedist,
differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend,
rng::AbstractRNG = Random.default_rng(),
Expand All @@ -167,16 +193,18 @@ function generate(
ps,
st;
tspan,
steer_rate,
basedist,
differentiation_backend,
rng,
sol_args,
sol_kwargs,
)
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
sol = solve(prob, sol_args...; sol_kwargs...)
fsol = @view sol[:, end]
z = @view fsol[begin:(end - n_aug - 1)]
z = @view fsol[begin:(end - n_aug_input - n_aug - 1)]
z
end

Expand All @@ -188,22 +216,25 @@ function generate_prob(
st::Any,
n::Integer;
tspan::NTuple{2} = icnf.tspan,
steer_rate::AbstractFloat = icnf.steer_rate,
basedist::Distribution = icnf.basedist,
differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend,
rng::AbstractRNG = Random.default_rng(),
sol_args::Tuple = icnf.sol_args,
sol_kwargs::Dict = icnf.sol_kwargs,
) where {AT <: AbstractArray}
n_aug = n_augment(icnf, mode)
new_xs = convert(AT, rand(rng, basedist, n))
new_xs = convert(AT{T}, rand(rng, basedist, n))
zrs = zeros_T_AT(icnf, n_aug + 1, size(new_xs, 2))
f_aug = augmented_f(icnf, mode, ys, st, size(new_xs, 2); differentiation_backend, rng)
func = ODEFunction{false, SciMLBase.FullSpecialize}(f_aug)
prob = ODEProblem{false, SciMLBase.FullSpecialize}(
func,
vcat(new_xs, zrs),
reverse(tspan),
reverse(steer_tspan(icnf, tspan, steer_rate, rng)),
ps,
sol_args...;
sol_kwargs...,
)
prob
end
Expand All @@ -216,6 +247,7 @@ function generate(
st::Any,
n::Integer;
tspan::NTuple{2} = icnf.tspan,
steer_rate::AbstractFloat = icnf.steer_rate,
basedist::Distribution = icnf.basedist,
differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend,
rng::AbstractRNG = Random.default_rng(),
Expand All @@ -230,16 +262,18 @@ function generate(
st,
n;
tspan,
steer_rate,
basedist,
differentiation_backend,
rng,
sol_args,
sol_kwargs,
)
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
sol = solve(prob, sol_args...; sol_kwargs...)
fsol = @view sol[:, :, end]
z = @view fsol[begin:(end - n_aug - 1), :]
z = @view fsol[begin:(end - n_aug_input - n_aug - 1), :]
z
end

Expand All @@ -251,6 +285,7 @@ end
ps::Any,
st::Any;
tspan::NTuple{2} = icnf.tspan,
steer_rate::AbstractFloat = icnf.steer_rate,
basedist::Distribution = icnf.basedist,
differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend,
rng::AbstractRNG = Random.default_rng(),
Expand Down Expand Up @@ -283,6 +318,7 @@ end
ps::Any,
st::Any;
tspan::NTuple{2} = icnf.tspan,
steer_rate::AbstractFloat = icnf.steer_rate,
basedist::Distribution = icnf.basedist,
differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend,
rng::AbstractRNG = Random.default_rng(),
Expand Down
Loading

0 comments on commit 9a439b7

Please sign in to comment.