Skip to content

Commit

Permalink
- Removing the synchronization across devices of reductions computed …
Browse files Browse the repository at this point in the history
…in various spots. Reductions ops used to be nondeterministic on the GPU in XLA, but apparently not anymore.

- Removing the pmap_axis_name argument in the multiply_* and _update_curvature_matrix_estimate methods in the curvature blocks since it's no longer used there. Also removing it from _update_cache methods since it's hasn't been used there for a while either.

- Minor internal refactor of optimizer, and of pi_adjusted_kronecker_inverse.

- Updating patches_moments_explicit to propagate dtypes properly in one spot. (It's unclear if this matters.)

- Adding 'per_device_stats_to_log' option to example experiment class which cause the listed statistics to be logged on a per-device basis (which is useful for checking that the devices remain properly synchronized).

PiperOrigin-RevId: 520611599
Change-Id: I427b786a45e4b5c1c9f0d092bb25f6359502354e
  • Loading branch information
james-martens authored and jsspencer committed Apr 25, 2023
1 parent 314a657 commit f12a3ed
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions ferminet/curvature_tags_and_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

"""Curvature blocks for FermiNet."""
import functools
from typing import Any, Mapping, Optional, Sequence, Set, Tuple
import chex
import jax
Expand Down Expand Up @@ -97,8 +96,13 @@ def update_curvature_matrix_estimate(
inputs_cov = jnp.einsum("bijk,bijl->jkl", x, x) / normalizer
dy = jnp.reshape(dy, dy.shape[:-2] + (-1,))
outputs_cov = jnp.einsum("bijk,bijl->jkl", dy, dy) / normalizer

state.inputs_factor.update(inputs_cov, ema_old, ema_new)
state.outputs_factor.update(outputs_cov, ema_old, ema_new)

state.inputs_factor.sync(pmap_axis_name)
state.outputs_factor.sync(pmap_axis_name)

return state

def _init(
Expand Down Expand Up @@ -136,11 +140,9 @@ def _update_cache(
exact_powers: chex.Numeric,
approx_powers: chex.Numeric,
eigenvalues: bool,
pmap_axis_name: Optional[str],
) -> kfac_jax.TwoKroneckerFactored.State:
del eigenvalues
state.inputs_factor.sync(pmap_axis_name)
state.outputs_factor.sync(pmap_axis_name)

if exact_powers:
raise NotImplementedError(
"Caching of exact powers is not yet implemented for QmcBlockedDense.")
Expand All @@ -150,14 +152,13 @@ def _update_cache(
f"yet implemented.")
cache = state.cache[str(power)]
pi_adjusted_inverse = jax.vmap(
functools.partial(kfac_jax.utils.pi_adjusted_kronecker_inverse,
pmap_axis_name=pmap_axis_name),
(0, 0, None), (0, 0)
kfac_jax.utils.pi_adjusted_kronecker_inverse,
(0, None), (0, 0)
)
cache["inputs_factor"], cache["outputs_factor"] = pi_adjusted_inverse(
state.inputs_factor.value,
state.outputs_factor.value,
identity_weight
damping=identity_weight,
)
return state

Expand Down

0 comments on commit f12a3ed

Please sign in to comment.