-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
avoids setting jax tracer as lazy property attribute #1843
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this, @sbidari! Two thoughts:
- I think some aspects of the second test function have been unnecessarily carried over from the first test function and can be removed. See below.
- I wonder whether it might be worth annotating the test functions (via their docstrings) to explain why they should run without error if things are working and why/how they fail prior without the patch in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending reviews from @dylanhmorris. Thanks @sbidari!
Thanks @fehiepsi! @dylanhmorris I added docstring with links to where issues are described and combined the two test functions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Co-authored-by: Dylan H. Morris <[email protected]>
I am not sure why the test is failing in CI with |
Noting here that when this lands it will also provide a more general fix for #1651 |
Move test from test_distributions.py to test_distributions_util.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sbidari!
@fehiepsi Is there a release planned for this fix? We would appreciate having it for cdcgov/multisignal-epi-inference as soon as possible. |
I can make a patch release in upcoming days. In the mean time, you can patch the utility. |
@fehiepsi Do you have a date in mind? |
This commit applies the changes Subekshya Bidari authored in [Numpyro PR number 1843][pr] to fix using a TruncatedNormal distribution (or any other TwoSidedTruncatedDistribution) when running multiple chains in parallel using the NUTs MCMC sampler. [pr]: pyro-ppl#1843
Add a conditional if
numpyro.util.not_jax_tracer(value)
before settinglazy_property
as an attribute ofTruncatedDistribution
This resolves errors described in #1836 and CDCgov/PyRenew#282
tests added:
check valid sampling of
TruncatedDistribution
in parallel (failed previously as described in #1836 )check predictive methods (prior predictive and inference) can be run multiple times on the same model built using
TruncatedDistribution
(error encountered here CDCgov/PyRenew#282)