Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with using predict with vector valued random variables #2239

Open
SamuelBrand1 opened this issue May 28, 2024 · 16 comments
Open

Problem with using predict with vector valued random variables #2239

SamuelBrand1 opened this issue May 28, 2024 · 16 comments

Comments

@SamuelBrand1
Copy link

Hi everyone,

Problem

There seems to be a problem with using predict in conjunction with models that use vectorisation.

Consider this fairly simple example:

$$\begin{split} \sigma \sim \text{HalfNormal}(0.1) \\\ \mu_i \sim \mathcal{N}(0, 1),\qquad i = 1,...,n \\\ \epsilon_i \sim \mathcal{N}(0, \sigma^2),\qquad i = 1,...,n\\\ x_i \sim \mu_i + \epsilon_i ,\qquad i = 1,...,n \end{split}$$

We can generate a dataset by sampling from this model for (say) $n = 10$. The forecasting problem is then sampling for $n_f
= 11,...,20$ (only information propagated forward is about variance of noise).

However, this fails as per below:

using Turing, StatsPlots, DynamicPPL, Random

Random.seed!(1234)

@model function mv_normal(n)
	σ ~ truncated(Normal(0., 0.1), lower = 0.)
	μ ~ MvNormal(n, 1.0) # Means
	x ~ MvNormal(μ, σ) # noise
	return x
end

mdl_10 = mv_normal(10)

# Sample data
x_data = mdl_10()

# infer means and obs noise
chn = sample(mdl_10 | (x = x_data,), NUTS(), 2_000)

# forecast
forecast_mdl = mv_normal(20)

forecast_chn = predict(forecast_mdl, chn; include_all = true)

let
	obs = generated_quantities(forecast_mdl, forecast_chn) |> X -> reduce(hcat, X)
	plt = plot(obs, c = :grey, alpha = 0.05, lab = "")
	scatter!(plt, x_data, c = :red, lab = "observed", title = "BAD FORECAST", ms = 6)
end

The failure mode here seems to be that the sample underlying random variables for $\epsilon_i,~ i = 11,...,20$ gets drawn across samples from from chn.

Fix 1: mapreduce across forecast calls

So if you instead loop over samples and run forecast for each sample, this seems to work:

forecast_chn_mapreduce = mapreduce(vcat, 1:size(chn, 1)) do i
	c = predict(forecast_mdl, chn[i,:,1]; include_all = true)
        # Take care to set the range sequentially
	setrange(c, i:i)
end

let
	obs = generated_quantities(forecast_mdl, forecast_chn_mapreduce) |> X -> reduce(hcat, X)
	plt = plot(obs, c = :grey, alpha = 0.05, lab = "")
	scatter!(plt, x_data, c = :black, lab = "observed", title = "OK FORECAST")
end

Fix 2: Non-vectorised sampling

Or you can modify the underlying model to not use calls to vectorised random variables (although IMO this is non-ideal).

@model function mv_normal_2(n)
	σ ~ truncated(Normal(0., 0.1), lower = 0.)
	μ = Vector{eltype(σ)}(undef, n)
	for i = 1:n
		μ[i] ~ Normal()
	end
	x ~ MvNormal(μ, σ) # noise
	return x
end

mdl2_10 = mv_normal_2(10)
x_data2 = mdl2_10()
chn2 = sample(mdl2_10 | (x = x_data2,), NUTS(), 2_000)


forecast_mdl2 = mv_normal_2(20)
forecast_chn2 = predict(forecast_mdl2, chn2; include_all = true)

let
	obs = generated_quantities(forecast_mdl2, forecast_chn2) |> X -> reduce(hcat, X)
	plt = plot(obs, c = :grey, alpha = 0.05, lab = "")
	scatter!(plt, x_data2, c = :black, lab = "observed", title = "ALSO OK FORECAST?")
end

Ideal situation

Obviously, it would be ideal if predict "just worked" with vectorised random variables. Given the failure mode of naive usage of predict I'm assuming that this is a problem with how the random numbers are generated around here?

@torfjelde
Copy link
Member

torfjelde commented Jun 3, 2024

Okay, so the fact that any of this works is not great 😅

A few immediate things:

  1. None of these scenarios are meant to be supported 😕 The fact that the fixed version works is just a happy accident for this particular model.
  2. One crucial aspect to Turing.jl is that the variables are treated as they occur in the model. This means that if x is sampled from a multivariate distribution, well, then it is treated as a single multivariate random variable. Any attempts to treat it otherwise, are generally not supported. In this sense, we rely on the user to tell us the correct "semantics" of / how to interpret a given variable. It's also the case that in general, it's not possible to marginalize out components (as we would technically have to do in this scenario, since the desired behavior would be to fix mu[1:10] and only sample mu[11:20]).

