Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dependencies in the AutoReparam strategy #2966

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pyro/infer/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def _pyro_post_sample(self, msg):
msg["value"] = ProvenanceTensor(value, provenance)


@torch.enable_grad()
def get_dependencies(
model: Callable,
model_args: Optional[tuple] = None,
Expand Down Expand Up @@ -172,7 +171,7 @@ def model_3():

# Collect sites with tracked provenance.
with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False):
with TrackProvenance():
with poutine.block(), TrackProvenance():
trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
sample_sites = [msg for msg in trace.nodes.values() if is_sample_site(msg)]

Expand Down Expand Up @@ -272,7 +271,7 @@ def model(data):
model_kwargs = {}

with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False):
with TrackProvenance():
with poutine.block(), TrackProvenance():
trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)

sample_sample = {}
Expand Down
65 changes: 40 additions & 25 deletions pyro/infer/reparam/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
See :func:`~pyro.poutine.handlers.reparam` for usage.
"""

from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Union
from abc import abstractmethod
from typing import Callable, Dict, Optional

import torch
from torch.distributions import constraints

import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer.inspect import get_dependencies
from pyro.poutine.reparam_messenger import BaseStrategy

from .loc_scale import LocScaleReparam
from .projected_normal import ProjectedNormalReparam
Expand All @@ -25,7 +27,7 @@
from .transform import TransformReparam


class Strategy(ABC):
class Strategy(BaseStrategy):
"""
Abstract base class for reparametrizer configuration strategies.

Expand Down Expand Up @@ -60,24 +62,27 @@ def configure(self, msg: dict) -> Optional[Reparam]:
"""
raise NotImplementedError

def __call__(self, msg_or_fn: Union[dict, Callable]):
def __call__(self, fn: Callable):
"""
Strategies can be used as decorators to reparametrize a model.

:param msg_or_fn: Public use: a model to be decorated. (Internal use: a
site to be configured for reparametrization).
:param msg_or_fn: Public use: a model to be decorated.
"""
if isinstance(msg_or_fn, dict): # Internal use during configuration.
msg = msg_or_fn
name = msg["name"]
if name in self.config:
return self.config[name]
result = self.configure(msg)
self.config[name] = result
return result
else: # Public use as a decorator or handler.
fn = msg_or_fn
return poutine.reparam(fn, self)
return poutine.reparam(fn, self)

def config_with_model(
self,
msg: dict,
model: Callable,
model_args: tuple,
model_kwargs: dict,
) -> Optional[Reparam]:
name = msg["name"]
if name in self.config:
return self.config[name]
result = self.configure(msg)
self.config[name] = result
return result


class MinimalReparam(Strategy):
Expand Down Expand Up @@ -163,6 +168,18 @@ def __init__(self, *, centered: Optional[float] = None):
assert centered is None or isinstance(centered, float)
super().__init__()
self.centered = centered
self.dependencies = None

def config_with_model(
self,
msg: dict,
model: Callable,
model_args: tuple,
model_kwargs: dict,
) -> Optional[Reparam]:
if self.dependencies is None:
self.dependencies = get_dependencies(model, model_args, model_kwargs)
super().config_with_model(self, msg, model, model_args, model_kwargs)

def configure(self, msg: dict) -> Optional[Reparam]:
# Focus on tricks for latent sites.
Expand All @@ -178,10 +195,12 @@ def configure(self, msg: dict) -> Optional[Reparam]:
if isinstance(fn, torch.distributions.RelaxedOneHotCategorical):
return GumbelSoftmaxReparam()

# Apply a learnable LocScaleReparam.
result = _loc_scale_reparam(msg["name"], fn, self.centered)
if result is not None:
return result
# Check whether parameters depend on upstream latent variables.
if len(self.dependencies["prior_dependencies"][msg["name"]]) > 1:
# Apply a learnable LocScaleReparam.
result = _loc_scale_reparam(msg["name"], fn, self.centered)
if result is not None:
return result

# Apply minimal reparametrizers.
return _minimal_reparam(fn, msg["is_observed"])
Expand All @@ -200,10 +219,6 @@ def _loc_scale_reparam(name, fn, centered):
if not _is_unconstrained(fn.support):
return

# TODO reparametrize only if parameters are variable. We might guess
# based on whether parameters are differentiable, .requires_grad. See
# https://github.com/pyro-ppl/pyro/pull/2824

Comment on lines -203 to -206
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR resolves this TODO

# Create an elementwise-learnable reparametrizer.
shape_params = sorted(params - {"loc", "scale"})
return LocScaleReparam(centered=centered, shape_params=shape_params)
Expand Down
54 changes: 43 additions & 11 deletions pyro/poutine/reparam_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Callable, Dict, Union

import torch
Expand All @@ -15,6 +17,26 @@ def _get_init_messengers():
return []


class BaseStrategy(ABC):
"""
Abstract base class for reparametrization strategies.
"""

@abstractmethod
def config_with_model(
self,
msg: dict,
model: Callable,
model_args: tuple,
model_kwargs: dict,
):
"""
Returns a reparametrizer or None, based on ``msg``. The ``model``,
``model_args``, and ``model_kwargs`` may be used for initialization.
"""
raise NotImplementedError


class ReparamMessenger(Messenger):
"""
Reparametrizes each affected sample site into one or more auxiliary sample
Expand All @@ -33,17 +55,19 @@ class ReparamMessenger(Messenger):
https://arxiv.org/pdf/1906.03028.pdf

:param config: Configuration, either a dict mapping site name to
:class:`~pyro.infer.reparam.reparam.Reparameterizer` , or a function
:class:`~pyro.infer.reparam.reparam.Reparameterizer` , a function
mapping site to :class:`~pyro.infer.reparam.reparam.Reparameterizer` or
None. See :mod:`pyro.infer.reparam.strategies` for built-in
configuration strategies.
None, or a :class:`BaseStrategy` . See
:mod:`pyro.infer.reparam.strategies` for built-in configuration
strategies.
:type config: dict or callable
"""

def __init__(self, config: Union[Dict[str, object], Callable]):
def __init__(self, config: Union[Dict[str, object], Callable, BaseStrategy]):
super().__init__()
assert isinstance(config, dict) or callable(config)
self.config = config
self._model = None
self._args_kwargs = None

def __call__(self, fn):
Expand All @@ -54,6 +78,8 @@ def _pyro_sample(self, msg):
return
if isinstance(self.config, dict):
reparam = self.config.get(msg["name"])
elif isinstance(self.config, BaseStrategy):
self.config.config_with_model(msg, self._model, *self._args_kwargs)
else:
reparam = self.config(msg)
if reparam is None:
Expand Down Expand Up @@ -120,6 +146,17 @@ def _pyro_sample(self, msg):
msg["value"] = new_msg["value"]
msg["is_observed"] = new_msg["is_observed"]

@contextmanager
def call(self, model, model_args, model_kwargs):
# This saves model,args,kwargs for optional use by reparameterizers.
self._model = model
self._args_kwargs = model_args, model_kwargs
try:
yield self
finally:
self._model = None
self._args_kwargs = None


class ReparamHandler(object):
"""
Expand All @@ -132,10 +169,5 @@ def __init__(self, msngr, fn):
super().__init__()

def __call__(self, *args, **kwargs):
# This saves args,kwargs for optional use by reparameterizers.
self.msngr._args_kwargs = args, kwargs
try:
with self.msngr:
return self.fn(*args, **kwargs)
finally:
self.msngr._args_kwargs = None
with self.msngr.call(self.fn, args, kwargs):
return self.fn(*args, **kwargs)