diff --git a/mlcolvar/tests/test_cvs.py b/mlcolvar/tests/test_cvs.py index c7bc0c30..40d1425a 100644 --- a/mlcolvar/tests/test_cvs.py +++ b/mlcolvar/tests/test_cvs.py @@ -74,6 +74,7 @@ def dataset(): def test_resume_from_checkpoint(cv_model, dataset): """CVs correctly resume from a checkpoint.""" datamodule = DictModule(dataset, lengths=[1.0,0.], batch_size=len(dataset)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Run a few steps of training in a temporary directory. with tempfile.TemporaryDirectory() as tmp_dir_path: @@ -94,9 +95,9 @@ def test_resume_from_checkpoint(cv_model, dataset): cv_model2 = cv_model.__class__.load_from_checkpoint(checkpoint_file_path) # Check that state is the same. - x = dataset['data'] - cv_model.eval() - cv_model2.eval() + x = dataset['data'].to(device) + cv_model.to(device).eval() + cv_model2.to(device).eval() assert torch.allclose(cv_model(x), cv_model2(x)) def test_lr_scheduler():