Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634281277
Change-Id: I9a440f4625a239ede5982116a059a496a9a6b150
  • Loading branch information
Jake VanderPlas authored and jsspencer committed Jun 4, 2024
1 parent b02ea69 commit 3d453fa
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ferminet/utils/multi_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def check_synced(obj, name):
True if object is in sync across all devices and False otherwise.
"""
for i in range(1, jax.local_device_count()):
norms = jax.tree_map(lambda x: jnp.linalg.norm(x[0] - x[i]), obj) # pylint: disable=cell-var-from-loop
total_norms = sum(jax.tree_leaves(norms))
norms = jax.tree.map(lambda x: jnp.linalg.norm(x[0] - x[i]), obj) # pylint: disable=cell-var-from-loop
total_norms = sum(jax.tree.leaves(norms))
if total_norms != 0.0:
logging.info(
'%s object is not synced across device 0 and %d. The total norm'
Expand Down

0 comments on commit 3d453fa

Please sign in to comment.