Skip to content

Commit

Permalink
Initial commit of natural excited states. Basic functionality only.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571350535
Change-Id: I2d9b8ae42a97ebf4381aec06e172a43d9dc223d0
  • Loading branch information
dpfau authored and jsspencer committed Nov 24, 2023
1 parent 18d1a1c commit 3fb9e5e
Show file tree
Hide file tree
Showing 12 changed files with 364 additions and 51 deletions.
4 changes: 4 additions & 0 deletions ferminet/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def default() -> ml_collections.ConfigDict:
# Dimensionality. Change with care. FermiNet implementation currently
# assumes 3D systems.
'ndim': 3,
# Number of excited states. If 0, use normal ground state machinery.
# If 1, compute ground state using excited state machinery. If >1,
# compute that many excited states.
'states': 0,
# Units of *input* coords of atoms. Either 'bohr' or
# 'angstrom'. Internally work in a.u.; positions in
# Angstroms are converged to Bohr.
Expand Down
78 changes: 78 additions & 0 deletions ferminet/configs/li_excited.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example excited states config for lithium atom with FermiNet."""

from ferminet import base_config
from ferminet.utils import elements
from ferminet.utils import system
import ml_collections


def _adjust_nuclear_charge(cfg):
"""Sets the molecule, nuclear charge electrons for the atom.
Note: function name predates this logic but is kept for compatibility with
xm_expt.py.
Args:
cfg: ml_collections.ConfigDict after all argument parsing.
Returns:
ml_collections.ConfictDict with the nuclear charge for the atom in
cfg.system.molecule and cfg.system.charge appropriately set.
"""
if cfg.system.molecule:
atom = cfg.system.molecule[0]
else:
atom = system.Atom(symbol=cfg.system.atom, coords=(0, 0, 0))

if abs(cfg.system.delta_charge) > 1.e-8:
nuclear_charge = atom.charge + cfg.system.delta_charge
cfg.system.molecule = [
system.Atom(atom.symbol, atom.coords, nuclear_charge)
]
else:
cfg.system.molecule = [atom]

if not cfg.system.electrons:
atomic_number = elements.SYMBOLS[atom.symbol].atomic_number
if 'charge' in cfg.system:
atomic_number -= cfg.system.charge
if ('spin_polarisation' in cfg.system
and cfg.system.spin_polarisation is not None):
spin_polarisation = cfg.system.spin_polarisation
else:
spin_polarisation = elements.ATOMIC_NUMS[atomic_number].spin_config
nalpha = (atomic_number + spin_polarisation) // 2
cfg.system.electrons = (nalpha, atomic_number - nalpha)

return cfg


def get_config():
"""Returns config for running generic atoms with qmc."""
cfg = base_config.default()
cfg.system.atom = 'Li'
cfg.system.charge = 0
cfg.system.delta_charge = 0.0
cfg.system.states = 3
cfg.pretrain.iterations = 0
cfg.optim.reset_if_nan = True
cfg.system.spin_polarisation = ml_collections.FieldReference(
None, field_type=int)
with cfg.ignore_type():
cfg.system.set_molecule = _adjust_nuclear_charge
cfg.config_module = '.atom'
return cfg
114 changes: 90 additions & 24 deletions ferminet/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

"""Evaluating the Hamiltonian on a wavefunction."""

from typing import Any, Callable, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import chex
from ferminet import networks
from ferminet.utils import utils
import jax
from jax import lax
import jax.numpy as jnp
Expand All @@ -35,7 +36,7 @@ def __call__(
params: networks.ParamTree,
key: chex.PRNGKey,
data: networks.FermiNetData,
) -> jnp.ndarray:
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Returns the local energy of a Hamiltonian at a configuration.
Args:
Expand Down Expand Up @@ -74,16 +75,6 @@ def __call__(
]


def select_output(f: Callable[..., Sequence[Any]],
argnum: int) -> Callable[..., Any]:
"""Return the argnum-th result from callable f."""

def f_selected(*args, **kwargs):
return f(*args, **kwargs)[argnum]

return f_selected


def local_kinetic_energy(
f: networks.FermiNetLike,
use_scan: bool = False,
Expand All @@ -102,8 +93,8 @@ def local_kinetic_energy(
-1/2f \nabla^2 f = -1/2 (\nabla^2 log|f| + (\nabla log|f|)^2).
"""

