Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Langevin adaptative corrector #109

Merged
merged 10 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import torch

from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \
LangevinGenerator
from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \
PredictorCorrectorSamplingParameters
from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \
ScoreNetwork
from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \
NoiseParameters


class AdaptativeCorrectorGenerator(LangevinGenerator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Adaptatif" is french. "Adaptive" is a better choice in English.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it. But interesting grammatical fact: both adaptative and adaptive are in the oxford dictionary and act as synonyms.

"""Langevin Dynamics Generator using only a corrector step with adaptative step size for relative coordinates.

This class implements the Langevin Corrector generation of position samples, following
Song et. al. 2021, namely:
"SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS"
"""

def __init__(
self,
noise_parameters: NoiseParameters,
sampling_parameters: PredictorCorrectorSamplingParameters,
axl_network: ScoreNetwork,
):
"""Init method."""
super().__init__(
noise_parameters=noise_parameters,
sampling_parameters=sampling_parameters,
axl_network=axl_network,
)
self.corrector_r = noise_parameters.corrector_r

def predictor_step(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm reading this right, this predictor_step is identical to the predictor_step in LangevinGenerator, with the only difference that x_im1 = x_i. I note that the two methods have diverged (you copy-pasted an older version I believe).

I propose that you create methods

  • LangevinGenerator._relative_coordinates_update_predictor_step
    and
  • LangevinGenerator._relative_coordinates_update_corrector_step

In LangevinGenerator, both methods could just be dummies that call LangevinGenerator._relative_coordinates_update.

In the class AdaptativeCorrectorGenerator, you could overload _relative_coordinates_update_predictor_step to return its own input.

This way, we avoid a lot of duplicated code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I changed the code accordingly.

self,
composition_i: AXL,
index_i: int,
unit_cell: torch.Tensor, # TODO replace with AXL-L
cartesian_forces: torch.Tensor,
) -> AXL:
"""Predictor step.

Args:
composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i.
index_i : index of the time step.
unit_cell: sampled unit cell at time step i.
cartesian_forces: forces conditioning the sampling process

Returns:
composition_im1 : sampled composition, at time step i - 1.
"""
assert (
1 <= index_i <= self.number_of_discretization_steps
), "The predictor step can only be invoked for index_i between 1 and the total number of discretization steps."

idx = index_i - 1 # python starts indices at zero
t_i = self.noise.time[idx].to(composition_i.X)
sigma_i = self.noise.sigma[idx].to(composition_i.X)
q_matrices_i = self.noise.q_matrix[idx].to(composition_i.X)
q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X)
q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X)

model_predictions_i = self._get_model_predictions(
composition_i, t_i, sigma_i, unit_cell, cartesian_forces
)

# Even if the global flag 'one_atom_type_transition_per_step' is set to True, a single atomic transition
# cannot be used at the last time step because it is necessary for all atoms to be unmasked at the end
# of the trajectory. Here, we use 'first' and 'last' with respect to a denoising trajectory, where
# the "first" time step is at index_i = T and the "last" time step is index_i = 1.
this_is_last_time_step = idx == 0
one_atom_type_transition_per_step = (
self.one_atom_type_transition_per_step and not this_is_last_time_step
)

# atom types update
a_im1 = self._atom_types_update(
model_predictions_i.A,
composition_i.A,
q_matrices_i,
q_bar_matrices_i,
q_bar_tm1_matrices_i,
atom_type_greedy_sampling=self.atom_type_greedy_sampling,
one_atom_type_transition_per_step=one_atom_type_transition_per_step,
)

if this_is_last_time_step:
assert (a_im1 != self.masked_atom_type_index).all(), \
"There remains MASKED atoms at the last time step: review code, there must be a bug or invalid input."

# in the adaptative corrector approach, there is no predictor step applied on the X component
composition_im1 = AXL(
A=a_im1, X=composition_i.X, L=unit_cell
) # TODO : Deal with L correctly

if self.record:
# TODO : Deal with L correctly
composition_i_for_recording = AXL(
A=composition_i.A, X=composition_i.X, L=unit_cell
)
# Keep the record on the CPU
entry = dict(time_step_index=index_i)
list_keys = ["composition_i", "composition_im1", "model_predictions_i"]
list_axl = [
composition_i_for_recording,
composition_im1,
model_predictions_i,
]

for key, axl in zip(list_keys, list_axl):
record_axl = AXL(
A=axl.A.detach().cpu(),
X=axl.X.detach().cpu(),
L=axl.L.detach().cpu(),
)
entry[key] = record_axl
self.sample_trajectory_recorder.record(key="predictor_step", entry=entry)

