Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…ti_scale_molecular_dynamics into egnn_radial_cutoff
  • Loading branch information
sblackburn-mila committed Oct 4, 2024
2 parents 6be1155 + 9feaa7f commit 92abb0c
Show file tree
Hide file tree
Showing 56 changed files with 2,944 additions and 1,088 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
LANGEVIN_EXPLORATION_DIRECTORY
from crystal_diffusion.analysis.analytic_score.utils import (
get_exact_samples, get_silicon_supercell)
from crystal_diffusion.callbacks.sampling_callback import logger
from crystal_diffusion.callbacks.sampling_visualization_callback import logger
from crystal_diffusion.generators.langevin_generator import LangevinGenerator
from crystal_diffusion.generators.predictor_corrector_position_generator import \
PredictorCorrectorSamplingParameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_exact_samples, get_silicon_supercell)
from crystal_diffusion.callbacks.loss_monitoring_callback import \
LossMonitoringCallback
from crystal_diffusion.callbacks.sampling_callback import \
from crystal_diffusion.callbacks.sampling_visualization_callback import \
PredictorCorrectorDiffusionSamplingCallback
from crystal_diffusion.generators.predictor_corrector_position_generator import \
PredictorCorrectorSamplingParameters
Expand Down
64 changes: 35 additions & 29 deletions crystal_diffusion/analysis/generator_sample_analysis_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from einops import einops

