Skip to content

Commit

Permalink
Enable pretraining with samples from SCF wavefunction and pseudopoten…
Browse files Browse the repository at this point in the history
…tials

PiperOrigin-RevId: 643407243
Change-Id: I1f59587232666c8c634b974f939440a82963aae2
  • Loading branch information
dpfau committed Aug 22, 2024
1 parent 2383bb3 commit 2b034ad
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 166 deletions.
3 changes: 3 additions & 0 deletions ferminet/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ def default() -> ml_collections.ConfigDict:
'method': 'hf', # Currently only 'hf' is supported.
'iterations': 1000, # Only used if method is 'hf'.
'basis': 'ccpvdz', # Larger than STO-6G, but good for excited states
# Fraction of SCF to use in pretraining MCMC. This enables pretraining
# similar to the original FermiNet paper.
'scf_fraction': 1.0,
# The way to construct different states for excited state pretraining.
# One of 'ordered' or 'random'. 'Ordered' tends to work better, but
# 'random' is necessary for some systems, especially double
Expand Down
252 changes: 141 additions & 111 deletions ferminet/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Utilities for pretraining and importing PySCF models."""

from typing import Callable, Optional, Sequence, Tuple, Union
from typing import Callable, Mapping, Sequence, Tuple, Union

from absl import logging
import chex
Expand All @@ -31,11 +31,13 @@
import pyscf


def get_hf(molecule: Optional[Sequence[system.Atom]] = None,
nspins: Optional[Tuple[int, int]] = None,
basis: Optional[str] = 'sto-3g',
pyscf_mol: Optional[pyscf.gto.Mole] = None,
restricted: Optional[bool] = False,
def get_hf(molecule: Sequence[system.Atom] | None = None,
nspins: Tuple[int, int] | None = None,
basis: str | None = 'sto-3g',
ecp: Mapping[str, str] | None = None,
core_electrons: Mapping[str, int] | None = None,
pyscf_mol: pyscf.gto.Mole | None = None,
restricted: bool | None = False,
states: int = 0,
excitation_type: str = 'ordered') -> scf.Scf:
"""Returns an Scf object with the Hartree-Fock solution to the system.
Expand All @@ -44,6 +46,9 @@ def get_hf(molecule: Optional[Sequence[system.Atom]] = None,
molecule: the molecule in internal format.
nspins: tuple with number of spin up and spin down electrons.
basis: basis set to use in Hartree-Fock calculation.
ecp: dictionary of the ECP to use for different atoms.
core_electrons: dictionary of the number of core electrons excluded by the
pseudopotential/effective core potential.
pyscf_mol: pyscf Mole object defining the molecule. If supplied,
molecule, nspins and basis are ignored.
restricted: If true, perform a restricted Hartree-Fock calculation,
Expand All @@ -56,10 +61,15 @@ def get_hf(molecule: Optional[Sequence[system.Atom]] = None,
but 'random' is necessary for some systems, especially double excitaitons.
"""
if pyscf_mol:
scf_approx = scf.Scf(pyscf_mol=pyscf_mol, restricted=restricted)
scf_approx = scf.Scf(pyscf_mol=pyscf_mol,
restricted=restricted)
else:
scf_approx = scf.Scf(
molecule, nelectrons=nspins, basis=basis, restricted=restricted)
scf_approx = scf.Scf(molecule,
nelectrons=nspins,
basis=basis,
ecp=ecp,
core_electrons=core_electrons,
restricted=restricted)
scf_approx.run(excitations=max(states - 1, 0),
excitation_type=excitation_type)
return scf_approx
Expand Down Expand Up @@ -100,34 +110,14 @@ def eval_orbitals(scf_approx: scf.Scf, pos: Union[np.ndarray, jnp.ndarray],
return alpha_spin, beta_spin


def eval_slater(scf_approx: scf.Scf, pos: Union[jnp.ndarray, np.ndarray],
nspins: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray]:
"""Evaluates the Slater determinant.
Args:
scf_approx: an object that contains the result of a PySCF calculation.
pos: an array of electron positions to evaluate the orbitals at.
nspins: tuple with number of spin up and spin down electrons.
Returns:
tuple with sign and log absolute value of Slater determinant.
"""
matrices = eval_orbitals(scf_approx, pos, nspins)
slogdets = [np.linalg.slogdet(elem) for elem in matrices]
sign_alpha, sign_beta = [elem[0] for elem in slogdets]
log_abs_wf_alpha, log_abs_wf_beta = [elem[1] for elem in slogdets]
log_abs_slater_determinant = log_abs_wf_alpha + log_abs_wf_beta
sign = sign_alpha * sign_beta
return sign, log_abs_slater_determinant


def make_pretrain_step(
batch_orbitals: networks.OrbitalFnLike,
batch_network: networks.LogFermiNetLike,
optimizer_update: optax.TransformUpdateFn,
scf_approx: scf.Scf,
electrons: Tuple[int, int],
batch_size: int = 0,
full_det: bool = False,
scf_fraction: float = 0.0,
states: int = 0,
):
"""Creates function for performing one step of Hartre-Fock pretraining.
Expand All @@ -141,93 +131,129 @@ def make_pretrain_step(
magnitude of the (wavefunction) network evaluated at those positions.
optimizer_update: callable for transforming the gradients into an update (ie
conforms to the optax API).
scf_approx: an scf.Scf object that contains the result of a PySCF
calculation.
electrons: number of spin-up and spin-down electrons.
batch_size: number of walkers per device, used to make MCMC step.
full_det: If true, evaluate all electrons in a single determinant.
Otherwise, evaluate products of alpha- and beta-spin determinants.
scf_fraction: What fraction of the wavefunction sampled from is the SCF
wavefunction and what fraction is the neural network wavefunction?
states: Number of excited states, if not 0.
Returns:
Callable for performing a single pretraining optimisation step.
"""

def pretrain_step(data, params, state, key, logprob):
"""One iteration of pretraining to match HF."""

cnorm = lambda x, y: (x - y) * jnp.conj(x - y) # complex norm
def loss_fn(
params: networks.ParamTree,
data: networks.FermiNetData,
):
pos = data.positions
spins = data.spins
if states:
# Make vmap-ed versions of eval_orbitals and batch_orbitals over the
# states dimension.
# (batch, states, nelec*ndim)
pos = jnp.reshape(pos, pos.shape[:-1] + (states, -1))
# (batch, states, nelec)
spins = jnp.reshape(spins, spins.shape[:-1] + (states, -1))

scf_orbitals = jax.vmap(
scf_approx.eval_orbitals, in_axes=(-2, None), out_axes=-4
)

def net_orbitals(params, pos, spins, atoms, charges):
vmapped_orbitals = jax.vmap(
batch_orbitals, in_axes=(None, -2, -2, None, None), out_axes=-4
)
# Dimensions of result are
# [(batch, states, ndet*states, nelec, nelec)]
result = vmapped_orbitals(params, pos, spins, atoms, charges)
result = [
jnp.reshape(r, r.shape[:-3] + (states, -1) + r.shape[-2:])
for r in result
]
result = [jnp.transpose(r, (0, 3, 1, 2, 4, 5)) for r in result]
# We draw distinct samples for each excited state (electron
# configuration), and then evaluate each state within each sample.
# Output dimensions are:
# (batch, det, electron configuration,
# excited state, electron, orbital)
return result

else:
scf_orbitals = scf_approx.eval_orbitals
net_orbitals = batch_orbitals

target = scf_orbitals(pos, electrons)
orbitals = net_orbitals(params, pos, spins, data.atoms, data.charges)
if full_det:
dims = target[0].shape[:-2] # (batch) or (batch, states).
na = target[0].shape[-2]
nb = target[1].shape[-2]
target = jnp.concatenate(
(
jnp.concatenate(
(target[0], jnp.zeros(dims + (na, nb))), axis=-1),
jnp.concatenate(
(jnp.zeros(dims + (nb, na)), target[1]), axis=-1),
),
axis=-2,
# Create a function which gives either the SCF ansatz, the neural network
# ansatz, or a weighted mixture of the two.
if scf_fraction > 1 or scf_fraction < 0:
raise ValueError('scf_fraction must be in between 0 and 1, inclusive.')

if states:
def scf_network(fn, x):
x = x.reshape(x.shape[:-1] + (states, -1))
slater_fn = jax.vmap(fn, in_axes=(-2, None), out_axes=-2)
slogdets = slater_fn(x, electrons)
# logsumexp trick
maxlogdet = jnp.max(slogdets[1])
dets = slogdets[0] * jnp.exp(slogdets[1] - maxlogdet)
result = jnp.linalg.slogdet(dets)
return result[1] + maxlogdet * slogdets[1].shape[-1]
else:
scf_network = lambda fn, x: fn(x, electrons)[1]

if scf_fraction < 1e-6:
def mcmc_network(full_params, pos, spins, atoms, charges):
return batch_network(full_params['ferminet'], pos, spins, atoms, charges)
elif scf_fraction > 0.999999:
def mcmc_network(full_params, pos, spins, atoms, charges):
del spins, atoms, charges
return scf_network(full_params['scf'].eval_slater, pos)
else:
def mcmc_network(full_params, pos, spins, atoms, charges):
log_ferminet = batch_network(full_params['ferminet'], pos, spins, atoms,
charges)
log_scf = scf_network(full_params['scf'].eval_slater, pos)
return (1 - scf_fraction) * log_ferminet + scf_fraction * log_scf

mcmc_step = mcmc.make_mcmc_step(
mcmc_network, batch_per_device=batch_size, steps=1)

def loss_fn(
params: networks.ParamTree,
data: networks.FermiNetData,
scf_approx: scf.Scf,
):
pos = data.positions
spins = data.spins
if states:
# Make vmap-ed versions of eval_orbitals and batch_orbitals over the
# states dimension.
# (batch, states, nelec*ndim)
pos = jnp.reshape(pos, pos.shape[:-1] + (states, -1))
# (batch, states, nelec)
spins = jnp.reshape(spins, spins.shape[:-1] + (states, -1))

scf_orbitals = jax.vmap(
scf_approx.eval_orbitals, in_axes=(-2, None), out_axes=-4
)

def net_orbitals(params, pos, spins, atoms, charges):
vmapped_orbitals = jax.vmap(
batch_orbitals, in_axes=(None, -2, -2, None, None), out_axes=-4
)
result = jnp.mean(cnorm(target[:, None, ...], orbitals[0])).real
else:
result = jnp.array([
jnp.mean(cnorm(t[:, None, ...], o)).real
for t, o in zip(target, orbitals)
]).sum()
return constants.pmean(result)

# Dimensions of result are
# [(batch, states, ndet*states, nelec, nelec)]
result = vmapped_orbitals(params, pos, spins, atoms, charges)
result = [
jnp.reshape(r, r.shape[:-3] + (states, -1) + r.shape[-2:])
for r in result
]
result = [jnp.transpose(r, (0, 3, 1, 2, 4, 5)) for r in result]
# We draw distinct samples for each excited state (electron
# configuration), and then evaluate each state within each sample.
# Output dimensions are:
# (batch, det, electron configuration,
# excited state, electron, orbital)
return result

else:
scf_orbitals = scf_approx.eval_orbitals
net_orbitals = batch_orbitals

target = scf_orbitals(pos, electrons)
orbitals = net_orbitals(params, pos, spins, data.atoms, data.charges)
cnorm = lambda x, y: (x - y) * jnp.conj(x - y) # complex norm
if full_det:
dims = target[0].shape[:-2] # (batch) or (batch, states).
na = target[0].shape[-2]
nb = target[1].shape[-2]
target = jnp.concatenate(
(
jnp.concatenate(
(target[0], jnp.zeros(dims + (na, nb))), axis=-1),
jnp.concatenate(
(jnp.zeros(dims + (nb, na)), target[1]), axis=-1),
),
axis=-2,
)
result = jnp.mean(cnorm(target[:, None, ...], orbitals[0])).real
else:
result = jnp.array([
jnp.mean(cnorm(t[:, None, ...], o)).real
for t, o in zip(target, orbitals)
]).sum()
return constants.pmean(result)

def pretrain_step(data, params, state, key, scf_approx):
"""One iteration of pretraining to match HF."""
val_and_grad = jax.value_and_grad(loss_fn, argnums=0)
loss_val, search_direction = val_and_grad(params, data)
loss_val, search_direction = val_and_grad(params, data, scf_approx)
search_direction = constants.pmean(search_direction)
updates, state = optimizer_update(search_direction, state, params)
params = optax.apply_updates(params, updates)
data, key, logprob, _ = mcmc.mh_update(params, batch_network, data, key,
logprob, 0)
return data, params, state, loss_val, logprob
full_params = {'ferminet': params, 'scf': scf_approx}
data, pmove = mcmc_step(full_params, data, key, width=0.02)
return data, params, state, loss_val, pmove

return pretrain_step

Expand All @@ -246,7 +272,9 @@ def pretrain_hartree_fock(
electrons: Tuple[int, int],
scf_approx: scf.Scf,
iterations: int = 1000,
logger: Optional[Callable[[int, float], None]] = None,
batch_size: int = 0,
logger: Callable[[int, float], None] | None = None,
scf_fraction: float = 0.0,
states: int = 0,
):
"""Performs training to match initialization as closely as possible to HF.
Expand All @@ -271,8 +299,11 @@ def pretrain_hartree_fock(
scf_approx: an scf.Scf object that contains the result of a PySCF
calculation.
iterations: number of pretraining iterations to perform.
batch_size: number of walkers per device, used to make MCMC step.
logger: Callable with signature (step, value) which externally logs the
pretraining loss.
scf_fraction: What fraction of the wavefunction sampled from is the SCF
wavefunction and what fraction is the neural network wavefunction?
states: Number of excited states, if not 0.
Returns:
Expand All @@ -292,26 +323,25 @@ def pretrain_hartree_fock(
batch_orbitals,
batch_network,
optimizer.update,
scf_approx=scf_approx,
electrons=electrons,
batch_size=batch_size,
full_det=network_options.full_det,
scf_fraction=scf_fraction,
states=states,
)
pretrain_step = constants.pmap(pretrain_step)
pnetwork = constants.pmap(batch_network)

batch_spins = jnp.tile(spins[None], [positions.shape[1], 1])
pmap_spins = kfac_jax.utils.replicate_all_local_devices(batch_spins)
data = networks.FermiNetData(
positions=positions, spins=pmap_spins, atoms=atoms, charges=charges
)
logprob = 2.0 * pnetwork(params, positions, pmap_spins, atoms, charges)

for t in range(iterations):
sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key)
data, params, opt_state_pt, loss, logprob = pretrain_step(
data, params, opt_state_pt, subkeys, logprob)
logging.info('Pretrain iter %05d: %g', t, loss[0])
data, params, opt_state_pt, loss, pmove = pretrain_step(
data, params, opt_state_pt, subkeys, scf_approx)
logging.info('Pretrain iter %05d: %g %g', t, loss[0], pmove[0])
if logger:
logger(t, loss[0])
return params, data.positions
Loading

0 comments on commit 2b034ad

Please sign in to comment.