Skip to content

Commit

Permalink
Merge pull request #88 from luigibonati/fix_disable_validation
Browse files Browse the repository at this point in the history
allow to disable validation step
  • Loading branch information
luigibonati authored Oct 24, 2023
2 parents 51d1f66 + 2a43254 commit f948417
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 5 deletions.
29 changes: 29 additions & 0 deletions docs/notebooks/tutorials/intro_3_loss_optim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,35 @@
"\n",
"# After the training is over the metrics can be accessed with the dictionary .metrics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Disable validation loop"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to disable the validation loop you need to:\n",
"1. tell the `DictModule` not to split the dataset, with `lengths=[1.0]`\n",
"2. pass the two options below to the `lightning.trainer`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# from mlcolvar.data import DictModule\n",
"\n",
"#datamodule = DictModule(dataset,lengths=[1.0])\n",
"\n",
"trainer = lightning.Trainer(limit_val_batches=0, num_sanity_val_steps=0)"
]
}
],
"metadata": {
Expand Down
9 changes: 7 additions & 2 deletions mlcolvar/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def train_dataloader(self):
def val_dataloader(self):
"""Return validation dataloader."""
self._check_setup()
if len(self.lengths) < 2:
raise NotImplementedError(
"Validation dataset not available, you need to pass two lengths to datamodule."
)
if self.valid_loader is None:
self.valid_loader = DictLoader(
self._dataset_split[1],
Expand All @@ -182,7 +186,7 @@ def test_dataloader(self):
"""Return test dataloader."""
self._check_setup()
if len(self.lengths) < 3:
raise ValueError(
raise NotImplementedError(
"Test dataset not available, you need to pass three lengths to datamodule."
)
if self.test_loader is None:
Expand All @@ -202,7 +206,8 @@ def teardown(self, stage: str):
def __repr__(self) -> str:
string = f"DictModule(dataset -> {self.dataset.__repr__()}"
string += f",\n\t\t train_loader -> DictLoader(length={self.lengths[0]}, batch_size={self.batch_size[0]}, shuffle={self.shuffle[0]})"
string += f",\n\t\t valid_loader -> DictLoader(length={self.lengths[1]}, batch_size={self.batch_size[1]}, shuffle={self.shuffle[1]})"
if len(self.lengths) >= 2:
string += f",\n\t\t valid_loader -> DictLoader(length={self.lengths[1]}, batch_size={self.batch_size[1]}, shuffle={self.shuffle[1]})"
if len(self.lengths) >= 3:
string += f",\n\t\t\ttest_loader =DictLoader(length={self.lengths[2]}, batch_size={self.batch_size[2]}, shuffle={self.shuffle[2]})"
string += f")"
Expand Down
4 changes: 2 additions & 2 deletions mlcolvar/tests/test_utils_data_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# =============================================================================


@pytest.mark.parametrize("lengths", [[0.8, 0.2], [0.7, 0.2, 0.1]])
@pytest.mark.parametrize("lengths", [[1.0], [0.8, 0.2], [0.7, 0.2, 0.1]])
@pytest.mark.parametrize("fields", [[], ["labels", "weights"]])
@pytest.mark.parametrize("random_split", [True, False])
def test_dictionary_data_module_split(lengths, fields, random_split):
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_dictionary_data_module_split(lengths, fields, random_split):

# An error is raised if the length of the test set has not been specified.
if len(lengths) < 3:
with pytest.raises(ValueError, match="you need to pass three lengths"):
with pytest.raises(NotImplementedError, match="you need to pass three lengths"):
datamodule.test_dataloader()


Expand Down
2 changes: 1 addition & 1 deletion mlcolvar/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self):
super().__init__()
self.metrics = {"epoch": []}

def on_validation_epoch_end(self, trainer, pl_module):
def on_train_epoch_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
if not trainer.sanity_checking:
self.metrics["epoch"].append(trainer.current_epoch)
Expand Down

0 comments on commit f948417

Please sign in to comment.