Skip to content

Commit

Permalink
potential function fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
necludov committed Aug 5, 2023
1 parent 9314bda commit 1cbaa63
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
16 changes: 7 additions & 9 deletions ferminet/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,15 @@ def grad_phase_closure(x):
return _lapl_over_f


def potential_electron_electron(pos: Array) -> jnp.ndarray:
def potential_electron_electron(r_ee: Array) -> jnp.ndarray:
"""Returns the electron-electron potential.
Args:
pos: Shape (neletrons, ndim). Electron positions.
r_ee: Shape (neletrons, nelectrons, :). r_ee[i,j,0] gives the distance
between electrons i and j. Other elements in the final axes are not
required.
"""
nelectrons = pos.shape[0]
r_ee = jnp.sqrt(jnp.sum((pos[None, ...] - pos[:, None] +
jnp.eye(nelectrons)[..., None])**2, axis=-1))
r_ee = r_ee[jnp.triu_indices_from(r_ee, 1)]
r_ee = r_ee[jnp.triu_indices_from(r_ee[..., 0], 1)]
return (1./r_ee).sum()


Expand Down Expand Up @@ -236,9 +235,8 @@ def _e_l(
data: MCMC configuration.
"""
del key # unused
_, _, r_ae, _ = networks.construct_input_features(data.positions, data.atoms)
pos = data.positions.reshape(r_ae.shape[0],-1)
potential = potential_energy(r_ae, pos, data.atoms, charges)
_, _, r_ae, r_ee = networks.construct_input_features(data.positions, data.atoms)
potential = potential_energy(r_ae, r_ee, data.atoms, charges)
kinetic = ke(params, data)
return potential + kinetic

Expand Down
5 changes: 2 additions & 3 deletions ferminet/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,8 @@ def construct_input_features(
r_ae = jnp.linalg.norm(ae, axis=2, keepdims=True)
# Avoid computing the norm of zero, as is has undefined grad
n = ee.shape[0]
# r_ee = (
# jnp.linalg.norm(ee + jnp.eye(n)[..., None], axis=-1) * (1.0 - jnp.eye(n)))
r_ee = jnp.sqrt(((ee + jnp.eye(n)[..., None])**2).sum(2)) * (1.0 - jnp.eye(n))
r_ee = (
jnp.linalg.norm(ee + jnp.eye(n)[..., None], axis=-1) * (1.0 - jnp.eye(n)))
return ae, ee, r_ae, r_ee[..., None]


Expand Down

0 comments on commit 1cbaa63

Please sign in to comment.