From 37a778105a4fb0b639a2b2b8c85222393f185ad0 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Wed, 27 Nov 2024 19:37:24 -0500 Subject: [PATCH 01/10] Record more stuff. --- .../generators/axl_generator.py | 1 + .../generators/langevin_generator.py | 17 ++++++++++++++++- .../utils/sample_trajectory.py | 5 +++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/axl_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/axl_generator.py index 18a7f047..1c4c1989 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/axl_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/axl_generator.py @@ -28,6 +28,7 @@ class SamplingParameters: False # should the predictor and corrector steps be recorded to a file ) record_samples_corrector_steps: bool = False + record_atom_type_update: bool = False # record the information pertaining to generating atom types. class AXLGenerator(ABC): diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py index 73335adf..a2c2580b 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py @@ -63,6 +63,10 @@ def __init__( self.record = sampling_parameters.record_samples self.record_corrector = sampling_parameters.record_samples_corrector_steps + self.record_atom_type_update = sampling_parameters.record_atom_type_update + + if self.record_corrector or self.record_atom_type_update: + assert self.record, "Corrector steps or atom_type_update can only be recorded if record_samples is True." if self.record: self.sample_trajectory_recorder = SampleTrajectory() @@ -274,6 +278,17 @@ def atom_types_update( ] ) # TODO some sanity check at the last step because this approach does not guarantee a full transition... + + if self.record_atom_type_update: + # Keep the record on the CPU + entry = dict(predicted_logits=predicted_logits.detach().cpu(), + one_step_transition_probabilities=one_step_transition_probs.detach().cpu(), + gumbel_sample=u.cpu(), + a_i=atom_types_i.cpu(), + a_im1=a_im1.cpu()) + + self.sample_trajectory_recorder.record(key='atom_type_update', entry=entry) + return a_im1 def adjust_atom_types_probabilities_for_greedy_sampling( @@ -468,7 +483,7 @@ def corrector_step( L=unit_cell, # TODO replace with AXL-L ) - if self.record and self.record_corrector: + if self.record_corrector: # TODO : Deal with L correctly composition_i_for_recording = AXL(A=composition_i.A, X=composition_i.X, diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/sample_trajectory.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/sample_trajectory.py index 606a466c..089b9ef2 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/sample_trajectory.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/sample_trajectory.py @@ -35,5 +35,10 @@ def record(self, key: str, entry: Union[Dict[str, Any], NamedTuple]): def write_to_pickle(self, path_to_pickle: str): """Write data to pickle file.""" + self._internal_data = dict(self._internal_data) + for key, value in self._internal_data.items(): + if len(value) == 1: + self._internal_data[key] = value[0] + with open(path_to_pickle, "wb") as fd: torch.save(self._internal_data, fd) From 1bb875cc9d2b55deb36fc1f1c4105b16743e5ec0 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 28 Nov 2024 14:58:39 -0500 Subject: [PATCH 02/10] Align with how things are recorded. --- .../analysis/sample_trajectory_analyser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/analysis/sample_trajectory_analyser.py b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/sample_trajectory_analyser.py index c4591ab3..039d960c 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/analysis/sample_trajectory_analyser.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/sample_trajectory_analyser.py @@ -32,7 +32,7 @@ def __init__(self, pickle_path: Path, num_classes: int): data = torch.load(pickle_path, map_location=torch.device("cpu")) logger.info("Done reading data.") - noise_parameters = NoiseParameters(**data['noise_parameters'][0]) + noise_parameters = NoiseParameters(**data['noise_parameters']) sampler = NoiseScheduler(noise_parameters, num_classes=num_classes) self.noise, _ = sampler.get_all_sampling_parameters() From c9e77ed73201b2a80a37ae9e5f9b5ad7e0bf7799 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 28 Nov 2024 15:00:18 -0500 Subject: [PATCH 03/10] Better probability calculation. --- .../utils/d3pm_utils.py | 34 ++++++++++--- tests/utils/test_d3pm_utils.py | 51 ++++++++++++++++++- 2 files changed, 78 insertions(+), 7 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/d3pm_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/d3pm_utils.py index c9fb17fb..8e36755b 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/d3pm_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/d3pm_utils.py @@ -99,9 +99,8 @@ def get_probability_at_previous_time_step( one-step transition normalized probabilities of dimension [batch_size, number_of_atoms, num_type_atoms] """ if probability_at_zeroth_timestep_are_logits: - probability_at_zeroth_timestep = torch.nn.functional.softmax( - probability_at_zeroth_timestep, dim=-1 - ) + probability_at_zeroth_timestep = get_probability_from_logits(probability_at_zeroth_timestep, + lowest_probability_value=small_epsilon) numerator1 = einops.einsum( probability_at_zeroth_timestep, q_bar_tm1_matrices, "... j, ... j i -> ... i" @@ -116,12 +115,35 @@ def get_probability_at_previous_time_step( one_hot_probability_at_current_timestep, "... i j, ... j -> ... i", ) - den2 = einops.einsum( - probability_at_zeroth_timestep, den1, "... j, ... j -> ..." - ).clip(min=small_epsilon) + den2 = einops.einsum(probability_at_zeroth_timestep, den1, "... j, ... j -> ...") denominator = einops.repeat( den2, "... -> ... num_classes", num_classes=numerator.shape[-1] ) return numerator / denominator + + +def get_probability_from_logits(logits: torch.Tensor, lowest_probability_value: float) -> torch.Tensor: + """Get probability from logits. + + Compute the probabilities from the logit, imposing that no class probablility can be lower than + lowest_probability_value. + + Args: + logits: Unormalized values that can be turned into probabilities. Dimensions [..., num_classes] + lowest_probability_value: imposed lowest probability value for any class. + + Returns: + probabilities: derived from the logits, with minimal clipped values. Dimensions [..., num_classes]. + + """ + raw_probabilities = torch.nn.functional.softmax(logits, dim=-1) + probability_sum = raw_probabilities.sum(dim=-1) + torch.testing.assert_close(probability_sum, torch.ones_like(probability_sum), + msg="Logits are pathological: the probabilities do not sum to one.") + + clipped_probabilities = raw_probabilities.clip(min=lowest_probability_value) + + probabilities = clipped_probabilities / clipped_probabilities.sum(dim=-1).unsqueeze(-1) + return probabilities diff --git a/tests/utils/test_d3pm_utils.py b/tests/utils/test_d3pm_utils.py index ecd7234d..8cd4e7af 100644 --- a/tests/utils/test_d3pm_utils.py +++ b/tests/utils/test_d3pm_utils.py @@ -1,3 +1,5 @@ +from copy import copy + import pytest import torch @@ -7,7 +9,7 @@ NoiseScheduler from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import ( class_index_to_onehot, compute_q_at_given_a0, compute_q_at_given_atm1, - get_probability_at_previous_time_step) + get_probability_at_previous_time_step, get_probability_from_logits) from diffusion_for_multi_scale_molecular_dynamics.utils.tensor_utils import \ broadcast_batch_matrix_tensor_to_all_dimensions @@ -321,3 +323,50 @@ def test_prob_a0_given_a1_is_never_mask(number_of_atoms, num_classes, total_time total_probability = p_a0_given_a1.sum(dim=-1) torch.testing.assert_allclose(total_probability, torch.ones_like(total_probability)) + + +@pytest.fixture() +def logits(batch_size, num_atom_types, num_classes): + return torch.rand(batch_size, num_atom_types, num_classes) + + +@pytest.mark.parametrize("lowest_probability_value", [1e-12, 1e-8, 1e-3]) +def test_get_probability_from_logits_general(logits, lowest_probability_value): + probabilities = get_probability_from_logits(logits, lowest_probability_value) + + approximate_probabilities = torch.nn.functional.softmax(logits, dim=-1) + + torch.testing.assert_close(probabilities, approximate_probabilities) + + computed_sums = probabilities.sum(dim=-1) + torch.testing.assert_close(computed_sums, torch.ones_like(computed_sums)) + + +@pytest.mark.parametrize("lowest_probability_value", [1e-12, 1e-8, 1e-3]) +def test_get_probability_from_logits_some_zero_probabilities(logits, lowest_probability_value): + + mask = torch.randint(0, 2, logits.shape).to(torch.bool) + mask[:, :, 0] = False # make sure no mask is all True. + + edge_case_logits = copy(logits) + edge_case_logits[mask] = -torch.inf + + computed_probabilities = get_probability_from_logits(edge_case_logits, lowest_probability_value) + + computed_sums = computed_probabilities.sum(dim=-1) + torch.testing.assert_close(computed_sums, torch.ones_like(computed_sums)) + + assert torch.all(computed_probabilities[mask] > 0.1 * lowest_probability_value) + + +@pytest.mark.parametrize("lowest_probability_value", [1e-12, 1e-8, 1e-3]) +def test_get_probability_from_logits_pathological(logits, lowest_probability_value): + + mask = torch.randint(0, 2, logits.shape).to(torch.bool) + mask[0, 0, :] = True # All bad logits + + bad_logits = copy(logits) + bad_logits[mask] = -torch.inf + + with pytest.raises(AssertionError): + _ = get_probability_from_logits(bad_logits, lowest_probability_value) From e2ee046fbaf22cea6a27d150c5e7ba9e0bc5ba18 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 29 Nov 2024 06:53:01 -0500 Subject: [PATCH 04/10] Better docstring. --- tests/generators/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generators/conftest.py b/tests/generators/conftest.py index a07fc7b6..572e27c4 100644 --- a/tests/generators/conftest.py +++ b/tests/generators/conftest.py @@ -12,7 +12,7 @@ class FakeAXLNetwork(ScoreNetwork): - """A fake, smooth score network for the ODE solver.""" + """A fake score network for tests.""" def _forward_unchecked( self, batch: Dict[AnyStr, torch.Tensor], conditional: bool = False From b4db573be0374585eed1b9be17ffc5f4e104209a Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 29 Nov 2024 14:59:24 -0500 Subject: [PATCH 05/10] Do not allow T=1. --- .../generators/predictor_corrector_axl_generator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/predictor_corrector_axl_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/predictor_corrector_axl_generator.py index 67c760a1..2757ac2b 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/predictor_corrector_axl_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/predictor_corrector_axl_generator.py @@ -36,9 +36,10 @@ def __init__( **kwargs, ): """Init method.""" + # T = 1 is a dangerous and meaningless edge case. assert ( - number_of_discretization_steps > 0 - ), "The number of discretization steps should be larger than zero" + number_of_discretization_steps > 1 + ), "The number of discretization steps should be larger than one" assert ( number_of_corrector_steps >= 0 ), "The number of corrector steps should be non-negative" From 5f25f77f0c2b0908233292c71905ad44ef0b9236 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 29 Nov 2024 15:01:01 -0500 Subject: [PATCH 06/10] Refactor how atom types are determined. --- .../generators/langevin_generator.py | 257 +++++++--- tests/generators/test_langevin_generator.py | 460 +++++++++++++----- 2 files changed, 515 insertions(+), 202 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py index a2c2580b..e0b04d32 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py @@ -1,6 +1,7 @@ import dataclasses from typing import Tuple +import einops import torch from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import ( @@ -42,11 +43,8 @@ def __init__( spatial_dimension=sampling_parameters.spatial_dimension, num_atom_types=sampling_parameters.num_atom_types, ) - self.noise_parameters = noise_parameters - sampler = NoiseScheduler( - noise_parameters, num_classes=sampling_parameters.num_atom_types + 1 - ) + sampler = NoiseScheduler(noise_parameters, num_classes=self.num_classes) self.noise, self.langevin_dynamics = sampler.get_all_sampling_parameters() self.number_of_atoms = sampling_parameters.number_of_atoms self.masked_atom_type_index = self.num_classes - 1 @@ -66,24 +64,29 @@ def __init__( self.record_atom_type_update = sampling_parameters.record_atom_type_update if self.record_corrector or self.record_atom_type_update: - assert self.record, "Corrector steps or atom_type_update can only be recorded if record_samples is True." + assert ( + self.record + ), "Corrector steps or atom_type_update can only be recorded if record_samples is True." if self.record: self.sample_trajectory_recorder = SampleTrajectory() self.sample_trajectory_recorder.record(key="noise", entry=self.noise) - self.sample_trajectory_recorder.record(key="noise_parameters", - entry=dataclasses.asdict(noise_parameters)) - self.sample_trajectory_recorder.record(key="sampling_parameters", - entry=dataclasses.asdict(sampling_parameters)) + self.sample_trajectory_recorder.record( + key="noise_parameters", entry=dataclasses.asdict(noise_parameters) + ) + self.sample_trajectory_recorder.record( + key="sampling_parameters", entry=dataclasses.asdict(sampling_parameters) + ) def initialize( self, number_of_samples: int, device: torch.device = torch.device("cpu") ): """This method must initialize the samples from the fully noised distribution.""" # all atoms are initialized as masked - atom_types = torch.ones(number_of_samples, self.number_of_atoms).long().to( - device - ) * (self.num_classes - 1) + atom_types = ( + torch.ones(number_of_samples, self.number_of_atoms).long().to(device) + * self.masked_atom_type_index + ) # relative coordinates are sampled from the uniform distribution relative_coordinates = torch.rand( number_of_samples, self.number_of_atoms, self.spatial_dimension @@ -158,7 +161,7 @@ def _get_model_predictions( model_predictions = self.axl_network(augmented_batch, conditional=False) return model_predictions - def relative_coordinates_update( + def _relative_coordinates_update( self, relative_coordinates: torch.Tensor, sigma_normalized_scores: torch.Tensor, @@ -199,13 +202,15 @@ def relative_coordinates_update( updated_coordinates = map_relative_coordinates_to_unit_cell(updated_coordinates) return updated_coordinates - def atom_types_update( + def _atom_types_update( self, predicted_logits: torch.Tensor, atom_types_i: torch.LongTensor, q_matrices_i: torch.Tensor, q_bar_matrices_i: torch.Tensor, q_bar_tm1_matrices_i: torch.Tensor, + atom_type_greedy_sampling: bool, + one_atom_type_transition_per_step: bool, ) -> torch.LongTensor: """Generic update of the atom types. @@ -222,12 +227,17 @@ def atom_types_update( number_of_atoms, num_classes, num_classes]. q_bar_tm1_matrices_i: cumulative transition matrix at time step 'i - 1'. Dimension: [number_of_samples, number_of_atoms, num_classes, num_classes]. + atom_type_greedy_sampling: boolean flag that sets whether the atom types should be selected greedily. + one_atom_type_transition_per_step: boolean flag that sets whether a single atom type transition can + occur per time step. Returns: - a_im1: updated atom type indices. Dimension: [number_of_samples, number_of_atoms] + atom_types_im1: updated atom type indices. Dimension: [number_of_samples, number_of_atoms] """ number_of_samples = predicted_logits.shape[0] - u = self._draw_gumbel_sample(number_of_samples).to(predicted_logits.device) + gumbel_random_variable = self._draw_gumbel_sample(number_of_samples).to( + predicted_logits.device + ) one_hot_atom_types_i = class_index_to_onehot( atom_types_i, num_classes=self.num_classes ) @@ -241,61 +251,97 @@ def atom_types_update( probability_at_zeroth_timestep_are_logits=True, ) # p(a_{t-1} | a_t) as a [num_samples, num_atoms, num_classes] tensor - if self.atom_type_greedy_sampling: - # if we use greedy sampling, we will update the transition probabilities for the MASK token + if atom_type_greedy_sampling: + # if we use greedy sampling, we will update the transition probabilities for the MASK token. # For a_i = MASK, we define "greedy sampling" as first determining if a_{i-1} should also be MASK based on # p(a_{i-1} = MASK | a_i = MASK). If a_{i-1} should be unmasked, its atom type is selected as the one with # the highest probability (i.e., no stochastic sampling). Stochasticity is removed by setting the relevant - # row of u to zero. - one_step_transition_probs, u = ( - self.adjust_atom_types_probabilities_for_greedy_sampling( - one_step_transition_probs, atom_types_i, u + # row of gumbel_random_variable to zero. + one_step_transition_probs, gumbel_random_variable = ( + self._adjust_atom_types_probabilities_for_greedy_sampling( + one_step_transition_probs, atom_types_i, gumbel_random_variable ) ) - # find the updated atom types by sampling from the transition probabilities using the gumbel-softmax trick - # we also keep the associated scores in memory, so we can compare which transitions are the most likely - max_logits_per_atom, updated_atom_types = torch.max( - torch.log(one_step_transition_probs + self.small_epsilon) + u, dim=-1 + # Use the Gumbel-softmax trick to sample atomic types. + # We also keep the associated values in memory, so we can compare which transitions are the most likely. + # Dimensions: [num_samples, num_atoms]. + max_gumbel_values, sampled_atom_types = torch.max( + torch.log(one_step_transition_probs + self.small_epsilon) + + gumbel_random_variable, + dim=-1, ) - if not self.one_atom_type_transition_per_step: - a_im1 = updated_atom_types # we are done - - else: + if one_atom_type_transition_per_step: # force a single transition for each sample - atoms_have_changed_types = ( - updated_atom_types != atom_types_i - ) # num_samples, num_atoms bool tensor - max_transition_per_sample = torch.argmax( - torch.where(atoms_have_changed_types, max_logits_per_atom, -torch.inf), - dim=-1, + atom_types_im1 = self._get_updated_atom_types_for_one_transition_per_step( + atom_types_i, max_gumbel_values, sampled_atom_types ) - a_im1 = atom_types_i.clone() - a_im1[torch.arange(number_of_samples), max_transition_per_sample] = ( - updated_atom_types[ - torch.arange(number_of_samples), max_transition_per_sample - ] - ) - # TODO some sanity check at the last step because this approach does not guarantee a full transition... + else: + atom_types_im1 = sampled_atom_types if self.record_atom_type_update: # Keep the record on the CPU - entry = dict(predicted_logits=predicted_logits.detach().cpu(), - one_step_transition_probabilities=one_step_transition_probs.detach().cpu(), - gumbel_sample=u.cpu(), - a_i=atom_types_i.cpu(), - a_im1=a_im1.cpu()) + entry = dict( + predicted_logits=predicted_logits.detach().cpu(), + one_step_transition_probabilities=one_step_transition_probs.detach().cpu(), + gumbel_sample=gumbel_random_variable.cpu(), + a_i=atom_types_i.cpu(), + a_im1=atom_types_im1.cpu(), + ) + + self.sample_trajectory_recorder.record(key="atom_type_update", entry=entry) + + return atom_types_im1 + + def _get_updated_atom_types_for_one_transition_per_step( + self, + current_atom_types: torch.Tensor, + max_gumbel_values: torch.Tensor, + sampled_atom_types: torch.Tensor, + ): + """Get updated atom types for one transition per step. + + Assuming the Gumbel softmax trick was used to create a new sample of atom types, this method + restrict the transitions from the current atom types to only the most likely one per sample. + + Args: + current_atom_types: current indices of the atom types. Dimension: [number_of_samples, number_of_atoms] + max_gumbel_values: maximum Gumbel softmax values. Dimension: [number_of_samples, number_of_atoms] + sampled_atom_types: indices of the atom types resulting from the gumbel softmax sampling. + Dimension: [number_of_samples, number_of_atoms] + + Returns: + updated_atom_types: atom types resulting from only making one transition per sample on current_atom_types. + Dimension: [number_of_samples, number_of_atoms] + """ + number_of_samples = current_atom_types.shape[0] + sample_indices = torch.arange(number_of_samples) + + # Boolean mask of dimensions [number_of_samples, number_of_atoms] + atoms_have_changed_types = sampled_atom_types != current_atom_types - self.sample_trajectory_recorder.record(key='atom_type_update', entry=entry) + # Identify the most likely transition amongst the proposed changes. + max_gumbel_values_restricted_to_proposed_changes = torch.where( + atoms_have_changed_types, max_gumbel_values, -torch.inf + ) + most_likely_transition_atom_indices = torch.argmax( + max_gumbel_values_restricted_to_proposed_changes, dim=-1 + ) - return a_im1 + # Restrict transitions to only the most likely ones. + updated_atom_types = current_atom_types.clone() + updated_atom_types[sample_indices, most_likely_transition_atom_indices] = ( + sampled_atom_types[sample_indices, most_likely_transition_atom_indices] + ) - def adjust_atom_types_probabilities_for_greedy_sampling( + return updated_atom_types + + def _adjust_atom_types_probabilities_for_greedy_sampling( self, one_step_transition_probs: torch.Tensor, atom_types_i: torch.LongTensor, - u: torch.Tensor, + gumbel_random_variable: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Update the transition probabilities and the gumbel random variables to allow greedy sampling. @@ -307,7 +353,8 @@ def adjust_atom_types_probabilities_for_greedy_sampling( Args: one_step_transition_probs: class distributions at time t-1 given distribution at time t. p(a_{t-1} | a_t) atom_types_i: indices of atom types at time i. Dimension: [number_of_samples, number_of_atoms] - u: gumbel noise used for sampling. Dimension: [number_of_samples, number_of_atoms, num_classes] + gumbel_random_variable: gumbel noise used for sampling. + Dimension: [number_of_samples, number_of_atoms, num_classes] Returns: one_step_transition_probs: probabilities are updated so a MASK to non-MASK transition can happen @@ -344,8 +391,10 @@ def adjust_atom_types_probabilities_for_greedy_sampling( # In the current choice of \beta_t = 1 / (T-t+1), a greedy sampling will always select the MASK type if that # probability is not set to zero - except at the last generation step. This might not hold if the \beta schedule # is modified. - u = torch.where(all_masked.view(-1, 1, 1), u, 0.0) - return one_step_transition_probs, u + gumbel_random_variable = torch.where( + all_masked.view(-1, 1, 1), gumbel_random_variable, 0.0 + ) + return one_step_transition_probs, gumbel_random_variable def predictor_step( self, @@ -369,46 +418,88 @@ def predictor_step( 1 <= index_i <= self.number_of_discretization_steps ), "The predictor step can only be invoked for index_i between 1 and the total number of discretization steps." + number_of_samples = composition_i.X.shape[0] + number_of_atoms = composition_i.X.shape[1] + idx = index_i - 1 # python starts indices at zero t_i = self.noise.time[idx].to(composition_i.X) g_i = self.noise.g[idx].to(composition_i.X) g2_i = self.noise.g_squared[idx].to(composition_i.X) sigma_i = self.noise.sigma[idx].to(composition_i.X) - q_matrices_i = self.noise.q_matrix[idx].to(composition_i.X) - q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X) - q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X) + + # Broadcast the q matrices to the expected dimensions. + q_matrices_i = einops.repeat( + self.noise.q_matrix[idx].to(composition_i.X), + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=number_of_atoms, + ) + + q_bar_matrices_i = einops.repeat( + self.noise.q_bar_matrix[idx].to(composition_i.X), + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=number_of_atoms, + ) + + q_bar_tm1_matrices_i = einops.repeat( + self.noise.q_bar_tm1_matrix[idx].to(composition_i.X), + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=number_of_atoms, + ) model_predictions_i = self._get_model_predictions( composition_i, t_i, sigma_i, unit_cell, cartesian_forces ) - # atom types update - a_im1 = self.atom_types_update( + # Even if the global flag 'one_atom_type_transition_per_step' is set to True, a single atomic transition + # cannot be used at the last time step because it is necessary for all atoms to be unmasked at the end + # of the trajectory. Here, we use 'first' and 'last' with respect to a denoising trajectory, where + # the "first" time step is at index_i = T and the "last" time step is index_i = 1. + this_is_last_time_step = idx == 0 + one_atom_type_transition_per_step = ( + self.one_atom_type_transition_per_step and not this_is_last_time_step + ) + + a_im1 = self._atom_types_update( model_predictions_i.A, composition_i.A, q_matrices_i, q_bar_matrices_i, q_bar_tm1_matrices_i, + atom_type_greedy_sampling=self.atom_type_greedy_sampling, + one_atom_type_transition_per_step=one_atom_type_transition_per_step, ) - x_im1 = self.relative_coordinates_update( + x_im1 = self._relative_coordinates_update( composition_i.X, model_predictions_i.X, sigma_i, g2_i, g_i ) - composition_im1 = AXL(A=a_im1, X=x_im1, L=unit_cell) # TODO : Deal with L correctly + composition_im1 = AXL( + A=a_im1, X=x_im1, L=unit_cell + ) # TODO : Deal with L correctly if self.record: # TODO : Deal with L correctly - composition_i_for_recording = AXL(A=composition_i.A, - X=composition_i.X, - L=unit_cell) + composition_i_for_recording = AXL( + A=composition_i.A, X=composition_i.X, L=unit_cell + ) # Keep the record on the CPU entry = dict(time_step_index=index_i) - list_keys = ['composition_i', 'composition_im1', 'model_predictions_i'] - list_axl = [composition_i_for_recording, composition_im1, model_predictions_i] + list_keys = ["composition_i", "composition_im1", "model_predictions_i"] + list_axl = [ + composition_i_for_recording, + composition_im1, + model_predictions_i, + ] for key, axl in zip(list_keys, list_axl): - record_axl = AXL(A=axl.A.detach().cpu(), X=axl.X.detach().cpu(), L=axl.L.detach().cpu()) + record_axl = AXL( + A=axl.A.detach().cpu(), + X=axl.X.detach().cpu(), + L=axl.L.detach().cpu(), + ) entry[key] = record_axl self.sample_trajectory_recorder.record(key="predictor_step", entry=entry) @@ -458,7 +549,7 @@ def corrector_step( composition_i, t_i, sigma_i, unit_cell, cartesian_forces ) - corrected_x_i = self.relative_coordinates_update( + corrected_x_i = self._relative_coordinates_update( composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i ) @@ -467,12 +558,14 @@ def corrector_step( q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X) q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X) # atom types update - corrected_a_i = self.atom_types_update( + corrected_a_i = self._atom_types_update( model_predictions_i.A, composition_i.A, q_matrices_i, q_bar_matrices_i, q_bar_tm1_matrices_i, + atom_type_greedy_sampling=self.atom_type_greedy_sampling, + one_atom_type_transition_per_step=self.one_atom_type_transition_per_step, ) else: corrected_a_i = composition_i.A @@ -485,16 +578,28 @@ def corrector_step( if self.record_corrector: # TODO : Deal with L correctly - composition_i_for_recording = AXL(A=composition_i.A, - X=composition_i.X, - L=unit_cell) + composition_i_for_recording = AXL( + A=composition_i.A, X=composition_i.X, L=unit_cell + ) # Keep the record on the CPU entry = dict(time_step_index=index_i) - list_keys = ['composition_i', 'corrected_composition_i', 'model_predictions_i'] - list_axl = [composition_i_for_recording, corrected_composition_i, model_predictions_i] + list_keys = [ + "composition_i", + "corrected_composition_i", + "model_predictions_i", + ] + list_axl = [ + composition_i_for_recording, + corrected_composition_i, + model_predictions_i, + ] for key, axl in zip(list_keys, list_axl): - record_axl = AXL(A=axl.A.detach().cpu(), X=axl.X.detach().cpu(), L=axl.L.detach().cpu()) + record_axl = AXL( + A=axl.A.detach().cpu(), + X=axl.X.detach().cpu(), + L=axl.L.detach().cpu(), + ) entry[key] = record_axl self.sample_trajectory_recorder.record(key="corrector_step", entry=entry) diff --git a/tests/generators/test_langevin_generator.py b/tests/generators/test_langevin_generator.py index f8ad144e..0f66cb3c 100644 --- a/tests/generators/test_langevin_generator.py +++ b/tests/generators/test_langevin_generator.py @@ -1,3 +1,4 @@ +import einops import pytest import torch @@ -10,8 +11,6 @@ NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ map_relative_coordinates_to_unit_cell -from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import ( - class_index_to_onehot, get_probability_at_previous_time_step) from src.diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ NoiseScheduler from tests.generators.conftest import BaseTestGenerator @@ -19,28 +18,32 @@ class TestLangevinGenerator(BaseTestGenerator): - @pytest.fixture(params=[1, 5, 10]) - def num_atom_types(self, request): - return request.param + @pytest.fixture() + def num_atom_types(self): + return 4 @pytest.fixture() def num_atomic_classes(self, num_atom_types): return num_atom_types + 1 - @pytest.fixture(params=[0, 1, 2]) + @pytest.fixture(params=[0, 2]) def number_of_corrector_steps(self, request): return request.param - @pytest.fixture(params=[1, 5, 10]) + @pytest.fixture(params=[2, 5, 10]) def total_time_steps(self, request): return request.param @pytest.fixture() - def noise_parameters(self, total_time_steps): + def sigma_min(self): + return 0.15 + + @pytest.fixture() + def noise_parameters(self, total_time_steps, sigma_min): noise_parameters = NoiseParameters( total_time_steps=total_time_steps, time_delta=0.1, - sigma_min=0.15, + sigma_min=sigma_min, corrector_step_epsilon=0.25, ) return noise_parameters @@ -49,6 +52,18 @@ def noise_parameters(self, total_time_steps): def small_epsilon(self): return 1e-6 + @pytest.fixture(params=[True, False]) + def one_atom_type_transition_per_step(self, request): + return request.param + + @pytest.fixture(params=[True, False]) + def atom_type_greedy_sampling(self, request): + return request.param + + @pytest.fixture(params=[True, False]) + def atom_type_transition_in_corrector(self, request): + return request.param + @pytest.fixture() def sampling_parameters( self, @@ -59,6 +74,9 @@ def sampling_parameters( number_of_corrector_steps, unit_cell_size, num_atom_types, + one_atom_type_transition_per_step, + atom_type_greedy_sampling, + atom_type_transition_in_corrector, small_epsilon, ): sampling_parameters = PredictorCorrectorSamplingParameters( @@ -68,11 +86,22 @@ def sampling_parameters( cell_dimensions=cell_dimensions, spatial_dimension=spatial_dimension, num_atom_types=num_atom_types, + one_atom_type_transition_per_step=one_atom_type_transition_per_step, + atom_type_greedy_sampling=atom_type_greedy_sampling, + atom_type_transition_in_corrector=atom_type_transition_in_corrector, small_epsilon=small_epsilon, ) return sampling_parameters + @pytest.fixture() + def noise(self, noise_parameters, num_atomic_classes, device): + sampler = NoiseScheduler(noise_parameters, num_classes=num_atomic_classes).to( + device + ) + noise, _ = sampler.get_all_sampling_parameters() + return noise + @pytest.fixture() def pc_generator(self, noise_parameters, sampling_parameters, axl_network): generator = LangevinGenerator( @@ -116,20 +145,13 @@ def test_predictor_step_relative_coordinates( self, mocker, pc_generator, - noise_parameters, + noise, + sigma_min, axl_i, total_time_steps, number_of_samples, unit_cell_sample, - num_atomic_classes, - device, ): - - sampler = NoiseScheduler(noise_parameters, num_classes=num_atomic_classes).to( - device - ) - noise, _ = sampler.get_all_sampling_parameters() - sigma_min = noise_parameters.sigma_min list_sigma = noise.sigma list_time = noise.time forces = torch.zeros_like(axl_i.X) @@ -165,133 +187,298 @@ def test_predictor_step_relative_coordinates( torch.testing.assert_close(computed_sample.X, expected_coordinates) - @pytest.mark.parametrize("one_atom_type_transition_per_step", [True, False]) - @pytest.mark.parametrize("atom_type_greedy_sampling", [True, False]) - def test_predictor_step_atom_types( + def test_adjust_atom_types_probabilities_for_greedy_sampling( + self, pc_generator, number_of_atoms, num_atomic_classes + ): + # Test that all_masked atom types are unaffected. + fully_masked_row = pc_generator.masked_atom_type_index * torch.ones( + number_of_atoms, dtype=torch.int64 + ) + + partially_unmasked_row = fully_masked_row.clone() + partially_unmasked_row[0] = 0 + + atom_types_i = torch.stack([fully_masked_row, partially_unmasked_row]) + + number_of_samples = atom_types_i.shape[0] + u = pc_generator._draw_gumbel_sample(number_of_samples) + + one_step_transition_probs = torch.rand( + number_of_samples, number_of_atoms, num_atomic_classes + ).softmax(dim=-1) + # Use cloned values because the method overrides the inputs. + updated_one_step_transition_probs, updated_u = ( + pc_generator._adjust_atom_types_probabilities_for_greedy_sampling( + one_step_transition_probs.clone(), atom_types_i, u.clone() + ) + ) + + # Test that the fully masked row is unaffected + torch.testing.assert_close( + updated_one_step_transition_probs[0], one_step_transition_probs[0] + ) + torch.testing.assert_close(u[0], updated_u[0]) + + # Test that when an atom is unmasked, the probabilities are set up for greedy sampling: + # - the probabilities for the real atomic classes are unchanged. + # - the probability for the MASK class (last index) is either unchanged or set to zero. + # - the Gumbel sample is set to zero so that the unmasking is greedy. + + torch.testing.assert_close( + updated_one_step_transition_probs[1, :, :-1], + one_step_transition_probs[1, :, :-1], + ) + + m1 = ( + updated_one_step_transition_probs[1, :, -1] + == one_step_transition_probs[1, :, -1] + ) + m2 = updated_one_step_transition_probs[1, :, -1] == 0.0 + assert torch.logical_or(m1, m2).all() + torch.testing.assert_close(updated_u[1], torch.zeros_like(updated_u[1])) + + def test_get_updated_atom_types_for_one_transition_per_step_is_idempotent( self, - mocker, - one_atom_type_transition_per_step, - atom_type_greedy_sampling, - noise_parameters, - sampling_parameters, - axl_network, - axl_i, - total_time_steps, + pc_generator, number_of_samples, - unit_cell_sample, - num_atomic_classes, - small_epsilon, number_of_atoms, + num_atomic_classes, device, ): - sampling_parameters.one_atom_type_transition_per_step = ( - one_atom_type_transition_per_step + # Test that the method returns the current atom types if there is no proposed changes. + current_atom_types = torch.randint( + 0, num_atomic_classes, (number_of_samples, number_of_atoms) + ).to(device) + sampled_atom_types = current_atom_types.clone() + max_gumbel_values = torch.rand(number_of_samples, number_of_atoms).to(device) + + updated_atom_types = ( + pc_generator._get_updated_atom_types_for_one_transition_per_step( + current_atom_types, max_gumbel_values, sampled_atom_types + ) ) - sampling_parameters.atom_type_greedy_sampling = atom_type_greedy_sampling - pc_generator = LangevinGenerator( - noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - axl_network=axl_network, + torch.testing.assert_close(updated_atom_types, current_atom_types) + + def test_get_updated_atom_types_for_one_transition_per_step( + self, + pc_generator, + number_of_samples, + number_of_atoms, + num_atomic_classes, + device, + ): + assert ( + num_atomic_classes > 0 + ), "Cannot run this test with a single atomic class." + current_atom_types = torch.randint( + 0, num_atomic_classes, (number_of_samples, number_of_atoms) + ).to(device) + sampled_atom_types = torch.randint( + 0, num_atomic_classes, (number_of_samples, number_of_atoms) + ).to(device) + # Make sure at least one atom is different in every sample. + while not (current_atom_types != sampled_atom_types).any(dim=-1).all(): + sampled_atom_types = torch.randint( + 0, num_atomic_classes, (number_of_samples, number_of_atoms) + ).to(device) + + proposed_difference_mask = current_atom_types != sampled_atom_types + + max_gumbel_values = torch.rand(number_of_samples, number_of_atoms).to(device) + + updated_atom_types = ( + pc_generator._get_updated_atom_types_for_one_transition_per_step( + current_atom_types, max_gumbel_values, sampled_atom_types + ) ) - sampler = NoiseScheduler(noise_parameters, num_classes=num_atomic_classes).to( - device + difference_mask = updated_atom_types != current_atom_types + + # Check that there is a single difference per sample + number_of_changes = difference_mask.sum(dim=-1) + torch.testing.assert_close( + number_of_changes, torch.ones(number_of_samples).to(number_of_changes) ) - noise, _ = sampler.get_all_sampling_parameters() - list_sigma = noise.sigma - list_time = noise.time - list_q_matrices = noise.q_matrix - list_q_bar_matrices = noise.q_bar_matrix - list_q_bar_tm1_matrices = noise.q_bar_tm1_matrix - forces = torch.zeros_like(axl_i.X) - u = pc_generator._draw_gumbel_sample(number_of_samples).to( - device=axl_i.A.device + # Check that the difference is at the location of the maximum value of the Gumbel random variable over the + # possible changes. + computed_changed_atom_indices = torch.where(difference_mask)[1] + + expected_changed_atom_indices = [] + for sample_idx in range(number_of_samples): + sample_gumbel_values = max_gumbel_values[sample_idx].clone() + sample_proposed_difference_mask = proposed_difference_mask[sample_idx] + sample_gumbel_values[~sample_proposed_difference_mask] = -torch.inf + max_index = torch.argmax(sample_gumbel_values) + expected_changed_atom_indices.append(max_index) + expected_changed_atom_indices = torch.tensor(expected_changed_atom_indices).to( + computed_changed_atom_indices ) - mocker.patch.object(pc_generator, "_draw_gumbel_sample", return_value=u) - binary_sample = pc_generator._draw_binary_sample(number_of_samples).to( - device=axl_i.A.device + torch.testing.assert_close( + computed_changed_atom_indices, expected_changed_atom_indices ) - mocker.patch.object( - pc_generator, "_draw_binary_sample", return_value=binary_sample + + def test_atom_types_update( + self, + pc_generator, + noise, + total_time_steps, + num_atomic_classes, + number_of_samples, + number_of_atoms, + device, + ): + + # Initialize to fully masked + a_i = pc_generator.masked_atom_type_index * torch.ones( + number_of_samples, number_of_atoms, dtype=torch.int64 + ).to(device) + + for time_index_i in range(total_time_steps, 0, -1): + this_is_last_time_step = time_index_i == 1 + idx = time_index_i - 1 + q_matrices_i = einops.repeat( + noise.q_matrix[idx], + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=number_of_atoms, + ) + + q_bar_matrices_i = einops.repeat( + noise.q_bar_matrix[idx], + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=number_of_atoms, + ) + + q_bar_tm1_matrices_i = einops.repeat( + noise.q_bar_tm1_matrix[idx], + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=number_of_atoms, + ) + + random_logits = torch.rand( + number_of_samples, number_of_atoms, num_atomic_classes + ).to(device) + random_logits[:, :, -1] = -torch.inf + + one_atom_type_transition_per_step = ( + pc_generator.one_atom_type_transition_per_step + and not this_is_last_time_step + ) + + a_im1 = pc_generator._atom_types_update( + random_logits, + a_i, + q_matrices_i, + q_bar_matrices_i, + q_bar_tm1_matrices_i, + atom_type_greedy_sampling=pc_generator.atom_type_greedy_sampling, + one_atom_type_transition_per_step=one_atom_type_transition_per_step, + ) + + difference_mask = a_im1 != a_i + + # Test that the changes are from MASK to not-MASK + assert (a_i[difference_mask] == pc_generator.masked_atom_type_index).all() + assert (a_im1[difference_mask] != pc_generator.masked_atom_type_index).all() + + if one_atom_type_transition_per_step: + # Test that there is at most one change + assert torch.all(difference_mask.sum(dim=-1) <= 1.0) + + if pc_generator.atom_type_greedy_sampling: + # Test that the changes are the most probable (greedy) + sample_indices, atom_indices = torch.where(difference_mask) + for sample_idx, atom_idx in zip(sample_indices, atom_indices): + # Greedy sampling only applies if at least one atom was already unmasked. + if (a_i[sample_idx] == pc_generator.masked_atom_type_index).all(): + continue + computed_atom_type = a_im1[sample_idx, atom_idx] + expected_atom_type = random_logits[sample_idx, atom_idx].argmax() + assert computed_atom_type == expected_atom_type + + a_i = a_im1 + + # Test that no MASKED states remain + assert not (a_i == pc_generator.masked_atom_type_index).any() + + def test_predictor_step_atom_types( + self, + mocker, + pc_generator, + total_time_steps, + number_of_samples, + number_of_atoms, + num_atomic_classes, + spatial_dimension, + unit_cell_sample, + device, + ): + zeros = torch.zeros(number_of_samples, number_of_atoms, spatial_dimension).to( + device ) + forces = zeros + + random_x = map_relative_coordinates_to_unit_cell( + torch.rand(number_of_samples, number_of_atoms, spatial_dimension) + ).to(device) + + random_l = torch.zeros( + number_of_samples, spatial_dimension, spatial_dimension + ).to(device) + + # Initialize to fully masked + a_ip1 = pc_generator.masked_atom_type_index * torch.ones( + number_of_samples, number_of_atoms, dtype=torch.int64 + ).to(device) + axl_ip1 = AXL(A=a_ip1, X=random_x, L=random_l) + + for idx in range(total_time_steps - 1, -1, -1): + + # Inject reasonable logits + logits = torch.rand( + number_of_samples, number_of_atoms, num_atomic_classes + ).to(device) + logits[:, :, -1] = -torch.inf + fake_model_predictions = AXL(A=logits, X=zeros, L=zeros) + mocker.patch.object( + pc_generator, + "_get_model_predictions", + return_value=fake_model_predictions, + ) - for index_i in range(total_time_steps - 1, -1, -1): - computed_sample = pc_generator.predictor_step( - axl_i, index_i + 1, unit_cell_sample, forces + axl_i = pc_generator.predictor_step( + axl_ip1, idx + 1, unit_cell_sample, forces ) - sigma_i = list_sigma[index_i] - t_i = list_time[index_i] - - p_a0_given_at_i = pc_generator._get_model_predictions( - axl_i, t_i, sigma_i, unit_cell_sample, forces - ).A - - onehot_at = class_index_to_onehot(axl_i.A, num_classes=num_atomic_classes) - q_matrices = list_q_matrices[index_i] - q_bar_matrices = list_q_bar_matrices[index_i] - q_bar_tm1_matrices = list_q_bar_tm1_matrices[index_i] - - p_atm1_given_at = get_probability_at_previous_time_step( - probability_at_zeroth_timestep=p_a0_given_at_i, - one_hot_probability_at_current_timestep=onehot_at, - q_matrices=q_matrices, - q_bar_matrices=q_bar_matrices, - q_bar_tm1_matrices=q_bar_tm1_matrices, - small_epsilon=small_epsilon, - probability_at_zeroth_timestep_are_logits=True, + + this_is_last_time_step = idx == 0 + a_i = axl_i.A + a_ip1 = axl_ip1.A + + difference_mask = a_ip1 != a_i + + # Test that the changes are from MASK to not-MASK + assert (a_ip1[difference_mask] == pc_generator.masked_atom_type_index).all() + assert (a_i[difference_mask] != pc_generator.masked_atom_type_index).all() + + one_atom_type_transition_per_step = ( + pc_generator.one_atom_type_transition_per_step + and not this_is_last_time_step ) - updated_atm1_given_at = p_atm1_given_at.clone() - if atom_type_greedy_sampling: - # remove the noise component, so we are sampling the max value from the prob distribution - # also, set the probability of getting a mask to zero based on the binary_sample drawn earlier - samples_with_only_masks = torch.all( - axl_i.A == num_atomic_classes - 1, dim=-1 - ) - for sample_idx, sample_is_just_mask in enumerate( - samples_with_only_masks - ): - if not sample_is_just_mask: - u[sample_idx, :, :] = 0.0 - # replace mask probability if random number is larger than prob. of staying mask - for atom_idx in range(number_of_atoms): - if axl_i.A[sample_idx, atom_idx] == num_atomic_classes - 1: - updated_atm1_given_at[sample_idx, atom_idx, -1] *= ( - binary_sample[sample_idx, atom_idx] - < p_atm1_given_at[sample_idx, atom_idx, -1] - ) # multiply by 1 if random number is low (do nothing), or replace with 0 otherwise - - gumbel_distribution = ( - torch.log(updated_atm1_given_at + 1e-8) + u - ) # avoid log(zero) - - expected_atom_types = torch.argmax(gumbel_distribution, dim=-1) if one_atom_type_transition_per_step: - new_atom_types = axl_i.A.clone() - for sample_idx in range(number_of_samples): - # find the prob scores for each transition in this sample - sample_probs = [] - for atom_idx in range(number_of_atoms): - old_atom_id = axl_i.A[sample_idx, atom_idx] - new_atom_id = expected_atom_types[sample_idx, atom_idx] - # compare old id to new id - if same, no transition - if old_atom_id != new_atom_id: - # different, record the gumbel score - sample_probs.append( - gumbel_distribution[sample_idx, atom_idx, :].max() - ) - else: - sample_probs.append(-torch.inf) - highest_score_transition = torch.argmax(torch.tensor(sample_probs)) - new_atom_types[sample_idx, highest_score_transition] = ( - expected_atom_types[sample_idx, highest_score_transition] - ) - expected_atom_types = new_atom_types - - assert torch.all(computed_sample.A == expected_atom_types) + # Test that there is at most one change + assert torch.all(difference_mask.sum(dim=-1) <= 1.0) + + axl_ip1 = AXL(A=a_i, X=random_x, L=random_l) + + # Test that no MASKED states remain + a_i = axl_i.A + assert not (a_i == pc_generator.masked_atom_type_index).any() def test_corrector_step( self, @@ -344,4 +531,25 @@ def test_corrector_step( ) torch.testing.assert_close(computed_sample.X, expected_coordinates) - assert torch.all(computed_sample.A == axl_i.A) + + if pc_generator.atom_type_transition_in_corrector: + a_i = axl_i.A + corrected_a_i = computed_sample.A + + difference_mask = corrected_a_i != a_i + + # Test that the changes are from MASK to not-MASK + assert ( + a_i[difference_mask] == pc_generator.masked_atom_type_index + ).all() + assert ( + corrected_a_i[difference_mask] + != pc_generator.masked_atom_type_index + ).all() + + if pc_generator.one_atom_type_transition_per_step: + # Test that there is at most one change + assert torch.all(difference_mask.sum(dim=-1) <= 1.0) + + else: + assert torch.all(computed_sample.A == axl_i.A) From d4a8d240099a1d6cd5a24f1ec07f3a52169ebc36 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 29 Nov 2024 15:01:10 -0500 Subject: [PATCH 07/10] name fix --- .../patches/identity_relative_coordinates_langevin_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experiments/atom_types_only_experiments/patches/identity_relative_coordinates_langevin_generator.py b/experiments/atom_types_only_experiments/patches/identity_relative_coordinates_langevin_generator.py index 535a2c11..cdd5e063 100644 --- a/experiments/atom_types_only_experiments/patches/identity_relative_coordinates_langevin_generator.py +++ b/experiments/atom_types_only_experiments/patches/identity_relative_coordinates_langevin_generator.py @@ -52,7 +52,7 @@ def initialize( return fixed_init_composition - def relative_coordinates_update( + def _relative_coordinates_update( self, relative_coordinates: torch.Tensor, sigma_normalized_scores: torch.Tensor, From d792c3eddef11ecf2b8ed8cea09083fc0f5e36e6 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 29 Nov 2024 15:06:20 -0500 Subject: [PATCH 08/10] Add explicit check. --- .../generators/langevin_generator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py index e0b04d32..2d68dc77 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py @@ -472,6 +472,10 @@ def predictor_step( one_atom_type_transition_per_step=one_atom_type_transition_per_step, ) + if this_is_last_time_step: + assert (a_im1 != self.masked_atom_type_index).all(), \ + "There remains MASKED atoms at the last time step: review code, there must be a bug or invalid input." + x_im1 = self._relative_coordinates_update( composition_i.X, model_predictions_i.X, sigma_i, g2_i, g_i ) From 0a5c6d29000692164f04053970c3a2f3ea40dede Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 29 Nov 2024 15:15:43 -0500 Subject: [PATCH 09/10] Remove meaningless, useless broken test. --- tests/loss/test_atom_type_loss_calculator.py | 33 -------------------- 1 file changed, 33 deletions(-) diff --git a/tests/loss/test_atom_type_loss_calculator.py b/tests/loss/test_atom_type_loss_calculator.py index d94f261d..33d70702 100644 --- a/tests/loss/test_atom_type_loss_calculator.py +++ b/tests/loss/test_atom_type_loss_calculator.py @@ -332,39 +332,6 @@ def test_vb_loss_predicting_a0( torch.testing.assert_close(computed_kl_loss, torch.zeros_like(computed_kl_loss)) - def test_kl_loss_diagonal_q_matrices( - self, - num_classes, - d3pm_calculator, - ): - # with diagonal Q matrices, the KL is ALWAYS ZERO. This is because either: - # 1) the posterior is all zero - # or - # 2) the prediction is equal to the posterior; this follows because the prediction is normalized. - predicted_logits = torch.rand(1, 1, num_classes) - time_indices = torch.tensor([1]) # non-zero to compute the KL - - q_matrices = torch.eye(num_classes).view(1, 1, num_classes, num_classes) - q_bar_matrices = torch.eye(num_classes).view(1, 1, num_classes, num_classes) - q_bar_tm1_matrices = torch.eye(num_classes).view(1, 1, num_classes, num_classes) - for i in range(num_classes): - for j in range(num_classes): - one_hot_a0 = torch.zeros(1, 1, num_classes) - one_hot_at = torch.zeros(1, 1, num_classes) - one_hot_a0[0, 0, i] = 1.0 - one_hot_at[0, 0, j] = 1.0 - - computed_kl = d3pm_calculator.variational_bound_loss_term( - predicted_logits, - one_hot_a0, - one_hot_at, - q_matrices, - q_bar_matrices, - q_bar_tm1_matrices, - time_indices, - ) - torch.testing.assert_close(computed_kl, torch.zeros_like(computed_kl)) - def test_cross_entropy_loss_term(self, predicted_logits, one_hot_a0, d3pm_calculator): computed_ce_loss = d3pm_calculator.cross_entropy_loss_term(predicted_logits, one_hot_a0) From eca50154dbd74c8a93f71b3bf7f97543f14bf808 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 29 Nov 2024 15:17:54 -0500 Subject: [PATCH 10/10] Fix test. --- tests/generators/test_predictor_corrector_position_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generators/test_predictor_corrector_position_generator.py b/tests/generators/test_predictor_corrector_position_generator.py index 92319548..8b9a0a34 100644 --- a/tests/generators/test_predictor_corrector_position_generator.py +++ b/tests/generators/test_predictor_corrector_position_generator.py @@ -60,7 +60,7 @@ def corrector_step( return updated_axl -@pytest.mark.parametrize("number_of_discretization_steps", [1, 5, 10]) +@pytest.mark.parametrize("number_of_discretization_steps", [2, 5, 10]) @pytest.mark.parametrize("number_of_corrector_steps", [0, 1, 2]) class TestPredictorCorrectorPositionGenerator(BaseTestGenerator): @pytest.fixture(scope="class", autouse=True)