diff --git a/src/base_cond_icnf.jl b/src/base_cond_icnf.jl index a797bab2..08454e12 100644 --- a/src/base_cond_icnf.jl +++ b/src/base_cond_icnf.jl @@ -15,10 +15,10 @@ export inference, generate, loss prob = ODEProblem{false, SciMLBase.FullSpecialize}( let icnf = icnf, mode = mode, ys = ys, ϵ = ϵ, st = st (u, p, t) -> augmented_f(u, p, t, icnf, mode, ys, ϵ, st) - end; - u0 = cat(xs, zrs; dims = 1), - tspan = steer_tspan(icnf, mode), - p = ps, + end, + cat(xs, zrs; dims = 1), + steer_tspan(icnf, mode), + ps; icnf.sol_kwargs..., ) prob @@ -62,10 +62,10 @@ end prob = ODEProblem{false, SciMLBase.FullSpecialize}( let icnf = icnf, mode = mode, ys = ys, ϵ = ϵ, st = st (u, p, t) -> augmented_f(u, p, t, icnf, mode, ys, ϵ, st) - end; - u0 = cat(xs, zrs; dims = 1), - tspan = steer_tspan(icnf, mode), - p = ps, + end, + cat(xs, zrs; dims = 1), + steer_tspan(icnf, mode), + ps; icnf.sol_kwargs..., ) prob @@ -109,10 +109,10 @@ end prob = ODEProblem{false, SciMLBase.FullSpecialize}( let icnf = icnf, mode = mode, ys = ys, ϵ = ϵ, st = st (u, p, t) -> augmented_f(u, p, t, icnf, mode, ys, ϵ, st) - end; - u0 = cat(new_xs, zrs; dims = 1), - tspan = reverse(steer_tspan(icnf, mode)), - p = ps, + end, + cat(new_xs, zrs; dims = 1), + reverse(steer_tspan(icnf, mode)), + ps; icnf.sol_kwargs..., ) prob @@ -150,10 +150,10 @@ end prob = ODEProblem{false, SciMLBase.FullSpecialize}( let icnf = icnf, mode = mode, ys = ys, ϵ = ϵ, st = st (u, p, t) -> augmented_f(u, p, t, icnf, mode, ys, ϵ, st) - end; - u0 = cat(new_xs, zrs; dims = 1), - tspan = reverse(steer_tspan(icnf, mode)), - p = ps, + end, + cat(new_xs, zrs; dims = 1), + reverse(steer_tspan(icnf, mode)), + ps; icnf.sol_kwargs..., ) prob diff --git a/src/base_icnf.jl b/src/base_icnf.jl index 201c8f9a..772d273c 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -14,10 +14,10 @@ export inference, generate, loss prob = ODEProblem{false, SciMLBase.FullSpecialize}( let icnf = icnf, mode = mode, ϵ = ϵ, st = st (u, p, t) -> augmented_f(u, p, t, icnf, mode, ϵ, st) - end; - u0 = cat(xs, zrs; dims = 1), - tspan = steer_tspan(icnf, mode), - p = ps, + end, + cat(xs, zrs; dims = 1), + steer_tspan(icnf, mode), + ps; icnf.sol_kwargs..., ) prob @@ -59,10 +59,10 @@ end prob = ODEProblem{false, SciMLBase.FullSpecialize}( let icnf = icnf, mode = mode, ϵ = ϵ, st = st (u, p, t) -> augmented_f(u, p, t, icnf, mode, ϵ, st) - end; - u0 = cat(xs, zrs; dims = 1), - tspan = steer_tspan(icnf, mode), - p = ps, + end, + cat(xs, zrs; dims = 1), + steer_tspan(icnf, mode), + ps; icnf.sol_kwargs..., ) prob @@ -104,10 +104,10 @@ end prob = ODEProblem{false, SciMLBase.FullSpecialize}( let icnf = icnf, mode = mode, ϵ = ϵ, st = st (u, p, t) -> augmented_f(u, p, t, icnf, mode, ϵ, st) - end; - u0 = cat(new_xs, zrs; dims = 1), - tspan = reverse(steer_tspan(icnf, mode)), - p = ps, + end, + cat(new_xs, zrs; dims = 1), + reverse(steer_tspan(icnf, mode)), + ps; icnf.sol_kwargs..., ) prob @@ -142,10 +142,10 @@ end prob = ODEProblem{false, SciMLBase.FullSpecialize}( let icnf = icnf, mode = mode, ϵ = ϵ, st = st (u, p, t) -> augmented_f(u, p, t, icnf, mode, ϵ, st) - end; - u0 = cat(new_xs, zrs; dims = 1), - tspan = reverse(steer_tspan(icnf, mode)), - p = ps, + end, + cat(new_xs, zrs; dims = 1), + reverse(steer_tspan(icnf, mode)), + ps; icnf.sol_kwargs..., ) prob