Skip to content

Commit

Permalink
Merge pull request #109 from mila-iqia/langevin_adaptative_corrector
Browse files Browse the repository at this point in the history
Langevin adaptative corrector
  • Loading branch information
sblackburn86 authored Dec 3, 2024
2 parents c382f6e + 568bd78 commit 2f6bbb3
Show file tree
Hide file tree
Showing 7 changed files with 323 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Optional

import torch

from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \
LangevinGenerator
from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \
PredictorCorrectorSamplingParameters
from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \
ScoreNetwork
from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \
NoiseParameters


class AdaptiveCorrectorGenerator(LangevinGenerator):
"""Langevin Dynamics Generator using only a corrector step with adaptive step size for relative coordinates.
This class implements the Langevin Corrector generation of position samples, following
Song et. al. 2021, namely:
"SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS"
"""

def __init__(
self,
noise_parameters: NoiseParameters,
sampling_parameters: PredictorCorrectorSamplingParameters,
axl_network: ScoreNetwork,
):
"""Init method."""
super().__init__(
noise_parameters=noise_parameters,
sampling_parameters=sampling_parameters,
axl_network=axl_network,
)
self.corrector_r = noise_parameters.corrector_r

def _relative_coordinates_update_predictor_step(
self,
relative_coordinates: torch.Tensor,
sigma_normalized_scores: torch.Tensor,
sigma_i: torch.Tensor,
score_weight: torch.Tensor,
gaussian_noise_weight: torch.Tensor,
z: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Do not update the relative coordinates in the predictor."""
return relative_coordinates

def _get_corrector_step_size(
self,
index_i: int,
sigma_i: torch.Tensor,
model_predictions_i: AXL,
z: torch.Tensor,
) -> torch.Tensor:
r"""Compute the size of the corrector step for the relative coordinates update.
Always affect the reduced coordinates and lattice vectors. The prefactors determining the changes in the
relative coordinates are determined using the sigma normalized score at that corrector step. The relative
coordinates update is given by:
.. math::
x_i \leftarrow x_i + \epsilon_i * s(x_i, t_i) + \sqrt(2 \epsilon_i) z
where :math:`s(x_i, t_i)` is the score, :math:`z` is a random variable drawn from a normal distribution and
:math:`\epsilon_i` is given by:
.. math::
\epsilon_i = 2 \left(r \frac{||z||_2}{||s(x_i, t_i)||_2}\right)^2
where :math:`r` is an hyper-parameter (0.17 by default) and :math:`||\cdot||_2` is the L2 norm.
"""
# to compute epsilon_i, we need the norm of the score summed over the atoms and averaged over the mini-batch.
# taking the norm over the last 2 dimensions means summing the squared components over the spatial dimension and
# the atoms, then taking the square-root.
# the mean averages over the mini-batch
relative_coordinates_sigma_score_norm = (
torch.linalg.norm(model_predictions_i.X, dim=[-2, -1]).mean()
).view(1, 1, 1)
# note that sigma_score is \sigma * s(x, t), so we need to divide the norm by sigma to get the correct step size
relative_coordinates_sigma_score_norm /= sigma_i
# compute the norm of the z random noise similarly
z_norm = torch.linalg.norm(z, dim=[-2, -1]).mean().view(1, 1, 1)

eps_i = (
2
* (
self.corrector_r
* z_norm
/ (relative_coordinates_sigma_score_norm.clip(min=self.small_epsilon))
)
** 2
)

return eps_i
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from diffusion_for_multi_scale_molecular_dynamics.generators.adaptive_corrector import \
AdaptiveCorrectorGenerator
from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \
SamplingParameters
from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \
Expand All @@ -22,7 +24,8 @@ def instantiate_generator(
"ode",
"sde",
"predictor_corrector",
], "Unknown algorithm. Possible choices are 'ode', 'sde' and 'predictor_corrector'"
"adaptive_corrector",
], "Unknown algorithm. Possible choices are 'ode', 'sde', 'predictor_corrector' and 'adaptive_corrector'"

match sampling_parameters.algorithm:
case "predictor_corrector":
Expand All @@ -31,6 +34,12 @@ def instantiate_generator(
noise_parameters=noise_parameters,
axl_network=axl_network,
)
case "adaptive_corrector":
generator = AdaptiveCorrectorGenerator(
sampling_parameters=sampling_parameters,
noise_parameters=noise_parameters,
axl_network=axl_network,
)
case "ode":
generator = ExplodingVarianceODEAXLGenerator(
sampling_parameters=sampling_parameters,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def _relative_coordinates_update(
sigma_i: torch.Tensor,
score_weight: torch.Tensor,
gaussian_noise_weight: torch.Tensor,
z: torch.Tensor,
) -> torch.Tensor:
r"""Generic update for the relative coordinates.
Expand All @@ -186,13 +187,16 @@ def _relative_coordinates_update(
eps_i in the corrector step. Dimension: [number_of_samples]
gaussian_noise_weight: prefactor in front of the random noise update. Should be g_i in the predictor step
and sqrt_2eps_i in the corrector step. Dimension: [number_of_samples]
z: gaussian noise used to update the coordinates. A sample drawn from the normal distribution.
Dimension: [number_of_samples, number_of_atoms, spatial_dimension].
Returns:
updated_coordinates: relative coordinates after the update. Dimension: [number_of_samples, number_of_atoms,
spatial_dimension].
"""
number_of_samples = relative_coordinates.shape[0]
z = self._draw_gaussian_sample(number_of_samples).to(relative_coordinates)
if z is None:
z = self._draw_gaussian_sample(number_of_samples).to(relative_coordinates)
updated_coordinates = (
relative_coordinates
+ score_weight * sigma_normalized_scores / sigma_i
Expand All @@ -202,6 +206,50 @@ def _relative_coordinates_update(
updated_coordinates = map_relative_coordinates_to_unit_cell(updated_coordinates)
return updated_coordinates

def _relative_coordinates_update_predictor_step(
self,
relative_coordinates: torch.Tensor,
sigma_normalized_scores: torch.Tensor,
sigma_i: torch.Tensor,
score_weight: torch.Tensor,
gaussian_noise_weight: torch.Tensor,
z: torch.Tensor,
) -> torch.Tensor:
"""Relative coordinates update for the predictor step.
This returns the generic _relative_coordinates_update.
"""
return self._relative_coordinates_update(
relative_coordinates,
sigma_normalized_scores,
sigma_i,
score_weight,
gaussian_noise_weight,
z,
)

def _relative_coordinates_update_corrector_step(
self,
relative_coordinates: torch.Tensor,
sigma_normalized_scores: torch.Tensor,
sigma_i: torch.Tensor,
score_weight: torch.Tensor,
gaussian_noise_weight: torch.Tensor,
z: torch.Tensor,
) -> torch.Tensor:
"""Relative coordinates update for the corrector step.
This returns the generic _relative_coordinates_update.
"""
return self._relative_coordinates_update(
relative_coordinates,
sigma_normalized_scores,
sigma_i,
score_weight,
gaussian_noise_weight,
z,
)

def _atom_types_update(
self,
predicted_logits: torch.Tensor,
Expand Down Expand Up @@ -476,8 +524,11 @@ def predictor_step(
assert (a_im1 != self.masked_atom_type_index).all(), \
"There remains MASKED atoms at the last time step: review code, there must be a bug or invalid input."

x_im1 = self._relative_coordinates_update(
composition_i.X, model_predictions_i.X, sigma_i, g2_i, g_i
# draw a gaussian noise sample and update the positions accordingly
z = self._draw_gaussian_sample(number_of_samples).to(composition_i.X)

x_im1 = self._relative_coordinates_update_predictor_step(
composition_i.X, model_predictions_i.X, sigma_i, g2_i, g_i, z
)

composition_im1 = AXL(
Expand Down Expand Up @@ -509,6 +560,18 @@ def predictor_step(

return composition_im1

def _get_corrector_step_size(
self,
index_i: int,
sigma_i: torch.Tensor,
model_predictions_i: AXL,
z: torch.Tensor,
) -> torch.Tensor:
"""Compute the size of the corrector step for the relative coordinates update."""
# Get the epsilon from the tabulated Langevin dynamics array indexed with [0,..., N-1].
eps_i = self.langevin_dynamics.epsilon[index_i].to(model_predictions_i.X)
return eps_i

def corrector_step(
self,
composition_i: AXL,
Expand All @@ -518,7 +581,8 @@ def corrector_step(
) -> AXL:
"""Corrector Step.
Note this is not affecting the atom types. Only the reduced coordinates and lattice vectors.
Note this does not affect the atom types unless specified with the atom_type_transition_in_corrector
argument. Always affect the reduced coordinates and lattice vectors.
Args:
composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i.
Expand All @@ -533,9 +597,8 @@ def corrector_step(
"The corrector step can only be invoked for index_i between 0 and "
"the total number of discretization steps minus 1."
)
# The Langevin dynamics array are indexed with [0,..., N-1]
eps_i = self.langevin_dynamics.epsilon[index_i].to(composition_i.X)
sqrt_2eps_i = self.langevin_dynamics.sqrt_2_epsilon[index_i].to(composition_i.X)

number_of_samples = composition_i.X.shape[0]

if index_i == 0:
# TODO: we are extrapolating here; the score network will never have seen this time step...
Expand All @@ -553,8 +616,16 @@ def corrector_step(
composition_i, t_i, sigma_i, unit_cell, cartesian_forces
)

corrected_x_i = self._relative_coordinates_update(
composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i
# draw a gaussian noise sample and update the positions accordingly
z = self._draw_gaussian_sample(number_of_samples).to(composition_i.X)

# get the step size eps_i
eps_i = self._get_corrector_step_size(index_i, sigma_i, model_predictions_i, z)
# the size for the noise part is sqrt(2 * eps_i)
sqrt_2eps_i = torch.sqrt(2 * eps_i)

corrected_x_i = self._relative_coordinates_update_corrector_step(
composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i, z
)

if self.atom_type_transition_in_corrector:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class SDESamplingParameters(SamplingParameters):
algorithm: str = "sde"
sde_type: str = "ito"
method: str = "euler"
adaptative: bool = False
adaptive: bool = False
absolute_solver_tolerance: float = (
1.0e-7 # the absolute error tolerance passed to the SDE solver.
)
Expand Down Expand Up @@ -325,7 +325,7 @@ def sample(
sde_times,
method=self.sampling_parameters.method,
dt=dt,
adaptive=self.sampling_parameters.adaptative,
adaptive=self.sampling_parameters.adaptive,
atol=self.sampling_parameters.absolute_solver_tolerance,
rtol=self.sampling_parameters.relative_solver_tolerance,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,8 @@ class NoiseParameters:

# Default value comes from "Generative Modeling by Estimating Gradients of the Data Distribution"
corrector_step_epsilon: float = 2e-5

# Step size scaling for the Adaptive Corrector Generator. Default value comes from github implementation
# https: // github.com / yang - song / score_sde / blob / main / configs / default_celeba_configs.py
# for the celeba dataset. Note the suggested value for CIFAR10 is 0.16 in that repo.
corrector_r: float = 0.17
4 changes: 4 additions & 0 deletions tests/generators/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def _forward_unchecked(
class BaseTestGenerator:
"""A base class that contains common test fixtures useful for testing generators."""

@pytest.fixture(scope="class", autouse=True)
def set_random_seed(self):
torch.manual_seed(34534534)

@pytest.fixture()
def unit_cell_size(self):
return 10
Expand Down
Loading

0 comments on commit 2f6bbb3

Please sign in to comment.