diff --git a/pyro/infer/inspect.py b/pyro/infer/inspect.py index 6bab27da7c..1235431983 100644 --- a/pyro/infer/inspect.py +++ b/pyro/infer/inspect.py @@ -50,7 +50,6 @@ def _pyro_post_sample(self, msg): msg["value"] = ProvenanceTensor(value, provenance) -@torch.enable_grad() def get_dependencies( model: Callable, model_args: Optional[tuple] = None, @@ -172,7 +171,7 @@ def model_3(): # Collect sites with tracked provenance. with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False): - with TrackProvenance(): + with poutine.block(), TrackProvenance(): trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) sample_sites = [msg for msg in trace.nodes.values() if is_sample_site(msg)] @@ -272,7 +271,7 @@ def model(data): model_kwargs = {} with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False): - with TrackProvenance(): + with poutine.block(), TrackProvenance(): trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) sample_sample = {} diff --git a/pyro/infer/reparam/strategies.py b/pyro/infer/reparam/strategies.py index 4a471caaed..c297a13104 100644 --- a/pyro/infer/reparam/strategies.py +++ b/pyro/infer/reparam/strategies.py @@ -8,14 +8,16 @@ See :func:`~pyro.poutine.handlers.reparam` for usage. """ -from abc import ABC, abstractmethod -from typing import Callable, Dict, Optional, Union +from abc import abstractmethod +from typing import Callable, Dict, Optional import torch from torch.distributions import constraints import pyro.distributions as dist import pyro.poutine as poutine +from pyro.infer.inspect import get_dependencies +from pyro.poutine.reparam_messenger import BaseStrategy from .loc_scale import LocScaleReparam from .projected_normal import ProjectedNormalReparam @@ -25,7 +27,7 @@ from .transform import TransformReparam -class Strategy(ABC): +class Strategy(BaseStrategy): """ Abstract base class for reparametrizer configuration strategies. @@ -60,24 +62,27 @@ def configure(self, msg: dict) -> Optional[Reparam]: """ raise NotImplementedError - def __call__(self, msg_or_fn: Union[dict, Callable]): + def __call__(self, fn: Callable): """ Strategies can be used as decorators to reparametrize a model. - :param msg_or_fn: Public use: a model to be decorated. (Internal use: a - site to be configured for reparametrization). + :param msg_or_fn: Public use: a model to be decorated. """ - if isinstance(msg_or_fn, dict): # Internal use during configuration. - msg = msg_or_fn - name = msg["name"] - if name in self.config: - return self.config[name] - result = self.configure(msg) - self.config[name] = result - return result - else: # Public use as a decorator or handler. - fn = msg_or_fn - return poutine.reparam(fn, self) + return poutine.reparam(fn, self) + + def config_with_model( + self, + msg: dict, + model: Callable, + model_args: tuple, + model_kwargs: dict, + ) -> Optional[Reparam]: + name = msg["name"] + if name in self.config: + return self.config[name] + result = self.configure(msg) + self.config[name] = result + return result class MinimalReparam(Strategy): @@ -163,6 +168,18 @@ def __init__(self, *, centered: Optional[float] = None): assert centered is None or isinstance(centered, float) super().__init__() self.centered = centered + self.dependencies = None + + def config_with_model( + self, + msg: dict, + model: Callable, + model_args: tuple, + model_kwargs: dict, + ) -> Optional[Reparam]: + if self.dependencies is None: + self.dependencies = get_dependencies(model, model_args, model_kwargs) + super().config_with_model(self, msg, model, model_args, model_kwargs) def configure(self, msg: dict) -> Optional[Reparam]: # Focus on tricks for latent sites. @@ -178,10 +195,12 @@ def configure(self, msg: dict) -> Optional[Reparam]: if isinstance(fn, torch.distributions.RelaxedOneHotCategorical): return GumbelSoftmaxReparam() - # Apply a learnable LocScaleReparam. - result = _loc_scale_reparam(msg["name"], fn, self.centered) - if result is not None: - return result + # Check whether parameters depend on upstream latent variables. + if len(self.dependencies["prior_dependencies"][msg["name"]]) > 1: + # Apply a learnable LocScaleReparam. + result = _loc_scale_reparam(msg["name"], fn, self.centered) + if result is not None: + return result # Apply minimal reparametrizers. return _minimal_reparam(fn, msg["is_observed"]) @@ -200,10 +219,6 @@ def _loc_scale_reparam(name, fn, centered): if not _is_unconstrained(fn.support): return - # TODO reparametrize only if parameters are variable. We might guess - # based on whether parameters are differentiable, .requires_grad. See - # https://github.com/pyro-ppl/pyro/pull/2824 - # Create an elementwise-learnable reparametrizer. shape_params = sorted(params - {"loc", "scale"}) return LocScaleReparam(centered=centered, shape_params=shape_params) diff --git a/pyro/poutine/reparam_messenger.py b/pyro/poutine/reparam_messenger.py index d2460b40ae..0c4edeecf1 100644 --- a/pyro/poutine/reparam_messenger.py +++ b/pyro/poutine/reparam_messenger.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import warnings +from abc import ABC, abstractmethod +from contextlib import contextmanager from typing import Callable, Dict, Union import torch @@ -15,6 +17,26 @@ def _get_init_messengers(): return [] +class BaseStrategy(ABC): + """ + Abstract base class for reparametrization strategies. + """ + + @abstractmethod + def config_with_model( + self, + msg: dict, + model: Callable, + model_args: tuple, + model_kwargs: dict, + ): + """ + Returns a reparametrizer or None, based on ``msg``. The ``model``, + ``model_args``, and ``model_kwargs`` may be used for initialization. + """ + raise NotImplementedError + + class ReparamMessenger(Messenger): """ Reparametrizes each affected sample site into one or more auxiliary sample @@ -33,17 +55,19 @@ class ReparamMessenger(Messenger): https://arxiv.org/pdf/1906.03028.pdf :param config: Configuration, either a dict mapping site name to - :class:`~pyro.infer.reparam.reparam.Reparameterizer` , or a function + :class:`~pyro.infer.reparam.reparam.Reparameterizer` , a function mapping site to :class:`~pyro.infer.reparam.reparam.Reparameterizer` or - None. See :mod:`pyro.infer.reparam.strategies` for built-in - configuration strategies. + None, or a :class:`BaseStrategy` . See + :mod:`pyro.infer.reparam.strategies` for built-in configuration + strategies. :type config: dict or callable """ - def __init__(self, config: Union[Dict[str, object], Callable]): + def __init__(self, config: Union[Dict[str, object], Callable, BaseStrategy]): super().__init__() assert isinstance(config, dict) or callable(config) self.config = config + self._model = None self._args_kwargs = None def __call__(self, fn): @@ -54,6 +78,8 @@ def _pyro_sample(self, msg): return if isinstance(self.config, dict): reparam = self.config.get(msg["name"]) + elif isinstance(self.config, BaseStrategy): + self.config.config_with_model(msg, self._model, *self._args_kwargs) else: reparam = self.config(msg) if reparam is None: @@ -120,6 +146,17 @@ def _pyro_sample(self, msg): msg["value"] = new_msg["value"] msg["is_observed"] = new_msg["is_observed"] + @contextmanager + def call(self, model, model_args, model_kwargs): + # This saves model,args,kwargs for optional use by reparameterizers. + self._model = model + self._args_kwargs = model_args, model_kwargs + try: + yield self + finally: + self._model = None + self._args_kwargs = None + class ReparamHandler(object): """ @@ -132,10 +169,5 @@ def __init__(self, msngr, fn): super().__init__() def __call__(self, *args, **kwargs): - # This saves args,kwargs for optional use by reparameterizers. - self.msngr._args_kwargs = args, kwargs - try: - with self.msngr: - return self.fn(*args, **kwargs) - finally: - self.msngr._args_kwargs = None + with self.msngr.call(self.fn, args, kwargs): + return self.fn(*args, **kwargs)