From 18d1a1ce1b156ffc72889a0627649add46797fc4 Mon Sep 17 00:00:00 2001 From: David Pfau Date: Fri, 29 Sep 2023 15:11:14 +0100 Subject: [PATCH] Add ability to reset state if NaN encountered during optimization PiperOrigin-RevId: 569483520 Change-Id: Icc14a54c8a82588e991388859ca9c6053658bcd0 --- ferminet/base_config.py | 4 +++ ferminet/train.py | 61 ++++++++++++++++++++++++++++++++++------- 2 files changed, 55 insertions(+), 10 deletions(-) diff --git a/ferminet/base_config.py b/ferminet/base_config.py index abc003f..391e0b0 100644 --- a/ferminet/base_config.py +++ b/ferminet/base_config.py @@ -72,6 +72,10 @@ def default() -> ml_collections.ConfigDict: # average clipped energy rather than average energy, guaranteeing that # the average energy difference will be zero in each batch. 'center_at_clip': True, + # If true, keep the parameters and optimizer state from the previous + # step and revert them if they become NaN after an update. Mainly + # useful for excited states + 'reset_if_nan': False, # KFAC hyperparameters. See KFAC documentation for details. 'kfac': { 'invert_every': 1, diff --git a/ferminet/train.py b/ferminet/train.py index ce498c5..e93109b 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -243,6 +243,7 @@ def loss_eval( def make_training_step( mcmc_step, optimizer_step: OptUpdate, + reset_if_nan: bool = False, ) -> Step: """Factory to create traning step for non-KFAC optimizers. @@ -251,6 +252,8 @@ def make_training_step( for creating the callable. optimizer_step: OptUpdate callable which evaluates the forward and backward passes and updates the parameters and optimizer state, as required. + reset_if_nan: If true, reset the params and opt state to the state at the + previous step when the loss is NaN Returns: step, a callable which performs a set of MCMC steps and then an optimization @@ -270,15 +273,27 @@ def step( data, pmove = mcmc_step(params, data, mcmc_key, mcmc_width) # Optimization step - new_params, state, loss, aux_data = optimizer_step(params, data, state, - loss_key) - return data, new_params, state, loss, aux_data, pmove + new_params, new_state, loss, aux_data = optimizer_step(params, + data, + state, + loss_key) + if reset_if_nan: + new_params = jax.lax.cond(jnp.isnan(loss), + lambda: params, + lambda: new_params) + new_state = jax.lax.cond(jnp.isnan(loss), + lambda: state, + lambda: new_state) + return data, new_params, new_state, loss, aux_data, pmove return step -def make_kfac_training_step(mcmc_step, damping: float, - optimizer: kfac_jax.Optimizer) -> Step: +def make_kfac_training_step( + mcmc_step, + damping: float, + optimizer: kfac_jax.Optimizer, + reset_if_nan: bool = False) -> Step: """Factory to create traning step for KFAC optimizers. Args: @@ -286,6 +301,8 @@ def make_kfac_training_step(mcmc_step, damping: float, for creating the callable. damping: value of damping to use for each KFAC update step. optimizer: KFAC optimizer instance. + reset_if_nan: If true, reset the params and opt state to the state at the + previous step when the loss is NaN Returns: step, a callable which performs a set of MCMC steps and then an optimization @@ -295,6 +312,10 @@ def make_kfac_training_step(mcmc_step, damping: float, shared_mom = kfac_jax.utils.replicate_all_local_devices(jnp.zeros([])) shared_damping = kfac_jax.utils.replicate_all_local_devices( jnp.asarray(damping)) + # Due to some KFAC cleverness related to donated buffers, need to do this + # to make state resettable + copy_tree = constants.pmap( + functools.partial(jax.tree_util.tree_map, lambda x: 1.0 * x)) def step( data: networks.FermiNetData, @@ -311,8 +332,12 @@ def step( mcmc_keys, loss_keys = kfac_jax.utils.p_split(key) data, pmove = mcmc_step(params, data, mcmc_keys, mcmc_width) + if reset_if_nan: + old_params = copy_tree(params) + old_state = copy_tree(state) + # Optimization step - new_params, state, stats = optimizer.step( + new_params, new_state, stats = optimizer.step( params=params, state=state, rng=loss_keys, @@ -320,7 +345,11 @@ def step( momentum=shared_mom, damping=shared_damping, ) - return data, new_params, state, stats['loss'], stats['aux'], pmove + + if reset_if_nan and jnp.isnan(stats['loss']): + new_params = old_params + new_state = old_state + return data, new_params, new_state, stats['loss'], stats['aux'], pmove return step @@ -652,12 +681,14 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray: opt_state = opt_state_ckpt or opt_state # avoid overwriting ckpted state step = make_training_step( mcmc_step=mcmc_step, - optimizer_step=make_opt_update_step(evaluate_loss, optimizer)) + optimizer_step=make_opt_update_step(evaluate_loss, optimizer), + reset_if_nan=cfg.optim.reset_if_nan) elif isinstance(optimizer, kfac_jax.Optimizer): step = make_kfac_training_step( mcmc_step=mcmc_step, damping=cfg.optim.kfac.damping, - optimizer=optimizer) + optimizer=optimizer, + reset_if_nan=cfg.optim.reset_if_nan) else: raise ValueError(f'Unknown optimizer: {optimizer}') @@ -709,6 +740,7 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray: log=False) with writer_manager as writer: # Main training loop + num_resets = 0 # used if reset_if_nan is true for t in range(t_init, cfg.optim.iterations): sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key) data, params, opt_state, loss, unused_aux_data, pmove = step( @@ -740,7 +772,16 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray: tree = {'params': params, 'loss': loss} if cfg.optim.optimizer != 'none': tree['optim'] = opt_state - chex.assert_tree_all_finite(tree) + try: + chex.assert_tree_all_finite(tree) + num_resets = 0 # Reset counter if check passes + except AssertionError as e: + if cfg.optim.reset_if_nan: # Allow a certain number of NaNs + num_resets += 1 + if num_resets > 100: + raise e + else: + raise e # Logging if t % cfg.log.stats_frequency == 0: