diff --git a/ferminet/utils/multi_host.py b/ferminet/utils/multi_host.py index 110518b..dda8549 100644 --- a/ferminet/utils/multi_host.py +++ b/ferminet/utils/multi_host.py @@ -30,8 +30,8 @@ def check_synced(obj, name): True if object is in sync across all devices and False otherwise. """ for i in range(1, jax.local_device_count()): - norms = jax.tree_map(lambda x: jnp.linalg.norm(x[0] - x[i]), obj) # pylint: disable=cell-var-from-loop - total_norms = sum(jax.tree_leaves(norms)) + norms = jax.tree.map(lambda x: jnp.linalg.norm(x[0] - x[i]), obj) # pylint: disable=cell-var-from-loop + total_norms = sum(jax.tree.leaves(norms)) if total_norms != 0.0: logging.info( '%s object is not synced across device 0 and %d. The total norm'