from crystal_diffusion.generators.ode_position_generator import \
ExplodingVarianceODEPositionGenerator
from crystal_diffusion.generators.ode_position_generator import (
ExplodingVarianceODEPositionGenerator, ODESamplingParameters)
from crystal_diffusion.models.mace_utils import get_adj_matrix
from crystal_diffusion.models.score_networks.score_network import ScoreNetwork
from crystal_diffusion.samplers.variance_sampler import NoiseParameters
Expand All @@ -17,40 +17,40 @@ class PartialODEPositionGenerator(ExplodingVarianceODEPositionGenerator):
2- providing a fixed starting point, initial_relative_coordinates, instead of a random starting point.
"""

def __init__(self,
noise_parameters: NoiseParameters,
number_of_atoms: int,
spatial_dimension: int,
sigma_normalized_score_network: ScoreNetwork,
initial_relative_coordinates: torch.Tensor,
record_samples: bool = False,
absolute_solver_tolerance: float = 1.0e-3,
relative_solver_tolerance: float = 1.0e-2,
tf: float = 1.0,
):
def __init__(
self,
noise_parameters: NoiseParameters,
sampling_parameters: ODESamplingParameters,
sigma_normalized_score_network: ScoreNetwork,
initial_relative_coordinates: torch.Tensor,
tf: float = 1.0,
):
"""Init method."""
super(PartialODEPositionGenerator, self).__init__(noise_parameters,
number_of_atoms,
spatial_dimension,
sigma_normalized_score_network,
record_samples,
absolute_solver_tolerance,
relative_solver_tolerance)
super(PartialODEPositionGenerator, self).__init__(
noise_parameters, sampling_parameters, sigma_normalized_score_network
)

self.tf = tf
assert initial_relative_coordinates.shape[1:] == (number_of_atoms, spatial_dimension), "Inconsistent shape"
assert initial_relative_coordinates.shape[1:] == (
sampling_parameters.number_of_atoms,
sampling_parameters.spatial_dimension,
), "Inconsistent shape"

self.initial_relative_coordinates = initial_relative_coordinates

def initialize(self, number_of_samples: int):
"""This method must initialize the samples from the fully noised distribution."""
assert number_of_samples == self.initial_relative_coordinates.shape[0], "Inconsistent number of samples"
assert (
number_of_samples == self.initial_relative_coordinates.shape[0]
), "Inconsistent number of samples"
return self.initial_relative_coordinates


def get_interatomic_distances(cartesian_positions: torch.Tensor,
basis_vectors: torch.Tensor,
radial_cutoff: float = 5.0):
def get_interatomic_distances(
cartesian_positions: torch.Tensor,
basis_vectors: torch.Tensor,
radial_cutoff: float = 5.0,
):
"""Get Interatomic Distances.
Args:
Expand All @@ -61,12 +61,18 @@ def get_interatomic_distances(cartesian_positions: torch.Tensor,
Returns:
distances : all distances up to cutoff.
"""
shifted_adjacency_matrix, shifts, batch_indices = get_adj_matrix(positions=cartesian_positions,
basis_vectors=basis_vectors,
radial_cutoff=radial_cutoff)
shifted_adjacency_matrix, shifts, batch_indices = get_adj_matrix(
positions=cartesian_positions,
basis_vectors=basis_vectors,
radial_cutoff=radial_cutoff,
)

flat_positions = einops.rearrange(cartesian_positions, "b n d -> (b n) d")

displacements = flat_positions[shifted_adjacency_matrix[1]] - flat_positions[shifted_adjacency_matrix[0]] + shifts
displacements = (
flat_positions[shifted_adjacency_matrix[1]]
- flat_positions[shifted_adjacency_matrix[0]]
+ shifts
)
interatomic_distances = torch.linalg.norm(displacements, dim=1)
return interatomic_distances
8 changes: 4 additions & 4 deletions crystal_diffusion/callbacks/analysis_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from crystal_diffusion.analysis import PLOT_STYLE_PATH
from crystal_diffusion.analysis.analytic_score.utils import \
get_relative_harmonic_energy
from crystal_diffusion.callbacks.sampling_callback import \
DiffusionSamplingCallback
from crystal_diffusion.callbacks.sampling_visualization_callback import \
SamplingVisualizationCallback
from crystal_diffusion.generators.position_generator import SamplingParameters
from crystal_diffusion.samplers.variance_sampler import NoiseParameters

Expand All @@ -22,7 +22,7 @@
plt.style.use(PLOT_STYLE_PATH)


class HarmonicEnergyDiffusionSamplingCallback(DiffusionSamplingCallback):
class HarmonicEnergyDiffusionSamplingCallback(SamplingVisualizationCallback):
"""Callback class to periodically generate samples and log their energies."""

def __init__(self, noise_parameters: NoiseParameters,
Expand Down Expand Up @@ -54,7 +54,7 @@ def _compute_oracle_energies(self, batch_relative_coordinates: torch.Tensor) ->
@staticmethod
def _plot_energy_histogram(sample_energies: np.ndarray, validation_dataset_energies: np.array,
epoch: int) -> plt.figure:
fig = DiffusionSamplingCallback._plot_energy_histogram(sample_energies, validation_dataset_energies, epoch)
fig = SamplingVisualizationCallback._plot_energy_histogram(sample_energies, validation_dataset_energies, epoch)

fig.suptitle(f'Sampling Unitless Harmonic Potential Energy Distributions\nEpoch {epoch}')
ax1 = fig.axes[0]
Expand Down
6 changes: 3 additions & 3 deletions crystal_diffusion/callbacks/callback_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

from crystal_diffusion.callbacks.loss_monitoring_callback import \
instantiate_loss_monitoring_callback
from crystal_diffusion.callbacks.sampling_callback import \
instantiate_diffusion_sampling_callback
from crystal_diffusion.callbacks.sampling_visualization_callback import \
instantiate_sampling_visualization_callback
from crystal_diffusion.callbacks.standard_callbacks import (
CustomProgressBar, instantiate_early_stopping_callback,
instantiate_model_checkpoint_callbacks)

OPTIONAL_CALLBACK_DICTIONARY = dict(early_stopping=instantiate_early_stopping_callback,
model_checkpoint=instantiate_model_checkpoint_callbacks,
diffusion_sampling=instantiate_diffusion_sampling_callback,
sampling_visualization=instantiate_sampling_visualization_callback,
loss_monitoring=instantiate_loss_monitoring_callback)


Expand Down
Loading

0 comments on commit 92abb0c

Please sign in to comment.