From db5a27aa435d39c67b5895ad4057241635f43b32 Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Fri, 16 Jun 2023 09:56:07 -0400 Subject: [PATCH 1/3] Wrap the preconditioned jit Use functools' wraps decorator inside the preconditioned/staggered JIT decorator to ensure doc and inspection of the decorated functions yeilds the correct function instead of the stagger wrapper. --- pymbar/mbar_solvers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymbar/mbar_solvers.py b/pymbar/mbar_solvers.py index dd45f622..bd85861e 100644 --- a/pymbar/mbar_solvers.py +++ b/pymbar/mbar_solvers.py @@ -1,5 +1,6 @@ import logging import warnings +from functools import wraps import numpy as np @@ -135,6 +136,7 @@ def jit_or_pass_after_bitsize(jitable_fn): A function which can be jit'd """ + @wraps(jitable_fn) # Helper to ensure the decorated function still registers for docs and inspection def staggered_jit(*args, **kwargs): # This will only trigger if JAX is set if use_jit and not config.x64_enabled: From bc1a23cfc40f69703af2954c405c9e2648c8d92f Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Fri, 16 Jun 2023 10:51:30 -0400 Subject: [PATCH 2/3] black --- pymbar/mbar_solvers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymbar/mbar_solvers.py b/pymbar/mbar_solvers.py index bd85861e..6c36e654 100644 --- a/pymbar/mbar_solvers.py +++ b/pymbar/mbar_solvers.py @@ -136,7 +136,9 @@ def jit_or_pass_after_bitsize(jitable_fn): A function which can be jit'd """ - @wraps(jitable_fn) # Helper to ensure the decorated function still registers for docs and inspection + @wraps( + jitable_fn + ) # Helper to ensure the decorated function still registers for docs and inspection def staggered_jit(*args, **kwargs): # This will only trigger if JAX is set if use_jit and not config.x64_enabled: From 3598b011a1f23927726800b99674c39d09b1dfe4 Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Fri, 16 Jun 2023 11:36:36 -0400 Subject: [PATCH 3/3] Mark the TNC test flaky as it is. --- pymbar/tests/test_mbar_solvers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymbar/tests/test_mbar_solvers.py b/pymbar/tests/test_mbar_solvers.py index 0ec9aee9..121eea43 100644 --- a/pymbar/tests/test_mbar_solvers.py +++ b/pymbar/tests/test_mbar_solvers.py @@ -64,7 +64,7 @@ def run_mbar_protocol(oscillator_bundle, protocol): "CG", "BFGS", "Newton-CG", - "TNC", + pytest.param("TNC", marks=pytest.mark.flaky(max_runs=2)), # This one is flaky "trust-ncg", "trust-krylov", "trust-exact",