Skip to content

Commit

Permalink
Merge branch 'sampling_script' into refactor_oct19
Browse files Browse the repository at this point in the history
  • Loading branch information
rousseab committed Oct 19, 2024
2 parents fe7e869 + 1dcb211 commit 0aaa0c8
Show file tree
Hide file tree
Showing 2 changed files with 350 additions and 0 deletions.
193 changes: 193 additions & 0 deletions crystal_diffusion/sample_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""Sample Diffusion.
This script is the entry point to draw samples from a pre-trained model checkpoint.
"""
import argparse
import logging
import os
import socket
from pathlib import Path
from typing import Any, AnyStr, Dict, Optional, Union

import torch

from crystal_diffusion.generators.instantiate_generator import \
instantiate_generator
from crystal_diffusion.generators.load_sampling_parameters import \
load_sampling_parameters
from crystal_diffusion.generators.position_generator import SamplingParameters
from crystal_diffusion.main_utils import load_and_backup_hyperparameters
from crystal_diffusion.models.position_diffusion_lightning_model import \
PositionDiffusionLightningModel
from crystal_diffusion.models.score_networks import ScoreNetwork
from crystal_diffusion.oracle.energies import compute_oracle_energies
from crystal_diffusion.samplers.variance_sampler import NoiseParameters
from crystal_diffusion.samples.sampling import create_batch_of_samples
from crystal_diffusion.utils.logging_utils import (get_git_hash,
setup_console_logger)

logger = logging.getLogger(__name__)


def main(args: Optional[Any] = None):
"""Load a diffusion model and draw samples.
This main.py file is meant to be called using the cli.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
required=True,
help="config file with sampling parameters in yaml format.",
)
parser.add_argument(
"--checkpoint", required=True, help="path to checkpoint model to be loaded."
)
parser.add_argument(
"--output", required=True, help="path to outputs - will store files here"
)
parser.add_argument(
"--device", default="gpu", help="Device to use. Defaults to cuda."
)
args = parser.parse_args(args)
if os.path.exists(args.output):
logger.info(f"WARNING: the output directory {args.output} already exists!")
else:
os.makedirs(args.output)

setup_console_logger(experiment_dir=args.output)
assert os.path.exists(
args.checkpoint
), f"The path {args.checkpoint} does not exist. Cannot go on."

script_location = os.path.realpath(__file__)
git_hash = get_git_hash(script_location)
hostname = socket.gethostname()
logger.info("Sampling Experiment info:")
logger.info(f" Hostname : {hostname}")
logger.info(f" Git Hash : {git_hash}")
logger.info(f" Checkpoint : {args.checkpoint}")

# Very opinionated logger, which writes to the output folder.
logger.info(f"Start Generating Samples with checkpoint {args.checkpoint}")

hyper_params = load_and_backup_hyperparameters(
config_file_path=args.config, output_directory=args.output
)

device = torch.device(args.device)
noise_parameters, sampling_parameters = extract_and_validate_parameters(
hyper_params
)

create_samples_and_write_to_disk(
noise_parameters=noise_parameters,
sampling_parameters=sampling_parameters,
device=device,
checkpoint_path=args.checkpoint,
output_path=args.output,
)


def extract_and_validate_parameters(hyper_params: Dict[AnyStr, Any]):
"""Extract and validate parameters.
Args:
hyper_params : Dictionary of hyper-parameters for drawing samples.
Returns:
noise_parameters: object that defines the noise schedule
sampling_parameters: object that defines how to draw samples, and how many.
"""
assert (
"noise" in hyper_params
), "The noise parameters must be defined to draw samples."
noise_parameters = NoiseParameters(**hyper_params["noise"])

assert (
"sampling" in hyper_params
), "The sampling parameters must be defined to draw samples."
sampling_parameters = load_sampling_parameters(hyper_params["sampling"])

return noise_parameters, sampling_parameters


def get_sigma_normalized_score_network(
checkpoint_path: Union[str, Path]
) -> ScoreNetwork:
"""Get sigma-normalized score network.
Args:
checkpoint_path : path where the checkpoint is written.
Returns:
sigma_normalized score network: read from the checkpoint.
"""
logger.info("Loading checkpoint...")
pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path)
pl_model.eval()

sigma_normalized_score_network = pl_model.sigma_normalized_score_network
return sigma_normalized_score_network


def create_samples_and_write_to_disk(
noise_parameters: NoiseParameters,
sampling_parameters: SamplingParameters,
device: torch.device,
checkpoint_path: Union[str, Path],
output_path: Union[str, Path],
):
"""Create Samples and write to disk.
Method that drives the creation of samples.
Args:
noise_parameters: object that defines the noise schedule
sampling_parameters: object that defines how to draw samples, and how many.
device: which device should be used to draw samples.
checkpoint_path : path to checkpoint of model to be loaded.
output_path: where the outputs should be written.
Returns:
None
"""
sigma_normalized_score_network = get_sigma_normalized_score_network(checkpoint_path)

logger.info("Instantiate generator...")
position_generator = instantiate_generator(
sampling_parameters=sampling_parameters,
noise_parameters=noise_parameters,
sigma_normalized_score_network=sigma_normalized_score_network,
)

