Skip to content

Commit

Permalink
Enable optimization of complex wavefunctions.
Browse files Browse the repository at this point in the history
This is described in Cassella et al., PRL 130, 036401 (2023)

PiperOrigin-RevId: 518242819
Change-Id: Ia87bc23e049797cdea1725ffd15e4e574c87c37b
  • Loading branch information
dpfau authored and jsspencer committed Apr 25, 2023
1 parent c1d235a commit 7ae9502
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 40 deletions.
2 changes: 2 additions & 0 deletions ferminet/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def default() -> ml_collections.ConfigDict:
},
'network': {
'network_type': 'ferminet', # One of 'ferminet' or 'psiformer'.
# If true, the network outputs complex numbers rather than real.
'complex': False,
# Config specific to original FermiNet architecture.
# Only used if network_type is 'ferminet'.
'ferminet': {
Expand Down
54 changes: 46 additions & 8 deletions ferminet/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __call__(
charges: jnp.ndarray,
nspins: Sequence[int],
use_scan: bool = False,
complex_output: bool = False,
**kwargs: Any
) -> LocalEnergy:
"""Builds the LocalEnergy function.
Expand All @@ -61,6 +62,7 @@ def __call__(
charges: nuclear charges.
nspins: Number of particles of each spin.
use_scan: Whether to use a `lax.scan` for computing the laplacian.
complex_output: If true, the output of f is complex-valued.
**kwargs: additional kwargs to use for creating the specific Hamiltonian.
"""

Expand All @@ -70,37 +72,70 @@ def __call__(
]


def select_output(f: Callable[..., Sequence[Any]],
argnum: int) -> Callable[..., Any]:
"""Return the argnum-th result from callable f."""

def f_selected(*args, **kwargs):
return f(*args, **kwargs)[argnum]

return f_selected


def local_kinetic_energy(
f: networks.LogFermiNetLike, use_scan: bool = False
f: networks.FermiNetLike,
use_scan: bool = False,
complex_output: bool = False,
) -> KineticEnergy:
r"""Creates a function to for the local kinetic energy, -1/2 \nabla^2 ln|f|.
Args:
f: Callable which evaluates the log of the magnitude of the wavefunction.
f: Callable which evaluates the wavefunction as a
(sign or phase, log magnitude) tuple.
use_scan: Whether to use a `lax.scan` for computing the laplacian.
complex_output: If true, the output of f is complex-valued.
Returns:
Callable which evaluates the local kinetic energy,
-1/2f \nabla^2 f = -1/2 (\nabla^2 log|f| + (\nabla log|f|)^2).
"""

phase_f = select_output(f, 0)
logabs_f = select_output(f, 1)

def _lapl_over_f(params, data):
n = data.positions.shape[0]
eye = jnp.eye(n)
grad_f = jax.grad(f, argnums=1)
grad_f = jax.grad(logabs_f, argnums=1)
def grad_f_closure(x):
return grad_f(params, x, data.spins, data.atoms, data.charges)

primal, dgrad_f = jax.linearize(grad_f_closure, data.positions)

if complex_output:
grad_phase = jax.grad(phase_f, argnums=1)
def grad_phase_closure(x):
return grad_phase(params, x, data.spins, data.atoms, data.charges)
phase_primal, dgrad_phase = jax.linearize(
grad_phase_closure, data.positions)
hessian_diagonal = (
lambda i: dgrad_f(eye[i])[i] + 1.j * dgrad_phase(eye[i])[i]
)
else:
hessian_diagonal = lambda i: dgrad_f(eye[i])[i]

if use_scan:
_, diagonal = lax.scan(
lambda i, _: (i + 1, dgrad_f(eye[i])[i]), 0, None, length=n)
lambda i, _: (i + 1, hessian_diagonal(i)), 0, None, length=n)
result = -0.5 * jnp.sum(diagonal)
else:
result = -0.5 * lax.fori_loop(
0, n, lambda i, val: val + dgrad_f(eye[i])[i], 0.0)
return result - 0.5 * jnp.sum(primal ** 2)
0, n, lambda i, val: val + hessian_diagonal(i), 0.0)
result -= 0.5 * jnp.sum(primal ** 2)
if complex_output:
result += 0.5 * jnp.sum(phase_primal ** 2)
result -= 1.j * jnp.sum(primal * phase_primal)
return result

return _lapl_over_f

Expand Down Expand Up @@ -165,6 +200,7 @@ def local_energy(
charges: jnp.ndarray,
nspins: Sequence[int],
use_scan: bool = False,
complex_output: bool = False,
) -> LocalEnergy:
"""Creates the function to evaluate the local energy.
Expand All @@ -175,15 +211,17 @@ def local_energy(
charges: Shape (natoms). Nuclear charges of the atoms.
nspins: Number of particles of each spin.
use_scan: Whether to use a `lax.scan` for computing the laplacian.
complex_output: If true, the output of f is complex-valued.
Returns:
Callable with signature e_l(params, key, data) which evaluates the local
energy of the wavefunction given the parameters params, RNG state key,
and a single MCMC configuration in data.
"""
del nspins
log_abs_f = lambda *args, **kwargs: f(*args, **kwargs)[1]
ke = local_kinetic_energy(log_abs_f, use_scan=use_scan)
ke = local_kinetic_energy(f,
use_scan=use_scan,
complex_output=complex_output)

def _e_l(
params: networks.ParamTree, key: chex.PRNGKey, data: networks.FermiNetData
Expand Down
50 changes: 39 additions & 11 deletions ferminet/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def clip_local_values(
clip_scale: float,
clip_from_median: bool,
center_at_clipped_value: bool,
complex_output: bool = False,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Clips local operator estimates to remove outliers.
Expand All @@ -91,6 +92,7 @@ def clip_local_values(
center_at_clipped_value: If true, center the local energy differences passed
back to the gradient around the clipped quantities, so the mean difference
across the batch is guaranteed to be zero.
complex_output: If true, the local energies will be complex valued.
Returns:
Tuple of the central value (estimate of the expectation value of the
Expand All @@ -101,16 +103,27 @@ def clip_local_values(

batch_mean = lambda values: constants.pmean(jnp.mean(values))

def clip_at_total_variation(values, center, scale):
tv = batch_mean(jnp.abs(values- center))
return jnp.clip(values, center - scale * tv, center + scale * tv)

if clip_from_median:
# More natural place to center the clipping, but expensive due to both
# the median and all_gather (at least on multihost)
clip_center = jnp.median(constants.all_gather(local_values))
else:
clip_center = mean_local_values
# roughly, the total variation of the local energies
tv = batch_mean(jnp.abs(local_values - clip_center))
clipped_local_values = jnp.clip(local_values, clip_center - clip_scale * tv,
clip_center + clip_scale * tv)
if complex_output:
clipped_local_values = (
clip_at_total_variation(
local_values.real, clip_center.real, clip_scale) +
1.j * clip_at_total_variation(
local_values.imag, clip_center.imag, clip_scale)
)
else:
clipped_local_values = clip_at_total_variation(
local_values, clip_center, clip_scale)
if center_at_clipped_value:
diff_center = batch_mean(clipped_local_values)
else:
Expand All @@ -123,7 +136,8 @@ def make_loss(network: networks.LogFermiNetLike,
local_energy: hamiltonian.LocalEnergy,
clip_local_energy: float = 0.0,
clip_from_median: bool = True,
center_at_clipped_energy: bool = True) -> LossFn:
center_at_clipped_energy: bool = True,
complex_output: bool = False) -> LossFn:
"""Creates the loss function, including custom gradients.
Args:
Expand All @@ -142,6 +156,7 @@ def make_loss(network: networks.LogFermiNetLike,
center_at_clipped_energy: If true, center the local energy differences
passed back to the gradient around the clipped local energy, so the mean
difference across the batch is guaranteed to be zero.
complex_output: If true, the local energies will be complex valued.
Returns:
Callable with signature (params, data) and returns (loss, aux_data), where
Expand Down Expand Up @@ -185,9 +200,10 @@ def total_energy(
keys = jax.random.split(key, num=data.positions.shape[0])
e_l = batch_local_energy(params, keys, data)
loss = constants.pmean(jnp.mean(e_l))
variance = constants.pmean(jnp.mean((e_l - loss)**2))
loss_diff = e_l - loss
variance = constants.pmean(jnp.mean(loss_diff * jnp.conj(loss_diff)))
return loss, AuxiliaryLossData(
variance=variance, local_energy=e_l, clipped_energy=e_l)
variance=variance.real, local_energy=e_l, clipped_energy=e_l)

@total_energy.defjvp
def total_energy_jvp(primals, tangents): # pylint: disable=unused-variable
Expand All @@ -201,7 +217,8 @@ def total_energy_jvp(primals, tangents): # pylint: disable=unused-variable
loss,
clip_local_energy,
clip_from_median,
center_at_clipped_energy)
center_at_clipped_energy,
complex_output)
else:
diff = aux_data.local_energy - loss

Expand All @@ -220,10 +237,21 @@ def total_energy_jvp(primals, tangents): # pylint: disable=unused-variable
data_tangents.charges,
)
psi_primal, psi_tangent = jax.jvp(batch_network, primals, tangents)
kfac_jax.register_normal_predictive_distribution(psi_primal[:, None])
primals_out = loss, aux_data
device_batch_size = jnp.shape(aux_data.local_energy)[0]
tangents_out = (jnp.dot(psi_tangent, diff) / device_batch_size, aux_data)
if complex_output:
clipped_el = diff + aux_data.clipped_energy
term1 = (jnp.dot(clipped_el, jnp.conjugate(psi_tangent)) +
jnp.dot(jnp.conjugate(clipped_el), psi_tangent))
term2 = jnp.sum(aux_data.clipped_energy*psi_tangent.real)
kfac_jax.register_normal_predictive_distribution(
psi_primal.real[:, None])
primals_out = loss.real, aux_data
device_batch_size = jnp.shape(aux_data.local_energy)[0]
tangents_out = ((term1 - 2*term2).real / device_batch_size, aux_data)
else:
kfac_jax.register_normal_predictive_distribution(psi_primal[:, None])
primals_out = loss, aux_data
device_batch_size = jnp.shape(aux_data.local_energy)[0]
tangents_out = (jnp.dot(psi_tangent, diff) / device_batch_size, aux_data)
return primals_out, tangents_out

return total_energy
17 changes: 12 additions & 5 deletions ferminet/network_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def slogdet(x):
sign, (natural) logarithm of the determinant of x.
"""
if x.shape[-1] == 1:
sign = jnp.sign(x[..., 0, 0])
if x.dtype == jnp.complex64 or x.dtype == jnp.complex128:
sign = x[..., 0, 0] / jnp.abs(x[..., 0, 0])
else:
sign = jnp.sign(x[..., 0, 0])
logdet = jnp.log(jnp.abs(x[..., 0, 0]))
else:
sign, logdet = jnp.linalg.slogdet(x)
Expand Down Expand Up @@ -154,17 +157,21 @@ def logdet_matmul(
[x.reshape(-1) for x in xs if x.shape[-1] == 1], 1)
# Pass initial value to functools so sign_in = 1, logdet = 0 if all matrices
# are 1x1.
sign_in, logdet = functools.reduce(
phase_in, logdet = functools.reduce(
lambda a, b: (a[0] * b[0], a[1] + b[1]),
[slogdet(x) for x in xs if x.shape[-1] > 1], (1, 0))

# log-sum-exp trick
maxlogdet = jnp.max(logdet)
det = sign_in * det1d * jnp.exp(logdet - maxlogdet)
det = phase_in * det1d * jnp.exp(logdet - maxlogdet)
if w is None:
result = jnp.sum(det)
else:
result = jnp.matmul(det, w)[0]
sign_out = jnp.sign(result)
# return phase as a unit-norm complex number, rather than as an angle
if result.dtype == jnp.complex64 or result.dtype == jnp.complex128:
phase_out = jnp.angle(result) # result / jnp.abs(result)
else:
phase_out = jnp.sign(result)
log_out = jnp.log(jnp.abs(result)) + maxlogdet
return sign_out, log_out
return phase_out, log_out
17 changes: 16 additions & 1 deletion ferminet/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ class BaseNetworkOptions:
feature_layer: Feature object to create and apply the input features for the
one- and two-electron layers.
jastrow: Type of Jastrow factor if used, or 'none' if no Jastrow factor.
complex_output: If true, the network outputs complex numbers.
"""

ndim: int = 3
Expand All @@ -289,6 +290,7 @@ class BaseNetworkOptions:
takes_self=False))
feature_layer: FeatureLayer = None
jastrow: jastrows.JastrowType = jastrows.JastrowType.NONE
complex_output: bool = False


@attr.s(auto_attribs=True, kw_only=True)
Expand Down Expand Up @@ -1103,6 +1105,8 @@ def init(key: chex.PRNGKey) -> ParamTree:
# Spin-factored block-diagonal determinant. Need nspin orbitals per
# electron per determinant.
norbitals = nspin * options.determinants
if options.complex_output:
norbitals *= 2 # one output is real, one is imaginary
nspin_orbitals.append(norbitals)

# create envelope params
Expand All @@ -1112,7 +1116,10 @@ def init(key: chex.PRNGKey) -> ParamTree:
output_dims = dims_orbital_in
elif options.envelope.apply_type == envelopes.EnvelopeType.PRE_DETERMINANT:
# Applied to orbitals.
output_dims = nspin_orbitals
if options.complex_output:
output_dims = [nspin_orbital // 2 for nspin_orbital in nspin_orbitals]
else:
output_dims = nspin_orbitals
else:
raise ValueError('Unknown envelope type')
params['envelope'] = options.envelope.init(
Expand Down Expand Up @@ -1191,6 +1198,11 @@ def apply(
network_blocks.linear_layer(h, **p)
for h, p in zip(h_to_orbitals, params['orbital'])
]
if options.complex_output:
# create imaginary orbitals
orbitals = [
orbital[..., ::2] + 1.0j * orbital[..., 1::2] for orbital in orbitals
]

# Apply envelopes if required.
if options.envelope.apply_type == envelopes.EnvelopeType.PRE_DETERMINANT:
Expand Down Expand Up @@ -1242,6 +1254,7 @@ def make_fermi_net(
envelope: Optional[envelopes.Envelope] = None,
feature_layer: Optional[FeatureLayer] = None,
jastrow: Union[str, jastrows.JastrowType] = jastrows.JastrowType.NONE,
complex_output: bool = False,
bias_orbitals: bool = False,
full_det: bool = True,
rescale_inputs: bool = False,
Expand All @@ -1264,6 +1277,7 @@ def make_fermi_net(
envelope: Envelope to use to impose orbitals go to zero at infinity.
feature_layer: Input feature construction.
jastrow: Type of Jastrow factor if used, or no jastrow if 'default'.
complex_output: If true, the network outputs complex numbers.
bias_orbitals: If true, include a bias in the final linear layer to shape
the outputs into orbitals.
full_det: If true, evaluate determinants over all electrons. Otherwise,
Expand Down Expand Up @@ -1320,6 +1334,7 @@ def make_fermi_net(
envelope=envelope,
feature_layer=feature_layer,
jastrow=jastrow,
complex_output=complex_output,
bias_orbitals=bias_orbitals,
full_det=full_det,
hidden_dims=hidden_dims,
Expand Down
6 changes: 4 additions & 2 deletions ferminet/pbc/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def local_energy(
charges: jnp.ndarray,
nspins: Sequence[int],
use_scan: bool = False,
complex_output: bool = False,
lattice: Optional[jnp.ndarray] = None,
heg: bool = True,
convergence_radius: int = 5,
Expand All @@ -169,6 +170,7 @@ def local_energy(
charges: Shape (natoms). Nuclear charges of the atoms.
nspins: Number of particles of each spin.
use_scan: Whether to use a `lax.scan` for computing the laplacian.
complex_output: If true, the output of f is complex-valued.
lattice: Shape (ndim, ndim). Matrix of lattice vectors. Default: identity
matrix.
heg: bool. Flag to enable features specific to the electron gas.
Expand All @@ -183,8 +185,8 @@ def local_energy(
if lattice is None:
lattice = jnp.eye(3)

log_abs_f = lambda *args, **kwargs: f(*args, **kwargs)[1]
ke = hamiltonian.local_kinetic_energy(log_abs_f, use_scan=use_scan)
ke = hamiltonian.local_kinetic_energy(f, use_scan=use_scan,
complex_output=complex_output)
potential_energy = make_ewald_potential(
lattice, atoms, charges, convergence_radius, heg
)
Expand Down
Loading

0 comments on commit 7ae9502

Please sign in to comment.