diff --git a/ferminet/mcmc.py b/ferminet/mcmc.py index 08da246..e002428 100644 --- a/ferminet/mcmc.py +++ b/ferminet/mcmc.py @@ -244,7 +244,6 @@ def make_mcmc_step(batch_network, """ inner_fun = mh_block_update if blocks > 1 else mh_update - @jax.jit def mcmc_step(params, data, key, width): """Performs a set of MCMC steps. diff --git a/ferminet/tests/train_test.py b/ferminet/tests/train_test.py index 352c0aa..d5c07f0 100644 --- a/ferminet/tests/train_test.py +++ b/ferminet/tests/train_test.py @@ -25,7 +25,6 @@ from ferminet import train from ferminet.configs import atom from ferminet.configs import diatomic -import jax import pyscf FLAGS = flags.FLAGS @@ -95,9 +94,6 @@ def setUp(self): # Test calculations are small enough to fit in RAM and we don't need # checkpoint files. pyscf.lib.param.TMPDIR = None - # Prevents issues related to the mcmc step in pretraining if multiple - # training runs are executed in the same session. - jax.clear_caches() @parameterized.parameters(_config_params()) def test_training_step(