Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add automatic dependency tracking to AutoStructured guide #2824

Merged
merged 13 commits into from
Sep 7, 2021
13 changes: 13 additions & 0 deletions docs/source/infer.util.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Inference utilities
===================

.. autofunction:: pyro.infer.util.enable_validation
.. autofunction:: pyro.infer.util.is_validation_enabled
.. autofunction:: pyro.infer.util.validation_enabled

Model inspection
----------------

.. automodule:: pyro.infer.inspect
:members:
:member-order: bysource
1 change: 1 addition & 0 deletions docs/source/inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ See `Intro II <http://pyro.ai/examples/intro_part_ii.html>`_ for a discussion of
mcmc
infer.autoguide
infer.reparam
infer.util
117 changes: 80 additions & 37 deletions pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def model():
import operator
import warnings
import weakref
from collections import defaultdict
from collections import OrderedDict, defaultdict
from contextlib import ExitStack
from types import SimpleNamespace
from typing import Callable, Dict, Union
from typing import Callable, Dict, Optional, Union

import torch
from torch import nn
Expand All @@ -40,6 +40,7 @@ def model():
init_to_median,
)
from pyro.infer.enum import config_enumerate
from pyro.infer.inspect import get_dependencies
from pyro.nn import PyroModule, PyroParam
from pyro.ops.hessian import hessian
from pyro.ops.tensor_utils import periodic_repeat
Expand Down Expand Up @@ -1310,9 +1311,18 @@ def model(data):
y = pyro.sample("y", dist.Normal(0, 1))
pyro.sample("z", dist.Normal(y, x), obs=data)

# Either fully automatic...
guide = AutoStructured(model)

# ...or with specified conditional and dependency types...
guide = AutoStructured(
model, conditionals="normal", dependencies="linear"
)

# ...or with custom dependency structure and distribution types.
guide = AutoStructured(
model=model,
conditionals={"x": "normal", "y": "normal"},
conditionals={"x": "normal", "y": "delta"},
dependencies={"x": {"y": "linear"}},
)

Expand All @@ -1333,19 +1343,20 @@ def optim_config(param_name):
adam = pyro.optim.Adam(optim_config)

:param callable model: A Pyro model.
:param conditionals: Family of distribution with which to model each latent
variable's conditional posterior. This should be a dict mapping each
latent variable name to either a string in ("delta", "normal", or
"mvn") or to a callable that returns a sample from a zero mean (or
approximately centered) noise distribution (such callables typically
call ``pyro.param()`` and ``pyro.sample()`` internally).
:param dependencies: Dict mapping each site name to a dict of its upstream
dependencies; each inner dict maps upstream site name to either the
string "linear" or a callable that maps a *flattened* upstream
perturbation to *flattened* downstream perturbation. The string
"linear" is equivalent to ``nn.Linear(upstream.numel(),
downstream.numel(), bias=False)``. Dependencies must not contain
cycles or self-loops.
:param conditionals: Either a single distribution type or a dict mapping
each latent variable name to a distribution type. A distribution type
is either a string in {"delta", "normal", "mvn"} or a callable that
returns a sample from a zero mean (or approximately centered) noise
distribution (such callables typically call ``pyro.param()`` and
``pyro.sample()`` internally).
:param dependencies: Dependency type, or a dict mapping each site name to a
dict mapping its upstream dependencies to dependency types. If only a
dependecy type is provided, dependency structure will be inferred. A
dependency type is either the string "linear" or a callable that maps a
*flattened* upstream perturbation to *flattened* downstream
perturbation. The string "linear" is equivalent to
``nn.Linear(upstream.numel(), downstream.numel(), bias=False)``.
Dependencies must not contain cycles or self-loops.
:param callable init_loc_fn: A per-site initialization function.
See :ref:`autoguide-initialization` section for available functions.
:param float init_scale: Initial scale for the standard deviation of each
Expand All @@ -1363,33 +1374,62 @@ def __init__(
self,
model,
*,
conditionals: Dict[str, Union[str, Callable]] = "normal",
dependencies: Dict[str, Dict[str, Union[str, Callable]]] = "linear",
init_loc_fn=init_to_feasible,
init_scale=0.1,
create_plates=None,
conditionals: Union[str, Dict[str, Union[str, Callable]]] = "mvn",
dependencies: Union[str, Dict[str, Dict[str, Union[str, Callable]]]] = "linear",
init_loc_fn: Callable = init_to_feasible,
init_scale: float = 0.1,
create_plates: Optional[Callable] = None,
):
assert isinstance(conditionals, dict)
for name, fn in conditionals.items():
assert isinstance(name, str)
assert isinstance(fn, str) or callable(fn)
assert isinstance(dependencies, dict)
for downstream, deps in dependencies.items():
assert downstream in conditionals
assert isinstance(deps, dict)
for upstream, dep in deps.items():
assert upstream in conditionals
assert upstream != downstream
assert isinstance(dep, str) or callable(dep)
assert isinstance(conditionals, (dict, str))
if isinstance(conditionals, dict):
for name, fn in conditionals.items():
assert isinstance(name, str)
assert isinstance(fn, str) or callable(fn)
assert isinstance(dependencies, (dict, str))
if isinstance(dependencies, dict):
for downstream, deps in dependencies.items():
assert downstream in conditionals
assert isinstance(deps, dict)
for upstream, dep in deps.items():
assert upstream in conditionals
assert upstream != downstream
assert isinstance(dep, str) or callable(dep)
self.conditionals = conditionals
self.dependencies = dependencies

