Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #3255 (draft) #3265

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open

Fix #3255 (draft) #3265

wants to merge 2 commits into from

Conversation

gui11aume
Copy link

No description provided.

@gui11aume
Copy link
Author

gui11aume commented Aug 30, 2023

The purpose of this pull request is to harmonize the behavior of masking between the different ways of estimating the gradient of the ELBO (most notably when has_r_sample is False). See the discussion on issue #3255.

I added some tests inspired from those that were already present. Just ignoring the fact that some sites of the model are not present in the guide gave the correct results right away. The main difficulties were:

  1. Having missing sites in the guide triggered a warning that made the tests fail.
  2. The memory blew up with 50,000 samples in one particular case (going above 256 GB).

I had to disable user warnings for test_mask_gradient and break down sampling in 10 series of 5000 each.

The changes to pyro itself are otherwise as discreet as possible.

@gui11aume gui11aume marked this pull request as ready for review September 13, 2023 16:42
@ordabayevy
Copy link
Member

ordabayevy commented Oct 4, 2023

@gui11aume can you please fix lint issues from make lint output and push changes? You can first run make format to automatically fix formatting and then manually fix any leftover linting issues.

@gui11aume
Copy link
Author

Hi @ordabayevy! Thanks for taking the time to help me with this. There were issues in the file tests/infer/test_gradient.py (mostly in the code for the new tests) and running make format seemed to fix them all. Let me know if some issues remain.

@@ -23,7 +23,7 @@ def _compute_log_r(model_trace, guide_trace):
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
log_r_term = model_site["log_prob"]
if not model_site["is_observed"]:
if not model_site["is_observed"] and name in guide_trace.nodes:
Copy link
Member

@fritzo fritzo Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be forgetting something, but I thought Trace_ELBO requires the guide to include all model sites that are not observed. If that's the case, we wouldn't want to keep the old version where Trace_ELBO errors. Can you explain when a model site would be neither observed nor in the guide?

Copy link
Author

@gui11aume gui11aume Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you and I cannot think of useful cases of this. The point I make in the issue is that this triggers a warning when has_rsample is true and an error when it is false. I think they should both trigger a warning or both trigger an error. The suggested changes try to make the behavior consistent with the case has_rsample = True. But maybe it makes more sense to trigger an error everywhere?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks for explaining, I think I better understand now.

I'd like to hear what other folks think (@martinjankowiak @eb8680 @fehiepsi). One argument for erroring more often is that there is a lot code in Pyro that tacitly assumes all sites are either observed or guided. I'm not sure what that code is, since we've only tacitly made that assumption, but it's worth thinking about: reparametrizers, Predictive, AutoGuideMessenger.

One argument for allowing "partial" guides is that it's just more general. But if we decide to support "partial" guides throughout Pyro, I think we'll need to adopt importance sampling semantics, so we'd need to replace pyro's basic Trace data structure with a weighted trace, and replace sample sets with weighted sets of samples in a number of places e.g. Predictive. This seems like a perfectly fine design choice for a PPL, but it is different from much of Pyro's existing usage, and I think we would need to make many small changes throughout the codebase including tutorials. 🤔

Copy link
Author

@gui11aume gui11aume Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. For context, this happened to me on sites that are implicitly created by Pyro in the model (and that are therefore not in the guide), and that subsequently caused a failure because they were in the case has_rsample is false. Figuring out why the code fails in this case is quite challenging.

@@ -25,6 +25,6 @@ def scale_and_mask(self, scale=1.0, mask=None):
:type mask: torch.BoolTensor or None
"""
log_prob = scale_and_mask(self.log_prob, scale, mask)
score_function = self.score_function # not scaled
score_function = scale_and_mask(self.score_function, 1.0, mask) # not scaled
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you point out exactly the scenario that is fixed by this one line change? IIRC, score_function would always be multiplied by another tensor that is masked, so the mask here would be redundant.

Copy link
Author

@gui11aume gui11aume Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it's again for consistency. When a variable is partially observed, some parameters are created for the missing observations, and some dummy parameters are created for the non-missing ones. When has_rsample is true, the dummy parameters are not updated: they retain their initial values because they do not contribute to the gradient. When has_rsample is false, the gradient "leaks" through this line and the dummy parameters are updated during learning (but I found that inference on the non-dummy parameters was correct in the cases I checked). As above, this line does not really fix any bug, it just tries to make the behavior consistent between has_rsample = True and has_rsample = False.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining, I think I better understand now.

Comment on lines +221 to +222
if node not in guide_trace.nodes:
continue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: this should never happen

@@ -108,7 +108,7 @@ def _differentiable_loss_particle(self, model_trace, guide_trace):
if model_site["type"] == "sample":
if model_site["is_observed"]:
elbo_particle = elbo_particle + model_site["log_prob_sum"]
else:
elif name in guide_trace.nodes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: this should never happen

@@ -214,6 +215,102 @@ def guide(subsample):
assert_equal(actual_grads, expected_grads, prec=precision)


# Not including the unobserved site in the guide triggers a warning
# that can make the test fail if we do not deactivate UserWarning.
@pytest.mark.filterwarnings("ignore::UserWarning")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: this should never happen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants