Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD: custom checkpoint callbacks #16

Merged
merged 2 commits into from
Jan 1, 2025
Merged

Conversation

haru-256
Copy link
Owner

@haru-256 haru-256 commented Jan 1, 2025

User description

Summary

Changes

  • add custom checkpoint callbacks

How to Test

  • cd common && make test

Additional Context

  • none

PR Type

Enhancement, Tests


Description

  • Introduced CustomModelCheckpoint for model saving.

  • Added comprehensive tests for checkpoint functionality.

  • Updated project dependencies for testing utilities.

  • Cleaned up VSCode and project configuration files.


Changes walkthrough 📝

Relevant files
Tests
test_callbacks.py
Tests for CustomModelCheckpoint functionality                       

common/tests/test_utils/test_callbacks.py

  • Added tests for CustomModelCheckpoint
  • Validated checkpoint saving logic
  • Checked error handling for invalid parameters
  • +194/-0 
    Enhancement
    callbacks.py
    Custom Model Checkpoint Implementation                                     

    common/utils/callbacks.py

  • Implemented CustomModelCheckpoint class
  • Added methods for checkpoint management
  • Included error handling for parameters
  • +245/-0 
    Configuration changes
    settings.json
    Update VSCode settings                                                                     

    common/.vscode/settings.json

    • Removed formatOnSave setting
    +0/-1     
    Dependencies
    pyproject.toml
    Update project dependencies                                                           

    common/pyproject.toml

    • Added dependencies for testing
    • Cleaned up dependency groups
    +5/-7     

    💡 PR-Agent usage: Comment /help "your question" on any pull request to receive relevant information

    Copy link

    github-actions bot commented Jan 1, 2025

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    ⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review

    Test Coverage

    Ensure that all edge cases and potential failure points in the CustomModelCheckpoint class are adequately tested, especially around the handling of checkpoints and model saving logic.

    class TestCustomModelCheckpoint:
        def test_invalid_save_top_k(self) -> None:
            with pytest.raises(ValueError, match="save_top_k should be positive or -1"):
                CustomModelCheckpoint(monitor="val_loss", mode="min", save_top_k=0)
    
        def test_invalid_mode(self) -> None:
            with pytest.raises(ValueError, match="mode should be 'min' or 'max'"):
                CustomModelCheckpoint(monitor="val_loss", mode="invalid")  # type: ignore
    
        def test_best_model_properties_before_setting(self) -> None:
            callback = _custom_model_checkpoint_factory(mode="min")
            with pytest.raises(ValueError, match="No best model path found"):
                _ = callback.best_model_path
            with pytest.raises(ValueError, match="No best model score found"):
                _ = callback.best_model_score
    
        def test_get_top_k_checkpoints(self) -> None:
            checkpoints = [
                (0.5, pathlib.Path("model1.ckpt")),
                (0.3, pathlib.Path("model2.ckpt")),
                (0.7, pathlib.Path("model3.ckpt")),
            ]
            callback = _custom_model_checkpoint_factory(mode="min", checkpoints=checkpoints)
            top_k = callback.get_top_k_checkpoints()
            assert len(top_k) == 3 and DeepDiff(top_k, checkpoints)
    
            callback = _custom_model_checkpoint_factory(
                mode="min", checkpoints=checkpoints, save_top_k=2
            )
            top_k = callback.get_top_k_checkpoints()
            assert len(top_k) == 2 and DeepDiff(top_k, [checkpoints[1], checkpoints[0]])
    
        def test_update_best_model(self) -> None:
            checkpoints = [
                (0.5, pathlib.Path("model1.ckpt")),
                (0.3, pathlib.Path("model2.ckpt")),
                (0.7, pathlib.Path("model3.ckpt")),
            ]
    
            callback = _custom_model_checkpoint_factory(mode="min", checkpoints=checkpoints)
            callback._update_best_model()
            assert callback._best_model_score == 0.3
            assert callback._best_model_path == pathlib.Path("model2.ckpt")
    
            # Test with max mode
            callback = _custom_model_checkpoint_factory(mode="max", checkpoints=checkpoints)
            callback._update_best_model()
            assert callback._best_model_score == 0.7
            assert callback._best_model_path == pathlib.Path("model3.ckpt")
    
            # Test with empty checkpoints
            callback = _custom_model_checkpoint_factory(mode="min", checkpoints=checkpoints)
            assert callback._best_model_score is None
            assert callback._best_model_path is None
    
        def test_get_trainer_log_dir(self, mocker: MockerFixture) -> None:
            callback = CustomModelCheckpoint(monitor="val_loss", mode="min")
    
            # Test with valid log dir
            mock_trainer = mocker.MagicMock(L.Trainer)
            mock_trainer.logger.log_dir = "test/path"
            log_dir = callback._get_trainer_log_dir(mock_trainer)
            assert log_dir == pathlib.Path("test/path")
    
            # Test with None logger
            mock_trainer.logger = None
            with pytest.raises(ValueError, match="Trainer logger is None"):
                callback._get_trainer_log_dir(mock_trainer)
    
            # Test with None log_dir
            mock_trainer.logger = mocker.MagicMock()
            mock_trainer.logger.log_dir = None
            with pytest.raises(ValueError, match="Trainer logger has no log directory"):
                callback._get_trainer_log_dir(mock_trainer)
    
        def test_on_train_epoch_end(self, mocker: MockerFixture, tmp_path: pathlib.Path) -> None:
            checkpoints = [
                (0.5, pathlib.Path("model1.ckpt")),
                (0.3, pathlib.Path("model2.ckpt")),
            ]
    
            # testするcallbackを作成
            callback = _custom_model_checkpoint_factory(
                mode="min", checkpoints=checkpoints, save_top_k=2
            )
            mocker.patch.object(
                callback, "_should_skip_saving_checkpoints", return_value=False, autospec=True
            )
            # DIするtrainerとlightning moduleはMagicMockでモック
            trainer = mocker.MagicMock(L.Trainer)
            trainer.logger.log_dir = str(tmp_path)
            trainer.current_epoch = 10
            trainer.global_step = 0
            trainer.callback_metrics = {"metric": torch.tensor(0.1)}
            data = {"state_dict": torch.tensor(0.1)}
            _checkpoint_connector = mocker.MagicMock()
            _checkpoint_connector.dump_checkpoint.return_value = data
            trainer._checkpoint_connector = _checkpoint_connector
            # NOTE: 以下だとAttributeError: Mock object has no attribute 'profiler'となる。おそらく _ prefixのメソッドのmethodをmockできないため
            # trainer._checkpoint_connector.dump_checkpoint.return_value = data
            pl_module = mocker.MagicMock(L.LightningModule)
    
            callback.on_train_epoch_end(trainer, pl_module)
            checkpoint_path = tmp_path / "checkpoints" / "epoch=10-step=0000000000-metric=0.1.ckpt"
            assert checkpoint_path.exists()
            assert torch.load(checkpoint_path, weights_only=True) == data
            torch.testing.assert_close(callback.best_model_score, 0.1)
            assert callback.best_model_path == checkpoint_path
            actual_top_k_path = [
                path for _, path in callback.get_top_k_checkpoints()
            ]  # 数値的に不安定なのでfileのみ比較
            assert DeepDiff(actual_top_k_path, [checkpoint_path, checkpoints[1]])
    
        def test_on_fit_end(self, mocker: MockerFixture) -> None:
            checkpoints = [
                (0.5, pathlib.Path("model1.ckpt")),
                (0.3, pathlib.Path("model2.ckpt")),
                (0.7, pathlib.Path("model3.ckpt")),
            ]
            callback = _custom_model_checkpoint_factory(
                mode="min", checkpoints=checkpoints, save_top_k=2
            )
    
            mocker.patch.object(
                callback, "_should_skip_saving_checkpoints", side_effect=[True, False], autospec=True
            )
            # 以下でもOK
            # callback._should_skip_saving_checkpoints = mocker.MagicMock(
            #     side_effect=[True, False], autospec=True
            # )
            open_mock = mocker.patch("builtins.open", new_callable=mocker.mock_open)
            # DIするobjectはMagicMockでモック
            trainer = mocker.MagicMock(L.Trainer)
            trainer.logger.log_dir = "tests/path"
            pl_module = mocker.MagicMock(L.LightningModule)
    
            # Test skip
            callback.on_fit_end(trainer, pl_module)
            assert open_mock.call_count == 0
    
            # Test save
            callback.on_fit_end(trainer, pl_module)
            assert open_mock.call_count == 1
            # TODO: 書き込むデータの検証
            # handle = open_mock()
            # data = {
            #     "monitor": "metric",
            #     "mode": "min",
            #     "top_k_models": [
            #         {
            #             "rank": rnk + 1,
            #             "metric": metric,
            #             "file_name": str(checkpoint),
            #             "path": str(checkpoint),
            #         }
            #         for rnk, (metric, checkpoint) in enumerate(
            #             [(0.3, "model2.ckpt"), (0.5, "model1.ckpt")]
            #         )
            #     ],
            # }
            # assert handle.write.assert_called_once_with(json.dumps(data))
    Exception Handling

    Review the exception handling in the CustomModelCheckpoint methods to ensure that all potential errors are caught and handled gracefully, particularly in methods that interact with the file system.

    class CustomModelCheckpoint(Checkpoint):
        def __init__(
            self,
            monitor: str,
            mode: Literal["min", "max"],
            every_n_epochs: int = 1,
            save_top_k: int = -1,
        ):
            """Custom ModelCheckpoint callback, GCS fuseに保存するとファイルシステムの違いでエラーが出るため、custom化
            issue: https://github.com/Lightning-AI/pytorch-lightning/issues/20270
    
            Args:
                monitor: _description_
                mode: _description_
                every_n_epochs: _description_. Defaults to 1.
                save_top_k: _description_. Defaults to -1.
    
            Raises:
                ValueError: _description_
                ValueError: _description_
            """
            super().__init__()
    
            if not (save_top_k > 0 or save_top_k == -1):
                raise ValueError("save_top_k should be positive or -1")
            if mode not in ["min", "max"]:
                raise ValueError("mode should be 'min' or 'max'")
    
            self.monitor = monitor
            self.mode = mode
            self.every_n_epochs = every_n_epochs
            self.save_top_k = save_top_k
    
            self._checkpoints: list[tuple[float, pathlib.Path]] = []
            self._best_model_path: Optional[pathlib.Path] = None
            self._best_model_score: Optional[float] = None
    
        def _should_skip_saving_checkpoints(self, trainer: L.Trainer) -> bool:
            """Determine if checkpoint saving should be skipped.
    
            Args:
                trainer (L.Trainer): Lightning trainer instance.
    
            Returns:
                bool: True if saving should be skipped, False otherwise.
            """
            # from: lightning.pytorch.callbacks.model_checkpoint::ModelCheckpoint::_should_skip_saving_checkpoints
            return (
                bool(trainer.fast_dev_run)
                or trainer.state.fn != TrainerFn.FITTING
                or trainer.sanity_checking
            )
    
        @property
        def best_model_path(self) -> Optional[pathlib.Path]:
            """Get the path to the best model checkpoint.
    
            Returns:
                Optional[pathlib.Path]: Path to the best model checkpoint.
    
            Raises:
                ValueError: If no best model path has been set.
            """
            if self._best_model_path is None:
                raise ValueError("No best model path found")
            return self._best_model_path
    
        @property
        def best_model_score(self) -> Optional[float]:
            """Get the score of the best model.
    
            Returns:
                Optional[float]: Score of the best model.
    
            Raises:
                ValueError: If no best model score has been set.
            """
            if self._best_model_score is None:
                raise ValueError("No best model score found")
            return self._best_model_score
    
        def _sorted_checkpoints(self) -> list[tuple[float, pathlib.Path]]:
            """Sort checkpoints based on their scores.
    
            Returns:
                list[tuple[float, pathlib.Path]]: List of (score, path) tuples sorted by score.
            """
            return sorted(self._checkpoints, key=lambda x: x[0], reverse=self.mode == "max")
    
        def _update_best_model(self) -> None:
            """Update the best model path and score based on current checkpoints."""
            if len(self._checkpoints) == 0:
                return
    
            sorted_checkpoints = self._sorted_checkpoints()
            self._best_model_score, self._best_model_path = sorted_checkpoints[0]
    
        def get_top_k_checkpoints(self) -> list[tuple[float, pathlib.Path]]:
            """Get the top K checkpoints based on their scores.
    
            Returns:
                list[tuple[float, pathlib.Path]]: List of top K (score, path) tuples.
            """
            if self.save_top_k == -1:
                return self._sorted_checkpoints()
            return self._sorted_checkpoints()[: self.save_top_k]
    
        def _clean_checkpoints(self) -> None:
            """Remove checkpoints that are not in the top K.
    
            Does nothing if save_top_k is -1 (keep all checkpoints).
            """
            if self.save_top_k == -1:
                return
    
            sorted_checkpoints = self._sorted_checkpoints()
            deleted_checkpoints = sorted_checkpoints[self.save_top_k :]
    
            for _, checkpoint in deleted_checkpoints:
                if checkpoint.exists():
                    checkpoint.unlink()
    
        def _get_trainer_log_dir(self, trainer: L.Trainer) -> pathlib.Path:
            if trainer.logger is None:
                raise ValueError("Trainer logger is None")
            log_dir = trainer.logger.log_dir
            if log_dir is None:
                raise ValueError("Trainer logger has no log directory")
            return pathlib.Path(log_dir)
    
        @override
        def on_train_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
            """Save model checkpoints at the end of each training epoch.
    
            This method is called at the end of each training epoch to:
            1. Skip checkpoint saving if conditions are met (fast_dev_run, sanity_checking)
            2. Save model checkpoint if the current epoch matches the save frequency
            3. Update the best model tracking based on the monitored metric
    
            Args:
                trainer (L.Trainer): The Lightning trainer instance.
                pl_module (L.LightningModule): The Lightning module being trained.
    
            Side Effects:
                - Creates checkpoint directory if it doesn't exist
                - Saves model checkpoint to disk
                - Updates internal checkpoint tracking
                - Updates best model path and score
            """
            if self._should_skip_saving_checkpoints(trainer):
                return
            if not (self.every_n_epochs >= 1 and trainer.current_epoch % self.every_n_epochs == 0):
                return
    
            # Retrieval the monitored metric
            metrics = trainer.callback_metrics.get(self.monitor)
            if metrics is None:
                raise ValueError(f"Metric '{self.monitor}' not found in callback metrics")
            # breakpoint()
            metrics = float(metrics.item())
            num_epochs = trainer.current_epoch
            num_steps = trainer.global_step
    
            # Make sure the save directory exists
            log_dir = self._get_trainer_log_dir(trainer)
            save_dir = pathlib.Path(log_dir) / "checkpoints"
            save_dir.mkdir(parents=True, exist_ok=True)
    
            # Save the checkpoint
            checkpoint_path = (
                save_dir
                / f"epoch={num_epochs:02d}-step={num_steps:010d}-{self.monitor}={metrics:.6g}.ckpt"
            )
            logger.info(f"Saving checkpoint to {checkpoint_path}")
            _start = time.perf_counter()
            ## from trainer.save_checkpoint(checkpoint_path )
            checkpoint = trainer._checkpoint_connector.dump_checkpoint(weights_only=False)
            torch.save(checkpoint, checkpoint_path)
            logger.info(f"Checkpoint saved in {time.perf_counter() - _start:.2f} seconds")
    
            # update the best model
            self._checkpoints.append((metrics, checkpoint_path))
            self._update_best_model()
            self._clean_checkpoints()
            # breakpoint()
    
        @override
        def on_fit_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
            """Finalize the model checkpoint tracking at the end of training.
    
            This method is called once at the end of training to:
            1. Skip if in fast_dev_run or sanity checking mode
            2. Update the best model tracking
            3. Save the best model information to a JSON file
    
            Args:
                trainer (L.Trainer): The Lightning trainer instance.
                pl_module (L.LightningModule): The Lightning module that was trained.
    
            Side Effects:
                - Creates a 'best_model.json' file containing:
                    - Monitored metric name
                    - Monitor mode (min/max)
                    - Top K model information (rank, metric value, file paths)
            """
            if self._should_skip_saving_checkpoints(trainer):
                logger.info("Skip saving checkpoints because of fast_dev_run or sanity_checking")
                return
    
            self._update_best_model()
            logger.info(f"Best model score: {self.best_model_score:.6g} at {self.best_model_path}")
            log_dir = self._get_trainer_log_dir(trainer)
            with open(log_dir / "best_model.json", "w") as f:
                json.dump(
                    {
                        "monitor_metric": self.monitor,
                        "monitor_mode": self.mode,
                        "top_k_models": [
                            {
                                "rank": rnk + 1,
                                "metric": metric,
                                "file_name": checkpoint.name,
                                "path": str(checkpoint),
                            }
                            for rnk, (metric, checkpoint) in enumerate(self.get_top_k_checkpoints())
                        ],
                        "best_model_score": self.best_model_score,
                        "best_model_path": str(self.best_model_path),
                    },
                    f,
                    indent=4,
                )

    Copy link

    github-actions bot commented Jan 1, 2025

    PR Code Suggestions ✨

    Explore these optional code suggestions:

    CategorySuggestion                                                                                                                                    Score
    Possible issue
    Validate the monitored metric before accessing it to avoid potential runtime errors

    Ensure that the monitor argument is a valid metric before using it in the
    on_train_epoch_end method to prevent runtime errors.

    common/utils/callbacks.py [169]

    -metrics = trainer.callback_metrics.get(self.monitor)
    +if self.monitor not in trainer.callback_metrics:
    +    raise ValueError(f"Metric '{self.monitor}' not found in callback metrics")
    +metrics = trainer.callback_metrics[self.monitor]
    Suggestion importance[1-10]: 8

    Why: This suggestion addresses a potential runtime error by ensuring that the monitored metric exists in the callback metrics before accessing it, which is crucial for robustness.

    8
    Prevent saving a None checkpoint to avoid runtime errors during checkpoint saving

    Add a check to ensure that checkpoint is not None before attempting to save it in
    the on_train_epoch_end method to avoid saving errors.

    common/utils/callbacks.py [190]

    +if checkpoint is None:
    +    raise ValueError("Checkpoint data is None, cannot save.")
     torch.save(checkpoint, checkpoint_path)
    Suggestion importance[1-10]: 7

    Why: This suggestion improves error handling by checking if the checkpoint is None before attempting to save it, which can prevent runtime errors during the saving process.

    7

    @haru-256 haru-256 merged commit a725861 into main Jan 1, 2025
    1 of 2 checks passed
    @haru-256 haru-256 deleted the feat/custm-checkpoint-callbacks branch January 1, 2025 10:43
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant