Skip to content

Commit

Permalink
merge Planar and FFJORD (#360)
Browse files Browse the repository at this point in the history
* merge Planar and FFJORD

* rename vars

* move the sign to l̇

* Format .jl files (#361)

Co-authored-by: prbzrg <[email protected]>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: prbzrg <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2024
1 parent 787c35a commit 442cdad
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 243 deletions.
48 changes: 24 additions & 24 deletions src/cond_ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ end
u::Any,
p::Any,
t::Any,
icnf::Union{CondRNODE{T, <:ADVecJacVectorMode}, CondFFJORD{T, <:ADVecJacVectorMode}},
icnf::AbstractCondICNF{T, <:ADVecJacVectorMode},
mode::TrainMode,
ys::AbstractVector{<:Real},
ϵ::AbstractVector{T},
Expand All @@ -153,21 +153,21 @@ end
z,
)
ϵJ = only(VJ(ϵ))
= ϵJ ϵ
= -(ϵJ ϵ)
if icnf isa CondRNODE
= norm(ż)
= norm(ϵJ)
vcat(ż, -l̇, Ė, ṅ)
vcat(ż, l̇, Ė, ṅ)
else
vcat(ż, -l̇)
vcat(ż, l̇)
end
end

@views function augmented_f(
u::Any,
p::Any,
t::Any,
icnf::Union{CondRNODE{T, <:ADJacVecVectorMode}, CondFFJORD{T, <:ADJacVecVectorMode}},
icnf::AbstractCondICNF{T, <:ADJacVecVectorMode},
mode::TrainMode,
ys::AbstractVector{<:Real},
ϵ::AbstractVector{T},
Expand All @@ -184,21 +184,21 @@ end
)
ż, Jϵ = ż_JV(ϵ)
= only(Jϵ)
= ϵ
= -(ϵ )
if icnf isa CondRNODE
= norm(ż)
= norm(Jϵ)
vcat(ż, -l̇, Ė, ṅ)
vcat(ż, l̇, Ė, ṅ)
else
vcat(ż, -l̇)
vcat(ż, l̇)
end
end

@views function augmented_f(
u::Any,
p::Any,
t::Any,
icnf::Union{CondRNODE{T, <:ZygoteVectorMode}, CondFFJORD{T, <:ZygoteVectorMode}},
icnf::AbstractCondICNF{T, <:ZygoteVectorMode},
mode::TrainMode,
ys::AbstractVector{<:Real},
ϵ::AbstractVector{T},
Expand All @@ -210,21 +210,21 @@ end
x -> first(icnf.nn(vcat(x, ys), p, st))
end, z)
ϵJ = only(VJ(ϵ))
= ϵJ ϵ
= -(ϵJ ϵ)
if icnf isa CondRNODE
= norm(ż)
= norm(ϵJ)
vcat(ż, -l̇, Ė, ṅ)
vcat(ż, l̇, Ė, ṅ)
else
vcat(ż, -l̇)
vcat(ż, l̇)
end
end

@views function augmented_f(
u::Any,
p::Any,
t::Any,
icnf::Union{CondRNODE{T, <:SDVecJacMatrixMode}, CondFFJORD{T, <:SDVecJacMatrixMode}},
icnf::AbstractCondICNF{T, <:SDVecJacMatrixMode},
mode::TrainMode,
ys::AbstractMatrix{<:Real},
ϵ::AbstractMatrix{T},
Expand All @@ -241,21 +241,21 @@ end
autodiff = icnf.autodiff_backend,
)
ϵJ = reshape(Jf * ϵ, size(z))
= sum(ϵJ .* ϵ; dims = 1)
= -sum(ϵJ .* ϵ; dims = 1)
if icnf isa CondRNODE
= transpose(norm.(eachcol(ż)))
= transpose(norm.(eachcol(ϵJ)))
vcat(ż, -l̇, Ė, ṅ)
vcat(ż, l̇, Ė, ṅ)
else
vcat(ż, -l̇)
vcat(ż, l̇)
end
end

@views function augmented_f(
u::Any,
p::Any,
t::Any,
icnf::Union{CondRNODE{T, <:SDJacVecMatrixMode}, CondFFJORD{T, <:SDJacVecMatrixMode}},
icnf::AbstractCondICNF{T, <:SDJacVecMatrixMode},
mode::TrainMode,
ys::AbstractMatrix{<:Real},
ϵ::AbstractMatrix{T},
Expand All @@ -272,21 +272,21 @@ end
autodiff = icnf.autodiff_backend,
)
= reshape(Jf * ϵ, size(z))
= sum.* Jϵ; dims = 1)
= -sum.* Jϵ; dims = 1)
if icnf isa CondRNODE
= transpose(norm.(eachcol(ż)))
= transpose(norm.(eachcol(Jϵ)))
vcat(ż, -l̇, Ė, ṅ)
vcat(ż, l̇, Ė, ṅ)
else
vcat(ż, -l̇)
vcat(ż, l̇)
end
end

