From 06f75b216c09cb63e88deaabeacc9003dc7a60b5 Mon Sep 17 00:00:00 2001 From: EnricoTrizio Date: Mon, 15 Jan 2024 14:26:48 +0100 Subject: [PATCH] Moved model to device in checkpoint test --- mlcolvar/tests/test_cvs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlcolvar/tests/test_cvs.py b/mlcolvar/tests/test_cvs.py index a76c1a99..20d09c93 100644 --- a/mlcolvar/tests/test_cvs.py +++ b/mlcolvar/tests/test_cvs.py @@ -74,6 +74,7 @@ def dataset(): def test_resume_from_checkpoint(cv_model, dataset): """CVs correctly resume from a checkpoint.""" datamodule = DictModule(dataset, lengths=[1.0,0.], batch_size=len(dataset)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Run a few steps of training in a temporary directory. with tempfile.TemporaryDirectory() as tmp_dir_path: @@ -94,7 +95,7 @@ def test_resume_from_checkpoint(cv_model, dataset): cv_model2 = cv_model.__class__.load_from_checkpoint(checkpoint_file_path) # Check that state is the same. - x = dataset['data'] - cv_model.eval() - cv_model2.eval() + x = dataset['data'].to(device) + cv_model.to(device).eval() + cv_model2.to(device).eval() assert torch.allclose(cv_model(x), cv_model2(x))