return composition_im1

def corrector_step(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method also seems like it is very close to LangevinGenerator.corrector_step.

I propose that you create a method LangevinGenerator._get_corrector_step_epsilon. In LangevinGenerator, it could just return the correct entry from the tabulated values, and you could overload this method here to compute epsilon in a more sophisticated way.

We could get rid of the array sqrt_2_epsilon in LangevinDynamics: surely it's not taking a sqrt that slows us down.
Screenshot 2024-12-02 at 11 40 43 AM

Again, this would avoid a lot of duplication.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed. I did not delete sqrt_2_epsilon yet. It requires checking a few tests to make sure it is not called anywhere. We could either do it right now or leave it as a TODO as part of a code cleanup effort.

self,
composition_i: AXL,
index_i: int,
unit_cell: torch.Tensor, # TODO replace with AXL-L
cartesian_forces: torch.Tensor,
) -> AXL:
r"""Corrector Step.

Note this does not affect the atom types unless specified with the atom_type_transition_in_corrector argument.
Always affect the reduced coordinates and lattice vectors. The prefactors determining the changes in the X and L
variables are determined using the sigma normalized score at that corrector step. The relative coordinates
update is given by:

.. math::

x_i \leftarrow x_i + \epsilon_i * s(x_i, t_i) + \sqrt(2 \epsilon_i) z

where :math:`s(x_i, t_i)` is the score, :math:`z` is a random variable drawn from a normal distribution and
:math:`\epsilon_i` is given by:

.. math::

\epsilon_i = 2 \left(r \frac{||z||_2}{||s(x_i, t_i)||_2}\right)^2

where :math:`r` is an hyper-parameter (0.17 by default) and :math:`||\cdot||_2` is the L2 norm.

Args:
composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i.
index_i : index of the time step.
unit_cell: sampled unit cell at time step i. # TODO replace with AXL-L
cartesian_forces: forces conditioning the sampling

Returns:
corrected_composition_i : sampled composition, after corrector step.
"""
assert 0 <= index_i <= self.number_of_discretization_steps - 1, (
"The corrector step can only be invoked for index_i between 0 and "
"the total number of discretization steps minus 1."
)

if index_i == 0:
# TODO: we are extrapolating here; the score network will never have seen this time step...
sigma_i = (
self.noise_parameters.sigma_min
) # no need to change device, this is a float
t_i = 0.0 # same for device - this is a float
idx = index_i
else:
idx = index_i - 1 # python starts indices at zero
sigma_i = self.noise.sigma[idx].to(composition_i.X)
t_i = self.noise.time[idx].to(composition_i.X)

model_predictions_i = self._get_model_predictions(
composition_i, t_i, sigma_i, unit_cell, cartesian_forces
)

# to compute epsilon_i, we need the norm of the score. We average over the atoms.
relative_coordinates_sigma_score_norm = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The norm in the paper is calculated over all pixels.

I think you are doing this:
Screenshot 2024-12-02 at 11 34 09 AM

and I think the right way of doing it is this:
Screenshot 2024-12-02 at 11 34 30 AM

The paper talks about averaging over the batch dimension; you are averaging over the atoms within a single sample.

Screenshot 2024-12-02 at 11 34 59 AM

I don't think we should do this just yet...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure what was the best approach to take the norm, so I went with something reasonable. Let's stick to the paper recommendations for now.

torch.linalg.norm(model_predictions_i.X, dim=-1).mean(dim=-1)
).view(-1, 1, 1)
# note that sigma_score is \sigma * s(x, t), so we need to divide the norm by sigma to get the correct step size
relative_coordinates_sigma_score_norm /= sigma_i
# draw random noise
z = self._draw_gaussian_sample(relative_coordinates_sigma_score_norm.shape[0]).to(composition_i.X)
# and compute the norm
z_norm = torch.linalg.norm(z, dim=-1).mean(dim=-1).view(-1, 1, 1)

eps_i = (
2
* (
self.corrector_r
* z_norm
/ (relative_coordinates_sigma_score_norm.clip(min=self.small_epsilon))
)
** 2
)
sqrt_2eps_i = torch.sqrt(2 * eps_i)

corrected_x_i = self._relative_coordinates_update(
composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i, z=z
)

