Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Progress reporting in parallel sampling #2264

Open
SamuelBrand1 opened this issue Jun 12, 2024 · 3 comments
Open

Progress reporting in parallel sampling #2264

SamuelBrand1 opened this issue Jun 12, 2024 · 3 comments

Comments

@SamuelBrand1
Copy link

Hi everyone,

One thing I've noticed is that progress reporting when doing chains in parallel (for example using MCMCThreads()) is not informative, the progress meter only updates when a chain is finished rather than reporting within chain progress (as per serial sampling).

Is there any movement towards chain-by-chain progress reporting as per stan?

@torfjelde
Copy link
Member

This is an issue that has come up fairly often but AFAIK no perfect solution exists.

Ref: TuringLang/AbstractMCMC.jl#82 TuringLang/AbstractMCMC.jl#105

There's a discourse thread where someone seems to have come up with a "solution" (https://discourse.julialang.org/t/displaying-parallel-progress-bars/4148/8), but that's ages ago and not sure if that solution still works.

Note that you can provide an arbitrary callback to sample which is executed after every step where you could so custom progress-keeping, but atm there's no good built-in solution unfortunately 😕

@SamuelBrand1
Copy link
Author

Thanks for flagging this up @torfjelde ! I guess this will keep circling around :-(.

@torfjelde
Copy link
Member

Might be possible to do something with this: timholy/ProgressMeter.jl#157

In fact, if I use that branch + some minor changes to AbstractMCMC.jl, the following

using ProgressMeter
using Turing

struct ProgressCallback{P}
    p::P
    index::Int
end

function (callback::ProgressCallback)(rng, model, sampler, sample, state, iteration; kwargs...)
    # Can do more stuff here if you want.
    next!(callback.p[callback.index])
end

@model demo() = x ~ Normal()
model = demo()

num_samples = 100_000
num_chains = 10
p = MultipleProgress(
    [Progress(num_samples; desc="Chain $i ") for i in 1:num_chains],
    Progress(num_samples * num_chains; desc="Total ")
)
callbacks = map(1:num_chains) do i
    ProgressCallback(p, i)
end
chain = sample(
    model,
    HMC(0.1, 32),
    MCMCThreads(),
    num_samples,
    num_chains,
    callback=callbacks,
    progress=false,
    thinning=10
)

results in

image

It's small so not sure if you can see it, but it creates one bar for each thread + a global progress bar.

(note that this relies on minor changes to abstractmcmc + that experimental branch of progressmeter, which only supports the REPL, not, say, IJulia)

Miiight be worth adopting this in TuringCallbacks.jl as a bridge until there's good solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants