From 2c00882121c1d00c2c0ccf1693fcf1059e53fe87 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Thu, 21 Mar 2024 21:43:06 +0200 Subject: [PATCH 01/10] Make pyro.infer.predictive._predictive always return both the samples and trace. --- pyro/infer/predictive.py | 60 +++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 34 deletions(-) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 9d8b1c7f76..582817060f 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -3,11 +3,13 @@ import warnings from functools import reduce +from typing import List, NamedTuple, Union import torch import pyro import pyro.poutine as poutine +from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites @@ -31,16 +33,16 @@ def _guess_max_plate_nesting(model, args, kwargs): return max_plate_nesting +class _predictiveResults(NamedTuple): + samples: dict + trace: Union[Trace, List[Trace]] + + def _predictive_sequential( - model, - posterior_samples, - model_args, - model_kwargs, - num_samples, - return_site_shapes, - return_trace=False, + model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes ): - collected = [] + collected_samples = [] + collected_trace = [] samples = [ {k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples) ] @@ -48,20 +50,18 @@ def _predictive_sequential( trace = poutine.trace(poutine.condition(model, samples[i])).get_trace( *model_args, **model_kwargs ) - if return_trace: - collected.append(trace) - else: - collected.append( - {site: trace.nodes[site]["value"] for site in return_site_shapes} - ) + collected_trace.append(trace) + collected_samples.append( + {site: trace.nodes[site]["value"] for site in return_site_shapes} + ) - if return_trace: - return collected - else: - return { - site: torch.stack([s[site] for s in collected]).reshape(shape) + return _predictiveResults( + trace=collected_trace, + samples={ + site: torch.stack([s[site] for s in collected_samples]).reshape(shape) for site, shape in return_site_shapes.items() - } + }, + ) def _predictive( @@ -69,7 +69,6 @@ def _predictive( posterior_samples, num_samples, return_sites=(), - return_trace=False, parallel=False, model_args=(), model_kwargs={}, @@ -93,12 +92,6 @@ def _predictive( ) reshaped_samples[name] = sample - if return_trace: - trace = poutine.trace( - poutine.condition(vectorize(model), reshaped_samples) - ).get_trace(*model_args, **model_kwargs) - return trace - return_site_shapes = {} for site in model_trace.stochastic_nodes + model_trace.observation_nodes: append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape) @@ -131,7 +124,6 @@ def _predictive( model_kwargs, num_samples, return_site_shapes, - return_trace=False, ) trace = poutine.trace( @@ -148,7 +140,7 @@ def _predictive( else: predictions[site] = value.reshape(shape) - return predictions + return _predictiveResults(trace=trace, samples=predictions) class Predictive(torch.nn.Module): @@ -269,7 +261,7 @@ def forward(self, *args, **kwargs): parallel=self.parallel, model_args=args, model_kwargs=kwargs, - ) + ).samples return _predictive( self.model, posterior_samples, @@ -278,7 +270,7 @@ def forward(self, *args, **kwargs): parallel=self.parallel, model_args=args, model_kwargs=kwargs, - ) + ).samples def get_samples(self, *args, **kwargs): warnings.warn( @@ -304,12 +296,12 @@ def get_vectorized_trace(self, *args, **kwargs): parallel=self.parallel, model_args=args, model_kwargs=kwargs, - ) + ).samples return _predictive( self.model, posterior_samples, self.num_samples, - return_trace=True, + parallel=self.parallel, model_args=args, model_kwargs=kwargs, - ) + ).trace From c6f953212491714a049f9004f6d39a67ab346be7 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 22 Mar 2024 11:39:13 +0200 Subject: [PATCH 02/10] Added pyro.infer.predictive.WeighedPredictive and some of its tests. --- pyro/infer/__init__.py | 2 +- pyro/infer/predictive.py | 100 ++++++++++++++++++++++++++++++++- pyro/infer/util.py | 17 ++++++ tests/infer/test_predictive.py | 50 ++++++++++++----- 4 files changed, 152 insertions(+), 17 deletions(-) diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index c0f3a26c3f..0c3efd8a12 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -12,7 +12,7 @@ from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.nuts import NUTS from pyro.infer.mcmc.rwkernel import RandomWalkKernel -from pyro.infer.predictive import Predictive +from pyro.infer.predictive import Predictive, WeighedPredictive from pyro.infer.renyi_elbo import RenyiELBO from pyro.infer.rws import ReweightedWakeSleep from pyro.infer.smcfilter import SMCFilter diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 582817060f..3e847ee324 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -9,6 +9,7 @@ import pyro import pyro.poutine as poutine +from pyro.infer.util import plate_log_prob_sum from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites @@ -64,6 +65,9 @@ def _predictive_sequential( ) +_predictive_vectorize_plate_name = "_num_predictive_samples" + + def _predictive( model, posterior_samples, @@ -72,11 +76,12 @@ def _predictive( parallel=False, model_args=(), model_kwargs={}, + mask=True ): - model = torch.no_grad()(poutine.mask(model, mask=False)) + model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model) max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) vectorize = pyro.plate( - "_num_predictive_samples", num_samples, dim=-max_plate_nesting - 1 + _predictive_vectorize_plate_name, num_samples, dim=-max_plate_nesting - 1 ) model_trace = prune_subsample_sites( poutine.trace(model).get_trace(*model_args, **model_kwargs) @@ -305,3 +310,94 @@ def get_vectorized_trace(self, *args, **kwargs): model_args=args, model_kwargs=kwargs, ).trace + + +def trace_log_prob(trace: Union[Trace, List[Trace]]) -> torch.Tensor: + if isinstance(trace, list): + return torch.Tensor([trace_element.log_prob_sum() for trace_element in trace]) + else: + return plate_log_prob_sum(trace, _predictive_vectorize_plate_name) + + +class WeighedPredictiveResults(NamedTuple): + samples: Union[dict, tuple] + log_weights: torch.Tensor + guide_prob: torch.Tensor + model_prob: torch.Tensor + + +class WeighedPredictive(Predictive): + """ + Class used to construct a weighed predictive distribution that is based + on the same initialization interface as :class:`Predictive`. + + The methods `.forward` and `.call` must be called with an additional keyword argument + `model_guide` which is the model used to create and optimize the guide, and they return both samples and log_weights. + + The weights are calculated as the per sample gap between the model_guide log-probability + and the guide log-probability (a guide must always be provided). + """ + + def call(self, *args, **kwargs): + """ + Method `.call` that is backwards compatible with the same method found in :class:`Predictive` + but must be called with an additional keyword argument `model_guide` + which is the model used to create and optimize the guide. + """ + result = self.forward(*args, **kwargs) + return WeighedPredictiveResults( + samples=tuple(v for _, v in sorted(result.items())), + log_weights=result.log_weights, + guide_prob=result.guide_prob, + model_prob=result.model_prob + ) + + def forward(self, *args, **kwargs): + """ + Method `.forward` that is backwards compatible with the same method found in :class:`Predictive`. + but must be called with an additional keyword argument `model_guide` + which is the model used to create and optimize the guide. + """ + model_guide = kwargs.pop('model_guide') + return_sites = self.return_sites + # return all sites by default if a guide is provided. + return_sites = None if not return_sites else return_sites + guide_predictive = _predictive( + self.guide, + self.posterior_samples, + self.num_samples, + return_sites=None, + parallel=self.parallel, + model_args=args, + model_kwargs=kwargs, + mask=False + ) + posterior_samples = guide_predictive.samples + model_predictive = _predictive( + model_guide, + posterior_samples, + self.num_samples, + return_sites=return_sites, + parallel=self.parallel, + model_args=args, + model_kwargs=kwargs, + mask=False + ) + if not isinstance(guide_predictive.trace, list): + guide_predictive.trace.compute_score_parts() + model_predictive.trace.compute_log_prob() + guide_predictive.trace.pack_tensors() + model_predictive.trace.pack_tensors(guide_predictive.trace.plate_to_symbol) + model_prob = trace_log_prob(model_predictive.trace) + guide_prob = trace_log_prob(guide_predictive.trace) + return WeighedPredictiveResults( + samples=_predictive(self.model, + posterior_samples, + self.num_samples, + return_sites=return_sites, + parallel=self.parallel, + model_args=args, + model_kwargs=kwargs).samples, + log_weights=model_prob - guide_prob, + guide_prob=guide_prob, + model_prob=model_prob) diff --git a/pyro/infer/util.py b/pyro/infer/util.py index 7ea460c1ec..fe2d8af6f7 100644 --- a/pyro/infer/util.py +++ b/pyro/infer/util.py @@ -14,6 +14,7 @@ from pyro.ops import packed from pyro.ops.einsum.adjoint import require_backward from pyro.ops.rings import MarginalRing +from pyro.poutine.trace_struct import Trace from pyro.poutine.util import site_is_subsample from .. import settings @@ -342,3 +343,19 @@ def check_fully_reparametrized(guide_site): raise NotImplementedError( "All distributions in the guide must be fully reparameterized." ) + + +def plate_log_prob_sum(trace: Trace, plate_name: str) -> torch.Tensor: + """ + Get log probability sum from trace while keeping indexing over the specified plate. + """ + wd = trace.plate_to_symbol[plate_name] + log_prob_sum = 0.0 + for site in trace.nodes.values(): + if site["type"] != "sample" or wd not in site["packed"]["log_prob"]._pyro_dims: + continue + log_prob_sum += torch.einsum( + site["packed"]["log_prob"]._pyro_dims + "->" + wd, + [site["packed"]["log_prob"]], + ) + return log_prob_sum diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index fc6f63fa37..8458cb390a 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -8,7 +8,7 @@ import pyro.distributions as dist import pyro.optim as optim import pyro.poutine as poutine -from pyro.infer import SVI, Predictive, Trace_ELBO +from pyro.infer import SVI, Predictive, WeighedPredictive, Trace_ELBO from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal from tests.common import assert_close @@ -39,29 +39,38 @@ def beta_guide(num_trials): pyro.sample("phi", phi_posterior) +@pytest.mark.parametrize("predictive", [Predictive, WeighedPredictive]) @pytest.mark.parametrize("parallel", [False, True]) -def test_posterior_predictive_svi_manual_guide(parallel): +def test_posterior_predictive_svi_manual_guide(parallel, predictive): true_probs = torch.ones(5) * 0.7 - num_trials = torch.ones(5) * 1000 + num_trials = torch.ones(5) * 400 # Reduced to 400 from 1000 in order for guide optimization to converge num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) - svi = SVI(conditioned_model, beta_guide, optim.Adam(dict(lr=1.0)), elbo) - for i in range(1000): + svi = SVI(conditioned_model, beta_guide, optim.Adam(dict(lr=3.0)), elbo) + for i in range(5000): # Increased to 5000 from 1000 in order for guide optimization to converge svi.step(num_trials) - posterior_predictive = Predictive( + posterior_predictive = predictive( model, guide=beta_guide, num_samples=10000, parallel=parallel, return_sites=["_RETURN"], ) - marginal_return_vals = posterior_predictive(num_trials)["_RETURN"] - assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05) + if predictive is Predictive: + marginal_return_vals = posterior_predictive(num_trials)["_RETURN"] + else: + weighed_samples = posterior_predictive(num_trials, model_guide=conditioned_model) + marginal_return_vals = weighed_samples.samples["_RETURN"] + assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape + # Weights should be uniform as the guide has the same distribution as the model + assert weighed_samples.log_weights.std() < 0.6 + assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 280, rtol=0.1) +@pytest.mark.parametrize("predictive", [Predictive, WeighedPredictive]) @pytest.mark.parametrize("parallel", [False, True]) -def test_posterior_predictive_svi_auto_delta_guide(parallel): +def test_posterior_predictive_svi_auto_delta_guide(parallel, predictive): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() @@ -70,15 +79,21 @@ def test_posterior_predictive_svi_auto_delta_guide(parallel): svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=1.0)), Trace_ELBO()) for i in range(1000): svi.step(num_trials) - posterior_predictive = Predictive( + posterior_predictive = predictive( model, guide=guide, num_samples=10000, parallel=parallel ) - marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] + if predictive is Predictive: + marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] + else: + weighed_samples = posterior_predictive.get_samples(num_trials, model_guide=conditioned_model) + marginal_return_vals = weighed_samples.samples["obs"] + assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05) +@pytest.mark.parametrize("predictive", [Predictive, WeighedPredictive]) @pytest.mark.parametrize("return_trace", [False, True]) -def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace): +def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace, predictive): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() @@ -87,7 +102,7 @@ def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace): svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO()) for i in range(1000): svi.step(num_trials) - posterior_predictive = Predictive( + posterior_predictive = predictive( model, guide=guide, num_samples=10000, parallel=True ) if return_trace: @@ -95,7 +110,12 @@ def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace): num_trials ).nodes["obs"]["value"] else: - marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] + if predictive is Predictive: + marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] + else: + weighed_samples = posterior_predictive.get_samples(num_trials, model_guide=conditioned_model) + marginal_return_vals = weighed_samples.samples["obs"] + assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05) @@ -213,3 +233,5 @@ def guide(): assert "guide-always" in called assert "model-sometimes" not in called assert "guide-sometimes" not in called + +test_posterior_predictive_svi_manual_guide(True, WeighedPredictive) From 89955e018751550cfe4ab39b56a1b92c7b65fab2 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 22 Mar 2024 15:02:34 +0200 Subject: [PATCH 03/10] Make model_guide in call to WeighedPredictive optional. --- pyro/infer/predictive.py | 13 +++++++------ tests/infer/test_predictive.py | 16 +++++++++------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 3e847ee324..c22675e0ee 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -331,8 +331,9 @@ class WeighedPredictive(Predictive): Class used to construct a weighed predictive distribution that is based on the same initialization interface as :class:`Predictive`. - The methods `.forward` and `.call` must be called with an additional keyword argument + The methods `.forward` and `.call` can be called with an additional keyword argument `model_guide` which is the model used to create and optimize the guide, and they return both samples and log_weights. + If not provided `model_guide` defaults to `self.model`. The weights are calculated as the per sample gap between the model_guide log-probability and the guide log-probability (a guide must always be provided). @@ -341,7 +342,7 @@ class WeighedPredictive(Predictive): def call(self, *args, **kwargs): """ Method `.call` that is backwards compatible with the same method found in :class:`Predictive` - but must be called with an additional keyword argument `model_guide` + but can be called with an additional keyword argument `model_guide` which is the model used to create and optimize the guide. """ result = self.forward(*args, **kwargs) @@ -354,11 +355,11 @@ def call(self, *args, **kwargs): def forward(self, *args, **kwargs): """ - Method `.forward` that is backwards compatible with the same method found in :class:`Predictive`. - but must be called with an additional keyword argument `model_guide` + Method `.forward` that is backwards compatible with the same method found in :class:`Predictive` + but can be called with an additional keyword argument `model_guide` which is the model used to create and optimize the guide. """ - model_guide = kwargs.pop('model_guide') + model_guide = kwargs.pop('model_guide', self.model) return_sites = self.return_sites # return all sites by default if a guide is provided. return_sites = None if not return_sites else return_sites @@ -397,7 +398,7 @@ def forward(self, *args, **kwargs): return_sites=return_sites, parallel=self.parallel, model_args=args, - model_kwargs=kwargs).samples, + model_kwargs=kwargs).samples if model_guide is not self.model else model_predictive.samples, log_weights=model_prob - guide_prob, guide_prob=guide_prob, model_prob=model_prob) diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index 8458cb390a..90f9741cda 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -133,8 +133,9 @@ def test_posterior_predictive_svi_one_hot(): assert_close(marginal_return_vals.mean(dim=0), true_probs.unsqueeze(0), rtol=0.1) +@pytest.mark.parametrize("predictive", [Predictive, WeighedPredictive]) @pytest.mark.parametrize("parallel", [False, True]) -def test_shapes(parallel): +def test_shapes(parallel, predictive): num_samples = 10 def model(): @@ -152,14 +153,17 @@ def model(): expected = poutine.replay(vectorize(model), trace)() # Use Predictive. - predictive = Predictive( + actual = predictive( model, guide=guide, return_sites=["x", "y"], num_samples=num_samples, parallel=parallel, - ) - actual = predictive() + )() + if predictive is WeighedPredictive: + assert actual.samples["x"].shape[:1] == actual.log_weights.shape + assert actual.samples["y"].shape[:1] == actual.log_weights.shape + actual = actual.samples assert set(actual) == set(expected) assert actual["x"].shape == expected["x"].shape assert actual["y"].shape == expected["y"].shape @@ -167,7 +171,7 @@ def model(): @pytest.mark.parametrize("with_plate", [True, False]) @pytest.mark.parametrize("event_shape", [(), (2,)]) -def test_deterministic(with_plate, event_shape): +def test_deterministic(with_plate, event_shape, predictive): def model(y=None): with pyro.util.optional(pyro.plate("plate", 3), with_plate): x = pyro.sample("x", dist.Normal(0, 1).expand(event_shape).to_event()) @@ -233,5 +237,3 @@ def guide(): assert "guide-always" in called assert "model-sometimes" not in called assert "guide-sometimes" not in called - -test_posterior_predictive_svi_manual_guide(True, WeighedPredictive) From 133bd742f2296c23f7d1cf53e179f52f4e66724d Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 22 Mar 2024 15:10:34 +0200 Subject: [PATCH 04/10] Add test for WeighedPredictive with plate and event shape. --- tests/infer/test_predictive.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index 90f9741cda..e4bb30f0cf 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -169,6 +169,7 @@ def model(): assert actual["y"].shape == expected["y"].shape +@pytest.mark.parametrize("predictive", [Predictive, WeighedPredictive]) @pytest.mark.parametrize("with_plate", [True, False]) @pytest.mark.parametrize("event_shape", [(), (2,)]) def test_deterministic(with_plate, event_shape, predictive): @@ -186,9 +187,13 @@ def model(y=None): for i in range(100): svi.step(y) - actual = Predictive( + actual = predictive( model, guide=guide, return_sites=["x2", "x3"], num_samples=1000 )() + if predictive is WeighedPredictive: + assert actual.samples["x2"].shape[:1] == actual.log_weights.shape + assert actual.samples["x3"].shape[:1] == actual.log_weights.shape + actual = actual.samples x2_batch_shape = (3,) if with_plate else () assert actual["x2"].shape == (1000,) + x2_batch_shape + event_shape # x3 shape is prepended 1 to match Pyro shape semantics From a989e87d0c30a0aa2e0af9eefb4906c4c392f52b Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 22 Mar 2024 15:19:54 +0200 Subject: [PATCH 05/10] Linting and formatting updates associated with the introduction of WeighedPredictive. --- pyro/infer/__init__.py | 1 + pyro/infer/predictive.py | 41 ++++++++++++++++++++-------------- tests/infer/test_predictive.py | 22 +++++++++++++----- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index 0c3efd8a12..6934bd29fe 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -62,4 +62,5 @@ "TraceTailAdaptive_ELBO", "Trace_ELBO", "Trace_MMD", + "WeighedPredictive", ] diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index c22675e0ee..ff0f77e378 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -76,7 +76,7 @@ def _predictive( parallel=False, model_args=(), model_kwargs={}, - mask=True + mask=True, ): model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model) max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) @@ -330,10 +330,10 @@ class WeighedPredictive(Predictive): """ Class used to construct a weighed predictive distribution that is based on the same initialization interface as :class:`Predictive`. - + The methods `.forward` and `.call` can be called with an additional keyword argument - `model_guide` which is the model used to create and optimize the guide, and they return both samples and log_weights. - If not provided `model_guide` defaults to `self.model`. + `model_guide` which is the model used to create and optimize the guide (if not + provided `model_guide` defaults to `self.model`), and they return both samples and log_weights. The weights are calculated as the per sample gap between the model_guide log-probability and the guide log-probability (a guide must always be provided). @@ -350,7 +350,7 @@ def call(self, *args, **kwargs): samples=tuple(v for _, v in sorted(result.items())), log_weights=result.log_weights, guide_prob=result.guide_prob, - model_prob=result.model_prob + model_prob=result.model_prob, ) def forward(self, *args, **kwargs): @@ -359,7 +359,7 @@ def forward(self, *args, **kwargs): but can be called with an additional keyword argument `model_guide` which is the model used to create and optimize the guide. """ - model_guide = kwargs.pop('model_guide', self.model) + model_guide = kwargs.pop("model_guide", self.model) return_sites = self.return_sites # return all sites by default if a guide is provided. return_sites = None if not return_sites else return_sites @@ -371,7 +371,7 @@ def forward(self, *args, **kwargs): parallel=self.parallel, model_args=args, model_kwargs=kwargs, - mask=False + mask=False, ) posterior_samples = guide_predictive.samples model_predictive = _predictive( @@ -382,23 +382,30 @@ def forward(self, *args, **kwargs): parallel=self.parallel, model_args=args, model_kwargs=kwargs, - mask=False + mask=False, ) if not isinstance(guide_predictive.trace, list): guide_predictive.trace.compute_score_parts() model_predictive.trace.compute_log_prob() guide_predictive.trace.pack_tensors() model_predictive.trace.pack_tensors(guide_predictive.trace.plate_to_symbol) - model_prob = trace_log_prob(model_predictive.trace) + model_prob = trace_log_prob(model_predictive.trace) guide_prob = trace_log_prob(guide_predictive.trace) return WeighedPredictiveResults( - samples=_predictive(self.model, - posterior_samples, - self.num_samples, - return_sites=return_sites, - parallel=self.parallel, - model_args=args, - model_kwargs=kwargs).samples if model_guide is not self.model else model_predictive.samples, + samples=( + _predictive( + self.model, + posterior_samples, + self.num_samples, + return_sites=return_sites, + parallel=self.parallel, + model_args=args, + model_kwargs=kwargs, + ).samples + if model_guide is not self.model + else model_predictive.samples + ), log_weights=model_prob - guide_prob, guide_prob=guide_prob, - model_prob=model_prob) + model_prob=model_prob, + ) diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index e4bb30f0cf..1f28e1f05c 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -8,7 +8,7 @@ import pyro.distributions as dist import pyro.optim as optim import pyro.poutine as poutine -from pyro.infer import SVI, Predictive, WeighedPredictive, Trace_ELBO +from pyro.infer import SVI, Predictive, Trace_ELBO, WeighedPredictive from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal from tests.common import assert_close @@ -43,12 +43,16 @@ def beta_guide(num_trials): @pytest.mark.parametrize("parallel", [False, True]) def test_posterior_predictive_svi_manual_guide(parallel, predictive): true_probs = torch.ones(5) * 0.7 - num_trials = torch.ones(5) * 400 # Reduced to 400 from 1000 in order for guide optimization to converge + num_trials = ( + torch.ones(5) * 400 + ) # Reduced to 400 from 1000 in order for guide optimization to converge num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) svi = SVI(conditioned_model, beta_guide, optim.Adam(dict(lr=3.0)), elbo) - for i in range(5000): # Increased to 5000 from 1000 in order for guide optimization to converge + for i in range( + 5000 + ): # Increased to 5000 from 1000 in order for guide optimization to converge svi.step(num_trials) posterior_predictive = predictive( model, @@ -60,7 +64,9 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): if predictive is Predictive: marginal_return_vals = posterior_predictive(num_trials)["_RETURN"] else: - weighed_samples = posterior_predictive(num_trials, model_guide=conditioned_model) + weighed_samples = posterior_predictive( + num_trials, model_guide=conditioned_model + ) marginal_return_vals = weighed_samples.samples["_RETURN"] assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape # Weights should be uniform as the guide has the same distribution as the model @@ -85,7 +91,9 @@ def test_posterior_predictive_svi_auto_delta_guide(parallel, predictive): if predictive is Predictive: marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] else: - weighed_samples = posterior_predictive.get_samples(num_trials, model_guide=conditioned_model) + weighed_samples = posterior_predictive.get_samples( + num_trials, model_guide=conditioned_model + ) marginal_return_vals = weighed_samples.samples["obs"] assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05) @@ -113,7 +121,9 @@ def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace, predictiv if predictive is Predictive: marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] else: - weighed_samples = posterior_predictive.get_samples(num_trials, model_guide=conditioned_model) + weighed_samples = posterior_predictive.get_samples( + num_trials, model_guide=conditioned_model + ) marginal_return_vals = weighed_samples.samples["obs"] assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05) From 7426641cacacbbd6613129a24b0ee872107b2695 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 22 Mar 2024 16:41:33 +0200 Subject: [PATCH 06/10] Fix naming from probability to log-probability. --- pyro/infer/predictive.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index ff0f77e378..648fbd8855 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -322,8 +322,8 @@ def trace_log_prob(trace: Union[Trace, List[Trace]]) -> torch.Tensor: class WeighedPredictiveResults(NamedTuple): samples: Union[dict, tuple] log_weights: torch.Tensor - guide_prob: torch.Tensor - model_prob: torch.Tensor + guide_log_prob: torch.Tensor + model_log_prob: torch.Tensor class WeighedPredictive(Predictive): @@ -349,8 +349,8 @@ def call(self, *args, **kwargs): return WeighedPredictiveResults( samples=tuple(v for _, v in sorted(result.items())), log_weights=result.log_weights, - guide_prob=result.guide_prob, - model_prob=result.model_prob, + guide_log_prob=result.guide_log_prob, + model_log_prob=result.model_log_prob, ) def forward(self, *args, **kwargs): @@ -389,8 +389,8 @@ def forward(self, *args, **kwargs): model_predictive.trace.compute_log_prob() guide_predictive.trace.pack_tensors() model_predictive.trace.pack_tensors(guide_predictive.trace.plate_to_symbol) - model_prob = trace_log_prob(model_predictive.trace) - guide_prob = trace_log_prob(guide_predictive.trace) + model_log_prob = trace_log_prob(model_predictive.trace) + guide_log_prob = trace_log_prob(guide_predictive.trace) return WeighedPredictiveResults( samples=( _predictive( @@ -405,7 +405,7 @@ def forward(self, *args, **kwargs): if model_guide is not self.model else model_predictive.samples ), - log_weights=model_prob - guide_prob, - guide_prob=guide_prob, - model_prob=model_prob, + log_weights=model_log_prob - guide_log_prob, + guide_log_prob=guide_log_prob, + model_log_prob=model_log_prob, ) From a4c4d8585e19c9736288db6326864609128d8a3f Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 22 Mar 2024 18:16:36 +0200 Subject: [PATCH 07/10] Fix backwards compatbility. --- pyro/infer/predictive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 648fbd8855..c5df545921 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -306,7 +306,7 @@ def get_vectorized_trace(self, *args, **kwargs): self.model, posterior_samples, self.num_samples, - parallel=self.parallel, + parallel=True, model_args=args, model_kwargs=kwargs, ).trace From 026cae27510a0be08bf01d3e416e380a2bddbf3e Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Mon, 25 Mar 2024 20:49:16 +0200 Subject: [PATCH 08/10] Create shared machinery between pyro.infer.WeighedPredictive and pyro.infer.Importance. --- pyro/infer/importance.py | 20 ++++---------------- pyro/infer/predictive.py | 35 ++++++++++++++++++++++------------- pyro/infer/util.py | 7 +++---- 3 files changed, 29 insertions(+), 33 deletions(-) diff --git a/pyro/infer/importance.py b/pyro/infer/importance.py index d7c25a843d..d25cf16680 100644 --- a/pyro/infer/importance.py +++ b/pyro/infer/importance.py @@ -12,6 +12,7 @@ from .abstract_infer import TracePosterior from .enum import get_importance_trace +from .util import plate_log_prob_sum class Importance(TracePosterior): @@ -143,22 +144,9 @@ def _fn(*args, **kwargs): log_weights = model_trace.log_prob_sum() - guide_trace.log_prob_sum() else: wd = guide_trace.plate_to_symbol["num_particles_vectorized"] - log_weights = 0.0 - for site in model_trace.nodes.values(): - if site["type"] != "sample": - continue - log_weights += torch.einsum( - site["packed"]["log_prob"]._pyro_dims + "->" + wd, - [site["packed"]["log_prob"]], - ) - - for site in guide_trace.nodes.values(): - if site["type"] != "sample": - continue - log_weights -= torch.einsum( - site["packed"]["log_prob"]._pyro_dims + "->" + wd, - [site["packed"]["log_prob"]], - ) + log_weights = plate_log_prob_sum(model_trace, wd) - plate_log_prob_sum( + guide_trace, wd + ) if normalized: log_weights = log_weights - torch.logsumexp(log_weights) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index c5df545921..905c9ca4f0 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -312,13 +312,6 @@ def get_vectorized_trace(self, *args, **kwargs): ).trace -def trace_log_prob(trace: Union[Trace, List[Trace]]) -> torch.Tensor: - if isinstance(trace, list): - return torch.Tensor([trace_element.log_prob_sum() for trace_element in trace]) - else: - return plate_log_prob_sum(trace, _predictive_vectorize_plate_name) - - class WeighedPredictiveResults(NamedTuple): samples: Union[dict, tuple] log_weights: torch.Tensor @@ -385,12 +378,28 @@ def forward(self, *args, **kwargs): mask=False, ) if not isinstance(guide_predictive.trace, list): - guide_predictive.trace.compute_score_parts() - model_predictive.trace.compute_log_prob() - guide_predictive.trace.pack_tensors() - model_predictive.trace.pack_tensors(guide_predictive.trace.plate_to_symbol) - model_log_prob = trace_log_prob(model_predictive.trace) - guide_log_prob = trace_log_prob(guide_predictive.trace) + guide_trace = prune_subsample_sites(guide_predictive.trace) + model_trace = prune_subsample_sites(model_predictive.trace) + guide_trace.compute_score_parts() + model_trace.compute_log_prob() + guide_trace.pack_tensors() + model_trace.pack_tensors(guide_trace.plate_to_symbol) + plate_symbol = guide_trace.plate_to_symbol[_predictive_vectorize_plate_name] + guide_log_prob = plate_log_prob_sum(guide_trace, plate_symbol) + model_log_prob = plate_log_prob_sum(model_trace, plate_symbol) + else: + guide_log_prob = torch.Tensor( + [ + trace_element.log_prob_sum() + for trace_element in guide_predictive.trace + ] + ) + model_log_prob = torch.Tensor( + [ + trace_element.log_prob_sum() + for trace_element in model_predictive.trace + ] + ) return WeighedPredictiveResults( samples=( _predictive( diff --git a/pyro/infer/util.py b/pyro/infer/util.py index fe2d8af6f7..13e1d9e12f 100644 --- a/pyro/infer/util.py +++ b/pyro/infer/util.py @@ -345,17 +345,16 @@ def check_fully_reparametrized(guide_site): ) -def plate_log_prob_sum(trace: Trace, plate_name: str) -> torch.Tensor: +def plate_log_prob_sum(trace: Trace, plate_symbol: str) -> torch.Tensor: """ Get log probability sum from trace while keeping indexing over the specified plate. """ - wd = trace.plate_to_symbol[plate_name] log_prob_sum = 0.0 for site in trace.nodes.values(): - if site["type"] != "sample" or wd not in site["packed"]["log_prob"]._pyro_dims: + if site["type"] != "sample": continue log_prob_sum += torch.einsum( - site["packed"]["log_prob"]._pyro_dims + "->" + wd, + site["packed"]["log_prob"]._pyro_dims + "->" + plate_symbol, [site["packed"]["log_prob"]], ) return log_prob_sum From 260e52869b402c30559e569f37876f2aff3d1d5b Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Mon, 25 Mar 2024 23:20:13 +0200 Subject: [PATCH 09/10] Elaborate methematical details of pyro.infer.predictive.WeighedPredictive. --- pyro/infer/predictive.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 905c9ca4f0..312e9295c4 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -325,11 +325,25 @@ class WeighedPredictive(Predictive): on the same initialization interface as :class:`Predictive`. The methods `.forward` and `.call` can be called with an additional keyword argument - `model_guide` which is the model used to create and optimize the guide (if not - provided `model_guide` defaults to `self.model`), and they return both samples and log_weights. + ``model_guide`` which is the model used to create and optimize the guide (if not + provided ``model_guide`` defaults to ``self.model``), and they return both samples and log_weights. The weights are calculated as the per sample gap between the model_guide log-probability and the guide log-probability (a guide must always be provided). + + A typical use case would be based on a ``model`` :math:`p(x,z)=p(x|z)p(z)` and ``guide`` :math:`q(z)` + that has already been fitted to the model given observations :math:`p(X_{obs},z)`, both of which + are provided at itialization of :class:`WeighedPredictive` (same as you would do with :class:`Predictive`). + When calling an instance of :class:`WeighedPredictive` we provide the model given observations :math:`p(X_{obs},z)` + as the keyword argument ``model_guide``. + The resulting output would be the usual samples :math:`p(x|z)q(z)` returned by :class:`Predictive`, + along with per sample weights :math:`p(X_{obs},z)/q(z)`. The samples and weights can be fed into + :any:`weighed_quantile` in order to obtain the true quantiles of the resulting distribution. + + Note that the ``model`` can be more elaborate with sample sites :math:`y` that are not observed + and are not part of the guide, if the samples sites :math:`y` are sampled after the observations + and the latent variables sampled by the guide, such that :math:`p(x,y,z)=p(y|x,z)p(x|z)p(z)` where + each element in the product represents a set of ``pyro.sample`` statements. """ def call(self, *args, **kwargs): From 5030d41ec2f9b2a2d0b82957e1ad456e551a63fb Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Tue, 26 Mar 2024 12:31:52 +0200 Subject: [PATCH 10/10] Update and fix docs. --- pyro/infer/predictive.py | 14 ++++++++++++++ pyro/ops/stats.py | 19 +++++++++++-------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 312e9295c4..6be8b5cb5f 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -35,6 +35,10 @@ def _guess_max_plate_nesting(model, args, kwargs): class _predictiveResults(NamedTuple): + """ + Return value of call to ``_predictive`` and ``_predictive_sequential``. + """ + samples: dict trace: Union[Trace, List[Trace]] @@ -313,6 +317,10 @@ def get_vectorized_trace(self, *args, **kwargs): class WeighedPredictiveResults(NamedTuple): + """ + Return value of call to instance of :class:`WeighedPredictive`. + """ + samples: Union[dict, tuple] log_weights: torch.Tensor guide_log_prob: torch.Tensor @@ -351,6 +359,9 @@ def call(self, *args, **kwargs): Method `.call` that is backwards compatible with the same method found in :class:`Predictive` but can be called with an additional keyword argument `model_guide` which is the model used to create and optimize the guide. + + Returns :class:`WeighedPredictiveResults` which has attributes ``.samples`` and per sample + weights ``.log_weights``. """ result = self.forward(*args, **kwargs) return WeighedPredictiveResults( @@ -365,6 +376,9 @@ def forward(self, *args, **kwargs): Method `.forward` that is backwards compatible with the same method found in :class:`Predictive` but can be called with an additional keyword argument `model_guide` which is the model used to create and optimize the guide. + + Returns :class:`WeighedPredictiveResults` which has attributes ``.samples`` and per sample + weights ``.log_weights``. """ model_guide = kwargs.pop("model_guide", self.model) return_sites = self.return_sites diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index 2ec57d4784..8e0bd2631f 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -277,14 +277,17 @@ def weighed_quantile( :param int dim: dimension to take quantiles from ``input``. :returns torch.Tensor: quantiles of ``input`` at ``probs``. - Example: - >>> from pyro.ops.stats import weighed_quantile - >>> import torch - >>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]]) - >>> probs = torch.Tensor([0.2, 0.8]) - >>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log() - >>> result = weighed_quantile(input, probs, log_weights, -1) - >>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]])) + **Example:** + + .. doctest:: + + >>> from pyro.ops.stats import weighed_quantile + >>> import torch + >>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]]) + >>> probs = torch.Tensor([0.2, 0.8]) + >>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log() + >>> result = weighed_quantile(input, probs, log_weights, -1) + >>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]])) """ dim = dim if dim >= 0 else (len(input.shape) + dim) if isinstance(probs, (list, tuple)):