Skip to content

Commit

Permalink
better itr_n
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Aug 4, 2023
1 parent cd716a3 commit 6e5a154
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
10 changes: 7 additions & 3 deletions src/core.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# SciML interface

function callback_f(ps, l, icnf::AbstractFlows, prgr::Progress, itr_n::AbstractVector)
function callback_f(ps, l, icnf::AbstractFlows, prgr::Progress, itr_n::AbstractArray)
ProgressMeter.next!(
prgr;
showvalues = [(:loss_value, l), (:iteration, itr_n[]), (:last_update, Dates.now())],
showvalues = [
(:loss_value, l),
(:iteration, only(itr_n)),
(:last_update, Dates.now()),
],
)
itr_n[] += one(itr_n[])
itr_n[] += one(only(itr_n))
false
end

Expand Down
2 changes: 1 addition & 1 deletion src/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
desc = "Fitting (epoch: $ep of $(model.n_epochs)): ",
showspeed = true,
)
itr_n = [1]
itr_n = ones(Int)
tst_one = @timed res = solve(
optprob_re,
opt,
Expand Down
2 changes: 1 addition & 1 deletion src/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
desc = "Fitting (epoch: $ep of $(model.n_epochs)): ",
showspeed = true,
)
itr_n = [1]
itr_n = ones(Int)
tst_one = @timed res = solve(
optprob_re,
opt,
Expand Down

0 comments on commit 6e5a154

Please sign in to comment.