Skip to content

Commit

Permalink
Explicitly import distributions from torch (#3333)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Feb 27, 2024
1 parent 3a1bd6a commit e8af7cd
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
all: docs test

install: FORCE
pip install -e .[dev,profile]
pip install -e .[dev,profile] --config-settings editable_mode=strict

uninstall: FORCE
pip uninstall pyro-ppl
Expand Down
90 changes: 90 additions & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,57 @@
# SPDX-License-Identifier: Apache-2.0

import pyro.distributions.torch_patch # noqa F403

# Import * to get the latest upstream distributions.
from pyro.distributions.torch import * # noqa F403

# Additionally try to import explicitly to help mypy static analysis.
try:
from pyro.distributions.torch import (
Bernoulli,
Beta,
Binomial,
Categorical,
Cauchy,
Chi2,
ContinuousBernoulli,
Dirichlet,
Exponential,
ExponentialFamily,
FisherSnedecor,
Gamma,
Geometric,
Gumbel,
HalfCauchy,
HalfNormal,
Independent,
Kumaraswamy,
Laplace,
LKJCholesky,
LogisticNormal,
LogNormal,
LowRankMultivariateNormal,
MixtureSameFamily,
Multinomial,
MultivariateNormal,
NegativeBinomial,
Normal,
OneHotCategorical,
OneHotCategoricalStraightThrough,
Pareto,
Poisson,
RelaxedBernoulli,
RelaxedOneHotCategorical,
StudentT,
TransformedDistribution,
Uniform,
VonMises,
Weibull,
Wishart,
)
except ImportError:
pass

# isort: split

from pyro.distributions.affine_beta import AffineBeta
Expand Down Expand Up @@ -99,7 +148,13 @@
"AVFMultivariateNormal",
"AffineBeta",
"AsymmetricLaplace",
"Bernoulli",
"Beta",
"BetaBinomial",
"Binomial",
"Categorical",
"Cauchy",
"Chi2",
"CoalescentRateLikelihood",
"CoalescentTimes",
"CoalescentTimesWithRate",
Expand All @@ -108,43 +163,71 @@
"ConditionalTransform",
"ConditionalTransformModule",
"ConditionalTransformedDistribution",
"ContinuousBernoulli",
"Delta",
"Dirichlet",
"DirichletMultinomial",
"DiscreteHMM",
"Distribution",
"Empirical",
"ExpandedDistribution",
"Exponential",
"ExponentialFamily",
"ExtendedBetaBinomial",
"ExtendedBinomial",
"FisherSnedecor",
"FoldedDistribution",
"Gamma",
"GammaGaussianHMM",
"GammaPoisson",
"GaussianHMM",
"GaussianMRF",
"GaussianScaleMixture",
"Geometric",
"GroupedNormalNormal",
"Gumbel",
"HalfCauchy",
"HalfNormal",
"ImproperUniform",
"Independent",
"IndependentHMM",
"InverseGamma",
"Kumaraswamy",
"LKJ",
"LKJCholesky",
"LKJCorrCholesky",
"Laplace",
"LinearHMM",
"LogNormal",
"LogNormalNegativeBinomial",
"Logistic",
"LogisticNormal",
"LowRankMultivariateNormal",
"MaskedDistribution",
"MaskedMixture",
"MixtureOfDiagNormals",
"MixtureOfDiagNormalsSharedCovariance",
"MixtureSameFamily",
"Multinomial",
"MultivariateNormal",
"MultivariateStudentT",
"NanMaskedMultivariateNormal",
"NanMaskedNormal",
"NegativeBinomial",
"Normal",
"OMTMultivariateNormal",
"OneHotCategorical",
"OneHotCategoricalStraightThrough",
"OneOneMatching",
"OneTwoMatching",
"OrderedLogistic",
"Pareto",
"Poisson",
"ProjectedNormal",
"Rejector",
"RelaxedBernoulli",
"RelaxedBernoulliStraightThrough",
"RelaxedOneHotCategorical",
"RelaxedOneHotCategoricalStraightThrough",
"SineBivariateVonMises",
"SineSkewed",
Expand All @@ -153,11 +236,17 @@
"SoftLaplace",
"SpanningTree",
"Stable",
"StudentT",
"TorchDistribution",
"TransformModule",
"TransformedDistribution",
"TruncatedPolyaGamma",
"Uniform",
"Unit",
"VonMises",
"VonMises3D",
"Weibull",
"Wishart",
"ZeroInflatedDistribution",
"ZeroInflatedNegativeBinomial",
"ZeroInflatedPoisson",
Expand All @@ -171,4 +260,5 @@

# Import all torch distributions from `pyro.distributions.torch_distribution`
__all__.extend(torch_dists)
__all__[:] = sorted(set(__all__))
del torch_dists
79 changes: 71 additions & 8 deletions pyro/distributions/constraints.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,50 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

# Import * to get the latest upstream constraints.
from torch.distributions.constraints import * # noqa F403

# Additionally try to import explicitly to help mypy static analysis.
try:
from torch.distributions.constraints import (
Constraint,
boolean,
cat,
corr_cholesky,
dependent,
dependent_property,
greater_than,
greater_than_eq,
half_open_interval,
independent,
integer_interval,
interval,
is_dependent,
less_than,
lower_cholesky,
lower_triangular,
multinomial,
nonnegative,
nonnegative_integer,
one_hot,
positive,
positive_definite,
positive_integer,
positive_semidefinite,
real,
real_vector,
simplex,
square,
stack,
symmetric,
unit_interval,
)
except ImportError:
pass

# isort: split

import torch
from torch.distributions.constraints import (
Constraint,
independent,
lower_cholesky,
positive,
positive_definite,
)
from torch.distributions.constraints import __all__ as torch_constraints


Expand Down Expand Up @@ -129,19 +161,50 @@ def check(self, value):
corr_cholesky_constraint = corr_cholesky # noqa: F405 DEPRECATED

__all__ = [
"Constraint",
"boolean",
"cat",
"corr_cholesky",
"corr_cholesky_constraint",
"corr_matrix",
"dependent",
"dependent_property",
"greater_than",
"greater_than_eq",
"half_open_interval",
"independent",
"integer",
"integer_interval",
"interval",
"is_dependent",
"less_than",
"lower_cholesky",
"lower_triangular",
"multinomial",
"nonnegative",
"nonnegative_integer",
"one_hot",
"ordered_vector",
"positive",
"positive_definite",
"positive_integer",
"positive_ordered_vector",
"positive_semidefinite",
"real",
"real_vector",
"simplex",
"softplus_lower_cholesky",
"softplus_positive",
"sphere",
"square",
"stack",
"symmetric",
"unit_interval",
"unit_lower_cholesky",
]

__all__.extend(torch_constraints)
__all__ = sorted(set(__all__))
__all__[:] = sorted(set(__all__))
del torch_constraints


Expand Down
49 changes: 47 additions & 2 deletions pyro/distributions/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,52 @@ def _cat_docstrings(*docstrings):
return result


# Programmatically load all distributions from PyTorch.
__all__ = []
# Add static imports to help mypy.
__all__ = [ # noqa: F822
"Bernoulli",
"Beta",
"Binomial",
"Categorical",
"Cauchy",
"Chi2",
"ContinuousBernoulli",
"Dirichlet",
"ExponentialFamily",
"Exponential",
"FisherSnedecor",
"Gamma",
"Geometric",
"Gumbel",
"HalfCauchy",
"HalfNormal",
"Independent",
"Kumaraswamy",
"Laplace",
"LKJCholesky",
"LogNormal",
"LogisticNormal",
"LowRankMultivariateNormal",
"MixtureSameFamily",
"Multinomial",
"MultivariateNormal",
"NegativeBinomial",
"Normal",
"OneHotCategorical",
"OneHotCategoricalStraightThrough",
"Pareto",
"Poisson",
"RelaxedBernoulli",
"RelaxedOneHotCategorical",
"StudentT",
"TransformedDistribution",
"Uniform",
"VonMises",
"Weibull",
"Wishart",
]

# Programmatically load all distributions from PyTorch,
# updating __all__ to include any new distributions.
for _name, _Dist in torch.distributions.__dict__.items():
if not isinstance(_Dist, type):
continue
Expand All @@ -372,6 +416,7 @@ def _cat_docstrings(*docstrings):
)
_PyroDist.__doc__ = _cat_docstrings(_PyroDist.__doc__, _Dist.__doc__)
__all__.append(_name)
__all__ = sorted(set(__all__))


# Create sphinx documentation.
Expand Down
Loading

0 comments on commit e8af7cd

Please sign in to comment.