Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing pyro.infer.predictive.WeighedPredictive which reports weights along with predicted samples #3345

Merged
merged 10 commits into from
Mar 26, 2024
3 changes: 2 additions & 1 deletion pyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pyro.infer.mcmc.hmc import HMC
from pyro.infer.mcmc.nuts import NUTS
from pyro.infer.mcmc.rwkernel import RandomWalkKernel
from pyro.infer.predictive import Predictive
from pyro.infer.predictive import Predictive, WeighedPredictive
from pyro.infer.renyi_elbo import RenyiELBO
from pyro.infer.rws import ReweightedWakeSleep
from pyro.infer.smcfilter import SMCFilter
Expand Down Expand Up @@ -62,4 +62,5 @@
"TraceTailAdaptive_ELBO",
"Trace_ELBO",
"Trace_MMD",
"WeighedPredictive",
]
20 changes: 4 additions & 16 deletions pyro/infer/importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .abstract_infer import TracePosterior
from .enum import get_importance_trace
from .util import plate_log_prob_sum


class Importance(TracePosterior):
Expand Down Expand Up @@ -143,22 +144,9 @@ def _fn(*args, **kwargs):
log_weights = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
else:
wd = guide_trace.plate_to_symbol["num_particles_vectorized"]
log_weights = 0.0
for site in model_trace.nodes.values():
if site["type"] != "sample":
continue
log_weights += torch.einsum(
site["packed"]["log_prob"]._pyro_dims + "->" + wd,
[site["packed"]["log_prob"]],
)

for site in guide_trace.nodes.values():
if site["type"] != "sample":
continue
log_weights -= torch.einsum(
site["packed"]["log_prob"]._pyro_dims + "->" + wd,
[site["packed"]["log_prob"]],
)
log_weights = plate_log_prob_sum(model_trace, wd) - plate_log_prob_sum(
guide_trace, wd
)

if normalized:
log_weights = log_weights - torch.logsumexp(log_weights)
Expand Down
203 changes: 168 additions & 35 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

import warnings
from functools import reduce
from typing import List, NamedTuple, Union

import torch

import pyro
import pyro.poutine as poutine
from pyro.infer.util import plate_log_prob_sum
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites


Expand All @@ -31,53 +34,58 @@ def _guess_max_plate_nesting(model, args, kwargs):
return max_plate_nesting


class _predictiveResults(NamedTuple):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the cleanup! This will also help us with #2550

"""
Return value of call to ``_predictive`` and ``_predictive_sequential``.
"""

samples: dict
trace: Union[Trace, List[Trace]]


def _predictive_sequential(
model,
posterior_samples,
model_args,
model_kwargs,
num_samples,
return_site_shapes,
return_trace=False,
model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes
):
collected = []
collected_samples = []
collected_trace = []
samples = [
{k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples)
]
for i in range(num_samples):
trace = poutine.trace(poutine.condition(model, samples[i])).get_trace(
*model_args, **model_kwargs
)
if return_trace:
collected.append(trace)
else:
collected.append(
{site: trace.nodes[site]["value"] for site in return_site_shapes}
)
collected_trace.append(trace)
collected_samples.append(
{site: trace.nodes[site]["value"] for site in return_site_shapes}
)

if return_trace:
return collected
else:
return {
site: torch.stack([s[site] for s in collected]).reshape(shape)
return _predictiveResults(
trace=collected_trace,
samples={
site: torch.stack([s[site] for s in collected_samples]).reshape(shape)
for site, shape in return_site_shapes.items()
}
},
)


_predictive_vectorize_plate_name = "_num_predictive_samples"


def _predictive(
model,
posterior_samples,
num_samples,
return_sites=(),
return_trace=False,
parallel=False,
model_args=(),
model_kwargs={},
mask=True,
):
model = torch.no_grad()(poutine.mask(model, mask=False))
model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model)
max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
vectorize = pyro.plate(
"_num_predictive_samples", num_samples, dim=-max_plate_nesting - 1
_predictive_vectorize_plate_name, num_samples, dim=-max_plate_nesting - 1
)
model_trace = prune_subsample_sites(
poutine.trace(model).get_trace(*model_args, **model_kwargs)
Expand All @@ -93,12 +101,6 @@ def _predictive(
)
reshaped_samples[name] = sample

if return_trace:
trace = poutine.trace(
poutine.condition(vectorize(model), reshaped_samples)
).get_trace(*model_args, **model_kwargs)
return trace

return_site_shapes = {}
for site in model_trace.stochastic_nodes + model_trace.observation_nodes:
append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape)
Expand Down Expand Up @@ -131,7 +133,6 @@ def _predictive(
model_kwargs,
num_samples,
return_site_shapes,
return_trace=False,
)

trace = poutine.trace(
Expand All @@ -148,7 +149,7 @@ def _predictive(
else:
predictions[site] = value.reshape(shape)

return predictions
return _predictiveResults(trace=trace, samples=predictions)


class Predictive(torch.nn.Module):
Expand Down Expand Up @@ -269,7 +270,7 @@ def forward(self, *args, **kwargs):
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
)
).samples
return _predictive(
self.model,
posterior_samples,
Expand All @@ -278,7 +279,7 @@ def forward(self, *args, **kwargs):
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
)
).samples

def get_samples(self, *args, **kwargs):
warnings.warn(
Expand All @@ -304,12 +305,144 @@ def get_vectorized_trace(self, *args, **kwargs):
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
)
).samples
return _predictive(
self.model,
posterior_samples,
self.num_samples,
return_trace=True,
parallel=True,
model_args=args,
model_kwargs=kwargs,
).trace


class WeighedPredictiveResults(NamedTuple):
"""
Return value of call to instance of :class:`WeighedPredictive`.
"""

samples: Union[dict, tuple]
log_weights: torch.Tensor
guide_log_prob: torch.Tensor
model_log_prob: torch.Tensor


class WeighedPredictive(Predictive):
"""
Class used to construct a weighed predictive distribution that is based
on the same initialization interface as :class:`Predictive`.

The methods `.forward` and `.call` can be called with an additional keyword argument
``model_guide`` which is the model used to create and optimize the guide (if not
provided ``model_guide`` defaults to ``self.model``), and they return both samples and log_weights.

The weights are calculated as the per sample gap between the model_guide log-probability
and the guide log-probability (a guide must always be provided).

A typical use case would be based on a ``model`` :math:`p(x,z)=p(x|z)p(z)` and ``guide`` :math:`q(z)`
that has already been fitted to the model given observations :math:`p(X_{obs},z)`, both of which
are provided at itialization of :class:`WeighedPredictive` (same as you would do with :class:`Predictive`).
When calling an instance of :class:`WeighedPredictive` we provide the model given observations :math:`p(X_{obs},z)`
as the keyword argument ``model_guide``.
The resulting output would be the usual samples :math:`p(x|z)q(z)` returned by :class:`Predictive`,
along with per sample weights :math:`p(X_{obs},z)/q(z)`. The samples and weights can be fed into
:any:`weighed_quantile` in order to obtain the true quantiles of the resulting distribution.

Note that the ``model`` can be more elaborate with sample sites :math:`y` that are not observed
and are not part of the guide, if the samples sites :math:`y` are sampled after the observations
and the latent variables sampled by the guide, such that :math:`p(x,y,z)=p(y|x,z)p(x|z)p(z)` where
each element in the product represents a set of ``pyro.sample`` statements.
"""

def call(self, *args, **kwargs):
"""
Method `.call` that is backwards compatible with the same method found in :class:`Predictive`
but can be called with an additional keyword argument `model_guide`
which is the model used to create and optimize the guide.

Returns :class:`WeighedPredictiveResults` which has attributes ``.samples`` and per sample
weights ``.log_weights``.
"""
result = self.forward(*args, **kwargs)
return WeighedPredictiveResults(
samples=tuple(v for _, v in sorted(result.items())),
log_weights=result.log_weights,
guide_log_prob=result.guide_log_prob,
model_log_prob=result.model_log_prob,
)

def forward(self, *args, **kwargs):
"""
Method `.forward` that is backwards compatible with the same method found in :class:`Predictive`
but can be called with an additional keyword argument `model_guide`
which is the model used to create and optimize the guide.

Returns :class:`WeighedPredictiveResults` which has attributes ``.samples`` and per sample
weights ``.log_weights``.
"""
model_guide = kwargs.pop("model_guide", self.model)
return_sites = self.return_sites
# return all sites by default if a guide is provided.
return_sites = None if not return_sites else return_sites
guide_predictive = _predictive(
self.guide,
self.posterior_samples,
self.num_samples,
return_sites=None,
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
mask=False,
)
posterior_samples = guide_predictive.samples
model_predictive = _predictive(
model_guide,
posterior_samples,
self.num_samples,
return_sites=return_sites,
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
mask=False,
)
if not isinstance(guide_predictive.trace, list):
guide_trace = prune_subsample_sites(guide_predictive.trace)
model_trace = prune_subsample_sites(model_predictive.trace)
guide_trace.compute_score_parts()
model_trace.compute_log_prob()
guide_trace.pack_tensors()
model_trace.pack_tensors(guide_trace.plate_to_symbol)
plate_symbol = guide_trace.plate_to_symbol[_predictive_vectorize_plate_name]
guide_log_prob = plate_log_prob_sum(guide_trace, plate_symbol)
model_log_prob = plate_log_prob_sum(model_trace, plate_symbol)
else:
guide_log_prob = torch.Tensor(
[
trace_element.log_prob_sum()
for trace_element in guide_predictive.trace
]
)
model_log_prob = torch.Tensor(
[
trace_element.log_prob_sum()
for trace_element in model_predictive.trace
]
)
return WeighedPredictiveResults(
samples=(
_predictive(
self.model,
posterior_samples,
self.num_samples,
return_sites=return_sites,
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
).samples
if model_guide is not self.model
else model_predictive.samples
),
log_weights=model_log_prob - guide_log_prob,
guide_log_prob=guide_log_prob,
model_log_prob=model_log_prob,
)
16 changes: 16 additions & 0 deletions pyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pyro.ops import packed
from pyro.ops.einsum.adjoint import require_backward
from pyro.ops.rings import MarginalRing
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import site_is_subsample

from .. import settings
Expand Down Expand Up @@ -342,3 +343,18 @@ def check_fully_reparametrized(guide_site):
raise NotImplementedError(
"All distributions in the guide must be fully reparameterized."
)


def plate_log_prob_sum(trace: Trace, plate_symbol: str) -> torch.Tensor:
"""
Get log probability sum from trace while keeping indexing over the specified plate.
"""
log_prob_sum = 0.0
for site in trace.nodes.values():
if site["type"] != "sample":
continue
log_prob_sum += torch.einsum(
site["packed"]["log_prob"]._pyro_dims + "->" + plate_symbol,
[site["packed"]["log_prob"]],
)
return log_prob_sum
19 changes: 11 additions & 8 deletions pyro/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,17 @@ def weighed_quantile(
:param int dim: dimension to take quantiles from ``input``.
:returns torch.Tensor: quantiles of ``input`` at ``probs``.

Example:
>>> from pyro.ops.stats import weighed_quantile
>>> import torch
>>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]])
>>> probs = torch.Tensor([0.2, 0.8])
>>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log()
>>> result = weighed_quantile(input, probs, log_weights, -1)
>>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]]))
**Example:**

.. doctest::

>>> from pyro.ops.stats import weighed_quantile
>>> import torch
>>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]])
>>> probs = torch.Tensor([0.2, 0.8])
>>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log()
>>> result = weighed_quantile(input, probs, log_weights, -1)
>>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]]))
"""
dim = dim if dim >= 0 else (len(input.shape) + dim)
if isinstance(probs, (list, tuple)):
Expand Down
Loading
Loading