-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #86 from mila-iqia/refactor_oct19
Refactor oct19
- Loading branch information
Showing
14 changed files
with
1,224 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
"""Sample Diffusion. | ||
This script is the entry point to draw samples from a pre-trained model checkpoint. | ||
""" | ||
import argparse | ||
import logging | ||
import os | ||
import socket | ||
from pathlib import Path | ||
from typing import Any, AnyStr, Dict, Optional, Union | ||
|
||
import torch | ||
|
||
from crystal_diffusion.generators.instantiate_generator import \ | ||
instantiate_generator | ||
from crystal_diffusion.generators.load_sampling_parameters import \ | ||
load_sampling_parameters | ||
from crystal_diffusion.generators.position_generator import SamplingParameters | ||
from crystal_diffusion.main_utils import load_and_backup_hyperparameters | ||
from crystal_diffusion.models.position_diffusion_lightning_model import \ | ||
PositionDiffusionLightningModel | ||
from crystal_diffusion.models.score_networks import ScoreNetwork | ||
from crystal_diffusion.oracle.energies import compute_oracle_energies | ||
from crystal_diffusion.samplers.variance_sampler import NoiseParameters | ||
from crystal_diffusion.samples.sampling import create_batch_of_samples | ||
from crystal_diffusion.utils.logging_utils import (get_git_hash, | ||
setup_console_logger) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def main(args: Optional[Any] = None): | ||
"""Load a diffusion model and draw samples. | ||
This main.py file is meant to be called using the cli. | ||
""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--config", | ||
required=True, | ||
help="config file with sampling parameters in yaml format.", | ||
) | ||
parser.add_argument( | ||
"--checkpoint", required=True, help="path to checkpoint model to be loaded." | ||
) | ||
parser.add_argument( | ||
"--output", required=True, help="path to outputs - will store files here" | ||
) | ||
parser.add_argument( | ||
"--device", default="cuda", help="Device to use. Defaults to cuda." | ||
) | ||
args = parser.parse_args(args) | ||
if os.path.exists(args.output): | ||
logger.info(f"WARNING: the output directory {args.output} already exists!") | ||
else: | ||
os.makedirs(args.output) | ||
|
||
setup_console_logger(experiment_dir=args.output) | ||
assert os.path.exists( | ||
args.checkpoint | ||
), f"The path {args.checkpoint} does not exist. Cannot go on." | ||
|
||
script_location = os.path.realpath(__file__) | ||
git_hash = get_git_hash(script_location) | ||
hostname = socket.gethostname() | ||
logger.info("Sampling Experiment info:") | ||
logger.info(f" Hostname : {hostname}") | ||
logger.info(f" Git Hash : {git_hash}") | ||
logger.info(f" Checkpoint : {args.checkpoint}") | ||
|
||
# Very opinionated logger, which writes to the output folder. | ||
logger.info(f"Start Generating Samples with checkpoint {args.checkpoint}") | ||
|
||
hyper_params = load_and_backup_hyperparameters( | ||
config_file_path=args.config, output_directory=args.output | ||
) | ||
|
||
device = torch.device(args.device) | ||
noise_parameters, sampling_parameters = extract_and_validate_parameters( | ||
hyper_params | ||
) | ||
|
||
create_samples_and_write_to_disk( | ||
noise_parameters=noise_parameters, | ||
sampling_parameters=sampling_parameters, | ||
device=device, | ||
checkpoint_path=args.checkpoint, | ||
output_path=args.output, | ||
) | ||
|
||
|
||
def extract_and_validate_parameters(hyper_params: Dict[AnyStr, Any]): | ||
"""Extract and validate parameters. | ||
Args: | ||
hyper_params : Dictionary of hyper-parameters for drawing samples. | ||
Returns: | ||
noise_parameters: object that defines the noise schedule | ||
sampling_parameters: object that defines how to draw samples, and how many. | ||
""" | ||
assert ( | ||
"noise" in hyper_params | ||
), "The noise parameters must be defined to draw samples." | ||
noise_parameters = NoiseParameters(**hyper_params["noise"]) | ||
|
||
assert ( | ||
"sampling" in hyper_params | ||
), "The sampling parameters must be defined to draw samples." | ||
sampling_parameters = load_sampling_parameters(hyper_params["sampling"]) | ||
|
||
return noise_parameters, sampling_parameters | ||
|
||
|
||
def get_sigma_normalized_score_network( | ||
checkpoint_path: Union[str, Path] | ||
) -> ScoreNetwork: | ||
"""Get sigma-normalized score network. | ||
Args: | ||
checkpoint_path : path where the checkpoint is written. | ||
Returns: | ||
sigma_normalized score network: read from the checkpoint. | ||
""" | ||
logger.info("Loading checkpoint...") | ||
pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) | ||
pl_model.eval() | ||
|
||
sigma_normalized_score_network = pl_model.sigma_normalized_score_network | ||
return sigma_normalized_score_network | ||
|
||
|
||
def create_samples_and_write_to_disk( | ||
noise_parameters: NoiseParameters, | ||
sampling_parameters: SamplingParameters, | ||
device: torch.device, | ||
checkpoint_path: Union[str, Path], | ||
output_path: Union[str, Path], | ||
): | ||
"""Create Samples and write to disk. | ||
Method that drives the creation of samples. | ||
Args: | ||
noise_parameters: object that defines the noise schedule | ||
sampling_parameters: object that defines how to draw samples, and how many. | ||
device: which device should be used to draw samples. | ||
checkpoint_path : path to checkpoint of model to be loaded. | ||
output_path: where the outputs should be written. | ||
Returns: | ||
None | ||
""" | ||
sigma_normalized_score_network = get_sigma_normalized_score_network(checkpoint_path) | ||
|
||
logger.info("Instantiate generator...") | ||
position_generator = instantiate_generator( | ||
sampling_parameters=sampling_parameters, | ||
noise_parameters=noise_parameters, | ||
sigma_normalized_score_network=sigma_normalized_score_network, | ||
) | ||
|
||
logger.info("Generating samples...") | ||
with torch.no_grad(): | ||
samples_batch = create_batch_of_samples( | ||
generator=position_generator, | ||
sampling_parameters=sampling_parameters, | ||
device=device, | ||
) | ||
logger.info("Done Generating Samples.") | ||
|
||
logger.info("Writing samples to disk...") | ||
output_directory = Path(output_path) | ||
with open(output_directory / "samples.pt", "wb") as fd: | ||
torch.save(samples_batch, fd) | ||
|
||
logger.info("Compute energy from Oracle...") | ||
sample_energies = compute_oracle_energies(samples_batch) | ||
|
||
logger.info("Writing energies to disk...") | ||
with open(output_directory / "energies.pt", "wb") as fd: | ||
torch.save(sample_energies, fd) | ||
|
||
if sampling_parameters.record_samples: | ||
logger.info("Writing sampling trajectories to disk...") | ||
position_generator.sample_trajectory_recorder.write_to_pickle( | ||
output_directory / "trajectories.pt" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
experiment_analysis/dataset_analysis/dataset_covariance.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
"""Effective Dataset Variance. | ||
The goal of this script is to compute the effective "sigma_d" of the | ||
actual datasets, that is, the standard deviation of the displacement | ||
from equilibrium, in fractional coordinates. | ||
""" | ||
import logging | ||
|
||
import einops | ||
import torch | ||
from tqdm import tqdm | ||
|
||
from crystal_diffusion import ANALYSIS_RESULTS_DIR, DATA_DIR | ||
from crystal_diffusion.data.diffusion.data_loader import ( | ||
LammpsForDiffusionDataModule, LammpsLoaderParameters) | ||
from crystal_diffusion.utils.basis_transformations import \ | ||
map_relative_coordinates_to_unit_cell | ||
from crystal_diffusion.utils.logging_utils import setup_analysis_logger | ||
|
||
logger = logging.getLogger(__name__) | ||
dataset_name = 'si_diffusion_2x2x2' | ||
# dataset_name = 'si_diffusion_1x1x1' | ||
|
||
output_dir = ANALYSIS_RESULTS_DIR / "covariances" | ||
output_dir.mkdir(exist_ok=True) | ||
|
||
|
||
if dataset_name == 'si_diffusion_1x1x1': | ||
max_atom = 8 | ||
translation = torch.tensor([0.125, 0.125, 0.125]) | ||
elif dataset_name == 'si_diffusion_2x2x2': | ||
max_atom = 64 | ||
translation = torch.tensor([0.0625, 0.0625, 0.0625]) | ||
|
||
lammps_run_dir = DATA_DIR / dataset_name | ||
processed_dataset_dir = lammps_run_dir / "processed" | ||
|
||
cache_dir = lammps_run_dir / "cache" | ||
|
||
data_params = LammpsLoaderParameters(batch_size=2048, max_atom=max_atom) | ||
|
||
if __name__ == '__main__': | ||
setup_analysis_logger() | ||
logger.info(f"Computing the covariance matrix for {dataset_name}") | ||
|
||
datamodule = LammpsForDiffusionDataModule( | ||
lammps_run_dir=lammps_run_dir, | ||
processed_dataset_dir=processed_dataset_dir, | ||
hyper_params=data_params, | ||
working_cache_dir=cache_dir, | ||
) | ||
datamodule.setup() | ||
|
||
train_dataset = datamodule.train_dataset | ||
|
||
list_means = [] | ||
for batch in tqdm(datamodule.train_dataloader(), "Mean"): | ||
x = map_relative_coordinates_to_unit_cell(batch['relative_coordinates'] + translation) | ||
list_means.append(x.mean(dim=0)) | ||
|
||
# Drop the last batch, which might not have dimension batch_size | ||
x0 = torch.stack(list_means[:-1]).mean(dim=0) | ||
|
||
list_covariances = [] | ||
list_sizes = [] | ||
for batch in tqdm(datamodule.train_dataloader(), "displacements"): | ||
x = map_relative_coordinates_to_unit_cell(batch['relative_coordinates'] + translation) | ||
list_sizes.append(x.shape[0]) | ||
displacements = einops.rearrange(x - x0, "batch natoms space -> batch (natoms space)") | ||
covariance = (displacements[:, None, :] * displacements[:, :, None]).sum(dim=0) | ||
list_covariances.append(covariance) | ||
|
||
covariance = torch.stack(list_covariances).sum(dim=0) / sum(list_sizes) | ||
|
||
output_file = output_dir / f"covariance_{dataset_name}.pkl" | ||
logger.info(f"Writing to file {output_file}...") | ||
with open(output_file, 'wb') as fd: | ||
torch.save(covariance, fd) |
Oops, something went wrong.