-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Langevin adaptative corrector #109
Changes from 6 commits
495e0bd
e9af6c1
7e74b7f
81a2200
accde32
4e30870
0948cba
45facdd
535335e
568bd78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
import torch | ||
|
||
from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ | ||
LangevinGenerator | ||
from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ | ||
PredictorCorrectorSamplingParameters | ||
from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ | ||
ScoreNetwork | ||
from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL | ||
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ | ||
NoiseParameters | ||
|
||
|
||
class AdaptativeCorrectorGenerator(LangevinGenerator): | ||
"""Langevin Dynamics Generator using only a corrector step with adaptative step size for relative coordinates. | ||
|
||
This class implements the Langevin Corrector generation of position samples, following | ||
Song et. al. 2021, namely: | ||
"SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS" | ||
""" | ||
|
||
def __init__( | ||
self, | ||
noise_parameters: NoiseParameters, | ||
sampling_parameters: PredictorCorrectorSamplingParameters, | ||
axl_network: ScoreNetwork, | ||
): | ||
"""Init method.""" | ||
super().__init__( | ||
noise_parameters=noise_parameters, | ||
sampling_parameters=sampling_parameters, | ||
axl_network=axl_network, | ||
) | ||
self.corrector_r = noise_parameters.corrector_r | ||
|
||
def predictor_step( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I'm reading this right, this I propose that you create methods
In In the class This way, we avoid a lot of duplicated code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. I changed the code accordingly. |
||
self, | ||
composition_i: AXL, | ||
index_i: int, | ||
unit_cell: torch.Tensor, # TODO replace with AXL-L | ||
cartesian_forces: torch.Tensor, | ||
) -> AXL: | ||
"""Predictor step. | ||
|
||
Args: | ||
composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i. | ||
index_i : index of the time step. | ||
unit_cell: sampled unit cell at time step i. | ||
cartesian_forces: forces conditioning the sampling process | ||
|
||
Returns: | ||
composition_im1 : sampled composition, at time step i - 1. | ||
""" | ||
assert ( | ||
1 <= index_i <= self.number_of_discretization_steps | ||
), "The predictor step can only be invoked for index_i between 1 and the total number of discretization steps." | ||
|
||
idx = index_i - 1 # python starts indices at zero | ||
t_i = self.noise.time[idx].to(composition_i.X) | ||
sigma_i = self.noise.sigma[idx].to(composition_i.X) | ||
q_matrices_i = self.noise.q_matrix[idx].to(composition_i.X) | ||
q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X) | ||
q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X) | ||
|
||
model_predictions_i = self._get_model_predictions( | ||
composition_i, t_i, sigma_i, unit_cell, cartesian_forces | ||
) | ||
|
||
# Even if the global flag 'one_atom_type_transition_per_step' is set to True, a single atomic transition | ||
# cannot be used at the last time step because it is necessary for all atoms to be unmasked at the end | ||
# of the trajectory. Here, we use 'first' and 'last' with respect to a denoising trajectory, where | ||
# the "first" time step is at index_i = T and the "last" time step is index_i = 1. | ||
this_is_last_time_step = idx == 0 | ||
one_atom_type_transition_per_step = ( | ||
self.one_atom_type_transition_per_step and not this_is_last_time_step | ||
) | ||
|
||
# atom types update | ||
a_im1 = self._atom_types_update( | ||
model_predictions_i.A, | ||
composition_i.A, | ||
q_matrices_i, | ||
q_bar_matrices_i, | ||
q_bar_tm1_matrices_i, | ||
atom_type_greedy_sampling=self.atom_type_greedy_sampling, | ||
one_atom_type_transition_per_step=one_atom_type_transition_per_step, | ||
) | ||
|
||
if this_is_last_time_step: | ||
assert (a_im1 != self.masked_atom_type_index).all(), \ | ||
"There remains MASKED atoms at the last time step: review code, there must be a bug or invalid input." | ||
|
||
# in the adaptative corrector approach, there is no predictor step applied on the X component | ||
composition_im1 = AXL( | ||
A=a_im1, X=composition_i.X, L=unit_cell | ||
) # TODO : Deal with L correctly | ||
|
||
if self.record: | ||
# TODO : Deal with L correctly | ||
composition_i_for_recording = AXL( | ||
A=composition_i.A, X=composition_i.X, L=unit_cell | ||
) | ||
# Keep the record on the CPU | ||
entry = dict(time_step_index=index_i) | ||
list_keys = ["composition_i", "composition_im1", "model_predictions_i"] | ||
list_axl = [ | ||
composition_i_for_recording, | ||
composition_im1, | ||
model_predictions_i, | ||
] | ||
|
||
for key, axl in zip(list_keys, list_axl): | ||
record_axl = AXL( | ||
A=axl.A.detach().cpu(), | ||
X=axl.X.detach().cpu(), | ||
L=axl.L.detach().cpu(), | ||
) | ||
entry[key] = record_axl | ||
self.sample_trajectory_recorder.record(key="predictor_step", entry=entry) | ||
|
||
return composition_im1 | ||
|
||
def corrector_step( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method also seems like it is very close to I propose that you create a method We could get rid of the array Again, this would avoid a lot of duplication. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed. I did not delete sqrt_2_epsilon yet. It requires checking a few tests to make sure it is not called anywhere. We could either do it right now or leave it as a TODO as part of a code cleanup effort. |
||
self, | ||
composition_i: AXL, | ||
index_i: int, | ||
unit_cell: torch.Tensor, # TODO replace with AXL-L | ||
cartesian_forces: torch.Tensor, | ||
) -> AXL: | ||
r"""Corrector Step. | ||
|
||
Note this does not affect the atom types unless specified with the atom_type_transition_in_corrector argument. | ||
Always affect the reduced coordinates and lattice vectors. The prefactors determining the changes in the X and L | ||
variables are determined using the sigma normalized score at that corrector step. The relative coordinates | ||
update is given by: | ||
|
||
.. math:: | ||
|
||
x_i \leftarrow x_i + \epsilon_i * s(x_i, t_i) + \sqrt(2 \epsilon_i) z | ||
|
||
where :math:`s(x_i, t_i)` is the score, :math:`z` is a random variable drawn from a normal distribution and | ||
:math:`\epsilon_i` is given by: | ||
|
||
.. math:: | ||
|
||
\epsilon_i = 2 \left(r \frac{||z||_2}{||s(x_i, t_i)||_2}\right)^2 | ||
|
||
where :math:`r` is an hyper-parameter (0.17 by default) and :math:`||\cdot||_2` is the L2 norm. | ||
|
||
Args: | ||
composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i. | ||
index_i : index of the time step. | ||
unit_cell: sampled unit cell at time step i. # TODO replace with AXL-L | ||
cartesian_forces: forces conditioning the sampling | ||
|
||
Returns: | ||
corrected_composition_i : sampled composition, after corrector step. | ||
""" | ||
assert 0 <= index_i <= self.number_of_discretization_steps - 1, ( | ||
"The corrector step can only be invoked for index_i between 0 and " | ||
"the total number of discretization steps minus 1." | ||
) | ||
|
||
if index_i == 0: | ||
# TODO: we are extrapolating here; the score network will never have seen this time step... | ||
sigma_i = ( | ||
self.noise_parameters.sigma_min | ||
) # no need to change device, this is a float | ||
t_i = 0.0 # same for device - this is a float | ||
idx = index_i | ||
else: | ||
idx = index_i - 1 # python starts indices at zero | ||
sigma_i = self.noise.sigma[idx].to(composition_i.X) | ||
t_i = self.noise.time[idx].to(composition_i.X) | ||
|
||
model_predictions_i = self._get_model_predictions( | ||
composition_i, t_i, sigma_i, unit_cell, cartesian_forces | ||
) | ||
|
||
# to compute epsilon_i, we need the norm of the score. We average over the atoms. | ||
relative_coordinates_sigma_score_norm = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't sure what was the best approach to take the norm, so I went with something reasonable. Let's stick to the paper recommendations for now. |
||
torch.linalg.norm(model_predictions_i.X, dim=-1).mean(dim=-1) | ||
).view(-1, 1, 1) | ||
# note that sigma_score is \sigma * s(x, t), so we need to divide the norm by sigma to get the correct step size | ||
relative_coordinates_sigma_score_norm /= sigma_i | ||
# draw random noise | ||
z = self._draw_gaussian_sample(relative_coordinates_sigma_score_norm.shape[0]).to(composition_i.X) | ||
# and compute the norm | ||
z_norm = torch.linalg.norm(z, dim=-1).mean(dim=-1).view(-1, 1, 1) | ||
|
||
eps_i = ( | ||
2 | ||
* ( | ||
self.corrector_r | ||
* z_norm | ||
/ (relative_coordinates_sigma_score_norm.clip(min=self.small_epsilon)) | ||
) | ||
** 2 | ||
) | ||
sqrt_2eps_i = torch.sqrt(2 * eps_i) | ||
|
||
corrected_x_i = self._relative_coordinates_update( | ||
composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i, z=z | ||
) | ||
|
||
if self.atom_type_transition_in_corrector: | ||
q_matrices_i = self.noise.q_matrix[idx].to(composition_i.X) | ||
q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X) | ||
q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X) | ||
# atom types update | ||
corrected_a_i = self._atom_types_update( | ||
model_predictions_i.A, | ||
composition_i.A, | ||
q_matrices_i, | ||
q_bar_matrices_i, | ||
q_bar_tm1_matrices_i, | ||
atom_type_greedy_sampling=self.atom_type_greedy_sampling, | ||
one_atom_type_transition_per_step=self.one_atom_type_transition_per_step, | ||
) | ||
else: | ||
corrected_a_i = composition_i.A | ||
|
||
corrected_composition_i = AXL( | ||
A=corrected_a_i, | ||
X=corrected_x_i, | ||
L=unit_cell, # TODO replace with AXL-L | ||
) | ||
|
||
if self.record and self.record_corrector: | ||
# TODO : Deal with L correctly | ||
composition_i_for_recording = AXL( | ||
A=composition_i.A, X=composition_i.X, L=unit_cell | ||
) | ||
# Keep the record on the CPU | ||
entry = dict(time_step_index=index_i) | ||
list_keys = [ | ||
"composition_i", | ||
"corrected_composition_i", | ||
"model_predictions_i", | ||
] | ||
list_axl = [ | ||
composition_i_for_recording, | ||
corrected_composition_i, | ||
model_predictions_i, | ||
] | ||
|
||
for key, axl in zip(list_keys, list_axl): | ||
record_axl = AXL( | ||
A=axl.A.detach().cpu(), | ||
X=axl.X.detach().cpu(), | ||
L=axl.L.detach().cpu(), | ||
) | ||
entry[key] = record_axl | ||
|
||
self.sample_trajectory_recorder.record(key="corrector_step", entry=entry) | ||
|
||
return corrected_composition_i |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from diffusion_for_multi_scale_molecular_dynamics.generators.adaptative_corrector import \ | ||
AdaptativeCorrectorGenerator | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Adaptative" -> "Adaptive" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||
from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \ | ||
SamplingParameters | ||
from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ | ||
|
@@ -22,7 +24,8 @@ def instantiate_generator( | |
"ode", | ||
"sde", | ||
"predictor_corrector", | ||
], "Unknown algorithm. Possible choices are 'ode', 'sde' and 'predictor_corrector'" | ||
"adaptative_corrector", | ||
], "Unknown algorithm. Possible choices are 'ode', 'sde', 'predictor_corrector' and 'adaptative_corrector'" | ||
|
||
match sampling_parameters.algorithm: | ||
case "predictor_corrector": | ||
|
@@ -31,6 +34,12 @@ def instantiate_generator( | |
noise_parameters=noise_parameters, | ||
axl_network=axl_network, | ||
) | ||
case "adaptative_corrector": | ||
generator = AdaptativeCorrectorGenerator( | ||
sampling_parameters=sampling_parameters, | ||
noise_parameters=noise_parameters, | ||
axl_network=axl_network, | ||
) | ||
case "ode": | ||
generator = ExplodingVarianceODEAXLGenerator( | ||
sampling_parameters=sampling_parameters, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import dataclasses | ||
from typing import Tuple | ||
from typing import Optional, Tuple | ||
|
||
import einops | ||
import torch | ||
|
@@ -168,6 +168,7 @@ def _relative_coordinates_update( | |
sigma_i: torch.Tensor, | ||
score_weight: torch.Tensor, | ||
gaussian_noise_weight: torch.Tensor, | ||
z: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
r"""Generic update for the relative coordinates. | ||
|
||
|
@@ -186,13 +187,16 @@ def _relative_coordinates_update( | |
eps_i in the corrector step. Dimension: [number_of_samples] | ||
gaussian_noise_weight: prefactor in front of the random noise update. Should be g_i in the predictor step | ||
and sqrt_2eps_i in the corrector step. Dimension: [number_of_samples] | ||
z: gaussian noise used to update the coordinates. If None, a sample is drawn from the normal distribution. | ||
Dimension: [number_of_samples, number_of_atoms, spatial_dimension]. Defaults to None. | ||
|
||
Returns: | ||
updated_coordinates: relative coordinates after the update. Dimension: [number_of_samples, number_of_atoms, | ||
spatial_dimension]. | ||
""" | ||
number_of_samples = relative_coordinates.shape[0] | ||
z = self._draw_gaussian_sample(number_of_samples).to(relative_coordinates) | ||
if z is None: | ||
z = self._draw_gaussian_sample(number_of_samples).to(relative_coordinates) | ||
updated_coordinates = ( | ||
relative_coordinates | ||
+ score_weight * sigma_normalized_scores / sigma_i | ||
|
@@ -518,7 +522,8 @@ def corrector_step( | |
) -> AXL: | ||
"""Corrector Step. | ||
|
||
Note this is not affecting the atom types. Only the reduced coordinates and lattice vectors. | ||
Note this dones not affect the atom types unless specified with the atom_type_transition_in_corrector argument. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
Always affect the reduced coordinates and lattice vectors. | ||
|
||
Args: | ||
composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Adaptatif" is french. "Adaptive" is a better choice in English.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it. But interesting grammatical fact: both adaptative and adaptive are in the oxford dictionary and act as synonyms.