Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace deprecated
jax.tree_*
functions with jax.tree.*
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 634281277 Change-Id: I9a440f4625a239ede5982116a059a496a9a6b150
- Loading branch information