Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Removing the synchronization across devices of reductions computed …
…in various spots. Reductions ops used to be nondeterministic on the GPU in XLA, but apparently not anymore. - Removing the pmap_axis_name argument in the multiply_* and _update_curvature_matrix_estimate methods in the curvature blocks since it's no longer used there. Also removing it from _update_cache methods since it's hasn't been used there for a while either. - Minor internal refactor of optimizer, and of pi_adjusted_kronecker_inverse. - Updating patches_moments_explicit to propagate dtypes properly in one spot. (It's unclear if this matters.) - Adding 'per_device_stats_to_log' option to example experiment class which cause the listed statistics to be logged on a per-device basis (which is useful for checking that the devices remain properly synchronized). PiperOrigin-RevId: 520611599 Change-Id: I427b786a45e4b5c1c9f0d092bb25f6359502354e
- Loading branch information