diff --git a/docs/notebooks/tutorials/intro_3_loss_optim.ipynb b/docs/notebooks/tutorials/intro_3_loss_optim.ipynb index 5a9c7656..2454c3af 100644 --- a/docs/notebooks/tutorials/intro_3_loss_optim.ipynb +++ b/docs/notebooks/tutorials/intro_3_loss_optim.ipynb @@ -450,6 +450,35 @@ "\n", "# After the training is over the metrics can be accessed with the dictionary .metrics" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Disable validation loop" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to disable the validation loop you need to:\n", + "1. tell the `DictModule` not to split the dataset, with `lengths=[1.0]`\n", + "2. pass the two options below to the `lightning.trainer`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# from mlcolvar.data import DictModule\n", + "\n", + "#datamodule = DictModule(dataset,lengths=[1.0])\n", + "\n", + "trainer = lightning.Trainer(limit_val_batches=0, num_sanity_val_steps=0)" + ] } ], "metadata": { diff --git a/mlcolvar/data/datamodule.py b/mlcolvar/data/datamodule.py index c37b7ce8..87856a2b 100644 --- a/mlcolvar/data/datamodule.py +++ b/mlcolvar/data/datamodule.py @@ -170,6 +170,10 @@ def train_dataloader(self): def val_dataloader(self): """Return validation dataloader.""" self._check_setup() + if len(self.lengths) < 2: + raise NotImplementedError( + "Validation dataset not available, you need to pass two lengths to datamodule." + ) if self.valid_loader is None: self.valid_loader = DictLoader( self._dataset_split[1], @@ -182,7 +186,7 @@ def test_dataloader(self): """Return test dataloader.""" self._check_setup() if len(self.lengths) < 3: - raise ValueError( + raise NotImplementedError( "Test dataset not available, you need to pass three lengths to datamodule." ) if self.test_loader is None: @@ -202,7 +206,8 @@ def teardown(self, stage: str): def __repr__(self) -> str: string = f"DictModule(dataset -> {self.dataset.__repr__()}" string += f",\n\t\t train_loader -> DictLoader(length={self.lengths[0]}, batch_size={self.batch_size[0]}, shuffle={self.shuffle[0]})" - string += f",\n\t\t valid_loader -> DictLoader(length={self.lengths[1]}, batch_size={self.batch_size[1]}, shuffle={self.shuffle[1]})" + if len(self.lengths) >= 2: + string += f",\n\t\t valid_loader -> DictLoader(length={self.lengths[1]}, batch_size={self.batch_size[1]}, shuffle={self.shuffle[1]})" if len(self.lengths) >= 3: string += f",\n\t\t\ttest_loader =DictLoader(length={self.lengths[2]}, batch_size={self.batch_size[2]}, shuffle={self.shuffle[2]})" string += f")" diff --git a/mlcolvar/tests/test_utils_data_datamodule.py b/mlcolvar/tests/test_utils_data_datamodule.py index 53b86c4c..00d7e030 100644 --- a/mlcolvar/tests/test_utils_data_datamodule.py +++ b/mlcolvar/tests/test_utils_data_datamodule.py @@ -25,7 +25,7 @@ # ============================================================================= -@pytest.mark.parametrize("lengths", [[0.8, 0.2], [0.7, 0.2, 0.1]]) +@pytest.mark.parametrize("lengths", [[1.0], [0.8, 0.2], [0.7, 0.2, 0.1]]) @pytest.mark.parametrize("fields", [[], ["labels", "weights"]]) @pytest.mark.parametrize("random_split", [True, False]) def test_dictionary_data_module_split(lengths, fields, random_split): @@ -75,7 +75,7 @@ def test_dictionary_data_module_split(lengths, fields, random_split): # An error is raised if the length of the test set has not been specified. if len(lengths) < 3: - with pytest.raises(ValueError, match="you need to pass three lengths"): + with pytest.raises(NotImplementedError, match="you need to pass three lengths"): datamodule.test_dataloader() diff --git a/mlcolvar/utils/trainer.py b/mlcolvar/utils/trainer.py index 72db60e0..b9242940 100644 --- a/mlcolvar/utils/trainer.py +++ b/mlcolvar/utils/trainer.py @@ -26,7 +26,7 @@ def __init__(self): super().__init__() self.metrics = {"epoch": []} - def on_validation_epoch_end(self, trainer, pl_module): + def on_train_epoch_end(self, trainer, pl_module): metrics = trainer.callback_metrics if not trainer.sanity_checking: self.metrics["epoch"].append(trainer.current_epoch)