From 478a80a8a825bc0d0207171d65676be5349b5000 Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Sun, 18 Feb 2024 17:16:06 +0800 Subject: [PATCH] Move checkpoint funtions from runner to a new sub-package --- docs/en/api/checkpoint.rst | 34 + docs/en/api/runner.rst | 4 + docs/en/index.rst | 1 + docs/zh_cn/api/checkpoint.rst | 34 + docs/zh_cn/api/runner.rst | 4 + docs/zh_cn/index.rst | 1 + mmengine/checkpoint/__init__.py | 14 + mmengine/checkpoint/io.py | 376 ++++++++ mmengine/checkpoint/loader.py | 322 +++++++ mmengine/checkpoint/utils.py | 140 +++ mmengine/runner/checkpoint.py | 835 +----------------- mmengine/testing/__init__.py | 10 +- mmengine/testing/compare.py | 4 + .../test_io.py} | 186 +--- tests/test_checkpoint/test_loader.py | 127 +++ 15 files changed, 1127 insertions(+), 965 deletions(-) create mode 100644 docs/en/api/checkpoint.rst create mode 100644 docs/zh_cn/api/checkpoint.rst create mode 100644 mmengine/checkpoint/__init__.py create mode 100644 mmengine/checkpoint/io.py create mode 100644 mmengine/checkpoint/loader.py create mode 100644 mmengine/checkpoint/utils.py rename tests/{test_runner/test_checkpoint.py => test_checkpoint/test_io.py} (74%) create mode 100644 tests/test_checkpoint/test_loader.py diff --git a/docs/en/api/checkpoint.rst b/docs/en/api/checkpoint.rst new file mode 100644 index 0000000000..18c0c73c91 --- /dev/null +++ b/docs/en/api/checkpoint.rst @@ -0,0 +1,34 @@ +.. role:: hidden + :class: hidden-section + +mmengine.checkpoint +=================================== + +.. contents:: mmengine.checkpoint + :depth: 2 + :local: + :backlinks: top + +.. currentmodule:: mmengine.checkpoint + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + CheckpointLoader + +.. autosummary:: + :toctree: generated + :nosignatures: + + load_checkpoint + save_checkpoint + load_state_dict + get_state_dict + weights_to_cpu + find_latest_checkpoint + get_deprecated_model_names + get_external_models + get_mmcls_models + get_torchvision_models diff --git a/docs/en/api/runner.rst b/docs/en/api/runner.rst index 8738472b78..4daec61347 100644 --- a/docs/en/api/runner.rst +++ b/docs/en/api/runner.rst @@ -39,6 +39,10 @@ Loop Checkpoints ---------------- +.. warn:: + + All functions and classes in this file have been moved to `mmengine.checkpoint`. Please import them from `mmengine.checkpoint`. + .. autosummary:: :toctree: generated :nosignatures: diff --git a/docs/en/index.rst b/docs/en/index.rst index 66e7f8fd1c..f72284b193 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -103,6 +103,7 @@ You can switch between Chinese and English documents in the lower-left corner of mmengine.dataset mmengine.infer mmengine.device + mmengine.checkpoint mmengine.hub mmengine.logging mmengine.visualization diff --git a/docs/zh_cn/api/checkpoint.rst b/docs/zh_cn/api/checkpoint.rst new file mode 100644 index 0000000000..18c0c73c91 --- /dev/null +++ b/docs/zh_cn/api/checkpoint.rst @@ -0,0 +1,34 @@ +.. role:: hidden + :class: hidden-section + +mmengine.checkpoint +=================================== + +.. contents:: mmengine.checkpoint + :depth: 2 + :local: + :backlinks: top + +.. currentmodule:: mmengine.checkpoint + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + CheckpointLoader + +.. autosummary:: + :toctree: generated + :nosignatures: + + load_checkpoint + save_checkpoint + load_state_dict + get_state_dict + weights_to_cpu + find_latest_checkpoint + get_deprecated_model_names + get_external_models + get_mmcls_models + get_torchvision_models diff --git a/docs/zh_cn/api/runner.rst b/docs/zh_cn/api/runner.rst index 8738472b78..ad674b43c4 100644 --- a/docs/zh_cn/api/runner.rst +++ b/docs/zh_cn/api/runner.rst @@ -39,6 +39,10 @@ Loop Checkpoints ---------------- +.. warn:: + + 所有的函数和类在这个文件中已经被移动到 `mmengine.checkpoint`。请从 `mmengine.checkpoint` 导入它们。 + .. autosummary:: :toctree: generated :nosignatures: diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 6aa95e75e9..a64f1acc61 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -103,6 +103,7 @@ mmengine.dataset mmengine.infer mmengine.device + mmengine.checkpoint mmengine.hub mmengine.logging mmengine.visualization diff --git a/mmengine/checkpoint/__init__.py b/mmengine/checkpoint/__init__.py new file mode 100644 index 0000000000..ea8d32cfc7 --- /dev/null +++ b/mmengine/checkpoint/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .io import (get_state_dict, load_checkpoint, load_state_dict, + save_checkpoint, weights_to_cpu) +from .loader import CheckpointLoader +from .utils import (find_latest_checkpoint, get_deprecated_model_names, + get_external_models, get_mmcls_models, + get_torchvision_models) + +__all__ = [ + 'CheckpointLoader', 'find_latest_checkpoint', 'get_deprecated_model_names', + 'get_external_models', 'get_mmcls_models', 'get_state_dict', + 'get_torchvision_models', 'load_checkpoint', 'load_state_dict', + 'save_checkpoint', 'weights_to_cpu' +] diff --git a/mmengine/checkpoint/io.py b/mmengine/checkpoint/io.py new file mode 100644 index 0000000000..ad33deaf98 --- /dev/null +++ b/mmengine/checkpoint/io.py @@ -0,0 +1,376 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import io +import logging +import os.path as osp +import re +from collections import OrderedDict, namedtuple +from tempfile import TemporaryDirectory + +import torch + +from mmengine.dist import get_dist_info +from mmengine.fileio import FileClient, get_file_backend +from mmengine.logging import print_log +from mmengine.model import BaseTTAModel, is_model_wrapper +from mmengine.utils import apply_to, deprecated_function +from .loader import CheckpointLoader + + +def _load_checkpoint(filename, map_location=None, logger=None): + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str, optional): Same as :func:`torch.load`. + Defaults to None. + logger (:mod:`logging.Logger`, optional): The logger for error message. + Defaults to None + + Returns: + dict or OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + return CheckpointLoader.load_checkpoint(filename, map_location, logger) + + +def _load_checkpoint_with_prefix(prefix, filename, map_location=None): + """Load partial pretrained model with specific prefix. + + Args: + prefix (str): The prefix of sub-module. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. + Defaults to None. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + checkpoint = _load_checkpoint(filename, map_location=map_location) + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if not prefix.endswith('.'): + prefix += '.' + prefix_len = len(prefix) + + state_dict = { + k[prefix_len:]: v + for k, v in state_dict.items() if k.startswith(prefix) + } + + assert state_dict, f'{prefix} is not in the pretrained model' + return state_dict + + +def _load_checkpoint_to_model(model, + checkpoint, + strict=False, + logger=None, + revise_keys=[(r'^module\.', '')]): + + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + # strip prefix of state_dict + metadata = getattr(state_dict, '_metadata', OrderedDict()) + for p, r in revise_keys: + state_dict = OrderedDict( + {re.sub(p, r, k): v + for k, v in state_dict.items()}) + # Keep metadata in state_dict + state_dict._metadata = metadata + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +class _IncompatibleKeys( + namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): + + def __repr__(self): + if not self.missing_keys and not self.unexpected_keys: + return '' + return super().__repr__() + + __str__ = __repr__ + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Defaults to False. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + missing_keys = [] + err_msg = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + # use _load_from_state_dict to enable checkpoint version control + def load(module, local_state_dict, prefix=''): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_model_wrapper(module) or isinstance(module, BaseTTAModel): + module = module.module + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(local_state_dict, prefix, local_metadata, + True, missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + child_prefix = prefix + name + '.' + child_state_dict = { + k: v + for k, v in local_state_dict.items() + if k.startswith(child_prefix) + } + load(child, child_state_dict, child_prefix) + + # Note that the hook can modify missing_keys and unexpected_keys. + incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) + if hasattr(module, '_load_state_dict_post_hooks'): + for hook in module._load_state_dict_post_hooks.values(): + out = hook(module, incompatible_keys) + assert out is None, ( + 'Hooks registered with ' + '``register_load_state_dict_post_hook`` are not expected ' + 'to return new values, if incompatible_keys need to be ' + 'modified, it should be done inplace.') + + load(module, state_dict) + load = None # break load->load reference cycle + + # ignore "num_batches_tracked" of BN layers + missing_keys = [ + key for key in missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) + if strict: + raise RuntimeError(err_msg) + else: + print_log(err_msg, logger=logger, level=logging.WARNING) + + +def load_checkpoint(model, + filename, + map_location=None, + strict=False, + logger=None, + revise_keys=[(r'^module\.', '')]): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + revise_keys (list): A list of customized keywords to modify the + state_dict in checkpoint. Each item is a (pattern, replacement) + pair of the regular expression operations. Defaults to strip + the prefix 'module.' by [(r'^module\\.', '')]. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location, logger) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + + return _load_checkpoint_to_model(model, checkpoint, strict, logger, + revise_keys) + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + # stash metadata to put in state_dict later + metadata = getattr(state_dict, '_metadata', OrderedDict()) + state_dict = apply_to(state_dict, lambda x: hasattr(x, 'cpu'), + lambda x: x.cpu()) + state_dict._metadata = metadata + return state_dict + + +@deprecated_function( + since='0.3.0', + removed_in='0.5.0', + instructions='`_save_to_state_dict` will be deprecated in the future, ' + 'please use `nn.Module._save_to_state_dict` directly.') +def _save_to_state_dict(module, destination, prefix, keep_vars): + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + keep_vars (bool): Whether to keep the variable property of the + parameters. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + if buf is not None and name not in module._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module, destination=None, prefix='', keep_vars=False): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Defaults to False. + + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_model_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict( + version=module._version) + module._save_to_state_dict(destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict( + child, destination, prefix + name + '.', keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def save_checkpoint(checkpoint, + filename, + file_client_args=None, + backend_args=None): + """Save checkpoint to file. + + Args: + checkpoint (dict): Module whose params are to be saved. + filename (str): Checkpoint filename. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmengine.fileio.FileClient` for details. + Defaults to None. It will be deprecated in future. Please use + `backend_args` instead. + backend_args (dict, optional): Arguments to instantiate the + prefix of uri corresponding backend. Defaults to None. + New in v0.2.0. + """ + if file_client_args is not None: + print_log( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', + logger='current', + level=logging.WARNING) + if backend_args is not None: + raise ValueError( + '"file_client_args" and "backend_args" cannot be set ' + 'at the same time.') + + if filename.startswith('pavi://'): + if file_client_args is not None or backend_args is not None: + raise ValueError( + '"file_client_args" or "backend_args" should be "None" if ' + 'filename starts with "pavi://"') + try: + from pavi import exception, modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except exception.NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, 'wb') as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + file_client = FileClient.infer_client(file_client_args, filename) + if file_client_args is None: + file_backend = get_file_backend( + filename, backend_args=backend_args) + else: + file_backend = file_client + + with io.BytesIO() as f: + torch.save(checkpoint, f) + file_backend.put(f.getvalue(), filename) diff --git a/mmengine/checkpoint/loader.py b/mmengine/checkpoint/loader.py new file mode 100644 index 0000000000..00b409a1a0 --- /dev/null +++ b/mmengine/checkpoint/loader.py @@ -0,0 +1,322 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import io +import logging +import os +import os.path as osp +import re +from collections import OrderedDict +from tempfile import TemporaryDirectory +from typing import Callable, Dict + +import torch + +from mmengine.dist import get_dist_info +from mmengine.fileio import get_file_backend +from mmengine.logging import print_log +from mmengine.utils.dl_utils import load_url +from .utils import (_get_mmengine_home, get_deprecated_model_names, + get_external_models, get_mmcls_models, + get_torchvision_models) + + +class CheckpointLoader: + """A general checkpoint loader to manage all schemes.""" + + _schemes: Dict[str, Callable] = {} + + @classmethod + def _register_scheme(cls, prefixes, loader, force=False): + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if (prefix not in cls._schemes) or force: + cls._schemes[prefix] = loader + else: + raise KeyError( + f'{prefix} is already registered as a loader backend, ' + 'add "force=True" if you want to override it') + # sort, longer prefixes take priority + cls._schemes = OrderedDict( + sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True)) + + @classmethod + def register_scheme(cls, prefixes, loader=None, force=False): + """Register a loader to CheckpointLoader. + + This method can be used as a normal class method or a decorator. + + Args: + prefixes (str or list[str] or tuple[str]): + The prefix of the registered loader. + loader (function, optional): The loader function to be registered. + When this method is used as a decorator, loader is None. + Defaults to None. + force (bool, optional): Whether to override the loader + if the prefix has already been registered. Defaults to False. + """ + + if loader is not None: + cls._register_scheme(prefixes, loader, force=force) + return + + def _register(loader_cls): + cls._register_scheme(prefixes, loader_cls, force=force) + return loader_cls + + return _register + + @classmethod + def _get_checkpoint_loader(cls, path): + """Finds a loader that supports the given path. Falls back to the local + loader if no other loader is found. + + Args: + path (str): checkpoint path + + Returns: + callable: checkpoint loader + """ + for p in cls._schemes: + # use regular match to handle some cases that where the prefix of + # loader has a prefix. For example, both 's3://path' and + # 'open-mmlab:s3://path' should return `load_from_ceph` + if re.match(p, path) is not None: + return cls._schemes[p] + + @classmethod + def load_checkpoint(cls, filename, map_location=None, logger='current'): + """load checkpoint through URL scheme path. + + Args: + filename (str): checkpoint file name with given prefix + map_location (str, optional): Same as :func:`torch.load`. + Defaults to None + logger (str): The logger for message. Defaults to 'current'. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + checkpoint_loader = cls._get_checkpoint_loader(filename) + class_name = checkpoint_loader.__name__ + print_log( + f'Loads checkpoint by {class_name[10:]} backend from path: ' + f'{filename}', + logger=logger) + return checkpoint_loader(filename, map_location) + + +@CheckpointLoader.register_scheme(prefixes='') +def load_from_local(filename, map_location): + """load checkpoint by local file path. + + Args: + filename (str): local checkpoint file path + map_location (str, optional): Same as :func:`torch.load`. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + filename = osp.expanduser(filename) + if not osp.isfile(filename): + raise FileNotFoundError(f'{filename} can not be found.') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes=('http://', 'https://')) +def load_from_http(filename, + map_location=None, + model_dir=None, + progress=os.isatty(0)): + """load checkpoint through HTTP or HTTPS scheme path. In distributed + setting, this function only download checkpoint at local rank 0. + + Args: + filename (str): checkpoint file path with modelzoo or + torchvision prefix + map_location (str, optional): Same as :func:`torch.load`. + model_dir (string, optional): directory in which to save the object, + Defaults to None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + rank, world_size = get_dist_info() + if rank == 0: + checkpoint = load_url( + filename, + model_dir=model_dir, + map_location=map_location, + progress=progress) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = load_url( + filename, + model_dir=model_dir, + map_location=map_location, + progress=progress) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes='pavi://') +def load_from_pavi(filename, map_location=None): + """load checkpoint through the file path prefixed with pavi. In distributed + setting, this function download ckpt at all ranks to different temporary + directories. + + Args: + filename (str): checkpoint file path with pavi prefix + map_location (str, optional): Same as :func:`torch.load`. + Defaults to None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + assert filename.startswith('pavi://'), \ + f'Expected filename startswith `pavi://`, but get {filename}' + model_path = filename[7:] + + try: + from pavi import modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme( + prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://']) +def load_from_ceph(filename, map_location=None, backend='petrel'): + """load checkpoint through the file path prefixed with s3. In distributed + setting, this function download ckpt at all ranks to different temporary + directories. + + Args: + filename (str): checkpoint file path with s3 prefix + map_location (str, optional): Same as :func:`torch.load`. + backend (str, optional): The storage backend type. + Defaults to 'petrel'. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + file_backend = get_file_backend( + filename, backend_args={'backend': backend}) + with io.BytesIO(file_backend.get(filename)) as buffer: + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://')) +def load_from_torchvision(filename, map_location=None): + """load checkpoint through the file path prefixed with modelzoo or + torchvision. + + Args: + filename (str): checkpoint file path with modelzoo or + torchvision prefix + map_location (str, optional): Same as :func:`torch.load`. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + model_urls = get_torchvision_models() + if filename.startswith('modelzoo://'): + print_log( + 'The URL scheme of "modelzoo://" is deprecated, please ' + 'use "torchvision://" instead', + logger='current', + level=logging.WARNING) + model_name = filename[11:] + else: + model_name = filename[14:] + return load_from_http(model_urls[model_name], map_location=map_location) + + +@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) +def load_from_openmmlab(filename, map_location=None): + """load checkpoint through the file path prefixed with open-mmlab or + openmmlab. + + Args: + filename (str): checkpoint file path with open-mmlab or + openmmlab prefix + map_location (str, optional): Same as :func:`torch.load`. + Defaults to None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + model_urls = get_external_models() + prefix_str = 'open-mmlab://' + if filename.startswith(prefix_str): + model_name = filename[13:] + else: + model_name = filename[12:] + prefix_str = 'openmmlab://' + + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + print_log( + f'{prefix_str}{model_name} is deprecated in favor ' + f'of {prefix_str}{deprecated_urls[model_name]}', + logger='current', + level=logging.WARNING) + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(('http://', 'https://')): + checkpoint = load_from_http(model_url, map_location=map_location) + else: + filename = osp.join(_get_mmengine_home(), model_url) + if not osp.isfile(filename): + raise FileNotFoundError(f'{filename} can not be found.') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +def _process_mmcls_checkpoint(checkpoint): + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + # Some checkpoints converted from 3rd-party repo don't + # have the "state_dict" key. + state_dict = checkpoint + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('backbone.'): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + + return new_checkpoint + + +@CheckpointLoader.register_scheme(prefixes='mmcls://') +def load_from_mmcls(filename, map_location=None): + """load checkpoint through the file path prefixed with mmcls. + + Args: + filename (str): checkpoint file path with mmcls prefix + map_location (str, optional): Same as :func:`torch.load`. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_from_http( + model_urls[model_name], map_location=map_location) + checkpoint = _process_mmcls_checkpoint(checkpoint) + return checkpoint diff --git a/mmengine/checkpoint/utils.py b/mmengine/checkpoint/utils.py new file mode 100644 index 0000000000..79b7105e30 --- /dev/null +++ b/mmengine/checkpoint/utils.py @@ -0,0 +1,140 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import pkgutil +from importlib import import_module +from typing import Optional + +import mmengine +from mmengine.fileio import load as load_file +from mmengine.logging import print_log +from mmengine.utils import digit_version, mkdir_or_exist + +# `MMENGINE_HOME` is the highest priority directory to save checkpoints +# downloaded from Internet. If it is not set, as a workaround, using +# `XDG_CACHE_HOME`` or `~/.cache` instead. +# Note that `XDG_CACHE_HOME` defines the base directory relative to which +# user-specific non-essential data files should be stored. If `XDG_CACHE_HOME` +# is either not set or empty, a default equal to `~/.cache` should be used. +ENV_MMENGINE_HOME = 'MMENGINE_HOME' +ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +DEFAULT_CACHE_DIR = '~/.cache' + + +def _get_mmengine_home(): + mmengine_home = os.path.expanduser( + os.getenv( + ENV_MMENGINE_HOME, + os.path.join( + os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmengine'))) + + mkdir_or_exist(mmengine_home) + return mmengine_home + + +def get_torchvision_models(): + import torchvision + if digit_version(torchvision.__version__) < digit_version('0.13.0a0'): + model_urls = dict() + # When the version of torchvision is lower than 0.13, the model url is + # not declared in `torchvision.model.__init__.py`, so we need to + # iterate through `torchvision.models.__path__` to get the url for each + # model. + for _, name, ispkg in pkgutil.walk_packages( + torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + else: + # Since torchvision bumps to v0.13, the weight loading logic, + # model keys and model urls have been changed. Here the URLs of old + # version is loaded to avoid breaking back compatibility. If the + # torchvision version>=0.13.0, new URLs will be added. Users can get + # the resnet50 checkpoint by setting 'resnet50.imagent1k_v1', + # 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config. + json_path = osp.join(mmengine.__path__[0], 'hub/torchvision_0.12.json') + model_urls = mmengine.load(json_path) + if digit_version(torchvision.__version__) < digit_version('0.14.0a0'): + weights_list = [ + cls for cls_name, cls in torchvision.models.__dict__.items() + if cls_name.endswith('_Weights') + ] + else: + weights_list = [ + torchvision.models.get_model_weights(model) + for model in torchvision.models.list_models(torchvision.models) + ] + + for cls in weights_list: + # The name of torchvision model weights classes ends with + # `_Weights` such as `ResNet18_Weights`. However, some model weight + # classes, such as `MNASNet0_75_Weights` does not have any urls in + # torchvision 0.13.0 and cannot be iterated. Here we simply check + # `DEFAULT` attribute to ensure the class is not empty. + if not hasattr(cls, 'DEFAULT'): + continue + # Since `cls.DEFAULT` can not be accessed by iterating cls, we set + # default urls explicitly. + cls_name = cls.__name__ + cls_key = cls_name.replace('_Weights', '').lower() + model_urls[f'{cls_key}.default'] = cls.DEFAULT.url + for weight_enum in cls: + cls_key = cls_name.replace('_Weights', '').lower() + cls_key = f'{cls_key}.{weight_enum.name.lower()}' + model_urls[cls_key] = weight_enum.url + + return model_urls + + +def get_external_models(): + mmengine_home = _get_mmengine_home() + default_json_path = osp.join(mmengine.__path__[0], 'hub/openmmlab.json') + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(mmengine_home, 'open_mmlab.json') + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(mmengine.__path__[0], 'hub/mmcls.json') + mmcls_urls = load_file(mmcls_json_path) + + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmengine.__path__[0], 'hub/deprecated.json') + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + + return deprecate_urls + + +def find_latest_checkpoint(path: str) -> Optional[str]: + """Find the latest checkpoint from the given path. + + Refer to https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py # noqa: E501 + + Args: + path(str): The path to find checkpoints. + + Returns: + str or None: File path of the latest checkpoint. + """ + save_file = osp.join(path, 'last_checkpoint') + last_saved: Optional[str] + if os.path.exists(save_file): + with open(save_file) as f: + last_saved = f.read().strip() + else: + print_log('Did not find last_checkpoint to be resumed.') + last_saved = None + return last_saved diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 60d71a735b..ead628f0ae 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -1,815 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. -import io -import logging -import os -import os.path as osp -import pkgutil -import re -from collections import OrderedDict, namedtuple -from importlib import import_module -from tempfile import TemporaryDirectory -from typing import Callable, Dict, Optional - -import torch - -import mmengine -from mmengine.dist import get_dist_info -from mmengine.fileio import FileClient, get_file_backend -from mmengine.fileio import load as load_file -from mmengine.logging import print_log -from mmengine.model import BaseTTAModel, is_model_wrapper -from mmengine.utils import (apply_to, deprecated_function, digit_version, - mkdir_or_exist) -from mmengine.utils.dl_utils import load_url - -# `MMENGINE_HOME` is the highest priority directory to save checkpoints -# downloaded from Internet. If it is not set, as a workaround, using -# `XDG_CACHE_HOME`` or `~/.cache` instead. -# Note that `XDG_CACHE_HOME` defines the base directory relative to which -# user-specific non-essential data files should be stored. If `XDG_CACHE_HOME` -# is either not set or empty, a default equal to `~/.cache` should be used. -ENV_MMENGINE_HOME = 'MMENGINE_HOME' -ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' -DEFAULT_CACHE_DIR = '~/.cache' - - -class _IncompatibleKeys( - namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): - - def __repr__(self): - if not self.missing_keys and not self.unexpected_keys: - return '' - return super().__repr__() - - __str__ = __repr__ - - -def _get_mmengine_home(): - mmengine_home = os.path.expanduser( - os.getenv( - ENV_MMENGINE_HOME, - os.path.join( - os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmengine'))) - - mkdir_or_exist(mmengine_home) - return mmengine_home - - -def load_state_dict(module, state_dict, strict=False, logger=None): - """Load state_dict to a module. - - This method is modified from :meth:`torch.nn.Module.load_state_dict`. - Default value for ``strict`` is set to ``False`` and the message for - param mismatch will be shown even if strict is False. - - Args: - module (Module): Module that receives the state_dict. - state_dict (OrderedDict): Weights. - strict (bool): whether to strictly enforce that the keys - in :attr:`state_dict` match the keys returned by this module's - :meth:`~torch.nn.Module.state_dict` function. Defaults to False. - logger (:obj:`logging.Logger`, optional): Logger to log the error - message. If not specified, print function will be used. - """ - unexpected_keys = [] - missing_keys = [] - err_msg = [] - - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - # use _load_from_state_dict to enable checkpoint version control - def load(module, local_state_dict, prefix=''): - # recursively check parallel module in case that the model has a - # complicated structure, e.g., nn.Module(nn.Module(DDP)) - if is_model_wrapper(module) or isinstance(module, BaseTTAModel): - module = module.module - local_metadata = {} if metadata is None else metadata.get( - prefix[:-1], {}) - module._load_from_state_dict(local_state_dict, prefix, local_metadata, - True, missing_keys, unexpected_keys, - err_msg) - for name, child in module._modules.items(): - if child is not None: - child_prefix = prefix + name + '.' - child_state_dict = { - k: v - for k, v in local_state_dict.items() - if k.startswith(child_prefix) - } - load(child, child_state_dict, child_prefix) - - # Note that the hook can modify missing_keys and unexpected_keys. - incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) - if hasattr(module, '_load_state_dict_post_hooks'): - for hook in module._load_state_dict_post_hooks.values(): - out = hook(module, incompatible_keys) - assert out is None, ( - 'Hooks registered with ' - '``register_load_state_dict_post_hook`` are not expected ' - 'to return new values, if incompatible_keys need to be ' - 'modified, it should be done inplace.') - - load(module, state_dict) - load = None # break load->load reference cycle - - # ignore "num_batches_tracked" of BN layers - missing_keys = [ - key for key in missing_keys if 'num_batches_tracked' not in key - ] - - if unexpected_keys: - err_msg.append('unexpected key in source ' - f'state_dict: {", ".join(unexpected_keys)}\n') - if missing_keys: - err_msg.append( - f'missing keys in source state_dict: {", ".join(missing_keys)}\n') - - rank, _ = get_dist_info() - if len(err_msg) > 0 and rank == 0: - err_msg.insert( - 0, 'The model and loaded state dict do not match exactly\n') - err_msg = '\n'.join(err_msg) - if strict: - raise RuntimeError(err_msg) - else: - print_log(err_msg, logger=logger, level=logging.WARNING) - - -def get_torchvision_models(): - import torchvision - if digit_version(torchvision.__version__) < digit_version('0.13.0a0'): - model_urls = dict() - # When the version of torchvision is lower than 0.13, the model url is - # not declared in `torchvision.model.__init__.py`, so we need to - # iterate through `torchvision.models.__path__` to get the url for each - # model. - for _, name, ispkg in pkgutil.walk_packages( - torchvision.models.__path__): - if ispkg: - continue - _zoo = import_module(f'torchvision.models.{name}') - if hasattr(_zoo, 'model_urls'): - _urls = getattr(_zoo, 'model_urls') - model_urls.update(_urls) - else: - # Since torchvision bumps to v0.13, the weight loading logic, - # model keys and model urls have been changed. Here the URLs of old - # version is loaded to avoid breaking back compatibility. If the - # torchvision version>=0.13.0, new URLs will be added. Users can get - # the resnet50 checkpoint by setting 'resnet50.imagent1k_v1', - # 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config. - json_path = osp.join(mmengine.__path__[0], 'hub/torchvision_0.12.json') - model_urls = mmengine.load(json_path) - if digit_version(torchvision.__version__) < digit_version('0.14.0a0'): - weights_list = [ - cls for cls_name, cls in torchvision.models.__dict__.items() - if cls_name.endswith('_Weights') - ] - else: - weights_list = [ - torchvision.models.get_model_weights(model) - for model in torchvision.models.list_models(torchvision.models) - ] - - for cls in weights_list: - # The name of torchvision model weights classes ends with - # `_Weights` such as `ResNet18_Weights`. However, some model weight - # classes, such as `MNASNet0_75_Weights` does not have any urls in - # torchvision 0.13.0 and cannot be iterated. Here we simply check - # `DEFAULT` attribute to ensure the class is not empty. - if not hasattr(cls, 'DEFAULT'): - continue - # Since `cls.DEFAULT` can not be accessed by iterating cls, we set - # default urls explicitly. - cls_name = cls.__name__ - cls_key = cls_name.replace('_Weights', '').lower() - model_urls[f'{cls_key}.default'] = cls.DEFAULT.url - for weight_enum in cls: - cls_key = cls_name.replace('_Weights', '').lower() - cls_key = f'{cls_key}.{weight_enum.name.lower()}' - model_urls[cls_key] = weight_enum.url - - return model_urls - - -def get_external_models(): - mmengine_home = _get_mmengine_home() - default_json_path = osp.join(mmengine.__path__[0], 'hub/openmmlab.json') - default_urls = load_file(default_json_path) - assert isinstance(default_urls, dict) - external_json_path = osp.join(mmengine_home, 'open_mmlab.json') - if osp.exists(external_json_path): - external_urls = load_file(external_json_path) - assert isinstance(external_urls, dict) - default_urls.update(external_urls) - - return default_urls - - -def get_mmcls_models(): - mmcls_json_path = osp.join(mmengine.__path__[0], 'hub/mmcls.json') - mmcls_urls = load_file(mmcls_json_path) - - return mmcls_urls - - -def get_deprecated_model_names(): - deprecate_json_path = osp.join(mmengine.__path__[0], 'hub/deprecated.json') - deprecate_urls = load_file(deprecate_json_path) - assert isinstance(deprecate_urls, dict) - - return deprecate_urls - - -def _process_mmcls_checkpoint(checkpoint): - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - else: - # Some checkpoints converted from 3rd-party repo don't - # have the "state_dict" key. - state_dict = checkpoint - new_state_dict = OrderedDict() - for k, v in state_dict.items(): - if k.startswith('backbone.'): - new_state_dict[k[9:]] = v - new_checkpoint = dict(state_dict=new_state_dict) - - return new_checkpoint - - -class CheckpointLoader: - """A general checkpoint loader to manage all schemes.""" - - _schemes: Dict[str, Callable] = {} - - @classmethod - def _register_scheme(cls, prefixes, loader, force=False): - if isinstance(prefixes, str): - prefixes = [prefixes] - else: - assert isinstance(prefixes, (list, tuple)) - for prefix in prefixes: - if (prefix not in cls._schemes) or force: - cls._schemes[prefix] = loader - else: - raise KeyError( - f'{prefix} is already registered as a loader backend, ' - 'add "force=True" if you want to override it') - # sort, longer prefixes take priority - cls._schemes = OrderedDict( - sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True)) - - @classmethod - def register_scheme(cls, prefixes, loader=None, force=False): - """Register a loader to CheckpointLoader. - - This method can be used as a normal class method or a decorator. - - Args: - prefixes (str or list[str] or tuple[str]): - The prefix of the registered loader. - loader (function, optional): The loader function to be registered. - When this method is used as a decorator, loader is None. - Defaults to None. - force (bool, optional): Whether to override the loader - if the prefix has already been registered. Defaults to False. - """ - - if loader is not None: - cls._register_scheme(prefixes, loader, force=force) - return - - def _register(loader_cls): - cls._register_scheme(prefixes, loader_cls, force=force) - return loader_cls - - return _register - - @classmethod - def _get_checkpoint_loader(cls, path): - """Finds a loader that supports the given path. Falls back to the local - loader if no other loader is found. - - Args: - path (str): checkpoint path - - Returns: - callable: checkpoint loader - """ - for p in cls._schemes: - # use regular match to handle some cases that where the prefix of - # loader has a prefix. For example, both 's3://path' and - # 'open-mmlab:s3://path' should return `load_from_ceph` - if re.match(p, path) is not None: - return cls._schemes[p] - - @classmethod - def load_checkpoint(cls, filename, map_location=None, logger='current'): - """load checkpoint through URL scheme path. - - Args: - filename (str): checkpoint file name with given prefix - map_location (str, optional): Same as :func:`torch.load`. - Defaults to None - logger (str): The logger for message. Defaults to 'current'. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - - checkpoint_loader = cls._get_checkpoint_loader(filename) - class_name = checkpoint_loader.__name__ - print_log( - f'Loads checkpoint by {class_name[10:]} backend from path: ' - f'{filename}', - logger=logger) - return checkpoint_loader(filename, map_location) - - -@CheckpointLoader.register_scheme(prefixes='') -def load_from_local(filename, map_location): - """load checkpoint by local file path. - - Args: - filename (str): local checkpoint file path - map_location (str, optional): Same as :func:`torch.load`. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - filename = osp.expanduser(filename) - if not osp.isfile(filename): - raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) - return checkpoint - - -@CheckpointLoader.register_scheme(prefixes=('http://', 'https://')) -def load_from_http(filename, - map_location=None, - model_dir=None, - progress=os.isatty(0)): - """load checkpoint through HTTP or HTTPS scheme path. In distributed - setting, this function only download checkpoint at local rank 0. - - Args: - filename (str): checkpoint file path with modelzoo or - torchvision prefix - map_location (str, optional): Same as :func:`torch.load`. - model_dir (string, optional): directory in which to save the object, - Defaults to None - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - rank, world_size = get_dist_info() - if rank == 0: - checkpoint = load_url( - filename, - model_dir=model_dir, - map_location=map_location, - progress=progress) - if world_size > 1: - torch.distributed.barrier() - if rank > 0: - checkpoint = load_url( - filename, - model_dir=model_dir, - map_location=map_location, - progress=progress) - return checkpoint - - -@CheckpointLoader.register_scheme(prefixes='pavi://') -def load_from_pavi(filename, map_location=None): - """load checkpoint through the file path prefixed with pavi. In distributed - setting, this function download ckpt at all ranks to different temporary - directories. - - Args: - filename (str): checkpoint file path with pavi prefix - map_location (str, optional): Same as :func:`torch.load`. - Defaults to None - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - assert filename.startswith('pavi://'), \ - f'Expected filename startswith `pavi://`, but get {filename}' - model_path = filename[7:] - - try: - from pavi import modelcloud - except ImportError: - raise ImportError( - 'Please install pavi to load checkpoint from modelcloud.') - - model = modelcloud.get(model_path) - with TemporaryDirectory() as tmp_dir: - downloaded_file = osp.join(tmp_dir, model.name) - model.download(downloaded_file) - checkpoint = torch.load(downloaded_file, map_location=map_location) - return checkpoint - - -@CheckpointLoader.register_scheme( - prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://']) -def load_from_ceph(filename, map_location=None, backend='petrel'): - """load checkpoint through the file path prefixed with s3. In distributed - setting, this function download ckpt at all ranks to different temporary - directories. - - Args: - filename (str): checkpoint file path with s3 prefix - map_location (str, optional): Same as :func:`torch.load`. - backend (str, optional): The storage backend type. - Defaults to 'petrel'. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - file_backend = get_file_backend( - filename, backend_args={'backend': backend}) - with io.BytesIO(file_backend.get(filename)) as buffer: - checkpoint = torch.load(buffer, map_location=map_location) - return checkpoint - - -@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://')) -def load_from_torchvision(filename, map_location=None): - """load checkpoint through the file path prefixed with modelzoo or - torchvision. - - Args: - filename (str): checkpoint file path with modelzoo or - torchvision prefix - map_location (str, optional): Same as :func:`torch.load`. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - model_urls = get_torchvision_models() - if filename.startswith('modelzoo://'): - print_log( - 'The URL scheme of "modelzoo://" is deprecated, please ' - 'use "torchvision://" instead', - logger='current', - level=logging.WARNING) - model_name = filename[11:] - else: - model_name = filename[14:] - return load_from_http(model_urls[model_name], map_location=map_location) - - -@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) -def load_from_openmmlab(filename, map_location=None): - """load checkpoint through the file path prefixed with open-mmlab or - openmmlab. - - Args: - filename (str): checkpoint file path with open-mmlab or - openmmlab prefix - map_location (str, optional): Same as :func:`torch.load`. - Defaults to None - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - - model_urls = get_external_models() - prefix_str = 'open-mmlab://' - if filename.startswith(prefix_str): - model_name = filename[13:] - else: - model_name = filename[12:] - prefix_str = 'openmmlab://' - - deprecated_urls = get_deprecated_model_names() - if model_name in deprecated_urls: - print_log( - f'{prefix_str}{model_name} is deprecated in favor ' - f'of {prefix_str}{deprecated_urls[model_name]}', - logger='current', - level=logging.WARNING) - model_name = deprecated_urls[model_name] - model_url = model_urls[model_name] - # check if is url - if model_url.startswith(('http://', 'https://')): - checkpoint = load_from_http(model_url, map_location=map_location) - else: - filename = osp.join(_get_mmengine_home(), model_url) - if not osp.isfile(filename): - raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) - return checkpoint - - -@CheckpointLoader.register_scheme(prefixes='mmcls://') -def load_from_mmcls(filename, map_location=None): - """load checkpoint through the file path prefixed with mmcls. - - Args: - filename (str): checkpoint file path with mmcls prefix - map_location (str, optional): Same as :func:`torch.load`. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - - model_urls = get_mmcls_models() - model_name = filename[8:] - checkpoint = load_from_http( - model_urls[model_name], map_location=map_location) - checkpoint = _process_mmcls_checkpoint(checkpoint) - return checkpoint - - -def _load_checkpoint(filename, map_location=None, logger=None): - """Load checkpoint from somewhere (modelzoo, file, url). - - Args: - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for - details. - map_location (str, optional): Same as :func:`torch.load`. - Defaults to None. - logger (:mod:`logging.Logger`, optional): The logger for error message. - Defaults to None - - Returns: - dict or OrderedDict: The loaded checkpoint. It can be either an - OrderedDict storing model weights or a dict containing other - information, which depends on the checkpoint. - """ - return CheckpointLoader.load_checkpoint(filename, map_location, logger) - - -def _load_checkpoint_with_prefix(prefix, filename, map_location=None): - """Load partial pretrained model with specific prefix. - - Args: - prefix (str): The prefix of sub-module. - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for - details. - map_location (str | None): Same as :func:`torch.load`. - Defaults to None. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - - checkpoint = _load_checkpoint(filename, map_location=map_location) - - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - else: - state_dict = checkpoint - if not prefix.endswith('.'): - prefix += '.' - prefix_len = len(prefix) - - state_dict = { - k[prefix_len:]: v - for k, v in state_dict.items() if k.startswith(prefix) - } - - assert state_dict, f'{prefix} is not in the pretrained model' - return state_dict - - -def _load_checkpoint_to_model(model, - checkpoint, - strict=False, - logger=None, - revise_keys=[(r'^module\.', '')]): - - # get state_dict from checkpoint - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - else: - state_dict = checkpoint - - # strip prefix of state_dict - metadata = getattr(state_dict, '_metadata', OrderedDict()) - for p, r in revise_keys: - state_dict = OrderedDict( - {re.sub(p, r, k): v - for k, v in state_dict.items()}) - # Keep metadata in state_dict - state_dict._metadata = metadata - - # load state_dict - load_state_dict(model, state_dict, strict, logger) - return checkpoint - - -def load_checkpoint(model, - filename, - map_location=None, - strict=False, - logger=None, - revise_keys=[(r'^module\.', '')]): - """Load checkpoint from a file or URI. - - Args: - model (Module): Module to load checkpoint. - filename (str): Accept local filepath, URL, ``torchvision://xxx``, - ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for - details. - map_location (str): Same as :func:`torch.load`. - strict (bool): Whether to allow different params for the model and - checkpoint. - logger (:mod:`logging.Logger` or None): The logger for error message. - revise_keys (list): A list of customized keywords to modify the - state_dict in checkpoint. Each item is a (pattern, replacement) - pair of the regular expression operations. Defaults to strip - the prefix 'module.' by [(r'^module\\.', '')]. - - Returns: - dict or OrderedDict: The loaded checkpoint. - """ - checkpoint = _load_checkpoint(filename, map_location, logger) - # OrderedDict is a subclass of dict - if not isinstance(checkpoint, dict): - raise RuntimeError( - f'No state_dict found in checkpoint file {filename}') - - return _load_checkpoint_to_model(model, checkpoint, strict, logger, - revise_keys) - - -def weights_to_cpu(state_dict): - """Copy a model state_dict to cpu. - - Args: - state_dict (OrderedDict): Model weights on GPU. - - Returns: - OrderedDict: Model weights on GPU. - """ - # stash metadata to put in state_dict later - metadata = getattr(state_dict, '_metadata', OrderedDict()) - state_dict = apply_to(state_dict, lambda x: hasattr(x, 'cpu'), - lambda x: x.cpu()) - state_dict._metadata = metadata - return state_dict - - -@deprecated_function( - since='0.3.0', - removed_in='0.5.0', - instructions='`_save_to_state_dict` will be deprecated in the future, ' - 'please use `nn.Module._save_to_state_dict` directly.') -def _save_to_state_dict(module, destination, prefix, keep_vars): - """Saves module state to `destination` dictionary. - - This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. - - Args: - module (nn.Module): The module to generate state_dict. - destination (dict): A dict where state will be stored. - prefix (str): The prefix for parameters and buffers used in this - module. - keep_vars (bool): Whether to keep the variable property of the - parameters. - """ - for name, param in module._parameters.items(): - if param is not None: - destination[prefix + name] = param if keep_vars else param.detach() - for name, buf in module._buffers.items(): - if buf is not None and name not in module._non_persistent_buffers_set: - destination[prefix + name] = buf if keep_vars else buf.detach() - - -def get_state_dict(module, destination=None, prefix='', keep_vars=False): - """Returns a dictionary containing a whole state of the module. - - Both parameters and persistent buffers (e.g. running averages) are - included. Keys are corresponding parameter and buffer names. - This method is modified from :meth:`torch.nn.Module.state_dict` to - recursively check parallel module in case that the model has a complicated - structure, e.g., nn.Module(nn.Module(DDP)). - - Args: - module (nn.Module): The module to generate state_dict. - destination (OrderedDict): Returned dict for the state of the - module. - prefix (str): Prefix of the key. - keep_vars (bool): Whether to keep the variable property of the - parameters. Defaults to False. - - Returns: - dict: A dictionary containing a whole state of the module. - """ - # recursively check parallel module in case that the model has a - # complicated structure, e.g., nn.Module(nn.Module(DDP)) - if is_model_wrapper(module): - module = module.module - - # below is the same as torch.nn.Module.state_dict() - if destination is None: - destination = OrderedDict() - destination._metadata = OrderedDict() - destination._metadata[prefix[:-1]] = local_metadata = dict( - version=module._version) - module._save_to_state_dict(destination, prefix, keep_vars) - for name, child in module._modules.items(): - if child is not None: - get_state_dict( - child, destination, prefix + name + '.', keep_vars=keep_vars) - for hook in module._state_dict_hooks.values(): - hook_result = hook(module, destination, prefix, local_metadata) - if hook_result is not None: - destination = hook_result - return destination - - -def save_checkpoint(checkpoint, - filename, - file_client_args=None, - backend_args=None): - """Save checkpoint to file. - - Args: - checkpoint (dict): Module whose params are to be saved. - filename (str): Checkpoint filename. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - `backend_args` instead. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - """ - if file_client_args is not None: - print_log( - '"file_client_args" will be deprecated in future. ' - 'Please use "backend_args" instead', - logger='current', - level=logging.WARNING) - if backend_args is not None: - raise ValueError( - '"file_client_args" and "backend_args" cannot be set ' - 'at the same time.') - - if filename.startswith('pavi://'): - if file_client_args is not None or backend_args is not None: - raise ValueError( - '"file_client_args" or "backend_args" should be "None" if ' - 'filename starts with "pavi://"') - try: - from pavi import exception, modelcloud - except ImportError: - raise ImportError( - 'Please install pavi to load checkpoint from modelcloud.') - model_path = filename[7:] - root = modelcloud.Folder() - model_dir, model_name = osp.split(model_path) - try: - model = modelcloud.get(model_dir) - except exception.NodeNotFoundError: - model = root.create_training_model(model_dir) - with TemporaryDirectory() as tmp_dir: - checkpoint_file = osp.join(tmp_dir, model_name) - with open(checkpoint_file, 'wb') as f: - torch.save(checkpoint, f) - f.flush() - model.create_file(checkpoint_file, name=model_name) - else: - file_client = FileClient.infer_client(file_client_args, filename) - if file_client_args is None: - file_backend = get_file_backend( - filename, backend_args=backend_args) - else: - file_backend = file_client - - with io.BytesIO() as f: - torch.save(checkpoint, f) - file_backend.put(f.getvalue(), filename) - - -def find_latest_checkpoint(path: str) -> Optional[str]: - """Find the latest checkpoint from the given path. - - Refer to https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/checkpoint.py # noqa: E501 - - Args: - path(str): The path to find checkpoints. - - Returns: - str or None: File path of the latest checkpoint. - """ - save_file = osp.join(path, 'last_checkpoint') - last_saved: Optional[str] - if os.path.exists(save_file): - with open(save_file) as f: - last_saved = f.read().strip() - else: - print_log('Did not find last_checkpoint to be resumed.') - last_saved = None - return last_saved +# All functions and classes in this file have been moved to mmengine.checkpoint +# Import them here to avoid BC +# flake8: noqa +from mmengine.checkpoint.io import (_IncompatibleKeys, _load_checkpoint, + _load_checkpoint_to_model, + _load_checkpoint_with_prefix, + _save_to_state_dict, get_state_dict, + load_checkpoint, load_state_dict, + save_checkpoint, weights_to_cpu) +from mmengine.checkpoint.loader import (CheckpointLoader, + _process_mmcls_checkpoint, + load_from_ceph, load_from_http, + load_from_local, load_from_mmcls, + load_from_openmmlab, load_from_pavi, + load_from_torchvision) +from mmengine.checkpoint.utils import (DEFAULT_CACHE_DIR, ENV_MMENGINE_HOME, + ENV_XDG_CACHE_HOME, _get_mmengine_home, + find_latest_checkpoint, + get_deprecated_model_names, + get_external_models, get_mmcls_models, + get_torchvision_models) diff --git a/mmengine/testing/__init__.py b/mmengine/testing/__init__.py index a7e4da3543..29be6ac8bf 100644 --- a/mmengine/testing/__init__.py +++ b/mmengine/testing/__init__.py @@ -2,11 +2,13 @@ from .compare import (assert_allclose, assert_attrs_equal, assert_dict_contains_subset, assert_dict_has_keys, assert_is_norm_layer, assert_keys_equal, - assert_params_all_zeros, check_python_script) + assert_params_all_zeros, assert_tensor_equal, + check_python_script) from .runner_test_case import RunnerTestCase __all__ = [ - 'assert_allclose', 'assert_dict_contains_subset', 'assert_keys_equal', - 'assert_attrs_equal', 'assert_dict_has_keys', 'assert_is_norm_layer', - 'assert_params_all_zeros', 'check_python_script', 'RunnerTestCase' + 'assert_allclose', 'assert_tensor_equal', 'assert_dict_contains_subset', + 'assert_keys_equal', 'assert_attrs_equal', 'assert_dict_has_keys', + 'assert_is_norm_layer', 'assert_params_all_zeros', 'check_python_script', + 'RunnerTestCase' ] diff --git a/mmengine/testing/compare.py b/mmengine/testing/compare.py index 14c7a97ba7..e617eed8a7 100644 --- a/mmengine/testing/compare.py +++ b/mmengine/testing/compare.py @@ -56,6 +56,10 @@ def assert_allclose( actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan) +def assert_tensor_equal(tensor_a, tensor_b): + assert tensor_a.eq(tensor_b).all() + + def check_python_script(cmd): """Run the python cmd script with `__main__`. The difference between `os.system` is that, this function exectues code in the current process, so diff --git a/tests/test_runner/test_checkpoint.py b/tests/test_checkpoint/test_io.py similarity index 74% rename from tests/test_runner/test_checkpoint.py rename to tests/test_checkpoint/test_io.py index b846616428..54704f81a0 100644 --- a/tests/test_runner/test_checkpoint.py +++ b/tests/test_checkpoint/test_io.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import os -import sys +import re import tempfile from collections import OrderedDict from tempfile import TemporaryDirectory -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest import torch @@ -12,13 +12,12 @@ import torch.optim as optim from torch.nn.parallel import DataParallel +from mmengine.checkpoint.io import (_load_checkpoint_with_prefix, + get_state_dict, load_checkpoint, + load_state_dict, save_checkpoint) from mmengine.fileio.file_client import PetrelBackend from mmengine.registry import MODEL_WRAPPERS -from mmengine.runner.checkpoint import (CheckpointLoader, - _load_checkpoint_with_prefix, - get_state_dict, load_checkpoint, - load_from_local, load_from_pavi, - load_state_dict, save_checkpoint) +from mmengine.testing import assert_tensor_equal @MODEL_WRAPPERS.register_module() @@ -44,19 +43,6 @@ def __init__(self): self.conv = nn.Conv2d(3, 3, 1) -class Mockpavimodel: - - def __init__(self, name='fakename'): - self.name = name - - def download(self, file): - pass - - -def assert_tensor_equal(tensor_a, tensor_b): - assert tensor_a.eq(tensor_b).all() - - def test_get_state_dict(): if torch.__version__ == 'parrots': state_dict_keys = { @@ -147,57 +133,7 @@ def test_get_state_dict(): wrapped_model.module.conv.module.bias) -@patch.dict(sys.modules, {'pavi': MagicMock()}) -def test_load_pavimodel_dist(): - pavimodel = Mockpavimodel() - import pavi - pavi.modelcloud.get = MagicMock(return_value=pavimodel) - with pytest.raises(AssertionError): - # test pavi prefix - _ = load_from_pavi('MyPaviFolder/checkpoint.pth') - - with pytest.raises(FileNotFoundError): - # there is not such checkpoint for us to load - _ = load_from_pavi('pavi://checkpoint.pth') - - -def test_load_checkpoint_with_prefix(): - - class FooModule(nn.Module): - - def __init__(self): - super().__init__() - self.linear = nn.Linear(1, 2) - self.conv2d = nn.Conv2d(3, 1, 3) - self.conv2d_2 = nn.Conv2d(3, 2, 3) - - model = FooModule() - nn.init.constant_(model.linear.weight, 1) - nn.init.constant_(model.linear.bias, 2) - nn.init.constant_(model.conv2d.weight, 3) - nn.init.constant_(model.conv2d.bias, 4) - nn.init.constant_(model.conv2d_2.weight, 5) - nn.init.constant_(model.conv2d_2.bias, 6) - - with TemporaryDirectory(): - torch.save(model.state_dict(), 'model.pth') - prefix = 'conv2d' - state_dict = _load_checkpoint_with_prefix(prefix, 'model.pth') - assert torch.equal(model.conv2d.state_dict()['weight'], - state_dict['weight']) - assert torch.equal(model.conv2d.state_dict()['bias'], - state_dict['bias']) - - # test whether prefix is in pretrained model - with pytest.raises(AssertionError): - prefix = 'back' - _load_checkpoint_with_prefix(prefix, 'model.pth') - - def test_load_checkpoint(): - import os - import re - import tempfile class PrefixModel(nn.Module): @@ -292,66 +228,37 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight) -@patch.dict(sys.modules, {'petrel_client': MagicMock()}) -def test_checkpoint_loader(): - filenames = [ - 'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth', - 'modelzoo://xx.xx/xx.pth', 'torchvision://xx.xx/xx.pth', - 'open-mmlab://xx.xx/xx.pth', 'openmmlab://xx.xx/xx.pth', - 'mmcls://xx.xx/xx.pth', 'pavi://xx.xx/xx.pth', 's3://xx.xx/xx.pth', - 'ss3://xx.xx/xx.pth', ' s3://xx.xx/xx.pth', - 'open-mmlab:s3://xx.xx/xx.pth', 'openmmlab:s3://xx.xx/xx.pth', - 'openmmlabs3://xx.xx/xx.pth', ':s3://xx.xx/xx.path' - ] - fn_names = [ - 'load_from_http', 'load_from_http', 'load_from_torchvision', - 'load_from_torchvision', 'load_from_openmmlab', 'load_from_openmmlab', - 'load_from_mmcls', 'load_from_pavi', 'load_from_ceph', - 'load_from_local', 'load_from_local', 'load_from_ceph', - 'load_from_ceph', 'load_from_local', 'load_from_local' - ] - - for filename, fn_name in zip(filenames, fn_names): - loader = CheckpointLoader._get_checkpoint_loader(filename) - assert loader.__name__ == fn_name - - @CheckpointLoader.register_scheme(prefixes='ftp://') - def load_from_ftp(filename, map_location): - return dict(filename=filename) - - # test register_loader - filename = 'ftp://xx.xx/xx.pth' - loader = CheckpointLoader._get_checkpoint_loader(filename) - assert loader.__name__ == 'load_from_ftp' - - def load_from_ftp1(filename, map_location): - return dict(filename=filename) - - # test duplicate registered error - with pytest.raises(KeyError): - CheckpointLoader.register_scheme('ftp://', load_from_ftp1) - - # test force param - CheckpointLoader.register_scheme('ftp://', load_from_ftp1, force=True) - checkpoint = CheckpointLoader.load_checkpoint(filename) - assert checkpoint['filename'] == filename - - # test print function name - loader = CheckpointLoader._get_checkpoint_loader(filename) - assert loader.__name__ == 'load_from_ftp1' - - # test sort - @CheckpointLoader.register_scheme(prefixes='a/b') - def load_from_ab(filename, map_location): - return dict(filename=filename) - - @CheckpointLoader.register_scheme(prefixes='a/b/c') - def load_from_abc(filename, map_location): - return dict(filename=filename) - - filename = 'a/b/c/d' - loader = CheckpointLoader._get_checkpoint_loader(filename) - assert loader.__name__ == 'load_from_abc' +def test_load_checkpoint_with_prefix(): + + class FooModule(nn.Module): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 2) + self.conv2d = nn.Conv2d(3, 1, 3) + self.conv2d_2 = nn.Conv2d(3, 2, 3) + + model = FooModule() + nn.init.constant_(model.linear.weight, 1) + nn.init.constant_(model.linear.bias, 2) + nn.init.constant_(model.conv2d.weight, 3) + nn.init.constant_(model.conv2d.bias, 4) + nn.init.constant_(model.conv2d_2.weight, 5) + nn.init.constant_(model.conv2d_2.bias, 6) + + with TemporaryDirectory(): + torch.save(model.state_dict(), 'model.pth') + prefix = 'conv2d' + state_dict = _load_checkpoint_with_prefix(prefix, 'model.pth') + assert torch.equal(model.conv2d.state_dict()['weight'], + state_dict['weight']) + assert torch.equal(model.conv2d.state_dict()['bias'], + state_dict['bias']) + + # test whether prefix is in pretrained model + with pytest.raises(AssertionError): + prefix = 'back' + _load_checkpoint_with_prefix(prefix, 'model.pth') def test_save_checkpoint(tmp_path): @@ -393,21 +300,6 @@ def test_save_checkpoint(tmp_path): mock_method.assert_called() -def test_load_from_local(): - import os - home_path = os.path.expanduser('~') - checkpoint_path = os.path.join( - home_path, 'dummy_checkpoint_used_to_test_load_from_local.pth') - model = Model() - save_checkpoint(model.state_dict(), checkpoint_path) - checkpoint = load_from_local( - '~/dummy_checkpoint_used_to_test_load_from_local.pth', - map_location=None) - assert_tensor_equal(checkpoint['block.conv.weight'], - model.block.conv.weight) - os.remove(checkpoint_path) - - def test_load_state_dict_post_hooks(): module = Block() @@ -421,7 +313,7 @@ def test_load_state_dict_post_hooks(): } state_dict.pop('norm.running_var') - with patch('mmengine.runner.checkpoint.print_log') as mock: + with patch('mmengine.checkpoint.io.print_log') as mock: load_state_dict(module, state_dict, strict=False) mock.assert_called_once() @@ -430,6 +322,6 @@ def post_hook(_, incompatible_keys): module._load_state_dict_post_hooks = {0: post_hook} - with patch('mmengine.runner.checkpoint.print_log') as mock: + with patch('mmengine.checkpoint.io.print_log') as mock: load_state_dict(module, state_dict, strict=False) mock.assert_not_called() diff --git a/tests/test_checkpoint/test_loader.py b/tests/test_checkpoint/test_loader.py new file mode 100644 index 0000000000..2d4f89e032 --- /dev/null +++ b/tests/test_checkpoint/test_loader.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +from unittest.mock import MagicMock, patch + +import pytest +import torch.nn as nn + +from mmengine.checkpoint.io import save_checkpoint +from mmengine.checkpoint.loader import (CheckpointLoader, load_from_local, + load_from_pavi) +from mmengine.testing import assert_tensor_equal + + +class Block(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + self.norm = nn.BatchNorm2d(3) + + +class Model(nn.Module): + + def __init__(self): + super().__init__() + self.block = Block() + self.conv = nn.Conv2d(3, 3, 1) + + +class Mockpavimodel: + + def __init__(self, name='fakename'): + self.name = name + + def download(self, file): + pass + + +@patch.dict(sys.modules, {'pavi': MagicMock()}) +def test_load_pavimodel_dist(): + pavimodel = Mockpavimodel() + import pavi + pavi.modelcloud.get = MagicMock(return_value=pavimodel) + with pytest.raises(AssertionError): + # test pavi prefix + _ = load_from_pavi('MyPaviFolder/checkpoint.pth') + + with pytest.raises(FileNotFoundError): + # there is not such checkpoint for us to load + _ = load_from_pavi('pavi://checkpoint.pth') + + +def test_load_from_local(): + import os + home_path = os.path.expanduser('~') + checkpoint_path = os.path.join( + home_path, 'dummy_checkpoint_used_to_test_load_from_local.pth') + model = Model() + save_checkpoint(model.state_dict(), checkpoint_path) + checkpoint = load_from_local( + '~/dummy_checkpoint_used_to_test_load_from_local.pth', + map_location=None) + assert_tensor_equal(checkpoint['block.conv.weight'], + model.block.conv.weight) + os.remove(checkpoint_path) + + +@patch.dict(sys.modules, {'petrel_client': MagicMock()}) +def test_checkpoint_loader(): + filenames = [ + 'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth', + 'modelzoo://xx.xx/xx.pth', 'torchvision://xx.xx/xx.pth', + 'open-mmlab://xx.xx/xx.pth', 'openmmlab://xx.xx/xx.pth', + 'mmcls://xx.xx/xx.pth', 'pavi://xx.xx/xx.pth', 's3://xx.xx/xx.pth', + 'ss3://xx.xx/xx.pth', ' s3://xx.xx/xx.pth', + 'open-mmlab:s3://xx.xx/xx.pth', 'openmmlab:s3://xx.xx/xx.pth', + 'openmmlabs3://xx.xx/xx.pth', ':s3://xx.xx/xx.path' + ] + fn_names = [ + 'load_from_http', 'load_from_http', 'load_from_torchvision', + 'load_from_torchvision', 'load_from_openmmlab', 'load_from_openmmlab', + 'load_from_mmcls', 'load_from_pavi', 'load_from_ceph', + 'load_from_local', 'load_from_local', 'load_from_ceph', + 'load_from_ceph', 'load_from_local', 'load_from_local' + ] + + for filename, fn_name in zip(filenames, fn_names): + loader = CheckpointLoader._get_checkpoint_loader(filename) + assert loader.__name__ == fn_name + + @CheckpointLoader.register_scheme(prefixes='ftp://') + def load_from_ftp(filename, map_location): + return dict(filename=filename) + + # test register_loader + filename = 'ftp://xx.xx/xx.pth' + loader = CheckpointLoader._get_checkpoint_loader(filename) + assert loader.__name__ == 'load_from_ftp' + + def load_from_ftp1(filename, map_location): + return dict(filename=filename) + + # test duplicate registered error + with pytest.raises(KeyError): + CheckpointLoader.register_scheme('ftp://', load_from_ftp1) + + # test force param + CheckpointLoader.register_scheme('ftp://', load_from_ftp1, force=True) + checkpoint = CheckpointLoader.load_checkpoint(filename) + assert checkpoint['filename'] == filename + + # test print function name + loader = CheckpointLoader._get_checkpoint_loader(filename) + assert loader.__name__ == 'load_from_ftp1' + + # test sort + @CheckpointLoader.register_scheme(prefixes='a/b') + def load_from_ab(filename, map_location): + return dict(filename=filename) + + @CheckpointLoader.register_scheme(prefixes='a/b/c') + def load_from_abc(filename, map_location): + return dict(filename=filename) + + filename = 'a/b/c/d' + loader = CheckpointLoader._get_checkpoint_loader(filename) + assert loader.__name__ == 'load_from_abc'