Skip to content

Commit

Permalink
Fix bug in copy_tree
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 595951242
Change-Id: I7913d9037a5295494ca62bf314ca2e3dd4e8a008
  • Loading branch information
dpfau authored and jsspencer committed Apr 15, 2024
1 parent e94435b commit 8a09e34
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion ferminet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ def make_kfac_training_step(
# Due to some KFAC cleverness related to donated buffers, need to do this
# to make state resettable
copy_tree = constants.pmap(
functools.partial(jax.tree_util.tree_map, lambda x: 1.0 * x))
functools.partial(jax.tree_util.tree_map,
lambda x: (1.0 * x).astype(x.dtype)))

def step(
data: networks.FermiNetData,
Expand Down

0 comments on commit 8a09e34

Please sign in to comment.