From dc5865a4c1de5749a143a3819d8490ddaf2b9a2e Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Wed, 2 Oct 2024 13:29:19 -0400 Subject: [PATCH 01/18] Noodling analysis scripts. --- .../analyse_stability.py | 78 +++++++++++++++ .../score_stability_analysis/util.py | 99 +++++++++++++++++++ 2 files changed, 177 insertions(+) create mode 100644 experiment_analysis/score_stability_analysis/analyse_stability.py create mode 100644 experiment_analysis/score_stability_analysis/util.py diff --git a/experiment_analysis/score_stability_analysis/analyse_stability.py b/experiment_analysis/score_stability_analysis/analyse_stability.py new file mode 100644 index 00000000..e968d8b5 --- /dev/null +++ b/experiment_analysis/score_stability_analysis/analyse_stability.py @@ -0,0 +1,78 @@ +import logging + +import numpy as np +import scipy.optimize as so +import torch + +from crystal_diffusion.models.score_networks.mlp_score_network import \ + MLPScoreNetworkParameters +from crystal_diffusion.models.score_networks.score_network_factory import \ + create_score_network +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.utils.logging_utils import setup_analysis_logger +from experiment_analysis.score_stability_analysis.util import ( + create_fixed_time_vector_field_function, get_flat_vector_field_function, + get_hessian_function) + +logger = logging.getLogger(__name__) +setup_analysis_logger() + +checkpoint_path = ("/network/scratch/r/rousseab/experiments/sept21_egnn_2x2x2/run4/" + "output/best_model/best_model-epoch=024-step=019550.ckpt") + + +spatial_dimension = 3 +number_of_atoms = 64 +atom_types = np.ones(number_of_atoms, dtype=int) + +acell = 10.86 +basis_vectors = torch.diag(torch.tensor([acell, acell, acell])) + +total_time_steps = 1000 +noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + sigma_min=0.0001, + sigma_max=0.2, +) + + +if __name__ == "__main__": + """ + logger.info("Loading checkpoint...") + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + """ + + score_network_parameters = MLPScoreNetworkParameters( + number_of_atoms=number_of_atoms, + n_hidden_dimensions=3, + embedding_dimensions_size=8, + hidden_dimensions_size=8, + spatial_dimension=spatial_dimension, + ) + sigma_normalized_score_network = create_score_network(score_network_parameters) + for parameter in sigma_normalized_score_network.parameters(): + parameter.requires_grad_(False) + + vector_field_fn = create_fixed_time_vector_field_function(sigma_normalized_score_network, + noise_parameters, + time=0.1, + basis_vectors=basis_vectors) + + hessian_fn = get_hessian_function(vector_field_fn) + + batch_size = 12 + + relative_coordinates = torch.rand(batch_size, number_of_atoms, spatial_dimension) + + vector_field = vector_field_fn(relative_coordinates) + + hessian = hessian_fn(relative_coordinates) + + eigenvalues, _ = torch.linalg.eig(hessian) + + func = get_flat_vector_field_function(vector_field_fn, number_of_atoms, spatial_dimension) + x0 = torch.rand(192).numpy() + out = so.root(func, x0) diff --git a/experiment_analysis/score_stability_analysis/util.py b/experiment_analysis/score_stability_analysis/util.py new file mode 100644 index 00000000..abaa45c6 --- /dev/null +++ b/experiment_analysis/score_stability_analysis/util.py @@ -0,0 +1,99 @@ +from typing import Callable + +import einops +import numpy as np +import torch +from torch.func import jacrev + +from crystal_diffusion.models.score_networks import ScoreNetwork +from crystal_diffusion.namespace import (CARTESIAN_FORCES, NOISE, + NOISY_RELATIVE_COORDINATES, TIME, + UNIT_CELL) +from crystal_diffusion.samplers.exploding_variance import ExplodingVariance +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell + + +def create_fixed_time_vector_field_function( + sigma_normalized_score_network: ScoreNetwork, + noise_parameters: NoiseParameters, + time: float, + basis_vectors: torch.Tensor, +): + """Create the vector field function.""" + variance_calculator = ExplodingVariance(noise_parameters) + + def vector_field_function(relative_coordinates: torch.Tensor) -> torch.Tensor: + batch_size, number_of_atoms, spatial_dimension = relative_coordinates.shape + + times = einops.repeat( + torch.tensor([time]).to(relative_coordinates), "1 -> b 1", b=batch_size + ) + unit_cells = einops.repeat( + basis_vectors.to(relative_coordinates), "s1 s2 -> b s1 s2", b=batch_size + ) + + forces = torch.zeros_like(relative_coordinates) + sigmas = variance_calculator.get_sigma(times) + g2 = variance_calculator.get_g_squared(times) + + vector_field_prefactor = -g2 / sigmas + + augmented_batch = { + NOISY_RELATIVE_COORDINATES: relative_coordinates, + TIME: times, + NOISE: sigmas, + UNIT_CELL: unit_cells, + CARTESIAN_FORCES: forces, + } + + sigma_normalized_scores = sigma_normalized_score_network( + augmented_batch, conditional=False + ) + + prefactor = einops.repeat( + vector_field_prefactor, + "b 1 -> b n s", + n=number_of_atoms, + s=spatial_dimension, + ) + + return prefactor * sigma_normalized_scores + + return vector_field_function + + +def get_hessian_function(vector_field_function: Callable) -> Callable: + """Get hessian function.""" + batch_hessian_function = jacrev(vector_field_function, argnums=0) + + def hessian_function(relative_coordinates: torch.Tensor) -> torch.Tensor: + # The batch hessian has dimension [batch_size, natoms, space, batch_size, natoms, space] + batch_hessian = batch_hessian_function(relative_coordinates) + + # Diagonal dumps the batch dimension at the end. + hessian = torch.diagonal(batch_hessian, dim1=0, dim2=3) + + flat_hessian = einops.rearrange(hessian, "n1 s1 n2 s2 b -> b (n1 s1) (n2 s2)") + return flat_hessian + + return hessian_function + + +def get_flat_vector_field_function( + vector_field_function: Callable, number_of_atoms: int, spatial_dimension: int +) -> Callable: + """Get a flat vector field function.""" + def flat_vector_field_function(x: np.ndarray) -> np.ndarray: + cast_x = torch.from_numpy(x).to(torch.float32) + relative_coordinates = einops.rearrange( + cast_x, "(n s) -> 1 n s", n=number_of_atoms, s=spatial_dimension + ) + relative_coordinates = map_relative_coordinates_to_unit_cell( + relative_coordinates + ) + vector_field = vector_field_function(relative_coordinates) + return einops.rearrange(vector_field, "1 n s -> (n s)").numpy() + + return flat_vector_field_function From a8728b76eb15f832c4c27d4d01797ce820744c3e Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 3 Oct 2024 09:20:37 -0400 Subject: [PATCH 02/18] WIP scripts --- .../analyse_stability.py | 100 +++++++++++------- .../score_stability_analysis/util.py | 50 +++++---- 2 files changed, 94 insertions(+), 56 deletions(-) diff --git a/experiment_analysis/score_stability_analysis/analyse_stability.py b/experiment_analysis/score_stability_analysis/analyse_stability.py index e968d8b5..0cc55ae6 100644 --- a/experiment_analysis/score_stability_analysis/analyse_stability.py +++ b/experiment_analysis/score_stability_analysis/analyse_stability.py @@ -1,31 +1,32 @@ import logging +import matplotlib.pyplot as plt import numpy as np -import scipy.optimize as so import torch -from crystal_diffusion.models.score_networks.mlp_score_network import \ - MLPScoreNetworkParameters -from crystal_diffusion.models.score_networks.score_network_factory import \ - create_score_network +from crystal_diffusion.analysis import PLOT_STYLE_PATH, PLEASANT_FIG_SIZE +from crystal_diffusion.analysis.analytic_score.utils import get_silicon_supercell +from crystal_diffusion.models.position_diffusion_lightning_model import PositionDiffusionLightningModel +from crystal_diffusion.samplers.exploding_variance import ExplodingVariance from crystal_diffusion.samplers.variance_sampler import NoiseParameters from crystal_diffusion.utils.logging_utils import setup_analysis_logger from experiment_analysis.score_stability_analysis.util import ( - create_fixed_time_vector_field_function, get_flat_vector_field_function, - get_hessian_function) + create_fixed_time_normalized_score_function, get_hessian_function, get_square_norm_and_grad_functions) + + +plt.style.use(PLOT_STYLE_PATH) logger = logging.getLogger(__name__) setup_analysis_logger() -checkpoint_path = ("/network/scratch/r/rousseab/experiments/sept21_egnn_2x2x2/run4/" - "output/best_model/best_model-epoch=024-step=019550.ckpt") +checkpoint_path = "/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/output/last_model/last_model-epoch=049-step=039100.ckpt" spatial_dimension = 3 -number_of_atoms = 64 +number_of_atoms = 8 atom_types = np.ones(number_of_atoms, dtype=int) -acell = 10.86 +acell = 5.43 basis_vectors = torch.diag(torch.tensor([acell, acell, acell])) total_time_steps = 1000 @@ -35,44 +36,71 @@ sigma_max=0.2, ) - +device = torch.device('cuda') if __name__ == "__main__": - """ + variance_calculator = ExplodingVariance(noise_parameters) + + logger.info("Loading checkpoint...") pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) pl_model.eval() sigma_normalized_score_network = pl_model.sigma_normalized_score_network - """ - - score_network_parameters = MLPScoreNetworkParameters( - number_of_atoms=number_of_atoms, - n_hidden_dimensions=3, - embedding_dimensions_size=8, - hidden_dimensions_size=8, - spatial_dimension=spatial_dimension, - ) - sigma_normalized_score_network = create_score_network(score_network_parameters) + for parameter in sigma_normalized_score_network.parameters(): parameter.requires_grad_(False) - vector_field_fn = create_fixed_time_vector_field_function(sigma_normalized_score_network, - noise_parameters, - time=0.1, - basis_vectors=basis_vectors) + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig.suptitle("Probing model along specific dimension") + ax1 = fig.add_subplot(121) + ax2 = fig.add_subplot(122) + + ax1.set_title("Score Norm") + ax2.set_title("Gradient of Score Norm") + + for ax in [ax1, ax2]: + ax.set_xlabel('x[2][0]') + + eq_rel_coords = get_silicon_supercell(supercell_factor=1) + for t in [1.0, 0.8, 0.5, 0.1, 0.01]: + print(f"Doing t = {t}") + + sigma = float(variance_calculator.get_sigma(t)) + + vector_field_fn = create_fixed_time_normalized_score_function(sigma_normalized_score_network, + noise_parameters, + time=t, + basis_vectors=basis_vectors) + + func, grad_func = get_square_norm_and_grad_functions(vector_field_fn, + number_of_atoms, + spatial_dimension, + device) + + + x0 = eq_rel_coords.flatten() + + list_dx = np.linspace(-0.5, 0.5, 101) - hessian_fn = get_hessian_function(vector_field_fn) + list_f = [] + list_g = [] + for dx in list_dx: + r = eq_rel_coords.copy() + r[2][0] += dx + x = r.flatten() + list_f.append(func(x)) + list_g.append(grad_func(x)[6]) - batch_size = 12 + list_f = np.array(list_f) + list_g = np.array(list_g) - relative_coordinates = torch.rand(batch_size, number_of_atoms, spatial_dimension) + ax1.plot(list_dx, list_f/np.sqrt(sigma), '-', label=f't = {t}, $\sigma$ = {sigma: 5.2e}') + ax2.plot(list_dx, list_g, '-') - vector_field = vector_field_fn(relative_coordinates) + ax1.legend(loc=0) - hessian = hessian_fn(relative_coordinates) + fig.tight_layout() - eigenvalues, _ = torch.linalg.eig(hessian) + plt.show() - func = get_flat_vector_field_function(vector_field_fn, number_of_atoms, spatial_dimension) - x0 = torch.rand(192).numpy() - out = so.root(func, x0) + #out = so.minimize(func, x0, jac=grad_func) \ No newline at end of file diff --git a/experiment_analysis/score_stability_analysis/util.py b/experiment_analysis/score_stability_analysis/util.py index abaa45c6..1d4d3692 100644 --- a/experiment_analysis/score_stability_analysis/util.py +++ b/experiment_analysis/score_stability_analysis/util.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Tuple import einops import numpy as np @@ -15,7 +15,7 @@ map_relative_coordinates_to_unit_cell -def create_fixed_time_vector_field_function( +def create_fixed_time_normalized_score_function( sigma_normalized_score_network: ScoreNetwork, noise_parameters: NoiseParameters, time: float, @@ -36,9 +36,6 @@ def vector_field_function(relative_coordinates: torch.Tensor) -> torch.Tensor: forces = torch.zeros_like(relative_coordinates) sigmas = variance_calculator.get_sigma(times) - g2 = variance_calculator.get_g_squared(times) - - vector_field_prefactor = -g2 / sigmas augmented_batch = { NOISY_RELATIVE_COORDINATES: relative_coordinates, @@ -52,14 +49,7 @@ def vector_field_function(relative_coordinates: torch.Tensor) -> torch.Tensor: augmented_batch, conditional=False ) - prefactor = einops.repeat( - vector_field_prefactor, - "b 1 -> b n s", - n=number_of_atoms, - s=spatial_dimension, - ) - - return prefactor * sigma_normalized_scores + return sigma_normalized_scores return vector_field_function @@ -81,19 +71,39 @@ def hessian_function(relative_coordinates: torch.Tensor) -> torch.Tensor: return hessian_function -def get_flat_vector_field_function( - vector_field_function: Callable, number_of_atoms: int, spatial_dimension: int -) -> Callable: +def get_square_norm_and_grad_functions( + vector_field_function: Callable, number_of_atoms: int, spatial_dimension: int, device: torch.device +) -> Tuple[Callable, Callable]: """Get a flat vector field function.""" - def flat_vector_field_function(x: np.ndarray) -> np.ndarray: - cast_x = torch.from_numpy(x).to(torch.float32) + + hessian_function = get_hessian_function(vector_field_function) + + def _get_relative_coordinates(x: np.ndarray) -> torch.Tensor: + cast_x = torch.from_numpy(x).to(torch.float32).to(device) relative_coordinates = einops.rearrange( cast_x, "(n s) -> 1 n s", n=number_of_atoms, s=spatial_dimension ) relative_coordinates = map_relative_coordinates_to_unit_cell( relative_coordinates ) + return relative_coordinates + + def square_norm_function(x: np.ndarray) -> np.ndarray: + relative_coordinates = _get_relative_coordinates(x) vector_field = vector_field_function(relative_coordinates) - return einops.rearrange(vector_field, "1 n s -> (n s)").numpy() + square_norm = 0.5 * (vector_field**2).flatten().sum() + return square_norm.cpu().numpy() + + def gradient_function(x: np.ndarray) -> np.ndarray: + relative_coordinates = _get_relative_coordinates(x) + + vector_field = vector_field_function(relative_coordinates) + + flat_vector_field = einops.rearrange(vector_field, "1 n s -> (n s)") + flat_hessian = einops.rearrange(hessian_function(relative_coordinates), "1 ns1 ns2 -> ns1 ns2") + + gradient = torch.matmul(flat_vector_field, flat_hessian) + + return gradient.cpu().numpy() - return flat_vector_field_function + return square_norm_function, gradient_function \ No newline at end of file From 7b1b5f37c1f954baf7a282fb263058a497e83c31 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 3 Oct 2024 09:49:57 -0400 Subject: [PATCH 03/18] Plotting the score norm along specific directions. --- .../plot_score_norm.py | 123 ++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 experiment_analysis/score_stability_analysis/plot_score_norm.py diff --git a/experiment_analysis/score_stability_analysis/plot_score_norm.py b/experiment_analysis/score_stability_analysis/plot_score_norm.py new file mode 100644 index 00000000..96fa6057 --- /dev/null +++ b/experiment_analysis/score_stability_analysis/plot_score_norm.py @@ -0,0 +1,123 @@ +import logging + +import einops +import matplotlib.pyplot as plt +import numpy as np +import torch +from tqdm import tqdm + +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.analysis.analytic_score.utils import \ + get_silicon_supercell +from crystal_diffusion.models.score_networks.mlp_score_network import \ + MLPScoreNetworkParameters +from crystal_diffusion.models.score_networks.score_network_factory import \ + create_score_network +from crystal_diffusion.samplers.exploding_variance import ExplodingVariance +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell +from crystal_diffusion.utils.logging_utils import setup_analysis_logger +from experiment_analysis.score_stability_analysis.util import \ + create_fixed_time_normalized_score_function + +plt.style.use(PLOT_STYLE_PATH) + +logger = logging.getLogger(__name__) +setup_analysis_logger() + + +checkpoint_path = ("/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/" + "output/last_model/last_model-epoch=049-step=039100.ckpt") + +spatial_dimension = 3 +number_of_atoms = 8 +atom_types = np.ones(number_of_atoms, dtype=int) + +acell = 5.43 +basis_vectors = torch.diag(torch.tensor([acell, acell, acell])) + +total_time_steps = 1000 +noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + sigma_min=0.0001, + sigma_max=0.2, +) + +device = torch.device("cuda") +if __name__ == "__main__": + # For debugging + score_network_parameters = MLPScoreNetworkParameters( + number_of_atoms=number_of_atoms, + n_hidden_dimensions=1, + hidden_dimensions_size=16, + embedding_dimensions_size=8, + condition_embedding_size=8, + ) + + sigma_normalized_score_network = create_score_network(score_network_parameters) + + variance_calculator = ExplodingVariance(noise_parameters) + + """ + logger.info("Loading checkpoint...") + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + """ + + for parameter in sigma_normalized_score_network.parameters(): + parameter.requires_grad_(False) + + equilibrium_relative_coordinates = torch.from_numpy( + get_silicon_supercell(supercell_factor=1) + ).to(torch.float32) + + direction = torch.zeros_like(equilibrium_relative_coordinates) + direction[2, 0] = 1.0 + + list_delta = torch.linspace(-0.5, 0.5, 101) + + relative_coordinates = [] + for delta in list_delta: + relative_coordinates.append( + equilibrium_relative_coordinates + delta * direction + ) + relative_coordinates = map_relative_coordinates_to_unit_cell( + torch.stack(relative_coordinates) + ) + + list_t = torch.tensor([1.0, 0.8, 0.5, 0.1, 0.01]) + list_sigmas = variance_calculator.get_sigma(list_t) + list_norms = [] + for t in tqdm(list_t, "norms"): + vector_field_fn = create_fixed_time_normalized_score_function( + sigma_normalized_score_network, + noise_parameters, + time=t, + basis_vectors=basis_vectors, + ) + + normalized_scores = vector_field_fn(relative_coordinates) + flat_normalized_scores = einops.rearrange( + normalized_scores, " b n s -> b (n s)" + ) + list_norms.append(flat_normalized_scores.norm(dim=-1)) + + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig.suptitle("Normalized Score Norm Along Specific Direction") + ax1 = fig.add_subplot(111) + ax1.set_xlabel(r"$\delta$") + ax1.set_ylabel(r"$|{\bf n}({\bf x}, t)|$") + + for t, sigma, norms in zip(list_t, list_sigmas, list_norms): + ax1.plot( + list_delta, norms, "-", label=f"t = {t: 3.2f}, $\\sigma$ = {sigma: 5.2e}" + ) + + ax1.legend(loc=0) + + fig.tight_layout() + + plt.show() From 4dbb3d7f1cc9811320e97be3012f15d90a1af752 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 4 Oct 2024 10:45:54 -0400 Subject: [PATCH 04/18] Drawing some samples. --- .../draw_samples_from_equilibrium.py | 209 ++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py diff --git a/experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py b/experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py new file mode 100644 index 00000000..2cfd6fa7 --- /dev/null +++ b/experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py @@ -0,0 +1,209 @@ +import logging + +import einops +import matplotlib.pyplot as plt +import numpy as np +import torch + +from crystal_diffusion.analysis import PLOT_STYLE_PATH, PLEASANT_FIG_SIZE +from crystal_diffusion.analysis.analytic_score.exploring_langevin_generator.generate_sample_energies import \ + EnergyCalculator +from crystal_diffusion.analysis.analytic_score.utils import \ + get_silicon_supercell +from crystal_diffusion.generators.langevin_generator import LangevinGenerator +from crystal_diffusion.generators.ode_position_generator import ExplodingVarianceODEPositionGenerator, \ + ODESamplingParameters +from crystal_diffusion.generators.predictor_corrector_position_generator import PredictorCorrectorSamplingParameters +from crystal_diffusion.generators.sde_position_generator import ExplodingVarianceSDEPositionGenerator, \ + SDESamplingParameters +from crystal_diffusion.models.position_diffusion_lightning_model import PositionDiffusionLightningModel +from crystal_diffusion.models.score_networks import ScoreNetwork +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.utils.logging_utils import setup_analysis_logger + +plt.style.use(PLOT_STYLE_PATH) + +logger = logging.getLogger(__name__) +setup_analysis_logger() + + +class ForcedStartingPointLangevinGenerator(LangevinGenerator): + + def __init__(self, + noise_parameters: NoiseParameters, + sampling_parameters: PredictorCorrectorSamplingParameters, + sigma_normalized_score_network: ScoreNetwork, + starting_relative_coordinates: torch.Tensor + ): + super().__init__(noise_parameters, sampling_parameters, sigma_normalized_score_network) + + self._starting_relative_coordinates = starting_relative_coordinates + + def initialize(self, number_of_samples: int): + """This method must initialize the samples from the fully noised distribution.""" + relative_coordinates = einops.repeat(self._starting_relative_coordinates, + "natoms space -> batch_size natoms space", + batch_size=number_of_samples) + return relative_coordinates + +class ForcedStartingPointODEPositionGenerator(ExplodingVarianceODEPositionGenerator): + + def __init__(self, + noise_parameters: NoiseParameters, + sampling_parameters: ODESamplingParameters, + sigma_normalized_score_network: ScoreNetwork, + starting_relative_coordinates: torch.Tensor + ): + super().__init__(noise_parameters, sampling_parameters, sigma_normalized_score_network) + + self._starting_relative_coordinates = starting_relative_coordinates + + def initialize(self, number_of_samples: int): + """This method must initialize the samples from the fully noised distribution.""" + relative_coordinates = einops.repeat(self._starting_relative_coordinates, + "natoms space -> batch_size natoms space", + batch_size=number_of_samples) + return relative_coordinates + + +class ForcedStartingPointSDEPositionGenerator(ExplodingVarianceSDEPositionGenerator): + def __init__(self, + noise_parameters: NoiseParameters, + sampling_parameters: SDESamplingParameters, + sigma_normalized_score_network: ScoreNetwork, + starting_relative_coordinates: torch.Tensor + ): + super().__init__(noise_parameters, sampling_parameters, sigma_normalized_score_network) + + self._starting_relative_coordinates = starting_relative_coordinates + + def initialize(self, number_of_samples: int): + """This method must initialize the samples from the fully noised distribution.""" + relative_coordinates = einops.repeat(self._starting_relative_coordinates, + "natoms space -> batch_size natoms space", + batch_size=number_of_samples) + return relative_coordinates + + +checkpoint_path = ("/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/" + "output/last_model/last_model-epoch=049-step=039100.ckpt") + +spatial_dimension = 3 +number_of_atoms = 8 +atom_types = np.ones(number_of_atoms, dtype=int) + +acell = 5.43 + +total_time_steps = 1000 +number_of_corrector_steps = 10 +epsilon =2.0e-7 +noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + corrector_step_epsilon=epsilon, + sigma_min=0.0001, + sigma_max=0.2, +) +number_of_samples = 1000 + +base_sampling_parameters_dict = dict(number_of_atoms=number_of_atoms, + spatial_dimension=spatial_dimension, + cell_dimensions=[acell, acell, acell], + number_of_samples=number_of_samples) + +ode_sampling_parameters = ODESamplingParameters(absolute_solver_tolerance=1.0e-5, + relative_solver_tolerance=1.0e-5, + **base_sampling_parameters_dict) + +# Fiddling with SDE is PITA. Also, is there a bug in there? +sde_sampling_parameters = SDESamplingParameters(adaptative=False, **base_sampling_parameters_dict) + + +langevin_sampling_parameters = PredictorCorrectorSamplingParameters(number_of_corrector_steps=number_of_corrector_steps, + **base_sampling_parameters_dict) + +device = torch.device("cuda") +if __name__ == "__main__": + + basis_vectors = torch.diag(torch.tensor([acell, acell, acell])).to(device) + + logger.info("Loading checkpoint...") + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + for parameter in sigma_normalized_score_network.parameters(): + parameter.requires_grad_(False) + + equilibrium_relative_coordinates = torch.from_numpy( + get_silicon_supercell(supercell_factor=1) + ).to(torch.float32).to(device) + + ode_generator = ForcedStartingPointODEPositionGenerator( + noise_parameters=noise_parameters, + sampling_parameters=ode_sampling_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + starting_relative_coordinates=equilibrium_relative_coordinates) + + sde_generator = ForcedStartingPointSDEPositionGenerator( + noise_parameters=noise_parameters, + sampling_parameters=sde_sampling_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + starting_relative_coordinates=equilibrium_relative_coordinates) + + langevin_generator = ForcedStartingPointLangevinGenerator( + noise_parameters=noise_parameters, + sampling_parameters=langevin_sampling_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + starting_relative_coordinates=equilibrium_relative_coordinates) + + unit_cells = einops.repeat(basis_vectors, "s1 s2 -> b s1 s2", b=number_of_samples) + + ode_samples = ode_generator.sample(number_of_samples=number_of_samples, device=device, unit_cell=unit_cells) + sde_samples = sde_generator.sample(number_of_samples=number_of_samples, device=device, unit_cell=unit_cells) + + langevin_samples = langevin_generator.sample(number_of_samples=number_of_samples, + device=device, + unit_cell=unit_cells) + + energy_calculator = EnergyCalculator(unit_cell=basis_vectors, number_of_atoms=number_of_atoms) + + ode_energies = energy_calculator.compute_oracle_energies(ode_samples) + sde_energies = energy_calculator.compute_oracle_energies(sde_samples) + langevin_energies = energy_calculator.compute_oracle_energies(langevin_samples) + + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig.suptitle("Energies of Samples Drawn from Equilibrium Coordinates") + ax1 = fig.add_subplot(121) + ax2 = fig.add_subplot(122) + ax1.set_title("Zoom In") + ax2.set_title("Broad View") + + list_q = np.linspace(0, 1, 101) + + list_energies = [ode_energies, sde_energies, langevin_energies] + + list_colors = ['blue', 'red', 'green'] + + langevin_label = (f"LANGEVIN (time steps = {total_time_steps}, " + f"corrector steps = {number_of_corrector_steps}, epsilon ={epsilon: 5.2e})") + + list_labels = ["ODE", "SDE", langevin_label] + + for ax in [ax1, ax2]: + + for energies, label in zip(list_energies, list_labels): + quantiles = np.quantile(energies, list_q) + ax.plot( 100 * list_q, quantiles, "-", label=label) + + ax.fill_between([0, 100], y1=-34.6, y2=-34.1, color='yellow', alpha=0.25, label='Training Energy Range') + + ax.set_xlabel("Quantile (%)") + ax.set_ylabel("Energy (eV)") + ax.set_xlim(-0.1, 100.1) + ax1.set_ylim(-35, -34.) + ax2.legend(loc="upper right", fancybox=True, shadow=True, ncol=1, fontsize=6) + + fig.tight_layout() + + plt.show() From 251feb78d70fcac586461289e16ba0d8f1d5a363 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 4 Oct 2024 10:46:13 -0400 Subject: [PATCH 05/18] Fix device bjork in energy calculator. --- .../exploring_langevin_generator/generate_sample_energies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py b/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py index 02340d4d..64d5c4fc 100644 --- a/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py +++ b/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py @@ -49,7 +49,7 @@ def compute_oracle_energies(self, batch_relative_coordinates: torch.Tensor) -> n logger.info("Compute energy from Oracle") with tempfile.TemporaryDirectory() as tmp_work_dir: - for positions, box in zip(batch_cartesian_positions.numpy(), batched_unit_cells.numpy()): + for positions, box in zip(batch_cartesian_positions.cpu().numpy(), batched_unit_cells.cpu().numpy()): energy, forces = get_energy_and_forces_from_lammps(positions, box, self.atom_types, From 46af909e0eb2d0068ec821f0e288a38595db46ad Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 4 Oct 2024 12:49:57 -0400 Subject: [PATCH 06/18] plotting norms. --- .../plot_score_norm.py | 35 +++++++------------ 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/experiment_analysis/score_stability_analysis/plot_score_norm.py b/experiment_analysis/score_stability_analysis/plot_score_norm.py index 96fa6057..c44f96c2 100644 --- a/experiment_analysis/score_stability_analysis/plot_score_norm.py +++ b/experiment_analysis/score_stability_analysis/plot_score_norm.py @@ -9,10 +9,7 @@ from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH from crystal_diffusion.analysis.analytic_score.utils import \ get_silicon_supercell -from crystal_diffusion.models.score_networks.mlp_score_network import \ - MLPScoreNetworkParameters -from crystal_diffusion.models.score_networks.score_network_factory import \ - create_score_network +from crystal_diffusion.models.position_diffusion_lightning_model import PositionDiffusionLightningModel from crystal_diffusion.samplers.exploding_variance import ExplodingVariance from crystal_diffusion.samplers.variance_sampler import NoiseParameters from crystal_diffusion.utils.basis_transformations import \ @@ -46,26 +43,13 @@ device = torch.device("cuda") if __name__ == "__main__": - # For debugging - score_network_parameters = MLPScoreNetworkParameters( - number_of_atoms=number_of_atoms, - n_hidden_dimensions=1, - hidden_dimensions_size=16, - embedding_dimensions_size=8, - condition_embedding_size=8, - ) - - sigma_normalized_score_network = create_score_network(score_network_parameters) - variance_calculator = ExplodingVariance(noise_parameters) - """ logger.info("Loading checkpoint...") pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) pl_model.eval() sigma_normalized_score_network = pl_model.sigma_normalized_score_network - """ for parameter in sigma_normalized_score_network.parameters(): parameter.requires_grad_(False) @@ -75,9 +59,16 @@ ).to(torch.float32) direction = torch.zeros_like(equilibrium_relative_coordinates) - direction[2, 0] = 1.0 - list_delta = torch.linspace(-0.5, 0.5, 101) + # Move a single atom + # direction[0, 0] = 1.0 + # list_delta = torch.linspace(-0.5, 0.5, 101) + + # Put two particles on top of each other + dv = equilibrium_relative_coordinates[0] - equilibrium_relative_coordinates[1] + direction[0] = -0.5 * dv + direction[1] = 0.5 * dv + list_delta = torch.linspace(0., 2.0, 201) relative_coordinates = [] for delta in list_delta: @@ -86,9 +77,9 @@ ) relative_coordinates = map_relative_coordinates_to_unit_cell( torch.stack(relative_coordinates) - ) + ).to(device) - list_t = torch.tensor([1.0, 0.8, 0.5, 0.1, 0.01]) + list_t = torch.tensor([0.8, 0.7, 0.5, 0.3, 0.1, 0.01]) list_sigmas = variance_calculator.get_sigma(list_t) list_norms = [] for t in tqdm(list_t, "norms"): @@ -103,7 +94,7 @@ flat_normalized_scores = einops.rearrange( normalized_scores, " b n s -> b (n s)" ) - list_norms.append(flat_normalized_scores.norm(dim=-1)) + list_norms.append(flat_normalized_scores.norm(dim=-1).cpu()) fig = plt.figure(figsize=PLEASANT_FIG_SIZE) fig.suptitle("Normalized Score Norm Along Specific Direction") From d1f72d6692dd51b16c279a866037d7211d1e84d9 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 4 Oct 2024 20:22:13 -0400 Subject: [PATCH 07/18] Plotting hessian eigenvalues. --- .../analyse_stability.py | 106 -------------- .../plot_hessian_eigenvalues.py | 131 ++++++++++++++++++ 2 files changed, 131 insertions(+), 106 deletions(-) delete mode 100644 experiment_analysis/score_stability_analysis/analyse_stability.py create mode 100644 experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py diff --git a/experiment_analysis/score_stability_analysis/analyse_stability.py b/experiment_analysis/score_stability_analysis/analyse_stability.py deleted file mode 100644 index 0cc55ae6..00000000 --- a/experiment_analysis/score_stability_analysis/analyse_stability.py +++ /dev/null @@ -1,106 +0,0 @@ -import logging -import matplotlib.pyplot as plt - -import numpy as np -import torch - -from crystal_diffusion.analysis import PLOT_STYLE_PATH, PLEASANT_FIG_SIZE -from crystal_diffusion.analysis.analytic_score.utils import get_silicon_supercell -from crystal_diffusion.models.position_diffusion_lightning_model import PositionDiffusionLightningModel -from crystal_diffusion.samplers.exploding_variance import ExplodingVariance -from crystal_diffusion.samplers.variance_sampler import NoiseParameters -from crystal_diffusion.utils.logging_utils import setup_analysis_logger -from experiment_analysis.score_stability_analysis.util import ( - create_fixed_time_normalized_score_function, get_hessian_function, get_square_norm_and_grad_functions) - - -plt.style.use(PLOT_STYLE_PATH) - -logger = logging.getLogger(__name__) -setup_analysis_logger() - - -checkpoint_path = "/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/output/last_model/last_model-epoch=049-step=039100.ckpt" - -spatial_dimension = 3 -number_of_atoms = 8 -atom_types = np.ones(number_of_atoms, dtype=int) - -acell = 5.43 -basis_vectors = torch.diag(torch.tensor([acell, acell, acell])) - -total_time_steps = 1000 -noise_parameters = NoiseParameters( - total_time_steps=total_time_steps, - sigma_min=0.0001, - sigma_max=0.2, -) - -device = torch.device('cuda') -if __name__ == "__main__": - variance_calculator = ExplodingVariance(noise_parameters) - - - logger.info("Loading checkpoint...") - pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) - pl_model.eval() - - sigma_normalized_score_network = pl_model.sigma_normalized_score_network - - for parameter in sigma_normalized_score_network.parameters(): - parameter.requires_grad_(False) - - fig = plt.figure(figsize=PLEASANT_FIG_SIZE) - fig.suptitle("Probing model along specific dimension") - ax1 = fig.add_subplot(121) - ax2 = fig.add_subplot(122) - - ax1.set_title("Score Norm") - ax2.set_title("Gradient of Score Norm") - - for ax in [ax1, ax2]: - ax.set_xlabel('x[2][0]') - - eq_rel_coords = get_silicon_supercell(supercell_factor=1) - for t in [1.0, 0.8, 0.5, 0.1, 0.01]: - print(f"Doing t = {t}") - - sigma = float(variance_calculator.get_sigma(t)) - - vector_field_fn = create_fixed_time_normalized_score_function(sigma_normalized_score_network, - noise_parameters, - time=t, - basis_vectors=basis_vectors) - - func, grad_func = get_square_norm_and_grad_functions(vector_field_fn, - number_of_atoms, - spatial_dimension, - device) - - - x0 = eq_rel_coords.flatten() - - list_dx = np.linspace(-0.5, 0.5, 101) - - list_f = [] - list_g = [] - for dx in list_dx: - r = eq_rel_coords.copy() - r[2][0] += dx - x = r.flatten() - list_f.append(func(x)) - list_g.append(grad_func(x)[6]) - - list_f = np.array(list_f) - list_g = np.array(list_g) - - ax1.plot(list_dx, list_f/np.sqrt(sigma), '-', label=f't = {t}, $\sigma$ = {sigma: 5.2e}') - ax2.plot(list_dx, list_g, '-') - - ax1.legend(loc=0) - - fig.tight_layout() - - plt.show() - - #out = so.minimize(func, x0, jac=grad_func) \ No newline at end of file diff --git a/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py new file mode 100644 index 00000000..15f5afa5 --- /dev/null +++ b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py @@ -0,0 +1,131 @@ +import logging +from pathlib import Path + +import einops +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.func import jacrev + +from crystal_diffusion.analysis import PLOT_STYLE_PATH, PLEASANT_FIG_SIZE +from crystal_diffusion.analysis.analytic_score.utils import get_silicon_supercell +from crystal_diffusion.models.position_diffusion_lightning_model import PositionDiffusionLightningModel +from crystal_diffusion.samplers.exploding_variance import ExplodingVariance +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.utils.logging_utils import setup_analysis_logger +from experiment_analysis.score_stability_analysis.util import get_normalized_score_function + +plt.style.use(PLOT_STYLE_PATH) + + + + + +logger = logging.getLogger(__name__) +setup_analysis_logger() + + +checkpoint_path = Path("/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/output/last_model/last_model-epoch=049-step=039100.ckpt") + +spatial_dimension = 3 +number_of_atoms = 8 +atom_types = np.ones(number_of_atoms, dtype=int) + +acell = 5.43 + +total_time_steps = 1000 +sigma_min = 0.0001 +sigma_max = 0.2 +noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + sigma_min=sigma_min, + sigma_max=sigma_max, +) + + +nsteps = 501 + +hessian_batch_size = 10 + +device = torch.device('cuda') +if __name__ == "__main__": + variance_calculator = ExplodingVariance(noise_parameters) + + basis_vectors = torch.diag(torch.tensor([acell, acell, acell])).to(device) + equilibrium_relative_coordinates = torch.from_numpy( + get_silicon_supercell(supercell_factor=1)).to(torch.float32).to(device) + + logger.info("Loading checkpoint...") + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + for parameter in sigma_normalized_score_network.parameters(): + parameter.requires_grad_(False) + + normalized_score_function = get_normalized_score_function( + noise_parameters=noise_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + basis_vectors=basis_vectors) + + + times = torch.linspace(1, 0, nsteps).unsqueeze(-1) + sigmas = variance_calculator.get_sigma(times) + g2 = variance_calculator.get_g_squared(times) + + prefactor = -g2 / sigmas + + relative_coordinates = einops.repeat(equilibrium_relative_coordinates, "n s -> b n s", b=nsteps) + + + batch_hessian_function = jacrev(normalized_score_function, argnums=0) + + list_flat_hessians = [] + for x, t in zip(torch.split(relative_coordinates, hessian_batch_size), torch.split(times, hessian_batch_size)): + batch_hessian = batch_hessian_function(x, t) + flat_hessian = einops.rearrange(torch.diagonal(batch_hessian, dim1=0, dim2=3), + "n1 s1 n2 s2 b -> b (n1 s1) (n2 s2)") + list_flat_hessians.append(flat_hessian) + + flat_hessian = torch.concat(list_flat_hessians) + + p = einops.repeat(prefactor ,"b 1 -> b d1 d2", + d1=number_of_atoms * spatial_dimension, + d2=number_of_atoms * spatial_dimension).to(flat_hessian) + + normalized_hessian = p * flat_hessian + + eigenvalues, eigenvectors = torch.linalg.eigh(normalized_hessian) + eigenvalues = eigenvalues.cpu().transpose(1, 0) + + small_count = (eigenvalues < 5e-4).sum(dim=0) + list_times = times.flatten().cpu() + list_sigmas = sigmas.flatten().cpu() + + + fig = plt.figure(figsize=(PLEASANT_FIG_SIZE[0], PLEASANT_FIG_SIZE[0])) + fig.suptitle("Hessian Eigenvalues") + ax1 = fig.add_subplot(311) + ax2 = fig.add_subplot(312) + ax3 = fig.add_subplot(313) + + ax3.set_xlabel(r'$\sigma(t)$') + ax3.set_ylabel('Small Count') + ax3.set_ylim(0, number_of_atoms * spatial_dimension) + + ax3.semilogx(list_sigmas, small_count, '-', color='black') + + for ax in [ax1, ax2]: + ax.set_xlabel(r'$\sigma(t)$') + ax.set_ylabel('Eigenvalue') + + for list_e in eigenvalues: + ax.semilogx(list_sigmas, list_e, '.', color='grey') + + ax2.set_ylim([-2.5e-4, 2.5e-4]) + for ax in [ax1, ax2, ax3]: + ax.set_xlim(sigma_min, sigma_max) + fig.tight_layout() + + plt.show() From 02e6b0c769665cd9f39f0e13960f8f49e13bc22e Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Sat, 5 Oct 2024 18:09:26 -0400 Subject: [PATCH 08/18] Scripts to do group theory. --- .../fun_times_with_point_groups.py | 176 ++++++++++++++++++ .../plot_hessian_eigenvalues.py | 4 - .../score_stability_analysis/util.py | 81 ++------ 3 files changed, 191 insertions(+), 70 deletions(-) create mode 100644 experiment_analysis/score_stability_analysis/fun_times_with_point_groups.py diff --git a/experiment_analysis/score_stability_analysis/fun_times_with_point_groups.py b/experiment_analysis/score_stability_analysis/fun_times_with_point_groups.py new file mode 100644 index 00000000..28df3ca7 --- /dev/null +++ b/experiment_analysis/score_stability_analysis/fun_times_with_point_groups.py @@ -0,0 +1,176 @@ +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import torch +from einops import einops + +from torch.func import jacrev + +from crystal_diffusion.analysis import PLOT_STYLE_PATH +from crystal_diffusion.analysis.analytic_score.utils import get_silicon_supercell +from crystal_diffusion.models.position_diffusion_lightning_model import PositionDiffusionLightningModel +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.utils.basis_transformations import map_relative_coordinates_to_unit_cell +from crystal_diffusion.utils.logging_utils import setup_analysis_logger +from experiment_analysis.score_stability_analysis.util import get_cubic_point_group_symmetries, \ + get_normalized_score_function +from tests.fake_data_utils import find_aligning_permutation + +plt.style.use(PLOT_STYLE_PATH) + +logger = logging.getLogger(__name__) +setup_analysis_logger() + +checkpoint_path = Path("/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/output/last_model/last_model-epoch=049-step=039100.ckpt") + +spatial_dimension = 3 +number_of_atoms = 8 + +acell = 5.43 + +total_time_steps = 1000 +sigma_min = 0.0001 +sigma_max = 0.2 +noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + sigma_min=sigma_min, + sigma_max=sigma_max, +) + +device = torch.device('cuda') +if __name__ == "__main__": + equilibrium_relative_coordinates = torch.from_numpy(get_silicon_supercell(supercell_factor=1)).to(torch.float32) + + list_g = get_cubic_point_group_symmetries() + nsym = len(list_g) + + # Find the group operations that leave the set of equilibrium coordinates unchanged + list_stabilizing_g = [] + list_permutations = [] + + for g in list_g: + x = map_relative_coordinates_to_unit_cell(equilibrium_relative_coordinates @ g.transpose(1, 0)) + try: + permutation = find_aligning_permutation(equilibrium_relative_coordinates, x, tol=1e-8) + list_stabilizing_g.append(g) + permutation_op = torch.diag(torch.ones(number_of_atoms))[permutation, :] + list_permutations.append(permutation_op) + print("Found stabilizing operation.") + except: + continue + + # Confirm that the equilibrium coordinates are indeed left unchanged. + for g, p in zip(list_stabilizing_g, list_permutations): + x = map_relative_coordinates_to_unit_cell(equilibrium_relative_coordinates @ g.transpose(1, 0)) + error = torch.linalg.norm(equilibrium_relative_coordinates - p @ x) + torch.testing.assert_close(error, torch.tensor(0.0)) + + + # build operators in the flat number_of_atoms x spatial_dimension space + list_op = [] + for g, p in zip(list_stabilizing_g, list_permutations): + flat_g = torch.block_diag(*(number_of_atoms * [g])) + flat_p = torch.zeros(number_of_atoms * spatial_dimension, number_of_atoms * spatial_dimension) + for i in range(number_of_atoms): + for j in range(number_of_atoms): + if p[i, j] == 1: + for k in range(spatial_dimension): + flat_p[spatial_dimension * i + k, spatial_dimension * j + k] = 1. + + op = flat_p @ flat_g + list_op.append(op) + + # Double check that positions are left invariant + x0 = einops.rearrange(equilibrium_relative_coordinates, "n d -> (n d)") + for op in list_op: + x = map_relative_coordinates_to_unit_cell(op @ x0) + torch.testing.assert_close(torch.norm(x - x0), torch.tensor(0.)) + + # Double check that the operators form a group + # Inverse is present + for op in list_op: + inv_op = op.transpose(1, 0) + + found = False + for check_op in list_op: + if torch.linalg.norm(inv_op - check_op) < 1e-8: + found = True + assert found + + # closed to product + for op1 in list_op: + for op2 in list_op: + new_op = op1 @ op2 + + found = False + for op3 in list_op: + if torch.linalg.norm(new_op - op3) < 1e-8: + found = True + assert found + + + times = torch.ones(nsym).unsqueeze(-1) + logger.info("Loading checkpoint...") + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + for parameter in sigma_normalized_score_network.parameters(): + parameter.requires_grad_(False) + + basis_vectors = torch.diag(torch.tensor([acell, acell, acell])).to(device) + + normalized_score_function = get_normalized_score_function( + noise_parameters=noise_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + basis_vectors=basis_vectors) + + + normalized_scores = normalized_score_function(relative_coordinates, times) + + batch_hessian_function = jacrev(normalized_score_function, argnums=0) + + hessian_batch_size = 10 + list_flat_hessians = [] + for x, t in zip(torch.split(relative_coordinates, hessian_batch_size), torch.split(times, hessian_batch_size)): + batch_hessian = batch_hessian_function(x, t) + flat_hessian = einops.rearrange(torch.diagonal(batch_hessian, dim1=0, dim2=3), + "n1 s1 n2 s2 b -> b (n1 s1) (n2 s2)") + list_flat_hessians.append(flat_hessian) + + flat_hessian = torch.concat(list_flat_hessians) + + + list_flat_g = [] + for g in list_g: + flat_g = torch.block_diag(*(number_of_atoms * [g])) + list_flat_g.append(flat_g) + + list_flat_g = torch.stack(list_flat_g).to(device) + + identity_id = 7 # by inspection + + x0 = relative_coordinates[identity_id] + h0 = flat_hessian[identity_id] + + list_errors = [] + for h, g in zip(flat_hessian, list_flat_g): + new_h = (g.transpose(1, 0) @ h) @ g + error = (new_h - h0).abs().max() + + list_errors.append(error) + + print(torch.tensor(list_errors).max()) + + n = number_of_atoms * spatial_dimension + random_h = torch.zeros(n, n).to(device) + + for g in list_flat_g: + r = torch.rand(n, n).to(device) + r = 0.5 * (r + r.transpose(1, 0)) + + random_h += (g.transpose(1, 0) @ r) @ g + + random_h = random_h / len(list_g) diff --git a/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py index 15f5afa5..77e056ff 100644 --- a/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py +++ b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py @@ -17,10 +17,6 @@ plt.style.use(PLOT_STYLE_PATH) - - - - logger = logging.getLogger(__name__) setup_analysis_logger() diff --git a/experiment_analysis/score_stability_analysis/util.py b/experiment_analysis/score_stability_analysis/util.py index 1d4d3692..efec96f5 100644 --- a/experiment_analysis/score_stability_analysis/util.py +++ b/experiment_analysis/score_stability_analysis/util.py @@ -1,3 +1,4 @@ +import itertools from typing import Callable, Tuple import einops @@ -14,22 +15,15 @@ from crystal_diffusion.utils.basis_transformations import \ map_relative_coordinates_to_unit_cell +def get_normalized_score_function( noise_parameters: NoiseParameters, + sigma_normalized_score_network: ScoreNetwork, + basis_vectors: torch.Tensor)-> Callable: -def create_fixed_time_normalized_score_function( - sigma_normalized_score_network: ScoreNetwork, - noise_parameters: NoiseParameters, - time: float, - basis_vectors: torch.Tensor, -): - """Create the vector field function.""" variance_calculator = ExplodingVariance(noise_parameters) - def vector_field_function(relative_coordinates: torch.Tensor) -> torch.Tensor: - batch_size, number_of_atoms, spatial_dimension = relative_coordinates.shape + def normalized_score_function(relative_coordinates: torch.Tensor, times: torch.Tensor) -> torch.Tensor: - times = einops.repeat( - torch.tensor([time]).to(relative_coordinates), "1 -> b 1", b=batch_size - ) + batch_size, number_of_atoms, spatial_dimension = relative_coordinates.shape unit_cells = einops.repeat( basis_vectors.to(relative_coordinates), "s1 s2 -> b s1 s2", b=batch_size ) @@ -51,59 +45,14 @@ def vector_field_function(relative_coordinates: torch.Tensor) -> torch.Tensor: return sigma_normalized_scores - return vector_field_function - - -def get_hessian_function(vector_field_function: Callable) -> Callable: - """Get hessian function.""" - batch_hessian_function = jacrev(vector_field_function, argnums=0) - - def hessian_function(relative_coordinates: torch.Tensor) -> torch.Tensor: - # The batch hessian has dimension [batch_size, natoms, space, batch_size, natoms, space] - batch_hessian = batch_hessian_function(relative_coordinates) - - # Diagonal dumps the batch dimension at the end. - hessian = torch.diagonal(batch_hessian, dim1=0, dim2=3) - - flat_hessian = einops.rearrange(hessian, "n1 s1 n2 s2 b -> b (n1 s1) (n2 s2)") - return flat_hessian - - return hessian_function - - -def get_square_norm_and_grad_functions( - vector_field_function: Callable, number_of_atoms: int, spatial_dimension: int, device: torch.device -) -> Tuple[Callable, Callable]: - """Get a flat vector field function.""" - - hessian_function = get_hessian_function(vector_field_function) - - def _get_relative_coordinates(x: np.ndarray) -> torch.Tensor: - cast_x = torch.from_numpy(x).to(torch.float32).to(device) - relative_coordinates = einops.rearrange( - cast_x, "(n s) -> 1 n s", n=number_of_atoms, s=spatial_dimension - ) - relative_coordinates = map_relative_coordinates_to_unit_cell( - relative_coordinates - ) - return relative_coordinates - - def square_norm_function(x: np.ndarray) -> np.ndarray: - relative_coordinates = _get_relative_coordinates(x) - vector_field = vector_field_function(relative_coordinates) - square_norm = 0.5 * (vector_field**2).flatten().sum() - return square_norm.cpu().numpy() - - def gradient_function(x: np.ndarray) -> np.ndarray: - relative_coordinates = _get_relative_coordinates(x) - - vector_field = vector_field_function(relative_coordinates) - - flat_vector_field = einops.rearrange(vector_field, "1 n s -> (n s)") - flat_hessian = einops.rearrange(hessian_function(relative_coordinates), "1 ns1 ns2 -> ns1 ns2") - - gradient = torch.matmul(flat_vector_field, flat_hessian) + return normalized_score_function - return gradient.cpu().numpy() +def get_cubic_point_group_symmetries(): + permutations = [torch.diag(torch.ones(3))[[idx]] for idx in itertools.permutations([0, 1, 2])] + sign_changes = [torch.diag(torch.tensor(diag)) for diag in itertools.product([-1., 1.], repeat=3)] + symmetries = [] + for permutation in permutations: + for sign_change in sign_changes: + symmetries.append(permutation @ sign_change) - return square_norm_function, gradient_function \ No newline at end of file + return symmetries \ No newline at end of file From 7875ab4d6a202a17e1fc756be58b4136794b3945 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 7 Oct 2024 11:56:38 -0400 Subject: [PATCH 09/18] Plotting more. --- .../plot_hessian_eigenvalues.py | 50 ++++++++++++++++++- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py index 77e056ff..42957844 100644 --- a/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py +++ b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py @@ -74,7 +74,6 @@ relative_coordinates = einops.repeat(equilibrium_relative_coordinates, "n s -> b n s", b=nsteps) - batch_hessian_function = jacrev(normalized_score_function, argnums=0) list_flat_hessians = [] @@ -86,7 +85,7 @@ flat_hessian = torch.concat(list_flat_hessians) - p = einops.repeat(prefactor ,"b 1 -> b d1 d2", + p = einops.repeat(prefactor,"b 1 -> b d1 d2", d1=number_of_atoms * spatial_dimension, d2=number_of_atoms * spatial_dimension).to(flat_hessian) @@ -125,3 +124,50 @@ fig.tight_layout() plt.show() + + fig2 = plt.figure(figsize=(PLEASANT_FIG_SIZE[0], PLEASANT_FIG_SIZE[0])) + fig2.suptitle("Hessian Eigenvalues At Small Time") + ax1 = fig2.add_subplot(111) + + ax1.set_xlabel(r'$\sigma(t)$') + ax1.set_ylabel('Eigenvalues') + + for list_e in eigenvalues: + ax1.loglog(list_sigmas, list_e, '.', color='grey') + + ax1.set_xlim(sigma_min, 1e-2) + + fig2.tight_layout() + + plt.show() + + + fig3 = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig3.suptitle("Hessian Eigenvalues of Normalized Score") + ax1 = fig3.add_subplot(121) + ax2 = fig3.add_subplot(122) + + ax1.set_xlabel(r'$\sigma(t)$') + ax1.set_ylabel('Eigenvalues') + + ax2.set_ylabel(r'$g^2(t) / \sigma(t)$') + ax2.set_xlabel(r'$\sigma(t)$') + + label1 = r'$\sigma(t)/g^2 \times \bf H$' + label2 = r'$\bf H$' + for list_e in eigenvalues: + ax1.semilogx(list_sigmas, list_e/(-prefactor.flatten()), '-', color='red', label=label1) + ax1.semilogx(list_sigmas, list_e, '-', color='grey', label=label2) + label1 = '__nolabel__' + label2 = '__nolabel__' + + ax1.legend(loc=0) + ax2.semilogx(list_sigmas, (-prefactor.flatten()), '-', color='blue') + + for ax in [ax1, ax2]: + ax.set_xlim(sigma_min, sigma_max) + + fig3.tight_layout() + + plt.show() + From 58a97d014b5a46a64a72a879b5be6c51685698da Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 7 Oct 2024 14:01:13 -0400 Subject: [PATCH 10/18] minor changes --- .../plot_hessian_eigenvalues.py | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py index 42957844..785d6714 100644 --- a/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py +++ b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py @@ -6,6 +6,7 @@ import numpy as np import torch from torch.func import jacrev +from tqdm import tqdm from crystal_diffusion.analysis import PLOT_STYLE_PATH, PLEASANT_FIG_SIZE from crystal_diffusion.analysis.analytic_score.utils import get_silicon_supercell @@ -20,15 +21,27 @@ logger = logging.getLogger(__name__) setup_analysis_logger() +system = 'Si_1x1x1' -checkpoint_path = Path("/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/output/last_model/last_model-epoch=049-step=039100.ckpt") +if system == 'Si_1x1x1': + checkpoint_path = Path("/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/output/last_model/last_model-epoch=049-step=039100.ckpt") + number_of_atoms = 8 + acell = 5.43 + supercell_factor = 1 + + hessian_batch_size = 10 + + +elif system == 'Si_2x2x2': + pickle_path = Path("/home/mila/r/rousseab/scratch/checkpoints/sota_egnn_2x2x2.pkl") + number_of_atoms = 64 + acell = 10.86 + supercell_factor = 2 + hessian_batch_size = 1 spatial_dimension = 3 -number_of_atoms = 8 atom_types = np.ones(number_of_atoms, dtype=int) -acell = 5.43 - total_time_steps = 1000 sigma_min = 0.0001 sigma_max = 0.2 @@ -41,7 +54,6 @@ nsteps = 501 -hessian_batch_size = 10 device = torch.device('cuda') if __name__ == "__main__": @@ -49,13 +61,17 @@ basis_vectors = torch.diag(torch.tensor([acell, acell, acell])).to(device) equilibrium_relative_coordinates = torch.from_numpy( - get_silicon_supercell(supercell_factor=1)).to(torch.float32).to(device) + get_silicon_supercell(supercell_factor=supercell_factor)).to(torch.float32).to(device) logger.info("Loading checkpoint...") - pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) - pl_model.eval() - sigma_normalized_score_network = pl_model.sigma_normalized_score_network + if system == 'Si_1x1x1': + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + elif system == 'Si_2x2x2': + sigma_normalized_score_network = torch.load(pickle_path) for parameter in sigma_normalized_score_network.parameters(): parameter.requires_grad_(False) @@ -65,7 +81,6 @@ sigma_normalized_score_network=sigma_normalized_score_network, basis_vectors=basis_vectors) - times = torch.linspace(1, 0, nsteps).unsqueeze(-1) sigmas = variance_calculator.get_sigma(times) g2 = variance_calculator.get_g_squared(times) @@ -77,7 +92,8 @@ batch_hessian_function = jacrev(normalized_score_function, argnums=0) list_flat_hessians = [] - for x, t in zip(torch.split(relative_coordinates, hessian_batch_size), torch.split(times, hessian_batch_size)): + for x, t in tqdm(zip(torch.split(relative_coordinates, hessian_batch_size), + torch.split(times, hessian_batch_size)), "Hessian"): batch_hessian = batch_hessian_function(x, t) flat_hessian = einops.rearrange(torch.diagonal(batch_hessian, dim1=0, dim2=3), "n1 s1 n2 s2 b -> b (n1 s1) (n2 s2)") @@ -116,7 +132,7 @@ ax.set_ylabel('Eigenvalue') for list_e in eigenvalues: - ax.semilogx(list_sigmas, list_e, '.', color='grey') + ax.semilogx(list_sigmas, list_e, '-', color='grey') ax2.set_ylim([-2.5e-4, 2.5e-4]) for ax in [ax1, ax2, ax3]: @@ -133,7 +149,7 @@ ax1.set_ylabel('Eigenvalues') for list_e in eigenvalues: - ax1.loglog(list_sigmas, list_e, '.', color='grey') + ax1.loglog(list_sigmas, list_e, '-', color='grey') ax1.set_xlim(sigma_min, 1e-2) @@ -156,8 +172,8 @@ label1 = r'$\sigma(t)/g^2 \times \bf H$' label2 = r'$\bf H$' for list_e in eigenvalues: - ax1.semilogx(list_sigmas, list_e/(-prefactor.flatten()), '-', color='red', label=label1) - ax1.semilogx(list_sigmas, list_e, '-', color='grey', label=label2) + ax1.semilogx(list_sigmas, list_e/(-prefactor.flatten()), '-', lw=1, color='red', label=label1) + ax1.semilogx(list_sigmas, list_e, '-', color='grey', lw=1, label=label2) label1 = '__nolabel__' label2 = '__nolabel__' @@ -171,3 +187,4 @@ plt.show() + jacobian_eig_at_t0 = eigenvalues[:, -1] / (-prefactor[-1]) From 7e179f7309928aee9f80b0ec801990fc2d7d4a7b Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Sun, 13 Oct 2024 18:36:01 -0400 Subject: [PATCH 11/18] A nice little script to compute the covariance matrices. --- .../dataset_analysis/dataset_covariance.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 experiment_analysis/dataset_analysis/dataset_covariance.py diff --git a/experiment_analysis/dataset_analysis/dataset_covariance.py b/experiment_analysis/dataset_analysis/dataset_covariance.py new file mode 100644 index 00000000..eed61e7e --- /dev/null +++ b/experiment_analysis/dataset_analysis/dataset_covariance.py @@ -0,0 +1,78 @@ +"""Effective Dataset Variance. + +The goal of this script is to compute the effective "sigma_d" of the +actual datasets, that is, the standard deviation of the displacement +from equilibrium, in fractional coordinates. +""" +import logging + +import einops +import torch +from tqdm import tqdm + +from crystal_diffusion import ANALYSIS_RESULTS_DIR, DATA_DIR +from crystal_diffusion.data.diffusion.data_loader import ( + LammpsForDiffusionDataModule, LammpsLoaderParameters) +from crystal_diffusion.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell +from crystal_diffusion.utils.logging_utils import setup_analysis_logger + +logger = logging.getLogger(__name__) +dataset_name = 'si_diffusion_2x2x2' +# dataset_name = 'si_diffusion_1x1x1' + +output_dir = ANALYSIS_RESULTS_DIR / "covariances" +output_dir.mkdir(exist_ok=True) + + +if dataset_name == 'si_diffusion_1x1x1': + max_atom = 8 + translation = torch.tensor([0.125, 0.125, 0.125]) +elif dataset_name == 'si_diffusion_2x2x2': + max_atom = 64 + translation = torch.tensor([0.0625, 0.0625, 0.0625]) + +lammps_run_dir = DATA_DIR / dataset_name +processed_dataset_dir = lammps_run_dir / "processed" + +cache_dir = lammps_run_dir / "cache" + +data_params = LammpsLoaderParameters(batch_size=2048, max_atom=max_atom) + +if __name__ == '__main__': + setup_analysis_logger() + logger.info(f"Computing the covariance matrix for {dataset_name}") + + datamodule = LammpsForDiffusionDataModule( + lammps_run_dir=lammps_run_dir, + processed_dataset_dir=processed_dataset_dir, + hyper_params=data_params, + working_cache_dir=cache_dir, + ) + datamodule.setup() + + train_dataset = datamodule.train_dataset + + list_means = [] + for batch in tqdm(datamodule.train_dataloader(), "Mean"): + x = map_relative_coordinates_to_unit_cell(batch['relative_coordinates'] + translation) + list_means.append(x.mean(dim=0)) + + # Drop the last batch, which might not have dimension batch_size + x0 = torch.stack(list_means[:-1]).mean(dim=0) + + list_covariances = [] + list_sizes = [] + for batch in tqdm(datamodule.train_dataloader(), "displacements"): + x = map_relative_coordinates_to_unit_cell(batch['relative_coordinates'] + translation) + list_sizes.append(x.shape[0]) + displacements = einops.rearrange(x - x0, "batch natoms space -> batch (natoms space)") + covariance = (displacements[:, None, :] * displacements[:, :, None]).sum(dim=0) + list_covariances.append(covariance) + + covariance = torch.stack(list_covariances).sum(dim=0) / sum(list_sizes) + + output_file = output_dir / f"covariance_{dataset_name}.pkl" + logger.info(f"Writing to file {output_file}...") + with open(output_file, 'wb') as fd: + torch.save(covariance, fd) From ac7ebd8e06b0f369d06e32fbd5db5ad96e94389a Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 15 Oct 2024 10:17:09 -0400 Subject: [PATCH 12/18] Plot phonon eigenvalues from the covariance matrix. --- .../dataset_analysis/plot_si_phonon_DOS.py | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 experiment_analysis/dataset_analysis/plot_si_phonon_DOS.py diff --git a/experiment_analysis/dataset_analysis/plot_si_phonon_DOS.py b/experiment_analysis/dataset_analysis/plot_si_phonon_DOS.py new file mode 100644 index 00000000..3d719b22 --- /dev/null +++ b/experiment_analysis/dataset_analysis/plot_si_phonon_DOS.py @@ -0,0 +1,86 @@ +"""Silicon phonon Density of States. + +The displacement covariance is related to the phonon dynamical matrix. +Here we extract the corresponding phonon density of state, based on this covariance, +to see if the energy scales match up. +""" +import matplotlib.pyplot as plt +import torch + +from crystal_diffusion import ANALYSIS_RESULTS_DIR +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH + +plt.style.use(PLOT_STYLE_PATH) + +# Define some constants. +kelvin_in_Ha = 0.000_003_166_78 +T_in_kelvin = 300.0 +bohr_in_angst = 0.529177 +Si_mass = 28.0855 +proton_mass = 1836.152673426 + +Ha_in_meV = 27211.0 + +THz_in_meV = 4.136 + +acell = 5.43 + + +dataset_name_2x2x2 = "si_diffusion_2x2x2" +dataset_name_1x1x1 = "si_diffusion_1x1x1" + +output_dir = ANALYSIS_RESULTS_DIR / "covariances" +output_dir.mkdir(exist_ok=True) + + +if __name__ == "__main__": + kBT = kelvin_in_Ha * T_in_kelvin + a = acell / bohr_in_angst + + M = Si_mass * proton_mass + + constant_1x1x1 = M * a**2 / kBT / Ha_in_meV**2 + constant_2x2x2 = M * (2.0 * a) ** 2 / kBT / Ha_in_meV**2 + + covariance_file_1x1x1 = output_dir / f"covariance_{dataset_name_1x1x1}.pkl" + sigma_1x1x1 = torch.load(covariance_file_1x1x1) + sigma_inv_1x1x1 = torch.linalg.pinv(sigma_1x1x1) + Omega_1x1x1 = sigma_inv_1x1x1 / constant_1x1x1 + omega2_1x1x1 = torch.linalg.eigvalsh(Omega_1x1x1) + list_omega_in_meV_1x1x1 = torch.sqrt(torch.abs(omega2_1x1x1)) + + covariance_file_2x2x2 = output_dir / f"covariance_{dataset_name_2x2x2}.pkl" + sigma_2x2x2 = torch.load(covariance_file_2x2x2) + sigma_inv_2x2x2 = torch.linalg.pinv(sigma_2x2x2) + Omega_2x2x2 = sigma_inv_2x2x2 / constant_2x2x2 + omega2_2x2x2 = torch.linalg.eigvalsh(Omega_2x2x2) + list_omega_in_meV_2x2x2 = torch.sqrt(torch.abs(omega2_2x2x2)) + + max_hw = torch.max(list_omega_in_meV_2x2x2) / THz_in_meV + + bins = torch.linspace(0.0, max_hw, 50) + + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig.suptitle("Eigenvalues of Dynamical Matrix, from Displacement Covariance") + ax = fig.add_subplot(111) + + ax.set_xlim(0, max_hw + 1) + ax.set_xlabel(r"$\hbar \omega$ (THz)") + ax.set_ylabel("Count") + + ax.hist( + list_omega_in_meV_1x1x1 / THz_in_meV, + bins=bins, + label="Si 1x1x1", + color="green", + alpha=0.5, + ) + ax.hist( + list_omega_in_meV_2x2x2 / THz_in_meV, + bins=bins, + label="Si 2x2x2", + color="blue", + alpha=0.25, + ) + ax.legend(loc=0) + plt.show() From 16d852c4faa437ae991680c27a408b439cda42ef Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 15 Oct 2024 10:24:23 -0400 Subject: [PATCH 13/18] Some analysis knick knack. --- .../draw_samples_from_equilibrium.py | 193 +++++++++++------- .../plot_hessian_eigenvalues.py | 115 +++++++---- .../plot_score_norm.py | 5 +- .../score_stability_analysis/util.py | 33 +-- 4 files changed, 215 insertions(+), 131 deletions(-) diff --git a/experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py b/experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py index 2cfd6fa7..700b0b4c 100644 --- a/experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py +++ b/experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py @@ -5,18 +5,20 @@ import numpy as np import torch -from crystal_diffusion.analysis import PLOT_STYLE_PATH, PLEASANT_FIG_SIZE +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH from crystal_diffusion.analysis.analytic_score.exploring_langevin_generator.generate_sample_energies import \ EnergyCalculator from crystal_diffusion.analysis.analytic_score.utils import \ get_silicon_supercell from crystal_diffusion.generators.langevin_generator import LangevinGenerator -from crystal_diffusion.generators.ode_position_generator import ExplodingVarianceODEPositionGenerator, \ - ODESamplingParameters -from crystal_diffusion.generators.predictor_corrector_position_generator import PredictorCorrectorSamplingParameters -from crystal_diffusion.generators.sde_position_generator import ExplodingVarianceSDEPositionGenerator, \ - SDESamplingParameters -from crystal_diffusion.models.position_diffusion_lightning_model import PositionDiffusionLightningModel +from crystal_diffusion.generators.ode_position_generator import ( + ExplodingVarianceODEPositionGenerator, ODESamplingParameters) +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + PredictorCorrectorSamplingParameters +from crystal_diffusion.generators.sde_position_generator import ( + ExplodingVarianceSDEPositionGenerator, SDESamplingParameters) +from crystal_diffusion.models.position_diffusion_lightning_model import \ + PositionDiffusionLightningModel from crystal_diffusion.models.score_networks import ScoreNetwork from crystal_diffusion.samplers.variance_sampler import NoiseParameters from crystal_diffusion.utils.logging_utils import setup_analysis_logger @@ -28,65 +30,87 @@ class ForcedStartingPointLangevinGenerator(LangevinGenerator): - - def __init__(self, - noise_parameters: NoiseParameters, - sampling_parameters: PredictorCorrectorSamplingParameters, - sigma_normalized_score_network: ScoreNetwork, - starting_relative_coordinates: torch.Tensor - ): - super().__init__(noise_parameters, sampling_parameters, sigma_normalized_score_network) + """Langevin Generator with Forced Starting point.""" + def __init__( + self, + noise_parameters: NoiseParameters, + sampling_parameters: PredictorCorrectorSamplingParameters, + sigma_normalized_score_network: ScoreNetwork, + starting_relative_coordinates: torch.Tensor, + ): + """Init method.""" + super().__init__( + noise_parameters, sampling_parameters, sigma_normalized_score_network + ) self._starting_relative_coordinates = starting_relative_coordinates def initialize(self, number_of_samples: int): """This method must initialize the samples from the fully noised distribution.""" - relative_coordinates = einops.repeat(self._starting_relative_coordinates, - "natoms space -> batch_size natoms space", - batch_size=number_of_samples) + relative_coordinates = einops.repeat( + self._starting_relative_coordinates, + "natoms space -> batch_size natoms space", + batch_size=number_of_samples, + ) return relative_coordinates -class ForcedStartingPointODEPositionGenerator(ExplodingVarianceODEPositionGenerator): - def __init__(self, - noise_parameters: NoiseParameters, - sampling_parameters: ODESamplingParameters, - sigma_normalized_score_network: ScoreNetwork, - starting_relative_coordinates: torch.Tensor - ): - super().__init__(noise_parameters, sampling_parameters, sigma_normalized_score_network) +class ForcedStartingPointODEPositionGenerator(ExplodingVarianceODEPositionGenerator): + """Forced starting point ODE position generator.""" + def __init__( + self, + noise_parameters: NoiseParameters, + sampling_parameters: ODESamplingParameters, + sigma_normalized_score_network: ScoreNetwork, + starting_relative_coordinates: torch.Tensor, + ): + """Init method.""" + super().__init__( + noise_parameters, sampling_parameters, sigma_normalized_score_network + ) self._starting_relative_coordinates = starting_relative_coordinates def initialize(self, number_of_samples: int): """This method must initialize the samples from the fully noised distribution.""" - relative_coordinates = einops.repeat(self._starting_relative_coordinates, - "natoms space -> batch_size natoms space", - batch_size=number_of_samples) + relative_coordinates = einops.repeat( + self._starting_relative_coordinates, + "natoms space -> batch_size natoms space", + batch_size=number_of_samples, + ) return relative_coordinates class ForcedStartingPointSDEPositionGenerator(ExplodingVarianceSDEPositionGenerator): - def __init__(self, - noise_parameters: NoiseParameters, - sampling_parameters: SDESamplingParameters, - sigma_normalized_score_network: ScoreNetwork, - starting_relative_coordinates: torch.Tensor - ): - super().__init__(noise_parameters, sampling_parameters, sigma_normalized_score_network) + """Forced Starting Point SDE position generator.""" + def __init__( + self, + noise_parameters: NoiseParameters, + sampling_parameters: SDESamplingParameters, + sigma_normalized_score_network: ScoreNetwork, + starting_relative_coordinates: torch.Tensor, + ): + """Init method.""" + super().__init__( + noise_parameters, sampling_parameters, sigma_normalized_score_network + ) self._starting_relative_coordinates = starting_relative_coordinates def initialize(self, number_of_samples: int): """This method must initialize the samples from the fully noised distribution.""" - relative_coordinates = einops.repeat(self._starting_relative_coordinates, - "natoms space -> batch_size natoms space", - batch_size=number_of_samples) + relative_coordinates = einops.repeat( + self._starting_relative_coordinates, + "natoms space -> batch_size natoms space", + batch_size=number_of_samples, + ) return relative_coordinates -checkpoint_path = ("/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/" - "output/last_model/last_model-epoch=049-step=039100.ckpt") +checkpoint_path = ( + "/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/" + "output/last_model/last_model-epoch=049-step=039100.ckpt" +) spatial_dimension = 3 number_of_atoms = 8 @@ -96,7 +120,7 @@ def initialize(self, number_of_samples: int): total_time_steps = 1000 number_of_corrector_steps = 10 -epsilon =2.0e-7 +epsilon = 2.0e-7 noise_parameters = NoiseParameters( total_time_steps=total_time_steps, corrector_step_epsilon=epsilon, @@ -105,25 +129,31 @@ def initialize(self, number_of_samples: int): ) number_of_samples = 1000 -base_sampling_parameters_dict = dict(number_of_atoms=number_of_atoms, - spatial_dimension=spatial_dimension, - cell_dimensions=[acell, acell, acell], - number_of_samples=number_of_samples) +base_sampling_parameters_dict = dict( + number_of_atoms=number_of_atoms, + spatial_dimension=spatial_dimension, + cell_dimensions=[acell, acell, acell], + number_of_samples=number_of_samples, +) -ode_sampling_parameters = ODESamplingParameters(absolute_solver_tolerance=1.0e-5, - relative_solver_tolerance=1.0e-5, - **base_sampling_parameters_dict) +ode_sampling_parameters = ODESamplingParameters( + absolute_solver_tolerance=1.0e-5, + relative_solver_tolerance=1.0e-5, + **base_sampling_parameters_dict, +) # Fiddling with SDE is PITA. Also, is there a bug in there? -sde_sampling_parameters = SDESamplingParameters(adaptative=False, **base_sampling_parameters_dict) +sde_sampling_parameters = SDESamplingParameters( + adaptative=False, **base_sampling_parameters_dict +) -langevin_sampling_parameters = PredictorCorrectorSamplingParameters(number_of_corrector_steps=number_of_corrector_steps, - **base_sampling_parameters_dict) +langevin_sampling_parameters = PredictorCorrectorSamplingParameters( + number_of_corrector_steps=number_of_corrector_steps, **base_sampling_parameters_dict +) device = torch.device("cuda") if __name__ == "__main__": - basis_vectors = torch.diag(torch.tensor([acell, acell, acell])).to(device) logger.info("Loading checkpoint...") @@ -135,38 +165,49 @@ def initialize(self, number_of_samples: int): for parameter in sigma_normalized_score_network.parameters(): parameter.requires_grad_(False) - equilibrium_relative_coordinates = torch.from_numpy( - get_silicon_supercell(supercell_factor=1) - ).to(torch.float32).to(device) + equilibrium_relative_coordinates = ( + torch.from_numpy(get_silicon_supercell(supercell_factor=1)) + .to(torch.float32) + .to(device) + ) ode_generator = ForcedStartingPointODEPositionGenerator( noise_parameters=noise_parameters, sampling_parameters=ode_sampling_parameters, sigma_normalized_score_network=sigma_normalized_score_network, - starting_relative_coordinates=equilibrium_relative_coordinates) + starting_relative_coordinates=equilibrium_relative_coordinates, + ) sde_generator = ForcedStartingPointSDEPositionGenerator( noise_parameters=noise_parameters, sampling_parameters=sde_sampling_parameters, sigma_normalized_score_network=sigma_normalized_score_network, - starting_relative_coordinates=equilibrium_relative_coordinates) + starting_relative_coordinates=equilibrium_relative_coordinates, + ) langevin_generator = ForcedStartingPointLangevinGenerator( noise_parameters=noise_parameters, sampling_parameters=langevin_sampling_parameters, sigma_normalized_score_network=sigma_normalized_score_network, - starting_relative_coordinates=equilibrium_relative_coordinates) + starting_relative_coordinates=equilibrium_relative_coordinates, + ) unit_cells = einops.repeat(basis_vectors, "s1 s2 -> b s1 s2", b=number_of_samples) - ode_samples = ode_generator.sample(number_of_samples=number_of_samples, device=device, unit_cell=unit_cells) - sde_samples = sde_generator.sample(number_of_samples=number_of_samples, device=device, unit_cell=unit_cells) + ode_samples = ode_generator.sample( + number_of_samples=number_of_samples, device=device, unit_cell=unit_cells + ) + sde_samples = sde_generator.sample( + number_of_samples=number_of_samples, device=device, unit_cell=unit_cells + ) - langevin_samples = langevin_generator.sample(number_of_samples=number_of_samples, - device=device, - unit_cell=unit_cells) + langevin_samples = langevin_generator.sample( + number_of_samples=number_of_samples, device=device, unit_cell=unit_cells + ) - energy_calculator = EnergyCalculator(unit_cell=basis_vectors, number_of_atoms=number_of_atoms) + energy_calculator = EnergyCalculator( + unit_cell=basis_vectors, number_of_atoms=number_of_atoms + ) ode_energies = energy_calculator.compute_oracle_energies(ode_samples) sde_energies = energy_calculator.compute_oracle_energies(sde_samples) @@ -183,25 +224,33 @@ def initialize(self, number_of_samples: int): list_energies = [ode_energies, sde_energies, langevin_energies] - list_colors = ['blue', 'red', 'green'] + list_colors = ["blue", "red", "green"] - langevin_label = (f"LANGEVIN (time steps = {total_time_steps}, " - f"corrector steps = {number_of_corrector_steps}, epsilon ={epsilon: 5.2e})") + langevin_label = ( + f"LANGEVIN (time steps = {total_time_steps}, " + f"corrector steps = {number_of_corrector_steps}, epsilon ={epsilon: 5.2e})" + ) list_labels = ["ODE", "SDE", langevin_label] for ax in [ax1, ax2]: - for energies, label in zip(list_energies, list_labels): quantiles = np.quantile(energies, list_q) - ax.plot( 100 * list_q, quantiles, "-", label=label) + ax.plot(100 * list_q, quantiles, "-", label=label) - ax.fill_between([0, 100], y1=-34.6, y2=-34.1, color='yellow', alpha=0.25, label='Training Energy Range') + ax.fill_between( + [0, 100], + y1=-34.6, + y2=-34.1, + color="yellow", + alpha=0.25, + label="Training Energy Range", + ) ax.set_xlabel("Quantile (%)") ax.set_ylabel("Energy (eV)") ax.set_xlim(-0.1, 100.1) - ax1.set_ylim(-35, -34.) + ax1.set_ylim(-35, -34.0) ax2.legend(loc="upper right", fancybox=True, shadow=True, ncol=1, fontsize=6) fig.tight_layout() diff --git a/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py index 785d6714..b48669be 100644 --- a/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py +++ b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py @@ -8,23 +8,29 @@ from torch.func import jacrev from tqdm import tqdm -from crystal_diffusion.analysis import PLOT_STYLE_PATH, PLEASANT_FIG_SIZE -from crystal_diffusion.analysis.analytic_score.utils import get_silicon_supercell -from crystal_diffusion.models.position_diffusion_lightning_model import PositionDiffusionLightningModel +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.analysis.analytic_score.utils import \ + get_silicon_supercell +from crystal_diffusion.models.position_diffusion_lightning_model import \ + PositionDiffusionLightningModel from crystal_diffusion.samplers.exploding_variance import ExplodingVariance from crystal_diffusion.samplers.variance_sampler import NoiseParameters from crystal_diffusion.utils.logging_utils import setup_analysis_logger -from experiment_analysis.score_stability_analysis.util import get_normalized_score_function +from experiment_analysis.score_stability_analysis.util import \ + get_normalized_score_function plt.style.use(PLOT_STYLE_PATH) logger = logging.getLogger(__name__) setup_analysis_logger() -system = 'Si_1x1x1' +system = "Si_1x1x1" -if system == 'Si_1x1x1': - checkpoint_path = Path("/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/output/last_model/last_model-epoch=049-step=039100.ckpt") +if system == "Si_1x1x1": + checkpoint_path = Path( + "/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/output/" + "last_model/last_model-epoch=049-step=039100.ckpt" + ) number_of_atoms = 8 acell = 5.43 supercell_factor = 1 @@ -32,7 +38,7 @@ hessian_batch_size = 10 -elif system == 'Si_2x2x2': +elif system == "Si_2x2x2": pickle_path = Path("/home/mila/r/rousseab/scratch/checkpoints/sota_egnn_2x2x2.pkl") number_of_atoms = 64 acell = 10.86 @@ -55,22 +61,25 @@ nsteps = 501 -device = torch.device('cuda') +device = torch.device("cuda") if __name__ == "__main__": variance_calculator = ExplodingVariance(noise_parameters) basis_vectors = torch.diag(torch.tensor([acell, acell, acell])).to(device) - equilibrium_relative_coordinates = torch.from_numpy( - get_silicon_supercell(supercell_factor=supercell_factor)).to(torch.float32).to(device) + equilibrium_relative_coordinates = ( + torch.from_numpy(get_silicon_supercell(supercell_factor=supercell_factor)) + .to(torch.float32) + .to(device) + ) logger.info("Loading checkpoint...") - if system == 'Si_1x1x1': + if system == "Si_1x1x1": pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) pl_model.eval() sigma_normalized_score_network = pl_model.sigma_normalized_score_network - elif system == 'Si_2x2x2': + elif system == "Si_2x2x2": sigma_normalized_score_network = torch.load(pickle_path) for parameter in sigma_normalized_score_network.parameters(): @@ -79,7 +88,8 @@ normalized_score_function = get_normalized_score_function( noise_parameters=noise_parameters, sigma_normalized_score_network=sigma_normalized_score_network, - basis_vectors=basis_vectors) + basis_vectors=basis_vectors, + ) times = torch.linspace(1, 0, nsteps).unsqueeze(-1) sigmas = variance_calculator.get_sigma(times) @@ -87,23 +97,35 @@ prefactor = -g2 / sigmas - relative_coordinates = einops.repeat(equilibrium_relative_coordinates, "n s -> b n s", b=nsteps) + relative_coordinates = einops.repeat( + equilibrium_relative_coordinates, "n s -> b n s", b=nsteps + ) batch_hessian_function = jacrev(normalized_score_function, argnums=0) list_flat_hessians = [] - for x, t in tqdm(zip(torch.split(relative_coordinates, hessian_batch_size), - torch.split(times, hessian_batch_size)), "Hessian"): + for x, t in tqdm( + zip( + torch.split(relative_coordinates, hessian_batch_size), + torch.split(times, hessian_batch_size), + ), + "Hessian", + ): batch_hessian = batch_hessian_function(x, t) - flat_hessian = einops.rearrange(torch.diagonal(batch_hessian, dim1=0, dim2=3), - "n1 s1 n2 s2 b -> b (n1 s1) (n2 s2)") + flat_hessian = einops.rearrange( + torch.diagonal(batch_hessian, dim1=0, dim2=3), + "n1 s1 n2 s2 b -> b (n1 s1) (n2 s2)", + ) list_flat_hessians.append(flat_hessian) flat_hessian = torch.concat(list_flat_hessians) - p = einops.repeat(prefactor,"b 1 -> b d1 d2", - d1=number_of_atoms * spatial_dimension, - d2=number_of_atoms * spatial_dimension).to(flat_hessian) + p = einops.repeat( + prefactor, + "b 1 -> b d1 d2", + d1=number_of_atoms * spatial_dimension, + d2=number_of_atoms * spatial_dimension, + ).to(flat_hessian) normalized_hessian = p * flat_hessian @@ -114,25 +136,24 @@ list_times = times.flatten().cpu() list_sigmas = sigmas.flatten().cpu() - fig = plt.figure(figsize=(PLEASANT_FIG_SIZE[0], PLEASANT_FIG_SIZE[0])) fig.suptitle("Hessian Eigenvalues") ax1 = fig.add_subplot(311) ax2 = fig.add_subplot(312) ax3 = fig.add_subplot(313) - ax3.set_xlabel(r'$\sigma(t)$') - ax3.set_ylabel('Small Count') + ax3.set_xlabel(r"$\sigma(t)$") + ax3.set_ylabel("Small Count") ax3.set_ylim(0, number_of_atoms * spatial_dimension) - ax3.semilogx(list_sigmas, small_count, '-', color='black') + ax3.semilogx(list_sigmas, small_count, "-", color="black") for ax in [ax1, ax2]: - ax.set_xlabel(r'$\sigma(t)$') - ax.set_ylabel('Eigenvalue') + ax.set_xlabel(r"$\sigma(t)$") + ax.set_ylabel("Eigenvalue") for list_e in eigenvalues: - ax.semilogx(list_sigmas, list_e, '-', color='grey') + ax.semilogx(list_sigmas, list_e, "-", color="grey") ax2.set_ylim([-2.5e-4, 2.5e-4]) for ax in [ax1, ax2, ax3]: @@ -145,11 +166,11 @@ fig2.suptitle("Hessian Eigenvalues At Small Time") ax1 = fig2.add_subplot(111) - ax1.set_xlabel(r'$\sigma(t)$') - ax1.set_ylabel('Eigenvalues') + ax1.set_xlabel(r"$\sigma(t)$") + ax1.set_ylabel("Eigenvalues") for list_e in eigenvalues: - ax1.loglog(list_sigmas, list_e, '-', color='grey') + ax1.loglog(list_sigmas, list_e, "-", color="grey") ax1.set_xlim(sigma_min, 1e-2) @@ -157,28 +178,34 @@ plt.show() - fig3 = plt.figure(figsize=PLEASANT_FIG_SIZE) fig3.suptitle("Hessian Eigenvalues of Normalized Score") ax1 = fig3.add_subplot(121) ax2 = fig3.add_subplot(122) - ax1.set_xlabel(r'$\sigma(t)$') - ax1.set_ylabel('Eigenvalues') + ax1.set_xlabel(r"$\sigma(t)$") + ax1.set_ylabel("Eigenvalues") - ax2.set_ylabel(r'$g^2(t) / \sigma(t)$') - ax2.set_xlabel(r'$\sigma(t)$') + ax2.set_ylabel(r"$g^2(t) / \sigma(t)$") + ax2.set_xlabel(r"$\sigma(t)$") - label1 = r'$\sigma(t)/g^2 \times \bf H$' - label2 = r'$\bf H$' + label1 = r"$\sigma(t)/g^2 \times \bf H$" + label2 = r"$\bf H$" for list_e in eigenvalues: - ax1.semilogx(list_sigmas, list_e/(-prefactor.flatten()), '-', lw=1, color='red', label=label1) - ax1.semilogx(list_sigmas, list_e, '-', color='grey', lw=1, label=label2) - label1 = '__nolabel__' - label2 = '__nolabel__' + ax1.semilogx( + list_sigmas, + list_e / (-prefactor.flatten()), + "-", + lw=1, + color="red", + label=label1, + ) + ax1.semilogx(list_sigmas, list_e, "-", color="grey", lw=1, label=label2) + label1 = "__nolabel__" + label2 = "__nolabel__" ax1.legend(loc=0) - ax2.semilogx(list_sigmas, (-prefactor.flatten()), '-', color='blue') + ax2.semilogx(list_sigmas, (-prefactor.flatten()), "-", color="blue") for ax in [ax1, ax2]: ax.set_xlim(sigma_min, sigma_max) diff --git a/experiment_analysis/score_stability_analysis/plot_score_norm.py b/experiment_analysis/score_stability_analysis/plot_score_norm.py index c44f96c2..ef4ddbc7 100644 --- a/experiment_analysis/score_stability_analysis/plot_score_norm.py +++ b/experiment_analysis/score_stability_analysis/plot_score_norm.py @@ -9,7 +9,8 @@ from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH from crystal_diffusion.analysis.analytic_score.utils import \ get_silicon_supercell -from crystal_diffusion.models.position_diffusion_lightning_model import PositionDiffusionLightningModel +from crystal_diffusion.models.position_diffusion_lightning_model import \ + PositionDiffusionLightningModel from crystal_diffusion.samplers.exploding_variance import ExplodingVariance from crystal_diffusion.samplers.variance_sampler import NoiseParameters from crystal_diffusion.utils.basis_transformations import \ @@ -67,7 +68,7 @@ # Put two particles on top of each other dv = equilibrium_relative_coordinates[0] - equilibrium_relative_coordinates[1] direction[0] = -0.5 * dv - direction[1] = 0.5 * dv + direction[1] = 0.5 * dv list_delta = torch.linspace(0., 2.0, 201) relative_coordinates = [] diff --git a/experiment_analysis/score_stability_analysis/util.py b/experiment_analysis/score_stability_analysis/util.py index efec96f5..6e272894 100644 --- a/experiment_analysis/score_stability_analysis/util.py +++ b/experiment_analysis/score_stability_analysis/util.py @@ -1,10 +1,8 @@ import itertools -from typing import Callable, Tuple +from typing import Callable import einops -import numpy as np import torch -from torch.func import jacrev from crystal_diffusion.models.score_networks import ScoreNetwork from crystal_diffusion.namespace import (CARTESIAN_FORCES, NOISE, @@ -12,17 +10,19 @@ UNIT_CELL) from crystal_diffusion.samplers.exploding_variance import ExplodingVariance from crystal_diffusion.samplers.variance_sampler import NoiseParameters -from crystal_diffusion.utils.basis_transformations import \ - map_relative_coordinates_to_unit_cell -def get_normalized_score_function( noise_parameters: NoiseParameters, - sigma_normalized_score_network: ScoreNetwork, - basis_vectors: torch.Tensor)-> Callable: +def get_normalized_score_function( + noise_parameters: NoiseParameters, + sigma_normalized_score_network: ScoreNetwork, + basis_vectors: torch.Tensor, +) -> Callable: + """Get normalizd score function.""" variance_calculator = ExplodingVariance(noise_parameters) - def normalized_score_function(relative_coordinates: torch.Tensor, times: torch.Tensor) -> torch.Tensor: - + def normalized_score_function( + relative_coordinates: torch.Tensor, times: torch.Tensor + ) -> torch.Tensor: batch_size, number_of_atoms, spatial_dimension = relative_coordinates.shape unit_cells = einops.repeat( basis_vectors.to(relative_coordinates), "s1 s2 -> b s1 s2", b=batch_size @@ -47,12 +47,19 @@ def normalized_score_function(relative_coordinates: torch.Tensor, times: torch.T return normalized_score_function + def get_cubic_point_group_symmetries(): - permutations = [torch.diag(torch.ones(3))[[idx]] for idx in itertools.permutations([0, 1, 2])] - sign_changes = [torch.diag(torch.tensor(diag)) for diag in itertools.product([-1., 1.], repeat=3)] + """Get cubic point group symmetries.""" + permutations = [ + torch.diag(torch.ones(3))[[idx]] for idx in itertools.permutations([0, 1, 2]) + ] + sign_changes = [ + torch.diag(torch.tensor(diag)) + for diag in itertools.product([-1.0, 1.0], repeat=3) + ] symmetries = [] for permutation in permutations: for sign_change in sign_changes: symmetries.append(permutation @ sign_change) - return symmetries \ No newline at end of file + return symmetries From 6fd1cd8c64f29a45a7c53be898ea25b015d50e58 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Wed, 16 Oct 2024 19:34:31 -0400 Subject: [PATCH 14/18] removing needless file --- .../fun_times_with_point_groups.py | 176 ------------------ 1 file changed, 176 deletions(-) delete mode 100644 experiment_analysis/score_stability_analysis/fun_times_with_point_groups.py diff --git a/experiment_analysis/score_stability_analysis/fun_times_with_point_groups.py b/experiment_analysis/score_stability_analysis/fun_times_with_point_groups.py deleted file mode 100644 index 28df3ca7..00000000 --- a/experiment_analysis/score_stability_analysis/fun_times_with_point_groups.py +++ /dev/null @@ -1,176 +0,0 @@ -import logging -from pathlib import Path - -import matplotlib.pyplot as plt -import torch -from einops import einops - -from torch.func import jacrev - -from crystal_diffusion.analysis import PLOT_STYLE_PATH -from crystal_diffusion.analysis.analytic_score.utils import get_silicon_supercell -from crystal_diffusion.models.position_diffusion_lightning_model import PositionDiffusionLightningModel -from crystal_diffusion.samplers.variance_sampler import NoiseParameters -from crystal_diffusion.utils.basis_transformations import map_relative_coordinates_to_unit_cell -from crystal_diffusion.utils.logging_utils import setup_analysis_logger -from experiment_analysis.score_stability_analysis.util import get_cubic_point_group_symmetries, \ - get_normalized_score_function -from tests.fake_data_utils import find_aligning_permutation - -plt.style.use(PLOT_STYLE_PATH) - -logger = logging.getLogger(__name__) -setup_analysis_logger() - -checkpoint_path = Path("/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/output/last_model/last_model-epoch=049-step=039100.ckpt") - -spatial_dimension = 3 -number_of_atoms = 8 - -acell = 5.43 - -total_time_steps = 1000 -sigma_min = 0.0001 -sigma_max = 0.2 -noise_parameters = NoiseParameters( - total_time_steps=total_time_steps, - sigma_min=sigma_min, - sigma_max=sigma_max, -) - -device = torch.device('cuda') -if __name__ == "__main__": - equilibrium_relative_coordinates = torch.from_numpy(get_silicon_supercell(supercell_factor=1)).to(torch.float32) - - list_g = get_cubic_point_group_symmetries() - nsym = len(list_g) - - # Find the group operations that leave the set of equilibrium coordinates unchanged - list_stabilizing_g = [] - list_permutations = [] - - for g in list_g: - x = map_relative_coordinates_to_unit_cell(equilibrium_relative_coordinates @ g.transpose(1, 0)) - try: - permutation = find_aligning_permutation(equilibrium_relative_coordinates, x, tol=1e-8) - list_stabilizing_g.append(g) - permutation_op = torch.diag(torch.ones(number_of_atoms))[permutation, :] - list_permutations.append(permutation_op) - print("Found stabilizing operation.") - except: - continue - - # Confirm that the equilibrium coordinates are indeed left unchanged. - for g, p in zip(list_stabilizing_g, list_permutations): - x = map_relative_coordinates_to_unit_cell(equilibrium_relative_coordinates @ g.transpose(1, 0)) - error = torch.linalg.norm(equilibrium_relative_coordinates - p @ x) - torch.testing.assert_close(error, torch.tensor(0.0)) - - - # build operators in the flat number_of_atoms x spatial_dimension space - list_op = [] - for g, p in zip(list_stabilizing_g, list_permutations): - flat_g = torch.block_diag(*(number_of_atoms * [g])) - flat_p = torch.zeros(number_of_atoms * spatial_dimension, number_of_atoms * spatial_dimension) - for i in range(number_of_atoms): - for j in range(number_of_atoms): - if p[i, j] == 1: - for k in range(spatial_dimension): - flat_p[spatial_dimension * i + k, spatial_dimension * j + k] = 1. - - op = flat_p @ flat_g - list_op.append(op) - - # Double check that positions are left invariant - x0 = einops.rearrange(equilibrium_relative_coordinates, "n d -> (n d)") - for op in list_op: - x = map_relative_coordinates_to_unit_cell(op @ x0) - torch.testing.assert_close(torch.norm(x - x0), torch.tensor(0.)) - - # Double check that the operators form a group - # Inverse is present - for op in list_op: - inv_op = op.transpose(1, 0) - - found = False - for check_op in list_op: - if torch.linalg.norm(inv_op - check_op) < 1e-8: - found = True - assert found - - # closed to product - for op1 in list_op: - for op2 in list_op: - new_op = op1 @ op2 - - found = False - for op3 in list_op: - if torch.linalg.norm(new_op - op3) < 1e-8: - found = True - assert found - - - times = torch.ones(nsym).unsqueeze(-1) - logger.info("Loading checkpoint...") - pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) - pl_model.eval() - - sigma_normalized_score_network = pl_model.sigma_normalized_score_network - - for parameter in sigma_normalized_score_network.parameters(): - parameter.requires_grad_(False) - - basis_vectors = torch.diag(torch.tensor([acell, acell, acell])).to(device) - - normalized_score_function = get_normalized_score_function( - noise_parameters=noise_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, - basis_vectors=basis_vectors) - - - normalized_scores = normalized_score_function(relative_coordinates, times) - - batch_hessian_function = jacrev(normalized_score_function, argnums=0) - - hessian_batch_size = 10 - list_flat_hessians = [] - for x, t in zip(torch.split(relative_coordinates, hessian_batch_size), torch.split(times, hessian_batch_size)): - batch_hessian = batch_hessian_function(x, t) - flat_hessian = einops.rearrange(torch.diagonal(batch_hessian, dim1=0, dim2=3), - "n1 s1 n2 s2 b -> b (n1 s1) (n2 s2)") - list_flat_hessians.append(flat_hessian) - - flat_hessian = torch.concat(list_flat_hessians) - - - list_flat_g = [] - for g in list_g: - flat_g = torch.block_diag(*(number_of_atoms * [g])) - list_flat_g.append(flat_g) - - list_flat_g = torch.stack(list_flat_g).to(device) - - identity_id = 7 # by inspection - - x0 = relative_coordinates[identity_id] - h0 = flat_hessian[identity_id] - - list_errors = [] - for h, g in zip(flat_hessian, list_flat_g): - new_h = (g.transpose(1, 0) @ h) @ g - error = (new_h - h0).abs().max() - - list_errors.append(error) - - print(torch.tensor(list_errors).max()) - - n = number_of_atoms * spatial_dimension - random_h = torch.zeros(n, n).to(device) - - for g in list_flat_g: - r = torch.rand(n, n).to(device) - r = 0.5 * (r + r.transpose(1, 0)) - - random_h += (g.transpose(1, 0) @ r) @ g - - random_h = random_h / len(list_g) From 1dcb211a9169e56f2bde1077b4823204e55f0cd2 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 17 Oct 2024 10:28:53 -0400 Subject: [PATCH 15/18] A proper script to draw samples from a checkpoint. --- crystal_diffusion/sample_diffusion.py | 193 ++++++++++++++++++++++++++ tests/test_sample_diffusion.py | 157 +++++++++++++++++++++ 2 files changed, 350 insertions(+) create mode 100644 crystal_diffusion/sample_diffusion.py create mode 100644 tests/test_sample_diffusion.py diff --git a/crystal_diffusion/sample_diffusion.py b/crystal_diffusion/sample_diffusion.py new file mode 100644 index 00000000..81f09e8c --- /dev/null +++ b/crystal_diffusion/sample_diffusion.py @@ -0,0 +1,193 @@ +"""Sample Diffusion. + +This script is the entry point to draw samples from a pre-trained model checkpoint. +""" +import argparse +import logging +import os +import socket +from pathlib import Path +from typing import Any, AnyStr, Dict, Optional, Union + +import torch + +from crystal_diffusion.generators.instantiate_generator import \ + instantiate_generator +from crystal_diffusion.generators.load_sampling_parameters import \ + load_sampling_parameters +from crystal_diffusion.generators.position_generator import SamplingParameters +from crystal_diffusion.main_utils import load_and_backup_hyperparameters +from crystal_diffusion.models.position_diffusion_lightning_model import \ + PositionDiffusionLightningModel +from crystal_diffusion.models.score_networks import ScoreNetwork +from crystal_diffusion.oracle.energies import compute_oracle_energies +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.samples.sampling import create_batch_of_samples +from crystal_diffusion.utils.logging_utils import (get_git_hash, + setup_console_logger) + +logger = logging.getLogger(__name__) + + +def main(args: Optional[Any] = None): + """Load a diffusion model and draw samples. + + This main.py file is meant to be called using the cli. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + required=True, + help="config file with sampling parameters in yaml format.", + ) + parser.add_argument( + "--checkpoint", required=True, help="path to checkpoint model to be loaded." + ) + parser.add_argument( + "--output", required=True, help="path to outputs - will store files here" + ) + parser.add_argument( + "--device", default="gpu", help="Device to use. Defaults to cuda." + ) + args = parser.parse_args(args) + if os.path.exists(args.output): + logger.info(f"WARNING: the output directory {args.output} already exists!") + else: + os.makedirs(args.output) + + setup_console_logger(experiment_dir=args.output) + assert os.path.exists( + args.checkpoint + ), f"The path {args.checkpoint} does not exist. Cannot go on." + + script_location = os.path.realpath(__file__) + git_hash = get_git_hash(script_location) + hostname = socket.gethostname() + logger.info("Sampling Experiment info:") + logger.info(f" Hostname : {hostname}") + logger.info(f" Git Hash : {git_hash}") + logger.info(f" Checkpoint : {args.checkpoint}") + + # Very opinionated logger, which writes to the output folder. + logger.info(f"Start Generating Samples with checkpoint {args.checkpoint}") + + hyper_params = load_and_backup_hyperparameters( + config_file_path=args.config, output_directory=args.output + ) + + device = torch.device(args.device) + noise_parameters, sampling_parameters = extract_and_validate_parameters( + hyper_params + ) + + create_samples_and_write_to_disk( + noise_parameters=noise_parameters, + sampling_parameters=sampling_parameters, + device=device, + checkpoint_path=args.checkpoint, + output_path=args.output, + ) + + +def extract_and_validate_parameters(hyper_params: Dict[AnyStr, Any]): + """Extract and validate parameters. + + Args: + hyper_params : Dictionary of hyper-parameters for drawing samples. + + Returns: + noise_parameters: object that defines the noise schedule + sampling_parameters: object that defines how to draw samples, and how many. + """ + assert ( + "noise" in hyper_params + ), "The noise parameters must be defined to draw samples." + noise_parameters = NoiseParameters(**hyper_params["noise"]) + + assert ( + "sampling" in hyper_params + ), "The sampling parameters must be defined to draw samples." + sampling_parameters = load_sampling_parameters(hyper_params["sampling"]) + + return noise_parameters, sampling_parameters + + +def get_sigma_normalized_score_network( + checkpoint_path: Union[str, Path] +) -> ScoreNetwork: + """Get sigma-normalized score network. + + Args: + checkpoint_path : path where the checkpoint is written. + + Returns: + sigma_normalized score network: read from the checkpoint. + """ + logger.info("Loading checkpoint...") + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + return sigma_normalized_score_network + + +def create_samples_and_write_to_disk( + noise_parameters: NoiseParameters, + sampling_parameters: SamplingParameters, + device: torch.device, + checkpoint_path: Union[str, Path], + output_path: Union[str, Path], +): + """Create Samples and write to disk. + + Method that drives the creation of samples. + + Args: + noise_parameters: object that defines the noise schedule + sampling_parameters: object that defines how to draw samples, and how many. + device: which device should be used to draw samples. + checkpoint_path : path to checkpoint of model to be loaded. + output_path: where the outputs should be written. + + Returns: + None + """ + sigma_normalized_score_network = get_sigma_normalized_score_network(checkpoint_path) + + logger.info("Instantiate generator...") + position_generator = instantiate_generator( + sampling_parameters=sampling_parameters, + noise_parameters=noise_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + ) + + logger.info("Generating samples...") + with torch.no_grad(): + samples_batch = create_batch_of_samples( + generator=position_generator, + sampling_parameters=sampling_parameters, + device=device, + ) + logger.info("Done Generating Samples.") + + logger.info("Writing samples to disk...") + output_directory = Path(output_path) + with open(output_directory / "samples.pt", "wb") as fd: + torch.save(samples_batch, fd) + + logger.info("Compute energy from Oracle...") + sample_energies = compute_oracle_energies(samples_batch) + + logger.info("Writing energies to disk...") + with open(output_directory / "energies.pt", "wb") as fd: + torch.save(sample_energies, fd) + + if sampling_parameters.record_samples: + logger.info("Writing sampling trajectories to disk...") + position_generator.sample_trajectory_recorder.write_to_pickle( + output_directory / "trajectories.pt" + ) + + +if __name__ == "__main__": + main() diff --git a/tests/test_sample_diffusion.py b/tests/test_sample_diffusion.py new file mode 100644 index 00000000..aa652e40 --- /dev/null +++ b/tests/test_sample_diffusion.py @@ -0,0 +1,157 @@ +import dataclasses + +import pytest +import torch +import yaml + +from crystal_diffusion import sample_diffusion +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + PredictorCorrectorSamplingParameters +from crystal_diffusion.models.loss import MSELossParameters +from crystal_diffusion.models.optimizer import OptimizerParameters +from crystal_diffusion.models.position_diffusion_lightning_model import ( + PositionDiffusionLightningModel, PositionDiffusionParameters) +from crystal_diffusion.models.score_networks.mlp_score_network import \ + MLPScoreNetworkParameters +from crystal_diffusion.namespace import RELATIVE_COORDINATES +from crystal_diffusion.samplers.variance_sampler import NoiseParameters + + +@pytest.fixture() +def spatial_dimension(): + return 3 + + +@pytest.fixture() +def number_of_atoms(): + return 8 + + +@pytest.fixture() +def number_of_samples(): + return 12 + + +@pytest.fixture() +def cell_dimensions(): + return [5.1, 6.2, 7.3] + + +@pytest.fixture(params=[True, False]) +def record_samples(request): + return request.param + + +@pytest.fixture() +def noise_parameters(): + return NoiseParameters(total_time_steps=10) + + +@pytest.fixture() +def sampling_parameters( + number_of_atoms, + spatial_dimension, + number_of_samples, + cell_dimensions, + record_samples, +): + return PredictorCorrectorSamplingParameters( + number_of_corrector_steps=1, + spatial_dimension=spatial_dimension, + number_of_atoms=number_of_atoms, + number_of_samples=number_of_samples, + cell_dimensions=cell_dimensions, + record_samples=record_samples, + ) + + +@pytest.fixture() +def sigma_normalized_score_network(number_of_atoms, noise_parameters): + score_network_parameters = MLPScoreNetworkParameters( + number_of_atoms=number_of_atoms, + embedding_dimensions_size=8, + n_hidden_dimensions=2, + hidden_dimensions_size=16, + ) + + diffusion_params = PositionDiffusionParameters( + score_network_parameters=score_network_parameters, + loss_parameters=MSELossParameters(), + optimizer_parameters=OptimizerParameters(name="adam", learning_rate=1e-3), + scheduler_parameters=None, + noise_parameters=noise_parameters, + diffusion_sampling_parameters=None, + ) + + model = PositionDiffusionLightningModel(diffusion_params) + return model.sigma_normalized_score_network + + +@pytest.fixture() +def config_path(tmp_path, noise_parameters, sampling_parameters): + config_path = str(tmp_path / "test_config.yaml") + + config = dict( + noise=dataclasses.asdict(noise_parameters), + sampling=dataclasses.asdict(sampling_parameters), + ) + + with open(config_path, "w") as fd: + yaml.dump(config, fd) + + return config_path + + +@pytest.fixture() +def checkpoint_path(tmp_path): + path_to_checkpoint = tmp_path / "fake_checkpoint.pt" + with open(path_to_checkpoint, "w") as fd: + fd.write("This is a dummy checkpoint file.") + return path_to_checkpoint + + +@pytest.fixture() +def output_path(tmp_path): + output = tmp_path / "output" + return output + + +@pytest.fixture() +def args(config_path, checkpoint_path, output_path): + """Input arguments for main.""" + input_args = [ + f"--config={config_path}", + f"--checkpoint={checkpoint_path}", + f"--output={output_path}", + "--device=cpu", + ] + + return input_args + + +def test_sample_diffusion( + mocker, + args, + sigma_normalized_score_network, + output_path, + number_of_samples, + number_of_atoms, + spatial_dimension, + record_samples, +): + mocker.patch( + "crystal_diffusion.sample_diffusion.get_sigma_normalized_score_network", + return_value=sigma_normalized_score_network, + ) + + sample_diffusion.main(args) + + assert (output_path / "samples.pt").exists() + samples = torch.load(output_path / "samples.pt") + assert samples[RELATIVE_COORDINATES].shape == ( + number_of_samples, + number_of_atoms, + spatial_dimension, + ) + + assert (output_path / "trajectories.pt").exists() == record_samples From 8d315cdabaf4a9d2549936e6e88b794714d5b12f Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 17 Oct 2024 16:57:44 -0400 Subject: [PATCH 16/18] More rigid radial cutoff parameters in EGNN. --- .../score_networks/egnn_score_network.py | 12 +++++++-- .../diffusion/config_diffusion_egnn.yaml | 1 + .../score_network/test_score_network.py | 25 ++++++++++++++++--- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/crystal_diffusion/models/score_networks/egnn_score_network.py b/crystal_diffusion/models/score_networks/egnn_score_network.py index 24f30d7a..17f6d39e 100644 --- a/crystal_diffusion/models/score_networks/egnn_score_network.py +++ b/crystal_diffusion/models/score_networks/egnn_score_network.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import AnyStr, Dict +from typing import AnyStr, Dict, Union import einops import torch @@ -31,7 +31,7 @@ class EGNNScoreNetworkParameters(ScoreNetworkParameters): message_agg: str = "mean" n_layers: int = 4 edges: str = 'fully_connected' - radial_cutoff: float = 4.0 + radial_cutoff: Union[float, None] = None drop_duplicate_edges: bool = True @@ -60,7 +60,15 @@ def __init__(self, hyper_params: EGNNScoreNetworkParameters): self.edges = hyper_params.edges assert self.edges in ["fully_connected", "radial_cutoff"], \ f'Edges type should be fully_connected or radial_cutoff. Got {self.edges}' + self.radial_cutoff = hyper_params.radial_cutoff + + if self.edges == "fully_connected": + assert self.radial_cutoff is None, "Specifying a radial cutoff is inconsistent with edges=fully_connected." + else: + assert type(self.radial_cutoff) is float, \ + "A floating point value for the radial cutoff is needed for edges=radial_cutoff." + self.drop_duplicate_edges = hyper_params.drop_duplicate_edges self.egnn = EGNN( diff --git a/examples/config_files/diffusion/config_diffusion_egnn.yaml b/examples/config_files/diffusion/config_diffusion_egnn.yaml index 31ab42b7..4d04b0b4 100644 --- a/examples/config_files/diffusion/config_diffusion_egnn.yaml +++ b/examples/config_files/diffusion/config_diffusion_egnn.yaml @@ -33,6 +33,7 @@ model: normalize: True residual: True tanh: False + edges: fully_connected noise: total_time_steps: 1000 sigma_min: 0.0001 diff --git a/tests/models/score_network/test_score_network.py b/tests/models/score_network/test_score_network.py index b5445cdf..fcb54d0c 100644 --- a/tests/models/score_network/test_score_network.py +++ b/tests/models/score_network/test_score_network.py @@ -301,7 +301,6 @@ def score_network(self, score_network_parameters): return DiffusionMACEScoreNetwork(score_network_parameters) -@pytest.mark.parametrize("spatial_dimension", [3]) class TestEGNNScoreNetwork(BaseTestScoreNetwork): @pytest.fixture(scope="class", autouse=True) @@ -314,8 +313,21 @@ def set_default_type_to_float64(self): torch.set_default_dtype(torch.float32) @pytest.fixture() - def score_network_parameters(self): - return EGNNScoreNetworkParameters() # Use the defaults + def spatial_dimension(self): + return 3 + + @pytest.fixture() + def basis_vectors(self, batch_size, spatial_dimension): + # The basis vectors should form a cube in order to test the equivariance of the current implementation + # of the EGNN model. The octaheral point group only applies in this case! + acell = 5.5 + cubes = torch.stack([torch.diag(acell * torch.ones(spatial_dimension)) for _ in range(batch_size)]) + return cubes + + @pytest.fixture(params=[("fully_connected", None), ("radial_cutoff", 3.0)]) + def score_network_parameters(self, request): + edges, radial_cutoff = request.param + return EGNNScoreNetworkParameters(edges=edges, radial_cutoff=radial_cutoff) @pytest.fixture() def score_network(self, score_network_parameters): @@ -334,6 +346,13 @@ def octahedral_point_group_symmetries(self): return symmetries + @pytest.mark.parametrize("edges, radial_cutoff", [("fully_connected", 3.0), ("radial_cutoff", None)]) + def test_score_network_parameters(self, edges, radial_cutoff): + score_network_parameters = EGNNScoreNetworkParameters(edges=edges, radial_cutoff=radial_cutoff) + with pytest.raises(AssertionError): + # Check that the code crashes when inconsistent parameters are fed in. + EGNNScoreNetwork(score_network_parameters) + def test_create_block_diagonal_projection_matrices(self, score_network, spatial_dimension): expected_matrices = [] for space_idx in range(spatial_dimension): From fe7e86942ae46cb0b37627b9902da835567a4944 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 17 Oct 2024 19:42:06 -0400 Subject: [PATCH 17/18] Cap the amount of samples that will be kept in memory. --- .../metrics/kolmogorov_smirnov_metrics.py | 22 +++++++++++++++---- .../position_diffusion_lightning_model.py | 4 +++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/crystal_diffusion/metrics/kolmogorov_smirnov_metrics.py b/crystal_diffusion/metrics/kolmogorov_smirnov_metrics.py index 5141ff50..7af4e714 100644 --- a/crystal_diffusion/metrics/kolmogorov_smirnov_metrics.py +++ b/crystal_diffusion/metrics/kolmogorov_smirnov_metrics.py @@ -7,23 +7,37 @@ class KolmogorovSmirnovMetrics: """Kolmogorov Smirnov metrics.""" - def __init__(self): - """Init method.""" + def __init__(self, maximum_number_of_samples: int = 1_000_000): + """Init method. + + Args: + maximum_number_of_samples : maximum number of samples that will be aggregated. This is to avoid + memory use explosion. + """ self.reference_samples_metric = CatMetric() self.predicted_samples_metric = CatMetric() + self.maximum_count = maximum_number_of_samples + self.reference_count = 0 + self.predicted_count = 0 def register_reference_samples(self, reference_samples): """Register reference samples.""" - self.reference_samples_metric.update(reference_samples) + if self.reference_count < self.maximum_count: + self.reference_count += len(reference_samples) + self.reference_samples_metric.update(reference_samples) def register_predicted_samples(self, predicted_samples): """Register predicted samples.""" - self.predicted_samples_metric.update(predicted_samples) + if self.predicted_count < self.maximum_count: + self.predicted_count += len(predicted_samples) + self.predicted_samples_metric.update(predicted_samples) def reset(self): """reset.""" self.reference_samples_metric.reset() self.predicted_samples_metric.reset() + self.reference_count = 0 + self.predicted_count = 0 def compute_kolmogorov_smirnov_distance_and_pvalue(self) -> Tuple[float, float]: """Compute Kolmogorov Smirnov Distance. diff --git a/crystal_diffusion/models/position_diffusion_lightning_model.py b/crystal_diffusion/models/position_diffusion_lightning_model.py index 43974c4f..2fc5e6e1 100644 --- a/crystal_diffusion/models/position_diffusion_lightning_model.py +++ b/crystal_diffusion/models/position_diffusion_lightning_model.py @@ -387,6 +387,7 @@ def on_validation_epoch_end(self) -> None: logger.info(" * Registering sample energies") self.energy_ks_metric.register_predicted_samples(sample_energies.cpu()) + logger.info(" * Computing KS distance for energies") ( ks_distance, p_value, @@ -400,7 +401,7 @@ def on_validation_epoch_end(self) -> None: self.log( "validation_ks_p_value_energy", p_value, on_step=False, on_epoch=True ) - logger.info(" * Done logging sample energies") + logger.info(" * Done logging KS distance for energies") if self.draw_samples and self.metrics_parameters.compute_structure_factor: logger.info(" * Computing sample distances") @@ -413,6 +414,7 @@ def on_validation_epoch_end(self) -> None: logger.info(" * Registering sample distances") self.structure_ks_metric.register_predicted_samples(sample_distances.cpu()) + logger.info(" * Computing KS distance for distances") ( ks_distance, p_value, From 0e10484fdd3f0dbb9be405a478fcb40c342835c8 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Sat, 19 Oct 2024 07:22:54 -0400 Subject: [PATCH 18/18] Fix bug. --- crystal_diffusion/sample_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crystal_diffusion/sample_diffusion.py b/crystal_diffusion/sample_diffusion.py index 81f09e8c..ea598c63 100644 --- a/crystal_diffusion/sample_diffusion.py +++ b/crystal_diffusion/sample_diffusion.py @@ -47,7 +47,7 @@ def main(args: Optional[Any] = None): "--output", required=True, help="path to outputs - will store files here" ) parser.add_argument( - "--device", default="gpu", help="Device to use. Defaults to cuda." + "--device", default="cuda", help="Device to use. Defaults to cuda." ) args = parser.parse_args(args) if os.path.exists(args.output):