Skip to content

Commit

Permalink
Merge pull request #117 from luigibonati/fix_device_checkpoint_test
Browse files Browse the repository at this point in the history
Fix device issue checkpoint test
  • Loading branch information
EnricoTrizio authored Jan 16, 2024
2 parents 973ef81 + 06f75b2 commit aa90b43
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions mlcolvar/tests/test_cvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def dataset():
def test_resume_from_checkpoint(cv_model, dataset):
"""CVs correctly resume from a checkpoint."""
datamodule = DictModule(dataset, lengths=[1.0,0.], batch_size=len(dataset))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Run a few steps of training in a temporary directory.
with tempfile.TemporaryDirectory() as tmp_dir_path:
Expand All @@ -94,9 +95,9 @@ def test_resume_from_checkpoint(cv_model, dataset):
cv_model2 = cv_model.__class__.load_from_checkpoint(checkpoint_file_path)

# Check that state is the same.
x = dataset['data']
cv_model.eval()
cv_model2.eval()
x = dataset['data'].to(device)
cv_model.to(device).eval()
cv_model2.to(device).eval()
assert torch.allclose(cv_model(x), cv_model2(x))

def test_lr_scheduler():
Expand Down

0 comments on commit aa90b43

Please sign in to comment.