diff --git a/Project.toml b/Project.toml index b779519e..fd340a21 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,6 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" -DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -53,7 +52,6 @@ ChainRulesCore = "1" ComponentArrays = "0.14" ComputationalResources = "0.3" DataFrames = "1" -DifferentialEquations = "7" Distributions = "0.25" DistributionsAD = "0.6" FillArrays = "1" diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index bfbf23c8..d6a096b1 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -8,7 +8,6 @@ using AbstractDifferentiation, ComputationalResources, DataFrames, Dates, - DifferentialEquations, Distributions, DistributionsAD, FillArrays, diff --git a/src/base.jl b/src/base.jl index 96764d4c..22a82da1 100644 --- a/src/base.jl +++ b/src/base.jl @@ -18,8 +18,12 @@ function construct( autodiff_backend::ADTypes.AbstractADType = AutoZygote(), sol_args::Tuple = (), sol_kwargs::Dict = Dict( - :alg_hints => [:nonstiff, :memorybound], - :reltol => 1e-2 + eps(1e-2), + :alg_hints => [:nonstiff], + :alg => VCABM(), + :sensealg => InterpolatingAdjoint(; autodiff = true, autojacvec = ZygoteVJP()), + :reltol => sqrt(eps(one(Float32))), + :abstol => eps(one(Float32)), + :maxiters => typemax(Int32), ), rng::AbstractRNG = Random.default_rng(), ) diff --git a/src/cond_rnode.jl b/src/cond_rnode.jl index 8cca4c25..05e4bf17 100644 --- a/src/cond_rnode.jl +++ b/src/cond_rnode.jl @@ -57,8 +57,12 @@ function construct( autodiff_backend::ADTypes.AbstractADType = AutoZygote(), sol_args::Tuple = (), sol_kwargs::Dict = Dict( - :alg_hints => [:nonstiff, :memorybound], - :reltol => 1e-2 + eps(1e-2), + :alg_hints => [:nonstiff], + :alg => VCABM(), + :sensealg => InterpolatingAdjoint(; autodiff = true, autojacvec = ZygoteVJP()), + :reltol => sqrt(eps(one(Float32))), + :abstol => eps(one(Float32)), + :maxiters => typemax(Int32), ), rng::AbstractRNG = Random.default_rng(), λ₁::AbstractFloat = convert(data_type, 1e-2), diff --git a/src/rnode.jl b/src/rnode.jl index f0ea3bf3..61b25fc2 100644 --- a/src/rnode.jl +++ b/src/rnode.jl @@ -59,8 +59,12 @@ function construct( autodiff_backend::ADTypes.AbstractADType = AutoZygote(), sol_args::Tuple = (), sol_kwargs::Dict = Dict( - :alg_hints => [:nonstiff, :memorybound], - :reltol => 1e-2 + eps(1e-2), + :alg_hints => [:nonstiff], + :alg => VCABM(), + :sensealg => InterpolatingAdjoint(; autodiff = true, autojacvec = ZygoteVJP()), + :reltol => sqrt(eps(one(Float32))), + :abstol => eps(one(Float32)), + :maxiters => typemax(Int32), ), rng::AbstractRNG = Random.default_rng(), λ₁::AbstractFloat = convert(data_type, 1e-2),