Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update call to psd_inv_cholesky to fix breaking change #71

Merged
merged 9 commits into from
Nov 24, 2023
6 changes: 3 additions & 3 deletions ferminet/curvature_tags_and_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
Scalar = kfac_jax.utils.Scalar
Numeric = kfac_jax.utils.Numeric

vmap_psd_inv_cholesky = jax.vmap(kfac_jax.utils.psd_inv_cholesky, (0, None), 0)
vmap_psd_inv = jax.vmap(kfac_jax.utils.psd_inv, (0, None), 0)
vmap_matmul = jax.vmap(jnp.matmul, in_axes=(0, 0), out_axes=0)


Expand Down Expand Up @@ -141,8 +141,8 @@ def _update_cache(
self,
state: kfac_jax.TwoKroneckerFactored.State,
identity_weight: kfac_jax.utils.Numeric,
exact_powers: set[kfac_jax.utils.Scalar],
approx_powers: set[kfac_jax.utils.Scalar],
exact_powers: Set[kfac_jax.utils.Scalar],
approx_powers: Set[kfac_jax.utils.Scalar],
gcassella marked this conversation as resolved.
Show resolved Hide resolved
eigenvalues: bool,
) -> kfac_jax.TwoKroneckerFactored.State:
del eigenvalues
Expand Down
Loading