Skip to content

Commit

Permalink
test constrained sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
rousseab committed Dec 3, 2024
1 parent cfb6ae9 commit 5f858e6
Showing 1 changed file with 51 additions and 8 deletions.
59 changes: 51 additions & 8 deletions tests/test_sample_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
OptimizerParameters
from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.mlp_score_network import \
MLPScoreNetworkParameters
from diffusion_for_multi_scale_molecular_dynamics.namespace import \
AXL_COMPOSITION
from diffusion_for_multi_scale_molecular_dynamics.namespace import (
AXL, AXL_COMPOSITION)
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \
NoiseParameters

Expand Down Expand Up @@ -46,6 +46,19 @@ def cell_dimensions():
return [5.1, 6.2, 7.3]


@pytest.fixture()
def reference_composition(num_atom_types, number_of_atoms, spatial_dimension):
a = torch.randint(0, num_atom_types, (number_of_atoms,))
x = torch.rand(number_of_atoms, spatial_dimension)
lat = torch.rand(spatial_dimension, spatial_dimension)
return AXL(A=a, X=x, L=lat)


@pytest.fixture()
def constrained_atom_indices(number_of_atoms):
return torch.sort(torch.randperm(number_of_atoms)[:number_of_atoms // 2]).values


@pytest.fixture(params=[True, False])
def record_samples(request):
return request.param
Expand Down Expand Up @@ -77,7 +90,7 @@ def sampling_parameters(


@pytest.fixture()
def axl_network(number_of_atoms, noise_parameters, num_atom_types):
def axl_network(number_of_atoms, noise_parameters, num_atom_types, device):
score_network_parameters = MLPScoreNetworkParameters(
number_of_atoms=number_of_atoms,
num_atom_types=num_atom_types,
Expand All @@ -97,7 +110,7 @@ def axl_network(number_of_atoms, noise_parameters, num_atom_types):
diffusion_sampling_parameters=None,
)

model = AXLDiffusionLightningModel(diffusion_params)
model = AXLDiffusionLightningModel(diffusion_params).to(device)
return model.axl_network


Expand All @@ -116,6 +129,20 @@ def config_path(tmp_path, noise_parameters, sampling_parameters):
return config_path


@pytest.fixture(params=[True, False])
def apply_constraint(request):
return request.param


@pytest.fixture()
def constraint_data_pickle_path(tmp_path, reference_composition, constrained_atom_indices):
path_to_pickle = tmp_path / "pickle_path.pkl"
data = dict(reference_composition=reference_composition,
constrained_atom_indices=constrained_atom_indices)
torch.save(data, path_to_pickle)
return path_to_pickle


@pytest.fixture()
def checkpoint_path(tmp_path):
path_to_checkpoint = tmp_path / "fake_checkpoint.pt"
Expand All @@ -131,27 +158,34 @@ def output_path(tmp_path):


@pytest.fixture()
def args(config_path, checkpoint_path, output_path):
def args(config_path, checkpoint_path, output_path, constraint_data_pickle_path, apply_constraint, device):
"""Input arguments for main."""
input_args = [
f"--config={config_path}",
f"--checkpoint={checkpoint_path}",
f"--output={output_path}",
"--device=cpu",
f"--device={device}",
]

if apply_constraint:
input_args.append(f"--path_to_constraint_data_pickle={constraint_data_pickle_path}")

return input_args


def test_sample_diffusion(
mocker,
device,
args,
axl_network,
output_path,
number_of_samples,
number_of_atoms,
spatial_dimension,
record_samples,
apply_constraint,
reference_composition,
constrained_atom_indices
):
mocker.patch(
"diffusion_for_multi_scale_molecular_dynamics.sample_diffusion.get_axl_network",
Expand All @@ -162,14 +196,23 @@ def test_sample_diffusion(

assert (output_path / "samples.pt").exists()
samples = torch.load(output_path / "samples.pt")
assert samples[AXL_COMPOSITION].X.shape == (
compositions = samples[AXL_COMPOSITION]

assert compositions.X.shape == (
number_of_samples,
number_of_atoms,
spatial_dimension,
)
assert samples[AXL_COMPOSITION].A.shape == (
assert compositions.A.shape == (
number_of_samples,
number_of_atoms,
)

assert (output_path / "trajectories.pt").exists() == record_samples

if apply_constraint:
reference_x = reference_composition.X[constrained_atom_indices].to(device)
reference_a = reference_composition.A[constrained_atom_indices].to(device)

assert (compositions.X[:, constrained_atom_indices] == reference_x).all()
assert (compositions.A[:, constrained_atom_indices] == reference_a).all()

0 comments on commit 5f858e6

Please sign in to comment.