Skip to content

Commit

Permalink
Rudimentary progress tracking for simulate (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
jClugstor authored Sep 24, 2024
1 parent b185ddd commit 458854a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/SimulationService.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ end
function publish_to_rabbitmq(content)
if !RABBITMQ_ENABLED[]
# stop printing content for now, getting to be too much
@warn "RabbitMQ disabled - `publish_to_rabbitmq`" # with content $(JSON3.write(content))"
@warn "RabbitMQ disabled - `publish_to_rabbitmq`" #with content $(JSON3.write(content))"
return content
end
json = Vector{UInt8}(codeunits(JSON3.write(content)))
Expand Down
30 changes: 16 additions & 14 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,6 @@ mutable struct IntermediateResults
new(0, every, id, 0)
end
end

function (o::IntermediateResults)(integrator)
(; iter, f, t, u, p) = integrator
if o.last_callback + o.every == iter
o.last_callback = iter
state_dict = Dict(unknowns(f.sys) .=> u)
param_dict = Dict(parameters(f.sys) .=> p)
publish_to_rabbitmq(; iter=iter, state=state_dict, params = param_dict, id=o.id,
retcode=SciMLBase.check_error(integrator))
end
EasyModelAnalysis.DifferentialEquations.u_modified!(integrator, false)
end

# Intermediate results functor for calibrate
function (o::IntermediateResults)(state,loss_val, ode_sol, ts)
if o.last_callback + o.every == o.iter
Expand Down Expand Up @@ -234,10 +221,25 @@ function get_callback(o::OperationRequest, ::Type{Simulate})
DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = 10))
end

function (o::IntermediateResults)(integrator)
(; iter, f, t, u, p, sol) = integrator
t_end = sol.prob.tspan[2]
percent = round((t/t_end)*100.0, digits = 2)
if o.last_callback + o.every == iter
o.last_callback = iter
#state_dict = Dict(states(f.sys) .=> u)
#param_dict = Dict(parameters(f.sys) .=> p)
publish_to_rabbitmq(;id=o.id,
retcode=SciMLBase.check_error(integrator), percent = percent)
end
EasyModelAnalysis.DifferentialEquations.u_modified!(integrator, false)
end


# callback for Simulate requests
function solve(op::Simulate; callback)
prob = ODEProblem(op.sys, [], op.timespan)
sol = solve(prob; progress = true, progress_steps = 1, saveat=1, callback = nothing)
sol = solve(prob; saveat=1, callback = callback)
@info "Timesteps returned are: $(sol.t)"
dataframe_with_observables(sol)
end
Expand Down
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ end
obj = SimulationService.get_json(json_url).configuration
sys = SimulationService.amr_get(obj, ODESystem)
op = Simulate(sys, (0.0, 99.0))
df = solve(op; callback = nothing)
call_op = OperationRequest() # to test callback
call_op.id = "1"
df = solve(op; callback = SimulationService.get_callback(call_op,SimulationService.Simulate))
@test df isa DataFrame
@test extrema(df.timestamp) == (0.0, 99.0)
end
Expand Down

0 comments on commit 458854a

Please sign in to comment.