From c9a48b94bfeb38a200078601039294273418a9c8 Mon Sep 17 00:00:00 2001 From: Luigi Bonati Date: Mon, 10 Jul 2023 01:17:19 +0200 Subject: [PATCH 1/4] add option to use only train loader in datamodule --- mlcolvar/data/datamodule.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mlcolvar/data/datamodule.py b/mlcolvar/data/datamodule.py index c37b7ce8..87856a2b 100644 --- a/mlcolvar/data/datamodule.py +++ b/mlcolvar/data/datamodule.py @@ -170,6 +170,10 @@ def train_dataloader(self): def val_dataloader(self): """Return validation dataloader.""" self._check_setup() + if len(self.lengths) < 2: + raise NotImplementedError( + "Validation dataset not available, you need to pass two lengths to datamodule." + ) if self.valid_loader is None: self.valid_loader = DictLoader( self._dataset_split[1], @@ -182,7 +186,7 @@ def test_dataloader(self): """Return test dataloader.""" self._check_setup() if len(self.lengths) < 3: - raise ValueError( + raise NotImplementedError( "Test dataset not available, you need to pass three lengths to datamodule." ) if self.test_loader is None: @@ -202,7 +206,8 @@ def teardown(self, stage: str): def __repr__(self) -> str: string = f"DictModule(dataset -> {self.dataset.__repr__()}" string += f",\n\t\t train_loader -> DictLoader(length={self.lengths[0]}, batch_size={self.batch_size[0]}, shuffle={self.shuffle[0]})" - string += f",\n\t\t valid_loader -> DictLoader(length={self.lengths[1]}, batch_size={self.batch_size[1]}, shuffle={self.shuffle[1]})" + if len(self.lengths) >= 2: + string += f",\n\t\t valid_loader -> DictLoader(length={self.lengths[1]}, batch_size={self.batch_size[1]}, shuffle={self.shuffle[1]})" if len(self.lengths) >= 3: string += f",\n\t\t\ttest_loader =DictLoader(length={self.lengths[2]}, batch_size={self.batch_size[2]}, shuffle={self.shuffle[2]})" string += f")" From a74d3a2c53db179ab3a53eaf452a6e5549b87a9a Mon Sep 17 00:00:00 2001 From: Luigi Bonati Date: Mon, 10 Jul 2023 01:17:36 +0200 Subject: [PATCH 2/4] add info on how to disable validation --- .../tutorials/intro_3_loss_optim.ipynb | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/docs/notebooks/tutorials/intro_3_loss_optim.ipynb b/docs/notebooks/tutorials/intro_3_loss_optim.ipynb index 5a9c7656..2454c3af 100644 --- a/docs/notebooks/tutorials/intro_3_loss_optim.ipynb +++ b/docs/notebooks/tutorials/intro_3_loss_optim.ipynb @@ -450,6 +450,35 @@ "\n", "# After the training is over the metrics can be accessed with the dictionary .metrics" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Disable validation loop" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to disable the validation loop you need to:\n", + "1. tell the `DictModule` not to split the dataset, with `lengths=[1.0]`\n", + "2. pass the two options below to the `lightning.trainer`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# from mlcolvar.data import DictModule\n", + "\n", + "#datamodule = DictModule(dataset,lengths=[1.0])\n", + "\n", + "trainer = lightning.Trainer(limit_val_batches=0, num_sanity_val_steps=0)" + ] } ], "metadata": { From 19450ddd8ae4e683b0ebf32614f4eea324792a09 Mon Sep 17 00:00:00 2001 From: Luigi Bonati Date: Mon, 10 Jul 2023 01:32:12 +0200 Subject: [PATCH 3/4] fix type of raised error in test datamodule --- mlcolvar/tests/test_utils_data_datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlcolvar/tests/test_utils_data_datamodule.py b/mlcolvar/tests/test_utils_data_datamodule.py index 53b86c4c..00d7e030 100644 --- a/mlcolvar/tests/test_utils_data_datamodule.py +++ b/mlcolvar/tests/test_utils_data_datamodule.py @@ -25,7 +25,7 @@ # ============================================================================= -@pytest.mark.parametrize("lengths", [[0.8, 0.2], [0.7, 0.2, 0.1]]) +@pytest.mark.parametrize("lengths", [[1.0], [0.8, 0.2], [0.7, 0.2, 0.1]]) @pytest.mark.parametrize("fields", [[], ["labels", "weights"]]) @pytest.mark.parametrize("random_split", [True, False]) def test_dictionary_data_module_split(lengths, fields, random_split): @@ -75,7 +75,7 @@ def test_dictionary_data_module_split(lengths, fields, random_split): # An error is raised if the length of the test set has not been specified. if len(lengths) < 3: - with pytest.raises(ValueError, match="you need to pass three lengths"): + with pytest.raises(NotImplementedError, match="you need to pass three lengths"): datamodule.test_dataloader() From 2a43254d617d3f0c3c1285d733e2cc39a8a7f3a3 Mon Sep 17 00:00:00 2001 From: Luigi Bonati Date: Wed, 25 Oct 2023 01:31:20 +0200 Subject: [PATCH 4/4] [utils] change metrics callback from on_validation_end to on_train_end (#88) --- mlcolvar/utils/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlcolvar/utils/trainer.py b/mlcolvar/utils/trainer.py index 72db60e0..b9242940 100644 --- a/mlcolvar/utils/trainer.py +++ b/mlcolvar/utils/trainer.py @@ -26,7 +26,7 @@ def __init__(self): super().__init__() self.metrics = {"epoch": []} - def on_validation_epoch_end(self, trainer, pl_module): + def on_train_epoch_end(self, trainer, pl_module): metrics = trainer.callback_metrics if not trainer.sanity_checking: self.metrics["epoch"].append(trainer.current_epoch)