diff --git a/ferminet/train.py b/ferminet/train.py index 63011d0..896f25a 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -127,7 +127,7 @@ def init_electrons( # pylint: disable=dangerous-default-value # All optimizer states (KFAC and optax-based). -OptimizerState = Union[optax.OptState, kfac_jax.optimizer.OptimizerState] +OptimizerState = Union[optax.OptState, kfac_jax.Optimizer.State] OptUpdateResults = Tuple[networks.ParamTree, Optional[OptimizerState], jnp.ndarray, Optional[qmc_loss_functions.AuxiliaryLossData]] @@ -331,7 +331,7 @@ def make_kfac_training_step( def step( data: networks.FermiNetData, params: networks.ParamTree, - state: kfac_jax.optimizer.OptimizerState, + state: kfac_jax.Optimizer.State, key: chex.PRNGKey, mcmc_width: jnp.ndarray, ) -> StepResults: