diff --git a/ferminet/train.py b/ferminet/train.py index 996b9b3..63011d0 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -325,7 +325,8 @@ def make_kfac_training_step( # Due to some KFAC cleverness related to donated buffers, need to do this # to make state resettable copy_tree = constants.pmap( - functools.partial(jax.tree_util.tree_map, lambda x: 1.0 * x)) + functools.partial(jax.tree_util.tree_map, + lambda x: (1.0 * x).astype(x.dtype))) def step( data: networks.FermiNetData,