Skip to content

Commit

Permalink
Support pymc3 traces with no draws or no tuning (#1211)
Browse files Browse the repository at this point in the history
* test edge cases w.r.t number of chains/tune/draws
+ is a regression test for #1210

* make pylint allow protected access in tests

* repair slicing of the trace for ndraws=0
closes #1210

* apply black formatting

* replace the negative slicing in all places

* accelerate tests

* add split_trace function

* update tests

* fix

* fix docstrings

* update changelog

Co-authored-by: Oriol (Prodesk) <[email protected]>
  • Loading branch information
michaelosthege and OriolAbril authored May 28, 2020
1 parent 6b4de81 commit ec33b4c
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 37 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@

### Documentation

## v0.8.3 (2020 May 28)
### Maintenance and fixes
* Restructured internals of `from_pymc3` to handle old pymc3 releases and
sliced traces and to provide useful warnings (#1211)

## v0.8.2 (2020 May 25)
### Maintenance and fixes
* Fixed bug in `from_pymc3` for sliced `pymc3.MultiTrace` input (#1209)

## v0.8.1 (2020 May 24)

### Maintenance and fixes
Expand Down
2 changes: 1 addition & 1 deletion arviz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pylint: disable=wildcard-import,invalid-name,wrong-import-position
"""ArviZ is a library for exploratory analysis of Bayesian models."""
__version__ = "0.8.2"
__version__ = "0.8.3"

import os
import logging
Expand Down
67 changes: 48 additions & 19 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""PyMC3-specific conversion code."""
import logging
import warnings
from typing import Dict, List, Any, Optional, Iterable, Union, TYPE_CHECKING
from typing import Dict, List, Any, Optional, Iterable, Union, TYPE_CHECKING, Tuple
from types import ModuleType

import numpy as np
Expand Down Expand Up @@ -113,6 +113,8 @@ def __init__(
" Please consider using PyMC3>=3.9 and do not slice the trace manually.",
UserWarning,
)
self.ntune = len(self.trace) - self.ndraws
self.posterior_trace, self.warmup_trace = self.split_trace()
else:
self.nchains = self.ndraws = 0

Expand Down Expand Up @@ -160,6 +162,26 @@ def find_observations(self) -> Optional[Dict[str, Var]]:
return {obs.name: obs.observations for obs in self.model.observed_RVs}
return None

def split_trace(self) -> Tuple[Union[None, MultiTrace], Union[None, MultiTrace]]:
"""Split MultiTrace object into posterior and warmup.
Returns
-------
trace_posterior: pymc3.MultiTrace or None
The slice of the trace corresponding to the posterior. If the posterior
trace is empty, None is returned
trace_warmup: pymc3.MultiTrace or None
The slice of the trace corresponding to the warmup. If the warmup trace is
empty or ``save_warmup=False``, None is returned
"""
trace_posterior = None
trace_warmup = None
if self.save_warmup and self.ntune > 0:
trace_warmup = self.trace[: self.ntune]
if self.ndraws > 0:
trace_posterior = self.trace[self.ntune :]
return trace_posterior, trace_warmup

def log_likelihood_vals_point(self, point, var, log_like_fun):
"""Compute log likelihood for each observed point."""
log_like_val = utils.one_de(log_like_fun(point))
Expand Down Expand Up @@ -208,13 +230,14 @@ def posterior_to_xarray(self):
data = {}
data_warmup = {}
for var_name in var_names:
if self.save_warmup:
if self.warmup_trace:
data_warmup[var_name] = np.array(
self.trace[: -self.ndraws].get_values(var_name, combine=False, squeeze=False)
self.warmup_trace.get_values(var_name, combine=False, squeeze=False)
)
if self.posterior_trace:
data[var_name] = np.array(
self.posterior_trace.get_values(var_name, combine=False, squeeze=False)
)
data[var_name] = np.array(
self.trace[-self.ndraws :].get_values(var_name, combine=False, squeeze=False)
)
return (
dict_to_dataset(
data, library=self.pymc3, coords=self.coords, dims=self.dims, attrs=self.attrs
Expand All @@ -239,11 +262,12 @@ def sample_stats_to_xarray(self):
name = rename_key.get(stat, stat)
if name == "tune":
continue
if self.save_warmup:
if self.warmup_trace:
data_warmup[name] = np.array(
self.trace[: -self.ndraws].get_sampler_stats(stat, combine=False)
self.warmup_trace.get_sampler_stats(stat, combine=False)
)
data[name] = np.array(self.trace[-self.ndraws :].get_sampler_stats(stat, combine=False))
if self.posterior_trace:
data[name] = np.array(self.posterior_trace.get_sampler_stats(stat, combine=False))

return (
dict_to_dataset(
Expand All @@ -260,17 +284,22 @@ def log_likelihood_to_xarray(self):
"""Extract log likelihood and log_p data from PyMC3 trace."""
if self.predictions or not self.log_likelihood:
return None
try:
data = self._extract_log_likelihood(self.trace[-self.ndraws :])
except TypeError:
warnings.warn(
"""Could not compute log_likelihood, it will be omitted.
Check your model object or set log_likelihood=False"""
)
return None
data_warmup = {}
if self.save_warmup:
data_warmup = self._extract_log_likelihood(self.trace[: -self.ndraws])
data = {}
warn_msg = (
"Could not compute log_likelihood, it will be omitted. "
"Check your model object or set log_likelihood=False"
)
if self.posterior_trace:
try:
data = self._extract_log_likelihood(self.posterior_trace)
except TypeError:
warnings.warn(warn_msg)
if self.warmup_trace:
try:
data_warmup = self._extract_log_likelihood(self.warmup_trace)
except TypeError:
warnings.warn(warn_msg)
return (
dict_to_dataset(data, library=self.pymc3, dims=self.dims, coords=self.coords),
dict_to_dataset(data_warmup, library=self.pymc3, dims=self.dims, coords=self.coords),
Expand Down
62 changes: 45 additions & 17 deletions arviz/tests/external_tests/test_data_pymc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: disable=no-member, invalid-name, redefined-outer-name
# pylint: disable=no-member, invalid-name, redefined-outer-name, protected-access
from sys import version_info
from typing import Dict, Tuple

Expand Down Expand Up @@ -405,36 +405,40 @@ class TestPyMC3WarmupHandling:
reason="requires pymc3 3.9 or higher",
)
@pytest.mark.parametrize("save_warmup", [False, True])
def test_save_warmup(self, save_warmup):
@pytest.mark.parametrize("chains", [1, 2])
@pytest.mark.parametrize("tune,draws", [(0, 50), (10, 40), (30, 0)])
def test_save_warmup(self, save_warmup, chains, tune, draws):
with pm.Model():
pm.Uniform("u1")
pm.Normal("n1")
trace = pm.sample(
tune=100,
draws=200,
chains=2,
tune=tune,
draws=draws,
chains=chains,
cores=1,
step=pm.Metropolis(),
discard_tuned_samples=False,
)
assert isinstance(trace, pm.backends.base.MultiTrace)
idata = from_pymc3(trace, save_warmup=save_warmup)
prefix = "" if save_warmup else "~"
warmup_prefix = "" if save_warmup and (tune > 0) else "~"
post_prefix = "" if draws > 0 else "~"
test_dict = {
"posterior": ["u1", "n1"],
"sample_stats": ["~tune", "accept"],
f"{prefix}warmup_posterior": ["u1", "n1"],
f"{prefix}warmup_sample_stats": ["~tune"],
f"{post_prefix}posterior": ["u1", "n1"],
f"{post_prefix}sample_stats": ["~tune", "accept"],
f"{warmup_prefix}warmup_posterior": ["u1", "n1"],
f"{warmup_prefix}warmup_sample_stats": ["~tune"],
"~warmup_log_likelihood": [],
"~log_likelihood": [],
}
fails = check_multiple_attrs(test_dict, idata)
assert not fails
assert idata.posterior.dims["chain"] == 2
assert idata.posterior.dims["draw"] == 200
if save_warmup:
assert idata.warmup_posterior.dims["chain"] == 2
assert idata.warmup_posterior.dims["draw"] == 100
if hasattr(idata, "posterior"):
assert idata.posterior.dims["chain"] == chains
assert idata.posterior.dims["draw"] == draws
if hasattr(idata, "warmup_posterior"):
assert idata.warmup_posterior.dims["chain"] == chains
assert idata.warmup_posterior.dims["draw"] == tune

@pytest.mark.skipif(
hasattr(pm.backends.base.SamplerReport, "n_draws"), reason="requires pymc3 3.8 or lower",
Expand All @@ -458,8 +462,16 @@ def test_save_warmup_issue_1208_before_3_9(self):
# making from_pymc3 fall back to len(trace) and triggering a warning
with pytest.warns(UserWarning, match="Warmup samples"):
idata = from_pymc3(trace, save_warmup=True)
assert idata.posterior.dims["draw"] == 300
assert idata.posterior.dims["chain"] == 2
test_dict = {
"posterior": ["u1", "n1"],
"sample_stats": ["~tune", "accept"],
"~warmup_posterior": [],
"~warmup_sample_stats": [],
}
fails = check_multiple_attrs(test_dict, idata)
assert not fails
assert idata.posterior.dims["draw"] == 300
assert idata.posterior.dims["chain"] == 2

@pytest.mark.skipif(
not hasattr(pm.backends.base.SamplerReport, "n_draws"),
Expand All @@ -482,11 +494,27 @@ def test_save_warmup_issue_1208_after_3_9(self):

# from original trace, warmup draws should be separated out
idata = from_pymc3(trace, save_warmup=True)
test_dict = {
"posterior": ["u1", "n1"],
"sample_stats": ["~tune", "accept"],
"warmup_posterior": ["u1", "n1"],
"warmup_sample_stats": ["~tune", "accept"],
}
fails = check_multiple_attrs(test_dict, idata)
assert not fails
assert idata.posterior.dims["chain"] == 2
assert idata.posterior.dims["draw"] == 200

# manually sliced trace triggers the same warning as <=3.8
with pytest.warns(UserWarning, match="Warmup samples"):
idata = from_pymc3(trace[-30:], save_warmup=True)
test_dict = {
"posterior": ["u1", "n1"],
"sample_stats": ["~tune", "accept"],
"~warmup_posterior": [],
"~warmup_sample_stats": [],
}
fails = check_multiple_attrs(test_dict, idata)
assert not fails
assert idata.posterior.dims["chain"] == 2
assert idata.posterior.dims["draw"] == 30

0 comments on commit ec33b4c

Please sign in to comment.