The reason why the 2nd scenario works at all is because the prior distribution for μ[11] is the same as the posterior predictive distribution for μ[11]. If μ[11] instead was dependent on some other variable, e.g. σ, the "fixed" version would result in μ[11] being sampled from a MvNormal with σ from the prior (not the posterior / chain).

In short, Turing.jl executes the model once before running the main part of the predict code, and uses the resulting "trace" / dictionary-like structure as a template for subsequent predictions. This initial run to construct this "trace" samples from the prior. This is why the first version results in μ[11:20] being "frozen"; one "trace" was sampled from the prior at the beginning, and their values are never resampled! Similarly, the reason why the other version happen to work, is because every time you call predict, a new sample from the prior is used to produce the "trace" before running the "actual" predict; hence, the values for μ[11:20] are sampled from the prior in every call to predict but never resampled once we've set μ[1:10] 😕

We have some checks in place to warn the user about these scenarios, but we clearly need more since this slipped through the cracks!

Effectively, if you want variables to be treated as i.i.d. rather than as a single multivariate, then you should either use .~ or for loop (as you did in the second scenario). For example, the following works:

# Nevermind; this also doesn't work...
@model function mv_normal_3(n)
	σ ~ truncated(Normal(0., 0.1), lower = 0.)
	μ = Vector{eltype(σ)}(undef, n)
	μ .~ Normal()
	x ~ MvNormal(μ, σ) # noise
	return x
end

mdl3_10 = mv_normal_3(10)
x_data3 = mdl3_10()
chn3 = sample(mdl3_10 | (x = x_data3,), NUTS(), 2_000)


forecast_mdl3 = mv_normal_3(20)
forecast_chn3 = predict(forecast_mdl3, chn3; include_all = true)

EDIT: Nvm, .~ also doesn't work I just realized, which is annoying because semantically speaking, it should 😕 Hmm, might want to do something about that.

@SamuelBrand1
Copy link
Author

Thanks for the detailed explanation!

So it turns out that this only works for models which have a representation where the priors are the same as the posterior.

TBF, this is actually a pretty large class of forecast models: discrete-time numerical solutions to SDEs and the finite dimensional distributions of a GP can be written this way (we were motivated by having a standard parameterisation of latent white noise).

@SamuelBrand1
Copy link
Author

So to be clear, in the example where you fix part of an array here using the Dict form of fixing... that wouldn't work if the array had been declared by calling ~ MvNormal(...)?

@torfjelde
Copy link
Member

TBF, this is actually a pretty large class of forecast models: discrete-time numerical solutions to SDEs and the finite dimensional distributions of a GP can be written this way (we were motivated by having a standard parameterisation of latent white noise).

Yeah, I definitively see the use for this! But it's somewhat non-trivial to support, so it really comes down to whether we want the maintenance burden of the functionality vs. having the user do some manual labour, i.e. use a for loop.

So to be clear, in the example where you fix part of an array here using the Dict form of fixing... that wouldn't work if the array had been declared by calling ~ MvNormal(...)?

Exactly.

The most annoying aprt of all this (IMO), is the perf implications of using vectorized vs. for loop. It's technically possible to do something like

if @performing_inference
    x ~ MvNormal(...)
else
    x = Vector(undef, 10)
    for i in eachindex(x)
        x[i] ~ Normal(...)
    end
end

but we don't have that implemented (related: TuringLang/DynamicPPL.jl#510).

@torfjelde
Copy link
Member

One simple approach that could also work is for the aformentioned code to be generated automatically through an iid macro or something, e.g.

@iid x ~ MvNormal(...)

and then this just converts to the above code block under the hood.

@SamuelBrand1
Copy link
Author

Right. And this would avoid issues with (say) adtype = AutoReverseDiff(true) because grad calls only occur in "inference mode" so the existence of the slower performance branch wouldn't be relevant to sampling?

@torfjelde
Copy link
Member

Right. And this would avoid issues with (say) adtype = AutoReverseDiff(true) because grad calls only occur in "inference mode" so the existence of the slower performance branch wouldn't be relevant to sampling?

Exactly:)

@SamuelBrand1
Copy link
Author

TBF, this is actually a pretty large class of forecast models: discrete-time numerical solutions to SDEs and the finite dimensional distributions of a GP can be written this way (we were motivated by having a standard parameterisation of latent white noise).

