diff --git a/ferminet/hamiltonian.py b/ferminet/hamiltonian.py index 252a4fe..448bfaa 100644 --- a/ferminet/hamiltonian.py +++ b/ferminet/hamiltonian.py @@ -144,16 +144,15 @@ def grad_phase_closure(x): return _lapl_over_f -def potential_electron_electron(pos: Array) -> jnp.ndarray: +def potential_electron_electron(r_ee: Array) -> jnp.ndarray: """Returns the electron-electron potential. Args: - pos: Shape (neletrons, ndim). Electron positions. + r_ee: Shape (neletrons, nelectrons, :). r_ee[i,j,0] gives the distance + between electrons i and j. Other elements in the final axes are not + required. """ - nelectrons = pos.shape[0] - r_ee = jnp.sqrt(jnp.sum((pos[None, ...] - pos[:, None] + - jnp.eye(nelectrons)[..., None])**2, axis=-1)) - r_ee = r_ee[jnp.triu_indices_from(r_ee, 1)] + r_ee = r_ee[jnp.triu_indices_from(r_ee[..., 0], 1)] return (1./r_ee).sum() @@ -236,9 +235,8 @@ def _e_l( data: MCMC configuration. """ del key # unused - _, _, r_ae, _ = networks.construct_input_features(data.positions, data.atoms) - pos = data.positions.reshape(r_ae.shape[0],-1) - potential = potential_energy(r_ae, pos, data.atoms, charges) + _, _, r_ae, r_ee = networks.construct_input_features(data.positions, data.atoms) + potential = potential_energy(r_ae, r_ee, data.atoms, charges) kinetic = ke(params, data) return potential + kinetic diff --git a/ferminet/networks.py b/ferminet/networks.py index 070ba72..d8075b8 100644 --- a/ferminet/networks.py +++ b/ferminet/networks.py @@ -471,9 +471,8 @@ def construct_input_features( r_ae = jnp.linalg.norm(ae, axis=2, keepdims=True) # Avoid computing the norm of zero, as is has undefined grad n = ee.shape[0] - # r_ee = ( - # jnp.linalg.norm(ee + jnp.eye(n)[..., None], axis=-1) * (1.0 - jnp.eye(n))) - r_ee = jnp.sqrt(((ee + jnp.eye(n)[..., None])**2).sum(2)) * (1.0 - jnp.eye(n)) + r_ee = ( + jnp.linalg.norm(ee + jnp.eye(n)[..., None], axis=-1) * (1.0 - jnp.eye(n))) return ae, ee, r_ae, r_ee[..., None]