diff --git a/src/base_icnf.jl b/src/base_icnf.jl index 81311119..cac6e1cf 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -236,9 +236,9 @@ function inference_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(icnf.nn, ps, st) + nn = icnf.nn SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(xs, zrs), steer_tspan(icnf, mode), ps, @@ -259,9 +259,9 @@ function inference_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st) + nn = CondLayer(icnf.nn, ys) SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(xs, zrs), steer_tspan(icnf, mode), ps, @@ -281,9 +281,9 @@ function inference_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2)) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(icnf.nn, ps, st) + nn = icnf.nn SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(xs, zrs), steer_tspan(icnf, mode), ps, @@ -304,9 +304,9 @@ function inference_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2)) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st) + nn = CondLayer(icnf.nn, ys) SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(xs, zrs), steer_tspan(icnf, mode), ps, @@ -327,9 +327,9 @@ function generate_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(icnf.nn, ps, st) + nn = icnf.nn SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(new_xs, zrs), reverse(steer_tspan(icnf, mode)), ps, @@ -351,9 +351,9 @@ function generate_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st) + nn = CondLayer(icnf.nn, ys) SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(new_xs, zrs), reverse(steer_tspan(icnf, mode)), ps, @@ -375,9 +375,9 @@ function generate_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(icnf.nn, ps, st) + nn = icnf.nn SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(new_xs, zrs), reverse(steer_tspan(icnf, mode)), ps, @@ -400,9 +400,9 @@ function generate_prob( ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = Lux.StatefulLuxLayer(CondLayer(icnf.nn, ys), ps, st) + nn = CondLayer(icnf.nn, ys) SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, ϵ), + make_ode_func(icnf, mode, nn, st, ϵ), vcat(new_xs, zrs), reverse(steer_tspan(icnf, mode)), ps, @@ -515,28 +515,21 @@ end @inline function make_ode_func( icnf::AbstractICNF{T, CM, INPLACE}, mode::Mode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVecOrMat{T}, ) where {T <: AbstractFloat, CM, INPLACE} function ode_func_op(u, p, t) - augmented_f(u, p, t, icnf, mode, nn, ϵ) + augmented_f(u, p, t, icnf, mode, nn, st, ϵ) end function ode_func_ip(du, u, p, t) - augmented_f(du, u, p, t, icnf, mode, nn, ϵ) + augmented_f(du, u, p, t, icnf, mode, nn, st, ϵ) end ifelse(INPLACE, ode_func_ip, ode_func_op) end -@inline function make_dyn_func(nn::Lux.StatefulLuxLayer, ps::Any) - function dyn_func(x) - LuxCore.apply(nn, x, ps) - end - - dyn_func -end - @inline function (icnf::AbstractICNF{T, CM, INPLACE, false})( xs::AbstractVecOrMat, ps::Any, diff --git a/src/icnf.jl b/src/icnf.jl index 908169a2..c3cd6936 100644 --- a/src/icnf.jl +++ b/src/icnf.jl @@ -116,16 +116,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:ADVectorMode, false}, mode::TestMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, J = AbstractDifferentiation.value_and_jacobian( - icnf.differentiation_backend, - make_dyn_func(nn, p), - z, - ) + ż, J = AbstractDifferentiation.value_and_jacobian(icnf.differentiation_backend, snn, z) l̇ = -LinearAlgebra.tr(only(J)) vcat(ż, l̇) end @@ -137,16 +135,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:ADVectorMode, true}, mode::TestMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, J = AbstractDifferentiation.value_and_jacobian( - icnf.differentiation_backend, - make_dyn_func(nn, p), - z, - ) + ż, J = AbstractDifferentiation.value_and_jacobian(icnf.differentiation_backend, snn, z) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.tr(only(J)) nothing @@ -158,16 +154,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIVectorMode, false}, mode::TestMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, J = DifferentiationInterface.value_and_jacobian( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ) + ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.autodiff_backend, z) l̇ = -LinearAlgebra.tr(J) vcat(ż, l̇) end @@ -179,16 +173,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIVectorMode, true}, mode::TestMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, J = DifferentiationInterface.value_and_jacobian( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ) + ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.autodiff_backend, z) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.tr(J) nothing @@ -200,12 +192,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:MatrixMode, false}, mode::TestMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractMatrix{T}, ) where {T <: AbstractFloat} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, J = jacobian_batched(icnf, make_dyn_func(nn, p), z) + ż, J = jacobian_batched(icnf, snn, z) l̇ = -transpose(LinearAlgebra.tr.(J)) vcat(ż, l̇) end @@ -217,12 +211,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:MatrixMode, true}, mode::TestMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractMatrix{T}, ) where {T <: AbstractFloat} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, J = jacobian_batched(icnf, make_dyn_func(nn, p), z) + ż, J = jacobian_batched(icnf, snn, z) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -(LinearAlgebra.tr.(J)) nothing @@ -234,14 +230,16 @@ function augmented_f( ::Any, icnf::ICNF{T, <:ADVecJacVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, VJ = AbstractDifferentiation.value_and_pullback_function( icnf.differentiation_backend, - make_dyn_func(nn, p), + snn, z, ) ϵJ = only(VJ(ϵ)) @@ -266,14 +264,16 @@ function augmented_f( ::Any, icnf::ICNF{T, <:ADVecJacVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, VJ = AbstractDifferentiation.value_and_pullback_function( icnf.differentiation_backend, - make_dyn_func(nn, p), + snn, z, ) ϵJ = only(VJ(ϵ)) @@ -298,14 +298,16 @@ function augmented_f( ::Any, icnf::ICNF{T, <:ADJacVecVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] ż_JV = AbstractDifferentiation.value_and_pushforward_function( icnf.differentiation_backend, - make_dyn_func(nn, p), + snn, z, ) ż, Jϵ = ż_JV(ϵ) @@ -331,14 +333,16 @@ function augmented_f( ::Any, icnf::ICNF{T, <:ADJacVecVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] ż_JV = AbstractDifferentiation.value_and_pushforward_function( icnf.differentiation_backend, - make_dyn_func(nn, p), + snn, z, ) ż, Jϵ = ż_JV(ϵ) @@ -364,17 +368,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIVecJacVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, ϵJ = DifferentiationInterface.value_and_pullback( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, ϵJ = DifferentiationInterface.value_and_pullback(snn, icnf.autodiff_backend, z, ϵ) l̇ = -LinearAlgebra.dot(ϵJ, ϵ) Ė = if NORM_Z LinearAlgebra.norm(ż) @@ -396,17 +397,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIVecJacVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, ϵJ = DifferentiationInterface.value_and_pullback( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, ϵJ = DifferentiationInterface.value_and_pullback(snn, icnf.autodiff_backend, z, ϵ) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.dot(ϵJ, ϵ) du[(end - n_aug + 1)] = if NORM_Z @@ -428,17 +426,15 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIJacVecVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, Jϵ = DifferentiationInterface.value_and_pushforward( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, Jϵ = + DifferentiationInterface.value_and_pushforward(snn, icnf.autodiff_backend, z, ϵ) l̇ = -LinearAlgebra.dot(ϵ, Jϵ) Ė = if NORM_Z LinearAlgebra.norm(ż) @@ -460,17 +456,15 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIJacVecVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractVector{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, Jϵ = DifferentiationInterface.value_and_pushforward( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, Jϵ = + DifferentiationInterface.value_and_pushforward(snn, icnf.autodiff_backend, z, ϵ) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.dot(ϵ, Jϵ) du[(end - n_aug + 1)] = if NORM_Z @@ -492,17 +486,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIVecJacMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractMatrix{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, ϵJ = DifferentiationInterface.value_and_pullback( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, ϵJ = DifferentiationInterface.value_and_pullback(snn, icnf.autodiff_backend, z, ϵ) l̇ = -sum(ϵJ .* ϵ; dims = 1) Ė = transpose(if NORM_Z LinearAlgebra.norm.(eachcol(ż)) @@ -528,17 +519,14 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIVecJacMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractMatrix{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, ϵJ = DifferentiationInterface.value_and_pullback( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, ϵJ = DifferentiationInterface.value_and_pullback(snn, icnf.autodiff_backend, z, ϵ) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1)) du[(end - n_aug + 1), :] .= if NORM_Z @@ -560,17 +548,15 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIJacVecMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractMatrix{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, Jϵ = DifferentiationInterface.value_and_pushforward( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, Jϵ = + DifferentiationInterface.value_and_pushforward(snn, icnf.autodiff_backend, z, ϵ) l̇ = -sum(ϵ .* Jϵ; dims = 1) Ė = transpose(if NORM_Z LinearAlgebra.norm.(eachcol(ż)) @@ -596,17 +582,15 @@ function augmented_f( ::Any, icnf::ICNF{T, <:DIJacVecMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, mode::TrainMode, - nn::Lux.StatefulLuxLayer, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, ϵ::AbstractMatrix{T}, ) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, Jϵ = DifferentiationInterface.value_and_pushforward( - make_dyn_func(nn, p), - icnf.autodiff_backend, - z, - ϵ, - ) + ż, Jϵ = + DifferentiationInterface.value_and_pushforward(snn, icnf.autodiff_backend, z, ϵ) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵ .* Jϵ; dims = 1)) du[(end - n_aug + 1), :] .= if NORM_Z diff --git a/src/utils.jl b/src/utils.jl index 33559c85..08108daf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,6 @@ @inline function jacobian_batched( icnf::AbstractICNF{T, <:DIVecJacMatrixMode}, - f::Function, + f::Lux.StatefulLuxLayer, xs::AbstractMatrix{<:Real}, ) where {T} y, VJ = DifferentiationInterface.value_and_pullback_split(f, icnf.autodiff_backend, xs) @@ -17,7 +17,7 @@ end @inline function jacobian_batched( icnf::AbstractICNF{T, <:DIJacVecMatrixMode}, - f::Function, + f::Lux.StatefulLuxLayer, xs::AbstractMatrix{<:Real}, ) where {T} y = f(xs) @@ -34,7 +34,7 @@ end @inline function jacobian_batched( icnf::AbstractICNF{T, <:DIMatrixMode}, - f::Function, + f::Lux.StatefulLuxLayer, xs::AbstractMatrix{<:Real}, ) where {T} y, J = DifferentiationInterface.value_and_jacobian(f, icnf.autodiff_backend, xs)