Skip to content

Commit

Permalink
Revert changes in the hamiltonian, update config
Browse files Browse the repository at this point in the history
  • Loading branch information
necludov committed Aug 8, 2023
1 parent 1cbaa63 commit dda112b
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion ferminet/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def default() -> ml_collections.ConfigDict:
# importlib.import_module.
'config_module': __name__,
'optim': {
'objective': 'vmc', # objective type. Either 'vmc' or 'wvmc'
'objective': 'vmc', # objective type. Either 'vmc' or 'wqmc'
'iterations': 1000000, # number of iterations
'optimizer': 'kfac', # one of adam, kfac, lamb, none
'lr': {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _adjust_nuclear_charge(cfg):
def get_config():
"""Returns config for running generic atoms with qmc."""
cfg = base_config.default()
cfg.system.atom = ''
cfg.system.atom = 'Li'
cfg.system.charge = 0
cfg.system.delta_charge = 0.0
cfg.system.spin_polarisation = ml_collections.FieldReference(
Expand All @@ -78,4 +78,5 @@ def get_config():
cfg.optim.clip_median = True
cfg.debug.deterministic = True
cfg.optim.kfac.norm_constraint = 1e-3
cfg.optim.objective = 'wqmc'
return cfg
12 changes: 6 additions & 6 deletions ferminet/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,25 +174,25 @@ def potential_nuclear_nuclear(charges: Array, atoms: Array) -> jnp.ndarray:
charges: Shape (natoms). Nuclear charges of the atoms.
atoms: Shape (natoms, ndim). Positions of the atoms.
"""
natoms = atoms.shape[0]
r_aa = jnp.sqrt(jnp.sum((atoms[None, ...] - atoms[:, None] +
jnp.eye(natoms)[..., None])**2, axis=-1))
r_aa = jnp.linalg.norm(atoms[None, ...] - atoms[:, None], axis=-1)
return jnp.sum(
jnp.triu((charges[None, ...] * charges[..., None]) / r_aa, k=1))


def potential_energy(r_ae: Array, pos: Array, atoms: Array,
def potential_energy(r_ae: Array, r_ee: Array, atoms: Array,
charges: Array) -> jnp.ndarray:
"""Returns the potential energy for this electron configuration.
Args:
r_ae: Shape (nelectrons, natoms). r_ae[i, j] gives the distance between
electron i and atom j.
pos: Shape (neletrons, ndim). Electron positions.
r_ee: Shape (neletrons, nelectrons, :). r_ee[i,j,0] gives the distance
between electrons i and j. Other elements in the final axes are not
required.
atoms: Shape (natoms, ndim). Positions of the atoms.
charges: Shape (natoms). Nuclear charges of the atoms.
"""
return (potential_electron_electron(pos) +
return (potential_electron_electron(r_ee) +
potential_electron_nuclear(charges, r_ae) +
potential_nuclear_nuclear(charges, atoms))

Expand Down
6 changes: 3 additions & 3 deletions ferminet/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,13 @@ def total_energy_jvp(primals, tangents): # pylint: disable=unused-variable
return total_energy


def make_wvmc_loss(network: networks.LogFermiNetLike,
def make_wqmc_loss(network: networks.LogFermiNetLike,
local_energy: hamiltonian.LocalEnergy,
clip_local_energy: float = 0.0,
clip_from_median: bool = True,
center_at_clipped_energy: bool = True,
complex_output: bool = False) -> LossFn:
"""Creates the loss function, including custom gradients.
"""Creates the WQMC loss function, including custom gradients.
Args:
network: callable which evaluates the log of the magnitude of the
Expand Down Expand Up @@ -362,7 +362,7 @@ def log_q(_params, _pos, _spins, _atoms, _charges):
return out.sum()

score = jax.grad(log_q, argnums=1)
primals = (primals[0], data.positions, data.spins, data.atoms, data.charges)
primals = (params, data.positions, data.spins, data.atoms, data.charges)
tangents = (
tangents[0],
tangents[2].positions,
Expand Down
4 changes: 2 additions & 2 deletions ferminet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,8 @@ def log_network(*args, **kwargs):
center_at_clipped_energy=cfg.optim.center_at_clip,
complex_output=cfg.network.get('complex', False)
)
elif cfg.optim.objective == 'wvmc':
evaluate_loss = qmc_loss_functions.make_wvmc_loss(
elif cfg.optim.objective == 'wqmc':
evaluate_loss = qmc_loss_functions.make_wqmc_loss(
log_network if cfg.network.get('complex', False) else logabs_network,
local_energy,
clip_local_energy=cfg.optim.clip_local_energy,
Expand Down

0 comments on commit dda112b

Please sign in to comment.