Skip to content

Commit

Permalink
Fix checkpointing to handle saving and restoring FermiNetData dataclass.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 536688175
Change-Id: Ic15c3b2d4530fa5403210e344aff4c4f4c4a3ecb
  • Loading branch information
jsspencer committed May 31, 2023
1 parent 5bbffc2 commit d3acd19
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
20 changes: 13 additions & 7 deletions ferminet/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

"""Super simple checkpoints using numpy."""

import dataclasses
import datetime
import os
from typing import Optional
import zipfile

from absl import logging
from ferminet import networks
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -108,7 +110,7 @@ def save(save_path: str, t: int, data, params, opt_state, mcmc_width) -> str:
np.savez(
f,
t=t,
data=data,
data=dataclasses.asdict(data),
params=params,
opt_state=opt_state,
mcmc_width=mcmc_width)
Expand Down Expand Up @@ -144,16 +146,20 @@ def restore(restore_filename: str, batch_size: Optional[int] = None):
# Retrieve data from npz file. Non-array variables need to be converted back
# to natives types using .tolist().
t = ckpt_data['t'].tolist() + 1 # Return the iterations completed.
data = ckpt_data['data']
data = networks.FermiNetData(**ckpt_data['data'].item())
params = ckpt_data['params'].tolist()
opt_state = ckpt_data['opt_state'].tolist()
mcmc_width = jnp.array(ckpt_data['mcmc_width'].tolist())
if data.shape[0] != jax.device_count():
if data.positions.shape[0] != jax.device_count():
raise ValueError(
f'Incorrect number of devices found. Expected {data.shape[0]}, found '
f'{jax.device_count()}.')
if batch_size and data.shape[0] * data.shape[1] != batch_size:
'Incorrect number of devices found. Expected'
f' {data.positions.shape[0]}, found {jax.device_count()}.'
)
if (
batch_size
and data.positions.shape[0] * data.positions.shape[1] != batch_size
):
raise ValueError(
f'Wrong batch size in loaded data. Expected {batch_size}, found '
f'{data.shape[0] * data.shape[1]}.')
f'{data.positions.shape[0] * data.positions.shape[1]}.')
return t, data, params, opt_state, mcmc_width
21 changes: 8 additions & 13 deletions ferminet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,9 @@ def log_network(*args, **kwargs):
pos = kfac_jax.utils.broadcast_all_local_devices(pos)
spins = jnp.reshape(spins, data_shape + spins.shape[1:])
spins = kfac_jax.utils.broadcast_all_local_devices(spins)
data = networks.FermiNetData(
positions=pos, spins=spins, atoms=batch_atoms, charges=batch_charges
)

t_init = 0
opt_state_ckpt = None
Expand All @@ -523,12 +526,12 @@ def log_network(*args, **kwargs):
network.orbitals, in_axes=(None, 0, 0, 0, 0), out_axes=0
)
sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key)
params, pos = pretrain.pretrain_hartree_fock(
params, data.positions = pretrain.pretrain_hartree_fock(
params=params,
positions=pos,
positions=data.positions,
spins=pretrain_spins,
atoms=batch_atoms,
charges=batch_charges,
atoms=data.atoms,
charges=data.charges,
batch_network=batch_network,
batch_orbitals=batch_orbitals,
network_options=network.options,
Expand Down Expand Up @@ -620,13 +623,7 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray:
# debug=True
)
sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key)
opt_state = optimizer.init(
params,
subkeys,
networks.FermiNetData(
positions=pos, spins=spins, atoms=batch_atoms, charges=batch_charges
),
)
opt_state = optimizer.init(params, subkeys, data)
opt_state = opt_state_ckpt or opt_state # avoid overwriting ckpted state
else:
raise ValueError(f'Not a recognized optimizer: {cfg.optim.optimizer}')
Expand Down Expand Up @@ -658,8 +655,6 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray:
jnp.asarray(cfg.mcmc.move_width))
pmoves = np.zeros(cfg.mcmc.adapt_frequency)

data = networks.FermiNetData(
positions=pos, spins=spins, atoms=batch_atoms, charges=batch_charges)
if t_init == 0:
logging.info('Burning in MCMC chain for %d steps', cfg.mcmc.burn_in)

Expand Down

0 comments on commit d3acd19

Please sign in to comment.