Skip to content

Commit

Permalink
Merge pull request #86 from mila-iqia/refactor_oct19
Browse files Browse the repository at this point in the history
Refactor oct19
  • Loading branch information
rousseab authored Oct 21, 2024
2 parents 6ce1dbd + 0e10484 commit 3389c89
Showing 14 changed files with 1,224 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@ def compute_oracle_energies(self, batch_relative_coordinates: torch.Tensor) -> n

logger.info("Compute energy from Oracle")
with tempfile.TemporaryDirectory() as tmp_work_dir:
for positions, box in zip(batch_cartesian_positions.numpy(), batched_unit_cells.numpy()):
for positions, box in zip(batch_cartesian_positions.cpu().numpy(), batched_unit_cells.cpu().numpy()):
energy, forces = get_energy_and_forces_from_lammps(positions,
box,
self.atom_types,
22 changes: 18 additions & 4 deletions crystal_diffusion/metrics/kolmogorov_smirnov_metrics.py
Original file line number Diff line number Diff line change
@@ -7,23 +7,37 @@
class KolmogorovSmirnovMetrics:
"""Kolmogorov Smirnov metrics."""

def __init__(self):
"""Init method."""
def __init__(self, maximum_number_of_samples: int = 1_000_000):
"""Init method.
Args:
maximum_number_of_samples : maximum number of samples that will be aggregated. This is to avoid
memory use explosion.
"""
self.reference_samples_metric = CatMetric()
self.predicted_samples_metric = CatMetric()
self.maximum_count = maximum_number_of_samples
self.reference_count = 0
self.predicted_count = 0

def register_reference_samples(self, reference_samples):
"""Register reference samples."""
self.reference_samples_metric.update(reference_samples)
if self.reference_count < self.maximum_count:
self.reference_count += len(reference_samples)
self.reference_samples_metric.update(reference_samples)

def register_predicted_samples(self, predicted_samples):
"""Register predicted samples."""
self.predicted_samples_metric.update(predicted_samples)
if self.predicted_count < self.maximum_count:
self.predicted_count += len(predicted_samples)
self.predicted_samples_metric.update(predicted_samples)

def reset(self):
"""reset."""
self.reference_samples_metric.reset()
self.predicted_samples_metric.reset()
self.reference_count = 0
self.predicted_count = 0

def compute_kolmogorov_smirnov_distance_and_pvalue(self) -> Tuple[float, float]:
"""Compute Kolmogorov Smirnov Distance.
Original file line number Diff line number Diff line change
@@ -387,6 +387,7 @@ def on_validation_epoch_end(self) -> None:
logger.info(" * Registering sample energies")
self.energy_ks_metric.register_predicted_samples(sample_energies.cpu())

logger.info(" * Computing KS distance for energies")
(
ks_distance,
p_value,
@@ -400,7 +401,7 @@ def on_validation_epoch_end(self) -> None:
self.log(
"validation_ks_p_value_energy", p_value, on_step=False, on_epoch=True
)
logger.info(" * Done logging sample energies")
logger.info(" * Done logging KS distance for energies")

if self.draw_samples and self.metrics_parameters.compute_structure_factor:
logger.info(" * Computing sample distances")
@@ -413,6 +414,7 @@ def on_validation_epoch_end(self) -> None:
logger.info(" * Registering sample distances")
self.structure_ks_metric.register_predicted_samples(sample_distances.cpu())

logger.info(" * Computing KS distance for distances")
(
ks_distance,
p_value,
12 changes: 10 additions & 2 deletions crystal_diffusion/models/score_networks/egnn_score_network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import AnyStr, Dict
from typing import AnyStr, Dict, Union

import einops
import torch
@@ -31,7 +31,7 @@ class EGNNScoreNetworkParameters(ScoreNetworkParameters):
message_agg: str = "mean"
n_layers: int = 4
edges: str = 'fully_connected'
radial_cutoff: float = 4.0
radial_cutoff: Union[float, None] = None
drop_duplicate_edges: bool = True


@@ -60,7 +60,15 @@ def __init__(self, hyper_params: EGNNScoreNetworkParameters):
self.edges = hyper_params.edges
assert self.edges in ["fully_connected", "radial_cutoff"], \
f'Edges type should be fully_connected or radial_cutoff. Got {self.edges}'

self.radial_cutoff = hyper_params.radial_cutoff

if self.edges == "fully_connected":
assert self.radial_cutoff is None, "Specifying a radial cutoff is inconsistent with edges=fully_connected."
else:
assert type(self.radial_cutoff) is float, \
"A floating point value for the radial cutoff is needed for edges=radial_cutoff."

self.drop_duplicate_edges = hyper_params.drop_duplicate_edges

self.egnn = EGNN(
193 changes: 193 additions & 0 deletions crystal_diffusion/sample_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""Sample Diffusion.
This script is the entry point to draw samples from a pre-trained model checkpoint.
"""
import argparse
import logging
import os
import socket
from pathlib import Path
from typing import Any, AnyStr, Dict, Optional, Union

import torch

from crystal_diffusion.generators.instantiate_generator import \
instantiate_generator
from crystal_diffusion.generators.load_sampling_parameters import \
load_sampling_parameters
from crystal_diffusion.generators.position_generator import SamplingParameters
from crystal_diffusion.main_utils import load_and_backup_hyperparameters
from crystal_diffusion.models.position_diffusion_lightning_model import \
PositionDiffusionLightningModel
from crystal_diffusion.models.score_networks import ScoreNetwork
from crystal_diffusion.oracle.energies import compute_oracle_energies
from crystal_diffusion.samplers.variance_sampler import NoiseParameters
from crystal_diffusion.samples.sampling import create_batch_of_samples
from crystal_diffusion.utils.logging_utils import (get_git_hash,
setup_console_logger)

logger = logging.getLogger(__name__)


def main(args: Optional[Any] = None):
"""Load a diffusion model and draw samples.
This main.py file is meant to be called using the cli.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
required=True,
help="config file with sampling parameters in yaml format.",
)
parser.add_argument(
"--checkpoint", required=True, help="path to checkpoint model to be loaded."
)
parser.add_argument(
"--output", required=True, help="path to outputs - will store files here"
)
parser.add_argument(
"--device", default="cuda", help="Device to use. Defaults to cuda."
)
args = parser.parse_args(args)
if os.path.exists(args.output):
logger.info(f"WARNING: the output directory {args.output} already exists!")
else:
os.makedirs(args.output)

setup_console_logger(experiment_dir=args.output)
assert os.path.exists(
args.checkpoint
), f"The path {args.checkpoint} does not exist. Cannot go on."

script_location = os.path.realpath(__file__)
git_hash = get_git_hash(script_location)
hostname = socket.gethostname()
logger.info("Sampling Experiment info:")
logger.info(f" Hostname : {hostname}")
logger.info(f" Git Hash : {git_hash}")
logger.info(f" Checkpoint : {args.checkpoint}")

# Very opinionated logger, which writes to the output folder.
logger.info(f"Start Generating Samples with checkpoint {args.checkpoint}")

hyper_params = load_and_backup_hyperparameters(
config_file_path=args.config, output_directory=args.output
)

device = torch.device(args.device)
noise_parameters, sampling_parameters = extract_and_validate_parameters(
hyper_params
)

create_samples_and_write_to_disk(
noise_parameters=noise_parameters,
sampling_parameters=sampling_parameters,
device=device,
checkpoint_path=args.checkpoint,
output_path=args.output,
)


def extract_and_validate_parameters(hyper_params: Dict[AnyStr, Any]):
"""Extract and validate parameters.
Args:
hyper_params : Dictionary of hyper-parameters for drawing samples.
Returns:
noise_parameters: object that defines the noise schedule
sampling_parameters: object that defines how to draw samples, and how many.
"""
assert (
"noise" in hyper_params
), "The noise parameters must be defined to draw samples."
noise_parameters = NoiseParameters(**hyper_params["noise"])

assert (
"sampling" in hyper_params
), "The sampling parameters must be defined to draw samples."
sampling_parameters = load_sampling_parameters(hyper_params["sampling"])

return noise_parameters, sampling_parameters


def get_sigma_normalized_score_network(
checkpoint_path: Union[str, Path]
) -> ScoreNetwork:
"""Get sigma-normalized score network.
Args:
checkpoint_path : path where the checkpoint is written.
Returns:
sigma_normalized score network: read from the checkpoint.
"""
logger.info("Loading checkpoint...")
pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path)
pl_model.eval()

sigma_normalized_score_network = pl_model.sigma_normalized_score_network
return sigma_normalized_score_network


def create_samples_and_write_to_disk(
noise_parameters: NoiseParameters,
sampling_parameters: SamplingParameters,
device: torch.device,
checkpoint_path: Union[str, Path],
output_path: Union[str, Path],
):
"""Create Samples and write to disk.
Method that drives the creation of samples.
Args:
noise_parameters: object that defines the noise schedule
sampling_parameters: object that defines how to draw samples, and how many.
device: which device should be used to draw samples.
checkpoint_path : path to checkpoint of model to be loaded.
output_path: where the outputs should be written.
Returns:
None
"""
sigma_normalized_score_network = get_sigma_normalized_score_network(checkpoint_path)

logger.info("Instantiate generator...")
position_generator = instantiate_generator(
sampling_parameters=sampling_parameters,
noise_parameters=noise_parameters,
sigma_normalized_score_network=sigma_normalized_score_network,
)

logger.info("Generating samples...")
with torch.no_grad():
samples_batch = create_batch_of_samples(
generator=position_generator,
sampling_parameters=sampling_parameters,
device=device,
)
logger.info("Done Generating Samples.")

logger.info("Writing samples to disk...")
output_directory = Path(output_path)
with open(output_directory / "samples.pt", "wb") as fd:
torch.save(samples_batch, fd)

logger.info("Compute energy from Oracle...")
sample_energies = compute_oracle_energies(samples_batch)

logger.info("Writing energies to disk...")
with open(output_directory / "energies.pt", "wb") as fd:
torch.save(sample_energies, fd)

if sampling_parameters.record_samples:
logger.info("Writing sampling trajectories to disk...")
position_generator.sample_trajectory_recorder.write_to_pickle(
output_directory / "trajectories.pt"
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions examples/config_files/diffusion/config_diffusion_egnn.yaml
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@ model:
normalize: True
residual: True
tanh: False
edges: fully_connected
noise:
total_time_steps: 1000
sigma_min: 0.0001
78 changes: 78 additions & 0 deletions experiment_analysis/dataset_analysis/dataset_covariance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Effective Dataset Variance.
The goal of this script is to compute the effective "sigma_d" of the
actual datasets, that is, the standard deviation of the displacement
from equilibrium, in fractional coordinates.
"""
import logging

import einops
import torch
from tqdm import tqdm

from crystal_diffusion import ANALYSIS_RESULTS_DIR, DATA_DIR
from crystal_diffusion.data.diffusion.data_loader import (
LammpsForDiffusionDataModule, LammpsLoaderParameters)
from crystal_diffusion.utils.basis_transformations import \
map_relative_coordinates_to_unit_cell
from crystal_diffusion.utils.logging_utils import setup_analysis_logger

logger = logging.getLogger(__name__)
dataset_name = 'si_diffusion_2x2x2'
# dataset_name = 'si_diffusion_1x1x1'

output_dir = ANALYSIS_RESULTS_DIR / "covariances"
output_dir.mkdir(exist_ok=True)


if dataset_name == 'si_diffusion_1x1x1':
max_atom = 8
translation = torch.tensor([0.125, 0.125, 0.125])
elif dataset_name == 'si_diffusion_2x2x2':
max_atom = 64
translation = torch.tensor([0.0625, 0.0625, 0.0625])

lammps_run_dir = DATA_DIR / dataset_name
processed_dataset_dir = lammps_run_dir / "processed"

cache_dir = lammps_run_dir / "cache"

data_params = LammpsLoaderParameters(batch_size=2048, max_atom=max_atom)

if __name__ == '__main__':
setup_analysis_logger()
logger.info(f"Computing the covariance matrix for {dataset_name}")

datamodule = LammpsForDiffusionDataModule(
lammps_run_dir=lammps_run_dir,
processed_dataset_dir=processed_dataset_dir,
hyper_params=data_params,
working_cache_dir=cache_dir,
)
datamodule.setup()

train_dataset = datamodule.train_dataset

list_means = []
for batch in tqdm(datamodule.train_dataloader(), "Mean"):
x = map_relative_coordinates_to_unit_cell(batch['relative_coordinates'] + translation)
list_means.append(x.mean(dim=0))

# Drop the last batch, which might not have dimension batch_size
x0 = torch.stack(list_means[:-1]).mean(dim=0)

list_covariances = []
list_sizes = []
for batch in tqdm(datamodule.train_dataloader(), "displacements"):
x = map_relative_coordinates_to_unit_cell(batch['relative_coordinates'] + translation)
list_sizes.append(x.shape[0])
displacements = einops.rearrange(x - x0, "batch natoms space -> batch (natoms space)")
covariance = (displacements[:, None, :] * displacements[:, :, None]).sum(dim=0)
list_covariances.append(covariance)

covariance = torch.stack(list_covariances).sum(dim=0) / sum(list_sizes)

output_file = output_dir / f"covariance_{dataset_name}.pkl"
logger.info(f"Writing to file {output_file}...")
with open(output_file, 'wb') as fd:
torch.save(covariance, fd)
Loading

0 comments on commit 3389c89

Please sign in to comment.