Skip to content

Commit

Permalink
Add option to use folx.batched_vmap in loss to reduce memory overhead
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675496946
Change-Id: Ib78b2445ae7577743142d006060e3541ae2bf645
  • Loading branch information
dpfau authored and jsspencer committed Sep 24, 2024
1 parent 67cf795 commit fa11e80
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
2 changes: 2 additions & 0 deletions ferminet/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def default() -> ml_collections.ConfigDict:
'iterations': 1000000, # number of iterations
'optimizer': 'kfac', # one of adam, kfac, lamb, none
'laplacian': 'default', # of of default or folx (for forward lapl)
# If 0, use standard vmap. If >0, the max batch size for batched_vmap
'max_vmap_batch_size': 0,
'lr': {
'rate': 0.05, # learning rate
'decay': 1.0, # exponent of learning rate decay
Expand Down
33 changes: 25 additions & 8 deletions ferminet/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

"""Helper functions to create the loss and custom gradient of the loss."""

import functools
from typing import Tuple

import chex
from ferminet import constants
from ferminet import hamiltonian
from ferminet import networks
import folx
import jax
import jax.numpy as jnp
import kfac_jax
Expand Down Expand Up @@ -153,7 +155,8 @@ def make_loss(network: networks.LogFermiNetLike,
clip_local_energy: float = 0.0,
clip_from_median: bool = True,
center_at_clipped_energy: bool = True,
complex_output: bool = False) -> LossFn:
complex_output: bool = False,
max_vmap_batch_size: int = 0) -> LossFn:
"""Creates the loss function, including custom gradients.
Args:
Expand All @@ -173,13 +176,17 @@ def make_loss(network: networks.LogFermiNetLike,
passed back to the gradient around the clipped local energy, so the mean
difference across the batch is guaranteed to be zero.
complex_output: If true, the local energies will be complex valued.
max_vmap_batch_size: If 0, use standard vmap. If >0, use batched_vmap with
the given batch size.
Returns:
Callable with signature (params, data) and returns (loss, aux_data), where
loss is the mean energy, and aux_data is an AuxiliaryLossDataobject. The
loss is averaged over the batch and over all devices inside a pmap.
"""
batch_local_energy = jax.vmap(
vmap = jax.vmap if max_vmap_batch_size == 0 else functools.partial(
folx.batched_vmap, max_batch_size=max_vmap_batch_size)
batch_local_energy = vmap(
local_energy,
in_axes=(
None,
Expand All @@ -188,7 +195,7 @@ def make_loss(network: networks.LogFermiNetLike,
),
out_axes=(0, 0)
)
batch_network = jax.vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0)
batch_network = vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0)

@jax.custom_jvp
def total_energy(
Expand Down Expand Up @@ -285,6 +292,7 @@ def make_wqmc_loss(
clip_from_median: bool = True,
center_at_clipped_energy: bool = True,
complex_output: bool = False,
max_vmap_batch_size: int = 0,
vmc_weight: float = 0.0
) -> LossFn:
"""Creates the WQMC loss function, including custom gradients.
Expand All @@ -306,6 +314,8 @@ def make_wqmc_loss(
passed back to the gradient around the clipped local energy, so the mean
difference across the batch is guaranteed to be zero.
complex_output: If true, the local energies will be complex valued.
max_vmap_batch_size: If 0, use standard vmap. If >0, use batched_vmap with
the given batch size.
vmc_weight: The weight of the contribution from the standard VMC energy
gradient.
Expand All @@ -314,7 +324,9 @@ def make_wqmc_loss(
loss is the mean energy, and aux_data is an AuxiliaryLossDataobject. The
loss is averaged over the batch and over all devices inside a pmap.
"""
batch_local_energy = jax.vmap(
vmap = jax.vmap if max_vmap_batch_size == 0 else functools.partial(
folx.batched_vmap, max_batch_size=max_vmap_batch_size)
batch_local_energy = vmap(
local_energy,
in_axes=(
None,
Expand All @@ -323,7 +335,7 @@ def make_wqmc_loss(
),
out_axes=(0, 0)
)
batch_network = jax.vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0)
batch_network = vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0)

@jax.custom_jvp
def total_energy(
Expand Down Expand Up @@ -434,7 +446,8 @@ def make_energy_overlap_loss(network: networks.LogFermiNetLike,
center_at_clipped_energy: bool = True,
overlap_penalty: float = 1.0,
overlap_weight: Tuple[float, ...] = (1.0,),
complex_output: bool = False) -> LossFn:
complex_output: bool = False,
max_vmap_batch_size: int = 0) -> LossFn:
"""Creates the loss function for the penalty method for excited states.
Args:
Expand Down Expand Up @@ -462,15 +475,19 @@ def make_energy_overlap_loss(network: networks.LogFermiNetLike,
overlap_weight: The weight to apply to each individual energy in the overall
optimization.
complex_output: If true, the network output is complex-valued.
max_vmap_batch_size: If 0, use standard vmap. If >0, use batched_vmap with
the given batch size.
Returns:
LossFn callable which evaluates the total energy of the system.
"""

vmap = jax.vmap if max_vmap_batch_size == 0 else functools.partial(
folx.batched_vmap, max_batch_size=max_vmap_batch_size)
data_axes = networks.FermiNetData(positions=0, spins=0, atoms=0, charges=0)
batch_local_energy = jax.vmap(
batch_local_energy = vmap(
local_energy, in_axes=(None, 0, data_axes), out_axes=(0, 0))
batch_network = jax.vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0)
batch_network = vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0)
overlap_weight = jnp.array(overlap_weight)

# TODO(pfau): how much of this can be factored out with make_loss?
Expand Down
5 changes: 4 additions & 1 deletion ferminet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,7 @@ def local_energy_and_s2_fn(params, keys, data):
clip_from_median=cfg.optim.clip_median,
center_at_clipped_energy=cfg.optim.center_at_clip,
complex_output=use_complex,
max_vmap_batch_size=cfg.optim.get('max_vmap_batch_size', 0),
)
elif cfg.optim.objective == 'wqmc':
evaluate_loss = qmc_loss_functions.make_wqmc_loss(
Expand All @@ -779,6 +780,7 @@ def local_energy_and_s2_fn(params, keys, data):
clip_from_median=cfg.optim.clip_median,
center_at_clipped_energy=cfg.optim.center_at_clip,
complex_output=use_complex,
max_vmap_batch_size=cfg.optim.get('max_vmap_batch_size', 0),
vmc_weight=cfg.optim.get('vmc_weight', 1.0)
)
elif cfg.optim.objective == 'vmc_overlap':
Expand All @@ -798,7 +800,8 @@ def local_energy_and_s2_fn(params, keys, data):
center_at_clipped_energy=cfg.optim.center_at_clip,
overlap_penalty=cfg.optim.overlap.penalty,
overlap_weight=overlap_weight,
complex_output=cfg.network.get('complex', False))
complex_output=cfg.network.get('complex', False),
max_vmap_batch_size=cfg.optim.get('max_vmap_batch_size', 0))
else:
raise ValueError(f'Not a recognized objective: {cfg.optim.objective}')

Expand Down

0 comments on commit fa11e80

Please sign in to comment.