diff --git a/pymbar/mbar_solvers.py b/pymbar/mbar_solvers.py index dd45f622..6c36e654 100644 --- a/pymbar/mbar_solvers.py +++ b/pymbar/mbar_solvers.py @@ -1,5 +1,6 @@ import logging import warnings +from functools import wraps import numpy as np @@ -135,6 +136,9 @@ def jit_or_pass_after_bitsize(jitable_fn): A function which can be jit'd """ + @wraps( + jitable_fn + ) # Helper to ensure the decorated function still registers for docs and inspection def staggered_jit(*args, **kwargs): # This will only trigger if JAX is set if use_jit and not config.x64_enabled: diff --git a/pymbar/tests/test_mbar_solvers.py b/pymbar/tests/test_mbar_solvers.py index 0ec9aee9..121eea43 100644 --- a/pymbar/tests/test_mbar_solvers.py +++ b/pymbar/tests/test_mbar_solvers.py @@ -64,7 +64,7 @@ def run_mbar_protocol(oscillator_bundle, protocol): "CG", "BFGS", "Newton-CG", - "TNC", + pytest.param("TNC", marks=pytest.mark.flaky(max_runs=2)), # This one is flaky "trust-ncg", "trust-krylov", "trust-exact",