diff --git a/docs/source/infer.util.rst b/docs/source/infer.util.rst new file mode 100644 index 0000000000..f022d744e6 --- /dev/null +++ b/docs/source/infer.util.rst @@ -0,0 +1,13 @@ +Inference utilities +=================== + +.. autofunction:: pyro.infer.util.enable_validation +.. autofunction:: pyro.infer.util.is_validation_enabled +.. autofunction:: pyro.infer.util.validation_enabled + +Model inspection +---------------- + +.. automodule:: pyro.infer.inspect + :members: + :member-order: bysource diff --git a/docs/source/inference.rst b/docs/source/inference.rst index 65ceffd791..200f3ac51f 100644 --- a/docs/source/inference.rst +++ b/docs/source/inference.rst @@ -19,3 +19,4 @@ See `Intro II `_ for a discussion of mcmc infer.autoguide infer.reparam + infer.util diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 2d9d4a85f0..036632ec16 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -19,10 +19,10 @@ def model(): import operator import warnings import weakref -from collections import defaultdict +from collections import OrderedDict, defaultdict from contextlib import ExitStack from types import SimpleNamespace -from typing import Callable, Dict, Union +from typing import Callable, Dict, Optional, Union import torch from torch import nn @@ -40,6 +40,7 @@ def model(): init_to_median, ) from pyro.infer.enum import config_enumerate +from pyro.infer.inspect import get_dependencies from pyro.nn import PyroModule, PyroParam from pyro.ops.hessian import hessian from pyro.ops.tensor_utils import periodic_repeat @@ -1310,9 +1311,18 @@ def model(data): y = pyro.sample("y", dist.Normal(0, 1)) pyro.sample("z", dist.Normal(y, x), obs=data) + # Either fully automatic... + guide = AutoStructured(model) + + # ...or with specified conditional and dependency types... + guide = AutoStructured( + model, conditionals="normal", dependencies="linear" + ) + + # ...or with custom dependency structure and distribution types. guide = AutoStructured( model=model, - conditionals={"x": "normal", "y": "normal"}, + conditionals={"x": "normal", "y": "delta"}, dependencies={"x": {"y": "linear"}}, ) @@ -1333,19 +1343,20 @@ def optim_config(param_name): adam = pyro.optim.Adam(optim_config) :param callable model: A Pyro model. - :param conditionals: Family of distribution with which to model each latent - variable's conditional posterior. This should be a dict mapping each - latent variable name to either a string in ("delta", "normal", or - "mvn") or to a callable that returns a sample from a zero mean (or - approximately centered) noise distribution (such callables typically - call ``pyro.param()`` and ``pyro.sample()`` internally). - :param dependencies: Dict mapping each site name to a dict of its upstream - dependencies; each inner dict maps upstream site name to either the - string "linear" or a callable that maps a *flattened* upstream - perturbation to *flattened* downstream perturbation. The string - "linear" is equivalent to ``nn.Linear(upstream.numel(), - downstream.numel(), bias=False)``. Dependencies must not contain - cycles or self-loops. + :param conditionals: Either a single distribution type or a dict mapping + each latent variable name to a distribution type. A distribution type + is either a string in {"delta", "normal", "mvn"} or a callable that + returns a sample from a zero mean (or approximately centered) noise + distribution (such callables typically call ``pyro.param()`` and + ``pyro.sample()`` internally). + :param dependencies: Dependency type, or a dict mapping each site name to a + dict mapping its upstream dependencies to dependency types. If only a + dependecy type is provided, dependency structure will be inferred. A + dependency type is either the string "linear" or a callable that maps a + *flattened* upstream perturbation to *flattened* downstream + perturbation. The string "linear" is equivalent to + ``nn.Linear(upstream.numel(), downstream.numel(), bias=False)``. + Dependencies must not contain cycles or self-loops. :param callable init_loc_fn: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. :param float init_scale: Initial scale for the standard deviation of each @@ -1363,33 +1374,62 @@ def __init__( self, model, *, - conditionals: Dict[str, Union[str, Callable]] = "normal", - dependencies: Dict[str, Dict[str, Union[str, Callable]]] = "linear", - init_loc_fn=init_to_feasible, - init_scale=0.1, - create_plates=None, + conditionals: Union[str, Dict[str, Union[str, Callable]]] = "mvn", + dependencies: Union[str, Dict[str, Dict[str, Union[str, Callable]]]] = "linear", + init_loc_fn: Callable = init_to_feasible, + init_scale: float = 0.1, + create_plates: Optional[Callable] = None, ): - assert isinstance(conditionals, dict) - for name, fn in conditionals.items(): - assert isinstance(name, str) - assert isinstance(fn, str) or callable(fn) - assert isinstance(dependencies, dict) - for downstream, deps in dependencies.items(): - assert downstream in conditionals - assert isinstance(deps, dict) - for upstream, dep in deps.items(): - assert upstream in conditionals - assert upstream != downstream - assert isinstance(dep, str) or callable(dep) + assert isinstance(conditionals, (dict, str)) + if isinstance(conditionals, dict): + for name, fn in conditionals.items(): + assert isinstance(name, str) + assert isinstance(fn, str) or callable(fn) + assert isinstance(dependencies, (dict, str)) + if isinstance(dependencies, dict): + for downstream, deps in dependencies.items(): + assert downstream in conditionals + assert isinstance(deps, dict) + for upstream, dep in deps.items(): + assert upstream in conditionals + assert upstream != downstream + assert isinstance(dep, str) or callable(dep) self.conditionals = conditionals self.dependencies = dependencies if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError(f"Expected init_scale > 0. but got {init_scale}") self._init_scale = init_scale + self._original_model = (model,) model = InitMessenger(init_loc_fn)(model) super().__init__(model, create_plates=create_plates) + def _auto_config(self, sample_sites, args, kwargs): + # Instantiate conditionals as dictionaries. + if not isinstance(self.conditionals, dict): + self.conditionals = { + name: self.conditionals for name, site in sample_sites.items() + } + + # Instantiate dependencies as dictionaries. + if not isinstance(self.dependencies, dict): + model = self._original_model[0] + meta = poutine.block(get_dependencies)(model, args, kwargs) + # Use posterior dependency edges but with prior ordering. This + # allows sampling of globals before locals on which they depend. + prior_order = {name: i for i, name in enumerate(sample_sites)} + dependencies = defaultdict(dict) + for d, upstreams in meta["posterior_dependencies"].items(): + assert d in sample_sites + for u, plates in upstreams.items(): + # TODO use plates to reduce dimension of dependency. + if u in sample_sites: + if prior_order[u] > prior_order[d]: + dependencies[u][d] = self.dependencies + elif prior_order[d] > prior_order[u]: + dependencies[d][u] = self.dependencies + self.dependencies = dict(dependencies) + def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) @@ -1400,11 +1440,13 @@ def _setup_prototype(self, *args, **kwargs): self.deps = PyroModule() self._batch_shapes = {} self._unconstrained_event_shapes = {} + sample_sites = OrderedDict(self.prototype_trace.iter_stochastic_nodes()) + self._auto_config(sample_sites, args, kwargs) # Collect unconstrained shapes. init_locs = {} numel = {} - for name, site in self.prototype_trace.iter_stochastic_nodes(): + for name, site in sample_sites.items(): with helpful_support_errors(site): init_loc = ( biject_to(site["fn"].support).inv(site["value"].detach()).detach() @@ -1419,7 +1461,7 @@ def _setup_prototype(self, *args, **kwargs): # Initialize guide params. children = defaultdict(list) num_pending = {} - for name, site in self.prototype_trace.iter_stochastic_nodes(): + for name, site in sample_sites.items(): # Initialize location parameters. init_loc = init_locs[name] _deep_setattr(self.locs, name, PyroParam(init_loc)) @@ -1449,7 +1491,7 @@ def _setup_prototype(self, *args, **kwargs): deps = PyroModule() _deep_setattr(self.deps, name, deps) for upstream, dep in self.dependencies.get(name, {}).items(): - assert upstream in self.prototype_trace.nodes + assert upstream in sample_sites children[upstream].append(name) num_pending[name] += 1 if isinstance(dep, str) and dep == "linear": @@ -1462,6 +1504,7 @@ def _setup_prototype(self, *args, **kwargs): _deep_setattr(deps, upstream, dep) # Topologically sort sites. + # TODO should we choose a more optimal structure? self._sorted_sites = [] while num_pending: name, count = min(num_pending.items(), key=lambda kv: (kv[1], kv[0])) @@ -1469,7 +1512,7 @@ def _setup_prototype(self, *args, **kwargs): del num_pending[name] for child in children[name]: num_pending[child] -= 1 - site = self._compress_site(self.prototype_trace.nodes[name]) + site = self._compress_site(sample_sites[name]) self._sorted_sites.append((name, site)) # Prune non-essential parts of the trace to save memory. diff --git a/pyro/infer/inspect.py b/pyro/infer/inspect.py new file mode 100644 index 0000000000..ccb642ebe2 --- /dev/null +++ b/pyro/infer/inspect.py @@ -0,0 +1,253 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Dict, Optional + +import torch + +import pyro +import pyro.poutine as poutine +from pyro.poutine.messenger import Messenger +from pyro.poutine.util import site_is_subsample + + +def is_sample_site(msg): + if msg["type"] != "sample": + return False + if site_is_subsample(msg): + return False + + # Ignore masked observations. + if msg["is_observed"] and msg["mask"] is False: + return False + + # Exclude deterministic sites. + fn = msg["fn"] + while hasattr(fn, "base_dist"): + fn = fn.base_dist + if type(fn).__name__ == "Delta": + return False + + return True + + +class RequiresGradMessenger(Messenger): + def __init__(self, predicate=lambda msg: True): + self.predicate = predicate + super().__init__() + + def _pyro_post_sample(self, msg): + if is_sample_site(msg): + if self.predicate(msg): + msg["value"].requires_grad_() + elif not msg["is_observed"] and msg["value"].requires_grad: + msg["value"] = msg["value"].detach() + + +def get_dependencies( + model: Callable, + model_args: Optional[tuple] = None, + model_kwargs: Optional[dict] = None, +) -> Dict[str, object]: + r""" + EXPERIMENTAL Infers dependency structure about a conditioned model. + + This returns a nested dictionary with structure like:: + + { + "prior_dependencies": { + "variable1": {"variable1": set()}, + "variable2": {"variable1": set(), "variable2": set()}, + ... + }, + "posterior_dependencies": { + "variable1": {"variable1": {"plate1"}, "variable2": set()}, + ... + }, + } + + where + + - `prior_dependencies` is a dict mapping downstream latent and observed + variables to dictionaries mapping upstream latent variables on which + they depend to sets of plates inducing full dependencies. + That is, included plates introduce quadratically many dependencies as + in complete-bipartite graphs, whereas excluded plates introduce only + linearly many dependencies as in independent sets of parallel edges. + Prior dependencies follow the original model order. + - `posterior_dependencies` is a similar dict, but mapping latent + variables to the latent or observed sits on which they depend in the + posterior. Posterior dependencies are reversed from the model order. + + Dependencies elide ``pyro.deterministic`` sites and ``pyro.sample(..., + Delta(...))`` sites. + + **Examples** + + Here is a simple example with no plates. We see every node depends on + itself, and only the latent variables appear in the posterior:: + + def model_1(): + a = pyro.sample("a", dist.Normal(0, 1)) + pyro.sample("b", dist.Normal(a, 1), obs=torch.tensor(0.0)) + + assert get_dependencies(model_1) == { + "prior_dependencies": { + "a": {"a": set()}, + "b": {"a": set(), "b": set()}, + }, + "posterior_dependencies": { + "a": {"a": set(), "b": set()}, + }, + } + + Here is an example where two variables ``a`` and ``b`` start out + conditionally independent in the prior, but become conditionally dependent + in the posterior do the so-called collider variable ``c`` on which they + both depend. This is called "moralization" in the graphical model + literature:: + + def model_2(): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.LogNormal(0, 1)) + c = pyro.sample("c", dist.Normal(a, b)) + pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.)) + + assert get_dependencies(model_2) == { + "prior_dependencies": { + "a": {"a": set()}, + "b": {"b": set()}, + "c": {"a": set(), "b": set(), "c": set()}, + "d": {"c": set(), "d": set()}, + }, + "posterior_dependencies": { + "a": {"a": set(), "b": set(), "c": set()}, + "b": {"b": set(), "c": set()}, + "c": {"c": set(), "d": set()}, + }, + } + + Dependencies can be more complex in the presence of plates. So far all the + dict values have been empty sets of plates, but in the following posterior + we see that ``c`` depends on itself across the plate ``p``. This means + that, among the elements of ``c``, e.g. ``c[0]`` depends on ``c[1]`` (this + is why we explicitly allow variables to depend on themselves):: + + def model_3(): + with pyro.plate("p", 5): + a = pyro.sample("a", dist.Normal(0, 1)) + pyro.sample("b", dist.Normal(a.sum(), 1), obs=torch.tensor(0.0)) + + assert get_dependencies(model_3) == { + "prior_dependencies": { + "a": {"a": set()}, + "b": {"a": set(), "b": set()}, + }, + "posterior_dependencies": { + "a": {"a": {"p"}, "b": set()}, + }, + } + + .. warning:: This currently relies on autograd and therefore works only for + continuous latent variables with differentiable dependencies. Discrete + latent variables will raise errors. Gradient blocking may silently drop + dependencies. + + **References** + + [1] S.Webb, A.GoliƄski, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018) + "Faithful inversion of generative models for effective amortized inference" + https://dl.acm.org/doi/10.5555/3327144.3327229 + + :param callable model: A model. + :param tuple model_args: Optional tuple of model args. + :param dict model_kwargs: Optional dict of model kwargs. + :returns: A dictionary of metadata (see above). + :rtype: dict + """ + if model_args is None: + model_args = () + if model_kwargs is None: + model_kwargs = {} + + def get_sample_sites(predicate=lambda msg: True): + with torch.enable_grad(), torch.random.fork_rng(): + with pyro.validation_enabled(False), RequiresGradMessenger(predicate): + trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) + return [msg for msg in trace.nodes.values() if is_sample_site(msg)] + + # Collect observations. + sample_sites = get_sample_sites() + observed = {msg["name"] for msg in sample_sites if msg["is_observed"]} + plates = { + msg["name"]: {f.name for f in msg["cond_indep_stack"] if f.vectorized} + for msg in sample_sites + } + + # First find transitive dependencies among latent and observed sites + prior_dependencies = {n: {n: set()} for n in plates} # no deps yet + for i, downstream in enumerate(sample_sites): + upstreams = [u for u in sample_sites[:i] if not u["is_observed"]] + if not upstreams: + continue + grads = torch.autograd.grad( + downstream["fn"].log_prob(downstream["value"]).sum(), + [u["value"] for u in upstreams], + allow_unused=True, + retain_graph=True, + ) + for upstream, grad in zip(upstreams, grads): + if grad is not None: + d = downstream["name"] + u = upstream["name"] + prior_dependencies[d][u] = set() + + # Then refine to direct dependencies among latent and observed sites. + for i, downstream in enumerate(sample_sites): + for j, upstream in enumerate(sample_sites[: max(0, i - 1)]): + if upstream["name"] not in prior_dependencies[downstream["name"]]: + continue + names = {upstream["name"], downstream["name"]} + sample_sites_ij = get_sample_sites(lambda msg: msg["name"] in names) + d = sample_sites_ij[i] + u = sample_sites_ij[j] + grad = torch.autograd.grad( + d["fn"].log_prob(d["value"]).sum(), + [u["value"]], + allow_unused=True, + retain_graph=True, + )[0] + if grad is None: + prior_dependencies[d["name"]].pop(u["name"]) + + # Next reverse dependencies and restrict downstream nodes to latent sites. + posterior_dependencies = {n: {} for n in plates if n not in observed} + for d, upstreams in prior_dependencies.items(): + for u, p in upstreams.items(): + if u not in observed: + # Note the folowing reverses: + # u is henceforth downstream and d is henceforth upstream. + posterior_dependencies[u][d] = p.copy() + + # Moralize: add dependencies among latent variables in each Markov blanket. + # This assumes all latents are eventually observed, at least indirectly. + order = {msg["name"]: i for i, msg in enumerate(reversed(sample_sites))} + for d, upstreams in prior_dependencies.items(): + upstreams = {u: p for u, p in upstreams.items() if u not in observed} + for u1, p1 in upstreams.items(): + for u2, p2 in upstreams.items(): + if order[u1] <= order[u2]: + p12 = posterior_dependencies[u2].setdefault(u1, set()) + p12 |= plates[u1] & plates[u2] - plates[d] + p12 |= plates[u2] & p1 + p12 |= plates[u1] & p2 + + return { + "prior_dependencies": prior_dependencies, + "posterior_dependencies": posterior_dependencies, + } + + +__all__ = [ + "get_dependencies", +] diff --git a/pyro/poutine/messenger.py b/pyro/poutine/messenger.py index 7b9a259c0a..bd610946a7 100644 --- a/pyro/poutine/messenger.py +++ b/pyro/poutine/messenger.py @@ -39,9 +39,6 @@ class Messenger: Most inference operations are implemented in subclasses of this. """ - def __init__(self): - pass - def __call__(self, fn): if not callable(fn): raise ValueError( diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index ca359f9db5..8e1f23b6f6 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -181,6 +181,7 @@ def dependency_z6_z5(z5): AutoLowRankMultivariateNormal, AutoIAFNormal, AutoLaplaceApproximation, + AutoStructured, AutoStructured_shapes, ], ) @@ -329,6 +330,7 @@ def __init__(self, model): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample), + AutoStructured, AutoStructured_median, ], ) @@ -376,6 +378,7 @@ def model(): functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_sample), + AutoStructured, AutoStructured_median, ], ) @@ -834,6 +837,7 @@ def __init__(self, model): AutoLaplaceApproximation, functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_mean), functools.partial(AutoDiagonalNormal, init_loc_fn=init_to_median), + AutoStructured, AutoStructured_predictive, ], ) diff --git a/tests/infer/test_inspect.py b/tests/infer/test_inspect.py new file mode 100644 index 0000000000..a0916ef3ac --- /dev/null +++ b/tests/infer/test_inspect.py @@ -0,0 +1,344 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import pyro +import pyro.distributions as dist +from pyro.distributions.testing.fakes import NonreparameterizedNormal +from pyro.infer.inspect import get_dependencies + + +def test_get_dependencies(): + def model(data): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", NonreparameterizedNormal(a, 0)) + c = pyro.sample("c", dist.Normal(b, 1)) + d = pyro.sample("d", dist.Normal(a, c.exp())) + + e = pyro.sample("e", dist.Normal(0, 1)) + f = pyro.sample("f", dist.Normal(0, 1)) + g = pyro.sample("g", dist.Bernoulli(logits=e + f), obs=torch.tensor(0.0)) + + with pyro.plate("p", len(data)): + d_ = d.detach() # this results in a known failure + h = pyro.sample("h", dist.Normal(c, d_.exp())) + i = pyro.deterministic("i", h + 1) + j = pyro.sample("j", dist.Delta(h + 1), obs=h + 1) + k = pyro.sample("k", dist.Normal(a, j.exp()), obs=data) + + return [a, b, c, d, e, f, g, h, i, j, k] + + data = torch.randn(3) + actual = get_dependencies(model, (data,)) + _ = set() + expected = { + "prior_dependencies": { + "a": {"a": _}, + "b": {"b": _, "a": _}, + "c": {"c": _, "b": _}, + "d": {"d": _, "c": _, "a": _}, + "e": {"e": _}, + "f": {"f": _}, + "g": {"g": _, "e": _, "f": _}, + "h": {"h": _, "c": _}, # [sic] + "k": {"k": _, "a": _, "h": _}, + }, + "posterior_dependencies": { + "a": {"a": _, "b": _, "c": _, "d": _, "h": _, "k": _}, + "b": {"b": _, "c": _}, + "c": {"c": _, "d": _, "h": _}, # [sic] + "d": {"d": _}, + "e": {"e": _, "g": _, "f": _}, + "f": {"f": _, "g": _}, + "h": {"h": _, "k": _}, + }, + } + assert actual == expected + + +def test_docstring_example_1(): + def model_1(): + a = pyro.sample("a", dist.Normal(0, 1)) + pyro.sample("b", dist.Normal(a, 1), obs=torch.tensor(0.0)) + + actual = get_dependencies(model_1) + expected = { + "prior_dependencies": { + "a": {"a": set()}, + "b": {"a": set(), "b": set()}, + }, + "posterior_dependencies": { + "a": {"a": set(), "b": set()}, + }, + } + assert actual == expected + + +def test_docstring_example_2(): + def model_2(): + a = pyro.sample("a", dist.Normal(0, 1)) + b = pyro.sample("b", dist.LogNormal(0, 1)) + c = pyro.sample("c", dist.Normal(a, b)) + pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0)) + + actual = get_dependencies(model_2) + expected = { + "prior_dependencies": { + "a": {"a": set()}, + "b": {"b": set()}, + "c": {"a": set(), "b": set(), "c": set()}, + "d": {"c": set(), "d": set()}, + }, + "posterior_dependencies": { + "a": {"a": set(), "b": set(), "c": set()}, + "b": {"b": set(), "c": set()}, + "c": {"c": set(), "d": set()}, + }, + } + assert actual == expected + + +def test_docstring_example_3(): + def model_3(): + with pyro.plate("p", 5): + a = pyro.sample("a", dist.Normal(0, 1)) + pyro.sample("b", dist.Normal(a.sum(), 1), obs=torch.tensor(0.0)) + + actual = get_dependencies(model_3) + expected = { + "prior_dependencies": { + "a": {"a": set()}, + "b": {"a": set(), "b": set()}, + }, + "posterior_dependencies": { + "a": {"a": {"p"}, "b": set()}, + }, + } + assert actual == expected + + +def test_plate_coupling(): + # x x + # || + # y + # + # This results in posterior dependency structure: + # + # x x y + # x ? ? ? + # x ? ? ? + + def model(data): + with pyro.plate("p", len(data)): + x = pyro.sample("x", dist.Normal(0, 1)) + pyro.sample("y", dist.Normal(x.sum(), 1), obs=data.sum()) + + data = torch.randn(2) + actual = get_dependencies(model, (data,)) + expected = { + "prior_dependencies": { + "x": {"x": set()}, + "y": {"y": set(), "x": set()}, + }, + "posterior_dependencies": { + "x": {"x": {"p"}, "y": set()}, + }, + } + assert actual == expected + + +def test_plate_coupling_2(): + # x x + # \\ y y + # \\ // + # z + # + # This results in posterior dependency structure: + # + # x x y y z + # x ? ? ? ? ? + # x ? ? ? ? ? + # y ? ? ? + # y ? ? ? + + def model(data): + with pyro.plate("p", len(data)): + x = pyro.sample("x", dist.Normal(0, 1)) + y = pyro.sample("y", dist.Normal(0, 1)) + pyro.sample("z", dist.Normal(x.sum(), y.sum().exp()), obs=data.sum()) + + data = torch.randn(2) + actual = get_dependencies(model, (data,)) + expected = { + "prior_dependencies": { + "x": {"x": set()}, + "y": {"y": set()}, + "z": {"z": set(), "x": set(), "y": set()}, + }, + "posterior_dependencies": { + "x": {"x": {"p"}, "y": {"p"}, "z": set()}, + "y": {"y": {"p"}, "z": set()}, + }, + } + assert actual == expected + + +def test_plate_coupling_3(): + # x x x x + # // \\ + # y y z z + # + # This results in posterior dependency structure: + # + # x x y y z + # x ? ? ? ? ? + # x ? ? ? ? ? + # y ? ? ? + # y ? ? ? + + def model(data): + i_plate = pyro.plate("i", data.shape[0], dim=-2) + j_plate = pyro.plate("j", data.shape[1], dim=-1) + with i_plate, j_plate: + x = pyro.sample("x", dist.Normal(0, 1)) + with i_plate: + pyro.sample("y", dist.Normal(x.sum(-1, True), 1), obs=data.sum(-1, True)) + with j_plate: + pyro.sample("z", dist.Normal(x.sum(-2, True), 1), obs=data.sum(-2, True)) + + data = torch.randn(3, 2) + actual = get_dependencies(model, (data,)) + expected = { + "prior_dependencies": { + "x": {"x": set()}, + "y": {"y": set(), "x": set()}, + "z": {"z": set(), "x": set()}, + }, + "posterior_dependencies": { + "x": {"x": {"i", "j"}, "y": set(), "z": set()}, + }, + } + assert actual == expected + + +def test_plate_collider(): + # x x y y + # \\ // + # zzzz + # + # This results in posterior dependency structure: + # + # x x y y z z z z + # x ? ? ? ? ? + # x ? ? ? ? ? + # y ? ? ? + # y ? ? ? + + def model(data): + i_plate = pyro.plate("i", data.shape[0], dim=-2) + j_plate = pyro.plate("j", data.shape[1], dim=-1) + + with i_plate: + x = pyro.sample("x", dist.Normal(0, 1)) + with j_plate: + y = pyro.sample("y", dist.Normal(0, 1)) + with i_plate, j_plate: + pyro.sample("z", dist.Normal(x, y.exp()), obs=data) + + data = torch.randn(3, 2) + actual = get_dependencies(model, (data,)) + _ = set() + expected = { + "prior_dependencies": { + "x": {"x": _}, + "y": {"y": _}, + "z": {"x": _, "y": _, "z": _}, + }, + "posterior_dependencies": { + "x": {"x": _, "y": _, "z": _}, + "y": {"y": _, "z": _}, + }, + } + assert actual == expected + + +def test_plate_dependency(): + # w w + # \ x1 x2 unroll x1 / \ x2 + # \ || y1 y2 =====> y1 | / \ | y2 + # \ || // \|/ \|/ + # z1 z2 z1 z2 + # + # This allows posterior dependency structure: + # + # w x x y y z z + # w ? ? ? ? ? ? ? + # x ? ? ? + # x ? ? ? + # y ? ? + # y ? ? + + def model(data): + w = pyro.sample("w", dist.Normal(0, 1)) + with pyro.plate("p", len(data)): + x = pyro.sample("x", dist.Normal(0, 1)) + y = pyro.sample("y", dist.Normal(0, 1)) + pyro.sample("z", dist.Normal(w + x + y, 1), obs=data) + + data = torch.rand(2) + actual = get_dependencies(model, (data,)) + _ = set() + expected = { + "prior_dependencies": { + "w": {"w": _}, + "x": {"x": _}, + "y": {"y": _}, + "z": {"w": _, "x": _, "y": _, "z": _}, + }, + "posterior_dependencies": { + "w": {"w": _, "x": _, "y": _, "z": _}, + "x": {"x": _, "y": _, "z": _}, + "y": {"y": _, "z": _}, + }, + } + assert actual == expected + + +def test_nested_plate_collider(): + # a a b b + # a a b b + # \\ // + # c c + # | + # d + + def model(): + plate_i = pyro.plate("i", 2, dim=-1) + plate_j = pyro.plate("j", 3, dim=-2) + plate_k = pyro.plate("k", 3, dim=-2) + + with plate_i: + with plate_j: + a = pyro.sample("a", dist.Normal(0, 1)) + with plate_k: + b = pyro.sample("b", dist.Normal(0, 1)) + c = pyro.sample("c", dist.Normal(a.sum(0) + b.sum([0, 1]), 1)) + pyro.sample("d", dist.Normal(c.sum(), 1), obs=torch.zeros(())) + + actual = get_dependencies(model) + _ = set() + expected = { + "prior_dependencies": { + "a": {"a": _}, + "b": {"b": _}, + "c": {"c": _, "a": _, "b": _}, + "d": {"d": _, "c": _}, + }, + "posterior_dependencies": { + "a": {"a": {"j"}, "b": _, "c": _}, + "b": {"b": {"k"}, "c": _}, + "c": {"c": {"i"}, "d": _}, + }, + } + assert actual == expected