diff --git a/mmengine/config/new_config.py b/mmengine/config/new_config.py index e73387e105..8ed95a2eab 100644 --- a/mmengine/config/new_config.py +++ b/mmengine/config/new_config.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import builtins import importlib import inspect +import os import platform import sys -from importlib.abc import Loader, MetaPathFinder from importlib.machinery import PathFinder -from importlib.util import spec_from_loader from pathlib import Path from types import BuiltinFunctionType, FunctionType, ModuleType from typing import Optional, Tuple, Union @@ -16,6 +16,7 @@ from .lazy import LazyImportContext, LazyObject RESERVED_KEYS = ['filename', 'text', 'pretty_text'] +_CFG_UID = 0 if platform.system() == 'Windows': import regex as re @@ -115,8 +116,7 @@ class ConfigV2(Config): .. _config tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html """ # noqa: E501 - _max_parent_depth = 4 - _parent_pkg = '_cfg_parent' + _pkg_prefix = '_mmengine_cfg' def __init__(self, cfg_dict: dict = None, @@ -161,7 +161,7 @@ def _sanity_check(cfg): for v in cfg: ConfigV2._sanity_check(v) elif isinstance(cfg, (type, FunctionType)): - if (ConfigV2._parent_pkg in cfg.__module__ + if (ConfigV2._pkg_prefix in cfg.__module__ or '__main__' in cfg.__module__): msg = ('You cannot use temporary functions ' 'as the value of a field.\n\n') @@ -211,22 +211,29 @@ def fromfile(filename: Union[str, Path], format_python_code=format_python_code) finally: ConfigDict.lazy = False + global _CFG_UID + _CFG_UID = 0 + for mod in list(sys.modules): + if mod.startswith(ConfigV2._pkg_prefix): + del sys.modules[mod] return cfg @staticmethod - def _get_config_module(filename: Union[str, Path], level=0): + def _get_config_module(filename: Union[str, Path]): file = Path(filename).absolute() module_name = re.sub(r'\W|^(?=\d)', '_', file.stem) - parent_pkg = ConfigV2._parent_pkg + str(level) - fullname = '.'.join([parent_pkg] * ConfigV2._max_parent_depth + - [module_name]) + global _CFG_UID + # Build a unique module name to avoid conflict. + fullname = f'{ConfigV2._pkg_prefix}{_CFG_UID}_{module_name}' + _CFG_UID += 1 # import config file as a module with LazyImportContext(): spec = importlib.util.spec_from_file_location(fullname, file) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) + sys.modules[fullname] = module return module @@ -338,14 +345,16 @@ def _format_basic_types(input_): return text - def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str]]: - return (self._cfg_dict, self._filename, self._text) + def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str], bool]: + return (self._cfg_dict, self._filename, self._text, + self._format_python_code) - def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str]]): - _cfg_dict, _filename, _text = state - super(Config, self).__setattr__('_cfg_dict', _cfg_dict) - super(Config, self).__setattr__('_filename', _filename) - super(Config, self).__setattr__('_text', _text) + def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str], + bool]): + super(Config, self).__setattr__('_cfg_dict', state[0]) + super(Config, self).__setattr__('_filename', state[1]) + super(Config, self).__setattr__('_text', state[2]) + super(Config, self).__setattr__('_format_python_code', state[3]) def _to_lazy_dict(self, keep_imported: bool = False) -> dict: """Convert config object to dictionary and filter the imported @@ -383,99 +392,9 @@ def lazy2string(cfg_dict): return lazy2string(_cfg_dict) -class BaseConfigLoader(Loader): - - def __init__(self, filepath, level) -> None: - self.filepath = filepath - self.level = level - - def create_module(self, spec): - file = self.filepath - return ConfigV2._get_config_module(file, level=self.level) - - def exec_module(self, module): - for k in dir(module): - module.__dict__[k] = ConfigV2._dict_to_config_dict_lazy( - getattr(module, k)) - - -class ParentFolderLoader(Loader): - - @staticmethod - def create_module(spec): - return ModuleType(spec.name) - - @staticmethod - def exec_module(module): - pass - - -class BaseImportContext(MetaPathFinder): - - def find_spec(self, fullname, path=None, target=None): - """Try to find a spec for 'fullname' on sys.path or 'path'. - - The search is based on sys.path_hooks and sys.path_importer_cache. - """ - parent_pkg = ConfigV2._parent_pkg + str(self.level) - names = fullname.split('.') - - if names[-1] == parent_pkg: - self.base_modules.append(fullname) - # Create parent package - return spec_from_loader( - fullname, loader=ParentFolderLoader, is_package=True) - elif names[0] == parent_pkg: - self.base_modules.append(fullname) - # relative imported base package - filepath = self.root_path - for name in names: - if name == parent_pkg: - # Use parent to remove `..` at the end of the root path - filepath = filepath.parent - else: - filepath = filepath / name - if filepath.is_dir(): - # If a dir, create a package. - return spec_from_loader( - fullname, loader=ParentFolderLoader, is_package=True) - - pypath = filepath.with_suffix('.py') - - if not pypath.exists(): - raise ImportError(f'Not found base path {filepath.resolve()}') - return importlib.util.spec_from_loader( - fullname, BaseConfigLoader(pypath, self.level + 1)) - else: - # Absolute import - pkg = PathFinder.find_spec(names[0]) - if pkg and pkg.submodule_search_locations: - self.base_modules.append(fullname) - path = Path(pkg.submodule_search_locations[0]) - for name in names[1:]: - path = path / name - if path.is_dir(): - return spec_from_loader( - fullname, loader=ParentFolderLoader, is_package=True) - pypath = path.with_suffix('.py') - if not pypath.exists(): - raise ImportError(f'Not found base path {path.resolve()}') - return importlib.util.spec_from_loader( - fullname, BaseConfigLoader(pypath, self.level + 1)) - return None +class BaseImportContext(): def __enter__(self): - # call from which file - stack = inspect.stack()[1] - file = inspect.getfile(stack[0]) - folder = Path(file).parent - self.root_path = folder.joinpath(*(['..'] * - ConfigV2._max_parent_depth)) - - self.base_modules = [] - self.level = len( - [p for p in sys.meta_path if isinstance(p, BaseImportContext)]) - # Disable enabled lazy loader during parsing base self.lazy_importers = [] for p in sys.meta_path: @@ -483,18 +402,68 @@ def __enter__(self): self.lazy_importers.append(p) p.enable = False - index = sys.meta_path.index(importlib.machinery.FrozenImporter) - sys.meta_path.insert(index + 1, self) + old_import = builtins.__import__ + + def new_import(name, globals=None, locals=None, fromlist=(), level=0): + cur_file = None + + # Try to import the base config source file + if level != 0 and globals is not None: + # For relative import path + if '__file__' in globals: + loc = Path(globals['__file__']).parent + else: + loc = Path(os.getcwd()) + cur_file = self.find_relative_file(loc, name, level - 1) + if not cur_file.exists(): + raise ImportError(f'Cannot import name "{name}" from ' + f'{loc}: {cur_file} does not exist.') + elif level == 0: + # For absolute import path + pkg, _, mod = name.partition('.') + pkg = PathFinder.find_spec(pkg) + if mod and pkg.submodule_search_locations: + loc = Path(pkg.submodule_search_locations[0]) + cur_file = self.find_relative_file(loc, mod) + if not cur_file.exists(): + raise ImportError(f'Cannot import name "{name}": ' + f'{cur_file} does not exist.') + + # Recover the original import during handle the base config file. + builtins.__import__ = old_import + + if cur_file is not None: + mod = ConfigV2._get_config_module(cur_file) + + for k in dir(mod): + mod.__dict__[k] = ConfigV2._dict_to_config_dict_lazy( + getattr(mod, k)) + else: + mod = old_import( + name, globals, locals, fromlist=fromlist, level=level) + + builtins.__import__ = new_import + + return mod + + self.old_import = old_import + builtins.__import__ = new_import def __exit__(self, exc_type, exc_val, exc_tb): - sys.meta_path.remove(self) - for name in self.base_modules: - sys.modules.pop(name, None) + builtins.__import__ = self.old_import for p in self.lazy_importers: p.enable = True - def __repr__(self): - return f'' + @staticmethod + def find_relative_file(loc: Path, relative_import_path, level=0): + if level > 0: + loc = loc.parents[level - 1] + names = relative_import_path.lstrip('.').split('.') + + for name in names: + loc = loc / name + + return loc.with_suffix('.py') read_base = BaseImportContext diff --git a/mmengine/config/old_config.py b/mmengine/config/old_config.py index 9a568794d8..8bdb7d9baa 100644 --- a/mmengine/config/old_config.py +++ b/mmengine/config/old_config.py @@ -666,14 +666,16 @@ def env_variables(self) -> dict: """get used environment variables.""" return self._env_variables - def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str], dict]: - return (self._cfg_dict, self._filename, self._text, - self._env_variables) + def __getstate__( + self) -> Tuple[dict, Optional[str], Optional[str], dict, bool]: + state = (self._cfg_dict, self._filename, self._text, + self._env_variables, self._format_python_code) + return state def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str], - dict]): - _cfg_dict, _filename, _text, _env_variables = state - super(Config, self).__setattr__('_cfg_dict', _cfg_dict) - super(Config, self).__setattr__('_filename', _filename) - super(Config, self).__setattr__('_text', _text) - super(Config, self).__setattr__('_text', _env_variables) + dict, bool]): + super(Config, self).__setattr__('_cfg_dict', state[0]) + super(Config, self).__setattr__('_filename', state[1]) + super(Config, self).__setattr__('_text', state[2]) + super(Config, self).__setattr__('_env_variables', state[3]) + super(Config, self).__setattr__('_format_python_code', state[4]) diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 94588c30a8..a389c23e0c 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -10,8 +10,8 @@ from rich.console import Console from rich.table import Table -from mmengine.config.utils import MODULE2PACKAGE from mmengine.config.lazy import LazyObject +from mmengine.config.utils import MODULE2PACKAGE from mmengine.utils import is_seq_of from .default_scope import DefaultScope diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 697af422c9..b8d8318a8c 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -986,8 +986,9 @@ def test_lazy_import(self, tmp_path): cfg_dict = cfg.to_dict() assert (cfg_dict['train_dataloader']['dataset']['type'] == '') - assert ( - cfg_dict['custom_hooks'][0]['type'] == '') + assert (cfg_dict['custom_hooks'][0]['type'] + in ('', + '')) # Dumped config dumped_cfg_path = tmp_path / 'test_dump_lazy.py' cfg.dump(dumped_cfg_path) @@ -1060,12 +1061,6 @@ def _compare_dict(a, b): osp.join(self.data_path, 'config/lazy_module_config/error_mix_using1.py')) - # Force to import in non-lazy-import mode - Config.fromfile( - osp.join(self.data_path, - 'config/lazy_module_config/error_mix_using1.py'), - lazy_import=False) - # current lazy-import config, base text config with pytest.raises(AttributeError, match='item2'): Config.fromfile( @@ -1088,7 +1083,7 @@ def _compare_dict(a, b): dumped_cfg = Config.fromfile(dumped_cfg_path) assert set(dumped_cfg.keys()) == { - 'path', 'name', 'suffix', 'chained', 'existed', 'cfgname' + 'path', 'name', 'suffix', 'chained', 'existed', 'cfgname', 'ex' } assert dumped_cfg.to_dict() == cfg.to_dict()