From 0233831104278798999f51eca299fa5473285bf3 Mon Sep 17 00:00:00 2001 From: James Spencer Date: Fri, 12 Apr 2024 10:48:22 +0100 Subject: [PATCH] Update type for kfac_jax Optimizer.State. The forward declaration was recently removed; instead use the public API. PiperOrigin-RevId: 624106473 Change-Id: Ia2043491c70b31c89d5025401b0f2d7ad36a972e --- ferminet/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ferminet/train.py b/ferminet/train.py index 63011d0..896f25a 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -127,7 +127,7 @@ def init_electrons( # pylint: disable=dangerous-default-value # All optimizer states (KFAC and optax-based). -OptimizerState = Union[optax.OptState, kfac_jax.optimizer.OptimizerState] +OptimizerState = Union[optax.OptState, kfac_jax.Optimizer.State] OptUpdateResults = Tuple[networks.ParamTree, Optional[OptimizerState], jnp.ndarray, Optional[qmc_loss_functions.AuxiliaryLossData]] @@ -331,7 +331,7 @@ def make_kfac_training_step( def step( data: networks.FermiNetData, params: networks.ParamTree, - state: kfac_jax.optimizer.OptimizerState, + state: kfac_jax.Optimizer.State, key: chex.PRNGKey, mcmc_width: jnp.ndarray, ) -> StepResults: