Skip to content

Commit

Permalink
isort error
Browse files Browse the repository at this point in the history
  • Loading branch information
sblackburn-mila committed May 30, 2024
1 parent 4993c42 commit 79661ba
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
25 changes: 15 additions & 10 deletions tests/data/diffusion/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,27 @@

from crystal_diffusion.data.diffusion.data_loader import (
LammpsForDiffusionDataModule, LammpsLoaderParameters)
from crystal_diffusion.namespace import (CARTESIAN_POSITIONS,
from crystal_diffusion.namespace import (CARTESIAN_FORCES, CARTESIAN_POSITIONS,
RELATIVE_COORDINATES)
from tests.conftest import TestDiffusionDataBase
from tests.fake_data_utils import Configuration, find_aligning_permutation


def convert_configurations_to_dataset(configurations: List[Configuration]) -> Dict[str, torch.Tensor]:
"""Convert the input configuration into a dict of torch tensors comparable to a pytorch dataset."""
# The expected dataset keys are {'natom', 'box', 'position', 'relative_positions', 'type', 'potential_energy'}
# The expected dataset keys are {'natom', 'box', 'cartesian_positions', 'relative_positions', 'type',
# 'cartesian_forces', 'potential_energy'}
data = defaultdict(list)
for configuration in configurations:
data['natom'].append(len(configuration.ids))
data['box'].append(configuration.cell_dimensions)
data[CARTESIAN_POSITIONS].append(configuration.positions)
data[CARTESIAN_FORCES].append(configuration.cartesian_forces)
data[CARTESIAN_POSITIONS].append(configuration.cartesian_positions)
data[RELATIVE_COORDINATES].append(configuration.relative_coordinates)
data['type'].append(configuration.types)
data['potential_energy'].append(configuration.potential_energy)


configuration_dataset = dict()
for key, array in data.items():
configuration_dataset[key] = torch.tensor(array)
Expand All @@ -37,16 +40,17 @@ def input_data_to_transform(self):
return {
'natom': [2], # batch size of 1
'box': [[1.0, 1.0, 1.0]],
'position': [[1., 2., 3, 4., 5, 6]], # for one batch, two atoms, 3D positions
'relative_positions': [[1., 2., 3, 4., 5, 6]],
CARTESIAN_POSITIONS: [[1., 2., 3, 4., 5, 6]], # for one batch, two atoms, 3D positions
CARTESIAN_FORCES: [[11., 12., 13, 14., 15, 16]], # for one batch, two atoms, 3D forces
RELATIVE_COORDINATES: [[1., 2., 3, 4., 5, 6]],
'type': [[1, 2]],
'potential_energy': [23.233],
}

def test_dataset_transform(self, input_data_to_transform):
result = LammpsForDiffusionDataModule.dataset_transform(input_data_to_transform)
# Check keys in result
assert set(result.keys()) == {'natom', CARTESIAN_POSITIONS, RELATIVE_COORDINATES,
assert set(result.keys()) == {'natom', CARTESIAN_FORCES, CARTESIAN_POSITIONS, RELATIVE_COORDINATES,
'box', 'type', 'potential_energy'}

# Check tensor types and shapes
Expand All @@ -68,8 +72,9 @@ def input_data_to_pad(self):
return {
'natom': 2, # batch size of 1
'box': [1.0, 1.0, 1.0],
'position': [1., 2., 3, 4., 5, 6], # for one batch, two atoms, 3D positions
'relative_positions': [1., 2., 3, 4., 5, 6],
CARTESIAN_POSITIONS: [1., 2., 3, 4., 5, 6], # for one batch, two atoms, 3D positions
CARTESIAN_FORCES: [11., 12., 13, 14., 15, 16],
RELATIVE_COORDINATES: [1., 2., 3, 4., 5, 6],
'type': [1, 2],
'potential_energy': 23.233,
}
Expand Down Expand Up @@ -124,8 +129,8 @@ def real_and_test_datasets(self, mode, data_module, all_train_configurations, al
return data_module_dataset, configuration_dataset

def test_dataset_feature_names(self, data_module):
expected_feature_names = {'natom', 'box', 'position', 'relative_positions', 'type', 'potential_energy',
CARTESIAN_POSITIONS, RELATIVE_COORDINATES}
expected_feature_names = {'natom', 'box', 'type', 'potential_energy', CARTESIAN_FORCES, CARTESIAN_POSITIONS,
RELATIVE_COORDINATES}
assert set(data_module.train_dataset.features.keys()) == expected_feature_names
assert set(data_module.valid_dataset.features.keys()) == expected_feature_names

Expand Down
3 changes: 2 additions & 1 deletion tests/data/diffusion/test_data_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from crystal_diffusion.data.diffusion.data_preprocess import \
LammpsProcessorForDiffusion
from crystal_diffusion.namespace import CARTESIAN_POSITIONS, CARTESIAN_FORCES, RELATIVE_COORDINATES
from crystal_diffusion.namespace import (CARTESIAN_FORCES, CARTESIAN_POSITIONS,
RELATIVE_COORDINATES)
from tests.conftest import TestDiffusionDataBase
from tests.fake_data_utils import generate_parquet_dataframe

Expand Down

0 comments on commit 79661ba

Please sign in to comment.