if self.atom_type_transition_in_corrector:
q_matrices_i = self.noise.q_matrix[idx].to(composition_i.X)
q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X)
q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X)
# atom types update
corrected_a_i = self._atom_types_update(
model_predictions_i.A,
composition_i.A,
q_matrices_i,
q_bar_matrices_i,
q_bar_tm1_matrices_i,
atom_type_greedy_sampling=self.atom_type_greedy_sampling,
one_atom_type_transition_per_step=self.one_atom_type_transition_per_step,
)
else:
corrected_a_i = composition_i.A

corrected_composition_i = AXL(
A=corrected_a_i,
X=corrected_x_i,
L=unit_cell, # TODO replace with AXL-L
)

if self.record and self.record_corrector:
# TODO : Deal with L correctly
composition_i_for_recording = AXL(
A=composition_i.A, X=composition_i.X, L=unit_cell
)
# Keep the record on the CPU
entry = dict(time_step_index=index_i)
list_keys = [
"composition_i",
"corrected_composition_i",
"model_predictions_i",
]
list_axl = [
composition_i_for_recording,
corrected_composition_i,
model_predictions_i,
]

for key, axl in zip(list_keys, list_axl):
record_axl = AXL(
A=axl.A.detach().cpu(),
X=axl.X.detach().cpu(),
L=axl.L.detach().cpu(),
)
entry[key] = record_axl

self.sample_trajectory_recorder.record(key="corrector_step", entry=entry)

return corrected_composition_i
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from diffusion_for_multi_scale_molecular_dynamics.generators.adaptative_corrector import \
AdaptativeCorrectorGenerator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Adaptative" -> "Adaptive"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \
SamplingParameters
from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \
Expand All @@ -22,7 +24,8 @@ def instantiate_generator(
"ode",
"sde",
"predictor_corrector",
], "Unknown algorithm. Possible choices are 'ode', 'sde' and 'predictor_corrector'"
"adaptative_corrector",
], "Unknown algorithm. Possible choices are 'ode', 'sde', 'predictor_corrector' and 'adaptative_corrector'"

match sampling_parameters.algorithm:
case "predictor_corrector":
Expand All @@ -31,6 +34,12 @@ def instantiate_generator(
noise_parameters=noise_parameters,
axl_network=axl_network,
)
case "adaptative_corrector":
generator = AdaptativeCorrectorGenerator(
sampling_parameters=sampling_parameters,
noise_parameters=noise_parameters,
axl_network=axl_network,
)
case "ode":
generator = ExplodingVarianceODEAXLGenerator(
sampling_parameters=sampling_parameters,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Tuple
from typing import Optional, Tuple

import einops
import torch
Expand Down Expand Up @@ -168,6 +168,7 @@ def _relative_coordinates_update(
sigma_i: torch.Tensor,
score_weight: torch.Tensor,
gaussian_noise_weight: torch.Tensor,
z: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Generic update for the relative coordinates.

Expand All @@ -186,13 +187,16 @@ def _relative_coordinates_update(
eps_i in the corrector step. Dimension: [number_of_samples]
gaussian_noise_weight: prefactor in front of the random noise update. Should be g_i in the predictor step
and sqrt_2eps_i in the corrector step. Dimension: [number_of_samples]
z: gaussian noise used to update the coordinates. If None, a sample is drawn from the normal distribution.
Dimension: [number_of_samples, number_of_atoms, spatial_dimension]. Defaults to None.

Returns:
updated_coordinates: relative coordinates after the update. Dimension: [number_of_samples, number_of_atoms,
spatial_dimension].
"""
number_of_samples = relative_coordinates.shape[0]
z = self._draw_gaussian_sample(number_of_samples).to(relative_coordinates)
if z is None:
z = self._draw_gaussian_sample(number_of_samples).to(relative_coordinates)
updated_coordinates = (
relative_coordinates
+ score_weight * sigma_normalized_scores / sigma_i
Expand Down Expand Up @@ -518,7 +522,8 @@ def corrector_step(
) -> AXL:
"""Corrector Step.

Note this is not affecting the atom types. Only the reduced coordinates and lattice vectors.
Note this dones not affect the atom types unless specified with the atom_type_transition_in_corrector argument.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: dones -> doesn't.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Always affect the reduced coordinates and lattice vectors.

Args:
composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,8 @@ class NoiseParameters:

# Default value comes from "Generative Modeling by Estimating Gradients of the Data Distribution"
corrector_step_epsilon: float = 2e-5

# Step size scaling for the Adaptative Corrector Generator. Default value comes from github implementation
# https: // github.com / yang - song / score_sde / blob / main / configs / default_celeba_configs.py
# for the celeba dataset. Note the suggested value for CIFAR10 is 0.16 in that repo.
corrector_r: float = 0.17
Loading