diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index fb9a6679ba..e3f316b9e7 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -104,7 +104,20 @@ def is_better(a, b): "checkpoint_last{}.pt".format(suffix) ] = not cfg.no_last_checkpoints - extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} + extra_state = { + "train_iterator": epoch_itr.state_dict(), + "val_loss": val_loss, + } + + # Going forward, different tasks could expose an API like this to dump all + # the checkpoint worthy attributes in a dictionary which then will be + # merged with the parent dictionary to create the "extra_state". This + # allows for an extensible yet simple design to checkpoint task level + # attributes + if hasattr(trainer.task, "get_checkpoint_dict"): + extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()} + logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint") + if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) @@ -275,6 +288,11 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): epoch=itr_state["epoch"], load_dataset=True, **passthrough_args ) epoch_itr.load_state_dict(itr_state) + + # Preload the checkpoint for the task + task_cp_dict = extra_state.get(trainer.task.__class__.__name__, {}) + if task_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"): + trainer.task.set_checkpoint_dict(task_cp_dict) else: epoch_itr = trainer.get_train_iterator( epoch=1, load_dataset=True, **passthrough_args diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py index 0bc85562c7..f8cd943cfa 100644 --- a/tests/test_checkpoint_utils.py +++ b/tests/test_checkpoint_utils.py @@ -11,8 +11,6 @@ from io import StringIO from unittest.mock import patch -from omegaconf import OmegaConf - from fairseq import checkpoint_utils from tests.utils import ( create_dummy_data, diff --git a/tests/test_checkpoint_utils_for_task_level_attributes.py b/tests/test_checkpoint_utils_for_task_level_attributes.py new file mode 100644 index 0000000000..53ab401f03 --- /dev/null +++ b/tests/test_checkpoint_utils_for_task_level_attributes.py @@ -0,0 +1,172 @@ +#!/usr/bin/env fbpython +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import contextlib +import logging +import unittest +from io import StringIO +from unittest.mock import MagicMock, patch + +import torch +from fairseq import checkpoint_utils, data +from omegaconf import OmegaConf + + +def mock_trainer(epoch, num_updates, iterations_in_epoch): + trainer = MagicMock() + trainer.load_checkpoint.return_value = { + "train_iterator": { + "epoch": epoch, + "iterations_in_epoch": iterations_in_epoch, + "shuffle": False, + }, + "FakeTask": checkpoint_dict()["FakeTask"], + } + trainer.get_num_updates.return_value = num_updates + trainer.task.__class__.__name__ = "FakeTask" + trainer.task.get_checkpoint_dict.return_value = checkpoint_dict() + trainer.task.set_checkpoint_dict = MagicMock() + + return trainer + + +def checkpoint_dict(): + return { + "FakeTask": { + "observer_stats": { + ( + 4, + 16, + "MovingAveragePerChannelMinMax", + "MovingAveragePerChannelMinMax", + ): {"mod1": 1, "mod2": 2, "mod3": 3} + } + } + } + + +def mock_dict(): + d = MagicMock() + d.pad.return_value = 1 + d.eos.return_value = 2 + d.unk.return_value = 3 + return d + + +def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch): + tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1) + tokens_ds = data.TokenBlockDataset( + tokens, + sizes=[tokens.size(-1)], + block_size=1, + pad=0, + eos=1, + include_targets=False, + ) + trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) + dataset = data.LanguagePairDataset( + tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False + ) + epoch_itr = data.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=[[i] for i in range(epoch_size)], + ) + return trainer, epoch_itr + + +def get_mock_cfg(finetune_from_model): + cfg_mock = OmegaConf.create( + { + "checkpoint": { + "save_dir": None, + "optimizer_overrides": "{}", + "reset_dataloader": False, + "reset_meters": False, + "reset_optimizer": False, + "reset_lr_scheduler": False, + "finetune_from_model": finetune_from_model, + "model_parallel_size": 1, + "restore_file": "checkpoint_last.pt", + "no_save": False, + "save_interval_updates": 0, + "no_last_checkpoints": False, + "keep_interval_updates": 0, + "keep_last_epochs": 0, + "keep_best_checkpoints": 0, + }, + "common": { + "model_parallel_size": 1, + }, + } + ) + return cfg_mock + + +class TestCheckpointsForTaskLevelAttributes(unittest.TestCase): + def setUp(self) -> None: + self.cfg_mock = get_mock_cfg(None) + self.patches = { + "os.makedirs": MagicMock(), + "os.path.join": MagicMock(), + "os.path.isfile": MagicMock(return_value=True), + "os.path.isabs": MagicMock(return_value=False), + "fairseq.file_io.PathManager.exists": MagicMock(return_value=False), + } + self.applied_patches = [patch(p, d) for p, d in self.patches.items()] + [p.start() for p in self.applied_patches] + logging.disable(logging.CRITICAL) + + self.trainer, self.epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) + self.trainer.get_train_iterator = MagicMock(return_value=self.epoch_itr) + self.epoch_itr.next_epoch_itr(shuffle=False) + + checkpoint_utils.save_checkpoint( + self.cfg_mock.checkpoint, self.trainer, self.epoch_itr, None + ) + + def tearDown(self): + patch.stopall() + logging.disable(logging.NOTSET) + + def test_verify_checkpoint(self) -> None: + cp_dict = self.trainer.task.get_checkpoint_dict() + self.assertTrue(len(cp_dict) == 1) + self.assertTrue("FakeTask" in cp_dict) + self.assertTrue("observer_stats" in cp_dict["FakeTask"]) + self.assertTrue(len(cp_dict["FakeTask"]["observer_stats"]) == 1) + self.assertTrue( + ( + 4, + 16, + "MovingAveragePerChannelMinMax", + "MovingAveragePerChannelMinMax", + ) + in cp_dict["FakeTask"]["observer_stats"] + ) + self.assertTrue( + cp_dict["FakeTask"]["observer_stats"][ + ( + 4, + 16, + "MovingAveragePerChannelMinMax", + "MovingAveragePerChannelMinMax", + ) + ] + == {"mod1": 1, "mod2": 2, "mod3": 3} + ) + + def test_load_checkpoint(self) -> None: + with contextlib.redirect_stdout(StringIO()): + # Now, load checkpoint to ensure the respective logic works as expected + _, epoch_itr = checkpoint_utils.load_checkpoint( + self.cfg_mock.checkpoint, self.trainer + ) + + self.trainer.task.set_checkpoint_dict.assert_called_once_with( + checkpoint_dict()["FakeTask"] + ) + + +if __name__ == "__main__": + unittest.main()