Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multithreaded sampling #131

Open
arnauqb opened this issue Oct 12, 2024 · 4 comments
Open

Multithreaded sampling #131

arnauqb opened this issue Oct 12, 2024 · 4 comments

Comments

@arnauqb
Copy link
Contributor

arnauqb commented Oct 12, 2024

I have tried to implement multithreaded sampling by changing:

function estimate_energy_with_samples(prob, samples)
    #return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
    logdensity_fn = Base.Fix1(LogDensityProblems.logdensity, prob)
    return mean(fetch.([Threads.@spawn logdensity_fn(sample) for sample in eachsample(samples)]))
end

However, while this works when using the AutoForwardDiff() AD backend, it fails (silently) when using Zygote. I am guessing that this is due to Zygote not being thread safe here?

Code:

using AdvancedVI
using ADTypes
using DynamicPPL
using DistributionsAD
using Distributions
using ForwardDiff
using Bijectors
using Optimisers
using LinearAlgebra
using Zygote

function double_normal()
    return MvNormal([2.0, 3.0, 4.0], Diagonal(ones(3)))
end

@model function normal_model(data)
    p1 ~ filldist(Normal(0.0, 1.0), 2)
    p2 ~ Normal(0.0, 1.0)
    ps = vcat(p1, p2)
    for i in 1:size(data, 2)
        data[:, i] ~ MvNormal(ps, Diagonal(ones(3)))
    end
end

data = rand(double_normal(), 5)
model = normal_model(data)

##

d = 3
μ = zeros(d)
L = Diagonal(ones(d));
q = AdvancedVI.MeanFieldGaussian(μ, L)
optimizer = Optimisers.Adam(1e-3)

ℓπ = DynamicPPL.LogDensityFunction(model)
elbo = AdvancedVI.RepGradELBO(10, entropy = StickingTheLandingEntropy())

q, _, stats, _ = AdvancedVI.optimize(
	ℓπ,
	elbo,
	q,
	500;
	adtype = AutoZygote(),
	optimizer = optimizer,
)

##
using PyPlot
fig, ax = PyPlot.subplots()
elbo = [s.elbo for s in stats]
ax.plot(elbo)
fig

1. Zygote no threading

plot_3

2. Zygote with threading

plot_5

3. ForwardDiff with threading

plot_6

@arnauqb arnauqb changed the title Multithreading sampling Multithreaded sampling Oct 12, 2024
@Red-Portal
Copy link
Member

Red-Portal commented Oct 12, 2024

Not sure about this one. Maybe @torfjelde @willtebbutt have more insight?

@torfjelde
Copy link
Member

Yeah I would be very surprised if Zygote.jl worked with threads like this.

You should probably look into something like Transducers.jl or something that defines a parallel way to perform a reduce (or just define your own threaded_sum(f, args...). Then once you have this, you can define a custom adjoint for it, thus hiding the threading from Zygote.jl.

I would be surprised if something like this doesn't already exist in a package, but not 100% up to date on this. Maybe @devmotion knows of one?

@arnauqb
Copy link
Contributor Author

arnauqb commented Oct 21, 2024

Thanks @torfjelde for your suggestion. I have tried implemeting a custom rule like this:

function ChainRulesCore.rrule(
    ::typeof(AdvancedVI.estimate_energy_with_samples), prob, samples
)
    fn = Base.Fix1(LogDensityProblems.logdensity, prob)
    fn_samples =
        fetch.([
            Threads.@spawn Zygote.pullback(fn, sample) for
            sample in AdvancedVI.eachsample(samples)
        ])
    values = [sample[1] for sample in fn_samples]
    pullbacks = [sample[2] for sample in fn_samples]
    function estimate_energy_with_samples_pullback(ȳ)
        grads = [pullback(ȳ_i)[1] for (ȳ_i, pullback) in zip(ȳ, pullbacks)]
        ret = mean(grads)
        return (NoTangent(), NoTangent(), ret)
    end
    return mean(values), estimate_energy_with_samples_pullback
end

This works pretty well, but somehow the variance in the ELBO seems to be a bit lower with ForwardDiff:

plot_34

so I am wondering if I'm doing something wrong in the custom rule. Thanks for your help!

@torfjelde
Copy link
Member

From the snippet you shared, it doesn't seem like you're using the same RNG? If so, that could just be the cause of it.

Second thought, though this seems unlikely IMO, is that there might be numerical differences in the rules used by the two approaches. But yeah, RNG should be ruled out first.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants