Skip to content

Commit

Permalink
add sol_kwargs to mlj models
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Oct 15, 2024
1 parent 315881d commit dc15bc4
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ SciMLSensitivity = "7"
ScientificTypesBase = "3"
Statistics = "1"
Zygote = "0.6"
julia = "1.9"
julia = "1.10"
2 changes: 1 addition & 1 deletion benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ Lux = "1"
PkgBenchmark = "0.2"
StableRNGs = "1"
Zygote = "0.6"
julia = "1.9"
julia = "1.10"
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

[compat]
Documenter = "1"
julia = "1.9"
julia = "1.10"
6 changes: 4 additions & 2 deletions src/exts/mlj_ext/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mutable struct CondICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF}

use_batch::Bool
batch_size::Int
sol_kwargs::NamedTuple
end

function CondICNFModel(
Expand All @@ -18,8 +19,9 @@ function CondICNFModel(
adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(),
use_batch::Bool = true,
batch_size::Int = 32,
sol_kwargs::NamedTuple = (;),
)
CondICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size)
CondICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size, sol_kwargs)
end

function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
Expand Down Expand Up @@ -64,8 +66,8 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
tst_epochs = @timed res = SciMLBase.solve(
optprob_re,
opt;
progress = true,
epochs = model.n_epochs,
model.sol_kwargs...,
)
ps .= res.u
@info(
Expand Down
6 changes: 4 additions & 2 deletions src/exts/mlj_ext/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mutable struct ICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF}

use_batch::Bool
batch_size::Int
sol_kwargs::NamedTuple
end

function ICNFModel(
Expand All @@ -18,8 +19,9 @@ function ICNFModel(
adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(),
use_batch::Bool = true,
batch_size::Int = 32,
sol_kwargs::NamedTuple = (;),
)
ICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size)
ICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size, sol_kwargs)
end

function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
Expand Down Expand Up @@ -62,8 +64,8 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
tst_epochs = @timed res = SciMLBase.solve(
optprob_re,
opt;
progress = true,
epochs = model.n_epochs,
model.sol_kwargs...,
)
ps .= res.u
@info(
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ StableRNGs = "1"
TerminalLoggers = "0.1"
Zygote = "0.6"
cuDNN = "1"
julia = "1.9"
julia = "1.10"

0 comments on commit dc15bc4

Please sign in to comment.