From 8a09e34124ddda8135a95689d5893439cbc863c9 Mon Sep 17 00:00:00 2001 From: David Pfau Date: Fri, 5 Jan 2024 11:37:11 +0000 Subject: [PATCH] Fix bug in copy_tree PiperOrigin-RevId: 595951242 Change-Id: I7913d9037a5295494ca62bf314ca2e3dd4e8a008 --- ferminet/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ferminet/train.py b/ferminet/train.py index 996b9b3..63011d0 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -325,7 +325,8 @@ def make_kfac_training_step( # Due to some KFAC cleverness related to donated buffers, need to do this # to make state resettable copy_tree = constants.pmap( - functools.partial(jax.tree_util.tree_map, lambda x: 1.0 * x)) + functools.partial(jax.tree_util.tree_map, + lambda x: (1.0 * x).astype(x.dtype))) def step( data: networks.FermiNetData,