Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partially type annotate distributions #3367

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install ruff black mypy nbstripout nbformat
- name: Lint
run: |
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/coalescent.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class CoalescentRateLikelihood:
def __init__(self, leaf_times, coal_times, duration, *, validate_args=None):
assert leaf_times.size(-1) == 1 + coal_times.size(-1)
assert isinstance(duration, int) and duration >= 2
if validate_args is True or validate_args is None and is_validation_enabled:
if validate_args is True or validate_args is None and is_validation_enabled():
constraint = CoalescentTimesConstraint(leaf_times, ordered=False)
if not constraint.check(coal_times).all():
raise ValueError("Invalid (leaf_times, coal_times)")
Expand Down
4 changes: 2 additions & 2 deletions pyro/distributions/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ def inv(self) -> "ConditionalTransformModule":


class _ConditionalInverseTransformModule(ConditionalTransformModule):
def __init__(self, transform: ConditionalTransform):
def __init__(self, transform: ConditionalTransformModule):
super().__init__()
self._transform = transform

@property
def inv(self) -> ConditionalTransform:
def inv(self) -> ConditionalTransformModule:
return self._transform

def condition(self, context: torch.Tensor):
Expand Down
6 changes: 4 additions & 2 deletions pyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
try:
from torch.distributions.constraints import (
Constraint,
_GreaterThan,
_LowerCholesky,
boolean,
cat,
corr_cholesky,
Expand Down Expand Up @@ -122,12 +124,12 @@ def check(self, value):
return ordered_vector.check(value) & independent(positive, 1).check(value)


class _SoftplusPositive(type(positive)):
class _SoftplusPositive(_GreaterThan):
def __init__(self):
super().__init__(lower_bound=0.0)


class _SoftplusLowerCholesky(type(lower_cholesky)):
class _SoftplusLowerCholesky(_LowerCholesky):
pass


Expand Down
8 changes: 5 additions & 3 deletions pyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import functools
import inspect
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, List

import torch

from pyro.distributions.score_parts import ScoreParts

COERCIONS = []
COERCIONS: List = []


class DistributionMeta(ABCMeta):
Expand Down Expand Up @@ -51,6 +52,7 @@ class Distribution(metaclass=DistributionMeta):

has_rsample = False
has_enumerate_support = False
rsample: Callable[..., torch.Tensor]

def __call__(self, *args, **kwargs):
"""
Expand All @@ -65,7 +67,7 @@ def __call__(self, *args, **kwargs):
return self.sample(*args, **kwargs)

@abstractmethod
def sample(self, *args, **kwargs):
def sample(self, *args, **kwargs) -> torch.Tensor:
"""
Samples a random value.

Expand All @@ -82,7 +84,7 @@ def sample(self, *args, **kwargs):
raise NotImplementedError

@abstractmethod
def log_prob(self, x, *args, **kwargs):
def log_prob(self, *args: Any, **kwargs: Any) -> torch.Tensor:
"""
Evaluates log probability densities for each of a batch of samples.

Expand Down
1 change: 0 additions & 1 deletion pyro/distributions/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,6 @@ class LinearHMM(HiddenMarkovModel):
"""

arg_constraints = {}
support = constraints.independent(constraints.real, 2)
has_rsample = True

def __init__(
Expand Down
3 changes: 2 additions & 1 deletion pyro/distributions/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import math
from typing import List

from torch.distributions import (
Independent,
Expand Down Expand Up @@ -53,4 +54,4 @@ def _kl_independent_mvn(p, q):
raise NotImplementedError


__all__ = []
__all__: List[str] = []
4 changes: 2 additions & 2 deletions pyro/distributions/nanmasked.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class NanMaskedNormal(Normal):
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
ok = value.isfinite()
if ok.all():
return super().log_prob(value)
return super().log_prob(value) # type: ignore[no-any-return]

# Broadcast all tensors.
value, ok, loc, scale = torch.broadcast_tensors(value, ok, self.loc, self.scale)
Expand Down Expand Up @@ -65,7 +65,7 @@ class NanMaskedMultivariateNormal(MultivariateNormal):
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
ok = value.isfinite()
if ok.all():
return super().log_prob(value)
return super().log_prob(value) # type: ignore[no-any-return]

# Broadcast all tensors. This might waste some computation by eagerly
# broadcasting, but the optimal implementation is quite complex.
Expand Down
5 changes: 4 additions & 1 deletion pyro/distributions/projected_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import math
from typing import Callable, Dict

import torch

Expand Down Expand Up @@ -51,7 +52,9 @@ def model():
arg_constraints = {"concentration": constraints.real_vector}
support = constraints.sphere
has_rsample = True
_log_prob_impls = {} # maps dim -> function(concentration, value)
_log_prob_impls: Dict[int, Callable] = (
{}
) # maps dim -> function(concentration, value)

def __init__(self, concentration, *, validate_args=None):
assert concentration.dim() >= 1
Expand Down
36 changes: 34 additions & 2 deletions pyro/distributions/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,38 @@
from .. import settings
from . import constraints

# Additionally try to import explicitly to help mypy static analysis.
try:
from torch.distributions import (
Bernoulli,
Cauchy,
Chi2,
ContinuousBernoulli,
Exponential,
ExponentialFamily,
FisherSnedecor,
Gumbel,
HalfCauchy,
HalfNormal,
Kumaraswamy,
Laplace,
LKJCholesky,
LogisticNormal,
MixtureSameFamily,
NegativeBinomial,
OneHotCategoricalStraightThrough,
Pareto,
RelaxedBernoulli,
RelaxedOneHotCategorical,
StudentT,
TransformedDistribution,
VonMises,
Weibull,
Wishart,
)
except ImportError:
pass


def _clamp_by_zero(x):
# works like clamp(x, min=0) but has grad at 0 is 0.5
Expand Down Expand Up @@ -202,7 +234,7 @@ def log_prob(self, value):
return (-value - 1) * torch.nn.functional.softplus(self.logits) + self.logits


class LogNormal(torch.distributions.LogNormal, TorchDistributionMixin):
class LogNormal(torch.distributions.LogNormal, TorchDistributionMixin): # type: ignore
def __init__(self, loc, scale, validate_args=None):
base_dist = Normal(loc, scale)
# This differs from torch.distributions.LogNormal only in that base_dist is
Expand Down Expand Up @@ -294,7 +326,7 @@ def log_prob(self, value):
)


class Independent(torch.distributions.Independent, TorchDistributionMixin):
class Independent(torch.distributions.Independent, TorchDistributionMixin): # type: ignore
@staticmethod
def infer_shapes(**kwargs):
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .util import broadcast_shape, scale_and_mask


class TorchDistributionMixin(Distribution, Callable):
class TorchDistributionMixin(Distribution, Callable): # type: ignore[misc]
"""
Mixin to provide Pyro compatibility for PyTorch distributions.

Expand Down
3 changes: 2 additions & 1 deletion pyro/distributions/torch_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import warnings
import weakref
from typing import List

import torch

Expand Down Expand Up @@ -92,4 +93,4 @@ def _lazy_property__call__(self):
raise NotImplementedError


__all__ = []
__all__: List[str] = []
2 changes: 1 addition & 1 deletion pyro/distributions/torch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, parts, cache_size=0):
def __hash__(self):
return super(torch.nn.Module, self).__hash__()

def with_cache(self, cache_size=1):
def with_cache(self, cache_size=1) -> "ComposeTransformModule":
if cache_size == self._cache_size:
return self
return ComposeTransformModule(self.parts, cache_size=cache_size)
2 changes: 1 addition & 1 deletion pyro/distributions/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
from .power import PositivePowerTransform
from .radial import ConditionalRadial, Radial, conditional_radial, radial
from .simplex_to_ordered import SimplexToOrderedTransform
from .softplus import SoftplusLowerCholeskyTransform, SoftplusTransform
from .softplus import SoftplusLowerCholeskyTransform, SoftplusTransform # type: ignore[assignment]
from .spline import ConditionalSpline, Spline, conditional_spline, spline
from .spline_autoregressive import (
ConditionalSplineAutoregressive,
Expand Down
7 changes: 3 additions & 4 deletions pyro/distributions/transforms/cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class CholeskyTransform(Transform):
"""

bijective = True
domain = constraints.positive_definite
codomain = constraints.lower_cholesky
domain: constraints.Constraint = constraints.positive_definite
codomain: constraints.Constraint = constraints.lower_cholesky

def __eq__(self, other):
return isinstance(other, CholeskyTransform)
Expand Down Expand Up @@ -55,8 +55,7 @@ class CorrMatrixCholeskyTransform(CholeskyTransform):

bijective = True
domain = constraints.corr_matrix
# TODO: change corr_cholesky_constraint to corr_cholesky when the latter is availabler
codomain = constraints.corr_cholesky_constraint
codomain = constraints.corr_cholesky

def __eq__(self, other):
return isinstance(other, CorrMatrixCholeskyTransform)
Expand Down
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ ignore_errors = True
[mypy-pyro.contrib.*]
ignore_errors = True

[mypy-pyro.distributions.*]
ignore_errors = True
warn_unused_ignores = True

[mypy-pyro.generic.*]
ignore_errors = True
warn_unused_ignores = True
Expand Down
Loading