From 495e0bdb6ed62a534839b804c64b570f04465060 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Fri, 29 Nov 2024 11:09:56 -0500 Subject: [PATCH 01/10] adaptative corrector generator --- .../generators/adaptative_corrector.py | 238 ++++++++++++++++++ .../generators/instantiate_generator.py | 11 +- .../generators/langevin_generator.py | 11 +- 3 files changed, 256 insertions(+), 4 deletions(-) create mode 100644 src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py new file mode 100644 index 00000000..23a535ee --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py @@ -0,0 +1,238 @@ +import torch + +from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ + LangevinGenerator +from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ + PredictorCorrectorSamplingParameters +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ + ScoreNetwork +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters + + +class AdaptativeCorrectorGenerator(LangevinGenerator): + """Langevin Dynamics Generator using only a corrector step with adaptative step size for relative coordinates. + + This class implements the Langevin Corrector generation of position samples, following + Song et. al. 2021, namely: + "SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS" + """ + + def __init__( + self, + noise_parameters: NoiseParameters, + sampling_parameters: PredictorCorrectorSamplingParameters, + axl_network: ScoreNetwork, + ): + """Init method.""" + super().__init__( + noise_parameters=noise_parameters, + sampling_parameters=sampling_parameters, + axl_network=axl_network, + ) + self.corrector_r = noise_parameters.corrector_r + + def predictor_step( + self, + composition_i: AXL, + index_i: int, + unit_cell: torch.Tensor, # TODO replace with AXL-L + cartesian_forces: torch.Tensor, + ) -> AXL: + """Predictor step. + + Args: + composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i. + index_i : index of the time step. + unit_cell: sampled unit cell at time step i. + cartesian_forces: forces conditioning the sampling process + + Returns: + composition_im1 : sampled composition, at time step i - 1. + """ + assert ( + 1 <= index_i <= self.number_of_discretization_steps + ), "The predictor step can only be invoked for index_i between 1 and the total number of discretization steps." + + idx = index_i - 1 # python starts indices at zero + t_i = self.noise.time[idx].to(composition_i.X) + sigma_i = self.noise.sigma[idx].to(composition_i.X) + q_matrices_i = self.noise.q_matrix[idx].to(composition_i.X) + q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X) + q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X) + + model_predictions_i = self._get_model_predictions( + composition_i, t_i, sigma_i, unit_cell, cartesian_forces + ) + + # atom types update + a_im1 = self.atom_types_update( + model_predictions_i.A, + composition_i.A, + q_matrices_i, + q_bar_matrices_i, + q_bar_tm1_matrices_i, + ) + + # in this approach, there is no predictor step applied on the X component + composition_im1 = AXL( + A=a_im1, X=composition_i.X, L=unit_cell + ) # TODO : Deal with L correctly + + if self.record: + # TODO : Deal with L correctly + composition_i_for_recording = AXL( + A=composition_i.A, X=composition_i.X, L=unit_cell + ) + # Keep the record on the CPU + entry = dict(time_step_index=index_i) + list_keys = ["composition_i", "composition_im1", "model_predictions_i"] + list_axl = [ + composition_i_for_recording, + composition_im1, + model_predictions_i, + ] + + for key, axl in zip(list_keys, list_axl): + record_axl = AXL( + A=axl.A.detach().cpu(), + X=axl.X.detach().cpu(), + L=axl.L.detach().cpu(), + ) + entry[key] = record_axl + self.sample_trajectory_recorder.record(key="predictor_step", entry=entry) + + return composition_im1 + + def corrector_step( + self, + composition_i: AXL, + index_i: int, + unit_cell: torch.Tensor, # TODO replace with AXL-L + cartesian_forces: torch.Tensor, + ) -> AXL: + r"""Corrector Step. + + Note this does not affect the atom types unless specified with the atom_type_transition_in_corrector argument. + Always affect the reduced coordinates and lattice vectors. The prefactors determining the changes in the X and L + variables are determined using the sigma normalized score at that corrector step. The relative coordinates + update is given by: + + .. math:: + + x_i \leftarrow x_i + \epsilon_i * s(x_i, t_i) + \sqrt(2 \epsilon_i) z + + where :math:`s(x_i, t_i)` is the score, :math:`z` is a random variable drawn from a normal distribution and + :math:`\epsilon_i` is given by: + + .. math:: + + \epsilon_i = 2 \left(r \frac{||z||_2}{||s(x_i, t_i)||_2}\right)^2 + + where :math:`r` is an hyper-parameter (0.15 by default) and :math:`||\cdot||_2` is the L2 norm. + + Args: + composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i. + index_i : index of the time step. + unit_cell: sampled unit cell at time step i. # TODO replace with AXL-L + cartesian_forces: forces conditioning the sampling + + Returns: + corrected_composition_i : sampled composition, after corrector step. + """ + assert 0 <= index_i <= self.number_of_discretization_steps - 1, ( + "The corrector step can only be invoked for index_i between 0 and " + "the total number of discretization steps minus 1." + ) + + if index_i == 0: + # TODO: we are extrapolating here; the score network will never have seen this time step... + sigma_i = ( + self.noise_parameters.sigma_min + ) # no need to change device, this is a float + t_i = 0.0 # same for device - this is a float + idx = index_i + else: + idx = index_i - 1 # python starts indices at zero + sigma_i = self.noise.sigma[idx].to(composition_i.X) + t_i = self.noise.time[idx].to(composition_i.X) + + model_predictions_i = self._get_model_predictions( + composition_i, t_i, sigma_i, unit_cell, cartesian_forces + ) + + # to compute epsilon_i, we need the norm of the score. We average over the atoms. + relative_coordinates_sigma_score_norm = ( + torch.linalg.norm(model_predictions_i.X, dim=-1).mean(dim=-1) + ).view(-1, 1, 1) + # draw random noise + z = self._draw_gaussian_sample(relative_coordinates_sigma_score_norm.shape[0]) + # and compute the norm + z_norm = torch.linalg.norm(z, dim=-1).mean(dim=-1).view(-1, 1, 1) + # note that sigma_score is \sigma * s(x, t), so we need to divide the norm by sigma to get the correct step size + eps_i = ( + 2 + * ( + self.corrector_r + * z_norm + / (relative_coordinates_sigma_score_norm.clip(min=self.small_epsilon)) + ) + ** 2 + ) + sqrt_2eps_i = torch.sqrt(2 * eps_i) + + corrected_x_i = self.relative_coordinates_update( + composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i, z=z + ) + + if self.atom_type_transition_in_corrector: + q_matrices_i = self.noise.q_matrix[idx].to(composition_i.X) + q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X) + q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X) + # atom types update + corrected_a_i = self.atom_types_update( + model_predictions_i.A, + composition_i.A, + q_matrices_i, + q_bar_matrices_i, + q_bar_tm1_matrices_i, + ) + else: + corrected_a_i = composition_i.A + + corrected_composition_i = AXL( + A=corrected_a_i, + X=corrected_x_i, + L=unit_cell, # TODO replace with AXL-L + ) + + if self.record and self.record_corrector: + # TODO : Deal with L correctly + composition_i_for_recording = AXL( + A=composition_i.A, X=composition_i.X, L=unit_cell + ) + # Keep the record on the CPU + entry = dict(time_step_index=index_i) + list_keys = [ + "composition_i", + "corrected_composition_i", + "model_predictions_i", + ] + list_axl = [ + composition_i_for_recording, + corrected_composition_i, + model_predictions_i, + ] + + for key, axl in zip(list_keys, list_axl): + record_axl = AXL( + A=axl.A.detach().cpu(), + X=axl.X.detach().cpu(), + L=axl.L.detach().cpu(), + ) + entry[key] = record_axl + + self.sample_trajectory_recorder.record(key="corrector_step", entry=entry) + + return corrected_composition_i diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py index af897328..433eb656 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py @@ -1,3 +1,5 @@ +from diffusion_for_multi_scale_molecular_dynamics.generators.adaptative_corrector import \ + AdaptativeCorrectorGenerator from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \ SamplingParameters from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ @@ -22,7 +24,8 @@ def instantiate_generator( "ode", "sde", "predictor_corrector", - ], "Unknown algorithm. Possible choices are 'ode', 'sde' and 'predictor_corrector'" + "adaptative_corrector", + ], "Unknown algorithm. Possible choices are 'ode', 'sde', 'predictor_corrector' and 'adaptative_corrector'" match sampling_parameters.algorithm: case "predictor_corrector": @@ -31,6 +34,12 @@ def instantiate_generator( noise_parameters=noise_parameters, axl_network=axl_network, ) + case "adaptative_corrector": + generator = AdaptativeCorrectorGenerator( + sampling_parameters=sampling_parameters, + noise_parameters=noise_parameters, + axl_network=axl_network, + ) case "ode": generator = ExplodingVarianceODEAXLGenerator( sampling_parameters=sampling_parameters, diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py index 2d68dc77..9ceb2c71 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Tuple +from typing import Optional, Tuple import einops import torch @@ -168,6 +168,7 @@ def _relative_coordinates_update( sigma_i: torch.Tensor, score_weight: torch.Tensor, gaussian_noise_weight: torch.Tensor, + z: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""Generic update for the relative coordinates. @@ -186,13 +187,16 @@ def _relative_coordinates_update( eps_i in the corrector step. Dimension: [number_of_samples] gaussian_noise_weight: prefactor in front of the random noise update. Should be g_i in the predictor step and sqrt_2eps_i in the corrector step. Dimension: [number_of_samples] + z: gaussian noise used to update the coordinates. If None, a sample is drawn from the normal distribution. + Dimension: [number_of_samples, number_of_atoms, spatial_dimension]. Defaults to None. Returns: updated_coordinates: relative coordinates after the update. Dimension: [number_of_samples, number_of_atoms, spatial_dimension]. """ number_of_samples = relative_coordinates.shape[0] - z = self._draw_gaussian_sample(number_of_samples).to(relative_coordinates) + if z is None: + z = self._draw_gaussian_sample(number_of_samples).to(relative_coordinates) updated_coordinates = ( relative_coordinates + score_weight * sigma_normalized_scores / sigma_i @@ -518,7 +522,8 @@ def corrector_step( ) -> AXL: """Corrector Step. - Note this is not affecting the atom types. Only the reduced coordinates and lattice vectors. + Note this dones not affect the atom types unless specified with the atom_type_transition_in_corrector argument. + Always affect the reduced coordinates and lattice vectors. Args: composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i. From e9af6c13312304998da4aeb36dc449685a86caad Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Fri, 29 Nov 2024 13:18:39 -0500 Subject: [PATCH 02/10] adding unit tests --- tests/generators/test_adaptative_corrector.py | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 tests/generators/test_adaptative_corrector.py diff --git a/tests/generators/test_adaptative_corrector.py b/tests/generators/test_adaptative_corrector.py new file mode 100644 index 00000000..0dac2c0d --- /dev/null +++ b/tests/generators/test_adaptative_corrector.py @@ -0,0 +1,124 @@ +import pytest +import torch + +from diffusion_for_multi_scale_molecular_dynamics.generators.adaptative_corrector import \ + AdaptativeCorrectorGenerator +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell +from src.diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + NoiseScheduler +from tests.generators.test_langevin_generator import TestLangevinGenerator + + +class TestAdaptativeCorrectorGenerator(TestLangevinGenerator): + + @pytest.fixture() + def noise_parameters(self, total_time_steps): + noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + time_delta=0.1, + sigma_min=0.15, + corrector_r=0.15, + ) + return noise_parameters + + @pytest.fixture() + def pc_generator(self, noise_parameters, sampling_parameters, axl_network): + # override the base class + generator = AdaptativeCorrectorGenerator( + noise_parameters=noise_parameters, + sampling_parameters=sampling_parameters, + axl_network=axl_network, + ) + + return generator + + def test_predictor_step_relative_coordinates( + self, + mocker, + pc_generator, + noise_parameters, + axl_i, + total_time_steps, + number_of_samples, + unit_cell_sample, + num_atomic_classes, + device, + ): + # override the base class + forces = torch.zeros_like(axl_i.X) + + for index_i in range(1, total_time_steps + 1): + computed_sample = pc_generator.predictor_step( + axl_i, index_i, unit_cell_sample, forces + ) + + expected_coordinates = axl_i.X + expected_coordinates = map_relative_coordinates_to_unit_cell( + expected_coordinates + ) + # this is almost trivial - the coordinates should not change in a predictor step + torch.testing.assert_close(computed_sample.X, expected_coordinates) + + @pytest.mark.parametrize("corrector_r", [0.1, 0.5, 1.2]) + def test_corrector_step( + self, + mocker, + corrector_r, + pc_generator, + noise_parameters, + axl_i, + total_time_steps, + number_of_samples, + unit_cell_sample, + num_atomic_classes, + ): + pc_generator.corrector_r = corrector_r + sampler = NoiseScheduler(noise_parameters, num_classes=num_atomic_classes) + noise, _ = sampler.get_all_sampling_parameters() + sigma_min = noise_parameters.sigma_min + list_sigma = noise.sigma + list_time = noise.time + forces = torch.zeros_like(axl_i.X) + + z = pc_generator._draw_gaussian_sample(number_of_samples).to(axl_i.X) + mocker.patch.object(pc_generator, "_draw_gaussian_sample", return_value=z) + z_norm = torch.sqrt((z**2).sum(dim=-1)).mean( + dim=-1 + ) # norm of z averaged over atoms + + for index_i in range(0, total_time_steps): + computed_sample = pc_generator.corrector_step( + axl_i, index_i, unit_cell_sample, forces + ) + + if index_i == 0: + sigma_i = sigma_min + t_i = 0.0 + else: + sigma_i = list_sigma[index_i - 1] + t_i = list_time[index_i - 1] + + s_i = ( + pc_generator._get_model_predictions( + axl_i, t_i, sigma_i, unit_cell_sample, forces + ).X + / sigma_i + ) + s_i_norm = torch.sqrt((s_i**2).sum(dim=-1)).mean(dim=-1) + # \epsilon_i = 2 \left(r \frac{||z||_2}{||s(x_i, t_i)||_2}\right)^2 + eps_i = ( + 2 + * (corrector_r * z_norm / s_i_norm.clip(min=pc_generator.small_epsilon)) + ** 2 + ) + eps_i = eps_i.view(-1, 1, 1) + + expected_coordinates = axl_i.X + eps_i * s_i + torch.sqrt(2.0 * eps_i) * z + expected_coordinates = map_relative_coordinates_to_unit_cell( + expected_coordinates + ) + + torch.testing.assert_close(computed_sample.X, expected_coordinates) From 7e74b7fc2be24b81ef822cfe2b2fbbabc081c402 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Fri, 29 Nov 2024 13:19:43 -0500 Subject: [PATCH 03/10] fixes --- .../generators/adaptative_corrector.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py index 23a535ee..ccd1f6d8 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py @@ -166,11 +166,13 @@ def corrector_step( relative_coordinates_sigma_score_norm = ( torch.linalg.norm(model_predictions_i.X, dim=-1).mean(dim=-1) ).view(-1, 1, 1) + # note that sigma_score is \sigma * s(x, t), so we need to divide the norm by sigma to get the correct step size + relative_coordinates_sigma_score_norm /= sigma_i # draw random noise - z = self._draw_gaussian_sample(relative_coordinates_sigma_score_norm.shape[0]) + z = self._draw_gaussian_sample(relative_coordinates_sigma_score_norm.shape[0]).to(composition_i.X) # and compute the norm z_norm = torch.linalg.norm(z, dim=-1).mean(dim=-1).view(-1, 1, 1) - # note that sigma_score is \sigma * s(x, t), so we need to divide the norm by sigma to get the correct step size + eps_i = ( 2 * ( From 81a220096368f4a8e09065e096c8d8108517f91f Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Fri, 29 Nov 2024 13:23:29 -0500 Subject: [PATCH 04/10] typo in docstring --- .../generators/adaptative_corrector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py index ccd1f6d8..14414082 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py @@ -130,7 +130,7 @@ def corrector_step( \epsilon_i = 2 \left(r \frac{||z||_2}{||s(x_i, t_i)||_2}\right)^2 - where :math:`r` is an hyper-parameter (0.15 by default) and :math:`||\cdot||_2` is the L2 norm. + where :math:`r` is an hyper-parameter (0.17 by default) and :math:`||\cdot||_2` is the L2 norm. Args: composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i. From accde32f395dc39354b6f4eb7c2ad05551373e59 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Mon, 2 Dec 2024 07:52:11 -0500 Subject: [PATCH 05/10] dynamic_corrector parameter r --- .../noise_schedulers/noise_parameters.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py b/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py index ae34bb85..de105c85 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py @@ -21,3 +21,8 @@ class NoiseParameters: # Default value comes from "Generative Modeling by Estimating Gradients of the Data Distribution" corrector_step_epsilon: float = 2e-5 + + # Step size scaling for the Adaptative Corrector Generator. Default value comes from github implementation + # https: // github.com / yang - song / score_sde / blob / main / configs / default_celeba_configs.py + # for the celeba dataset. Note the suggested value for CIFAR10 is 0.16 in that repo. + corrector_r: float = 0.17 From 4e308707130042cb2fb9e0ce38977f409daefae7 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Mon, 2 Dec 2024 09:10:48 -0500 Subject: [PATCH 06/10] fixes to reflect updates in langevin generator --- .../generators/adaptative_corrector.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py index 14414082..2d46dcc3 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py @@ -66,16 +66,31 @@ def predictor_step( composition_i, t_i, sigma_i, unit_cell, cartesian_forces ) + # Even if the global flag 'one_atom_type_transition_per_step' is set to True, a single atomic transition + # cannot be used at the last time step because it is necessary for all atoms to be unmasked at the end + # of the trajectory. Here, we use 'first' and 'last' with respect to a denoising trajectory, where + # the "first" time step is at index_i = T and the "last" time step is index_i = 1. + this_is_last_time_step = idx == 0 + one_atom_type_transition_per_step = ( + self.one_atom_type_transition_per_step and not this_is_last_time_step + ) + # atom types update - a_im1 = self.atom_types_update( + a_im1 = self._atom_types_update( model_predictions_i.A, composition_i.A, q_matrices_i, q_bar_matrices_i, q_bar_tm1_matrices_i, + atom_type_greedy_sampling=self.atom_type_greedy_sampling, + one_atom_type_transition_per_step=one_atom_type_transition_per_step, ) - # in this approach, there is no predictor step applied on the X component + if this_is_last_time_step: + assert (a_im1 != self.masked_atom_type_index).all(), \ + "There remains MASKED atoms at the last time step: review code, there must be a bug or invalid input." + + # in the adaptative corrector approach, there is no predictor step applied on the X component composition_im1 = AXL( A=a_im1, X=composition_i.X, L=unit_cell ) # TODO : Deal with L correctly @@ -184,7 +199,7 @@ def corrector_step( ) sqrt_2eps_i = torch.sqrt(2 * eps_i) - corrected_x_i = self.relative_coordinates_update( + corrected_x_i = self._relative_coordinates_update( composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i, z=z ) @@ -193,12 +208,14 @@ def corrector_step( q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X) q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X) # atom types update - corrected_a_i = self.atom_types_update( + corrected_a_i = self._atom_types_update( model_predictions_i.A, composition_i.A, q_matrices_i, q_bar_matrices_i, q_bar_tm1_matrices_i, + atom_type_greedy_sampling=self.atom_type_greedy_sampling, + one_atom_type_transition_per_step=self.one_atom_type_transition_per_step, ) else: corrected_a_i = composition_i.A From 0948cbac37950e77ef45ab3493a6ba9e26ac1585 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Mon, 2 Dec 2024 14:00:55 -0500 Subject: [PATCH 07/10] code refactor after code review --- .../generators/adaptative_corrector.py | 257 ------------------ .../generators/adaptive_corrector.py | 95 +++++++ .../generators/instantiate_generator.py | 12 +- .../generators/langevin_generator.py | 94 ++++++- tests/generators/test_adaptative_corrector.py | 6 +- 5 files changed, 185 insertions(+), 279 deletions(-) delete mode 100644 src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py create mode 100644 src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py deleted file mode 100644 index 2d46dcc3..00000000 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py +++ /dev/null @@ -1,257 +0,0 @@ -import torch - -from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ - LangevinGenerator -from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ - PredictorCorrectorSamplingParameters -from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ - ScoreNetwork -from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL -from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ - NoiseParameters - - -class AdaptativeCorrectorGenerator(LangevinGenerator): - """Langevin Dynamics Generator using only a corrector step with adaptative step size for relative coordinates. - - This class implements the Langevin Corrector generation of position samples, following - Song et. al. 2021, namely: - "SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS" - """ - - def __init__( - self, - noise_parameters: NoiseParameters, - sampling_parameters: PredictorCorrectorSamplingParameters, - axl_network: ScoreNetwork, - ): - """Init method.""" - super().__init__( - noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - axl_network=axl_network, - ) - self.corrector_r = noise_parameters.corrector_r - - def predictor_step( - self, - composition_i: AXL, - index_i: int, - unit_cell: torch.Tensor, # TODO replace with AXL-L - cartesian_forces: torch.Tensor, - ) -> AXL: - """Predictor step. - - Args: - composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i. - index_i : index of the time step. - unit_cell: sampled unit cell at time step i. - cartesian_forces: forces conditioning the sampling process - - Returns: - composition_im1 : sampled composition, at time step i - 1. - """ - assert ( - 1 <= index_i <= self.number_of_discretization_steps - ), "The predictor step can only be invoked for index_i between 1 and the total number of discretization steps." - - idx = index_i - 1 # python starts indices at zero - t_i = self.noise.time[idx].to(composition_i.X) - sigma_i = self.noise.sigma[idx].to(composition_i.X) - q_matrices_i = self.noise.q_matrix[idx].to(composition_i.X) - q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X) - q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X) - - model_predictions_i = self._get_model_predictions( - composition_i, t_i, sigma_i, unit_cell, cartesian_forces - ) - - # Even if the global flag 'one_atom_type_transition_per_step' is set to True, a single atomic transition - # cannot be used at the last time step because it is necessary for all atoms to be unmasked at the end - # of the trajectory. Here, we use 'first' and 'last' with respect to a denoising trajectory, where - # the "first" time step is at index_i = T and the "last" time step is index_i = 1. - this_is_last_time_step = idx == 0 - one_atom_type_transition_per_step = ( - self.one_atom_type_transition_per_step and not this_is_last_time_step - ) - - # atom types update - a_im1 = self._atom_types_update( - model_predictions_i.A, - composition_i.A, - q_matrices_i, - q_bar_matrices_i, - q_bar_tm1_matrices_i, - atom_type_greedy_sampling=self.atom_type_greedy_sampling, - one_atom_type_transition_per_step=one_atom_type_transition_per_step, - ) - - if this_is_last_time_step: - assert (a_im1 != self.masked_atom_type_index).all(), \ - "There remains MASKED atoms at the last time step: review code, there must be a bug or invalid input." - - # in the adaptative corrector approach, there is no predictor step applied on the X component - composition_im1 = AXL( - A=a_im1, X=composition_i.X, L=unit_cell - ) # TODO : Deal with L correctly - - if self.record: - # TODO : Deal with L correctly - composition_i_for_recording = AXL( - A=composition_i.A, X=composition_i.X, L=unit_cell - ) - # Keep the record on the CPU - entry = dict(time_step_index=index_i) - list_keys = ["composition_i", "composition_im1", "model_predictions_i"] - list_axl = [ - composition_i_for_recording, - composition_im1, - model_predictions_i, - ] - - for key, axl in zip(list_keys, list_axl): - record_axl = AXL( - A=axl.A.detach().cpu(), - X=axl.X.detach().cpu(), - L=axl.L.detach().cpu(), - ) - entry[key] = record_axl - self.sample_trajectory_recorder.record(key="predictor_step", entry=entry) - - return composition_im1 - - def corrector_step( - self, - composition_i: AXL, - index_i: int, - unit_cell: torch.Tensor, # TODO replace with AXL-L - cartesian_forces: torch.Tensor, - ) -> AXL: - r"""Corrector Step. - - Note this does not affect the atom types unless specified with the atom_type_transition_in_corrector argument. - Always affect the reduced coordinates and lattice vectors. The prefactors determining the changes in the X and L - variables are determined using the sigma normalized score at that corrector step. The relative coordinates - update is given by: - - .. math:: - - x_i \leftarrow x_i + \epsilon_i * s(x_i, t_i) + \sqrt(2 \epsilon_i) z - - where :math:`s(x_i, t_i)` is the score, :math:`z` is a random variable drawn from a normal distribution and - :math:`\epsilon_i` is given by: - - .. math:: - - \epsilon_i = 2 \left(r \frac{||z||_2}{||s(x_i, t_i)||_2}\right)^2 - - where :math:`r` is an hyper-parameter (0.17 by default) and :math:`||\cdot||_2` is the L2 norm. - - Args: - composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i. - index_i : index of the time step. - unit_cell: sampled unit cell at time step i. # TODO replace with AXL-L - cartesian_forces: forces conditioning the sampling - - Returns: - corrected_composition_i : sampled composition, after corrector step. - """ - assert 0 <= index_i <= self.number_of_discretization_steps - 1, ( - "The corrector step can only be invoked for index_i between 0 and " - "the total number of discretization steps minus 1." - ) - - if index_i == 0: - # TODO: we are extrapolating here; the score network will never have seen this time step... - sigma_i = ( - self.noise_parameters.sigma_min - ) # no need to change device, this is a float - t_i = 0.0 # same for device - this is a float - idx = index_i - else: - idx = index_i - 1 # python starts indices at zero - sigma_i = self.noise.sigma[idx].to(composition_i.X) - t_i = self.noise.time[idx].to(composition_i.X) - - model_predictions_i = self._get_model_predictions( - composition_i, t_i, sigma_i, unit_cell, cartesian_forces - ) - - # to compute epsilon_i, we need the norm of the score. We average over the atoms. - relative_coordinates_sigma_score_norm = ( - torch.linalg.norm(model_predictions_i.X, dim=-1).mean(dim=-1) - ).view(-1, 1, 1) - # note that sigma_score is \sigma * s(x, t), so we need to divide the norm by sigma to get the correct step size - relative_coordinates_sigma_score_norm /= sigma_i - # draw random noise - z = self._draw_gaussian_sample(relative_coordinates_sigma_score_norm.shape[0]).to(composition_i.X) - # and compute the norm - z_norm = torch.linalg.norm(z, dim=-1).mean(dim=-1).view(-1, 1, 1) - - eps_i = ( - 2 - * ( - self.corrector_r - * z_norm - / (relative_coordinates_sigma_score_norm.clip(min=self.small_epsilon)) - ) - ** 2 - ) - sqrt_2eps_i = torch.sqrt(2 * eps_i) - - corrected_x_i = self._relative_coordinates_update( - composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i, z=z - ) - - if self.atom_type_transition_in_corrector: - q_matrices_i = self.noise.q_matrix[idx].to(composition_i.X) - q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X) - q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X) - # atom types update - corrected_a_i = self._atom_types_update( - model_predictions_i.A, - composition_i.A, - q_matrices_i, - q_bar_matrices_i, - q_bar_tm1_matrices_i, - atom_type_greedy_sampling=self.atom_type_greedy_sampling, - one_atom_type_transition_per_step=self.one_atom_type_transition_per_step, - ) - else: - corrected_a_i = composition_i.A - - corrected_composition_i = AXL( - A=corrected_a_i, - X=corrected_x_i, - L=unit_cell, # TODO replace with AXL-L - ) - - if self.record and self.record_corrector: - # TODO : Deal with L correctly - composition_i_for_recording = AXL( - A=composition_i.A, X=composition_i.X, L=unit_cell - ) - # Keep the record on the CPU - entry = dict(time_step_index=index_i) - list_keys = [ - "composition_i", - "corrected_composition_i", - "model_predictions_i", - ] - list_axl = [ - composition_i_for_recording, - corrected_composition_i, - model_predictions_i, - ] - - for key, axl in zip(list_keys, list_axl): - record_axl = AXL( - A=axl.A.detach().cpu(), - X=axl.X.detach().cpu(), - L=axl.L.detach().cpu(), - ) - entry[key] = record_axl - - self.sample_trajectory_recorder.record(key="corrector_step", entry=entry) - - return corrected_composition_i diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py new file mode 100644 index 00000000..3b1a5193 --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py @@ -0,0 +1,95 @@ +from typing import Optional + +import torch + +from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ + LangevinGenerator +from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ + PredictorCorrectorSamplingParameters +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ + ScoreNetwork +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters + + +class AdaptiveCorrectorGenerator(LangevinGenerator): + """Langevin Dynamics Generator using only a corrector step with adaptative step size for relative coordinates. + + This class implements the Langevin Corrector generation of position samples, following + Song et. al. 2021, namely: + "SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS" + """ + + def __init__( + self, + noise_parameters: NoiseParameters, + sampling_parameters: PredictorCorrectorSamplingParameters, + axl_network: ScoreNetwork, + ): + """Init method.""" + super().__init__( + noise_parameters=noise_parameters, + sampling_parameters=sampling_parameters, + axl_network=axl_network, + ) + self.corrector_r = noise_parameters.corrector_r + + def _relative_coordinates_update_predictor_step( + self, + relative_coordinates: torch.Tensor, + sigma_normalized_scores: torch.Tensor, + sigma_i: torch.Tensor, + score_weight: torch.Tensor, + gaussian_noise_weight: torch.Tensor, + z: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Do not update the relative coordinates in the predictor.""" + return relative_coordinates + + def _get_corrector_step_size( + self, + index_i: int, + sigma_i: torch.Tensor, + model_predictions_i: AXL, + z: torch.Tensor, + ) -> torch.Tensor: + r"""Compute the size of the corrector step for the relative coordinates update. + + Always affect the reduced coordinates and lattice vectors. The prefactors determining the changes in the + relative coordinates are determined using the sigma normalized score at that corrector step. The relative + coordinates update is given by: + + .. math:: + + x_i \leftarrow x_i + \epsilon_i * s(x_i, t_i) + \sqrt(2 \epsilon_i) z + + where :math:`s(x_i, t_i)` is the score, :math:`z` is a random variable drawn from a normal distribution and + :math:`\epsilon_i` is given by: + + .. math:: + + \epsilon_i = 2 \left(r \frac{||z||_2}{||s(x_i, t_i)||_2}\right)^2 + + where :math:`r` is an hyper-parameter (0.17 by default) and :math:`||\cdot||_2` is the L2 norm. + """ + # to compute epsilon_i, we need the norm of the score. We average over the atoms. + relative_coordinates_sigma_score_norm = ( + torch.linalg.norm(model_predictions_i.X, dim=-1).mean(dim=-1) + ).view(-1, 1, 1) + # note that sigma_score is \sigma * s(x, t), so we need to divide the norm by sigma to get the correct step size + relative_coordinates_sigma_score_norm /= sigma_i + # compute the norm of the z random noise + z_norm = torch.linalg.norm(z, dim=-1).mean(dim=-1).view(-1, 1, 1) + + eps_i = ( + 2 + * ( + self.corrector_r + * z_norm + / (relative_coordinates_sigma_score_norm.clip(min=self.small_epsilon)) + ) + ** 2 + ) + + return eps_i diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py index 433eb656..76dd530c 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py @@ -1,5 +1,5 @@ -from diffusion_for_multi_scale_molecular_dynamics.generators.adaptative_corrector import \ - AdaptativeCorrectorGenerator +from diffusion_for_multi_scale_molecular_dynamics.generators.adaptive_corrector import \ + AdaptiveCorrectorGenerator from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \ SamplingParameters from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ @@ -24,8 +24,8 @@ def instantiate_generator( "ode", "sde", "predictor_corrector", - "adaptative_corrector", - ], "Unknown algorithm. Possible choices are 'ode', 'sde', 'predictor_corrector' and 'adaptative_corrector'" + "adaptive_corrector", + ], "Unknown algorithm. Possible choices are 'ode', 'sde', 'predictor_corrector' and 'adaptive_corrector'" match sampling_parameters.algorithm: case "predictor_corrector": @@ -34,8 +34,8 @@ def instantiate_generator( noise_parameters=noise_parameters, axl_network=axl_network, ) - case "adaptative_corrector": - generator = AdaptativeCorrectorGenerator( + case "adaptive_corrector": + generator = AdaptiveCorrectorGenerator( sampling_parameters=sampling_parameters, noise_parameters=noise_parameters, axl_network=axl_network, diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py index 9ceb2c71..d12d683f 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Optional, Tuple +from typing import Tuple import einops import torch @@ -168,7 +168,7 @@ def _relative_coordinates_update( sigma_i: torch.Tensor, score_weight: torch.Tensor, gaussian_noise_weight: torch.Tensor, - z: Optional[torch.Tensor] = None, + z: torch.Tensor, ) -> torch.Tensor: r"""Generic update for the relative coordinates. @@ -187,8 +187,8 @@ def _relative_coordinates_update( eps_i in the corrector step. Dimension: [number_of_samples] gaussian_noise_weight: prefactor in front of the random noise update. Should be g_i in the predictor step and sqrt_2eps_i in the corrector step. Dimension: [number_of_samples] - z: gaussian noise used to update the coordinates. If None, a sample is drawn from the normal distribution. - Dimension: [number_of_samples, number_of_atoms, spatial_dimension]. Defaults to None. + z: gaussian noise used to update the coordinates. A sample drawn from the normal distribution. + Dimension: [number_of_samples, number_of_atoms, spatial_dimension]. Returns: updated_coordinates: relative coordinates after the update. Dimension: [number_of_samples, number_of_atoms, @@ -206,6 +206,50 @@ def _relative_coordinates_update( updated_coordinates = map_relative_coordinates_to_unit_cell(updated_coordinates) return updated_coordinates + def _relative_coordinates_update_predictor_step( + self, + relative_coordinates: torch.Tensor, + sigma_normalized_scores: torch.Tensor, + sigma_i: torch.Tensor, + score_weight: torch.Tensor, + gaussian_noise_weight: torch.Tensor, + z: torch.Tensor, + ) -> torch.Tensor: + """Relative coordinates update for the predictor step. + + This returns the generic _relative_coordinates_update. + """ + return self._relative_coordinates_update( + relative_coordinates, + sigma_normalized_scores, + sigma_i, + score_weight, + gaussian_noise_weight, + z, + ) + + def _relative_coordinates_update_corrector_step( + self, + relative_coordinates: torch.Tensor, + sigma_normalized_scores: torch.Tensor, + sigma_i: torch.Tensor, + score_weight: torch.Tensor, + gaussian_noise_weight: torch.Tensor, + z: torch.Tensor, + ) -> torch.Tensor: + """Relative coordinates update for the corrector step. + + This returns the generic _relative_coordinates_update. + """ + return self._relative_coordinates_update( + relative_coordinates, + sigma_normalized_scores, + sigma_i, + score_weight, + gaussian_noise_weight, + z, + ) + def _atom_types_update( self, predicted_logits: torch.Tensor, @@ -480,8 +524,11 @@ def predictor_step( assert (a_im1 != self.masked_atom_type_index).all(), \ "There remains MASKED atoms at the last time step: review code, there must be a bug or invalid input." - x_im1 = self._relative_coordinates_update( - composition_i.X, model_predictions_i.X, sigma_i, g2_i, g_i + # draw a gaussian noise sample and update the positions accordingly + z = self._draw_gaussian_sample(number_of_samples).to(composition_i.X) + + x_im1 = self._relative_coordinates_update_predictor_step( + composition_i.X, model_predictions_i.X, sigma_i, g2_i, g_i, z ) composition_im1 = AXL( @@ -513,6 +560,20 @@ def predictor_step( return composition_im1 + def _get_corrector_step_size( + self, + index_i: int, + sigma_i: torch.Tensor, + model_predictions_i: AXL, + z: torch.Tensor, + ) -> torch.Tensor: + """Compute the size of the corrector step for the relative coordinates update.""" + del sigma_i # sigma_i is not needed to compute \epsilon_i because it is pre-tabulated + del z # noise tensor is not needed + # Get the epsilon from the tabulated Langevin dynamics array indexed with [0,..., N-1]. + eps_i = self.langevin_dynamics.epsilon[index_i].to(model_predictions_i.X) + return eps_i + def corrector_step( self, composition_i: AXL, @@ -522,8 +583,8 @@ def corrector_step( ) -> AXL: """Corrector Step. - Note this dones not affect the atom types unless specified with the atom_type_transition_in_corrector argument. - Always affect the reduced coordinates and lattice vectors. + Note this does not not affect the atom types unless specified with the atom_type_transition_in_corrector + argument. Always affect the reduced coordinates and lattice vectors. Args: composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i. @@ -538,9 +599,8 @@ def corrector_step( "The corrector step can only be invoked for index_i between 0 and " "the total number of discretization steps minus 1." ) - # The Langevin dynamics array are indexed with [0,..., N-1] - eps_i = self.langevin_dynamics.epsilon[index_i].to(composition_i.X) - sqrt_2eps_i = self.langevin_dynamics.sqrt_2_epsilon[index_i].to(composition_i.X) + + number_of_samples = composition_i.X.shape[0] if index_i == 0: # TODO: we are extrapolating here; the score network will never have seen this time step... @@ -558,8 +618,16 @@ def corrector_step( composition_i, t_i, sigma_i, unit_cell, cartesian_forces ) - corrected_x_i = self._relative_coordinates_update( - composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i + # draw a gaussian noise sample and update the positions accordingly + z = self._draw_gaussian_sample(number_of_samples).to(composition_i.X) + + # get the step size eps_i + eps_i = self._get_corrector_step_size(index_i, sigma_i, model_predictions_i, z) + # the size for the noise part is sqrt(2 * eps_i) + sqrt_2eps_i = torch.sqrt(2 * eps_i) + + corrected_x_i = self._relative_coordinates_update_corrector_step( + composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i, z ) if self.atom_type_transition_in_corrector: diff --git a/tests/generators/test_adaptative_corrector.py b/tests/generators/test_adaptative_corrector.py index 0dac2c0d..012d513a 100644 --- a/tests/generators/test_adaptative_corrector.py +++ b/tests/generators/test_adaptative_corrector.py @@ -1,8 +1,8 @@ import pytest import torch -from diffusion_for_multi_scale_molecular_dynamics.generators.adaptative_corrector import \ - AdaptativeCorrectorGenerator +from diffusion_for_multi_scale_molecular_dynamics.generators.adaptive_corrector import \ + AdaptiveCorrectorGenerator from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ @@ -27,7 +27,7 @@ def noise_parameters(self, total_time_steps): @pytest.fixture() def pc_generator(self, noise_parameters, sampling_parameters, axl_network): # override the base class - generator = AdaptativeCorrectorGenerator( + generator = AdaptiveCorrectorGenerator( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, axl_network=axl_network, From 45facdd2bb5a23bd4eabf238e93287ee137b4917 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Mon, 2 Dec 2024 14:13:25 -0500 Subject: [PATCH 08/10] changing the norm calculation in adaptive langevin and unit test --- .../generators/adaptive_corrector.py | 13 ++++++++----- tests/generators/test_adaptative_corrector.py | 4 ++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py index 3b1a5193..1020ca49 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py @@ -73,14 +73,17 @@ def _get_corrector_step_size( where :math:`r` is an hyper-parameter (0.17 by default) and :math:`||\cdot||_2` is the L2 norm. """ - # to compute epsilon_i, we need the norm of the score. We average over the atoms. + # to compute epsilon_i, we need the norm of the score summed over the atoms and averaged over the mini-batch. + # taking the norm over the last 2 dimensions means summing the squared components over the spatial dimension and + # the atoms, then taking the square-root. + # the mean averages over the mini-batch relative_coordinates_sigma_score_norm = ( - torch.linalg.norm(model_predictions_i.X, dim=-1).mean(dim=-1) - ).view(-1, 1, 1) + torch.linalg.norm(model_predictions_i.X, dim=[-2, -1]).mean() + ).view(1, 1, 1) # note that sigma_score is \sigma * s(x, t), so we need to divide the norm by sigma to get the correct step size relative_coordinates_sigma_score_norm /= sigma_i - # compute the norm of the z random noise - z_norm = torch.linalg.norm(z, dim=-1).mean(dim=-1).view(-1, 1, 1) + # compute the norm of the z random noise similarly + z_norm = torch.linalg.norm(z, dim=[-2, -1]).mean().view(1, 1, 1) eps_i = ( 2 diff --git a/tests/generators/test_adaptative_corrector.py b/tests/generators/test_adaptative_corrector.py index 012d513a..e4e732fa 100644 --- a/tests/generators/test_adaptative_corrector.py +++ b/tests/generators/test_adaptative_corrector.py @@ -85,7 +85,7 @@ def test_corrector_step( z = pc_generator._draw_gaussian_sample(number_of_samples).to(axl_i.X) mocker.patch.object(pc_generator, "_draw_gaussian_sample", return_value=z) - z_norm = torch.sqrt((z**2).sum(dim=-1)).mean( + z_norm = torch.sqrt((z**2).sum(dim=-1).sum(dim=-1)).mean( dim=-1 ) # norm of z averaged over atoms @@ -107,7 +107,7 @@ def test_corrector_step( ).X / sigma_i ) - s_i_norm = torch.sqrt((s_i**2).sum(dim=-1)).mean(dim=-1) + s_i_norm = torch.sqrt((s_i**2).sum(dim=-1).sum(dim=-1)).mean(dim=-1) # \epsilon_i = 2 \left(r \frac{||z||_2}{||s(x_i, t_i)||_2}\right)^2 eps_i = ( 2 From 535335ed81ea9ddc54a32147512b71d71bb0b58b Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Tue, 3 Dec 2024 10:12:25 -0500 Subject: [PATCH 09/10] typos and small fixes for code review --- .../generators/adaptive_corrector.py | 2 +- .../generators/langevin_generator.py | 4 +--- .../generators/sde_position_generator.py | 4 ++-- .../noise_schedulers/noise_parameters.py | 2 +- ...est_adaptative_corrector.py => test_adaptive_corrector.py} | 2 +- 5 files changed, 6 insertions(+), 8 deletions(-) rename tests/generators/{test_adaptative_corrector.py => test_adaptive_corrector.py} (98%) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py index 1020ca49..d24c885a 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py @@ -14,7 +14,7 @@ class AdaptiveCorrectorGenerator(LangevinGenerator): - """Langevin Dynamics Generator using only a corrector step with adaptative step size for relative coordinates. + """Langevin Dynamics Generator using only a corrector step with adaptive step size for relative coordinates. This class implements the Langevin Corrector generation of position samples, following Song et. al. 2021, namely: diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py index d12d683f..60b1ad0a 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py @@ -568,8 +568,6 @@ def _get_corrector_step_size( z: torch.Tensor, ) -> torch.Tensor: """Compute the size of the corrector step for the relative coordinates update.""" - del sigma_i # sigma_i is not needed to compute \epsilon_i because it is pre-tabulated - del z # noise tensor is not needed # Get the epsilon from the tabulated Langevin dynamics array indexed with [0,..., N-1]. eps_i = self.langevin_dynamics.epsilon[index_i].to(model_predictions_i.X) return eps_i @@ -583,7 +581,7 @@ def corrector_step( ) -> AXL: """Corrector Step. - Note this does not not affect the atom types unless specified with the atom_type_transition_in_corrector + Note this does not affect the atom types unless specified with the atom_type_transition_in_corrector argument. Always affect the reduced coordinates and lattice vectors. Args: diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/sde_position_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/sde_position_generator.py index 3531b9aa..5ead1442 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/sde_position_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/sde_position_generator.py @@ -31,7 +31,7 @@ class SDESamplingParameters(SamplingParameters): algorithm: str = "sde" sde_type: str = "ito" method: str = "euler" - adaptative: bool = False + adaptive: bool = False absolute_solver_tolerance: float = ( 1.0e-7 # the absolute error tolerance passed to the SDE solver. ) @@ -325,7 +325,7 @@ def sample( sde_times, method=self.sampling_parameters.method, dt=dt, - adaptive=self.sampling_parameters.adaptative, + adaptive=self.sampling_parameters.adaptive, atol=self.sampling_parameters.absolute_solver_tolerance, rtol=self.sampling_parameters.relative_solver_tolerance, ) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py b/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py index de105c85..d8b6093f 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py @@ -22,7 +22,7 @@ class NoiseParameters: # Default value comes from "Generative Modeling by Estimating Gradients of the Data Distribution" corrector_step_epsilon: float = 2e-5 - # Step size scaling for the Adaptative Corrector Generator. Default value comes from github implementation + # Step size scaling for the Adaptive Corrector Generator. Default value comes from github implementation # https: // github.com / yang - song / score_sde / blob / main / configs / default_celeba_configs.py # for the celeba dataset. Note the suggested value for CIFAR10 is 0.16 in that repo. corrector_r: float = 0.17 diff --git a/tests/generators/test_adaptative_corrector.py b/tests/generators/test_adaptive_corrector.py similarity index 98% rename from tests/generators/test_adaptative_corrector.py rename to tests/generators/test_adaptive_corrector.py index e4e732fa..64821431 100644 --- a/tests/generators/test_adaptative_corrector.py +++ b/tests/generators/test_adaptive_corrector.py @@ -12,7 +12,7 @@ from tests.generators.test_langevin_generator import TestLangevinGenerator -class TestAdaptativeCorrectorGenerator(TestLangevinGenerator): +class TestAdaptiveCorrectorGenerator(TestLangevinGenerator): @pytest.fixture() def noise_parameters(self, total_time_steps): From 568bd7862e23067e0ceb478598e69e8fb48c7889 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 11:10:53 -0500 Subject: [PATCH 10/10] Set a seed for the Generator tests. --- tests/generators/conftest.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/generators/conftest.py b/tests/generators/conftest.py index 572e27c4..23eead1b 100644 --- a/tests/generators/conftest.py +++ b/tests/generators/conftest.py @@ -29,6 +29,10 @@ def _forward_unchecked( class BaseTestGenerator: """A base class that contains common test fixtures useful for testing generators.""" + @pytest.fixture(scope="class", autouse=True) + def set_random_seed(self): + torch.manual_seed(34534534) + @pytest.fixture() def unit_cell_size(self): return 10