diff --git a/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py b/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py index 02340d4d..64d5c4fc 100644 --- a/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py +++ b/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py @@ -49,7 +49,7 @@ def compute_oracle_energies(self, batch_relative_coordinates: torch.Tensor) -> n logger.info("Compute energy from Oracle") with tempfile.TemporaryDirectory() as tmp_work_dir: - for positions, box in zip(batch_cartesian_positions.numpy(), batched_unit_cells.numpy()): + for positions, box in zip(batch_cartesian_positions.cpu().numpy(), batched_unit_cells.cpu().numpy()): energy, forces = get_energy_and_forces_from_lammps(positions, box, self.atom_types, diff --git a/experiment_analysis/dataset_analysis/dataset_covariance.py b/experiment_analysis/dataset_analysis/dataset_covariance.py new file mode 100644 index 00000000..eed61e7e --- /dev/null +++ b/experiment_analysis/dataset_analysis/dataset_covariance.py @@ -0,0 +1,78 @@ +"""Effective Dataset Variance. + +The goal of this script is to compute the effective "sigma_d" of the +actual datasets, that is, the standard deviation of the displacement +from equilibrium, in fractional coordinates. +""" +import logging + +import einops +import torch +from tqdm import tqdm + +from crystal_diffusion import ANALYSIS_RESULTS_DIR, DATA_DIR +from crystal_diffusion.data.diffusion.data_loader import ( + LammpsForDiffusionDataModule, LammpsLoaderParameters) +from crystal_diffusion.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell +from crystal_diffusion.utils.logging_utils import setup_analysis_logger + +logger = logging.getLogger(__name__) +dataset_name = 'si_diffusion_2x2x2' +# dataset_name = 'si_diffusion_1x1x1' + +output_dir = ANALYSIS_RESULTS_DIR / "covariances" +output_dir.mkdir(exist_ok=True) + + +if dataset_name == 'si_diffusion_1x1x1': + max_atom = 8 + translation = torch.tensor([0.125, 0.125, 0.125]) +elif dataset_name == 'si_diffusion_2x2x2': + max_atom = 64 + translation = torch.tensor([0.0625, 0.0625, 0.0625]) + +lammps_run_dir = DATA_DIR / dataset_name +processed_dataset_dir = lammps_run_dir / "processed" + +cache_dir = lammps_run_dir / "cache" + +data_params = LammpsLoaderParameters(batch_size=2048, max_atom=max_atom) + +if __name__ == '__main__': + setup_analysis_logger() + logger.info(f"Computing the covariance matrix for {dataset_name}") + + datamodule = LammpsForDiffusionDataModule( + lammps_run_dir=lammps_run_dir, + processed_dataset_dir=processed_dataset_dir, + hyper_params=data_params, + working_cache_dir=cache_dir, + ) + datamodule.setup() + + train_dataset = datamodule.train_dataset + + list_means = [] + for batch in tqdm(datamodule.train_dataloader(), "Mean"): + x = map_relative_coordinates_to_unit_cell(batch['relative_coordinates'] + translation) + list_means.append(x.mean(dim=0)) + + # Drop the last batch, which might not have dimension batch_size + x0 = torch.stack(list_means[:-1]).mean(dim=0) + + list_covariances = [] + list_sizes = [] + for batch in tqdm(datamodule.train_dataloader(), "displacements"): + x = map_relative_coordinates_to_unit_cell(batch['relative_coordinates'] + translation) + list_sizes.append(x.shape[0]) + displacements = einops.rearrange(x - x0, "batch natoms space -> batch (natoms space)") + covariance = (displacements[:, None, :] * displacements[:, :, None]).sum(dim=0) + list_covariances.append(covariance) + + covariance = torch.stack(list_covariances).sum(dim=0) / sum(list_sizes) + + output_file = output_dir / f"covariance_{dataset_name}.pkl" + logger.info(f"Writing to file {output_file}...") + with open(output_file, 'wb') as fd: + torch.save(covariance, fd) diff --git a/experiment_analysis/dataset_analysis/plot_si_phonon_DOS.py b/experiment_analysis/dataset_analysis/plot_si_phonon_DOS.py new file mode 100644 index 00000000..3d719b22 --- /dev/null +++ b/experiment_analysis/dataset_analysis/plot_si_phonon_DOS.py @@ -0,0 +1,86 @@ +"""Silicon phonon Density of States. + +The displacement covariance is related to the phonon dynamical matrix. +Here we extract the corresponding phonon density of state, based on this covariance, +to see if the energy scales match up. +""" +import matplotlib.pyplot as plt +import torch + +from crystal_diffusion import ANALYSIS_RESULTS_DIR +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH + +plt.style.use(PLOT_STYLE_PATH) + +# Define some constants. +kelvin_in_Ha = 0.000_003_166_78 +T_in_kelvin = 300.0 +bohr_in_angst = 0.529177 +Si_mass = 28.0855 +proton_mass = 1836.152673426 + +Ha_in_meV = 27211.0 + +THz_in_meV = 4.136 + +acell = 5.43 + + +dataset_name_2x2x2 = "si_diffusion_2x2x2" +dataset_name_1x1x1 = "si_diffusion_1x1x1" + +output_dir = ANALYSIS_RESULTS_DIR / "covariances" +output_dir.mkdir(exist_ok=True) + + +if __name__ == "__main__": + kBT = kelvin_in_Ha * T_in_kelvin + a = acell / bohr_in_angst + + M = Si_mass * proton_mass + + constant_1x1x1 = M * a**2 / kBT / Ha_in_meV**2 + constant_2x2x2 = M * (2.0 * a) ** 2 / kBT / Ha_in_meV**2 + + covariance_file_1x1x1 = output_dir / f"covariance_{dataset_name_1x1x1}.pkl" + sigma_1x1x1 = torch.load(covariance_file_1x1x1) + sigma_inv_1x1x1 = torch.linalg.pinv(sigma_1x1x1) + Omega_1x1x1 = sigma_inv_1x1x1 / constant_1x1x1 + omega2_1x1x1 = torch.linalg.eigvalsh(Omega_1x1x1) + list_omega_in_meV_1x1x1 = torch.sqrt(torch.abs(omega2_1x1x1)) + + covariance_file_2x2x2 = output_dir / f"covariance_{dataset_name_2x2x2}.pkl" + sigma_2x2x2 = torch.load(covariance_file_2x2x2) + sigma_inv_2x2x2 = torch.linalg.pinv(sigma_2x2x2) + Omega_2x2x2 = sigma_inv_2x2x2 / constant_2x2x2 + omega2_2x2x2 = torch.linalg.eigvalsh(Omega_2x2x2) + list_omega_in_meV_2x2x2 = torch.sqrt(torch.abs(omega2_2x2x2)) + + max_hw = torch.max(list_omega_in_meV_2x2x2) / THz_in_meV + + bins = torch.linspace(0.0, max_hw, 50) + + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig.suptitle("Eigenvalues of Dynamical Matrix, from Displacement Covariance") + ax = fig.add_subplot(111) + + ax.set_xlim(0, max_hw + 1) + ax.set_xlabel(r"$\hbar \omega$ (THz)") + ax.set_ylabel("Count") + + ax.hist( + list_omega_in_meV_1x1x1 / THz_in_meV, + bins=bins, + label="Si 1x1x1", + color="green", + alpha=0.5, + ) + ax.hist( + list_omega_in_meV_2x2x2 / THz_in_meV, + bins=bins, + label="Si 2x2x2", + color="blue", + alpha=0.25, + ) + ax.legend(loc=0) + plt.show() diff --git a/experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py b/experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py new file mode 100644 index 00000000..700b0b4c --- /dev/null +++ b/experiment_analysis/score_stability_analysis/draw_samples_from_equilibrium.py @@ -0,0 +1,258 @@ +import logging + +import einops +import matplotlib.pyplot as plt +import numpy as np +import torch + +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.analysis.analytic_score.exploring_langevin_generator.generate_sample_energies import \ + EnergyCalculator +from crystal_diffusion.analysis.analytic_score.utils import \ + get_silicon_supercell +from crystal_diffusion.generators.langevin_generator import LangevinGenerator +from crystal_diffusion.generators.ode_position_generator import ( + ExplodingVarianceODEPositionGenerator, ODESamplingParameters) +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + PredictorCorrectorSamplingParameters +from crystal_diffusion.generators.sde_position_generator import ( + ExplodingVarianceSDEPositionGenerator, SDESamplingParameters) +from crystal_diffusion.models.position_diffusion_lightning_model import \ + PositionDiffusionLightningModel +from crystal_diffusion.models.score_networks import ScoreNetwork +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.utils.logging_utils import setup_analysis_logger + +plt.style.use(PLOT_STYLE_PATH) + +logger = logging.getLogger(__name__) +setup_analysis_logger() + + +class ForcedStartingPointLangevinGenerator(LangevinGenerator): + """Langevin Generator with Forced Starting point.""" + def __init__( + self, + noise_parameters: NoiseParameters, + sampling_parameters: PredictorCorrectorSamplingParameters, + sigma_normalized_score_network: ScoreNetwork, + starting_relative_coordinates: torch.Tensor, + ): + """Init method.""" + super().__init__( + noise_parameters, sampling_parameters, sigma_normalized_score_network + ) + + self._starting_relative_coordinates = starting_relative_coordinates + + def initialize(self, number_of_samples: int): + """This method must initialize the samples from the fully noised distribution.""" + relative_coordinates = einops.repeat( + self._starting_relative_coordinates, + "natoms space -> batch_size natoms space", + batch_size=number_of_samples, + ) + return relative_coordinates + + +class ForcedStartingPointODEPositionGenerator(ExplodingVarianceODEPositionGenerator): + """Forced starting point ODE position generator.""" + def __init__( + self, + noise_parameters: NoiseParameters, + sampling_parameters: ODESamplingParameters, + sigma_normalized_score_network: ScoreNetwork, + starting_relative_coordinates: torch.Tensor, + ): + """Init method.""" + super().__init__( + noise_parameters, sampling_parameters, sigma_normalized_score_network + ) + + self._starting_relative_coordinates = starting_relative_coordinates + + def initialize(self, number_of_samples: int): + """This method must initialize the samples from the fully noised distribution.""" + relative_coordinates = einops.repeat( + self._starting_relative_coordinates, + "natoms space -> batch_size natoms space", + batch_size=number_of_samples, + ) + return relative_coordinates + + +class ForcedStartingPointSDEPositionGenerator(ExplodingVarianceSDEPositionGenerator): + """Forced Starting Point SDE position generator.""" + def __init__( + self, + noise_parameters: NoiseParameters, + sampling_parameters: SDESamplingParameters, + sigma_normalized_score_network: ScoreNetwork, + starting_relative_coordinates: torch.Tensor, + ): + """Init method.""" + super().__init__( + noise_parameters, sampling_parameters, sigma_normalized_score_network + ) + + self._starting_relative_coordinates = starting_relative_coordinates + + def initialize(self, number_of_samples: int): + """This method must initialize the samples from the fully noised distribution.""" + relative_coordinates = einops.repeat( + self._starting_relative_coordinates, + "natoms space -> batch_size natoms space", + batch_size=number_of_samples, + ) + return relative_coordinates + + +checkpoint_path = ( + "/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/" + "output/last_model/last_model-epoch=049-step=039100.ckpt" +) + +spatial_dimension = 3 +number_of_atoms = 8 +atom_types = np.ones(number_of_atoms, dtype=int) + +acell = 5.43 + +total_time_steps = 1000 +number_of_corrector_steps = 10 +epsilon = 2.0e-7 +noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + corrector_step_epsilon=epsilon, + sigma_min=0.0001, + sigma_max=0.2, +) +number_of_samples = 1000 + +base_sampling_parameters_dict = dict( + number_of_atoms=number_of_atoms, + spatial_dimension=spatial_dimension, + cell_dimensions=[acell, acell, acell], + number_of_samples=number_of_samples, +) + +ode_sampling_parameters = ODESamplingParameters( + absolute_solver_tolerance=1.0e-5, + relative_solver_tolerance=1.0e-5, + **base_sampling_parameters_dict, +) + +# Fiddling with SDE is PITA. Also, is there a bug in there? +sde_sampling_parameters = SDESamplingParameters( + adaptative=False, **base_sampling_parameters_dict +) + + +langevin_sampling_parameters = PredictorCorrectorSamplingParameters( + number_of_corrector_steps=number_of_corrector_steps, **base_sampling_parameters_dict +) + +device = torch.device("cuda") +if __name__ == "__main__": + basis_vectors = torch.diag(torch.tensor([acell, acell, acell])).to(device) + + logger.info("Loading checkpoint...") + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + for parameter in sigma_normalized_score_network.parameters(): + parameter.requires_grad_(False) + + equilibrium_relative_coordinates = ( + torch.from_numpy(get_silicon_supercell(supercell_factor=1)) + .to(torch.float32) + .to(device) + ) + + ode_generator = ForcedStartingPointODEPositionGenerator( + noise_parameters=noise_parameters, + sampling_parameters=ode_sampling_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + starting_relative_coordinates=equilibrium_relative_coordinates, + ) + + sde_generator = ForcedStartingPointSDEPositionGenerator( + noise_parameters=noise_parameters, + sampling_parameters=sde_sampling_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + starting_relative_coordinates=equilibrium_relative_coordinates, + ) + + langevin_generator = ForcedStartingPointLangevinGenerator( + noise_parameters=noise_parameters, + sampling_parameters=langevin_sampling_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + starting_relative_coordinates=equilibrium_relative_coordinates, + ) + + unit_cells = einops.repeat(basis_vectors, "s1 s2 -> b s1 s2", b=number_of_samples) + + ode_samples = ode_generator.sample( + number_of_samples=number_of_samples, device=device, unit_cell=unit_cells + ) + sde_samples = sde_generator.sample( + number_of_samples=number_of_samples, device=device, unit_cell=unit_cells + ) + + langevin_samples = langevin_generator.sample( + number_of_samples=number_of_samples, device=device, unit_cell=unit_cells + ) + + energy_calculator = EnergyCalculator( + unit_cell=basis_vectors, number_of_atoms=number_of_atoms + ) + + ode_energies = energy_calculator.compute_oracle_energies(ode_samples) + sde_energies = energy_calculator.compute_oracle_energies(sde_samples) + langevin_energies = energy_calculator.compute_oracle_energies(langevin_samples) + + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig.suptitle("Energies of Samples Drawn from Equilibrium Coordinates") + ax1 = fig.add_subplot(121) + ax2 = fig.add_subplot(122) + ax1.set_title("Zoom In") + ax2.set_title("Broad View") + + list_q = np.linspace(0, 1, 101) + + list_energies = [ode_energies, sde_energies, langevin_energies] + + list_colors = ["blue", "red", "green"] + + langevin_label = ( + f"LANGEVIN (time steps = {total_time_steps}, " + f"corrector steps = {number_of_corrector_steps}, epsilon ={epsilon: 5.2e})" + ) + + list_labels = ["ODE", "SDE", langevin_label] + + for ax in [ax1, ax2]: + for energies, label in zip(list_energies, list_labels): + quantiles = np.quantile(energies, list_q) + ax.plot(100 * list_q, quantiles, "-", label=label) + + ax.fill_between( + [0, 100], + y1=-34.6, + y2=-34.1, + color="yellow", + alpha=0.25, + label="Training Energy Range", + ) + + ax.set_xlabel("Quantile (%)") + ax.set_ylabel("Energy (eV)") + ax.set_xlim(-0.1, 100.1) + ax1.set_ylim(-35, -34.0) + ax2.legend(loc="upper right", fancybox=True, shadow=True, ncol=1, fontsize=6) + + fig.tight_layout() + + plt.show() diff --git a/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py new file mode 100644 index 00000000..b48669be --- /dev/null +++ b/experiment_analysis/score_stability_analysis/plot_hessian_eigenvalues.py @@ -0,0 +1,217 @@ +import logging +from pathlib import Path + +import einops +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.func import jacrev +from tqdm import tqdm + +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.analysis.analytic_score.utils import \ + get_silicon_supercell +from crystal_diffusion.models.position_diffusion_lightning_model import \ + PositionDiffusionLightningModel +from crystal_diffusion.samplers.exploding_variance import ExplodingVariance +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.utils.logging_utils import setup_analysis_logger +from experiment_analysis.score_stability_analysis.util import \ + get_normalized_score_function + +plt.style.use(PLOT_STYLE_PATH) + +logger = logging.getLogger(__name__) +setup_analysis_logger() + +system = "Si_1x1x1" + +if system == "Si_1x1x1": + checkpoint_path = Path( + "/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/output/" + "last_model/last_model-epoch=049-step=039100.ckpt" + ) + number_of_atoms = 8 + acell = 5.43 + supercell_factor = 1 + + hessian_batch_size = 10 + + +elif system == "Si_2x2x2": + pickle_path = Path("/home/mila/r/rousseab/scratch/checkpoints/sota_egnn_2x2x2.pkl") + number_of_atoms = 64 + acell = 10.86 + supercell_factor = 2 + hessian_batch_size = 1 + +spatial_dimension = 3 +atom_types = np.ones(number_of_atoms, dtype=int) + +total_time_steps = 1000 +sigma_min = 0.0001 +sigma_max = 0.2 +noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + sigma_min=sigma_min, + sigma_max=sigma_max, +) + + +nsteps = 501 + + +device = torch.device("cuda") +if __name__ == "__main__": + variance_calculator = ExplodingVariance(noise_parameters) + + basis_vectors = torch.diag(torch.tensor([acell, acell, acell])).to(device) + equilibrium_relative_coordinates = ( + torch.from_numpy(get_silicon_supercell(supercell_factor=supercell_factor)) + .to(torch.float32) + .to(device) + ) + + logger.info("Loading checkpoint...") + + if system == "Si_1x1x1": + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + elif system == "Si_2x2x2": + sigma_normalized_score_network = torch.load(pickle_path) + + for parameter in sigma_normalized_score_network.parameters(): + parameter.requires_grad_(False) + + normalized_score_function = get_normalized_score_function( + noise_parameters=noise_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + basis_vectors=basis_vectors, + ) + + times = torch.linspace(1, 0, nsteps).unsqueeze(-1) + sigmas = variance_calculator.get_sigma(times) + g2 = variance_calculator.get_g_squared(times) + + prefactor = -g2 / sigmas + + relative_coordinates = einops.repeat( + equilibrium_relative_coordinates, "n s -> b n s", b=nsteps + ) + + batch_hessian_function = jacrev(normalized_score_function, argnums=0) + + list_flat_hessians = [] + for x, t in tqdm( + zip( + torch.split(relative_coordinates, hessian_batch_size), + torch.split(times, hessian_batch_size), + ), + "Hessian", + ): + batch_hessian = batch_hessian_function(x, t) + flat_hessian = einops.rearrange( + torch.diagonal(batch_hessian, dim1=0, dim2=3), + "n1 s1 n2 s2 b -> b (n1 s1) (n2 s2)", + ) + list_flat_hessians.append(flat_hessian) + + flat_hessian = torch.concat(list_flat_hessians) + + p = einops.repeat( + prefactor, + "b 1 -> b d1 d2", + d1=number_of_atoms * spatial_dimension, + d2=number_of_atoms * spatial_dimension, + ).to(flat_hessian) + + normalized_hessian = p * flat_hessian + + eigenvalues, eigenvectors = torch.linalg.eigh(normalized_hessian) + eigenvalues = eigenvalues.cpu().transpose(1, 0) + + small_count = (eigenvalues < 5e-4).sum(dim=0) + list_times = times.flatten().cpu() + list_sigmas = sigmas.flatten().cpu() + + fig = plt.figure(figsize=(PLEASANT_FIG_SIZE[0], PLEASANT_FIG_SIZE[0])) + fig.suptitle("Hessian Eigenvalues") + ax1 = fig.add_subplot(311) + ax2 = fig.add_subplot(312) + ax3 = fig.add_subplot(313) + + ax3.set_xlabel(r"$\sigma(t)$") + ax3.set_ylabel("Small Count") + ax3.set_ylim(0, number_of_atoms * spatial_dimension) + + ax3.semilogx(list_sigmas, small_count, "-", color="black") + + for ax in [ax1, ax2]: + ax.set_xlabel(r"$\sigma(t)$") + ax.set_ylabel("Eigenvalue") + + for list_e in eigenvalues: + ax.semilogx(list_sigmas, list_e, "-", color="grey") + + ax2.set_ylim([-2.5e-4, 2.5e-4]) + for ax in [ax1, ax2, ax3]: + ax.set_xlim(sigma_min, sigma_max) + fig.tight_layout() + + plt.show() + + fig2 = plt.figure(figsize=(PLEASANT_FIG_SIZE[0], PLEASANT_FIG_SIZE[0])) + fig2.suptitle("Hessian Eigenvalues At Small Time") + ax1 = fig2.add_subplot(111) + + ax1.set_xlabel(r"$\sigma(t)$") + ax1.set_ylabel("Eigenvalues") + + for list_e in eigenvalues: + ax1.loglog(list_sigmas, list_e, "-", color="grey") + + ax1.set_xlim(sigma_min, 1e-2) + + fig2.tight_layout() + + plt.show() + + fig3 = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig3.suptitle("Hessian Eigenvalues of Normalized Score") + ax1 = fig3.add_subplot(121) + ax2 = fig3.add_subplot(122) + + ax1.set_xlabel(r"$\sigma(t)$") + ax1.set_ylabel("Eigenvalues") + + ax2.set_ylabel(r"$g^2(t) / \sigma(t)$") + ax2.set_xlabel(r"$\sigma(t)$") + + label1 = r"$\sigma(t)/g^2 \times \bf H$" + label2 = r"$\bf H$" + for list_e in eigenvalues: + ax1.semilogx( + list_sigmas, + list_e / (-prefactor.flatten()), + "-", + lw=1, + color="red", + label=label1, + ) + ax1.semilogx(list_sigmas, list_e, "-", color="grey", lw=1, label=label2) + label1 = "__nolabel__" + label2 = "__nolabel__" + + ax1.legend(loc=0) + ax2.semilogx(list_sigmas, (-prefactor.flatten()), "-", color="blue") + + for ax in [ax1, ax2]: + ax.set_xlim(sigma_min, sigma_max) + + fig3.tight_layout() + + plt.show() + + jacobian_eig_at_t0 = eigenvalues[:, -1] / (-prefactor[-1]) diff --git a/experiment_analysis/score_stability_analysis/plot_score_norm.py b/experiment_analysis/score_stability_analysis/plot_score_norm.py new file mode 100644 index 00000000..ef4ddbc7 --- /dev/null +++ b/experiment_analysis/score_stability_analysis/plot_score_norm.py @@ -0,0 +1,115 @@ +import logging + +import einops +import matplotlib.pyplot as plt +import numpy as np +import torch +from tqdm import tqdm + +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.analysis.analytic_score.utils import \ + get_silicon_supercell +from crystal_diffusion.models.position_diffusion_lightning_model import \ + PositionDiffusionLightningModel +from crystal_diffusion.samplers.exploding_variance import ExplodingVariance +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell +from crystal_diffusion.utils.logging_utils import setup_analysis_logger +from experiment_analysis.score_stability_analysis.util import \ + create_fixed_time_normalized_score_function + +plt.style.use(PLOT_STYLE_PATH) + +logger = logging.getLogger(__name__) +setup_analysis_logger() + + +checkpoint_path = ("/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/" + "output/last_model/last_model-epoch=049-step=039100.ckpt") + +spatial_dimension = 3 +number_of_atoms = 8 +atom_types = np.ones(number_of_atoms, dtype=int) + +acell = 5.43 +basis_vectors = torch.diag(torch.tensor([acell, acell, acell])) + +total_time_steps = 1000 +noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + sigma_min=0.0001, + sigma_max=0.2, +) + +device = torch.device("cuda") +if __name__ == "__main__": + variance_calculator = ExplodingVariance(noise_parameters) + + logger.info("Loading checkpoint...") + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + for parameter in sigma_normalized_score_network.parameters(): + parameter.requires_grad_(False) + + equilibrium_relative_coordinates = torch.from_numpy( + get_silicon_supercell(supercell_factor=1) + ).to(torch.float32) + + direction = torch.zeros_like(equilibrium_relative_coordinates) + + # Move a single atom + # direction[0, 0] = 1.0 + # list_delta = torch.linspace(-0.5, 0.5, 101) + + # Put two particles on top of each other + dv = equilibrium_relative_coordinates[0] - equilibrium_relative_coordinates[1] + direction[0] = -0.5 * dv + direction[1] = 0.5 * dv + list_delta = torch.linspace(0., 2.0, 201) + + relative_coordinates = [] + for delta in list_delta: + relative_coordinates.append( + equilibrium_relative_coordinates + delta * direction + ) + relative_coordinates = map_relative_coordinates_to_unit_cell( + torch.stack(relative_coordinates) + ).to(device) + + list_t = torch.tensor([0.8, 0.7, 0.5, 0.3, 0.1, 0.01]) + list_sigmas = variance_calculator.get_sigma(list_t) + list_norms = [] + for t in tqdm(list_t, "norms"): + vector_field_fn = create_fixed_time_normalized_score_function( + sigma_normalized_score_network, + noise_parameters, + time=t, + basis_vectors=basis_vectors, + ) + + normalized_scores = vector_field_fn(relative_coordinates) + flat_normalized_scores = einops.rearrange( + normalized_scores, " b n s -> b (n s)" + ) + list_norms.append(flat_normalized_scores.norm(dim=-1).cpu()) + + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig.suptitle("Normalized Score Norm Along Specific Direction") + ax1 = fig.add_subplot(111) + ax1.set_xlabel(r"$\delta$") + ax1.set_ylabel(r"$|{\bf n}({\bf x}, t)|$") + + for t, sigma, norms in zip(list_t, list_sigmas, list_norms): + ax1.plot( + list_delta, norms, "-", label=f"t = {t: 3.2f}, $\\sigma$ = {sigma: 5.2e}" + ) + + ax1.legend(loc=0) + + fig.tight_layout() + + plt.show() diff --git a/experiment_analysis/score_stability_analysis/util.py b/experiment_analysis/score_stability_analysis/util.py new file mode 100644 index 00000000..6e272894 --- /dev/null +++ b/experiment_analysis/score_stability_analysis/util.py @@ -0,0 +1,65 @@ +import itertools +from typing import Callable + +import einops +import torch + +from crystal_diffusion.models.score_networks import ScoreNetwork +from crystal_diffusion.namespace import (CARTESIAN_FORCES, NOISE, + NOISY_RELATIVE_COORDINATES, TIME, + UNIT_CELL) +from crystal_diffusion.samplers.exploding_variance import ExplodingVariance +from crystal_diffusion.samplers.variance_sampler import NoiseParameters + + +def get_normalized_score_function( + noise_parameters: NoiseParameters, + sigma_normalized_score_network: ScoreNetwork, + basis_vectors: torch.Tensor, +) -> Callable: + """Get normalizd score function.""" + variance_calculator = ExplodingVariance(noise_parameters) + + def normalized_score_function( + relative_coordinates: torch.Tensor, times: torch.Tensor + ) -> torch.Tensor: + batch_size, number_of_atoms, spatial_dimension = relative_coordinates.shape + unit_cells = einops.repeat( + basis_vectors.to(relative_coordinates), "s1 s2 -> b s1 s2", b=batch_size + ) + + forces = torch.zeros_like(relative_coordinates) + sigmas = variance_calculator.get_sigma(times) + + augmented_batch = { + NOISY_RELATIVE_COORDINATES: relative_coordinates, + TIME: times, + NOISE: sigmas, + UNIT_CELL: unit_cells, + CARTESIAN_FORCES: forces, + } + + sigma_normalized_scores = sigma_normalized_score_network( + augmented_batch, conditional=False + ) + + return sigma_normalized_scores + + return normalized_score_function + + +def get_cubic_point_group_symmetries(): + """Get cubic point group symmetries.""" + permutations = [ + torch.diag(torch.ones(3))[[idx]] for idx in itertools.permutations([0, 1, 2]) + ] + sign_changes = [ + torch.diag(torch.tensor(diag)) + for diag in itertools.product([-1.0, 1.0], repeat=3) + ] + symmetries = [] + for permutation in permutations: + for sign_change in sign_changes: + symmetries.append(permutation @ sign_change) + + return symmetries