diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py new file mode 100644 index 00000000..d24c885a --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptive_corrector.py @@ -0,0 +1,98 @@ +from typing import Optional + +import torch + +from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ + LangevinGenerator +from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ + PredictorCorrectorSamplingParameters +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ + ScoreNetwork +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters + + +class AdaptiveCorrectorGenerator(LangevinGenerator): + """Langevin Dynamics Generator using only a corrector step with adaptive step size for relative coordinates. + + This class implements the Langevin Corrector generation of position samples, following + Song et. al. 2021, namely: + "SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS" + """ + + def __init__( + self, + noise_parameters: NoiseParameters, + sampling_parameters: PredictorCorrectorSamplingParameters, + axl_network: ScoreNetwork, + ): + """Init method.""" + super().__init__( + noise_parameters=noise_parameters, + sampling_parameters=sampling_parameters, + axl_network=axl_network, + ) + self.corrector_r = noise_parameters.corrector_r + + def _relative_coordinates_update_predictor_step( + self, + relative_coordinates: torch.Tensor, + sigma_normalized_scores: torch.Tensor, + sigma_i: torch.Tensor, + score_weight: torch.Tensor, + gaussian_noise_weight: torch.Tensor, + z: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Do not update the relative coordinates in the predictor.""" + return relative_coordinates + + def _get_corrector_step_size( + self, + index_i: int, + sigma_i: torch.Tensor, + model_predictions_i: AXL, + z: torch.Tensor, + ) -> torch.Tensor: + r"""Compute the size of the corrector step for the relative coordinates update. + + Always affect the reduced coordinates and lattice vectors. The prefactors determining the changes in the + relative coordinates are determined using the sigma normalized score at that corrector step. The relative + coordinates update is given by: + + .. math:: + + x_i \leftarrow x_i + \epsilon_i * s(x_i, t_i) + \sqrt(2 \epsilon_i) z + + where :math:`s(x_i, t_i)` is the score, :math:`z` is a random variable drawn from a normal distribution and + :math:`\epsilon_i` is given by: + + .. math:: + + \epsilon_i = 2 \left(r \frac{||z||_2}{||s(x_i, t_i)||_2}\right)^2 + + where :math:`r` is an hyper-parameter (0.17 by default) and :math:`||\cdot||_2` is the L2 norm. + """ + # to compute epsilon_i, we need the norm of the score summed over the atoms and averaged over the mini-batch. + # taking the norm over the last 2 dimensions means summing the squared components over the spatial dimension and + # the atoms, then taking the square-root. + # the mean averages over the mini-batch + relative_coordinates_sigma_score_norm = ( + torch.linalg.norm(model_predictions_i.X, dim=[-2, -1]).mean() + ).view(1, 1, 1) + # note that sigma_score is \sigma * s(x, t), so we need to divide the norm by sigma to get the correct step size + relative_coordinates_sigma_score_norm /= sigma_i + # compute the norm of the z random noise similarly + z_norm = torch.linalg.norm(z, dim=[-2, -1]).mean().view(1, 1, 1) + + eps_i = ( + 2 + * ( + self.corrector_r + * z_norm + / (relative_coordinates_sigma_score_norm.clip(min=self.small_epsilon)) + ) + ** 2 + ) + + return eps_i diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py index af897328..76dd530c 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py @@ -1,3 +1,5 @@ +from diffusion_for_multi_scale_molecular_dynamics.generators.adaptive_corrector import \ + AdaptiveCorrectorGenerator from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \ SamplingParameters from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ @@ -22,7 +24,8 @@ def instantiate_generator( "ode", "sde", "predictor_corrector", - ], "Unknown algorithm. Possible choices are 'ode', 'sde' and 'predictor_corrector'" + "adaptive_corrector", + ], "Unknown algorithm. Possible choices are 'ode', 'sde', 'predictor_corrector' and 'adaptive_corrector'" match sampling_parameters.algorithm: case "predictor_corrector": @@ -31,6 +34,12 @@ def instantiate_generator( noise_parameters=noise_parameters, axl_network=axl_network, ) + case "adaptive_corrector": + generator = AdaptiveCorrectorGenerator( + sampling_parameters=sampling_parameters, + noise_parameters=noise_parameters, + axl_network=axl_network, + ) case "ode": generator = ExplodingVarianceODEAXLGenerator( sampling_parameters=sampling_parameters, diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py index 2d68dc77..60b1ad0a 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py @@ -168,6 +168,7 @@ def _relative_coordinates_update( sigma_i: torch.Tensor, score_weight: torch.Tensor, gaussian_noise_weight: torch.Tensor, + z: torch.Tensor, ) -> torch.Tensor: r"""Generic update for the relative coordinates. @@ -186,13 +187,16 @@ def _relative_coordinates_update( eps_i in the corrector step. Dimension: [number_of_samples] gaussian_noise_weight: prefactor in front of the random noise update. Should be g_i in the predictor step and sqrt_2eps_i in the corrector step. Dimension: [number_of_samples] + z: gaussian noise used to update the coordinates. A sample drawn from the normal distribution. + Dimension: [number_of_samples, number_of_atoms, spatial_dimension]. Returns: updated_coordinates: relative coordinates after the update. Dimension: [number_of_samples, number_of_atoms, spatial_dimension]. """ number_of_samples = relative_coordinates.shape[0] - z = self._draw_gaussian_sample(number_of_samples).to(relative_coordinates) + if z is None: + z = self._draw_gaussian_sample(number_of_samples).to(relative_coordinates) updated_coordinates = ( relative_coordinates + score_weight * sigma_normalized_scores / sigma_i @@ -202,6 +206,50 @@ def _relative_coordinates_update( updated_coordinates = map_relative_coordinates_to_unit_cell(updated_coordinates) return updated_coordinates + def _relative_coordinates_update_predictor_step( + self, + relative_coordinates: torch.Tensor, + sigma_normalized_scores: torch.Tensor, + sigma_i: torch.Tensor, + score_weight: torch.Tensor, + gaussian_noise_weight: torch.Tensor, + z: torch.Tensor, + ) -> torch.Tensor: + """Relative coordinates update for the predictor step. + + This returns the generic _relative_coordinates_update. + """ + return self._relative_coordinates_update( + relative_coordinates, + sigma_normalized_scores, + sigma_i, + score_weight, + gaussian_noise_weight, + z, + ) + + def _relative_coordinates_update_corrector_step( + self, + relative_coordinates: torch.Tensor, + sigma_normalized_scores: torch.Tensor, + sigma_i: torch.Tensor, + score_weight: torch.Tensor, + gaussian_noise_weight: torch.Tensor, + z: torch.Tensor, + ) -> torch.Tensor: + """Relative coordinates update for the corrector step. + + This returns the generic _relative_coordinates_update. + """ + return self._relative_coordinates_update( + relative_coordinates, + sigma_normalized_scores, + sigma_i, + score_weight, + gaussian_noise_weight, + z, + ) + def _atom_types_update( self, predicted_logits: torch.Tensor, @@ -476,8 +524,11 @@ def predictor_step( assert (a_im1 != self.masked_atom_type_index).all(), \ "There remains MASKED atoms at the last time step: review code, there must be a bug or invalid input." - x_im1 = self._relative_coordinates_update( - composition_i.X, model_predictions_i.X, sigma_i, g2_i, g_i + # draw a gaussian noise sample and update the positions accordingly + z = self._draw_gaussian_sample(number_of_samples).to(composition_i.X) + + x_im1 = self._relative_coordinates_update_predictor_step( + composition_i.X, model_predictions_i.X, sigma_i, g2_i, g_i, z ) composition_im1 = AXL( @@ -509,6 +560,18 @@ def predictor_step( return composition_im1 + def _get_corrector_step_size( + self, + index_i: int, + sigma_i: torch.Tensor, + model_predictions_i: AXL, + z: torch.Tensor, + ) -> torch.Tensor: + """Compute the size of the corrector step for the relative coordinates update.""" + # Get the epsilon from the tabulated Langevin dynamics array indexed with [0,..., N-1]. + eps_i = self.langevin_dynamics.epsilon[index_i].to(model_predictions_i.X) + return eps_i + def corrector_step( self, composition_i: AXL, @@ -518,7 +581,8 @@ def corrector_step( ) -> AXL: """Corrector Step. - Note this is not affecting the atom types. Only the reduced coordinates and lattice vectors. + Note this does not affect the atom types unless specified with the atom_type_transition_in_corrector + argument. Always affect the reduced coordinates and lattice vectors. Args: composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i. @@ -533,9 +597,8 @@ def corrector_step( "The corrector step can only be invoked for index_i between 0 and " "the total number of discretization steps minus 1." ) - # The Langevin dynamics array are indexed with [0,..., N-1] - eps_i = self.langevin_dynamics.epsilon[index_i].to(composition_i.X) - sqrt_2eps_i = self.langevin_dynamics.sqrt_2_epsilon[index_i].to(composition_i.X) + + number_of_samples = composition_i.X.shape[0] if index_i == 0: # TODO: we are extrapolating here; the score network will never have seen this time step... @@ -553,8 +616,16 @@ def corrector_step( composition_i, t_i, sigma_i, unit_cell, cartesian_forces ) - corrected_x_i = self._relative_coordinates_update( - composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i + # draw a gaussian noise sample and update the positions accordingly + z = self._draw_gaussian_sample(number_of_samples).to(composition_i.X) + + # get the step size eps_i + eps_i = self._get_corrector_step_size(index_i, sigma_i, model_predictions_i, z) + # the size for the noise part is sqrt(2 * eps_i) + sqrt_2eps_i = torch.sqrt(2 * eps_i) + + corrected_x_i = self._relative_coordinates_update_corrector_step( + composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i, z ) if self.atom_type_transition_in_corrector: diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/sde_position_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/sde_position_generator.py index 3531b9aa..5ead1442 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/sde_position_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/sde_position_generator.py @@ -31,7 +31,7 @@ class SDESamplingParameters(SamplingParameters): algorithm: str = "sde" sde_type: str = "ito" method: str = "euler" - adaptative: bool = False + adaptive: bool = False absolute_solver_tolerance: float = ( 1.0e-7 # the absolute error tolerance passed to the SDE solver. ) @@ -325,7 +325,7 @@ def sample( sde_times, method=self.sampling_parameters.method, dt=dt, - adaptive=self.sampling_parameters.adaptative, + adaptive=self.sampling_parameters.adaptive, atol=self.sampling_parameters.absolute_solver_tolerance, rtol=self.sampling_parameters.relative_solver_tolerance, ) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py b/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py index ae34bb85..d8b6093f 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py @@ -21,3 +21,8 @@ class NoiseParameters: # Default value comes from "Generative Modeling by Estimating Gradients of the Data Distribution" corrector_step_epsilon: float = 2e-5 + + # Step size scaling for the Adaptive Corrector Generator. Default value comes from github implementation + # https: // github.com / yang - song / score_sde / blob / main / configs / default_celeba_configs.py + # for the celeba dataset. Note the suggested value for CIFAR10 is 0.16 in that repo. + corrector_r: float = 0.17 diff --git a/tests/generators/conftest.py b/tests/generators/conftest.py index 572e27c4..23eead1b 100644 --- a/tests/generators/conftest.py +++ b/tests/generators/conftest.py @@ -29,6 +29,10 @@ def _forward_unchecked( class BaseTestGenerator: """A base class that contains common test fixtures useful for testing generators.""" + @pytest.fixture(scope="class", autouse=True) + def set_random_seed(self): + torch.manual_seed(34534534) + @pytest.fixture() def unit_cell_size(self): return 10 diff --git a/tests/generators/test_adaptive_corrector.py b/tests/generators/test_adaptive_corrector.py new file mode 100644 index 00000000..64821431 --- /dev/null +++ b/tests/generators/test_adaptive_corrector.py @@ -0,0 +1,124 @@ +import pytest +import torch + +from diffusion_for_multi_scale_molecular_dynamics.generators.adaptive_corrector import \ + AdaptiveCorrectorGenerator +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell +from src.diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + NoiseScheduler +from tests.generators.test_langevin_generator import TestLangevinGenerator + + +class TestAdaptiveCorrectorGenerator(TestLangevinGenerator): + + @pytest.fixture() + def noise_parameters(self, total_time_steps): + noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + time_delta=0.1, + sigma_min=0.15, + corrector_r=0.15, + ) + return noise_parameters + + @pytest.fixture() + def pc_generator(self, noise_parameters, sampling_parameters, axl_network): + # override the base class + generator = AdaptiveCorrectorGenerator( + noise_parameters=noise_parameters, + sampling_parameters=sampling_parameters, + axl_network=axl_network, + ) + + return generator + + def test_predictor_step_relative_coordinates( + self, + mocker, + pc_generator, + noise_parameters, + axl_i, + total_time_steps, + number_of_samples, + unit_cell_sample, + num_atomic_classes, + device, + ): + # override the base class + forces = torch.zeros_like(axl_i.X) + + for index_i in range(1, total_time_steps + 1): + computed_sample = pc_generator.predictor_step( + axl_i, index_i, unit_cell_sample, forces + ) + + expected_coordinates = axl_i.X + expected_coordinates = map_relative_coordinates_to_unit_cell( + expected_coordinates + ) + # this is almost trivial - the coordinates should not change in a predictor step + torch.testing.assert_close(computed_sample.X, expected_coordinates) + + @pytest.mark.parametrize("corrector_r", [0.1, 0.5, 1.2]) + def test_corrector_step( + self, + mocker, + corrector_r, + pc_generator, + noise_parameters, + axl_i, + total_time_steps, + number_of_samples, + unit_cell_sample, + num_atomic_classes, + ): + pc_generator.corrector_r = corrector_r + sampler = NoiseScheduler(noise_parameters, num_classes=num_atomic_classes) + noise, _ = sampler.get_all_sampling_parameters() + sigma_min = noise_parameters.sigma_min + list_sigma = noise.sigma + list_time = noise.time + forces = torch.zeros_like(axl_i.X) + + z = pc_generator._draw_gaussian_sample(number_of_samples).to(axl_i.X) + mocker.patch.object(pc_generator, "_draw_gaussian_sample", return_value=z) + z_norm = torch.sqrt((z**2).sum(dim=-1).sum(dim=-1)).mean( + dim=-1 + ) # norm of z averaged over atoms + + for index_i in range(0, total_time_steps): + computed_sample = pc_generator.corrector_step( + axl_i, index_i, unit_cell_sample, forces + ) + + if index_i == 0: + sigma_i = sigma_min + t_i = 0.0 + else: + sigma_i = list_sigma[index_i - 1] + t_i = list_time[index_i - 1] + + s_i = ( + pc_generator._get_model_predictions( + axl_i, t_i, sigma_i, unit_cell_sample, forces + ).X + / sigma_i + ) + s_i_norm = torch.sqrt((s_i**2).sum(dim=-1).sum(dim=-1)).mean(dim=-1) + # \epsilon_i = 2 \left(r \frac{||z||_2}{||s(x_i, t_i)||_2}\right)^2 + eps_i = ( + 2 + * (corrector_r * z_norm / s_i_norm.clip(min=pc_generator.small_epsilon)) + ** 2 + ) + eps_i = eps_i.view(-1, 1, 1) + + expected_coordinates = axl_i.X + eps_i * s_i + torch.sqrt(2.0 * eps_i) * z + expected_coordinates = map_relative_coordinates_to_unit_cell( + expected_coordinates + ) + + torch.testing.assert_close(computed_sample.X, expected_coordinates)