phase_f = select_output(f, 0)
logabs_f = select_output(f, 1)
phase_f = utils.select_output(f, 0)
logabs_f = utils.select_output(f, 1)

def _lapl_over_f(params, data):
n = data.positions.shape[0]
Expand Down Expand Up @@ -142,6 +133,52 @@ def grad_phase_closure(x):
return _lapl_over_f


def excited_kinetic_energy_matrix(f: networks.FermiNetLike,
states: int) -> KineticEnergy:
"""Creates a f'n which evaluates the matrix of local kinetic energies.
Args:
f: A network which returns a tuple of sign(psi) and log(|psi|) arrays, where
each array contains one element per excited state.
states: the number of excited states
Returns:
A function which computes the matrices (psi) and (K psi), which are the
value of the wavefunction and the kinetic energy applied to the
wavefunction for all combinations of electron sets and excited states.
"""

def _lapl_all_states(params, pos, spins, atoms, charges):
"""Return K psi/psi for each excited state."""
n = pos.shape[0]
eye = jnp.eye(n)
grad_f = jax.jacrev(utils.select_output(f, 1), argnums=1)
grad_f_closure = lambda x: grad_f(params, x, spins, atoms, charges)
primal, dgrad_f = jax.linearize(grad_f_closure, pos)

result = -0.5 * lax.fori_loop(
0, n, lambda i, val: val + dgrad_f(eye[i])[:, i], jnp.zeros(states))

return result - 0.5 * jnp.sum(primal ** 2, axis=-1)

def _lapl_over_f(params, data):
"""Return the kinetic energy (divided by psi) summed over excited states."""
pos_ = jnp.reshape(data.positions, [states, -1])
spins_ = jnp.reshape(data.spins, [states, -1])
vmap_f = jax.vmap(f, (None, 0, 0, None, None))
sign_mat, log_mat = vmap_f(params, pos_, spins_, data.atoms, data.charges)
vmap_lapl = jax.vmap(_lapl_all_states, (None, 0, 0, None, None))
lapl = vmap_lapl(params, pos_, spins_, data.atoms,
data.charges) # K psi_i(r_j) / psi_i(r_j)

# subtract off largest value to avoid under/overflow
psi_mat = sign_mat * jnp.exp(log_mat - jnp.max(log_mat)) # psi_i(r_j)
kpsi_mat = lapl * psi_mat # K psi_i(r_j)
return psi_mat, kpsi_mat

return _lapl_over_f