if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError(f"Expected init_scale > 0. but got {init_scale}")
self._init_scale = init_scale
self._original_model = (model,)
model = InitMessenger(init_loc_fn)(model)
super().__init__(model, create_plates=create_plates)

def _auto_config(self, sample_sites, args, kwargs):
# Instantiate conditionals as dictionaries.
if not isinstance(self.conditionals, dict):
self.conditionals = {
name: self.conditionals for name, site in sample_sites.items()
}

# Instantiate dependencies as dictionaries.
if not isinstance(self.dependencies, dict):
model = self._original_model[0]
meta = poutine.block(get_dependencies)(model, args, kwargs)
# Use posterior dependency edges but with prior ordering. This
# allows sampling of globals before locals on which they depend.
prior_order = {name: i for i, name in enumerate(sample_sites)}
dependencies = defaultdict(dict)
for d, upstreams in meta["posterior_dependencies"].items():
assert d in sample_sites
for u, plates in upstreams.items():
# TODO use plates to reduce dimension of dependency.
if u in sample_sites:
if prior_order[u] > prior_order[d]:
dependencies[u][d] = self.dependencies
elif prior_order[d] > prior_order[u]:
dependencies[d][u] = self.dependencies
self.dependencies = dict(dependencies)

def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)

Expand All @@ -1400,11 +1440,13 @@ def _setup_prototype(self, *args, **kwargs):
self.deps = PyroModule()
self._batch_shapes = {}
self._unconstrained_event_shapes = {}
sample_sites = OrderedDict(self.prototype_trace.iter_stochastic_nodes())
self._auto_config(sample_sites, args, kwargs)

# Collect unconstrained shapes.
init_locs = {}
numel = {}
for name, site in self.prototype_trace.iter_stochastic_nodes():
for name, site in sample_sites.items():
with helpful_support_errors(site):
init_loc = (
biject_to(site["fn"].support).inv(site["value"].detach()).detach()
Expand All @@ -1419,7 +1461,7 @@ def _setup_prototype(self, *args, **kwargs):
# Initialize guide params.
children = defaultdict(list)
num_pending = {}
for name, site in self.prototype_trace.iter_stochastic_nodes():
for name, site in sample_sites.items():
# Initialize location parameters.
init_loc = init_locs[name]
_deep_setattr(self.locs, name, PyroParam(init_loc))
Expand Down Expand Up @@ -1449,7 +1491,7 @@ def _setup_prototype(self, *args, **kwargs):
deps = PyroModule()
_deep_setattr(self.deps, name, deps)
for upstream, dep in self.dependencies.get(name, {}).items():
assert upstream in self.prototype_trace.nodes
assert upstream in sample_sites
children[upstream].append(name)
num_pending[name] += 1
if isinstance(dep, str) and dep == "linear":
Expand All @@ -1462,14 +1504,15 @@ def _setup_prototype(self, *args, **kwargs):
_deep_setattr(deps, upstream, dep)

# Topologically sort sites.
# TODO should we choose a more optimal structure?
self._sorted_sites = []
while num_pending:
name, count = min(num_pending.items(), key=lambda kv: (kv[1], kv[0]))
assert count == 0, f"cyclic dependency: {name}"
del num_pending[name]
for child in children[name]:
num_pending[child] -= 1
site = self._compress_site(self.prototype_trace.nodes[name])
site = self._compress_site(sample_sites[name])
self._sorted_sites.append((name, site))

# Prune non-essential parts of the trace to save memory.
Expand Down
Loading