diff --git a/README.md b/README.md index f944b915..b5a7d543 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,13 @@ icnf = construct( λ₁ = 1.0f-2, # regulate flow λ₂ = 1.0f-2, # regulate volume change λ₃ = 1.0f-2, # regulate augmented dimensions + sol_kwargs = (; + progress = true, + save_everystep = false, + reltol = sqrt(eps(one(Float32))), + abstol = eps(one(Float32)), + maxiters = typemax(Int32), + ), # pass to the solver ) # Data @@ -76,6 +83,9 @@ model = ICNFModel( # adtype = AutoZygote(), # use_batch = true, # batch_size = 32, + sol_kwargs = (; + progress = true, + ), # pass to the solver ) mach = machine(model, df) fit!(mach) diff --git a/src/base_icnf.jl b/src/base_icnf.jl index cbd1d58f..3c9b1d66 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -23,17 +23,7 @@ function construct( FillArrays.Zeros{data_type}(nvars + naugmented), FillArrays.Eye{data_type}(nvars + naugmented), ), - suggested_sol_kwargs::Bool = false, - sol_kwargs::NamedTuple = if suggested_sol_kwargs - (; - save_everystep = false, - reltol = sqrt(eps(one(Float32))), - abstol = eps(one(Float32)), - maxiters = typemax(Int32), - ) - else - (;) - end, + sol_kwargs::NamedTuple = (;), rng::Random.AbstractRNG = rng_AT(resource), λ₁::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} convert(data_type, 1.0e-2)