def potential_electron_electron(r_ee: Array) -> jnp.ndarray:
"""Returns the electron-electron potential.
Expand Down Expand Up @@ -201,6 +238,7 @@ def local_energy(
nspins: Sequence[int],
use_scan: bool = False,
complex_output: bool = False,
states: int = 0,
) -> LocalEnergy:
"""Creates the function to evaluate the local energy.
Expand All @@ -211,20 +249,28 @@ def local_energy(
nspins: Number of particles of each spin.
use_scan: Whether to use a `lax.scan` for computing the laplacian.
complex_output: If true, the output of f is complex-valued.
states: Number of excited states to compute. If 0, compute ground state with
default machinery. If 1, compute ground state with excited state machinery
Returns:
Callable with signature e_l(params, key, data) which evaluates the local
energy of the wavefunction given the parameters params, RNG state key,
and a single MCMC configuration in data.
"""
if complex_output and states > 1:
raise NotImplementedError(
'Excited states not implemented with complex output')
del nspins
ke = local_kinetic_energy(f,
use_scan=use_scan,
complex_output=complex_output)
if states:
ke = excited_kinetic_energy_matrix(f, states)
else:
ke = local_kinetic_energy(f,
use_scan=use_scan,
complex_output=complex_output)

def _e_l(
params: networks.ParamTree, key: chex.PRNGKey, data: networks.FermiNetData
) -> jnp.ndarray:
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Returns the total energy.
Args:
Expand All @@ -233,11 +279,31 @@ def _e_l(
data: MCMC configuration.
"""
del key # unused
_, _, r_ae, r_ee = networks.construct_input_features(
data.positions, data.atoms
)
potential = potential_energy(r_ae, r_ee, data.atoms, charges)
kinetic = ke(params, data)
return potential + kinetic
if states:
# Compute features
vmap_features = jax.vmap(networks.construct_input_features, (0, None))
positions = jnp.reshape(data.positions, [states, -1])
_, _, r_ae, r_ee = vmap_features(positions, data.atoms)

# Compute potential energy
vmap_pot = jax.vmap(potential_energy, (0, 0, None, None))
pot_spectrum = vmap_pot(r_ae, r_ee, data.atoms, charges)[:, None]

# Compute kinetic energy and matrix of states
psi_mat, kin_mat = ke(params, data)

# Combine terms
hpsi_mat = kin_mat + psi_mat * pot_spectrum
energy_mat = jnp.linalg.solve(psi_mat, hpsi_mat)
total_energy = jnp.trace(energy_mat)
else:
_, _, r_ae, r_ee = networks.construct_input_features(
data.positions, data.atoms
)
potential = potential_energy(r_ae, r_ee, data.atoms, charges)
kinetic = ke(params, data)
total_energy = potential + kinetic
energy_mat = None # Not necessary for ground state
return total_energy, energy_mat

return _e_l
14 changes: 9 additions & 5 deletions ferminet/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ class AuxiliaryLossData:
local_energy: local energy for each MCMC configuration.
clipped_energy: local energy after clipping has been applied
grad_local_energy: gradient of the local energy.
local_energy_mat: for excited states, the local energy matrix.
"""
variance: jax.Array
local_energy: jax.Array
clipped_energy: jax.Array
grad_local_energy: jax.Array | None
local_energy_mat: jax.Array | None


class LossFn(Protocol):
Expand Down Expand Up @@ -172,7 +174,7 @@ def make_loss(network: networks.LogFermiNetLike,
0,
networks.FermiNetData(positions=0, spins=0, atoms=0, charges=0),
),
out_axes=0,
out_axes=(0, 0)
)
batch_network = jax.vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0)

Expand Down Expand Up @@ -200,7 +202,7 @@ def total_energy(
over the batch and over all devices inside a pmap.
"""
keys = jax.random.split(key, num=data.positions.shape[0])
e_l = batch_local_energy(params, keys, data)
e_l, e_l_mat = batch_local_energy(params, keys, data)
loss = constants.pmean(jnp.mean(e_l))
loss_diff = e_l - loss
variance = constants.pmean(jnp.mean(loss_diff * jnp.conj(loss_diff)))
Expand All @@ -209,6 +211,7 @@ def total_energy(
local_energy=e_l,
clipped_energy=e_l,
grad_local_energy=None,
local_energy_mat=e_l_mat,
)

@total_energy.defjvp
Expand Down Expand Up @@ -303,7 +306,7 @@ def make_wqmc_loss(
0,
networks.FermiNetData(positions=0, spins=0, atoms=0, charges=0),
),
out_axes=0,
out_axes=(0, 0)
)
batch_network = jax.vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0)

Expand Down Expand Up @@ -331,7 +334,7 @@ def total_energy(
over the batch and over all devices inside a pmap.
"""
keys = jax.random.split(key, num=data.positions.shape[0])
e_l = batch_local_energy(params, keys, data)
e_l, e_l_mat = batch_local_energy(params, keys, data)
loss = constants.pmean(jnp.mean(e_l))
loss_diff = e_l - loss
variance = constants.pmean(jnp.mean(loss_diff * jnp.conj(loss_diff)))
Expand All @@ -343,7 +346,7 @@ def batch_local_energy_pos(pos):
atoms=data.atoms,
charges=data.charges,
)
return batch_local_energy(params, keys, network_data).sum()
return batch_local_energy(params, keys, network_data)[0].sum()

grad_e_l = jax.grad(batch_local_energy_pos)(data.positions)
grad_e_l = jnp.tanh(jax.lax.stop_gradient(grad_e_l))
Expand All @@ -352,6 +355,7 @@ def batch_local_energy_pos(pos):
local_energy=e_l,
clipped_energy=e_l,
grad_local_energy=grad_e_l,
local_energy_mat=e_l_mat,
)

@total_energy.defjvp
Expand Down
Loading

0 comments on commit 3fb9e5e

Please sign in to comment.