logger.info("Generating samples...")
with torch.no_grad():
samples_batch = create_batch_of_samples(
generator=position_generator,
sampling_parameters=sampling_parameters,
device=device,
)
logger.info("Done Generating Samples.")

logger.info("Writing samples to disk...")
output_directory = Path(output_path)
with open(output_directory / "samples.pt", "wb") as fd:
torch.save(samples_batch, fd)

logger.info("Compute energy from Oracle...")
sample_energies = compute_oracle_energies(samples_batch)

logger.info("Writing energies to disk...")
with open(output_directory / "energies.pt", "wb") as fd:
torch.save(sample_energies, fd)

if sampling_parameters.record_samples:
logger.info("Writing sampling trajectories to disk...")
position_generator.sample_trajectory_recorder.write_to_pickle(
output_directory / "trajectories.pt"
)


if __name__ == "__main__":
main()
157 changes: 157 additions & 0 deletions tests/test_sample_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import dataclasses

import pytest
import torch
import yaml

from crystal_diffusion import sample_diffusion
from crystal_diffusion.generators.predictor_corrector_position_generator import \
PredictorCorrectorSamplingParameters
from crystal_diffusion.models.loss import MSELossParameters
from crystal_diffusion.models.optimizer import OptimizerParameters
from crystal_diffusion.models.position_diffusion_lightning_model import (
PositionDiffusionLightningModel, PositionDiffusionParameters)
from crystal_diffusion.models.score_networks.mlp_score_network import \
MLPScoreNetworkParameters
from crystal_diffusion.namespace import RELATIVE_COORDINATES
from crystal_diffusion.samplers.variance_sampler import NoiseParameters


@pytest.fixture()
def spatial_dimension():
return 3


@pytest.fixture()
def number_of_atoms():
return 8


@pytest.fixture()
def number_of_samples():
return 12


@pytest.fixture()
def cell_dimensions():
return [5.1, 6.2, 7.3]


@pytest.fixture(params=[True, False])
def record_samples(request):
return request.param


@pytest.fixture()
def noise_parameters():
return NoiseParameters(total_time_steps=10)


@pytest.fixture()
def sampling_parameters(
number_of_atoms,
spatial_dimension,
number_of_samples,
cell_dimensions,
record_samples,
):
return PredictorCorrectorSamplingParameters(
number_of_corrector_steps=1,
spatial_dimension=spatial_dimension,
number_of_atoms=number_of_atoms,
number_of_samples=number_of_samples,
cell_dimensions=cell_dimensions,
record_samples=record_samples,
)


@pytest.fixture()
def sigma_normalized_score_network(number_of_atoms, noise_parameters):
score_network_parameters = MLPScoreNetworkParameters(
number_of_atoms=number_of_atoms,
embedding_dimensions_size=8,
n_hidden_dimensions=2,
hidden_dimensions_size=16,
)

diffusion_params = PositionDiffusionParameters(
score_network_parameters=score_network_parameters,
loss_parameters=MSELossParameters(),
optimizer_parameters=OptimizerParameters(name="adam", learning_rate=1e-3),
scheduler_parameters=None,
noise_parameters=noise_parameters,
diffusion_sampling_parameters=None,
)

model = PositionDiffusionLightningModel(diffusion_params)
return model.sigma_normalized_score_network


@pytest.fixture()
def config_path(tmp_path, noise_parameters, sampling_parameters):
config_path = str(tmp_path / "test_config.yaml")

config = dict(
noise=dataclasses.asdict(noise_parameters),
sampling=dataclasses.asdict(sampling_parameters),
)

with open(config_path, "w") as fd:
yaml.dump(config, fd)

return config_path


@pytest.fixture()
def checkpoint_path(tmp_path):
path_to_checkpoint = tmp_path / "fake_checkpoint.pt"
with open(path_to_checkpoint, "w") as fd:
fd.write("This is a dummy checkpoint file.")
return path_to_checkpoint


@pytest.fixture()
def output_path(tmp_path):
output = tmp_path / "output"
return output


@pytest.fixture()
def args(config_path, checkpoint_path, output_path):
"""Input arguments for main."""
input_args = [
f"--config={config_path}",
f"--checkpoint={checkpoint_path}",
f"--output={output_path}",
"--device=cpu",
]

return input_args


def test_sample_diffusion(
mocker,
args,
sigma_normalized_score_network,
output_path,
number_of_samples,
number_of_atoms,
spatial_dimension,
record_samples,
):
mocker.patch(
"crystal_diffusion.sample_diffusion.get_sigma_normalized_score_network",
return_value=sigma_normalized_score_network,
)

sample_diffusion.main(args)

assert (output_path / "samples.pt").exists()
samples = torch.load(output_path / "samples.pt")
assert samples[RELATIVE_COORDINATES].shape == (
number_of_samples,
number_of_atoms,
spatial_dimension,
)

assert (output_path / "trajectories.pt").exists() == record_samples

0 comments on commit 0aaa0c8

Please sign in to comment.