Yeah, I definitively see the use for this! But it's somewhat non-trivial to support, so it really comes down to whether we want the maintenance burden of the functionality vs. having the user do some manual labour, i.e. use a for loop.

So to be clear, in the example where you fix part of an array here using the Dict form of fixing... that wouldn't work if the array had been declared by calling ~ MvNormal(...)?

Exactly.

The most annoying aprt of all this (IMO), is the perf implications of using vectorized vs. for loop. It's technically possible to do something like

if @performing_inference
    x ~ MvNormal(...)
else
    x = Vector(undef, 10)
    for i in eachindex(x)
        x[i] ~ Normal(...)
    end
end

but we don't have that implemented (related: TuringLang/DynamicPPL.jl#510).

Did anything happen on this? @seabbs and myself are trying to make something where you can fully compose a fairly large set of Models defining different probabilistic components one might want in a fully feature epi model with @submodel.

Since, prediction is quite important here it would be a handy feature to have, but otoh I'm not inclined to go through every single model definition and put in a "forecast mode" boolean switch. Given the known behaviour of MvNormal under conditioning some indices, couldn't you support that as a special case?

@seabbs
Copy link

seabbs commented Jul 12, 2024

Noting this also appears to be an issue with filldist(Normal(), n) which is Doc'd (i think) as idd so I assumed it would work initially.

@seabbs
Copy link

seabbs commented Jul 14, 2024

I've started adding (CDCgov/Rt-without-renewal#369 (comment)) a PredictContext (+ hacking on Turing.predict) and have this switch-based approach on the way to working (some type conversion issues remain) I think (I did the inverse of the suggestion as thought it would be easier for someone with no real understanding of the contexts system). Something that is very clear is that it is pretty clunky (even with passing around the if else block in a submodel) and the @idd macro suggestion would be much preferred imo.

Given the known behaviour of MvNormal under conditioning some indices, couldn't you support that as a special case?

This seems even more ideal than the macro approach as that would be easy for a non-expert user to miss.

@torfjelde
Copy link
Member

Tagging @yebai to get some thoughts. Specifically on something like:

One simple approach that could also work is for the aformentioned code to be generated automatically through an iid macro or something, e.g.

@iid x ~ MvNormal(...)

and then this just converts to the above code block under the hood.

@yebai
Copy link
Member

yebai commented Jul 16, 2024

I am not sure about the additional macro, but I don’t have a good alternative yet. I am happy to brainstorm more here.

Cc @mhauru @sunxd3

@yebai
Copy link
Member

yebai commented Jul 16, 2024

@seabbs @SamuelBrand1, we would be happy to hear more of your thoughts on syntax design. I think it has to be robust and intuitive for long-term maintenance.

EDIT:

Noting this also appears to be an issue with filldist(Normal(), n) which is Doc'd (i think) as idd so I assumed it would work initially.

If we fix filldist (and maybe the broadcasting syntax, e.g. μ .~ Normal()) would that be sufficient for this use case?

Of course, we should provide a warning message to the user in all other unsupported cases discussed here.

@seabbs
Copy link

seabbs commented Jul 17, 2024

If we fix filldist (and maybe the broadcasting syntax, e.g. μ .~ Normal()) would that be sufficient for this use case?

Of course, we should provide a warning message to the user in all other unsupported cases discussed here.

Yes both of these would be great and cover our use case. It would also be nice to make it easier to switch modes (i.e. they specify a mode for inference and everything else manually) if users find other edge cases they want to fix but we wouldn't need that for what we are doing (at least at the moment).

@SamuelBrand1
Copy link
Author

@seabbs @SamuelBrand1, we would be happy to hear more of your thoughts on syntax design. I think it has to be robust and intuitive for long-term maintenance.

EDIT:

Noting this also appears to be an issue with filldist(Normal(), n) which is Doc'd (i think) as idd so I assumed it would work initially.

If we fix filldist (and maybe the broadcasting syntax, e.g. μ .~ Normal()) would that be sufficient for this use case?

Of course, we should provide a warning message to the user in all other unsupported cases discussed here.

This sounds good, but would there be a performance implication? I'm wondering about the upsides/downsides here.

@torfjelde
Copy link
Member

Noting this also appears to be an issue with filldist(Normal(), n) which is Doc'd (i think) as idd so I assumed it would work initially.

IMO this is more of a doc-issue, as there are definitively scenarios where you it makes sense to use filldist rather than arraydist, as the former is more efficient (when applicable).

μ .~ Normal()) would that be sufficient for this use case?

It's somewhat unclear to me how this would work, but maybe this is something we should discuss in DynamicPPL.jl :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants