From fa11e8044e5c9eaae5d021521031d6d048a40989 Mon Sep 17 00:00:00 2001 From: David Pfau Date: Tue, 17 Sep 2024 11:29:17 +0100 Subject: [PATCH] Add option to use folx.batched_vmap in loss to reduce memory overhead PiperOrigin-RevId: 675496946 Change-Id: Ib78b2445ae7577743142d006060e3541ae2bf645 --- ferminet/base_config.py | 2 ++ ferminet/loss.py | 33 +++++++++++++++++++++++++-------- ferminet/train.py | 5 ++++- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/ferminet/base_config.py b/ferminet/base_config.py index 19db657..cadf1d0 100644 --- a/ferminet/base_config.py +++ b/ferminet/base_config.py @@ -61,6 +61,8 @@ def default() -> ml_collections.ConfigDict: 'iterations': 1000000, # number of iterations 'optimizer': 'kfac', # one of adam, kfac, lamb, none 'laplacian': 'default', # of of default or folx (for forward lapl) + # If 0, use standard vmap. If >0, the max batch size for batched_vmap + 'max_vmap_batch_size': 0, 'lr': { 'rate': 0.05, # learning rate 'decay': 1.0, # exponent of learning rate decay diff --git a/ferminet/loss.py b/ferminet/loss.py index d076ed1..06821ae 100644 --- a/ferminet/loss.py +++ b/ferminet/loss.py @@ -14,12 +14,14 @@ """Helper functions to create the loss and custom gradient of the loss.""" +import functools from typing import Tuple import chex from ferminet import constants from ferminet import hamiltonian from ferminet import networks +import folx import jax import jax.numpy as jnp import kfac_jax @@ -153,7 +155,8 @@ def make_loss(network: networks.LogFermiNetLike, clip_local_energy: float = 0.0, clip_from_median: bool = True, center_at_clipped_energy: bool = True, - complex_output: bool = False) -> LossFn: + complex_output: bool = False, + max_vmap_batch_size: int = 0) -> LossFn: """Creates the loss function, including custom gradients. Args: @@ -173,13 +176,17 @@ def make_loss(network: networks.LogFermiNetLike, passed back to the gradient around the clipped local energy, so the mean difference across the batch is guaranteed to be zero. complex_output: If true, the local energies will be complex valued. + max_vmap_batch_size: If 0, use standard vmap. If >0, use batched_vmap with + the given batch size. Returns: Callable with signature (params, data) and returns (loss, aux_data), where loss is the mean energy, and aux_data is an AuxiliaryLossDataobject. The loss is averaged over the batch and over all devices inside a pmap. """ - batch_local_energy = jax.vmap( + vmap = jax.vmap if max_vmap_batch_size == 0 else functools.partial( + folx.batched_vmap, max_batch_size=max_vmap_batch_size) + batch_local_energy = vmap( local_energy, in_axes=( None, @@ -188,7 +195,7 @@ def make_loss(network: networks.LogFermiNetLike, ), out_axes=(0, 0) ) - batch_network = jax.vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0) + batch_network = vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0) @jax.custom_jvp def total_energy( @@ -285,6 +292,7 @@ def make_wqmc_loss( clip_from_median: bool = True, center_at_clipped_energy: bool = True, complex_output: bool = False, + max_vmap_batch_size: int = 0, vmc_weight: float = 0.0 ) -> LossFn: """Creates the WQMC loss function, including custom gradients. @@ -306,6 +314,8 @@ def make_wqmc_loss( passed back to the gradient around the clipped local energy, so the mean difference across the batch is guaranteed to be zero. complex_output: If true, the local energies will be complex valued. + max_vmap_batch_size: If 0, use standard vmap. If >0, use batched_vmap with + the given batch size. vmc_weight: The weight of the contribution from the standard VMC energy gradient. @@ -314,7 +324,9 @@ def make_wqmc_loss( loss is the mean energy, and aux_data is an AuxiliaryLossDataobject. The loss is averaged over the batch and over all devices inside a pmap. """ - batch_local_energy = jax.vmap( + vmap = jax.vmap if max_vmap_batch_size == 0 else functools.partial( + folx.batched_vmap, max_batch_size=max_vmap_batch_size) + batch_local_energy = vmap( local_energy, in_axes=( None, @@ -323,7 +335,7 @@ def make_wqmc_loss( ), out_axes=(0, 0) ) - batch_network = jax.vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0) + batch_network = vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0) @jax.custom_jvp def total_energy( @@ -434,7 +446,8 @@ def make_energy_overlap_loss(network: networks.LogFermiNetLike, center_at_clipped_energy: bool = True, overlap_penalty: float = 1.0, overlap_weight: Tuple[float, ...] = (1.0,), - complex_output: bool = False) -> LossFn: + complex_output: bool = False, + max_vmap_batch_size: int = 0) -> LossFn: """Creates the loss function for the penalty method for excited states. Args: @@ -462,15 +475,19 @@ def make_energy_overlap_loss(network: networks.LogFermiNetLike, overlap_weight: The weight to apply to each individual energy in the overall optimization. complex_output: If true, the network output is complex-valued. + max_vmap_batch_size: If 0, use standard vmap. If >0, use batched_vmap with + the given batch size. Returns: LossFn callable which evaluates the total energy of the system. """ + vmap = jax.vmap if max_vmap_batch_size == 0 else functools.partial( + folx.batched_vmap, max_batch_size=max_vmap_batch_size) data_axes = networks.FermiNetData(positions=0, spins=0, atoms=0, charges=0) - batch_local_energy = jax.vmap( + batch_local_energy = vmap( local_energy, in_axes=(None, 0, data_axes), out_axes=(0, 0)) - batch_network = jax.vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0) + batch_network = vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0) overlap_weight = jnp.array(overlap_weight) # TODO(pfau): how much of this can be factored out with make_loss? diff --git a/ferminet/train.py b/ferminet/train.py index 5c4adbb..7abd138 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -770,6 +770,7 @@ def local_energy_and_s2_fn(params, keys, data): clip_from_median=cfg.optim.clip_median, center_at_clipped_energy=cfg.optim.center_at_clip, complex_output=use_complex, + max_vmap_batch_size=cfg.optim.get('max_vmap_batch_size', 0), ) elif cfg.optim.objective == 'wqmc': evaluate_loss = qmc_loss_functions.make_wqmc_loss( @@ -779,6 +780,7 @@ def local_energy_and_s2_fn(params, keys, data): clip_from_median=cfg.optim.clip_median, center_at_clipped_energy=cfg.optim.center_at_clip, complex_output=use_complex, + max_vmap_batch_size=cfg.optim.get('max_vmap_batch_size', 0), vmc_weight=cfg.optim.get('vmc_weight', 1.0) ) elif cfg.optim.objective == 'vmc_overlap': @@ -798,7 +800,8 @@ def local_energy_and_s2_fn(params, keys, data): center_at_clipped_energy=cfg.optim.center_at_clip, overlap_penalty=cfg.optim.overlap.penalty, overlap_weight=overlap_weight, - complex_output=cfg.network.get('complex', False)) + complex_output=cfg.network.get('complex', False), + max_vmap_batch_size=cfg.optim.get('max_vmap_batch_size', 0)) else: raise ValueError(f'Not a recognized objective: {cfg.optim.objective}')