Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] incorrect discrete inference with sequentialenumeration. #3080

Open
gcskoenig opened this issue May 3, 2022 · 1 comment
Open

[bug] incorrect discrete inference with sequentialenumeration. #3080

gcskoenig opened this issue May 3, 2022 · 1 comment
Labels

Comments

@gcskoenig
Copy link

gcskoenig commented May 3, 2022

While the discrete inference results I get for parallel enumeration are accurate, the results for sequential enumeration are not. In theory, both should return the same result. source

I created a minimal working example to demonstrate the problem. I.e. for parallel enumeration 0.62 is returned, and for sequential it is ca. 0.5.

import pyro
import pyro.distributions as dist
import torch
from pyro.infer import config_enumerate
from pyro.infer import infer_discrete

@config_enumerate
def model(x_pa_obs=None, x_ch_obs=None, y_obs=None):
    p = x_pa_obs
    y = pyro.sample('y_pre', dist.Binomial(probs=p, total_count=1),
                        infer={"enumerate": "sequential"},
                        obs=y_obs)

    d_ch = dist.Normal(y, 1.0)
    x_ch_pre = pyro.sample('x_ch_pre', d_ch, obs=x_ch_obs)

    return y

data_obs = {'x_pa_obs': torch.tensor(0.5), 'x_ch_obs': torch.tensor(1.0)}
model_discrete = infer_discrete(model, first_available_dim=-1, temperature=1)

y_posts = []
for ii in range(10**4):
    print(f'iteration {ii}', end='\r')
    y_posts.append(model_discrete(**data_obs))

smpl = torch.stack(y_posts)
print(f"mean: {smpl.mean()}")
@qinqian
Copy link

qinqian commented Jul 13, 2023

This pull request (#3238) should fix the bug. @eb8680

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants