Skip to content

Commit

Permalink
Introducing pyro.infer.predictive.WeighedPredictive which reports wei…
Browse files Browse the repository at this point in the history
…ghts along with predicted samples (#3345)
  • Loading branch information
BenZickel authored Mar 26, 2024
1 parent a79ba3a commit 0dc635f
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 80 deletions.
3 changes: 2 additions & 1 deletion pyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pyro.infer.mcmc.hmc import HMC
from pyro.infer.mcmc.nuts import NUTS
from pyro.infer.mcmc.rwkernel import RandomWalkKernel
from pyro.infer.predictive import Predictive
from pyro.infer.predictive import Predictive, WeighedPredictive
from pyro.infer.renyi_elbo import RenyiELBO
from pyro.infer.rws import ReweightedWakeSleep
from pyro.infer.smcfilter import SMCFilter
Expand Down Expand Up @@ -62,4 +62,5 @@
"TraceTailAdaptive_ELBO",
"Trace_ELBO",
"Trace_MMD",
"WeighedPredictive",
]
20 changes: 4 additions & 16 deletions pyro/infer/importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .abstract_infer import TracePosterior
from .enum import get_importance_trace
from .util import plate_log_prob_sum


class Importance(TracePosterior):
Expand Down Expand Up @@ -143,22 +144,9 @@ def _fn(*args, **kwargs):
log_weights = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
else:
wd = guide_trace.plate_to_symbol["num_particles_vectorized"]
log_weights = 0.0
for site in model_trace.nodes.values():
if site["type"] != "sample":
continue
log_weights += torch.einsum(
site["packed"]["log_prob"]._pyro_dims + "->" + wd,
[site["packed"]["log_prob"]],
)

for site in guide_trace.nodes.values():
if site["type"] != "sample":
continue
log_weights -= torch.einsum(
site["packed"]["log_prob"]._pyro_dims + "->" + wd,
[site["packed"]["log_prob"]],
)
log_weights = plate_log_prob_sum(model_trace, wd) - plate_log_prob_sum(
guide_trace, wd
)

if normalized:
log_weights = log_weights - torch.logsumexp(log_weights)
Expand Down
203 changes: 168 additions & 35 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

import warnings
from functools import reduce
from typing import List, NamedTuple, Union

import torch

import pyro
import pyro.poutine as poutine
from pyro.infer.util import plate_log_prob_sum
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites


Expand All @@ -31,53 +34,58 @@ def _guess_max_plate_nesting(model, args, kwargs):
return max_plate_nesting


class _predictiveResults(NamedTuple):
"""
Return value of call to ``_predictive`` and ``_predictive_sequential``.
"""

samples: dict
trace: Union[Trace, List[Trace]]


def _predictive_sequential(
model,
posterior_samples,
model_args,
model_kwargs,
num_samples,
return_site_shapes,
return_trace=False,
model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes
):
collected = []
collected_samples = []
collected_trace = []
samples = [
{k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples)
]
for i in range(num_samples):
trace = poutine.trace(poutine.condition(model, samples[i])).get_trace(
*model_args, **model_kwargs
)
if return_trace:
collected.append(trace)
else:
collected.append(
{site: trace.nodes[site]["value"] for site in return_site_shapes}
)
collected_trace.append(trace)
collected_samples.append(
{site: trace.nodes[site]["value"] for site in return_site_shapes}
)

if return_trace:
return collected
else:
return {
site: torch.stack([s[site] for s in collected]).reshape(shape)
return _predictiveResults(
trace=collected_trace,
samples={
site: torch.stack([s[site] for s in collected_samples]).reshape(shape)
for site, shape in return_site_shapes.items()
}
},
)


_predictive_vectorize_plate_name = "_num_predictive_samples"


def _predictive(
model,
posterior_samples,
num_samples,
return_sites=(),
return_trace=False,
parallel=False,
model_args=(),
model_kwargs={},
mask=True,
):
model = torch.no_grad()(poutine.mask(model, mask=False))
model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model)
max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
vectorize = pyro.plate(
"_num_predictive_samples", num_samples, dim=-max_plate_nesting - 1
_predictive_vectorize_plate_name, num_samples, dim=-max_plate_nesting - 1
)
model_trace = prune_subsample_sites(
poutine.trace(model).get_trace(*model_args, **model_kwargs)
Expand All @@ -93,12 +101,6 @@ def _predictive(
)
reshaped_samples[name] = sample

if return_trace:
trace = poutine.trace(
poutine.condition(vectorize(model), reshaped_samples)
).get_trace(*model_args, **model_kwargs)
return trace

return_site_shapes = {}
for site in model_trace.stochastic_nodes + model_trace.observation_nodes:
append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape)
Expand Down Expand Up @@ -131,7 +133,6 @@ def _predictive(
model_kwargs,
num_samples,
return_site_shapes,
return_trace=False,
)

trace = poutine.trace(
Expand All @@ -148,7 +149,7 @@ def _predictive(
else:
predictions[site] = value.reshape(shape)

return predictions
return _predictiveResults(trace=trace, samples=predictions)


class Predictive(torch.nn.Module):
Expand Down Expand Up @@ -269,7 +270,7 @@ def forward(self, *args, **kwargs):
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
)
).samples
return _predictive(
self.model,
posterior_samples,
Expand All @@ -278,7 +279,7 @@ def forward(self, *args, **kwargs):
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
)
).samples

def get_samples(self, *args, **kwargs):
warnings.warn(
Expand All @@ -304,12 +305,144 @@ def get_vectorized_trace(self, *args, **kwargs):
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
)
).samples
return _predictive(
self.model,
posterior_samples,
self.num_samples,
return_trace=True,
parallel=True,
model_args=args,
model_kwargs=kwargs,
).trace


class WeighedPredictiveResults(NamedTuple):
"""
Return value of call to instance of :class:`WeighedPredictive`.
"""

samples: Union[dict, tuple]
log_weights: torch.Tensor
guide_log_prob: torch.Tensor
model_log_prob: torch.Tensor


class WeighedPredictive(Predictive):
"""
Class used to construct a weighed predictive distribution that is based
on the same initialization interface as :class:`Predictive`.
The methods `.forward` and `.call` can be called with an additional keyword argument
``model_guide`` which is the model used to create and optimize the guide (if not
provided ``model_guide`` defaults to ``self.model``), and they return both samples and log_weights.
The weights are calculated as the per sample gap between the model_guide log-probability
and the guide log-probability (a guide must always be provided).
A typical use case would be based on a ``model`` :math:`p(x,z)=p(x|z)p(z)` and ``guide`` :math:`q(z)`
that has already been fitted to the model given observations :math:`p(X_{obs},z)`, both of which
are provided at itialization of :class:`WeighedPredictive` (same as you would do with :class:`Predictive`).
When calling an instance of :class:`WeighedPredictive` we provide the model given observations :math:`p(X_{obs},z)`
as the keyword argument ``model_guide``.
The resulting output would be the usual samples :math:`p(x|z)q(z)` returned by :class:`Predictive`,
along with per sample weights :math:`p(X_{obs},z)/q(z)`. The samples and weights can be fed into
:any:`weighed_quantile` in order to obtain the true quantiles of the resulting distribution.
Note that the ``model`` can be more elaborate with sample sites :math:`y` that are not observed
and are not part of the guide, if the samples sites :math:`y` are sampled after the observations
and the latent variables sampled by the guide, such that :math:`p(x,y,z)=p(y|x,z)p(x|z)p(z)` where
each element in the product represents a set of ``pyro.sample`` statements.
"""

def call(self, *args, **kwargs):
"""
Method `.call` that is backwards compatible with the same method found in :class:`Predictive`
but can be called with an additional keyword argument `model_guide`
which is the model used to create and optimize the guide.
Returns :class:`WeighedPredictiveResults` which has attributes ``.samples`` and per sample
weights ``.log_weights``.
"""
result = self.forward(*args, **kwargs)
return WeighedPredictiveResults(
samples=tuple(v for _, v in sorted(result.items())),
log_weights=result.log_weights,
guide_log_prob=result.guide_log_prob,
model_log_prob=result.model_log_prob,
)

