From e29f53bfea67fd9e81c3da374daac4b472ba6bda Mon Sep 17 00:00:00 2001 From: Piyush Kansal Date: Fri, 15 Sep 2023 12:01:49 -0700 Subject: [PATCH 1/2] initial revision (#5328) --- fairseq/checkpoint_utils.py | 20 +- tests/test_checkpoint_utils.py | 2 - ...ckpoint_utils_for_task_level_attributes.py | 172 ++++++++++++++++++ 3 files changed, 191 insertions(+), 3 deletions(-) create mode 100644 tests/test_checkpoint_utils_for_task_level_attributes.py diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index fb9a6679ba..4eff7807e2 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -104,7 +104,20 @@ def is_better(a, b): "checkpoint_last{}.pt".format(suffix) ] = not cfg.no_last_checkpoints - extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} + extra_state = { + "train_iterator": epoch_itr.state_dict(), + "val_loss": val_loss, + } + + # Going forward, different tasks could expose an API like this to dump all + # the checkpoint worthy attributes in a dictionary which then will be + # merged with the parent dictionary to create the "extra_state". This + # allows for an extensible yet simple design to checkpoint task level + # attributes + if hasattr(trainer.task, "get_checkpoint_dict"): + extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()} + logger.info(f"{trainer.task.__class__} checkpoint worthy attributes are ready to be persisted with the checkpoint") + if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) @@ -275,6 +288,11 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): epoch=itr_state["epoch"], load_dataset=True, **passthrough_args ) epoch_itr.load_state_dict(itr_state) + + # Preload the observer stats for Supernet + supernet_cp_dict = extra_state.get("supernet", {}) + if supernet_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"): + trainer.task.set_checkpoint_dict(supernet_cp_dict) else: epoch_itr = trainer.get_train_iterator( epoch=1, load_dataset=True, **passthrough_args diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py index 0bc85562c7..f8cd943cfa 100644 --- a/tests/test_checkpoint_utils.py +++ b/tests/test_checkpoint_utils.py @@ -11,8 +11,6 @@ from io import StringIO from unittest.mock import patch -from omegaconf import OmegaConf - from fairseq import checkpoint_utils from tests.utils import ( create_dummy_data, diff --git a/tests/test_checkpoint_utils_for_task_level_attributes.py b/tests/test_checkpoint_utils_for_task_level_attributes.py new file mode 100644 index 0000000000..ed7ba59110 --- /dev/null +++ b/tests/test_checkpoint_utils_for_task_level_attributes.py @@ -0,0 +1,172 @@ +#!/usr/bin/env fbpython +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import contextlib +import logging +import unittest +from io import StringIO +from unittest.mock import MagicMock, patch + +import torch +from fairseq import checkpoint_utils, data +from omegaconf import OmegaConf + + +def mock_trainer(epoch, num_updates, iterations_in_epoch): + trainer = MagicMock() + trainer.load_checkpoint.return_value = { + "train_iterator": { + "epoch": epoch, + "iterations_in_epoch": iterations_in_epoch, + "shuffle": False, + }, + "supernet": checkpoint_dict()["supernet"], + } + trainer.get_num_updates.return_value = num_updates + trainer.task.get_checkpoint_dict.return_value = checkpoint_dict() + trainer.task.set_checkpoint_dict = MagicMock() + + return trainer + + +def checkpoint_dict(): + return { + "supernet": { + "observer_stats": { + ( + 4, + 16, + "MovingAveragePerChannelMinMax", + "MovingAveragePerChannelMinMax", + ): {"mod1": 1, "mod2": 2, "mod3": 3} + } + } + } + + +def mock_dict(): + d = MagicMock() + d.pad.return_value = 1 + d.eos.return_value = 2 + d.unk.return_value = 3 + return d + + +def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch): + tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1) + tokens_ds = data.TokenBlockDataset( + tokens, + sizes=[tokens.size(-1)], + block_size=1, + pad=0, + eos=1, + include_targets=False, + ) + trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) + dataset = data.LanguagePairDataset( + tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False + ) + epoch_itr = data.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=[[i] for i in range(epoch_size)], + ) + return trainer, epoch_itr + + +def get_mock_cfg(finetune_from_model): + cfg_mock = OmegaConf.create( + { + "checkpoint": { + "save_dir": None, + "optimizer_overrides": "{}", + "reset_dataloader": False, + "reset_meters": False, + "reset_optimizer": False, + "reset_lr_scheduler": False, + "finetune_from_model": finetune_from_model, + "model_parallel_size": 1, + "restore_file": "checkpoint_last.pt", + "no_save": False, + "save_interval_updates": 0, + "no_last_checkpoints": False, + "keep_interval_updates": 0, + "keep_last_epochs": 0, + "keep_best_checkpoints": 0, + }, + "common": { + "model_parallel_size": 1, + }, + } + ) + return cfg_mock + + +class TestCheckpointsForTaskLevelAttributes(unittest.TestCase): + def setUp(self) -> None: + self.cfg_mock = get_mock_cfg(None) + self.patches = { + "os.makedirs": MagicMock(), + "os.path.join": MagicMock(), + "os.path.isfile": MagicMock(return_value=True), + "os.path.isabs": MagicMock(return_value=False), + "fairseq.file_io.PathManager.exists": MagicMock(return_value=False), + } + self.applied_patches = [patch(p, d) for p, d in self.patches.items()] + [p.start() for p in self.applied_patches] + logging.disable(logging.CRITICAL) + + self.trainer, self.epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) + self.trainer.get_train_iterator = MagicMock(return_value=self.epoch_itr) + self.epoch_itr.next_epoch_itr(shuffle=False) + + checkpoint_utils.save_checkpoint( + self.cfg_mock.checkpoint, self.trainer, self.epoch_itr, None + ) + + def tearDown(self): + patch.stopall() + logging.disable(logging.NOTSET) + + def test_verify_checkpoint(self) -> None: + cp_dict = self.trainer.task.get_checkpoint_dict() + self.assertTrue(len(cp_dict) == 1) + self.assertTrue("supernet" in cp_dict) + self.assertTrue("observer_stats" in cp_dict["supernet"]) + self.assertTrue(len(cp_dict["supernet"]["observer_stats"]) == 1) + self.assertTrue( + ( + 4, + 16, + "MovingAveragePerChannelMinMax", + "MovingAveragePerChannelMinMax", + ) + in cp_dict["supernet"]["observer_stats"] + ) + self.assertTrue( + cp_dict["supernet"]["observer_stats"][ + ( + 4, + 16, + "MovingAveragePerChannelMinMax", + "MovingAveragePerChannelMinMax", + ) + ] + == {"mod1": 1, "mod2": 2, "mod3": 3} + ) + + def test_load_checkpoint(self) -> None: + with contextlib.redirect_stdout(StringIO()): + # Now, load checkpoint to ensure the respective logic works as expected + _, epoch_itr = checkpoint_utils.load_checkpoint( + self.cfg_mock.checkpoint, self.trainer + ) + + self.trainer.task.set_checkpoint_dict.assert_called_once_with( + checkpoint_dict()["supernet"] + ) + + +if __name__ == "__main__": + unittest.main() + From 7409af7f9a7b6ddac4cbfe7cafccc715b3c1b21e Mon Sep 17 00:00:00 2001 From: Piyush Kansal Date: Fri, 15 Sep 2023 16:15:19 -0700 Subject: [PATCH 2/2] Keep task level checkpoint key name generic (#5330) --- fairseq/checkpoint_utils.py | 10 +++++----- ...eckpoint_utils_for_task_level_attributes.py | 18 +++++++++--------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 4eff7807e2..e3f316b9e7 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -116,7 +116,7 @@ def is_better(a, b): # attributes if hasattr(trainer.task, "get_checkpoint_dict"): extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()} - logger.info(f"{trainer.task.__class__} checkpoint worthy attributes are ready to be persisted with the checkpoint") + logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint") if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) @@ -289,10 +289,10 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): ) epoch_itr.load_state_dict(itr_state) - # Preload the observer stats for Supernet - supernet_cp_dict = extra_state.get("supernet", {}) - if supernet_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"): - trainer.task.set_checkpoint_dict(supernet_cp_dict) + # Preload the checkpoint for the task + task_cp_dict = extra_state.get(trainer.task.__class__.__name__, {}) + if task_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"): + trainer.task.set_checkpoint_dict(task_cp_dict) else: epoch_itr = trainer.get_train_iterator( epoch=1, load_dataset=True, **passthrough_args diff --git a/tests/test_checkpoint_utils_for_task_level_attributes.py b/tests/test_checkpoint_utils_for_task_level_attributes.py index ed7ba59110..53ab401f03 100644 --- a/tests/test_checkpoint_utils_for_task_level_attributes.py +++ b/tests/test_checkpoint_utils_for_task_level_attributes.py @@ -20,9 +20,10 @@ def mock_trainer(epoch, num_updates, iterations_in_epoch): "iterations_in_epoch": iterations_in_epoch, "shuffle": False, }, - "supernet": checkpoint_dict()["supernet"], + "FakeTask": checkpoint_dict()["FakeTask"], } trainer.get_num_updates.return_value = num_updates + trainer.task.__class__.__name__ = "FakeTask" trainer.task.get_checkpoint_dict.return_value = checkpoint_dict() trainer.task.set_checkpoint_dict = MagicMock() @@ -31,7 +32,7 @@ def mock_trainer(epoch, num_updates, iterations_in_epoch): def checkpoint_dict(): return { - "supernet": { + "FakeTask": { "observer_stats": { ( 4, @@ -131,9 +132,9 @@ def tearDown(self): def test_verify_checkpoint(self) -> None: cp_dict = self.trainer.task.get_checkpoint_dict() self.assertTrue(len(cp_dict) == 1) - self.assertTrue("supernet" in cp_dict) - self.assertTrue("observer_stats" in cp_dict["supernet"]) - self.assertTrue(len(cp_dict["supernet"]["observer_stats"]) == 1) + self.assertTrue("FakeTask" in cp_dict) + self.assertTrue("observer_stats" in cp_dict["FakeTask"]) + self.assertTrue(len(cp_dict["FakeTask"]["observer_stats"]) == 1) self.assertTrue( ( 4, @@ -141,10 +142,10 @@ def test_verify_checkpoint(self) -> None: "MovingAveragePerChannelMinMax", "MovingAveragePerChannelMinMax", ) - in cp_dict["supernet"]["observer_stats"] + in cp_dict["FakeTask"]["observer_stats"] ) self.assertTrue( - cp_dict["supernet"]["observer_stats"][ + cp_dict["FakeTask"]["observer_stats"][ ( 4, 16, @@ -163,10 +164,9 @@ def test_load_checkpoint(self) -> None: ) self.trainer.task.set_checkpoint_dict.assert_called_once_with( - checkpoint_dict()["supernet"] + checkpoint_dict()["FakeTask"] ) if __name__ == "__main__": unittest.main() -