From dc15bc46748b75e704cf989feae5f102c0ef52f5 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Wed, 16 Oct 2024 01:48:11 +0330 Subject: [PATCH] add `sol_kwargs` to mlj models --- Project.toml | 2 +- benchmark/Project.toml | 2 +- docs/Project.toml | 2 +- src/exts/mlj_ext/core_cond_icnf.jl | 6 ++++-- src/exts/mlj_ext/core_icnf.jl | 6 ++++-- test/Project.toml | 2 +- 6 files changed, 12 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 244b7595..51c0f159 100644 --- a/Project.toml +++ b/Project.toml @@ -69,4 +69,4 @@ SciMLSensitivity = "7" ScientificTypesBase = "3" Statistics = "1" Zygote = "0.6" -julia = "1.9" +julia = "1.10" diff --git a/benchmark/Project.toml b/benchmark/Project.toml index e95bd580..41c42fad 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -19,4 +19,4 @@ Lux = "1" PkgBenchmark = "0.2" StableRNGs = "1" Zygote = "0.6" -julia = "1.9" +julia = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index 2a70d53d..14da2a40 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,4 +3,4 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" [compat] Documenter = "1" -julia = "1.9" +julia = "1.10" diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 3197a846..cb82d654 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -8,6 +8,7 @@ mutable struct CondICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} use_batch::Bool batch_size::Int + sol_kwargs::NamedTuple end function CondICNFModel( @@ -18,8 +19,9 @@ function CondICNFModel( adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), use_batch::Bool = true, batch_size::Int = 32, + sol_kwargs::NamedTuple = (;), ) - CondICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size) + CondICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size, sol_kwargs) end function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) @@ -64,8 +66,8 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) tst_epochs = @timed res = SciMLBase.solve( optprob_re, opt; - progress = true, epochs = model.n_epochs, + model.sol_kwargs..., ) ps .= res.u @info( diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 8878c6c7..451a3cbc 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -8,6 +8,7 @@ mutable struct ICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} use_batch::Bool batch_size::Int + sol_kwargs::NamedTuple end function ICNFModel( @@ -18,8 +19,9 @@ function ICNFModel( adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), use_batch::Bool = true, batch_size::Int = 32, + sol_kwargs::NamedTuple = (;), ) - ICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size) + ICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size, sol_kwargs) end function MLJModelInterface.fit(model::ICNFModel, verbosity, X) @@ -62,8 +64,8 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X) tst_epochs = @timed res = SciMLBase.solve( optprob_re, opt; - progress = true, epochs = model.n_epochs, + model.sol_kwargs..., ) ps .= res.u @info( diff --git a/test/Project.toml b/test/Project.toml index e7679324..5dbd58dd 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,4 +43,4 @@ StableRNGs = "1" TerminalLoggers = "0.1" Zygote = "0.6" cuDNN = "1" -julia = "1.9" +julia = "1.10"