From 0d92cc67581befc261b731c97e07c15ec9ebe299 Mon Sep 17 00:00:00 2001 From: James Spencer Date: Tue, 25 Jun 2024 18:12:56 +0100 Subject: [PATCH] Remove jax.jit decorators on MCMC functions. These are compiled as part of a larger program anyway and in jitting at the function definition triggers incorrect jit cache hits. PiperOrigin-RevId: 646519975 Change-Id: Ie4d4dcb57c088fd6017186ebcc3c988eade52037 --- ferminet/mcmc.py | 1 - ferminet/tests/train_test.py | 4 ---- 2 files changed, 5 deletions(-) diff --git a/ferminet/mcmc.py b/ferminet/mcmc.py index 08da246..e002428 100644 --- a/ferminet/mcmc.py +++ b/ferminet/mcmc.py @@ -244,7 +244,6 @@ def make_mcmc_step(batch_network, """ inner_fun = mh_block_update if blocks > 1 else mh_update - @jax.jit def mcmc_step(params, data, key, width): """Performs a set of MCMC steps. diff --git a/ferminet/tests/train_test.py b/ferminet/tests/train_test.py index 352c0aa..d5c07f0 100644 --- a/ferminet/tests/train_test.py +++ b/ferminet/tests/train_test.py @@ -25,7 +25,6 @@ from ferminet import train from ferminet.configs import atom from ferminet.configs import diatomic -import jax import pyscf FLAGS = flags.FLAGS @@ -95,9 +94,6 @@ def setUp(self): # Test calculations are small enough to fit in RAM and we don't need # checkpoint files. pyscf.lib.param.TMPDIR = None - # Prevents issues related to the mcmc step in pretraining if multiple - # training runs are executed in the same session. - jax.clear_caches() @parameterized.parameters(_config_params()) def test_training_step(