diff --git a/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py b/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py index 02340d4d..64d5c4fc 100644 --- a/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py +++ b/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py @@ -49,7 +49,7 @@ def compute_oracle_energies(self, batch_relative_coordinates: torch.Tensor) -> n logger.info("Compute energy from Oracle") with tempfile.TemporaryDirectory() as tmp_work_dir: - for positions, box in zip(batch_cartesian_positions.numpy(), batched_unit_cells.numpy()): + for positions, box in zip(batch_cartesian_positions.cpu().numpy(), batched_unit_cells.cpu().numpy()): energy, forces = get_energy_and_forces_from_lammps(positions, box, self.atom_types, diff --git a/crystal_diffusion/metrics/kolmogorov_smirnov_metrics.py b/crystal_diffusion/metrics/kolmogorov_smirnov_metrics.py index 5141ff50..7af4e714 100644 --- a/crystal_diffusion/metrics/kolmogorov_smirnov_metrics.py +++ b/crystal_diffusion/metrics/kolmogorov_smirnov_metrics.py @@ -7,23 +7,37 @@ class KolmogorovSmirnovMetrics: """Kolmogorov Smirnov metrics.""" - def __init__(self): - """Init method.""" + def __init__(self, maximum_number_of_samples: int = 1_000_000): + """Init method. + + Args: + maximum_number_of_samples : maximum number of samples that will be aggregated. This is to avoid + memory use explosion. + """ self.reference_samples_metric = CatMetric() self.predicted_samples_metric = CatMetric() + self.maximum_count = maximum_number_of_samples + self.reference_count = 0 + self.predicted_count = 0 def register_reference_samples(self, reference_samples): """Register reference samples.""" - self.reference_samples_metric.update(reference_samples) + if self.reference_count < self.maximum_count: + self.reference_count += len(reference_samples) + self.reference_samples_metric.update(reference_samples) def register_predicted_samples(self, predicted_samples): """Register predicted samples.""" - self.predicted_samples_metric.update(predicted_samples) + if self.predicted_count < self.maximum_count: + self.predicted_count += len(predicted_samples) + self.predicted_samples_metric.update(predicted_samples) def reset(self): """reset.""" self.reference_samples_metric.reset() self.predicted_samples_metric.reset() + self.reference_count = 0 + self.predicted_count = 0 def compute_kolmogorov_smirnov_distance_and_pvalue(self) -> Tuple[float, float]: """Compute Kolmogorov Smirnov Distance. diff --git a/crystal_diffusion/models/position_diffusion_lightning_model.py b/crystal_diffusion/models/position_diffusion_lightning_model.py index 43974c4f..2fc5e6e1 100644 --- a/crystal_diffusion/models/position_diffusion_lightning_model.py +++ b/crystal_diffusion/models/position_diffusion_lightning_model.py @@ -387,6 +387,7 @@ def on_validation_epoch_end(self) -> None: logger.info(" * Registering sample energies") self.energy_ks_metric.register_predicted_samples(sample_energies.cpu()) + logger.info(" * Computing KS distance for energies") ( ks_distance, p_value, @@ -400,7 +401,7 @@ def on_validation_epoch_end(self) -> None: self.log( "validation_ks_p_value_energy", p_value, on_step=False, on_epoch=True ) - logger.info(" * Done logging sample energies") + logger.info(" * Done logging KS distance for energies") if self.draw_samples and self.metrics_parameters.compute_structure_factor: logger.info(" * Computing sample distances") @@ -413,6 +414,7 @@ def on_validation_epoch_end(self) -> None: logger.info(" * Registering sample distances") self.structure_ks_metric.register_predicted_samples(sample_distances.cpu()) + logger.info(" * Computing KS distance for distances") ( ks_distance, p_value, diff --git a/crystal_diffusion/models/score_networks/egnn_score_network.py b/crystal_diffusion/models/score_networks/egnn_score_network.py index 24f30d7a..17f6d39e 100644 --- a/crystal_diffusion/models/score_networks/egnn_score_network.py +++ b/crystal_diffusion/models/score_networks/egnn_score_network.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import AnyStr, Dict +from typing import AnyStr, Dict, Union import einops import torch @@ -31,7 +31,7 @@ class EGNNScoreNetworkParameters(ScoreNetworkParameters): message_agg: str = "mean" n_layers: int = 4 edges: str = 'fully_connected' - radial_cutoff: float = 4.0 + radial_cutoff: Union[float, None] = None drop_duplicate_edges: bool = True @@ -60,7 +60,15 @@ def __init__(self, hyper_params: EGNNScoreNetworkParameters): self.edges = hyper_params.edges assert self.edges in ["fully_connected", "radial_cutoff"], \ f'Edges type should be fully_connected or radial_cutoff. Got {self.edges}' + self.radial_cutoff = hyper_params.radial_cutoff + + if self.edges == "fully_connected": + assert self.radial_cutoff is None, "Specifying a radial cutoff is inconsistent with edges=fully_connected." + else: + assert type(self.radial_cutoff) is float, \ + "A floating point value for the radial cutoff is needed for edges=radial_cutoff." + self.drop_duplicate_edges = hyper_params.drop_duplicate_edges self.egnn = EGNN( diff --git a/crystal_diffusion/sample_diffusion.py b/crystal_diffusion/sample_diffusion.py new file mode 100644 index 00000000..ea598c63 --- /dev/null +++ b/crystal_diffusion/sample_diffusion.py @@ -0,0 +1,193 @@ +"""Sample Diffusion. + +This script is the entry point to draw samples from a pre-trained model checkpoint. +""" +import argparse +import logging +import os +import socket +from pathlib import Path +from typing import Any, AnyStr, Dict, Optional, Union + +import torch + +from crystal_diffusion.generators.instantiate_generator import \ + instantiate_generator +from crystal_diffusion.generators.load_sampling_parameters import \ + load_sampling_parameters +from crystal_diffusion.generators.position_generator import SamplingParameters +from crystal_diffusion.main_utils import load_and_backup_hyperparameters +from crystal_diffusion.models.position_diffusion_lightning_model import \ + PositionDiffusionLightningModel +from crystal_diffusion.models.score_networks import ScoreNetwork +from crystal_diffusion.oracle.energies import compute_oracle_energies +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.samples.sampling import create_batch_of_samples +from crystal_diffusion.utils.logging_utils import (get_git_hash, + setup_console_logger) + +logger = logging.getLogger(__name__) + + +def main(args: Optional[Any] = None): + """Load a diffusion model and draw samples. + + This main.py file is meant to be called using the cli. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + required=True, + help="config file with sampling parameters in yaml format.", + ) + parser.add_argument( + "--checkpoint", required=True, help="path to checkpoint model to be loaded." + ) + parser.add_argument( + "--output", required=True, help="path to outputs - will store files here" + ) + parser.add_argument( + "--device", default="cuda", help="Device to use. Defaults to cuda." + ) + args = parser.parse_args(args) + if os.path.exists(args.output): + logger.info(f"WARNING: the output directory {args.output} already exists!") + else: + os.makedirs(args.output) + + setup_console_logger(experiment_dir=args.output) + assert os.path.exists( + args.checkpoint + ), f"The path {args.checkpoint} does not exist. Cannot go on." + + script_location = os.path.realpath(__file__) + git_hash = get_git_hash(script_location) + hostname = socket.gethostname() + logger.info("Sampling Experiment info:") + logger.info(f" Hostname : {hostname}") + logger.info(f" Git Hash : {git_hash}") + logger.info(f" Checkpoint : {args.checkpoint}") + + # Very opinionated logger, which writes to the output folder. + logger.info(f"Start Generating Samples with checkpoint {args.checkpoint}") + + hyper_params = load_and_backup_hyperparameters( + config_file_path=args.config, output_directory=args.output + ) + + device = torch.device(args.device) + noise_parameters, sampling_parameters = extract_and_validate_parameters( + hyper_params + ) + + create_samples_and_write_to_disk( + noise_parameters=noise_parameters, + sampling_parameters=sampling_parameters, + device=device, + checkpoint_path=args.checkpoint, + output_path=args.output, + ) + + +def extract_and_validate_parameters(hyper_params: Dict[AnyStr, Any]): + """Extract and validate parameters. + + Args: + hyper_params : Dictionary of hyper-parameters for drawing samples. + + Returns: + noise_parameters: object that defines the noise schedule + sampling_parameters: object that defines how to draw samples, and how many. + """ + assert ( + "noise" in hyper_params + ), "The noise parameters must be defined to draw samples." + noise_parameters = NoiseParameters(**hyper_params["noise"]) + + assert ( + "sampling" in hyper_params + ), "The sampling parameters must be defined to draw samples." + sampling_parameters = load_sampling_parameters(hyper_params["sampling"]) + + return noise_parameters, sampling_parameters + + +def get_sigma_normalized_score_network( + checkpoint_path: Union[str, Path] +) -> ScoreNetwork: + """Get sigma-normalized score network. + + Args: + checkpoint_path : path where the checkpoint is written. + + Returns: + sigma_normalized score network: read from the checkpoint. + """ + logger.info("Loading checkpoint...") + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + return sigma_normalized_score_network + + +def create_samples_and_write_to_disk( + noise_parameters: NoiseParameters, + sampling_parameters: SamplingParameters, + device: torch.device, + checkpoint_path: Union[str, Path], + output_path: Union[str, Path], +): + """Create Samples and write to disk. + + Method that drives the creation of samples. + + Args: + noise_parameters: object that defines the noise schedule + sampling_parameters: object that defines how to draw samples, and how many. + device: which device should be used to draw samples. + checkpoint_path : path to checkpoint of model to be loaded. + output_path: where the outputs should be written. + + Returns: + None + """ + sigma_normalized_score_network = get_sigma_normalized_score_network(checkpoint_path) + + logger.info("Instantiate generator...") + position_generator = instantiate_generator( + sampling_parameters=sampling_parameters, + noise_parameters=noise_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + ) + + logger.info("Generating samples...") + with torch.no_grad(): + samples_batch = create_batch_of_samples( + generator=position_generator, + sampling_parameters=sampling_parameters, + device=device, + ) + logger.info("Done Generating Samples.") + + logger.info("Writing samples to disk...") + output_directory = Path(output_path) + with open(output_directory / "samples.pt", "wb") as fd: + torch.save(samples_batch, fd) + + logger.info("Compute energy from Oracle...") + sample_energies = compute_oracle_energies(samples_batch) + + logger.info("Writing energies to disk...") + with open(output_directory / "energies.pt", "wb") as fd: + torch.save(sample_energies, fd) + + if sampling_parameters.record_samples: + logger.info("Writing sampling trajectories to disk...") + position_generator.sample_trajectory_recorder.write_to_pickle( + output_directory / "trajectories.pt" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/config_files/diffusion/config_diffusion_egnn.yaml b/examples/config_files/diffusion/config_diffusion_egnn.yaml index 31ab42b7..4d04b0b4 100644 --- a/examples/config_files/diffusion/config_diffusion_egnn.yaml +++ b/examples/config_files/diffusion/config_diffusion_egnn.yaml @@ -33,6 +33,7 @@ model: normalize: True residual: True tanh: False + edges: fully_connected noise: total_time_steps: 1000 sigma_min: 0.0001 diff --git a/experiment_analysis/dataset_analysis/dataset_covariance.py b/experiment_analysis/dataset_analysis/dataset_covariance.py new file mode 100644 index 00000000..eed61e7e --- /dev/null +++ b/experiment_analysis/dataset_analysis/dataset_covariance.py @@ -0,0 +1,78 @@ +"""Effective Dataset Variance. + +The goal of this script is to compute the effective "sigma_d" of the +actual datasets, that is, the standard deviation of the displacement +from equilibrium, in fractional coordinates. +""" +import logging + +import einops +import torch +from tqdm import tqdm + +from crystal_diffusion import ANALYSIS_RESULTS_DIR, DATA_DIR +from crystal_diffusion.data.diffusion.data_loader import ( + LammpsForDiffusionDataModule, LammpsLoaderParameters) +from crystal_diffusion.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell +from crystal_diffusion.utils.logging_utils import setup_analysis_logger + +logger = logging.getLogger(__name__) +dataset_name = 'si_diffusion_2x2x2' +# dataset_name = 'si_diffusion_1x1x1' + +output_dir = ANALYSIS_RESULTS_DIR / "covariances" +output_dir.mkdir(exist_ok=True) + + +if dataset_name == 'si_diffusion_1x1x1': + max_atom = 8 + translation = torch.tensor([0.125, 0.125, 0.125]) +elif dataset_name == 'si_diffusion_2x2x2': + max_atom = 64 + translation = torch.tensor([0.0625, 0.0625, 0.0625]) + +lammps_run_dir = DATA_DIR / dataset_name +processed_dataset_dir = lammps_run_dir / "processed" + +cache_dir = lammps_run_dir / "cache" + +data_params = LammpsLoaderParameters(batch_size=2048, max_atom=max_atom) + +if __name__ == '__main__': + setup_analysis_logger() + logger.info(f"Computing the covariance matrix for {dataset_name}") + + datamodule = LammpsForDiffusionDataModule( + lammps_run_dir=lammps_run_dir, + processed_dataset_dir=processed_dataset_dir, + hyper_params=data_params, + working_cache_dir=cache_dir, + ) + datamodule.setup() + + train_dataset = datamodule.train_dataset + + list_means = [] + for batch in tqdm(datamodule.train_dataloader(), "Mean"): + x = map_relative_coordinates_to_unit_cell(batch['relative_coordinates'] + translation) + list_means.append(x.mean(dim=0)) + + # Drop the last batch, which might not have dimension batch_size + x0 = torch.stack(list_means[:-1]).mean(dim=0) + + list_covariances = [] + list_sizes = [] + for batch in tqdm(datamodule.train_dataloader(), "displacements"): + x = map_relative_coordinates_to_unit_cell(batch['relative_coordinates'] + translation) + list_sizes.append(x.shape[0]) + displacements = einops.rearrange(x - x0, "batch natoms space -> batch (natoms space)") + covariance = (displacements[:, None, :] * displacements[:, :, None]).sum(dim=0) + list_covariances.append(covariance) + + covariance = torch.stack(list_covariances).sum(dim=0) / sum(list_sizes) + + output_file = output_dir / f"covariance_{dataset_name}.pkl" + logger.info(f"Writing to file {output_file}...") + with open(output_file, 'wb') as fd: + torch.save(covariance, fd) diff --git a/experiment_analysis/dataset_analysis/plot_si_phonon_DOS.py b/experiment_analysis/dataset_analysis/plot_si_phonon_DOS.py new file mode 100644 index 00000000..3d719b22 --- /dev/null +++ b/experiment_analysis/dataset_analysis/plot_si_phonon_DOS.py @@ -0,0 +1,86 @@ +"""Silicon phonon Density of States. + +The displacement covariance is related to the phonon dynamical matrix. +Here we extract the corresponding phonon density of state, based on this covariance, +to see if the energy scales match up. +""" +import matplotlib.pyplot as plt +import torch + +from crystal_diffusion import ANALYSIS_RESULTS_DIR +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH + +plt.style.use(PLOT_STYLE_PATH) + +# Define some constants. +kelvin_in_Ha = 0.000_003_166_78 +T_in_kelvin = 300.0 +bohr_in_angst = 0.529177 +Si_mass = 28.0855 +proton_mass = 1836.152673426 + +Ha_in_meV = 27211.0 + +THz_in_meV = 4.136 + +acell = 5.43 + + +dataset_name_2x2x2 = "si_diffusion_2x2x2" +dataset_name_1x1x1 = "si_diffusion_1x1x1" + +output_dir = ANALYSIS_RESULTS_DIR / "covariances" +output_dir.mkdir(exist_ok=True) + + +if __name__ == "__main__": + kBT = kelvin_in_Ha * T_in_kelvin + a = acell / bohr_in_angst + + M = Si_mass * proton_mass + + constant_1x1x1 = M * a**2 / kBT / Ha_in_meV**2 + constant_2x2x2 = M * (2.0 * a) ** 2 / kBT / Ha_in_meV**2 + + covariance_file_1x1x1 = output_dir / f"covariance_{dataset_name_1x1x1}.pkl" + sigma_1x1x1 = torch.load(covariance_file_1x1x1) + sigma_inv_1x1x1 = torch.linalg.pinv(sigma_1x1x1) + Omega_1x1x1 = sigma_inv_1x1x1 / constant_1x1x1 + omega2_1x1x1 = torch.linalg.eigvalsh(Omega_1x1x1) + list_omega_in_meV_1x1x1 = torch.sqrt(torch.abs(omega2_1x1x1)) + + covariance_file_2x2x2 = output_dir / f"covariance_{dataset_name_2x2x2}.pkl" + sigma_2x2x2 = torch.load(covariance_file_2x2x2) + sigma_inv_2x2x2 = torch.linalg.pinv(sigma_2x2x2) + Omega_2x2x2 = sigma_inv_2x2x2 / constant_2x2x2 + omega2_2x2x2 = torch.linalg.eigvalsh(Omega_2x2x2) + list_omega_in_meV_2x2x2 = torch.sqrt(torch.abs(omega2_2x2x2)) + + max_hw = torch.max(list_omega_in_meV_2x2x2) / THz_in_meV + + bins = torch.linspace(0.0, max_hw, 50) + + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig.suptitle("Eigenvalues of Dynamical Matrix, from Displacement Covariance") + ax = fig.add_subplot(111) + + ax.set_xlim(0, max_hw + 1) + ax.set_xlabel(r"$\hbar \omega$ (THz)") + ax.set_ylabel("Count") + + ax.hist( + list_omega_in_meV_1x1x1 / THz_in_meV, + bins=bins, + label="Si 1x1x1", + color="green", + alpha=0.5, + ) + ax.hist( + list_omega_in_meV_2x2x2 / THz_in_meV, + bins=bins, + label="Si 2x2x2", + color="blue", + alpha=0.25, + ) + ax.legend(loc=0) + plt.show() diff --git a/experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py b/experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py new file mode 100644 index 00000000..700b0b4c --- /dev/null +++ b/experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py @@ -0,0 +1,258 @@ +import logging + +import einops +import matplotlib.pyplot as plt +import numpy as np +import torch + +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.analysis.analytic_score.exploring_langevin_generator.generate_sample_energies import \ + EnergyCalculator +from crystal_diffusion.analysis.analytic_score.utils import \ + get_silicon_supercell +from crystal_diffusion.generators.langevin_generator import LangevinGenerator +from crystal_diffusion.generators.ode_position_generator import ( + ExplodingVarianceODEPositionGenerator, ODESamplingParameters) +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + PredictorCorrectorSamplingParameters +from crystal_diffusion.generators.sde_position_generator import ( + ExplodingVarianceSDEPositionGenerator, SDESamplingParameters) +from crystal_diffusion.models.position_diffusion_lightning_model import \ + PositionDiffusionLightningModel +from crystal_diffusion.models.score_networks import ScoreNetwork +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.utils.logging_utils import setup_analysis_logger + +plt.style.use(PLOT_STYLE_PATH) + +logger = logging.getLogger(__name__) +setup_analysis_logger() + + +class ForcedStartingPointLangevinGenerator(LangevinGenerator): + """Langevin Generator with Forced Starting point.""" + def __init__( + self, + noise_parameters: NoiseParameters, + sampling_parameters: PredictorCorrectorSamplingParameters, + sigma_normalized_score_network: ScoreNetwork, + starting_relative_coordinates: torch.Tensor, + ): + """Init method.""" + super().__init__( + noise_parameters, sampling_parameters, sigma_normalized_score_network + ) + + self._starting_relative_coordinates = starting_relative_coordinates + + def initialize(self, number_of_samples: int): + """This method must initialize the samples from the fully noised distribution.""" + relative_coordinates = einops.repeat( + self._starting_relative_coordinates, + "natoms space -> batch_size natoms space", + batch_size=number_of_samples, + ) + return relative_coordinates + + +class ForcedStartingPointODEPositionGenerator(ExplodingVarianceODEPositionGenerator): + """Forced starting point ODE position generator.""" + def __init__( + self, + noise_parameters: NoiseParameters, + sampling_parameters: ODESamplingParameters, + sigma_normalized_score_network: ScoreNetwork, + starting_relative_coordinates: torch.Tensor, + ): + """Init method.""" + super().__init__( + noise_parameters, sampling_parameters, sigma_normalized_score_network + ) + + self._starting_relative_coordinates = starting_relative_coordinates + + def initialize(self, number_of_samples: int): + """This method must initialize the samples from the fully noised distribution.""" + relative_coordinates = einops.repeat( + self._starting_relative_coordinates, + "natoms space -> batch_size natoms space", + batch_size=number_of_samples, + ) + return relative_coordinates + + +class ForcedStartingPointSDEPositionGenerator(ExplodingVarianceSDEPositionGenerator): + """Forced Starting Point SDE position generator.""" + def __init__( + self, + noise_parameters: NoiseParameters, + sampling_parameters: SDESamplingParameters, + sigma_normalized_score_network: ScoreNetwork, + starting_relative_coordinates: torch.Tensor, + ): + """Init method.""" + super().__init__( + noise_parameters, sampling_parameters, sigma_normalized_score_network + ) + + self._starting_relative_coordinates = starting_relative_coordinates + + def initialize(self, number_of_samples: int): + """This method must initialize the samples from the fully noised distribution.""" + relative_coordinates = einops.repeat( + self._starting_relative_coordinates, + "natoms space -> batch_size natoms space", + batch_size=number_of_samples, + ) + return relative_coordinates + + +checkpoint_path = ( + "/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/" + "output/last_model/last_model-epoch=049-step=039100.ckpt" +) + +spatial_dimension = 3 +number_of_atoms = 8 +atom_types = np.ones(number_of_atoms, dtype=int) + +acell = 5.43 + +total_time_steps = 1000 +number_of_corrector_steps = 10 +epsilon = 2.0e-7 +noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + corrector_step_epsilon=epsilon, + sigma_min=0.0001, + sigma_max=0.2, +) +number_of_samples = 1000 + +base_sampling_parameters_dict = dict( + number_of_atoms=number_of_atoms, + spatial_dimension=spatial_dimension, + cell_dimensions=[acell, acell, acell], + number_of_samples=number_of_samples, +) + +ode_sampling_parameters = ODESamplingParameters( + absolute_solver_tolerance=1.0e-5, + relative_solver_tolerance=1.0e-5, + **base_sampling_parameters_dict, +) + +# Fiddling with SDE is PITA. Also, is there a bug in there? +sde_sampling_parameters = SDESamplingParameters( + adaptative=False, **base_sampling_parameters_dict +) + + +langevin_sampling_parameters = PredictorCorrectorSamplingParameters( + number_of_corrector_steps=number_of_corrector_steps, **base_sampling_parameters_dict +) + +device = torch.device("cuda") +if __name__ == "__main__": + basis_vectors = torch.diag(torch.tensor([acell, acell, acell])).to(device) + + logger.info("Loading checkpoint...") + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + for parameter in sigma_normalized_score_network.parameters(): + parameter.requires_grad_(False) + + equilibrium_relative_coordinates = ( + torch.from_numpy(get_silicon_supercell(supercell_factor=1)) + .to(torch.float32) + .to(device) + ) + + ode_generator = ForcedStartingPointODEPositionGenerator( + noise_parameters=noise_parameters, + sampling_parameters=ode_sampling_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + starting_relative_coordinates=equilibrium_relative_coordinates, + ) + + sde_generator = ForcedStartingPointSDEPositionGenerator( + noise_parameters=noise_parameters, + sampling_parameters=sde_sampling_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + starting_relative_coordinates=equilibrium_relative_coordinates, + ) + + langevin_generator = ForcedStartingPointLangevinGenerator( + noise_parameters=noise_parameters, + sampling_parameters=langevin_sampling_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + starting_relative_coordinates=equilibrium_relative_coordinates, + ) + + unit_cells = einops.repeat(basis_vectors, "s1 s2 -> b s1 s2", b=number_of_samples) + + ode_samples = ode_generator.sample( + number_of_samples=number_of_samples, device=device, unit_cell=unit_cells + ) + sde_samples = sde_generator.sample( + number_of_samples=number_of_samples, device=device, unit_cell=unit_cells + ) + + langevin_samples = langevin_generator.sample( + number_of_samples=number_of_samples, device=device, unit_cell=unit_cells + ) + + energy_calculator = EnergyCalculator( + unit_cell=basis_vectors, number_of_atoms=number_of_atoms + ) + + ode_energies = energy_calculator.compute_oracle_energies(ode_samples) + sde_energies = energy_calculator.compute_oracle_energies(sde_samples) + langevin_energies = energy_calculator.compute_oracle_energies(langevin_samples) + + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig.suptitle("Energies of Samples Drawn from Equilibrium Coordinates") + ax1 = fig.add_subplot(121) + ax2 = fig.add_subplot(122) + ax1.set_title("Zoom In") + ax2.set_title("Broad View") + + list_q = np.linspace(0, 1, 101) + + list_energies = [ode_energies, sde_energies, langevin_energies] + + list_colors = ["blue", "red", "green"] + + langevin_label = ( + f"LANGEVIN (time steps = {total_time_steps}, " + f"corrector steps = {number_of_corrector_steps}, epsilon ={epsilon: 5.2e})" + ) + + list_labels = ["ODE", "SDE", langevin_label] + + for ax in [ax1, ax2]: + for energies, label in zip(list_energies, list_labels): + quantiles = np.quantile(energies, list_q) + ax.plot(100 * list_q, quantiles, "-", label=label) + + ax.fill_between( + [0, 100], + y1=-34.6, + y2=-34.1, + color="yellow", + alpha=0.25, + label="Training Energy Range", + ) + + ax.set_xlabel("Quantile (%)") + ax.set_ylabel("Energy (eV)") + ax.set_xlim(-0.1, 100.1) + ax1.set_ylim(-35, -34.0) + ax2.legend(loc="upper right", fancybox=True, shadow=True, ncol=1, fontsize=6) + + fig.tight_layout() + + plt.show() diff --git a/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py new file mode 100644 index 00000000..b48669be --- /dev/null +++ b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py @@ -0,0 +1,217 @@ +import logging +from pathlib import Path + +import einops +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.func import jacrev +from tqdm import tqdm + +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.analysis.analytic_score.utils import \ + get_silicon_supercell +from crystal_diffusion.models.position_diffusion_lightning_model import \ + PositionDiffusionLightningModel +from crystal_diffusion.samplers.exploding_variance import ExplodingVariance +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.utils.logging_utils import setup_analysis_logger +from experiment_analysis.score_stability_analysis.util import \ + get_normalized_score_function + +plt.style.use(PLOT_STYLE_PATH) + +logger = logging.getLogger(__name__) +setup_analysis_logger() + +system = "Si_1x1x1" + +if system == "Si_1x1x1": + checkpoint_path = Path( + "/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/output/" + "last_model/last_model-epoch=049-step=039100.ckpt" + ) + number_of_atoms = 8 + acell = 5.43 + supercell_factor = 1 + + hessian_batch_size = 10 + + +elif system == "Si_2x2x2": + pickle_path = Path("/home/mila/r/rousseab/scratch/checkpoints/sota_egnn_2x2x2.pkl") + number_of_atoms = 64 + acell = 10.86 + supercell_factor = 2 + hessian_batch_size = 1 + +spatial_dimension = 3 +atom_types = np.ones(number_of_atoms, dtype=int) + +total_time_steps = 1000 +sigma_min = 0.0001 +sigma_max = 0.2 +noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + sigma_min=sigma_min, + sigma_max=sigma_max, +) + + +nsteps = 501 + + +device = torch.device("cuda") +if __name__ == "__main__": + variance_calculator = ExplodingVariance(noise_parameters) + + basis_vectors = torch.diag(torch.tensor([acell, acell, acell])).to(device) + equilibrium_relative_coordinates = ( + torch.from_numpy(get_silicon_supercell(supercell_factor=supercell_factor)) + .to(torch.float32) + .to(device) + ) + + logger.info("Loading checkpoint...") + + if system == "Si_1x1x1": + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + elif system == "Si_2x2x2": + sigma_normalized_score_network = torch.load(pickle_path) + + for parameter in sigma_normalized_score_network.parameters(): + parameter.requires_grad_(False) + + normalized_score_function = get_normalized_score_function( + noise_parameters=noise_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + basis_vectors=basis_vectors, + ) + + times = torch.linspace(1, 0, nsteps).unsqueeze(-1) + sigmas = variance_calculator.get_sigma(times) + g2 = variance_calculator.get_g_squared(times) + + prefactor = -g2 / sigmas + + relative_coordinates = einops.repeat( + equilibrium_relative_coordinates, "n s -> b n s", b=nsteps + ) + + batch_hessian_function = jacrev(normalized_score_function, argnums=0) + + list_flat_hessians = [] + for x, t in tqdm( + zip( + torch.split(relative_coordinates, hessian_batch_size), + torch.split(times, hessian_batch_size), + ), + "Hessian", + ): + batch_hessian = batch_hessian_function(x, t) + flat_hessian = einops.rearrange( + torch.diagonal(batch_hessian, dim1=0, dim2=3), + "n1 s1 n2 s2 b -> b (n1 s1) (n2 s2)", + ) + list_flat_hessians.append(flat_hessian) + + flat_hessian = torch.concat(list_flat_hessians) + + p = einops.repeat( + prefactor, + "b 1 -> b d1 d2", + d1=number_of_atoms * spatial_dimension, + d2=number_of_atoms * spatial_dimension, + ).to(flat_hessian) + + normalized_hessian = p * flat_hessian + + eigenvalues, eigenvectors = torch.linalg.eigh(normalized_hessian) + eigenvalues = eigenvalues.cpu().transpose(1, 0) + + small_count = (eigenvalues < 5e-4).sum(dim=0) + list_times = times.flatten().cpu() + list_sigmas = sigmas.flatten().cpu() + + fig = plt.figure(figsize=(PLEASANT_FIG_SIZE[0], PLEASANT_FIG_SIZE[0])) + fig.suptitle("Hessian Eigenvalues") + ax1 = fig.add_subplot(311) + ax2 = fig.add_subplot(312) + ax3 = fig.add_subplot(313) + + ax3.set_xlabel(r"$\sigma(t)$") + ax3.set_ylabel("Small Count") + ax3.set_ylim(0, number_of_atoms * spatial_dimension) + + ax3.semilogx(list_sigmas, small_count, "-", color="black") + + for ax in [ax1, ax2]: + ax.set_xlabel(r"$\sigma(t)$") + ax.set_ylabel("Eigenvalue") + + for list_e in eigenvalues: + ax.semilogx(list_sigmas, list_e, "-", color="grey") + + ax2.set_ylim([-2.5e-4, 2.5e-4]) + for ax in [ax1, ax2, ax3]: + ax.set_xlim(sigma_min, sigma_max) + fig.tight_layout() + + plt.show() + + fig2 = plt.figure(figsize=(PLEASANT_FIG_SIZE[0], PLEASANT_FIG_SIZE[0])) + fig2.suptitle("Hessian Eigenvalues At Small Time") + ax1 = fig2.add_subplot(111) + + ax1.set_xlabel(r"$\sigma(t)$") + ax1.set_ylabel("Eigenvalues") + + for list_e in eigenvalues: + ax1.loglog(list_sigmas, list_e, "-", color="grey") + + ax1.set_xlim(sigma_min, 1e-2) + + fig2.tight_layout() + + plt.show() + + fig3 = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig3.suptitle("Hessian Eigenvalues of Normalized Score") + ax1 = fig3.add_subplot(121) + ax2 = fig3.add_subplot(122) + + ax1.set_xlabel(r"$\sigma(t)$") + ax1.set_ylabel("Eigenvalues") + + ax2.set_ylabel(r"$g^2(t) / \sigma(t)$") + ax2.set_xlabel(r"$\sigma(t)$") + + label1 = r"$\sigma(t)/g^2 \times \bf H$" + label2 = r"$\bf H$" + for list_e in eigenvalues: + ax1.semilogx( + list_sigmas, + list_e / (-prefactor.flatten()), + "-", + lw=1, + color="red", + label=label1, + ) + ax1.semilogx(list_sigmas, list_e, "-", color="grey", lw=1, label=label2) + label1 = "__nolabel__" + label2 = "__nolabel__" + + ax1.legend(loc=0) + ax2.semilogx(list_sigmas, (-prefactor.flatten()), "-", color="blue") + + for ax in [ax1, ax2]: + ax.set_xlim(sigma_min, sigma_max) + + fig3.tight_layout() + + plt.show() + + jacobian_eig_at_t0 = eigenvalues[:, -1] / (-prefactor[-1]) diff --git a/experiment_analysis/score_stability_analysis/plot_score_norm.py b/experiment_analysis/score_stability_analysis/plot_score_norm.py new file mode 100644 index 00000000..ef4ddbc7 --- /dev/null +++ b/experiment_analysis/score_stability_analysis/plot_score_norm.py @@ -0,0 +1,115 @@ +import logging + +import einops +import matplotlib.pyplot as plt +import numpy as np +import torch +from tqdm import tqdm + +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.analysis.analytic_score.utils import \ + get_silicon_supercell +from crystal_diffusion.models.position_diffusion_lightning_model import \ + PositionDiffusionLightningModel +from crystal_diffusion.samplers.exploding_variance import ExplodingVariance +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell +from crystal_diffusion.utils.logging_utils import setup_analysis_logger +from experiment_analysis.score_stability_analysis.util import \ + create_fixed_time_normalized_score_function + +plt.style.use(PLOT_STYLE_PATH) + +logger = logging.getLogger(__name__) +setup_analysis_logger() + + +checkpoint_path = ("/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/" + "output/last_model/last_model-epoch=049-step=039100.ckpt") + +spatial_dimension = 3 +number_of_atoms = 8 +atom_types = np.ones(number_of_atoms, dtype=int) + +acell = 5.43 +basis_vectors = torch.diag(torch.tensor([acell, acell, acell])) + +total_time_steps = 1000 +noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + sigma_min=0.0001, + sigma_max=0.2, +) + +device = torch.device("cuda") +if __name__ == "__main__": + variance_calculator = ExplodingVariance(noise_parameters) + + logger.info("Loading checkpoint...") + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + for parameter in sigma_normalized_score_network.parameters(): + parameter.requires_grad_(False) + + equilibrium_relative_coordinates = torch.from_numpy( + get_silicon_supercell(supercell_factor=1) + ).to(torch.float32) + + direction = torch.zeros_like(equilibrium_relative_coordinates) + + # Move a single atom + # direction[0, 0] = 1.0 + # list_delta = torch.linspace(-0.5, 0.5, 101) + + # Put two particles on top of each other + dv = equilibrium_relative_coordinates[0] - equilibrium_relative_coordinates[1] + direction[0] = -0.5 * dv + direction[1] = 0.5 * dv + list_delta = torch.linspace(0., 2.0, 201) + + relative_coordinates = [] + for delta in list_delta: + relative_coordinates.append( + equilibrium_relative_coordinates + delta * direction + ) + relative_coordinates = map_relative_coordinates_to_unit_cell( + torch.stack(relative_coordinates) + ).to(device) + + list_t = torch.tensor([0.8, 0.7, 0.5, 0.3, 0.1, 0.01]) + list_sigmas = variance_calculator.get_sigma(list_t) + list_norms = [] + for t in tqdm(list_t, "norms"): + vector_field_fn = create_fixed_time_normalized_score_function( + sigma_normalized_score_network, + noise_parameters, + time=t, + basis_vectors=basis_vectors, + ) + + normalized_scores = vector_field_fn(relative_coordinates) + flat_normalized_scores = einops.rearrange( + normalized_scores, " b n s -> b (n s)" + ) + list_norms.append(flat_normalized_scores.norm(dim=-1).cpu()) + + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig.suptitle("Normalized Score Norm Along Specific Direction") + ax1 = fig.add_subplot(111) + ax1.set_xlabel(r"$\delta$") + ax1.set_ylabel(r"$|{\bf n}({\bf x}, t)|$") + + for t, sigma, norms in zip(list_t, list_sigmas, list_norms): + ax1.plot( + list_delta, norms, "-", label=f"t = {t: 3.2f}, $\\sigma$ = {sigma: 5.2e}" + ) + + ax1.legend(loc=0) + + fig.tight_layout() + + plt.show() diff --git a/experiment_analysis/score_stability_analysis/util.py b/experiment_analysis/score_stability_analysis/util.py new file mode 100644 index 00000000..6e272894 --- /dev/null +++ b/experiment_analysis/score_stability_analysis/util.py @@ -0,0 +1,65 @@ +import itertools +from typing import Callable + +import einops +import torch + +from crystal_diffusion.models.score_networks import ScoreNetwork +from crystal_diffusion.namespace import (CARTESIAN_FORCES, NOISE, + NOISY_RELATIVE_COORDINATES, TIME, + UNIT_CELL) +from crystal_diffusion.samplers.exploding_variance import ExplodingVariance +from crystal_diffusion.samplers.variance_sampler import NoiseParameters + + +def get_normalized_score_function( + noise_parameters: NoiseParameters, + sigma_normalized_score_network: ScoreNetwork, + basis_vectors: torch.Tensor, +) -> Callable: + """Get normalizd score function.""" + variance_calculator = ExplodingVariance(noise_parameters) + + def normalized_score_function( + relative_coordinates: torch.Tensor, times: torch.Tensor + ) -> torch.Tensor: + batch_size, number_of_atoms, spatial_dimension = relative_coordinates.shape + unit_cells = einops.repeat( + basis_vectors.to(relative_coordinates), "s1 s2 -> b s1 s2", b=batch_size + ) + + forces = torch.zeros_like(relative_coordinates) + sigmas = variance_calculator.get_sigma(times) + + augmented_batch = { + NOISY_RELATIVE_COORDINATES: relative_coordinates, + TIME: times, + NOISE: sigmas, + UNIT_CELL: unit_cells, + CARTESIAN_FORCES: forces, + } + + sigma_normalized_scores = sigma_normalized_score_network( + augmented_batch, conditional=False + ) + + return sigma_normalized_scores + + return normalized_score_function + + +def get_cubic_point_group_symmetries(): + """Get cubic point group symmetries.""" + permutations = [ + torch.diag(torch.ones(3))[[idx]] for idx in itertools.permutations([0, 1, 2]) + ] + sign_changes = [ + torch.diag(torch.tensor(diag)) + for diag in itertools.product([-1.0, 1.0], repeat=3) + ] + symmetries = [] + for permutation in permutations: + for sign_change in sign_changes: + symmetries.append(permutation @ sign_change) + + return symmetries diff --git a/tests/models/score_network/test_score_network.py b/tests/models/score_network/test_score_network.py index b5445cdf..fcb54d0c 100644 --- a/tests/models/score_network/test_score_network.py +++ b/tests/models/score_network/test_score_network.py @@ -301,7 +301,6 @@ def score_network(self, score_network_parameters): return DiffusionMACEScoreNetwork(score_network_parameters) -@pytest.mark.parametrize("spatial_dimension", [3]) class TestEGNNScoreNetwork(BaseTestScoreNetwork): @pytest.fixture(scope="class", autouse=True) @@ -314,8 +313,21 @@ def set_default_type_to_float64(self): torch.set_default_dtype(torch.float32) @pytest.fixture() - def score_network_parameters(self): - return EGNNScoreNetworkParameters() # Use the defaults + def spatial_dimension(self): + return 3 + + @pytest.fixture() + def basis_vectors(self, batch_size, spatial_dimension): + # The basis vectors should form a cube in order to test the equivariance of the current implementation + # of the EGNN model. The octaheral point group only applies in this case! + acell = 5.5 + cubes = torch.stack([torch.diag(acell * torch.ones(spatial_dimension)) for _ in range(batch_size)]) + return cubes + + @pytest.fixture(params=[("fully_connected", None), ("radial_cutoff", 3.0)]) + def score_network_parameters(self, request): + edges, radial_cutoff = request.param + return EGNNScoreNetworkParameters(edges=edges, radial_cutoff=radial_cutoff) @pytest.fixture() def score_network(self, score_network_parameters): @@ -334,6 +346,13 @@ def octahedral_point_group_symmetries(self): return symmetries + @pytest.mark.parametrize("edges, radial_cutoff", [("fully_connected", 3.0), ("radial_cutoff", None)]) + def test_score_network_parameters(self, edges, radial_cutoff): + score_network_parameters = EGNNScoreNetworkParameters(edges=edges, radial_cutoff=radial_cutoff) + with pytest.raises(AssertionError): + # Check that the code crashes when inconsistent parameters are fed in. + EGNNScoreNetwork(score_network_parameters) + def test_create_block_diagonal_projection_matrices(self, score_network, spatial_dimension): expected_matrices = [] for space_idx in range(spatial_dimension): diff --git a/tests/test_sample_diffusion.py b/tests/test_sample_diffusion.py new file mode 100644 index 00000000..aa652e40 --- /dev/null +++ b/tests/test_sample_diffusion.py @@ -0,0 +1,157 @@ +import dataclasses + +import pytest +import torch +import yaml + +from crystal_diffusion import sample_diffusion +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + PredictorCorrectorSamplingParameters +from crystal_diffusion.models.loss import MSELossParameters +from crystal_diffusion.models.optimizer import OptimizerParameters +from crystal_diffusion.models.position_diffusion_lightning_model import ( + PositionDiffusionLightningModel, PositionDiffusionParameters) +from crystal_diffusion.models.score_networks.mlp_score_network import \ + MLPScoreNetworkParameters +from crystal_diffusion.namespace import RELATIVE_COORDINATES +from crystal_diffusion.samplers.variance_sampler import NoiseParameters + + +@pytest.fixture() +def spatial_dimension(): + return 3 + + +@pytest.fixture() +def number_of_atoms(): + return 8 + + +@pytest.fixture() +def number_of_samples(): + return 12 + + +@pytest.fixture() +def cell_dimensions(): + return [5.1, 6.2, 7.3] + + +@pytest.fixture(params=[True, False]) +def record_samples(request): + return request.param + + +@pytest.fixture() +def noise_parameters(): + return NoiseParameters(total_time_steps=10) + + +@pytest.fixture() +def sampling_parameters( + number_of_atoms, + spatial_dimension, + number_of_samples, + cell_dimensions, + record_samples, +): + return PredictorCorrectorSamplingParameters( + number_of_corrector_steps=1, + spatial_dimension=spatial_dimension, + number_of_atoms=number_of_atoms, + number_of_samples=number_of_samples, + cell_dimensions=cell_dimensions, + record_samples=record_samples, + ) + + +@pytest.fixture() +def sigma_normalized_score_network(number_of_atoms, noise_parameters): + score_network_parameters = MLPScoreNetworkParameters( + number_of_atoms=number_of_atoms, + embedding_dimensions_size=8, + n_hidden_dimensions=2, + hidden_dimensions_size=16, + ) + + diffusion_params = PositionDiffusionParameters( + score_network_parameters=score_network_parameters, + loss_parameters=MSELossParameters(), + optimizer_parameters=OptimizerParameters(name="adam", learning_rate=1e-3), + scheduler_parameters=None, + noise_parameters=noise_parameters, + diffusion_sampling_parameters=None, + ) + + model = PositionDiffusionLightningModel(diffusion_params) + return model.sigma_normalized_score_network + + +@pytest.fixture() +def config_path(tmp_path, noise_parameters, sampling_parameters): + config_path = str(tmp_path / "test_config.yaml") + + config = dict( + noise=dataclasses.asdict(noise_parameters), + sampling=dataclasses.asdict(sampling_parameters), + ) + + with open(config_path, "w") as fd: + yaml.dump(config, fd) + + return config_path + + +@pytest.fixture() +def checkpoint_path(tmp_path): + path_to_checkpoint = tmp_path / "fake_checkpoint.pt" + with open(path_to_checkpoint, "w") as fd: + fd.write("This is a dummy checkpoint file.") + return path_to_checkpoint + + +@pytest.fixture() +def output_path(tmp_path): + output = tmp_path / "output" + return output + + +@pytest.fixture() +def args(config_path, checkpoint_path, output_path): + """Input arguments for main.""" + input_args = [ + f"--config={config_path}", + f"--checkpoint={checkpoint_path}", + f"--output={output_path}", + "--device=cpu", + ] + + return input_args + + +def test_sample_diffusion( + mocker, + args, + sigma_normalized_score_network, + output_path, + number_of_samples, + number_of_atoms, + spatial_dimension, + record_samples, +): + mocker.patch( + "crystal_diffusion.sample_diffusion.get_sigma_normalized_score_network", + return_value=sigma_normalized_score_network, + ) + + sample_diffusion.main(args) + + assert (output_path / "samples.pt").exists() + samples = torch.load(output_path / "samples.pt") + assert samples[RELATIVE_COORDINATES].shape == ( + number_of_samples, + number_of_atoms, + spatial_dimension, + ) + + assert (output_path / "trajectories.pt").exists() == record_samples