From 3fb9e5e9db51ca81e8d28f85ce65d1040afd43b9 Mon Sep 17 00:00:00 2001 From: David Pfau Date: Fri, 6 Oct 2023 16:56:11 +0100 Subject: [PATCH] Initial commit of natural excited states. Basic functionality only. PiperOrigin-RevId: 571350535 Change-Id: I2d9b8ae42a97ebf4381aec06e172a43d9dc223d0 --- ferminet/base_config.py | 4 + ferminet/configs/li_excited.py | 78 +++++++++++++++++ ferminet/hamiltonian.py | 114 +++++++++++++++++++------ ferminet/loss.py | 14 +-- ferminet/networks.py | 102 +++++++++++++++++++++- ferminet/pbc/hamiltonian.py | 11 ++- ferminet/pbc/tests/hamiltonian_test.py | 4 +- ferminet/psiformer.py | 13 ++- ferminet/tests/hamiltonian_test.py | 2 +- ferminet/tests/train_test.py | 16 ++-- ferminet/train.py | 30 +++++-- ferminet/utils/utils.py | 27 ++++++ 12 files changed, 364 insertions(+), 51 deletions(-) create mode 100644 ferminet/configs/li_excited.py create mode 100644 ferminet/utils/utils.py diff --git a/ferminet/base_config.py b/ferminet/base_config.py index 391e0b0..a13e2a4 100644 --- a/ferminet/base_config.py +++ b/ferminet/base_config.py @@ -131,6 +131,10 @@ def default() -> ml_collections.ConfigDict: # Dimensionality. Change with care. FermiNet implementation currently # assumes 3D systems. 'ndim': 3, + # Number of excited states. If 0, use normal ground state machinery. + # If 1, compute ground state using excited state machinery. If >1, + # compute that many excited states. + 'states': 0, # Units of *input* coords of atoms. Either 'bohr' or # 'angstrom'. Internally work in a.u.; positions in # Angstroms are converged to Bohr. diff --git a/ferminet/configs/li_excited.py b/ferminet/configs/li_excited.py new file mode 100644 index 0000000..433dded --- /dev/null +++ b/ferminet/configs/li_excited.py @@ -0,0 +1,78 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example excited states config for lithium atom with FermiNet.""" + +from ferminet import base_config +from ferminet.utils import elements +from ferminet.utils import system +import ml_collections + + +def _adjust_nuclear_charge(cfg): + """Sets the molecule, nuclear charge electrons for the atom. + + Note: function name predates this logic but is kept for compatibility with + xm_expt.py. + + Args: + cfg: ml_collections.ConfigDict after all argument parsing. + + Returns: + ml_collections.ConfictDict with the nuclear charge for the atom in + cfg.system.molecule and cfg.system.charge appropriately set. + """ + if cfg.system.molecule: + atom = cfg.system.molecule[0] + else: + atom = system.Atom(symbol=cfg.system.atom, coords=(0, 0, 0)) + + if abs(cfg.system.delta_charge) > 1.e-8: + nuclear_charge = atom.charge + cfg.system.delta_charge + cfg.system.molecule = [ + system.Atom(atom.symbol, atom.coords, nuclear_charge) + ] + else: + cfg.system.molecule = [atom] + + if not cfg.system.electrons: + atomic_number = elements.SYMBOLS[atom.symbol].atomic_number + if 'charge' in cfg.system: + atomic_number -= cfg.system.charge + if ('spin_polarisation' in cfg.system + and cfg.system.spin_polarisation is not None): + spin_polarisation = cfg.system.spin_polarisation + else: + spin_polarisation = elements.ATOMIC_NUMS[atomic_number].spin_config + nalpha = (atomic_number + spin_polarisation) // 2 + cfg.system.electrons = (nalpha, atomic_number - nalpha) + + return cfg + + +def get_config(): + """Returns config for running generic atoms with qmc.""" + cfg = base_config.default() + cfg.system.atom = 'Li' + cfg.system.charge = 0 + cfg.system.delta_charge = 0.0 + cfg.system.states = 3 + cfg.pretrain.iterations = 0 + cfg.optim.reset_if_nan = True + cfg.system.spin_polarisation = ml_collections.FieldReference( + None, field_type=int) + with cfg.ignore_type(): + cfg.system.set_molecule = _adjust_nuclear_charge + cfg.config_module = '.atom' + return cfg diff --git a/ferminet/hamiltonian.py b/ferminet/hamiltonian.py index 9072d64..35fe254 100644 --- a/ferminet/hamiltonian.py +++ b/ferminet/hamiltonian.py @@ -14,10 +14,11 @@ """Evaluating the Hamiltonian on a wavefunction.""" -from typing import Any, Callable, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Union import chex from ferminet import networks +from ferminet.utils import utils import jax from jax import lax import jax.numpy as jnp @@ -35,7 +36,7 @@ def __call__( params: networks.ParamTree, key: chex.PRNGKey, data: networks.FermiNetData, - ) -> jnp.ndarray: + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: """Returns the local energy of a Hamiltonian at a configuration. Args: @@ -74,16 +75,6 @@ def __call__( ] -def select_output(f: Callable[..., Sequence[Any]], - argnum: int) -> Callable[..., Any]: - """Return the argnum-th result from callable f.""" - - def f_selected(*args, **kwargs): - return f(*args, **kwargs)[argnum] - - return f_selected - - def local_kinetic_energy( f: networks.FermiNetLike, use_scan: bool = False, @@ -102,8 +93,8 @@ def local_kinetic_energy( -1/2f \nabla^2 f = -1/2 (\nabla^2 log|f| + (\nabla log|f|)^2). """ - phase_f = select_output(f, 0) - logabs_f = select_output(f, 1) + phase_f = utils.select_output(f, 0) + logabs_f = utils.select_output(f, 1) def _lapl_over_f(params, data): n = data.positions.shape[0] @@ -142,6 +133,52 @@ def grad_phase_closure(x): return _lapl_over_f +def excited_kinetic_energy_matrix(f: networks.FermiNetLike, + states: int) -> KineticEnergy: + """Creates a f'n which evaluates the matrix of local kinetic energies. + + Args: + f: A network which returns a tuple of sign(psi) and log(|psi|) arrays, where + each array contains one element per excited state. + states: the number of excited states + + Returns: + A function which computes the matrices (psi) and (K psi), which are the + value of the wavefunction and the kinetic energy applied to the + wavefunction for all combinations of electron sets and excited states. + """ + + def _lapl_all_states(params, pos, spins, atoms, charges): + """Return K psi/psi for each excited state.""" + n = pos.shape[0] + eye = jnp.eye(n) + grad_f = jax.jacrev(utils.select_output(f, 1), argnums=1) + grad_f_closure = lambda x: grad_f(params, x, spins, atoms, charges) + primal, dgrad_f = jax.linearize(grad_f_closure, pos) + + result = -0.5 * lax.fori_loop( + 0, n, lambda i, val: val + dgrad_f(eye[i])[:, i], jnp.zeros(states)) + + return result - 0.5 * jnp.sum(primal ** 2, axis=-1) + + def _lapl_over_f(params, data): + """Return the kinetic energy (divided by psi) summed over excited states.""" + pos_ = jnp.reshape(data.positions, [states, -1]) + spins_ = jnp.reshape(data.spins, [states, -1]) + vmap_f = jax.vmap(f, (None, 0, 0, None, None)) + sign_mat, log_mat = vmap_f(params, pos_, spins_, data.atoms, data.charges) + vmap_lapl = jax.vmap(_lapl_all_states, (None, 0, 0, None, None)) + lapl = vmap_lapl(params, pos_, spins_, data.atoms, + data.charges) # K psi_i(r_j) / psi_i(r_j) + + # subtract off largest value to avoid under/overflow + psi_mat = sign_mat * jnp.exp(log_mat - jnp.max(log_mat)) # psi_i(r_j) + kpsi_mat = lapl * psi_mat # K psi_i(r_j) + return psi_mat, kpsi_mat + + return _lapl_over_f + + def potential_electron_electron(r_ee: Array) -> jnp.ndarray: """Returns the electron-electron potential. @@ -201,6 +238,7 @@ def local_energy( nspins: Sequence[int], use_scan: bool = False, complex_output: bool = False, + states: int = 0, ) -> LocalEnergy: """Creates the function to evaluate the local energy. @@ -211,20 +249,28 @@ def local_energy( nspins: Number of particles of each spin. use_scan: Whether to use a `lax.scan` for computing the laplacian. complex_output: If true, the output of f is complex-valued. + states: Number of excited states to compute. If 0, compute ground state with + default machinery. If 1, compute ground state with excited state machinery Returns: Callable with signature e_l(params, key, data) which evaluates the local energy of the wavefunction given the parameters params, RNG state key, and a single MCMC configuration in data. """ + if complex_output and states > 1: + raise NotImplementedError( + 'Excited states not implemented with complex output') del nspins - ke = local_kinetic_energy(f, - use_scan=use_scan, - complex_output=complex_output) + if states: + ke = excited_kinetic_energy_matrix(f, states) + else: + ke = local_kinetic_energy(f, + use_scan=use_scan, + complex_output=complex_output) def _e_l( params: networks.ParamTree, key: chex.PRNGKey, data: networks.FermiNetData - ) -> jnp.ndarray: + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: """Returns the total energy. Args: @@ -233,11 +279,31 @@ def _e_l( data: MCMC configuration. """ del key # unused - _, _, r_ae, r_ee = networks.construct_input_features( - data.positions, data.atoms - ) - potential = potential_energy(r_ae, r_ee, data.atoms, charges) - kinetic = ke(params, data) - return potential + kinetic + if states: + # Compute features + vmap_features = jax.vmap(networks.construct_input_features, (0, None)) + positions = jnp.reshape(data.positions, [states, -1]) + _, _, r_ae, r_ee = vmap_features(positions, data.atoms) + + # Compute potential energy + vmap_pot = jax.vmap(potential_energy, (0, 0, None, None)) + pot_spectrum = vmap_pot(r_ae, r_ee, data.atoms, charges)[:, None] + + # Compute kinetic energy and matrix of states + psi_mat, kin_mat = ke(params, data) + + # Combine terms + hpsi_mat = kin_mat + psi_mat * pot_spectrum + energy_mat = jnp.linalg.solve(psi_mat, hpsi_mat) + total_energy = jnp.trace(energy_mat) + else: + _, _, r_ae, r_ee = networks.construct_input_features( + data.positions, data.atoms + ) + potential = potential_energy(r_ae, r_ee, data.atoms, charges) + kinetic = ke(params, data) + total_energy = potential + kinetic + energy_mat = None # Not necessary for ground state + return total_energy, energy_mat return _e_l diff --git a/ferminet/loss.py b/ferminet/loss.py index 9fbfd80..14459ac 100644 --- a/ferminet/loss.py +++ b/ferminet/loss.py @@ -35,11 +35,13 @@ class AuxiliaryLossData: local_energy: local energy for each MCMC configuration. clipped_energy: local energy after clipping has been applied grad_local_energy: gradient of the local energy. + local_energy_mat: for excited states, the local energy matrix. """ variance: jax.Array local_energy: jax.Array clipped_energy: jax.Array grad_local_energy: jax.Array | None + local_energy_mat: jax.Array | None class LossFn(Protocol): @@ -172,7 +174,7 @@ def make_loss(network: networks.LogFermiNetLike, 0, networks.FermiNetData(positions=0, spins=0, atoms=0, charges=0), ), - out_axes=0, + out_axes=(0, 0) ) batch_network = jax.vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0) @@ -200,7 +202,7 @@ def total_energy( over the batch and over all devices inside a pmap. """ keys = jax.random.split(key, num=data.positions.shape[0]) - e_l = batch_local_energy(params, keys, data) + e_l, e_l_mat = batch_local_energy(params, keys, data) loss = constants.pmean(jnp.mean(e_l)) loss_diff = e_l - loss variance = constants.pmean(jnp.mean(loss_diff * jnp.conj(loss_diff))) @@ -209,6 +211,7 @@ def total_energy( local_energy=e_l, clipped_energy=e_l, grad_local_energy=None, + local_energy_mat=e_l_mat, ) @total_energy.defjvp @@ -303,7 +306,7 @@ def make_wqmc_loss( 0, networks.FermiNetData(positions=0, spins=0, atoms=0, charges=0), ), - out_axes=0, + out_axes=(0, 0) ) batch_network = jax.vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0) @@ -331,7 +334,7 @@ def total_energy( over the batch and over all devices inside a pmap. """ keys = jax.random.split(key, num=data.positions.shape[0]) - e_l = batch_local_energy(params, keys, data) + e_l, e_l_mat = batch_local_energy(params, keys, data) loss = constants.pmean(jnp.mean(e_l)) loss_diff = e_l - loss variance = constants.pmean(jnp.mean(loss_diff * jnp.conj(loss_diff))) @@ -343,7 +346,7 @@ def batch_local_energy_pos(pos): atoms=data.atoms, charges=data.charges, ) - return batch_local_energy(params, keys, network_data).sum() + return batch_local_energy(params, keys, network_data)[0].sum() grad_e_l = jax.grad(batch_local_energy_pos)(data.positions) grad_e_l = jnp.tanh(jax.lax.stop_gradient(grad_e_l)) @@ -352,6 +355,7 @@ def batch_local_energy_pos(pos): local_energy=e_l, clipped_energy=e_l, grad_local_energy=grad_e_l, + local_energy_mat=e_l_mat, ) @total_energy.defjvp diff --git a/ferminet/networks.py b/ferminet/networks.py index d8075b8..ef7a0c1 100644 --- a/ferminet/networks.py +++ b/ferminet/networks.py @@ -14,6 +14,7 @@ """Implementation of Fermionic Neural Network in JAX.""" import enum +import functools from typing import Any, Iterable, Mapping, MutableMapping, Optional, Sequence, Tuple, Union import attr @@ -267,6 +268,7 @@ class BaseNetworkOptions: Attributes: ndim: dimension of system. Change only with caution. determinants: Number of determinants to use. + states: Number of outputs, one per excited (or ground) state. Ignored if 0. full_det: If true, evaluate determinants over all electrons. Otherwise, block-diagonalise determinants into spin channels. rescale_inputs: If true, rescale the inputs so they grow as log(|r|). @@ -281,6 +283,7 @@ class BaseNetworkOptions: ndim: int = 3 determinants: int = 16 + states: int = 0 full_det: bool = True rescale_inputs: bool = False bias_orbitals: bool = False @@ -1096,14 +1099,15 @@ def init(key: chex.PRNGKey) -> ParamTree: # How many spin-orbitals do we need to create per spin channel? nspin_orbitals = [] + num_states = max(options.states, 1) for nspin in active_spin_channels: if options.full_det: # Dense determinant. Need N orbitals per electron per determinant. - norbitals = sum(nspins) * options.determinants + norbitals = sum(nspins) * options.determinants * num_states else: # Spin-factored block-diagonal determinant. Need nspin orbitals per # electron per determinant. - norbitals = nspin * options.determinants + norbitals = nspin * options.determinants * num_states if options.complex_output: norbitals *= 2 # one output is real, one is imaginary nspin_orbitals.append(norbitals) @@ -1241,6 +1245,83 @@ def apply( return init, apply +## Excited States ## + + +def make_state_matrix(signed_network: FermiNetLike, n: int) -> FermiNetLike: + """Construct a matrix-output ansatz which gives the Slater matrix of states. + + Let signed_network(params, pos, spins, options) be a function which returns + psi_1(pos), psi_2(pos), ... psi_n(pos) as a pair of arrays of length n, one + with values of sign(psi_k), one with values of log(psi_k). Then this function + returns a new function which computes the matrix psi_i(pos_j), given an array + of positions (and possibly spins) which has n times as many dimensions as + expected by signed_network. The output of this new meta-matrix is also given + as a sign, log pair. + + Args: + signed_network: A function with the same calling convention as the FermiNet. + n: the number of excited states, needed to know how to shape the determinant + + Returns: + A function with two outputs which combines the individual excited states + into a matrix of wavefunctions, one with the sign and one with the log. + """ + + def state_matrix(params, pos, spins, atoms, charges, **kwargs): + """Evaluate state_matrix for a given ansatz.""" + # `pos` has shape (n*nelectron*ndim), but can be reshaped as + # (n, nelectron, ndim), that is, the first dimension indexes which excited + # state we are considering, the second indexes electrons, and the third + # indexes spatial dimensions. `spins` has the same ordering of indices, + # but does not have the spatial dimensions. `atoms` does not have the + # leading index of number of excited states, as the different states are + # always evaluated at the same atomic geometry. + pos_ = jnp.reshape(pos, [n, -1]) + partial_network = functools.partial( + signed_network, atoms=atoms, charges=charges, **kwargs) + spins_ = jnp.reshape(spins, [n, -1]) + vmap_network = jax.vmap(partial_network, (None, 0, 0)) + sign_mat, log_mat = vmap_network(params, pos_, spins_) + return sign_mat, log_mat + + return state_matrix + + +def make_total_ansatz(signed_network: FermiNetLike, n: int) -> FermiNetLike: + """Construct a single-output ansatz which gives the meta-Slater determinant. + + Let signed_network(params, pos, spins, options) be a function which returns + psi_1(pos), psi_2(pos), ... psi_n(pos) as a pair of arrays, one with values + of sign(psi_k), one with values of log(psi_k). Then this function returns a + new function which computes det[psi_i(pos_j)], given an array of positions + (and possibly spins) which has n times as many dimensions as expected by + signed_network. The output of this new meta-determinant is also given as a + sign, log pair. + + Args: + signed_network: A function with the same calling convention as the FermiNet. + n: the number of excited states, needed to know how to shape the determinant + + Returns: + A function with a single output which combines the individual excited states + into a greater wavefunction given by the meta-Slater determinant. + """ + state_matrix = make_state_matrix(signed_network, n) + + def total_ansatz(params, pos, spins, atoms, charges, **kwargs): + """Evaluate meta_determinant for a given ansatz.""" + sign_mat, log_mat = state_matrix( + params, pos, spins, atoms=atoms, charges=charges, **kwargs) + + logmax = jnp.max(log_mat) # logsumexp trick + sign_out, log_out = jnp.linalg.slogdet(sign_mat * jnp.exp(log_mat - logmax)) + log_out += n * logmax + return sign_out, log_out + + return total_ansatz + + ## FermiNet ## @@ -1250,6 +1331,7 @@ def make_fermi_net( *, ndim: int = 3, determinants: int = 16, + states: int = 0, envelope: Optional[envelopes.Envelope] = None, feature_layer: Optional[FeatureLayer] = None, jastrow: Union[str, jastrows.JastrowType] = jastrows.JastrowType.NONE, @@ -1273,6 +1355,7 @@ def make_fermi_net( charges: (natom) array of atom nuclear charges. ndim: dimension of system. Change only with caution. determinants: Number of determinants to use. + states: Number of outputs, one per excited (or ground) state. Ignored if 0. envelope: Envelope to use to impose orbitals go to zero at infinity. feature_layer: Input feature construction. jastrow: Type of Jastrow factor if used, or no jastrow if 'default'. @@ -1306,7 +1389,9 @@ def make_fermi_net( Network object containing init, apply, orbitals, options, where init and apply are callables which initialise the network parameters and apply the network respectively, orbitals is a callable which applies the network up to - the orbitals, and options specifies the settings used in the network. + the orbitals, and options specifies the settings used in the network. If + options.states > 1, the length of the vectors returned by apply are equal + to the number of states. """ if sum([nspin for nspin in nspins if nspin > 0]) == 0: raise ValueError('No electrons present!') @@ -1329,6 +1414,7 @@ def make_fermi_net( options = FermiNetOptions( ndim=ndim, determinants=determinants, + states=states, rescale_inputs=rescale_inputs, envelope=envelope, feature_layer=feature_layer, @@ -1384,7 +1470,15 @@ def apply( """ orbitals = orbitals_apply(params, pos, spins, atoms, charges) - return network_blocks.logdet_matmul(orbitals) + if options.states: + batch_logdet_matmul = jax.vmap(network_blocks.logdet_matmul, in_axes=0) + orbitals = [ + jnp.reshape(orbital, (options.states, -1) + orbital.shape[1:]) + for orbital in orbitals + ] + return batch_logdet_matmul(*orbitals) + else: + return network_blocks.logdet_matmul(orbitals) return Network( options=options, init=init, apply=apply, orbitals=orbitals_apply diff --git a/ferminet/pbc/hamiltonian.py b/ferminet/pbc/hamiltonian.py index 9e62f8d..c68a9a7 100644 --- a/ferminet/pbc/hamiltonian.py +++ b/ferminet/pbc/hamiltonian.py @@ -20,7 +20,7 @@ """ import itertools -from typing import Callable, Optional, Sequence +from typing import Callable, Optional, Sequence, Tuple import chex from ferminet import hamiltonian @@ -156,6 +156,7 @@ def local_energy( nspins: Sequence[int], use_scan: bool = False, complex_output: bool = False, + states: int = 0, lattice: Optional[jnp.ndarray] = None, heg: bool = True, convergence_radius: int = 5, @@ -169,6 +170,8 @@ def local_energy( nspins: Number of particles of each spin. use_scan: Whether to use a `lax.scan` for computing the laplacian. complex_output: If true, the output of f is complex-valued. + states: Number of excited states to compute. Not implemented, only present + for consistency of calling convention. lattice: Shape (ndim, ndim). Matrix of lattice vectors. Default: identity matrix. heg: bool. Flag to enable features specific to the electron gas. @@ -179,6 +182,8 @@ def local_energy( energy of the wavefunction given the parameters params, RNG state key, and a single MCMC configuration in data. """ + if states: + raise NotImplementedError('Excited states not implemented with PBC.') del nspins if lattice is None: lattice = jnp.eye(3) @@ -188,7 +193,7 @@ def local_energy( def _e_l( params: networks.ParamTree, key: chex.PRNGKey, data: networks.FermiNetData - ) -> jnp.ndarray: + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: """Returns the total energy. Args: @@ -204,6 +209,6 @@ def _e_l( data.positions, data.atoms) potential = potential_energy(ae, ee) kinetic = ke(params, data) - return potential + kinetic + return potential + kinetic, None return _e_l diff --git a/ferminet/pbc/tests/hamiltonian_test.py b/ferminet/pbc/tests/hamiltonian_test.py index 37f1821..7efd323 100644 --- a/ferminet/pbc/tests/hamiltonian_test.py +++ b/ferminet/pbc/tests/hamiltonian_test.py @@ -73,7 +73,7 @@ def test_periodicity(self): ) key, subkey = jax.random.split(key) - e1 = local_energy(params, subkey, data) + e1, _ = local_energy(params, subkey, data) # Select random electron coordinate to displace by a random lattice vec key, subkey = jax.random.split(key) @@ -87,7 +87,7 @@ def test_periodicity(self): ) key, subkey = jax.random.split(key) - e2 = local_energy(params, subkey, data2) + e2, _ = local_energy(params, subkey, data2) atol, rtol = 4.e-3, 4.e-3 np.testing.assert_allclose(e1, e2, atol=atol, rtol=rtol) diff --git a/ferminet/psiformer.py b/ferminet/psiformer.py index c5e058e..179142f 100644 --- a/ferminet/psiformer.py +++ b/ferminet/psiformer.py @@ -329,6 +329,7 @@ def make_fermi_net( *, ndim: int = 3, determinants: int = 16, + states: int = 0, envelope: Optional[envelopes.Envelope] = None, feature_layer: Optional[networks.FeatureLayer] = None, jastrow: Union[str, jastrows.JastrowType] = jastrows.JastrowType.SIMPLE_EE, @@ -351,6 +352,7 @@ def make_fermi_net( charges: (natom) array of atom nuclear charges. ndim: Dimension of the system. Change only with caution. determinants: Number of determinants. + states: Number of outputs, one per excited (or ground) state. Ignored if 0. envelope: Envelope to use to impose orbitals go to zero at infinity. feature_layer: Input feature construction. jastrow: Type of Jastrow factor if used, or 'simple_ee' if 'default'. @@ -389,6 +391,7 @@ def make_fermi_net( options = PsiformerOptions( ndim=ndim, determinants=determinants, + states=states, envelope=envelope, feature_layer=feature_layer, jastrow=jastrow, @@ -436,7 +439,15 @@ def network_apply( of and log absolute value of the network evaluated at x. """ orbitals = orbitals_apply(params, pos, spins, atoms, charges) - return network_blocks.logdet_matmul(orbitals) + if options.states: + batch_logdet_matmul = jax.vmap(network_blocks.logdet_matmul, in_axes=0) + orbitals = [ + jnp.reshape(orbital, (options.states, -1) + orbital.shape[1:]) + for orbital in orbitals + ] + return batch_logdet_matmul(*orbitals) + else: + return network_blocks.logdet_matmul(orbitals) return networks.Network( options=options, diff --git a/ferminet/tests/hamiltonian_test.py b/ferminet/tests/hamiltonian_test.py index e473d92..699e4d0 100644 --- a/ferminet/tests/hamiltonian_test.py +++ b/ferminet/tests/hamiltonian_test.py @@ -138,7 +138,7 @@ def test_local_energy(self): ), ), ) - energies = batch_local_energy( + energies, _ = batch_local_energy( dummy_params, keys, networks.FermiNetData( diff --git a/ferminet/tests/train_test.py b/ferminet/tests/train_test.py index 239d5b7..fc79a6d 100644 --- a/ferminet/tests/train_test.py +++ b/ferminet/tests/train_test.py @@ -49,14 +49,19 @@ def setUpModule(): def _config_params(): - for system, optimizer, complex_ in itertools.product( - ('Li', 'LiH'), ('kfac', 'adam'), (True, False)): - yield {'system': system, 'optimizer': optimizer, 'complex_': complex_} + for system, optimizer, complex_, states in itertools.product( + ('Li', 'LiH'), ('kfac', 'adam'), (True, False), (0, 2)): + if states == 0 or not complex_: + yield {'system': system, + 'optimizer': optimizer, + 'complex_': complex_, + 'states': states} for optimizer in ('kfac', 'adam', 'lamb', 'none'): yield { 'system': 'H' if optimizer in ('kfac', 'adam') else 'Li', 'optimizer': optimizer, 'complex_': False, + 'states': 0, } @@ -70,7 +75,7 @@ def setUp(self): pyscf.lib.param.TMPDIR = None @parameterized.parameters(_config_params()) - def test_training_step(self, system, optimizer, complex_): + def test_training_step(self, system, optimizer, complex_, states): if system in ('H', 'Li'): cfg = atom.get_config() cfg.system.atom = system @@ -81,7 +86,8 @@ def test_training_step(self, system, optimizer, complex_): cfg.network.determinants = 2 cfg.network.complex = complex_ cfg.batch_size = 32 - cfg.pretrain.iterations = 10 + cfg.system.states = states + cfg.pretrain.iterations = 10 if states == 0 else 0 cfg.mcmc.burn_in = 10 cfg.optim.optimizer = optimizer cfg.optim.iterations = 3 diff --git a/ferminet/train.py b/ferminet/train.py index e93109b..6e33e7a 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -33,6 +33,7 @@ from ferminet import psiformer from ferminet.utils import statistics from ferminet.utils import system +from ferminet.utils import utils from ferminet.utils import writers import jax from jax.experimental import multihost_utils @@ -369,6 +370,7 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): # Device logging num_devices = jax.local_device_count() num_hosts = jax.device_count() // num_devices + num_states = cfg.system.get('states', 0) or 1 # avoid 0/1 confusion logging.info('Starting QMC with %i XLA devices per host ' 'across %i hosts.', num_devices, num_hosts) if cfg.batch_size % (num_devices * num_hosts) != 0: @@ -376,6 +378,7 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): f'got batch size {cfg.batch_size} for ' f'{num_devices * num_hosts} devices.') host_batch_size = cfg.batch_size // num_hosts # batch size per host + total_host_batch_size = host_batch_size * num_states device_batch_size = host_batch_size // num_devices # batch size per device data_shape = (num_devices, device_batch_size) @@ -405,6 +408,9 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): # Create parameters, network, and vmaped/pmaped derivations if cfg.pretrain.method == 'hf' and cfg.pretrain.iterations > 0: + if cfg.system.states > 1: + raise NotImplementedError( + 'Pretraining not yet implemented for excited states') hartree_fock = pretrain.get_hf( pyscf_mol=cfg.system.get('pyscf_mol'), molecule=cfg.system.molecule, @@ -451,6 +457,7 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): charges, ndim=cfg.system.ndim, determinants=cfg.network.determinants, + states=cfg.system.states, envelope=envelope, feature_layer=feature_layer, jastrow=cfg.network.get('jastrow', 'default'), @@ -466,6 +473,7 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): charges, ndim=cfg.system.ndim, determinants=cfg.network.determinants, + states=cfg.system.states, envelope=envelope, feature_layer=feature_layer, jastrow=cfg.network.get('jastrow', 'default'), @@ -479,7 +487,12 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): params = kfac_jax.utils.replicate_all_local_devices(params) signed_network = network.apply # Often just need log|psi(x)|. - logabs_network = lambda *args, **kwargs: signed_network(*args, **kwargs)[1] + if cfg.system.get('states', 0): + logabs_network = utils.select_output( + networks.make_total_ansatz(signed_network, + cfg.system.get('states', 0)), 1) + else: + logabs_network = lambda *args, **kwargs: signed_network(*args, **kwargs)[1] batch_network = jax.vmap( logabs_network, in_axes=(None, 0, 0, 0, 0), out_axes=0 ) # batched network @@ -521,12 +534,15 @@ def log_network(*args, **kwargs): subkey, cfg.system.molecule, cfg.system.electrons, - batch_size=host_batch_size, + batch_size=total_host_batch_size, init_width=cfg.mcmc.init_width, ) - pos = jnp.reshape(pos, data_shape + pos.shape[1:]) + # For excited states, each device has a batch of walkers, where each walker + # is nstates * nelectrons. The vmap over nstates is handled in the function + # created in make_total_ansatz + pos = jnp.reshape(pos, data_shape + (-1,)) pos = kfac_jax.utils.broadcast_all_local_devices(pos) - spins = jnp.reshape(spins, data_shape + spins.shape[1:]) + spins = jnp.reshape(spins, data_shape + (-1,)) spins = kfac_jax.utils.broadcast_all_local_devices(spins) data = networks.FermiNetData( positions=pos, spins=spins, atoms=batch_atoms, charges=batch_charges @@ -579,7 +595,7 @@ def log_network(*args, **kwargs): device_batch_size, steps=cfg.mcmc.steps, atoms=atoms_to_mcmc, - blocks=cfg.mcmc.blocks, + blocks=cfg.mcmc.blocks * num_states, ) # Construct loss and optimizer if cfg.system.make_local_energy_fn: @@ -592,6 +608,7 @@ def log_network(*args, **kwargs): charges=charges, nspins=nspins, use_scan=False, + states=cfg.system.get('states', 0), **cfg.system.make_local_energy_kwargs) else: local_energy = hamiltonian.local_energy( @@ -599,7 +616,8 @@ def log_network(*args, **kwargs): charges=charges, nspins=nspins, use_scan=False, - complex_output=cfg.network.get('complex', False)) + complex_output=cfg.network.get('complex', False), + states=cfg.system.get('states', 0)) if cfg.optim.objective == 'vmc': evaluate_loss = qmc_loss_functions.make_loss( log_network if cfg.network.get('complex', False) else logabs_network, diff --git a/ferminet/utils/utils.py b/ferminet/utils/utils.py new file mode 100644 index 0000000..515193d --- /dev/null +++ b/ferminet/utils/utils.py @@ -0,0 +1,27 @@ +# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic utils for all QMC calculations.""" + +from typing import Any, Callable, Sequence + + +def select_output(f: Callable[..., Sequence[Any]], + argnum: int) -> Callable[..., Any]: + """Return the argnum-th result from callable f.""" + + def f_selected(*args, **kwargs): + return f(*args, **kwargs)[argnum] + + return f_selected