def forward(self, *args, **kwargs):
"""
Method `.forward` that is backwards compatible with the same method found in :class:`Predictive`
but can be called with an additional keyword argument `model_guide`
which is the model used to create and optimize the guide.
Returns :class:`WeighedPredictiveResults` which has attributes ``.samples`` and per sample
weights ``.log_weights``.
"""
model_guide = kwargs.pop("model_guide", self.model)
return_sites = self.return_sites
# return all sites by default if a guide is provided.
return_sites = None if not return_sites else return_sites
guide_predictive = _predictive(
self.guide,
self.posterior_samples,
self.num_samples,
return_sites=None,
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
mask=False,
)
posterior_samples = guide_predictive.samples
model_predictive = _predictive(
model_guide,
posterior_samples,
self.num_samples,
return_sites=return_sites,
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
mask=False,
)
if not isinstance(guide_predictive.trace, list):
guide_trace = prune_subsample_sites(guide_predictive.trace)
model_trace = prune_subsample_sites(model_predictive.trace)
guide_trace.compute_score_parts()
model_trace.compute_log_prob()
guide_trace.pack_tensors()
model_trace.pack_tensors(guide_trace.plate_to_symbol)
plate_symbol = guide_trace.plate_to_symbol[_predictive_vectorize_plate_name]
guide_log_prob = plate_log_prob_sum(guide_trace, plate_symbol)
model_log_prob = plate_log_prob_sum(model_trace, plate_symbol)
else:
guide_log_prob = torch.Tensor(
[
trace_element.log_prob_sum()
for trace_element in guide_predictive.trace
]
)
model_log_prob = torch.Tensor(
[
trace_element.log_prob_sum()
for trace_element in model_predictive.trace
]
)
return WeighedPredictiveResults(
samples=(
_predictive(
self.model,
posterior_samples,
self.num_samples,
return_sites=return_sites,
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
).samples
if model_guide is not self.model
else model_predictive.samples
),
log_weights=model_log_prob - guide_log_prob,
guide_log_prob=guide_log_prob,
model_log_prob=model_log_prob,
)
16 changes: 16 additions & 0 deletions pyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pyro.ops import packed
from pyro.ops.einsum.adjoint import require_backward
from pyro.ops.rings import MarginalRing
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import site_is_subsample

from .. import settings
Expand Down Expand Up @@ -342,3 +343,18 @@ def check_fully_reparametrized(guide_site):
raise NotImplementedError(
"All distributions in the guide must be fully reparameterized."
)


def plate_log_prob_sum(trace: Trace, plate_symbol: str) -> torch.Tensor:
"""
Get log probability sum from trace while keeping indexing over the specified plate.
"""
log_prob_sum = 0.0
for site in trace.nodes.values():
if site["type"] != "sample":
continue
log_prob_sum += torch.einsum(
site["packed"]["log_prob"]._pyro_dims + "->" + plate_symbol,
[site["packed"]["log_prob"]],
)
return log_prob_sum
19 changes: 11 additions & 8 deletions pyro/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,17 @@ def weighed_quantile(
:param int dim: dimension to take quantiles from ``input``.
:returns torch.Tensor: quantiles of ``input`` at ``probs``.
Example:
>>> from pyro.ops.stats import weighed_quantile
>>> import torch
>>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]])
>>> probs = torch.Tensor([0.2, 0.8])
>>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log()
>>> result = weighed_quantile(input, probs, log_weights, -1)
>>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]]))
**Example:**
.. doctest::
>>> from pyro.ops.stats import weighed_quantile
>>> import torch
>>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]])
>>> probs = torch.Tensor([0.2, 0.8])
>>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log()
>>> result = weighed_quantile(input, probs, log_weights, -1)
>>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]]))
"""
dim = dim if dim >= 0 else (len(input.shape) + dim)
if isinstance(probs, (list, tuple)):
Expand Down
Loading

0 comments on commit 0dc635f

Please sign in to comment.