diff --git a/ferminet/curvature_tags_and_blocks.py b/ferminet/curvature_tags_and_blocks.py index 123c61e..75c5767 100644 --- a/ferminet/curvature_tags_and_blocks.py +++ b/ferminet/curvature_tags_and_blocks.py @@ -24,7 +24,7 @@ Scalar = kfac_jax.utils.Scalar Numeric = kfac_jax.utils.Numeric -vmap_psd_inv_cholesky = jax.vmap(kfac_jax.utils.psd_inv_cholesky, (0, None), 0) +vmap_psd_inv = jax.vmap(kfac_jax.utils.psd_inv, (0, None), 0) vmap_matmul = jax.vmap(jnp.matmul, in_axes=(0, 0), out_axes=0)