Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update checkpoint_hook.py #1281

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union

from mmengine import symlink
from mmengine.dist import is_main_process, master_only
from mmengine.fileio import FileClient, get_file_backend
from mmengine.logging import print_log
Expand Down Expand Up @@ -96,6 +97,9 @@ class CheckpointHook(Hook):
at which checkpoint saving begins. Defaults to 0, which means
saving at the beginning.
`New in version 0.8.3.`
create_symlink (bool, optional): Whether to create a symlink
"latest.pth" to point to the latest checkpoint.
Defaults to True.
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

Examples:
>>> # Save best based on single metric
Expand Down Expand Up @@ -145,6 +149,7 @@ def __init__(self,
backend_args: Optional[dict] = None,
published_keys: Union[str, List[str], None] = None,
save_begin: int = 0,
creat_symlink: Optional[bool] = True,
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
**kwargs) -> None:
self.interval = interval
self.by_epoch = by_epoch
Expand All @@ -153,6 +158,7 @@ def __init__(self,
self.out_dir = out_dir # type: ignore
self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last
self.creat_symlink = creat_symlink
self.args = kwargs

if file_client_args is not None:
Expand Down Expand Up @@ -460,6 +466,20 @@ def _save_checkpoint_with_step(self, runner, step, meta):
with open(save_file, 'w') as f:
f.write(self.last_ckpt) # type: ignore

# in some environments, `os.symlink` is not supported, then set
# 'create_symlink' to False.
if self.create_symlink:
dst_file = osp.join(runner.work_dir, 'latest.pth')
try:
symlink(ckpt_filename, dst_file)
except SystemError:
print_log(
'create_symlink is set as False because creating symbolic '
f'link is not allowed in {self.file_client.name}',
logger='current',
level=logging.WARNING)
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
self.create_symlink = False

def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint.

Expand Down