Skip to content

Commit

Permalink
Add bottleneck envelope to FermiNet and Psiformer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 580528927
Change-Id: If0a69dec087b128a9eb8459b5eaa3afa4b87ec5a
  • Loading branch information
dpfau authored and jsspencer committed Nov 24, 2023
1 parent c920271 commit 5394dba
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions ferminet/envelopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class EnvelopeType(enum.Enum):
class EnvelopeLabel(enum.Enum):
"""Available multiplicative envelope functions."""
ISOTROPIC = enum.auto()
BOTTLENECK = enum.auto()
DIAGONAL = enum.auto()
FULL = enum.auto()
NULL = enum.auto()
Expand Down Expand Up @@ -123,6 +124,47 @@ def apply(*, ae: jnp.ndarray, r_ae: jnp.ndarray, r_ee: jnp.ndarray,
return Envelope(EnvelopeType.PRE_DETERMINANT, init, apply)


def make_bottleneck_envelope(nenv: int = 16) -> Envelope:
"""Each orbital has a linear projection of a small number of envelopes.
Rather than a separate envelope for all num_determinant*num_electron
orbitals, construct a fixed number of envelopes and then project those
envelopes linearly into a num_determinant*num_electron space, one per
orbital. This has minimal impact on time but a significant impact on space.
This also is *slightly* more expressive than the isotropic envelope in some
cases, leading to improved accuracy on some systems, while being noisier on
others. This also leads to occasional numerical instability, so it is
recommended to set reset_if_nan to True when using this.
Args:
nenv: the fixed number of envelopes. Ideally smaller than num_determinants*
num_electrons
Returns:
An Envelope object with a type specifier, init and apply functions.
"""

def init(
natom: int, output_dims: Sequence[int], ndim: int = 3
) -> Sequence[Mapping[str, jnp.ndarray]]:
del ndim # unused
params = []
for output_dim in output_dims:
params.append({
'pi': jnp.ones(shape=(natom, nenv)),
'sigma': jnp.ones(shape=(natom, nenv)),
'w': jnp.ones(shape=(nenv, output_dim)) / nenv,
})
return params

def apply(*, ae: jnp.ndarray, r_ae: jnp.ndarray, r_ee: jnp.ndarray,
pi: jnp.ndarray, sigma: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray:
del ae, r_ee # unused
return jnp.sum(jnp.exp(-r_ae * sigma) * pi, axis=1) @ w

return Envelope(EnvelopeType.PRE_DETERMINANT, init, apply)


def make_diagonal_envelope() -> Envelope:
"""Creates a diagonal exponentially-decaying multiplicative envelope."""

Expand Down Expand Up @@ -268,6 +310,7 @@ def get_envelope(
EnvelopeLabel.STO: make_sto_envelope,
EnvelopeLabel.STO_POLY: make_sto_poly_envelope,
EnvelopeLabel.ISOTROPIC: make_isotropic_envelope,
EnvelopeLabel.BOTTLENECK: make_bottleneck_envelope,
EnvelopeLabel.DIAGONAL: make_diagonal_envelope,
EnvelopeLabel.FULL: make_full_envelope,
EnvelopeLabel.NULL: make_null_envelope,
Expand Down

0 comments on commit 5394dba

Please sign in to comment.