diff --git a/ferminet/curvature_tags_and_blocks.py b/ferminet/curvature_tags_and_blocks.py index 75c5767..617ae53 100644 --- a/ferminet/curvature_tags_and_blocks.py +++ b/ferminet/curvature_tags_and_blocks.py @@ -55,6 +55,7 @@ def update_curvature_matrix_estimate( estimation_data: Mapping[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, + identity_weight: Numeric, batch_size: int, ) -> kfac_jax.TwoKroneckerFactored.State: estimation_data = dict(**estimation_data) @@ -69,6 +70,7 @@ def update_curvature_matrix_estimate( estimation_data=estimation_data, ema_old=ema_old, ema_new=ema_new, + identity_weight=identity_weight, batch_size=batch_size, ) @@ -91,8 +93,11 @@ def update_curvature_matrix_estimate( estimation_data: Mapping[str, Sequence[Array]], ema_old: Numeric, ema_new: Numeric, + identity_weight: Numeric, batch_size: int, ) -> kfac_jax.TwoKroneckerFactored.State: + del identity_weight + x, = estimation_data["inputs"] dy, = estimation_data["outputs_tangent"] assert batch_size == x.shape[0]