@views function augmented_f(
u::Any,
p::Any,
t::Any,
icnf::Union{CondRNODE{T, <:ZygoteMatrixMode}, CondFFJORD{T, <:ZygoteMatrixMode}},
icnf::AbstractCondICNF{T, <:ZygoteMatrixMode},
mode::TrainMode,
ys::AbstractMatrix{<:Real},
ϵ::AbstractMatrix{T},
Expand All @@ -298,13 +298,13 @@ end
x -> first(icnf.nn(vcat(x, ys), p, st))
end, z)
ϵJ = only(VJ(ϵ))
= sum(ϵJ .* ϵ; dims = 1)
= -sum(ϵJ .* ϵ; dims = 1)
if icnf isa CondRNODE
= transpose(norm.(eachcol(ż)))
= transpose(norm.(eachcol(ϵJ)))
vcat(ż, -l̇, Ė, ṅ)
vcat(ż, l̇, Ė, ṅ)
else
vcat(ż, -l̇)
vcat(ż, l̇)
end
end

Expand Down
118 changes: 15 additions & 103 deletions src/cond_planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1)]
mz = first(icnf.nn(vcat(z, ys), p, st))
trace_J =
= first(icnf.nn(vcat(z, ys), p, st))
= -(
p.u transpose(
only(
AbstractDifferentiation.jacobian(
Expand All @@ -59,7 +59,8 @@ end
),
),
)
vcat(mz, -trace_J)
)
vcat(ż, l̇)
end

@views function augmented_f(
Expand All @@ -74,8 +75,8 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1)]
mz = first(icnf.nn(vcat(z, ys), p, st))
trace_J =
= first(icnf.nn(vcat(z, ys), p, st))
= -(
p.u transpose(
only(
AbstractDifferentiation.jacobian(
Expand All @@ -87,7 +88,8 @@ end
),
),
)
vcat(mz, -trace_J)
)
vcat(ż, l̇)
end

@views function augmented_f(
Expand All @@ -102,14 +104,15 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1)]
mz = first(icnf.nn(vcat(z, ys), p, st))
trace_J =
= first(icnf.nn(vcat(z, ys), p, st))
= -(
p.u transpose(
only(Zygote.jacobian(let ys = ys, p = p, st = st
x -> first(pl_h(icnf.nn, vcat(x, ys), p, st))
end, z)),
)
vcat(mz, -trace_J)
)
vcat(ż, l̇)
end

@views function augmented_f(
Expand All @@ -124,104 +127,13 @@ end
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1)]
mz = first(icnf.nn(vcat(z, ys), p, st))
trace_J =
= first(icnf.nn(vcat(z, ys), p, st))
= -(
p.u transpose(
only(Zygote.jacobian(let ys = ys, p = p, st = st
x -> first(pl_h(icnf.nn, vcat(x, ys), p, st))
end, z)),
)
vcat(mz, -trace_J)
end

@views function augmented_f(
u::Any,
p::Any,
t::Any,
icnf::CondPlanar{T, <:SDVecJacMatrixMode},
mode::TrainMode,
ys::AbstractMatrix{<:Real},
ϵ::AbstractMatrix{T},
st::Any,
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz = first(icnf.nn(vcat(z, ys), p, st))
Jf = VecJac(
let ys = ys, p = p, st = st
x -> first(icnf.nn(vcat(x, ys), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
)
ϵJ = reshape(Jf * ϵ, size(z))
trace_J = sum(ϵJ .* ϵ; dims = 1)
vcat(mz, -trace_J)
end

@views function augmented_f(
u::Any,
p::Any,
t::Any,
icnf::CondPlanar{T, <:SDJacVecMatrixMode},
mode::TrainMode,
ys::AbstractMatrix{<:Real},
ϵ::AbstractMatrix{T},
st::Any,
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz = first(icnf.nn(vcat(z, ys), p, st))
Jf = JacVec(
let ys = ys, p = p, st = st
x -> first(icnf.nn(vcat(x, ys), p, st))
end,
z;
autodiff = icnf.autodiff_backend,
)
= reshape(Jf * ϵ, size(z))
trace_J = sum.* Jϵ; dims = 1)
vcat(mz, -trace_J)
end

@views function augmented_f(
u::Any,
p::Any,
t::Any,
icnf::CondPlanar{T, <:ZygoteMatrixMode},
mode::TrainMode,
ys::AbstractMatrix{<:Real},
ϵ::AbstractMatrix{T},
st::Any,
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz, VJ = Zygote.pullback(let ys = ys, p = p, st = st
x -> first(icnf.nn(vcat(x, ys), p, st))
end, z)
ϵJ = only(VJ(ϵ))
trace_J = sum(ϵJ .* ϵ; dims = 1)
vcat(mz, -trace_J)
end

@views function augmented_f(
du::Any,
u::Any,
p::Any,
t::Any,
icnf::CondPlanar{T, <:ZygoteMatrixModeInplace, true},
mode::TrainMode,
ys::AbstractMatrix{<:Real},
ϵ::AbstractMatrix{T},
st::Any,
) where {T <: AbstractFloat}
n_aug = n_augment(icnf, mode)
z = u[begin:(end - n_aug - 1), :]
mz, VJ = Zygote.pullback(let ys = ys, p = p, st = st
x -> first(icnf.nn(vcat(x, ys), p, st))
end, z)
ϵJ = only(VJ(ϵ))
du[begin:(end - n_aug - 1), :] .= mz
du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1))
nothing
vcat(ż, l̇)
end
Loading

0 comments on commit 442cdad

Please sign in to comment.