Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor oct19 #86

Merged
merged 21 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def compute_oracle_energies(self, batch_relative_coordinates: torch.Tensor) -> n

logger.info("Compute energy from Oracle")
with tempfile.TemporaryDirectory() as tmp_work_dir:
for positions, box in zip(batch_cartesian_positions.numpy(), batched_unit_cells.numpy()):
for positions, box in zip(batch_cartesian_positions.cpu().numpy(), batched_unit_cells.cpu().numpy()):
energy, forces = get_energy_and_forces_from_lammps(positions,
box,
self.atom_types,
Expand Down
22 changes: 18 additions & 4 deletions crystal_diffusion/metrics/kolmogorov_smirnov_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,37 @@
class KolmogorovSmirnovMetrics:
"""Kolmogorov Smirnov metrics."""

def __init__(self):
"""Init method."""
def __init__(self, maximum_number_of_samples: int = 1_000_000):
"""Init method.

Args:
maximum_number_of_samples : maximum number of samples that will be aggregated. This is to avoid
memory use explosion.
"""
self.reference_samples_metric = CatMetric()
self.predicted_samples_metric = CatMetric()
self.maximum_count = maximum_number_of_samples
self.reference_count = 0
self.predicted_count = 0

def register_reference_samples(self, reference_samples):
"""Register reference samples."""
self.reference_samples_metric.update(reference_samples)
if self.reference_count < self.maximum_count:
self.reference_count += len(reference_samples)
self.reference_samples_metric.update(reference_samples)

def register_predicted_samples(self, predicted_samples):
"""Register predicted samples."""
self.predicted_samples_metric.update(predicted_samples)
if self.predicted_count < self.maximum_count:
self.predicted_count += len(predicted_samples)
self.predicted_samples_metric.update(predicted_samples)

def reset(self):
"""reset."""
self.reference_samples_metric.reset()
self.predicted_samples_metric.reset()
self.reference_count = 0
self.predicted_count = 0

def compute_kolmogorov_smirnov_distance_and_pvalue(self) -> Tuple[float, float]:
"""Compute Kolmogorov Smirnov Distance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def on_validation_epoch_end(self) -> None:
logger.info(" * Registering sample energies")
self.energy_ks_metric.register_predicted_samples(sample_energies.cpu())

logger.info(" * Computing KS distance for energies")
(
ks_distance,
p_value,
Expand All @@ -400,7 +401,7 @@ def on_validation_epoch_end(self) -> None:
self.log(
"validation_ks_p_value_energy", p_value, on_step=False, on_epoch=True
)
logger.info(" * Done logging sample energies")
logger.info(" * Done logging KS distance for energies")

if self.draw_samples and self.metrics_parameters.compute_structure_factor:
logger.info(" * Computing sample distances")
Expand All @@ -413,6 +414,7 @@ def on_validation_epoch_end(self) -> None:
logger.info(" * Registering sample distances")
self.structure_ks_metric.register_predicted_samples(sample_distances.cpu())

logger.info(" * Computing KS distance for distances")
(
ks_distance,
p_value,
Expand Down
12 changes: 10 additions & 2 deletions crystal_diffusion/models/score_networks/egnn_score_network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import AnyStr, Dict
from typing import AnyStr, Dict, Union

import einops
import torch
Expand Down Expand Up @@ -31,7 +31,7 @@ class EGNNScoreNetworkParameters(ScoreNetworkParameters):
message_agg: str = "mean"
n_layers: int = 4
edges: str = 'fully_connected'
radial_cutoff: float = 4.0
radial_cutoff: Union[float, None] = None
drop_duplicate_edges: bool = True


Expand Down Expand Up @@ -60,7 +60,15 @@ def __init__(self, hyper_params: EGNNScoreNetworkParameters):
self.edges = hyper_params.edges
assert self.edges in ["fully_connected", "radial_cutoff"], \
f'Edges type should be fully_connected or radial_cutoff. Got {self.edges}'

self.radial_cutoff = hyper_params.radial_cutoff

if self.edges == "fully_connected":
assert self.radial_cutoff is None, "Specifying a radial cutoff is inconsistent with edges=fully_connected."
else:
assert type(self.radial_cutoff) is float, \
"A floating point value for the radial cutoff is needed for edges=radial_cutoff."

self.drop_duplicate_edges = hyper_params.drop_duplicate_edges

self.egnn = EGNN(
Expand Down
193 changes: 193 additions & 0 deletions crystal_diffusion/sample_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""Sample Diffusion.

This script is the entry point to draw samples from a pre-trained model checkpoint.
"""
import argparse
import logging
import os
import socket
from pathlib import Path
from typing import Any, AnyStr, Dict, Optional, Union

import torch

from crystal_diffusion.generators.instantiate_generator import \
instantiate_generator
from crystal_diffusion.generators.load_sampling_parameters import \
load_sampling_parameters
from crystal_diffusion.generators.position_generator import SamplingParameters
from crystal_diffusion.main_utils import load_and_backup_hyperparameters
from crystal_diffusion.models.position_diffusion_lightning_model import \
PositionDiffusionLightningModel
from crystal_diffusion.models.score_networks import ScoreNetwork
from crystal_diffusion.oracle.energies import compute_oracle_energies
from crystal_diffusion.samplers.variance_sampler import NoiseParameters
from crystal_diffusion.samples.sampling import create_batch_of_samples
from crystal_diffusion.utils.logging_utils import (get_git_hash,
setup_console_logger)

logger = logging.getLogger(__name__)


def main(args: Optional[Any] = None):
"""Load a diffusion model and draw samples.

This main.py file is meant to be called using the cli.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
required=True,
help="config file with sampling parameters in yaml format.",
)
parser.add_argument(
"--checkpoint", required=True, help="path to checkpoint model to be loaded."
)
parser.add_argument(
"--output", required=True, help="path to outputs - will store files here"
)
parser.add_argument(
"--device", default="cuda", help="Device to use. Defaults to cuda."
)
args = parser.parse_args(args)
if os.path.exists(args.output):
logger.info(f"WARNING: the output directory {args.output} already exists!")
else:
os.makedirs(args.output)

setup_console_logger(experiment_dir=args.output)
assert os.path.exists(
args.checkpoint
), f"The path {args.checkpoint} does not exist. Cannot go on."

script_location = os.path.realpath(__file__)
git_hash = get_git_hash(script_location)
hostname = socket.gethostname()
logger.info("Sampling Experiment info:")
logger.info(f" Hostname : {hostname}")
logger.info(f" Git Hash : {git_hash}")
logger.info(f" Checkpoint : {args.checkpoint}")

# Very opinionated logger, which writes to the output folder.
logger.info(f"Start Generating Samples with checkpoint {args.checkpoint}")

hyper_params = load_and_backup_hyperparameters(
config_file_path=args.config, output_directory=args.output
)

device = torch.device(args.device)
noise_parameters, sampling_parameters = extract_and_validate_parameters(
hyper_params
)

create_samples_and_write_to_disk(
noise_parameters=noise_parameters,
sampling_parameters=sampling_parameters,
device=device,
checkpoint_path=args.checkpoint,
output_path=args.output,
)


def extract_and_validate_parameters(hyper_params: Dict[AnyStr, Any]):
"""Extract and validate parameters.

Args:
hyper_params : Dictionary of hyper-parameters for drawing samples.

Returns:
noise_parameters: object that defines the noise schedule
sampling_parameters: object that defines how to draw samples, and how many.
"""
assert (
"noise" in hyper_params
), "The noise parameters must be defined to draw samples."
noise_parameters = NoiseParameters(**hyper_params["noise"])

assert (
"sampling" in hyper_params
), "The sampling parameters must be defined to draw samples."
sampling_parameters = load_sampling_parameters(hyper_params["sampling"])

return noise_parameters, sampling_parameters


def get_sigma_normalized_score_network(
checkpoint_path: Union[str, Path]
) -> ScoreNetwork:
"""Get sigma-normalized score network.

Args:
checkpoint_path : path where the checkpoint is written.

Returns:
sigma_normalized score network: read from the checkpoint.
"""
logger.info("Loading checkpoint...")
pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path)
pl_model.eval()

sigma_normalized_score_network = pl_model.sigma_normalized_score_network
return sigma_normalized_score_network


def create_samples_and_write_to_disk(
noise_parameters: NoiseParameters,
sampling_parameters: SamplingParameters,
device: torch.device,
checkpoint_path: Union[str, Path],
output_path: Union[str, Path],
):
"""Create Samples and write to disk.

Method that drives the creation of samples.

Args:
noise_parameters: object that defines the noise schedule
sampling_parameters: object that defines how to draw samples, and how many.
device: which device should be used to draw samples.
checkpoint_path : path to checkpoint of model to be loaded.
output_path: where the outputs should be written.

Returns:
None
"""
sigma_normalized_score_network = get_sigma_normalized_score_network(checkpoint_path)

logger.info("Instantiate generator...")
position_generator = instantiate_generator(
sampling_parameters=sampling_parameters,
noise_parameters=noise_parameters,
sigma_normalized_score_network=sigma_normalized_score_network,
)

logger.info("Generating samples...")
with torch.no_grad():
samples_batch = create_batch_of_samples(
generator=position_generator,
sampling_parameters=sampling_parameters,
device=device,
)
logger.info("Done Generating Samples.")

logger.info("Writing samples to disk...")
output_directory = Path(output_path)
with open(output_directory / "samples.pt", "wb") as fd:
torch.save(samples_batch, fd)

logger.info("Compute energy from Oracle...")
sample_energies = compute_oracle_energies(samples_batch)

logger.info("Writing energies to disk...")
with open(output_directory / "energies.pt", "wb") as fd:
torch.save(sample_energies, fd)

if sampling_parameters.record_samples:
logger.info("Writing sampling trajectories to disk...")
position_generator.sample_trajectory_recorder.write_to_pickle(
output_directory / "trajectories.pt"
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions examples/config_files/diffusion/config_diffusion_egnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ model:
normalize: True
residual: True
tanh: False
edges: fully_connected
noise:
total_time_steps: 1000
sigma_min: 0.0001
Expand Down
78 changes: 78 additions & 0 deletions experiment_analysis/dataset_analysis/dataset_covariance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Effective Dataset Variance.

The goal of this script is to compute the effective "sigma_d" of the
actual datasets, that is, the standard deviation of the displacement
from equilibrium, in fractional coordinates.
"""
import logging

import einops
import torch
from tqdm import tqdm

from crystal_diffusion import ANALYSIS_RESULTS_DIR, DATA_DIR
from crystal_diffusion.data.diffusion.data_loader import (
LammpsForDiffusionDataModule, LammpsLoaderParameters)
from crystal_diffusion.utils.basis_transformations import \
map_relative_coordinates_to_unit_cell
from crystal_diffusion.utils.logging_utils import setup_analysis_logger

logger = logging.getLogger(__name__)
dataset_name = 'si_diffusion_2x2x2'
# dataset_name = 'si_diffusion_1x1x1'

output_dir = ANALYSIS_RESULTS_DIR / "covariances"
output_dir.mkdir(exist_ok=True)


if dataset_name == 'si_diffusion_1x1x1':
max_atom = 8
translation = torch.tensor([0.125, 0.125, 0.125])
elif dataset_name == 'si_diffusion_2x2x2':
max_atom = 64
translation = torch.tensor([0.0625, 0.0625, 0.0625])

lammps_run_dir = DATA_DIR / dataset_name
processed_dataset_dir = lammps_run_dir / "processed"

cache_dir = lammps_run_dir / "cache"

data_params = LammpsLoaderParameters(batch_size=2048, max_atom=max_atom)

if __name__ == '__main__':
setup_analysis_logger()
logger.info(f"Computing the covariance matrix for {dataset_name}")

datamodule = LammpsForDiffusionDataModule(
lammps_run_dir=lammps_run_dir,
processed_dataset_dir=processed_dataset_dir,
hyper_params=data_params,
working_cache_dir=cache_dir,
)
datamodule.setup()

train_dataset = datamodule.train_dataset

list_means = []
for batch in tqdm(datamodule.train_dataloader(), "Mean"):
x = map_relative_coordinates_to_unit_cell(batch['relative_coordinates'] + translation)
list_means.append(x.mean(dim=0))

# Drop the last batch, which might not have dimension batch_size
x0 = torch.stack(list_means[:-1]).mean(dim=0)

list_covariances = []
list_sizes = []
for batch in tqdm(datamodule.train_dataloader(), "displacements"):
x = map_relative_coordinates_to_unit_cell(batch['relative_coordinates'] + translation)
list_sizes.append(x.shape[0])
displacements = einops.rearrange(x - x0, "batch natoms space -> batch (natoms space)")
covariance = (displacements[:, None, :] * displacements[:, :, None]).sum(dim=0)
list_covariances.append(covariance)

covariance = torch.stack(list_covariances).sum(dim=0) / sum(list_sizes)

output_file = output_dir / f"covariance_{dataset_name}.pkl"
logger.info(f"Writing to file {output_file}...")
with open(output_file, 'wb') as fd:
torch.save(covariance, fd)
Loading
Loading