Skip to content

Commit

Permalink
add only to pullback & pushforward
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Oct 3, 2024
1 parent ffe1684 commit 912a1d3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
8 changes: 8 additions & 0 deletions src/icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ function augmented_f(
z = u[begin:(end - n_aug - 1)]
ż, ϵJ =
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
ϵJ = only(ϵJ)
= -LinearAlgebra.dot(ϵJ, ϵ)
= if NORM_Z
LinearAlgebra.norm(ż)
Expand Down Expand Up @@ -228,6 +229,7 @@ function augmented_f(
z = u[begin:(end - n_aug - 1)]
ż, ϵJ =
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
ϵJ = only(ϵJ)
du[begin:(end - n_aug - 1)] .=
du[(end - n_aug)] = -LinearAlgebra.dot(ϵJ, ϵ)
du[(end - n_aug + 1)] = if NORM_Z
Expand Down Expand Up @@ -262,6 +264,7 @@ function augmented_f(
z,
(ϵ,),
)
= only(Jϵ)
= -LinearAlgebra.dot(ϵ, Jϵ)
= if NORM_Z
LinearAlgebra.norm(ż)
Expand Down Expand Up @@ -296,6 +299,7 @@ function augmented_f(
z,
(ϵ,),
)
= only(Jϵ)
du[begin:(end - n_aug - 1)] .=
du[(end - n_aug)] = -LinearAlgebra.dot(ϵ, Jϵ)
du[(end - n_aug + 1)] = if NORM_Z
Expand Down Expand Up @@ -326,6 +330,7 @@ function augmented_f(
z = u[begin:(end - n_aug - 1), :]
ż, ϵJ =
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
ϵJ = only(ϵJ)
= -sum(ϵJ .* ϵ; dims = 1)
= transpose(if NORM_Z
LinearAlgebra.norm.(eachcol(ż))
Expand Down Expand Up @@ -360,6 +365,7 @@ function augmented_f(
z = u[begin:(end - n_aug - 1), :]
ż, ϵJ =
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
ϵJ = only(ϵJ)
du[begin:(end - n_aug - 1), :] .=
du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1))
du[(end - n_aug + 1), :] .= if NORM_Z
Expand Down Expand Up @@ -394,6 +400,7 @@ function augmented_f(
z,
(ϵ,),
)
= only(Jϵ)
= -sum.* Jϵ; dims = 1)
= transpose(if NORM_Z
LinearAlgebra.norm.(eachcol(ż))
Expand Down Expand Up @@ -432,6 +439,7 @@ function augmented_f(
z,
(ϵ,),
)
= only(Jϵ)
du[begin:(end - n_aug - 1), :] .=
du[(end - n_aug), :] .= -vec(sum.* Jϵ; dims = 1))
du[(end - n_aug + 1), :] .= if NORM_Z
Expand Down
7 changes: 4 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
for i in axes(xs, 1)
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
res[i, :, :] =
DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (z,))
only(DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (z,)))
ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T)
end
y, eachslice(copy(res); dims = 3)
Expand All @@ -27,8 +27,9 @@ end
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
for i in axes(xs, 1)
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
res[:, i, :] =
DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, (z,))
res[:, i, :] = only(
DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, (z,)),
)
ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T)
end
y, eachslice(copy(res); dims = 3)
Expand Down

0 comments on commit 912a1d3

Please sign in to comment.