Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix sequential enumeration #3238

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open

Conversation

qinqian
Copy link

@qinqian qinqian commented Jul 13, 2023

This is a pull request to fix the bug on the github issue.

with the same code, the sequential enumeration generates 0.6269999742507935 for 10000 infer_discrete operation with temperature = 1. Changing enum variable to parallel generates mean 0.6294000148773193. Using temperature=0 for MAP estimation of the y_pre will generate mean=1 for both parallel and sequential enumeration. These tests are on a GCP VM machine with a Ubuntu docker image.

import pyro
import pyro.distributions as dist
import torch
from pyro.infer import config_enumerate
from pyro.infer import infer_discrete

enum = "sequential" 

@config_enumerate
def model(x_pa_obs=None, x_ch_obs=None, y_obs=None):
    p = x_pa_obs
    y = pyro.sample('y_pre', dist.Binomial(probs=p, total_count=1),
                    infer={"enumerate": enum},
                    obs=y_obs)

    d_ch = dist.Normal(y, 1.0)
    x_ch_pre = pyro.sample('x_ch_pre', d_ch, obs=x_ch_obs)

    return y


data_obs = {'x_pa_obs': torch.tensor(0.5), 'x_ch_obs': torch.tensor(1.0)}
model_discrete = infer_discrete(model, first_available_dim=-1, temperature=1)

y_posts = []
for ii in range(10**4):
    print(f'iteration {ii}', end='\r')
    y_posts.append(model_discrete(**data_obs))

smpl = torch.stack(y_posts)
print(smpl.shape)
print(f"mean: {smpl.mean()}")

@eb8680
Copy link
Member

eb8680 commented Jul 14, 2023

@qinqian thanks for looking at this. Can you turn the example above into a unit test in tests/infer/test_discrete.py that fails without this fix and passes with it?

@qinqian
Copy link
Author

qinqian commented Jul 17, 2023

yes @eb8680 , I turned this example into a unit test in tests/infer/test_discrete.py. added 2 tests that pass with it, and four cases that fail without this. To simulate example without the fix, the enum was assigned to be other.

@ordabayevy
Copy link
Member

Does anyone know why Github Actions are not running after new commits?

@qinqian
Copy link
Author

qinqian commented Oct 5, 2023

I think we may need permission to kick off the Github Actions. Thanks for your interest to the pull request. I tested it locally, it should work this time.

Comment on lines +169 to +172
if msg["is_observed"] or msg["infer"].get("enumerate") not in [
"parallel",
"sequential",
]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks suspicious. Wouldn't this change merely treat "sequential" like "parallel" enumeration, so sure of course the results would then agree.

@fritzo
Copy link
Member

fritzo commented Oct 5, 2023

Sorry about the github bug, sometimes I've needed to close a PR and open another.

Aside from tests, can you explain your diagnosis of the problem and your proposed solution? From what I can tell, this PR amounts to "if the user says 'sequential' pretend they said 'parallel'", which seems wrong. But maybe I'm missing something.

@qinqian
Copy link
Author

qinqian commented Oct 5, 2023

Yes @fritzo. The question is: sequential enumerate generate different results from parallel enumeration. I use two ways to diagnosis of the problem.

The first is to add breakpoint to check the function here with the simple example above, and found that the key difference between the two enumerations is coming from the enum_terms which is always empty for the sequential enumeration, that means no elbo loss added for this term, it tends to be random elbo 0.5 for the simple example above. Then I track the function to the EnumMessenger (

class EnumMessenger(Messenger):
) class, and found that it did add loss for the sequential option.

The second way is to compare the pyro EnumMessenger with funsor, they use a similar way as I proposed above.

Please let me know if there is some misunderstanding of the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants