Skip to content

Commit

Permalink
Merge pull request #506 from Lnaden/jax_stagger_jit
Browse files Browse the repository at this point in the history
Wrap the preconditioned jit
  • Loading branch information
Lnaden authored Jun 16, 2023
2 parents 1aaa16d + 3598b01 commit cfe49fc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 4 additions & 0 deletions pymbar/mbar_solvers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import warnings
from functools import wraps

import numpy as np

Expand Down Expand Up @@ -135,6 +136,9 @@ def jit_or_pass_after_bitsize(jitable_fn):
A function which can be jit'd
"""

@wraps(
jitable_fn
) # Helper to ensure the decorated function still registers for docs and inspection
def staggered_jit(*args, **kwargs):
# This will only trigger if JAX is set
if use_jit and not config.x64_enabled:
Expand Down
2 changes: 1 addition & 1 deletion pymbar/tests/test_mbar_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def run_mbar_protocol(oscillator_bundle, protocol):
"CG",
"BFGS",
"Newton-CG",
"TNC",
pytest.param("TNC", marks=pytest.mark.flaky(max_runs=2)), # This one is flaky
"trust-ncg",
"trust-krylov",
"trust-exact",
Expand Down

0 comments on commit cfe49fc

Please sign in to comment.