Skip to content

Commit

Permalink
kfac_jax:
Browse files Browse the repository at this point in the history
- Adding capability to switch between using Cholesky-based inversions and regular ones. The former are deterministic on GPUs and somewhat faster to compute, but tend to perform a bit worse and not tolerate very low damping values.
- Changing internal implementation of Kronecker factored blocks to prune out singleton dimensions instead of using special case logic for them pi_adjusted_kronecker_factors.
- Going back to using exp-mean-log formula for c_k since this will better avoid over/underflows for blocks with many factors
PiperOrigin-RevId: 578240057
Change-Id: I20ab5268599214a31f4206a9358b76d99434d6c8
  • Loading branch information
james-martens authored and jsspencer committed Nov 24, 2023
1 parent e9f8c64 commit 51fad8f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion ferminet/curvature_tags_and_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
Scalar = kfac_jax.utils.Scalar
Numeric = kfac_jax.utils.Numeric

vmap_psd_inv_cholesky = jax.vmap(kfac_jax.utils.psd_inv_cholesky, (0, None), 0)
vmap_psd_inv_cholesky = jax.vmap(kfac_jax.utils.psd_inv, (0, None), 0)
vmap_matmul = jax.vmap(jnp.matmul, in_axes=(0, 0), out_axes=0)


Expand Down

0 comments on commit 51fad8f

Please sign in to comment.