Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly stagger JIT until first call #505

Merged
merged 1 commit into from
Jun 16, 2023

Conversation

Lnaden
Copy link
Contributor

@Lnaden Lnaden commented Jun 15, 2023

Follow up to #504 and #496. The initial implementation didn't actually stagger setting the 64-bit jax until first call. This implementation staggers the JIT call until the function is actually used. This shouldn't break JAX's cache since the function object in question never changes so its hash wont change and we still get all the accelerated code.

@codecov
Copy link

codecov bot commented Jun 15, 2023

Codecov Report

Merging #505 (14a16b6) into master (a5fa114) will decrease coverage by 0.05%.
The diff coverage is 100.00%.

@Lnaden Lnaden merged commit 63084ad into choderalab:master Jun 16, 2023
@Lnaden Lnaden deleted the jax_stagger_jit branch June 16, 2023 13:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant