Skip to content

Commit

Permalink
Add ability to reset state if NaN encountered during optimization
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569483520
Change-Id: Icc14a54c8a82588e991388859ca9c6053658bcd0
  • Loading branch information
dpfau authored and jsspencer committed Nov 24, 2023
1 parent 66c92e0 commit 18d1a1c
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 10 deletions.
4 changes: 4 additions & 0 deletions ferminet/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def default() -> ml_collections.ConfigDict:
# average clipped energy rather than average energy, guaranteeing that
# the average energy difference will be zero in each batch.
'center_at_clip': True,
# If true, keep the parameters and optimizer state from the previous
# step and revert them if they become NaN after an update. Mainly
# useful for excited states
'reset_if_nan': False,
# KFAC hyperparameters. See KFAC documentation for details.
'kfac': {
'invert_every': 1,
Expand Down
61 changes: 51 additions & 10 deletions ferminet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def loss_eval(
def make_training_step(
mcmc_step,
optimizer_step: OptUpdate,
reset_if_nan: bool = False,
) -> Step:
"""Factory to create traning step for non-KFAC optimizers.
Expand All @@ -251,6 +252,8 @@ def make_training_step(
for creating the callable.
optimizer_step: OptUpdate callable which evaluates the forward and backward
passes and updates the parameters and optimizer state, as required.
reset_if_nan: If true, reset the params and opt state to the state at the
previous step when the loss is NaN
Returns:
step, a callable which performs a set of MCMC steps and then an optimization
Expand All @@ -270,22 +273,36 @@ def step(
data, pmove = mcmc_step(params, data, mcmc_key, mcmc_width)

# Optimization step
new_params, state, loss, aux_data = optimizer_step(params, data, state,
loss_key)
return data, new_params, state, loss, aux_data, pmove
new_params, new_state, loss, aux_data = optimizer_step(params,
data,
state,
loss_key)
if reset_if_nan:
new_params = jax.lax.cond(jnp.isnan(loss),
lambda: params,
lambda: new_params)
new_state = jax.lax.cond(jnp.isnan(loss),
lambda: state,
lambda: new_state)
return data, new_params, new_state, loss, aux_data, pmove

return step


def make_kfac_training_step(mcmc_step, damping: float,
optimizer: kfac_jax.Optimizer) -> Step:
def make_kfac_training_step(
mcmc_step,
damping: float,
optimizer: kfac_jax.Optimizer,
reset_if_nan: bool = False) -> Step:
"""Factory to create traning step for KFAC optimizers.
Args:
mcmc_step: Callable which performs the set of MCMC steps. See make_mcmc_step
for creating the callable.
damping: value of damping to use for each KFAC update step.
optimizer: KFAC optimizer instance.
reset_if_nan: If true, reset the params and opt state to the state at the
previous step when the loss is NaN
Returns:
step, a callable which performs a set of MCMC steps and then an optimization
Expand All @@ -295,6 +312,10 @@ def make_kfac_training_step(mcmc_step, damping: float,
shared_mom = kfac_jax.utils.replicate_all_local_devices(jnp.zeros([]))
shared_damping = kfac_jax.utils.replicate_all_local_devices(
jnp.asarray(damping))
# Due to some KFAC cleverness related to donated buffers, need to do this
# to make state resettable
copy_tree = constants.pmap(
functools.partial(jax.tree_util.tree_map, lambda x: 1.0 * x))

def step(
data: networks.FermiNetData,
Expand All @@ -311,16 +332,24 @@ def step(
mcmc_keys, loss_keys = kfac_jax.utils.p_split(key)
data, pmove = mcmc_step(params, data, mcmc_keys, mcmc_width)

if reset_if_nan:
old_params = copy_tree(params)
old_state = copy_tree(state)

# Optimization step
new_params, state, stats = optimizer.step(
new_params, new_state, stats = optimizer.step(
params=params,
state=state,
rng=loss_keys,
batch=data,
momentum=shared_mom,
damping=shared_damping,
)
return data, new_params, state, stats['loss'], stats['aux'], pmove

if reset_if_nan and jnp.isnan(stats['loss']):
new_params = old_params
new_state = old_state
return data, new_params, new_state, stats['loss'], stats['aux'], pmove

return step

Expand Down Expand Up @@ -652,12 +681,14 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray:
opt_state = opt_state_ckpt or opt_state # avoid overwriting ckpted state
step = make_training_step(
mcmc_step=mcmc_step,
optimizer_step=make_opt_update_step(evaluate_loss, optimizer))
optimizer_step=make_opt_update_step(evaluate_loss, optimizer),
reset_if_nan=cfg.optim.reset_if_nan)
elif isinstance(optimizer, kfac_jax.Optimizer):
step = make_kfac_training_step(
mcmc_step=mcmc_step,
damping=cfg.optim.kfac.damping,
optimizer=optimizer)
optimizer=optimizer,
reset_if_nan=cfg.optim.reset_if_nan)
else:
raise ValueError(f'Unknown optimizer: {optimizer}')

Expand Down Expand Up @@ -709,6 +740,7 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray:
log=False)
with writer_manager as writer:
# Main training loop
num_resets = 0 # used if reset_if_nan is true
for t in range(t_init, cfg.optim.iterations):
sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key)
data, params, opt_state, loss, unused_aux_data, pmove = step(
Expand Down Expand Up @@ -740,7 +772,16 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray:
tree = {'params': params, 'loss': loss}
if cfg.optim.optimizer != 'none':
tree['optim'] = opt_state
chex.assert_tree_all_finite(tree)
try:
chex.assert_tree_all_finite(tree)
num_resets = 0 # Reset counter if check passes
except AssertionError as e:
if cfg.optim.reset_if_nan: # Allow a certain number of NaNs
num_resets += 1
if num_resets > 100:
raise e
else:
raise e

# Logging
if t % cfg.log.stats_frequency == 0:
Expand Down

0 comments on commit 18d1a1c

Please sign in to comment.