Skip to content

Commit

Permalink
Remove jax.jit decorators on MCMC functions.
Browse files Browse the repository at this point in the history
These are compiled as part of a larger program anyway and in jitting at the
function definition triggers incorrect jit cache hits.

PiperOrigin-RevId: 646519975
Change-Id: Ie4d4dcb57c088fd6017186ebcc3c988eade52037
  • Loading branch information
jsspencer authored and dpfau committed Aug 22, 2024
1 parent 286e3f7 commit 0d92cc6
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 5 deletions.
1 change: 0 additions & 1 deletion ferminet/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ def make_mcmc_step(batch_network,
"""
inner_fun = mh_block_update if blocks > 1 else mh_update

@jax.jit
def mcmc_step(params, data, key, width):
"""Performs a set of MCMC steps.
Expand Down
4 changes: 0 additions & 4 deletions ferminet/tests/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from ferminet import train
from ferminet.configs import atom
from ferminet.configs import diatomic
import jax
import pyscf

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -95,9 +94,6 @@ def setUp(self):
# Test calculations are small enough to fit in RAM and we don't need
# checkpoint files.
pyscf.lib.param.TMPDIR = None
# Prevents issues related to the mcmc step in pretraining if multiple
# training runs are executed in the same session.
jax.clear_caches()

@parameterized.parameters(_config_params())
def test_training_step(
Expand Down

0 comments on commit 0d92cc6

Please sign in to comment.