Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Trace_ELBO that generalizes Trace_ELBO, TraceEnum_ELBO, and TraceGraph_ELBO #2893

Draft
wants to merge 62 commits into
base: dev
Choose a base branch
from
Draft
Changes from 2 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
b00948c
trace_elbo
Jul 3, 2021
da2f887
lint
Jul 3, 2021
f6c95e4
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Jul 5, 2021
3ec076d
test_gradient
Jul 5, 2021
a22ff4e
copy traceenum_elbo and add test model with poisson dist
Jul 16, 2021
d551fa2
lint
Jul 16, 2021
b68bb3f
use constant funsor
Jul 21, 2021
bfb13bf
working version
Jul 28, 2021
ca1a1fe
pass second test
Jul 28, 2021
6d6a9ed
clean up trace_elbo
Jul 29, 2021
0f23b42
add another test
Aug 8, 2021
91384ed
lazy eval
Aug 20, 2021
c18a8bd
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Sep 18, 2021
34d9a3c
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Sep 30, 2021
b0182c0
vectorize particles; update tests
Sep 30, 2021
dc31767
minor fixes; pin to funsor@normalize-logaddexp
Sep 30, 2021
5c0fe75
update docs/requirements
Sep 30, 2021
2b15fe1
combine Trace_ELBO and TraceEnum_ELBO
Sep 30, 2021
351090b
eager evaluation
Oct 1, 2021
7d029c7
rm file
Oct 1, 2021
1bb7380
lazy
Oct 1, 2021
42ad4fa
remove memoize
Oct 1, 2021
5b6afdb
merge TraceEnum_ELBO
Oct 10, 2021
33628aa
skip test
Oct 11, 2021
18a973b
fixes
Oct 12, 2021
2c3ead3
convert Tensor to Categorical
Oct 12, 2021
5fb1522
restore docs/requirements.txt
Oct 12, 2021
f907f93
pin funsor in docs/requirements
Oct 12, 2021
902e445
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Oct 12, 2021
0042f85
use funsor.optimizer.apply_optimizer; higher precision in the test
Oct 12, 2021
ee5a5ad
pin funsor to the latest commit
Oct 12, 2021
e4c6760
optimize logzq
Oct 12, 2021
aba300a
optimize logzq
Oct 13, 2021
d823153
restore TraceEnum_ELBO
Oct 13, 2021
c06e9e4
revert hmm changes
Oct 13, 2021
eee297d
_tensor_to_categorical helper function
Oct 13, 2021
d748efa
lazy to_funsor
Oct 13, 2021
a1970d6
reduce over particle_var
Oct 13, 2021
4c1ee9e
address comment in tests
Oct 13, 2021
5df30c8
import pyroapi
Oct 13, 2021
46ff6f4
compute expected grads using dice factors
Oct 14, 2021
d7ee7ee
add test with guide enumeration
Oct 15, 2021
49553c3
add two more tests
Oct 15, 2021
835f815
pin funsor
Oct 15, 2021
760eeb0
lint
Oct 15, 2021
ab3831c
remove breakpoint
Oct 15, 2021
0b46f3a
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Oct 29, 2021
b6ff8e0
Approximate(ops.sample, ...) based approach
Nov 3, 2021
b5bece7
Importance funsor based approach
Nov 4, 2021
d6e246e
fixes
Nov 4, 2021
6582d7d
Merge branch 'dev' into fix-funsor-traceelbo
Apr 6, 2022
714fd62
fix funsor model enumeration
Apr 9, 2022
2d2210e
Merge branch 'fix-model-enumeration-funsor' into fix-funsor-traceelbo
Apr 9, 2022
29bad7a
use Sampled funsor
Apr 11, 2022
9144be1
fixes
Apr 11, 2022
e4c8a47
git fixes
Apr 11, 2022
c147ad9
Merge branch 'dev' into fix-funsor-traceelbo
Apr 11, 2022
703a2fa
use Provenance funsor
Apr 11, 2022
3137b1b
clean up
Apr 12, 2022
88713f6
fixes
May 5, 2022
99a0647
Merge branch 'dev' into fix-funsor-traceelbo
May 5, 2022
14131ad
use provenance
Jun 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions pyro/contrib/funsor/infer/trace_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,12 @@ def differentiable_loss(self, model, guide, *args, **kwargs):
model_terms = terms_from_trace(model_tr)
guide_terms = terms_from_trace(guide_tr)

log_measures = guide_terms["log_measures"] + model_terms["log_measures"]
log_factors = model_terms["log_factors"] + [
-f for f in guide_terms["log_factors"]
]
costs = model_terms["log_factors"] + [-f for f in guide_terms["log_factors"]]
plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"]
measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"]

elbo = funsor.Integrate(
sum(log_measures, to_funsor(0.0)),
sum(log_factors, to_funsor(0.0)),
measure_vars,
)
elbo = elbo.reduce(funsor.ops.add, plate_vars)

elbo = to_funsor(0.0)
for cost in costs:
elbo += cost.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this missing Dice factors included in log_measures? IIRC that was the reason for using Integrate.

Copy link
Member Author

@ordabayevy ordabayevy Jul 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied the test from #2894 which has a simple model/guide pair. When running that model (Elbo=Trace_ELBO, backend=contrib.funsor, reparam-False) both guide_terms["log_measures"] and model_terms["log_measures"] are empty. I can't find Dice factors anywhere in model_terms or guide_terms.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess they're not included because Funsor.sample isn't used in the evaluation of Trace_ELBO. I don't think contrib.funsor.infer.Trace_ELBO is tested extensively outside the pyro-api tests in tests/contrib/funsor/test_pyroapi_funsor.py, which is why this wasn't noticed before.

A more general Funsor-based implementation of Trace_ELBO is certainly possible and would look very similar to the guide-side enumeration handling logic in TraceEnum_ELBO. We might even be able to write a custom "enumeration" strategy that just called Funsor.sample and reuse TraceEnum_ELBO as the Trace_ELBO implementation.

I believe a completely general version might require variable elimination logic beyond what's currently in funsor.sum_product handling cases where the guide had plate structure incompatible with the restrictions there, although I can't immediately think of existing tests or examples where that would be the case.


return -to_data(elbo)

Expand Down