diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 92a4867bb9..210ba20ec7 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Union +from mmengine import symlink from mmengine.dist import is_main_process, master_only from mmengine.fileio import FileClient, get_file_backend from mmengine.logging import print_log @@ -96,6 +97,9 @@ class CheckpointHook(Hook): at which checkpoint saving begins. Defaults to 0, which means saving at the beginning. `New in version 0.8.3.` + create_symlink (bool, optional): Whether to create a symbolic + link pointing to the latest checkpoint named 'latest.pth'. + Defaults to False. `New in version 0.8.5.` Examples: >>> # Save best based on single metric @@ -145,6 +149,7 @@ def __init__(self, backend_args: Optional[dict] = None, published_keys: Union[str, List[str], None] = None, save_begin: int = 0, + creat_symlink: Optional[bool] = False, **kwargs) -> None: self.interval = interval self.by_epoch = by_epoch @@ -153,6 +158,7 @@ def __init__(self, self.out_dir = out_dir # type: ignore self.max_keep_ckpts = max_keep_ckpts self.save_last = save_last + self.creat_symlink = creat_symlink self.args = kwargs if file_client_args is not None: @@ -460,6 +466,18 @@ def _save_checkpoint_with_step(self, runner, step, meta): with open(save_file, 'w') as f: f.write(self.last_ckpt) # type: ignore + # in some environments, `os.symlink` is not supported, then set + # 'create_symlink' to False. + if self.create_symlink: + dst_file = osp.join(runner.work_dir, 'latest.pth') + try: + symlink(ckpt_filename, dst_file) + except SystemError as err: + print_log( + 'create_symlink is set as False because failed to create a symbolic ' + f'link for {err}', logger='current', level=logging.WARNING) + self.create_symlink = False + def _save_checkpoint(self, runner) -> None: """Save the current checkpoint and delete outdated checkpoint.