-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'sampling_script' into refactor_oct19
- Loading branch information
Showing
2 changed files
with
350 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
"""Sample Diffusion. | ||
This script is the entry point to draw samples from a pre-trained model checkpoint. | ||
""" | ||
import argparse | ||
import logging | ||
import os | ||
import socket | ||
from pathlib import Path | ||
from typing import Any, AnyStr, Dict, Optional, Union | ||
|
||
import torch | ||
|
||
from crystal_diffusion.generators.instantiate_generator import \ | ||
instantiate_generator | ||
from crystal_diffusion.generators.load_sampling_parameters import \ | ||
load_sampling_parameters | ||
from crystal_diffusion.generators.position_generator import SamplingParameters | ||
from crystal_diffusion.main_utils import load_and_backup_hyperparameters | ||
from crystal_diffusion.models.position_diffusion_lightning_model import \ | ||
PositionDiffusionLightningModel | ||
from crystal_diffusion.models.score_networks import ScoreNetwork | ||
from crystal_diffusion.oracle.energies import compute_oracle_energies | ||
from crystal_diffusion.samplers.variance_sampler import NoiseParameters | ||
from crystal_diffusion.samples.sampling import create_batch_of_samples | ||
from crystal_diffusion.utils.logging_utils import (get_git_hash, | ||
setup_console_logger) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def main(args: Optional[Any] = None): | ||
"""Load a diffusion model and draw samples. | ||
This main.py file is meant to be called using the cli. | ||
""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--config", | ||
required=True, | ||
help="config file with sampling parameters in yaml format.", | ||
) | ||
parser.add_argument( | ||
"--checkpoint", required=True, help="path to checkpoint model to be loaded." | ||
) | ||
parser.add_argument( | ||
"--output", required=True, help="path to outputs - will store files here" | ||
) | ||
parser.add_argument( | ||
"--device", default="gpu", help="Device to use. Defaults to cuda." | ||
) | ||
args = parser.parse_args(args) | ||
if os.path.exists(args.output): | ||
logger.info(f"WARNING: the output directory {args.output} already exists!") | ||
else: | ||
os.makedirs(args.output) | ||
|
||
setup_console_logger(experiment_dir=args.output) | ||
assert os.path.exists( | ||
args.checkpoint | ||
), f"The path {args.checkpoint} does not exist. Cannot go on." | ||
|
||
script_location = os.path.realpath(__file__) | ||
git_hash = get_git_hash(script_location) | ||
hostname = socket.gethostname() | ||
logger.info("Sampling Experiment info:") | ||
logger.info(f" Hostname : {hostname}") | ||
logger.info(f" Git Hash : {git_hash}") | ||
logger.info(f" Checkpoint : {args.checkpoint}") | ||
|
||
# Very opinionated logger, which writes to the output folder. | ||
logger.info(f"Start Generating Samples with checkpoint {args.checkpoint}") | ||
|
||
hyper_params = load_and_backup_hyperparameters( | ||
config_file_path=args.config, output_directory=args.output | ||
) | ||
|
||
device = torch.device(args.device) | ||
noise_parameters, sampling_parameters = extract_and_validate_parameters( | ||
hyper_params | ||
) | ||
|
||
create_samples_and_write_to_disk( | ||
noise_parameters=noise_parameters, | ||
sampling_parameters=sampling_parameters, | ||
device=device, | ||
checkpoint_path=args.checkpoint, | ||
output_path=args.output, | ||
) | ||
|
||
|
||
def extract_and_validate_parameters(hyper_params: Dict[AnyStr, Any]): | ||
"""Extract and validate parameters. | ||
Args: | ||
hyper_params : Dictionary of hyper-parameters for drawing samples. | ||
Returns: | ||
noise_parameters: object that defines the noise schedule | ||
sampling_parameters: object that defines how to draw samples, and how many. | ||
""" | ||
assert ( | ||
"noise" in hyper_params | ||
), "The noise parameters must be defined to draw samples." | ||
noise_parameters = NoiseParameters(**hyper_params["noise"]) | ||
|
||
assert ( | ||
"sampling" in hyper_params | ||
), "The sampling parameters must be defined to draw samples." | ||
sampling_parameters = load_sampling_parameters(hyper_params["sampling"]) | ||
|
||
return noise_parameters, sampling_parameters | ||
|
||
|
||
def get_sigma_normalized_score_network( | ||
checkpoint_path: Union[str, Path] | ||
) -> ScoreNetwork: | ||
"""Get sigma-normalized score network. | ||
Args: | ||
checkpoint_path : path where the checkpoint is written. | ||
Returns: | ||
sigma_normalized score network: read from the checkpoint. | ||
""" | ||
logger.info("Loading checkpoint...") | ||
pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) | ||
pl_model.eval() | ||
|
||
sigma_normalized_score_network = pl_model.sigma_normalized_score_network | ||
return sigma_normalized_score_network | ||
|
||
|
||
def create_samples_and_write_to_disk( | ||
noise_parameters: NoiseParameters, | ||
sampling_parameters: SamplingParameters, | ||
device: torch.device, | ||
checkpoint_path: Union[str, Path], | ||
output_path: Union[str, Path], | ||
): | ||
"""Create Samples and write to disk. | ||
Method that drives the creation of samples. | ||
Args: | ||
noise_parameters: object that defines the noise schedule | ||
sampling_parameters: object that defines how to draw samples, and how many. | ||
device: which device should be used to draw samples. | ||
checkpoint_path : path to checkpoint of model to be loaded. | ||
output_path: where the outputs should be written. | ||
Returns: | ||
None | ||
""" | ||
sigma_normalized_score_network = get_sigma_normalized_score_network(checkpoint_path) | ||
|
||
logger.info("Instantiate generator...") | ||
position_generator = instantiate_generator( | ||
sampling_parameters=sampling_parameters, | ||
noise_parameters=noise_parameters, | ||
sigma_normalized_score_network=sigma_normalized_score_network, | ||
) | ||
|
||
logger.info("Generating samples...") | ||
with torch.no_grad(): | ||
samples_batch = create_batch_of_samples( | ||
generator=position_generator, | ||
sampling_parameters=sampling_parameters, | ||
device=device, | ||
) | ||
logger.info("Done Generating Samples.") | ||
|
||
logger.info("Writing samples to disk...") | ||
output_directory = Path(output_path) | ||
with open(output_directory / "samples.pt", "wb") as fd: | ||
torch.save(samples_batch, fd) | ||
|
||
logger.info("Compute energy from Oracle...") | ||
sample_energies = compute_oracle_energies(samples_batch) | ||
|
||
logger.info("Writing energies to disk...") | ||
with open(output_directory / "energies.pt", "wb") as fd: | ||
torch.save(sample_energies, fd) | ||
|
||
if sampling_parameters.record_samples: | ||
logger.info("Writing sampling trajectories to disk...") | ||
position_generator.sample_trajectory_recorder.write_to_pickle( | ||
output_directory / "trajectories.pt" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import dataclasses | ||
|
||
import pytest | ||
import torch | ||
import yaml | ||
|
||
from crystal_diffusion import sample_diffusion | ||
from crystal_diffusion.generators.predictor_corrector_position_generator import \ | ||
PredictorCorrectorSamplingParameters | ||
from crystal_diffusion.models.loss import MSELossParameters | ||
from crystal_diffusion.models.optimizer import OptimizerParameters | ||
from crystal_diffusion.models.position_diffusion_lightning_model import ( | ||
PositionDiffusionLightningModel, PositionDiffusionParameters) | ||
from crystal_diffusion.models.score_networks.mlp_score_network import \ | ||
MLPScoreNetworkParameters | ||
from crystal_diffusion.namespace import RELATIVE_COORDINATES | ||
from crystal_diffusion.samplers.variance_sampler import NoiseParameters | ||
|
||
|
||
@pytest.fixture() | ||
def spatial_dimension(): | ||
return 3 | ||
|
||
|
||
@pytest.fixture() | ||
def number_of_atoms(): | ||
return 8 | ||
|
||
|
||
@pytest.fixture() | ||
def number_of_samples(): | ||
return 12 | ||
|
||
|
||
@pytest.fixture() | ||
def cell_dimensions(): | ||
return [5.1, 6.2, 7.3] | ||
|
||
|
||
@pytest.fixture(params=[True, False]) | ||
def record_samples(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture() | ||
def noise_parameters(): | ||
return NoiseParameters(total_time_steps=10) | ||
|
||
|
||
@pytest.fixture() | ||
def sampling_parameters( | ||
number_of_atoms, | ||
spatial_dimension, | ||
number_of_samples, | ||
cell_dimensions, | ||
record_samples, | ||
): | ||
return PredictorCorrectorSamplingParameters( | ||
number_of_corrector_steps=1, | ||
spatial_dimension=spatial_dimension, | ||
number_of_atoms=number_of_atoms, | ||
number_of_samples=number_of_samples, | ||
cell_dimensions=cell_dimensions, | ||
record_samples=record_samples, | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def sigma_normalized_score_network(number_of_atoms, noise_parameters): | ||
score_network_parameters = MLPScoreNetworkParameters( | ||
number_of_atoms=number_of_atoms, | ||
embedding_dimensions_size=8, | ||
n_hidden_dimensions=2, | ||
hidden_dimensions_size=16, | ||
) | ||
|
||
diffusion_params = PositionDiffusionParameters( | ||
score_network_parameters=score_network_parameters, | ||
loss_parameters=MSELossParameters(), | ||
optimizer_parameters=OptimizerParameters(name="adam", learning_rate=1e-3), | ||
scheduler_parameters=None, | ||
noise_parameters=noise_parameters, | ||
diffusion_sampling_parameters=None, | ||
) | ||
|
||
model = PositionDiffusionLightningModel(diffusion_params) | ||
return model.sigma_normalized_score_network | ||
|
||
|
||
@pytest.fixture() | ||
def config_path(tmp_path, noise_parameters, sampling_parameters): | ||
config_path = str(tmp_path / "test_config.yaml") | ||
|
||
config = dict( | ||
noise=dataclasses.asdict(noise_parameters), | ||
sampling=dataclasses.asdict(sampling_parameters), | ||
) | ||
|
||
with open(config_path, "w") as fd: | ||
yaml.dump(config, fd) | ||
|
||
return config_path | ||
|
||
|
||
@pytest.fixture() | ||
def checkpoint_path(tmp_path): | ||
path_to_checkpoint = tmp_path / "fake_checkpoint.pt" | ||
with open(path_to_checkpoint, "w") as fd: | ||
fd.write("This is a dummy checkpoint file.") | ||
return path_to_checkpoint | ||
|
||
|
||
@pytest.fixture() | ||
def output_path(tmp_path): | ||
output = tmp_path / "output" | ||
return output | ||
|
||
|
||
@pytest.fixture() | ||
def args(config_path, checkpoint_path, output_path): | ||
"""Input arguments for main.""" | ||
input_args = [ | ||
f"--config={config_path}", | ||
f"--checkpoint={checkpoint_path}", | ||
f"--output={output_path}", | ||
"--device=cpu", | ||
] | ||
|
||
return input_args | ||
|
||
|
||
def test_sample_diffusion( | ||
mocker, | ||
args, | ||
sigma_normalized_score_network, | ||
output_path, | ||
number_of_samples, | ||
number_of_atoms, | ||
spatial_dimension, | ||
record_samples, | ||
): | ||
mocker.patch( | ||
"crystal_diffusion.sample_diffusion.get_sigma_normalized_score_network", | ||
return_value=sigma_normalized_score_network, | ||
) | ||
|
||
sample_diffusion.main(args) | ||
|
||
assert (output_path / "samples.pt").exists() | ||
samples = torch.load(output_path / "samples.pt") | ||
assert samples[RELATIVE_COORDINATES].shape == ( | ||
number_of_samples, | ||
number_of_atoms, | ||
spatial_dimension, | ||
) | ||
|
||
assert (output_path / "trajectories.pt").exists() == record_samples |