diff --git a/ferminet/base_config.py b/ferminet/base_config.py index 00e8552..70d653b 100644 --- a/ferminet/base_config.py +++ b/ferminet/base_config.py @@ -53,7 +53,7 @@ def default() -> ml_collections.ConfigDict: # importlib.import_module. 'config_module': __name__, 'optim': { - 'objective': 'vmc', # objective type. Either 'vmc' or 'wvmc' + 'objective': 'vmc', # objective type. Either 'vmc' or 'wqmc' 'iterations': 1000000, # number of iterations 'optimizer': 'kfac', # one of adam, kfac, lamb, none 'lr': { diff --git a/ferminet/configs/atom_test.py b/ferminet/configs/li_wqmc.py similarity index 97% rename from ferminet/configs/atom_test.py rename to ferminet/configs/li_wqmc.py index 6b5f9c2..3d35d00 100644 --- a/ferminet/configs/atom_test.py +++ b/ferminet/configs/li_wqmc.py @@ -64,7 +64,7 @@ def _adjust_nuclear_charge(cfg): def get_config(): """Returns config for running generic atoms with qmc.""" cfg = base_config.default() - cfg.system.atom = '' + cfg.system.atom = 'Li' cfg.system.charge = 0 cfg.system.delta_charge = 0.0 cfg.system.spin_polarisation = ml_collections.FieldReference( @@ -78,4 +78,5 @@ def get_config(): cfg.optim.clip_median = True cfg.debug.deterministic = True cfg.optim.kfac.norm_constraint = 1e-3 + cfg.optim.objective = 'wqmc' return cfg diff --git a/ferminet/hamiltonian.py b/ferminet/hamiltonian.py index 448bfaa..a7257d7 100644 --- a/ferminet/hamiltonian.py +++ b/ferminet/hamiltonian.py @@ -174,25 +174,25 @@ def potential_nuclear_nuclear(charges: Array, atoms: Array) -> jnp.ndarray: charges: Shape (natoms). Nuclear charges of the atoms. atoms: Shape (natoms, ndim). Positions of the atoms. """ - natoms = atoms.shape[0] - r_aa = jnp.sqrt(jnp.sum((atoms[None, ...] - atoms[:, None] + - jnp.eye(natoms)[..., None])**2, axis=-1)) + r_aa = jnp.linalg.norm(atoms[None, ...] - atoms[:, None], axis=-1) return jnp.sum( jnp.triu((charges[None, ...] * charges[..., None]) / r_aa, k=1)) -def potential_energy(r_ae: Array, pos: Array, atoms: Array, +def potential_energy(r_ae: Array, r_ee: Array, atoms: Array, charges: Array) -> jnp.ndarray: """Returns the potential energy for this electron configuration. Args: r_ae: Shape (nelectrons, natoms). r_ae[i, j] gives the distance between electron i and atom j. - pos: Shape (neletrons, ndim). Electron positions. + r_ee: Shape (neletrons, nelectrons, :). r_ee[i,j,0] gives the distance + between electrons i and j. Other elements in the final axes are not + required. atoms: Shape (natoms, ndim). Positions of the atoms. charges: Shape (natoms). Nuclear charges of the atoms. """ - return (potential_electron_electron(pos) + + return (potential_electron_electron(r_ee) + potential_electron_nuclear(charges, r_ae) + potential_nuclear_nuclear(charges, atoms)) diff --git a/ferminet/loss.py b/ferminet/loss.py index 884b60e..df6c2b9 100644 --- a/ferminet/loss.py +++ b/ferminet/loss.py @@ -258,13 +258,13 @@ def total_energy_jvp(primals, tangents): # pylint: disable=unused-variable return total_energy -def make_wvmc_loss(network: networks.LogFermiNetLike, +def make_wqmc_loss(network: networks.LogFermiNetLike, local_energy: hamiltonian.LocalEnergy, clip_local_energy: float = 0.0, clip_from_median: bool = True, center_at_clipped_energy: bool = True, complex_output: bool = False) -> LossFn: - """Creates the loss function, including custom gradients. + """Creates the WQMC loss function, including custom gradients. Args: network: callable which evaluates the log of the magnitude of the @@ -362,7 +362,7 @@ def log_q(_params, _pos, _spins, _atoms, _charges): return out.sum() score = jax.grad(log_q, argnums=1) - primals = (primals[0], data.positions, data.spins, data.atoms, data.charges) + primals = (params, data.positions, data.spins, data.atoms, data.charges) tangents = ( tangents[0], tangents[2].positions, diff --git a/ferminet/train.py b/ferminet/train.py index a90321b..a260ba7 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -580,8 +580,8 @@ def log_network(*args, **kwargs): center_at_clipped_energy=cfg.optim.center_at_clip, complex_output=cfg.network.get('complex', False) ) - elif cfg.optim.objective == 'wvmc': - evaluate_loss = qmc_loss_functions.make_wvmc_loss( + elif cfg.optim.objective == 'wqmc': + evaluate_loss = qmc_loss_functions.make_wqmc_loss( log_network if cfg.network.get('complex', False) else logabs_network, local_energy, clip_local_energy=cfg.optim.clip_local_energy,