-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multithreaded sampling #131
Comments
Not sure about this one. Maybe @torfjelde @willtebbutt have more insight? |
Yeah I would be very surprised if Zygote.jl worked with threads like this. You should probably look into something like Transducers.jl or something that defines a parallel way to perform a I would be surprised if something like this doesn't already exist in a package, but not 100% up to date on this. Maybe @devmotion knows of one? |
Thanks @torfjelde for your suggestion. I have tried implemeting a custom rule like this: function ChainRulesCore.rrule(
::typeof(AdvancedVI.estimate_energy_with_samples), prob, samples
)
fn = Base.Fix1(LogDensityProblems.logdensity, prob)
fn_samples =
fetch.([
Threads.@spawn Zygote.pullback(fn, sample) for
sample in AdvancedVI.eachsample(samples)
])
values = [sample[1] for sample in fn_samples]
pullbacks = [sample[2] for sample in fn_samples]
function estimate_energy_with_samples_pullback(ȳ)
grads = [pullback(ȳ_i)[1] for (ȳ_i, pullback) in zip(ȳ, pullbacks)]
ret = mean(grads)
return (NoTangent(), NoTangent(), ret)
end
return mean(values), estimate_energy_with_samples_pullback
end This works pretty well, but somehow the variance in the ELBO seems to be a bit lower with ForwardDiff: so I am wondering if I'm doing something wrong in the custom rule. Thanks for your help! |
From the snippet you shared, it doesn't seem like you're using the same RNG? If so, that could just be the cause of it. Second thought, though this seems unlikely IMO, is that there might be numerical differences in the rules used by the two approaches. But yeah, RNG should be ruled out first. |
I have tried to implement multithreaded sampling by changing:
However, while this works when using the
AutoForwardDiff()
AD backend, it fails (silently) when usingZygote
. I am guessing that this is due to Zygote not being thread safe here?Code:
1. Zygote no threading
2. Zygote with threading
3. ForwardDiff with threading
The text was updated successfully, but these errors were encountered: