Skip to content

Commit

Permalink
* Adding damping to the curvature update methods.
Browse files Browse the repository at this point in the history
* Adding a few useful properties to the moving averages.
* Allowing to pass `None` in the kf products, which represents identity.

PiperOrigin-RevId: 635511629
Change-Id: I018fd274f48a2a1c227c96c7983b49e9f5cfa812
  • Loading branch information
FermiNet Contributor authored and jsspencer committed Jun 4, 2024
1 parent 3d453fa commit 1815983
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ferminet/curvature_tags_and_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def update_curvature_matrix_estimate(
estimation_data: Mapping[str, Sequence[Array]],
ema_old: Numeric,
ema_new: Numeric,
identity_weight: Numeric,
batch_size: int,
) -> kfac_jax.TwoKroneckerFactored.State:
estimation_data = dict(**estimation_data)
Expand All @@ -69,6 +70,7 @@ def update_curvature_matrix_estimate(
estimation_data=estimation_data,
ema_old=ema_old,
ema_new=ema_new,
identity_weight=identity_weight,
batch_size=batch_size,
)

Expand All @@ -91,8 +93,11 @@ def update_curvature_matrix_estimate(
estimation_data: Mapping[str, Sequence[Array]],
ema_old: Numeric,
ema_new: Numeric,
identity_weight: Numeric,
batch_size: int,
) -> kfac_jax.TwoKroneckerFactored.State:
del identity_weight

x, = estimation_data["inputs"]
dy, = estimation_data["outputs_tangent"]
assert batch_size == x.shape[0]
Expand Down

0 comments on commit 1815983

Please sign in to comment.