From 2b034adc35908dcbcf5dbbc46d85a6fdadbd7878 Mon Sep 17 00:00:00 2001 From: David Pfau Date: Fri, 14 Jun 2024 19:32:16 +0100 Subject: [PATCH] Enable pretraining with samples from SCF wavefunction and pseudopotentials PiperOrigin-RevId: 643407243 Change-Id: I1f59587232666c8c634b974f939440a82963aae2 --- ferminet/base_config.py | 3 + ferminet/pretrain.py | 252 ++++++++++++++++++++--------------- ferminet/tests/train_test.py | 24 +++- ferminet/train.py | 27 ++-- ferminet/utils/scf.py | 106 +++++++++------ 5 files changed, 246 insertions(+), 166 deletions(-) diff --git a/ferminet/base_config.py b/ferminet/base_config.py index b6e52e0..83991a6 100644 --- a/ferminet/base_config.py +++ b/ferminet/base_config.py @@ -298,6 +298,9 @@ def default() -> ml_collections.ConfigDict: 'method': 'hf', # Currently only 'hf' is supported. 'iterations': 1000, # Only used if method is 'hf'. 'basis': 'ccpvdz', # Larger than STO-6G, but good for excited states + # Fraction of SCF to use in pretraining MCMC. This enables pretraining + # similar to the original FermiNet paper. + 'scf_fraction': 1.0, # The way to construct different states for excited state pretraining. # One of 'ordered' or 'random'. 'Ordered' tends to work better, but # 'random' is necessary for some systems, especially double diff --git a/ferminet/pretrain.py b/ferminet/pretrain.py index 802c430..7a35dff 100644 --- a/ferminet/pretrain.py +++ b/ferminet/pretrain.py @@ -14,7 +14,7 @@ """Utilities for pretraining and importing PySCF models.""" -from typing import Callable, Optional, Sequence, Tuple, Union +from typing import Callable, Mapping, Sequence, Tuple, Union from absl import logging import chex @@ -31,11 +31,13 @@ import pyscf -def get_hf(molecule: Optional[Sequence[system.Atom]] = None, - nspins: Optional[Tuple[int, int]] = None, - basis: Optional[str] = 'sto-3g', - pyscf_mol: Optional[pyscf.gto.Mole] = None, - restricted: Optional[bool] = False, +def get_hf(molecule: Sequence[system.Atom] | None = None, + nspins: Tuple[int, int] | None = None, + basis: str | None = 'sto-3g', + ecp: Mapping[str, str] | None = None, + core_electrons: Mapping[str, int] | None = None, + pyscf_mol: pyscf.gto.Mole | None = None, + restricted: bool | None = False, states: int = 0, excitation_type: str = 'ordered') -> scf.Scf: """Returns an Scf object with the Hartree-Fock solution to the system. @@ -44,6 +46,9 @@ def get_hf(molecule: Optional[Sequence[system.Atom]] = None, molecule: the molecule in internal format. nspins: tuple with number of spin up and spin down electrons. basis: basis set to use in Hartree-Fock calculation. + ecp: dictionary of the ECP to use for different atoms. + core_electrons: dictionary of the number of core electrons excluded by the + pseudopotential/effective core potential. pyscf_mol: pyscf Mole object defining the molecule. If supplied, molecule, nspins and basis are ignored. restricted: If true, perform a restricted Hartree-Fock calculation, @@ -56,10 +61,15 @@ def get_hf(molecule: Optional[Sequence[system.Atom]] = None, but 'random' is necessary for some systems, especially double excitaitons. """ if pyscf_mol: - scf_approx = scf.Scf(pyscf_mol=pyscf_mol, restricted=restricted) + scf_approx = scf.Scf(pyscf_mol=pyscf_mol, + restricted=restricted) else: - scf_approx = scf.Scf( - molecule, nelectrons=nspins, basis=basis, restricted=restricted) + scf_approx = scf.Scf(molecule, + nelectrons=nspins, + basis=basis, + ecp=ecp, + core_electrons=core_electrons, + restricted=restricted) scf_approx.run(excitations=max(states - 1, 0), excitation_type=excitation_type) return scf_approx @@ -100,34 +110,14 @@ def eval_orbitals(scf_approx: scf.Scf, pos: Union[np.ndarray, jnp.ndarray], return alpha_spin, beta_spin -def eval_slater(scf_approx: scf.Scf, pos: Union[jnp.ndarray, np.ndarray], - nspins: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray]: - """Evaluates the Slater determinant. - - Args: - scf_approx: an object that contains the result of a PySCF calculation. - pos: an array of electron positions to evaluate the orbitals at. - nspins: tuple with number of spin up and spin down electrons. - - Returns: - tuple with sign and log absolute value of Slater determinant. - """ - matrices = eval_orbitals(scf_approx, pos, nspins) - slogdets = [np.linalg.slogdet(elem) for elem in matrices] - sign_alpha, sign_beta = [elem[0] for elem in slogdets] - log_abs_wf_alpha, log_abs_wf_beta = [elem[1] for elem in slogdets] - log_abs_slater_determinant = log_abs_wf_alpha + log_abs_wf_beta - sign = sign_alpha * sign_beta - return sign, log_abs_slater_determinant - - def make_pretrain_step( batch_orbitals: networks.OrbitalFnLike, batch_network: networks.LogFermiNetLike, optimizer_update: optax.TransformUpdateFn, - scf_approx: scf.Scf, electrons: Tuple[int, int], + batch_size: int = 0, full_det: bool = False, + scf_fraction: float = 0.0, states: int = 0, ): """Creates function for performing one step of Hartre-Fock pretraining. @@ -141,93 +131,129 @@ def make_pretrain_step( magnitude of the (wavefunction) network evaluated at those positions. optimizer_update: callable for transforming the gradients into an update (ie conforms to the optax API). - scf_approx: an scf.Scf object that contains the result of a PySCF - calculation. electrons: number of spin-up and spin-down electrons. + batch_size: number of walkers per device, used to make MCMC step. full_det: If true, evaluate all electrons in a single determinant. Otherwise, evaluate products of alpha- and beta-spin determinants. + scf_fraction: What fraction of the wavefunction sampled from is the SCF + wavefunction and what fraction is the neural network wavefunction? states: Number of excited states, if not 0. Returns: Callable for performing a single pretraining optimisation step. """ - def pretrain_step(data, params, state, key, logprob): - """One iteration of pretraining to match HF.""" - - cnorm = lambda x, y: (x - y) * jnp.conj(x - y) # complex norm - def loss_fn( - params: networks.ParamTree, - data: networks.FermiNetData, - ): - pos = data.positions - spins = data.spins - if states: - # Make vmap-ed versions of eval_orbitals and batch_orbitals over the - # states dimension. - # (batch, states, nelec*ndim) - pos = jnp.reshape(pos, pos.shape[:-1] + (states, -1)) - # (batch, states, nelec) - spins = jnp.reshape(spins, spins.shape[:-1] + (states, -1)) - - scf_orbitals = jax.vmap( - scf_approx.eval_orbitals, in_axes=(-2, None), out_axes=-4 - ) - - def net_orbitals(params, pos, spins, atoms, charges): - vmapped_orbitals = jax.vmap( - batch_orbitals, in_axes=(None, -2, -2, None, None), out_axes=-4 - ) - # Dimensions of result are - # [(batch, states, ndet*states, nelec, nelec)] - result = vmapped_orbitals(params, pos, spins, atoms, charges) - result = [ - jnp.reshape(r, r.shape[:-3] + (states, -1) + r.shape[-2:]) - for r in result - ] - result = [jnp.transpose(r, (0, 3, 1, 2, 4, 5)) for r in result] - # We draw distinct samples for each excited state (electron - # configuration), and then evaluate each state within each sample. - # Output dimensions are: - # (batch, det, electron configuration, - # excited state, electron, orbital) - return result - - else: - scf_orbitals = scf_approx.eval_orbitals - net_orbitals = batch_orbitals - - target = scf_orbitals(pos, electrons) - orbitals = net_orbitals(params, pos, spins, data.atoms, data.charges) - if full_det: - dims = target[0].shape[:-2] # (batch) or (batch, states). - na = target[0].shape[-2] - nb = target[1].shape[-2] - target = jnp.concatenate( - ( - jnp.concatenate( - (target[0], jnp.zeros(dims + (na, nb))), axis=-1), - jnp.concatenate( - (jnp.zeros(dims + (nb, na)), target[1]), axis=-1), - ), - axis=-2, + # Create a function which gives either the SCF ansatz, the neural network + # ansatz, or a weighted mixture of the two. + if scf_fraction > 1 or scf_fraction < 0: + raise ValueError('scf_fraction must be in between 0 and 1, inclusive.') + + if states: + def scf_network(fn, x): + x = x.reshape(x.shape[:-1] + (states, -1)) + slater_fn = jax.vmap(fn, in_axes=(-2, None), out_axes=-2) + slogdets = slater_fn(x, electrons) + # logsumexp trick + maxlogdet = jnp.max(slogdets[1]) + dets = slogdets[0] * jnp.exp(slogdets[1] - maxlogdet) + result = jnp.linalg.slogdet(dets) + return result[1] + maxlogdet * slogdets[1].shape[-1] + else: + scf_network = lambda fn, x: fn(x, electrons)[1] + + if scf_fraction < 1e-6: + def mcmc_network(full_params, pos, spins, atoms, charges): + return batch_network(full_params['ferminet'], pos, spins, atoms, charges) + elif scf_fraction > 0.999999: + def mcmc_network(full_params, pos, spins, atoms, charges): + del spins, atoms, charges + return scf_network(full_params['scf'].eval_slater, pos) + else: + def mcmc_network(full_params, pos, spins, atoms, charges): + log_ferminet = batch_network(full_params['ferminet'], pos, spins, atoms, + charges) + log_scf = scf_network(full_params['scf'].eval_slater, pos) + return (1 - scf_fraction) * log_ferminet + scf_fraction * log_scf + + mcmc_step = mcmc.make_mcmc_step( + mcmc_network, batch_per_device=batch_size, steps=1) + + def loss_fn( + params: networks.ParamTree, + data: networks.FermiNetData, + scf_approx: scf.Scf, + ): + pos = data.positions + spins = data.spins + if states: + # Make vmap-ed versions of eval_orbitals and batch_orbitals over the + # states dimension. + # (batch, states, nelec*ndim) + pos = jnp.reshape(pos, pos.shape[:-1] + (states, -1)) + # (batch, states, nelec) + spins = jnp.reshape(spins, spins.shape[:-1] + (states, -1)) + + scf_orbitals = jax.vmap( + scf_approx.eval_orbitals, in_axes=(-2, None), out_axes=-4 + ) + + def net_orbitals(params, pos, spins, atoms, charges): + vmapped_orbitals = jax.vmap( + batch_orbitals, in_axes=(None, -2, -2, None, None), out_axes=-4 ) - result = jnp.mean(cnorm(target[:, None, ...], orbitals[0])).real - else: - result = jnp.array([ - jnp.mean(cnorm(t[:, None, ...], o)).real - for t, o in zip(target, orbitals) - ]).sum() - return constants.pmean(result) - + # Dimensions of result are + # [(batch, states, ndet*states, nelec, nelec)] + result = vmapped_orbitals(params, pos, spins, atoms, charges) + result = [ + jnp.reshape(r, r.shape[:-3] + (states, -1) + r.shape[-2:]) + for r in result + ] + result = [jnp.transpose(r, (0, 3, 1, 2, 4, 5)) for r in result] + # We draw distinct samples for each excited state (electron + # configuration), and then evaluate each state within each sample. + # Output dimensions are: + # (batch, det, electron configuration, + # excited state, electron, orbital) + return result + + else: + scf_orbitals = scf_approx.eval_orbitals + net_orbitals = batch_orbitals + + target = scf_orbitals(pos, electrons) + orbitals = net_orbitals(params, pos, spins, data.atoms, data.charges) + cnorm = lambda x, y: (x - y) * jnp.conj(x - y) # complex norm + if full_det: + dims = target[0].shape[:-2] # (batch) or (batch, states). + na = target[0].shape[-2] + nb = target[1].shape[-2] + target = jnp.concatenate( + ( + jnp.concatenate( + (target[0], jnp.zeros(dims + (na, nb))), axis=-1), + jnp.concatenate( + (jnp.zeros(dims + (nb, na)), target[1]), axis=-1), + ), + axis=-2, + ) + result = jnp.mean(cnorm(target[:, None, ...], orbitals[0])).real + else: + result = jnp.array([ + jnp.mean(cnorm(t[:, None, ...], o)).real + for t, o in zip(target, orbitals) + ]).sum() + return constants.pmean(result) + + def pretrain_step(data, params, state, key, scf_approx): + """One iteration of pretraining to match HF.""" val_and_grad = jax.value_and_grad(loss_fn, argnums=0) - loss_val, search_direction = val_and_grad(params, data) + loss_val, search_direction = val_and_grad(params, data, scf_approx) search_direction = constants.pmean(search_direction) updates, state = optimizer_update(search_direction, state, params) params = optax.apply_updates(params, updates) - data, key, logprob, _ = mcmc.mh_update(params, batch_network, data, key, - logprob, 0) - return data, params, state, loss_val, logprob + full_params = {'ferminet': params, 'scf': scf_approx} + data, pmove = mcmc_step(full_params, data, key, width=0.02) + return data, params, state, loss_val, pmove return pretrain_step @@ -246,7 +272,9 @@ def pretrain_hartree_fock( electrons: Tuple[int, int], scf_approx: scf.Scf, iterations: int = 1000, - logger: Optional[Callable[[int, float], None]] = None, + batch_size: int = 0, + logger: Callable[[int, float], None] | None = None, + scf_fraction: float = 0.0, states: int = 0, ): """Performs training to match initialization as closely as possible to HF. @@ -271,8 +299,11 @@ def pretrain_hartree_fock( scf_approx: an scf.Scf object that contains the result of a PySCF calculation. iterations: number of pretraining iterations to perform. + batch_size: number of walkers per device, used to make MCMC step. logger: Callable with signature (step, value) which externally logs the pretraining loss. + scf_fraction: What fraction of the wavefunction sampled from is the SCF + wavefunction and what fraction is the neural network wavefunction? states: Number of excited states, if not 0. Returns: @@ -292,26 +323,25 @@ def pretrain_hartree_fock( batch_orbitals, batch_network, optimizer.update, - scf_approx=scf_approx, electrons=electrons, + batch_size=batch_size, full_det=network_options.full_det, + scf_fraction=scf_fraction, states=states, ) pretrain_step = constants.pmap(pretrain_step) - pnetwork = constants.pmap(batch_network) batch_spins = jnp.tile(spins[None], [positions.shape[1], 1]) pmap_spins = kfac_jax.utils.replicate_all_local_devices(batch_spins) data = networks.FermiNetData( positions=positions, spins=pmap_spins, atoms=atoms, charges=charges ) - logprob = 2.0 * pnetwork(params, positions, pmap_spins, atoms, charges) for t in range(iterations): sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key) - data, params, opt_state_pt, loss, logprob = pretrain_step( - data, params, opt_state_pt, subkeys, logprob) - logging.info('Pretrain iter %05d: %g', t, loss[0]) + data, params, opt_state_pt, loss, pmove = pretrain_step( + data, params, opt_state_pt, subkeys, scf_approx) + logging.info('Pretrain iter %05d: %g %g', t, loss[0], pmove[0]) if logger: logger(t, loss[0]) return params, data.positions diff --git a/ferminet/tests/train_test.py b/ferminet/tests/train_test.py index 7ed03fe..2883953 100644 --- a/ferminet/tests/train_test.py +++ b/ferminet/tests/train_test.py @@ -25,6 +25,7 @@ from ferminet import train from ferminet.configs import atom from ferminet.configs import diatomic +import jax import pyscf FLAGS = flags.FLAGS @@ -56,7 +57,15 @@ def _config_params(): 'optimizer': optimizer, 'complex_': complex_, 'states': states, - 'laplacian': 'default'} + 'laplacian': 'default', + 'scf_fraction': 0.0} + for states, scf_fraction in itertools.product((0, 2), (0.0, 0.5, 1.0)): + yield {'system': 'LiH', + 'optimizer': 'kfac', + 'complex_': False, + 'states': states, + 'laplacian': 'default', + 'scf_fraction': scf_fraction} for optimizer in ('kfac', 'adam', 'lamb', 'none'): yield { 'system': 'H' if optimizer in ('kfac', 'adam') else 'Li', @@ -64,6 +73,7 @@ def _config_params(): 'complex_': False, 'states': 0, 'laplacian': 'default', + 'scf_fraction': 0.0 } for states, laplacian, complex_ in itertools.product( (0, 2), ('default', 'folx'), (True, False)): @@ -72,7 +82,8 @@ def _config_params(): 'optimizer': 'kfac', 'complex_': complex_, 'states': states, - 'laplacian': laplacian + 'laplacian': laplacian, + 'scf_fraction': 0.0 } @@ -84,9 +95,13 @@ def setUp(self): # Test calculations are small enough to fit in RAM and we don't need # checkpoint files. pyscf.lib.param.TMPDIR = None + # Prevents issues related to the mcmc step in pretraining if multiple + # training runs are executed in the same session. + jax.clear_caches() @parameterized.parameters(_config_params()) - def test_training_step(self, system, optimizer, complex_, states, laplacian): + def test_training_step( + self, system, optimizer, complex_, states, laplacian, scf_fraction): if system in ('H', 'Li'): cfg = atom.get_config() cfg.system.atom = system @@ -99,6 +114,7 @@ def test_training_step(self, system, optimizer, complex_, states, laplacian): cfg.batch_size = 32 cfg.system.states = states cfg.pretrain.iterations = 10 + cfg.pretrain.scf_fraction = scf_fraction cfg.mcmc.burn_in = 10 cfg.optim.optimizer = optimizer cfg.optim.laplacian = laplacian @@ -231,7 +247,7 @@ def test_pseudopotential_step(self, states): cfg.network.ferminet.hidden_dims = ((16, 4),) * 2 cfg.network.determinants = 2 cfg.batch_size = 32 - cfg.pretrain.iterations = 0 + cfg.pretrain.iterations = 10 cfg.mcmc.burn_in = 0 cfg.system.use_pp = True cfg.system.pp.symbols = ['Li'] diff --git a/ferminet/train.py b/ferminet/train.py index 9c55587..823d86c 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -415,6 +415,19 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): seed = int(multihost_utils.broadcast_one_to_all(seed)[0]) key = jax.random.PRNGKey(seed) + # extract number of electrons of each spin around each atom removed because + # of pseudopotentials + if cfg.system.pyscf_mol: + cfg.system.pyscf_mol.build() + core_electrons = { + atom: ecp_table[0] + for atom, ecp_table in cfg.system.pyscf_mol._ecp.items() # pylint: disable=protected-access + } + ecp = cfg.system.pyscf_mol.ecp + else: + ecp = {} + core_electrons = {} + # Create parameters, network, and vmaped/pmaped derivations if cfg.pretrain.method == 'hf' and cfg.pretrain.iterations > 0: @@ -424,6 +437,8 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): nspins=nspins, restricted=False, basis=cfg.pretrain.basis, + ecp=ecp, + core_electrons=core_electrons, states=cfg.system.states, excitation_type=cfg.pretrain.get('excitation_type', 'ordered')) # broadcast the result of PySCF from host 0 to all other hosts @@ -557,16 +572,6 @@ def log_network(*args, **kwargs): key, subkey = jax.random.split(key) # make sure data on each host is initialized differently subkey = jax.random.fold_in(subkey, jax.process_index()) - # extract number of electrons of each spin around each atom removed because - # of pseudopotentials - if cfg.system.pyscf_mol: - cfg.system.pyscf_mol.build() - core_electrons = { - atom: ecp_table[0] - for atom, ecp_table in cfg.system.pyscf_mol._ecp.items() # pylint: disable=protected-access - } - else: - core_electrons = {} # create electron state (position and spin) pos, spins = init_electrons( subkey, @@ -672,6 +677,8 @@ def log_network(*args, **kwargs): electrons=cfg.system.electrons, scf_approx=hartree_fock, iterations=cfg.pretrain.iterations, + batch_size=device_batch_size, + scf_fraction=cfg.pretrain.get('scf_fraction', 0.0), states=cfg.system.states, ) diff --git a/ferminet/utils/scf.py b/ferminet/utils/scf.py index 7dd326c..221a65f 100644 --- a/ferminet/utils/scf.py +++ b/ferminet/utils/scf.py @@ -29,7 +29,7 @@ # are solutions to the Hartree-Fock equations. -from typing import Optional, Sequence, Tuple, Union +from typing import Mapping, Optional, Sequence, Tuple, Union from absl import logging from ferminet.utils import elements @@ -75,31 +75,58 @@ class Scf: """ def __init__(self, - molecule: Optional[Sequence[system.Atom]] = None, - nelectrons: Optional[Tuple[int, int]] = None, - basis: Optional[str] = 'cc-pVTZ', - pyscf_mol: Optional[pyscf.gto.Mole] = None, + molecule: Sequence[system.Atom] | None = None, + nelectrons: Tuple[int, int] | None = None, + basis: str | None = 'cc-pVTZ', + ecp: Mapping[str, str] | None = None, + core_electrons: Mapping[str, int] | None = None, + pyscf_mol: pyscf.gto.Mole | None = None, restricted: bool = True): + pyscf.lib.param.TMPDIR = None + if pyscf_mol: self._mol = pyscf_mol - # Create pure-JAX Mol object so that GTOs can be evaluated in traced - # JAX functions + else: + # If not passed a pyscf molecule, create one + if any(atom.atomic_number - atom.charge > 1.e-8 + for atom in molecule): + logging.info( + 'Fractional nuclear charge detected. ' + 'Running SCF on atoms with integer charge.' + ) + ecp = ecp or {} + core_electrons = core_electrons or {} + + nuclear_charge = 0 + for atom in molecule: + nuclear_charge += atom.atomic_number + if atom.symbol in core_electrons: + nuclear_charge -= core_electrons[atom.symbol] + charge = nuclear_charge - sum(nelectrons) + self._mol = pyscf.gto.Mole( + atom=[[atom.symbol, atom.coords] for atom in molecule], + unit='bohr') + self._mol.basis = basis + self._mol.spin = nelectrons[0] - nelectrons[1] + self._mol.charge = charge + self._mol.ecp = ecp + self._mol.build() + if self._mol.nelectron != sum(nelectrons): + raise RuntimeError('PySCF molecule not consistent with QMC molecule.') self._mol_jax = gto.Mol.from_pyscf_mol(self._mol) + if restricted: + self.mean_field = pyscf.scf.RHF(self._mol) else: - self.molecule = molecule - self.nelectrons = nelectrons - self.basis = basis - self._spin = nelectrons[0] - nelectrons[1] - self._mol = None + self.mean_field = pyscf.scf.UHF(self._mol) + # Create pure-JAX Mol object so that GTOs can be evaluated in traced + # JAX functions + self._mol_jax = gto.Mol.from_pyscf_mol(self._mol) self.restricted = restricted - self.mean_field = None self.excitations = None - pyscf.lib.param.TMPDIR = None - def run(self, - dm0: Optional[np.ndarray] = None, + dm0: np.ndarray | None = None, excitations: int = 0, excitation_type: str = 'ordered'): """Runs the Hartree-Fock calculation. @@ -121,31 +148,6 @@ def run(self, RuntimeError: If the number of electrons in the PySCF molecule is not consistent with self.nelectrons. """ - # If not passed a pyscf molecule, create one - if not self._mol: - if any(atom.atomic_number - atom.charge > 1.e-8 - for atom in self.molecule): - logging.info( - 'Fractional nuclear charge detected. ' - 'Running SCF on atoms with integer charge.' - ) - - nuclear_charge = sum(atom.atomic_number for atom in self.molecule) - charge = nuclear_charge - sum(self.nelectrons) - self._mol = pyscf.gto.Mole( - atom=[[atom.symbol, atom.coords] for atom in self.molecule], - unit='bohr') - self._mol.basis = self.basis - self._mol.spin = self._spin - self._mol.charge = charge - self._mol.build() - if self._mol.nelectron != sum(self.nelectrons): - raise RuntimeError('PySCF molecule not consistent with QMC molecule.') - self._mol_jax = gto.Mol.from_pyscf_mol(self._mol) - if self.restricted: - self.mean_field = pyscf.scf.RHF(self._mol) - else: - self.mean_field = pyscf.scf.UHF(self._mol) try: self.mean_field.kernel(dm0=dm0) except TypeError: @@ -191,12 +193,14 @@ def mo_coeff(self) -> Optional[np.ndarray]: @mo_coeff.setter def mo_coeff(self, mo_coeff): + # pytype: disable=attribute-error if (self.mean_field is not None and self.mean_field.mo_coeff is not None and self.mean_field.mo_coeff.ndim != mo_coeff.ndim): raise ValueError('Attempting to override mo_coeffs with different rank. ' f'Got {mo_coeff.shape=}, have ' f'{self.mean_field.mo_coeff.shape=}') + # pytype: enable=attribute-error self.mean_field.mo_coeff = mo_coeff def eval_mos(self, positions: NDArray) -> Tuple[NDArray, NDArray]: @@ -306,6 +310,26 @@ def eval_orbitals(self, return alpha_spin, beta_spin + def eval_slater(self, + pos: Union[jnp.ndarray, np.ndarray], + nspins: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray]: + """Evaluates the Slater determinant. + + Args: + pos: an array of electron positions to evaluate the orbitals at. + nspins: tuple with number of spin up and spin down electrons. + + Returns: + tuple with sign and log absolute value of Slater determinant. + """ + matrices = self.eval_orbitals(pos, nspins) + slogdets = [jnp.linalg.slogdet(elem) for elem in matrices] + sign_alpha, sign_beta = [elem[0] for elem in slogdets] + log_abs_wf_alpha, log_abs_wf_beta = [elem[1] for elem in slogdets] + log_abs_slater_determinant = log_abs_wf_alpha + log_abs_wf_beta + sign = sign_alpha * sign_beta + return sign, log_abs_slater_determinant + # pylint: disable=protected-access def scf_flatten(scf: Scf):