diff --git a/ferminet/checkpoint.py b/ferminet/checkpoint.py index 192fd7a..87f443b 100644 --- a/ferminet/checkpoint.py +++ b/ferminet/checkpoint.py @@ -14,12 +14,14 @@ """Super simple checkpoints using numpy.""" +import dataclasses import datetime import os from typing import Optional import zipfile from absl import logging +from ferminet import networks import jax import jax.numpy as jnp import numpy as np @@ -108,7 +110,7 @@ def save(save_path: str, t: int, data, params, opt_state, mcmc_width) -> str: np.savez( f, t=t, - data=data, + data=dataclasses.asdict(data), params=params, opt_state=opt_state, mcmc_width=mcmc_width) @@ -144,16 +146,20 @@ def restore(restore_filename: str, batch_size: Optional[int] = None): # Retrieve data from npz file. Non-array variables need to be converted back # to natives types using .tolist(). t = ckpt_data['t'].tolist() + 1 # Return the iterations completed. - data = ckpt_data['data'] + data = networks.FermiNetData(**ckpt_data['data'].item()) params = ckpt_data['params'].tolist() opt_state = ckpt_data['opt_state'].tolist() mcmc_width = jnp.array(ckpt_data['mcmc_width'].tolist()) - if data.shape[0] != jax.device_count(): + if data.positions.shape[0] != jax.device_count(): raise ValueError( - f'Incorrect number of devices found. Expected {data.shape[0]}, found ' - f'{jax.device_count()}.') - if batch_size and data.shape[0] * data.shape[1] != batch_size: + 'Incorrect number of devices found. Expected' + f' {data.positions.shape[0]}, found {jax.device_count()}.' + ) + if ( + batch_size + and data.positions.shape[0] * data.positions.shape[1] != batch_size + ): raise ValueError( f'Wrong batch size in loaded data. Expected {batch_size}, found ' - f'{data.shape[0] * data.shape[1]}.') + f'{data.positions.shape[0] * data.positions.shape[1]}.') return t, data, params, opt_state, mcmc_width diff --git a/ferminet/train.py b/ferminet/train.py index 5ebbac0..8b8ea32 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -499,6 +499,9 @@ def log_network(*args, **kwargs): pos = kfac_jax.utils.broadcast_all_local_devices(pos) spins = jnp.reshape(spins, data_shape + spins.shape[1:]) spins = kfac_jax.utils.broadcast_all_local_devices(spins) + data = networks.FermiNetData( + positions=pos, spins=spins, atoms=batch_atoms, charges=batch_charges + ) t_init = 0 opt_state_ckpt = None @@ -523,12 +526,12 @@ def log_network(*args, **kwargs): network.orbitals, in_axes=(None, 0, 0, 0, 0), out_axes=0 ) sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key) - params, pos = pretrain.pretrain_hartree_fock( + params, data.positions = pretrain.pretrain_hartree_fock( params=params, - positions=pos, + positions=data.positions, spins=pretrain_spins, - atoms=batch_atoms, - charges=batch_charges, + atoms=data.atoms, + charges=data.charges, batch_network=batch_network, batch_orbitals=batch_orbitals, network_options=network.options, @@ -620,13 +623,7 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray: # debug=True ) sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key) - opt_state = optimizer.init( - params, - subkeys, - networks.FermiNetData( - positions=pos, spins=spins, atoms=batch_atoms, charges=batch_charges - ), - ) + opt_state = optimizer.init(params, subkeys, data) opt_state = opt_state_ckpt or opt_state # avoid overwriting ckpted state else: raise ValueError(f'Not a recognized optimizer: {cfg.optim.optimizer}') @@ -658,8 +655,6 @@ def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray: jnp.asarray(cfg.mcmc.move_width)) pmoves = np.zeros(cfg.mcmc.adapt_frequency) - data = networks.FermiNetData( - positions=pos, spins=spins, atoms=batch_atoms, charges=batch_charges) if t_init == 0: logging.info('Burning in MCMC chain for %d steps', cfg.mcmc.burn_in)