diff --git a/mmengine/config/new_config.py b/mmengine/config/new_config.py index c8f774017f..e73387e105 100644 --- a/mmengine/config/new_config.py +++ b/mmengine/config/new_config.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -import builtins import importlib import inspect import platform import sys +from importlib.abc import Loader, MetaPathFinder +from importlib.machinery import PathFinder +from importlib.util import spec_from_loader from pathlib import Path from types import BuiltinFunctionType, FunctionType, ModuleType from typing import Optional, Tuple, Union @@ -14,7 +16,6 @@ from .lazy import LazyImportContext, LazyObject RESERVED_KEYS = ['filename', 'text', 'pretty_text'] -_CFG_UID = 0 if platform.system() == 'Windows': import regex as re @@ -114,7 +115,8 @@ class ConfigV2(Config): .. _config tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html """ # noqa: E501 - _pkg_name = '_mmengine_cfg' + _max_parent_depth = 4 + _parent_pkg = '_cfg_parent' def __init__(self, cfg_dict: dict = None, @@ -159,7 +161,7 @@ def _sanity_check(cfg): for v in cfg: ConfigV2._sanity_check(v) elif isinstance(cfg, (type, FunctionType)): - if (ConfigV2._pkg_name in cfg.__module__ + if (ConfigV2._parent_pkg in cfg.__module__ or '__main__' in cfg.__module__): msg = ('You cannot use temporary functions ' 'as the value of a field.\n\n') @@ -209,28 +211,22 @@ def fromfile(filename: Union[str, Path], format_python_code=format_python_code) finally: ConfigDict.lazy = False - global _CFG_UID - _CFG_UID = 0 - for mod in list(sys.modules): - if mod.startswith(ConfigV2._pkg_name): - del sys.modules[mod] return cfg @staticmethod - def _get_config_module(filename: Union[str, Path]): + def _get_config_module(filename: Union[str, Path], level=0): file = Path(filename).absolute() module_name = re.sub(r'\W|^(?=\d)', '_', file.stem) - global _CFG_UID - fullname = ConfigV2._pkg_name + str(_CFG_UID) + '.' + module_name - _CFG_UID += 1 + parent_pkg = ConfigV2._parent_pkg + str(level) + fullname = '.'.join([parent_pkg] * ConfigV2._max_parent_depth + + [module_name]) # import config file as a module with LazyImportContext(): spec = importlib.util.spec_from_file_location(fullname, file) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - sys.modules[fullname] = module return module @@ -342,16 +338,14 @@ def _format_basic_types(input_): return text - def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str], bool]: - return (self._cfg_dict, self._filename, self._text, - self._format_python_code) + def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str]]: + return (self._cfg_dict, self._filename, self._text) - def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str], - bool]): - super(Config, self).__setattr__('_cfg_dict', state[0]) - super(Config, self).__setattr__('_filename', state[1]) - super(Config, self).__setattr__('_text', state[2]) - super(Config, self).__setattr__('_format_python_code', state[3]) + def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str]]): + _cfg_dict, _filename, _text = state + super(Config, self).__setattr__('_cfg_dict', _cfg_dict) + super(Config, self).__setattr__('_filename', _filename) + super(Config, self).__setattr__('_text', _text) def _to_lazy_dict(self, keep_imported: bool = False) -> dict: """Convert config object to dictionary and filter the imported @@ -389,9 +383,99 @@ def lazy2string(cfg_dict): return lazy2string(_cfg_dict) -class BaseImportContext(): +class BaseConfigLoader(Loader): + + def __init__(self, filepath, level) -> None: + self.filepath = filepath + self.level = level + + def create_module(self, spec): + file = self.filepath + return ConfigV2._get_config_module(file, level=self.level) + + def exec_module(self, module): + for k in dir(module): + module.__dict__[k] = ConfigV2._dict_to_config_dict_lazy( + getattr(module, k)) + + +class ParentFolderLoader(Loader): + + @staticmethod + def create_module(spec): + return ModuleType(spec.name) + + @staticmethod + def exec_module(module): + pass + + +class BaseImportContext(MetaPathFinder): + + def find_spec(self, fullname, path=None, target=None): + """Try to find a spec for 'fullname' on sys.path or 'path'. + + The search is based on sys.path_hooks and sys.path_importer_cache. + """ + parent_pkg = ConfigV2._parent_pkg + str(self.level) + names = fullname.split('.') + + if names[-1] == parent_pkg: + self.base_modules.append(fullname) + # Create parent package + return spec_from_loader( + fullname, loader=ParentFolderLoader, is_package=True) + elif names[0] == parent_pkg: + self.base_modules.append(fullname) + # relative imported base package + filepath = self.root_path + for name in names: + if name == parent_pkg: + # Use parent to remove `..` at the end of the root path + filepath = filepath.parent + else: + filepath = filepath / name + if filepath.is_dir(): + # If a dir, create a package. + return spec_from_loader( + fullname, loader=ParentFolderLoader, is_package=True) + + pypath = filepath.with_suffix('.py') + + if not pypath.exists(): + raise ImportError(f'Not found base path {filepath.resolve()}') + return importlib.util.spec_from_loader( + fullname, BaseConfigLoader(pypath, self.level + 1)) + else: + # Absolute import + pkg = PathFinder.find_spec(names[0]) + if pkg and pkg.submodule_search_locations: + self.base_modules.append(fullname) + path = Path(pkg.submodule_search_locations[0]) + for name in names[1:]: + path = path / name + if path.is_dir(): + return spec_from_loader( + fullname, loader=ParentFolderLoader, is_package=True) + pypath = path.with_suffix('.py') + if not pypath.exists(): + raise ImportError(f'Not found base path {path.resolve()}') + return importlib.util.spec_from_loader( + fullname, BaseConfigLoader(pypath, self.level + 1)) + return None def __enter__(self): + # call from which file + stack = inspect.stack()[1] + file = inspect.getfile(stack[0]) + folder = Path(file).parent + self.root_path = folder.joinpath(*(['..'] * + ConfigV2._max_parent_depth)) + + self.base_modules = [] + self.level = len( + [p for p in sys.meta_path if isinstance(p, BaseImportContext)]) + # Disable enabled lazy loader during parsing base self.lazy_importers = [] for p in sys.meta_path: @@ -399,51 +483,18 @@ def __enter__(self): self.lazy_importers.append(p) p.enable = False - old_import = builtins.__import__ - - def new_import(name, globals=None, locals=None, fromlist=(), level=0): - # Only deal with relative imports inside config files - if (level != 0 and globals is not None - and globals.get('__package__') is not None and - globals.get('__package__').startswith(ConfigV2._pkg_name)): - cur_file = self.find_relative_file(globals['__file__'], name, - level) - mod = ConfigV2._get_config_module(cur_file) - - for k in dir(mod): - mod.__dict__[k] = ConfigV2._dict_to_config_dict_lazy( - getattr(mod, k)) - return mod - return old_import( - name, globals, locals, fromlist=fromlist, level=level) - - builtins.__import__ = new_import - self.old_import = old_import + index = sys.meta_path.index(importlib.machinery.FrozenImporter) + sys.meta_path.insert(index + 1, self) def __exit__(self, exc_type, exc_val, exc_tb): + sys.meta_path.remove(self) + for name in self.base_modules: + sys.modules.pop(name, None) for p in self.lazy_importers: p.enable = True - builtins.__import__ = self.old_import - @staticmethod - def find_relative_file(original_file, relative_import_path, level): - cur_file = Path(original_file).parents[level - 1] - names = relative_import_path.lstrip('.').split('.') - - for name in names: - cur_file = cur_file / name - - if cur_file.is_dir(): - raise ImportError( - f'Cannot import name {relative_import_path} from ' - f'{original_file}: {cur_file} is a directory.') - - cur_file = cur_file.with_suffix('.py') - if not cur_file.exists(): - raise ImportError( - f'Cannot import name {relative_import_path} from ' - f'{original_file}: {cur_file} does not exist.') - return cur_file + def __repr__(self): + return f'' read_base = BaseImportContext diff --git a/mmengine/config/old_config.py b/mmengine/config/old_config.py index 8bdb7d9baa..9a568794d8 100644 --- a/mmengine/config/old_config.py +++ b/mmengine/config/old_config.py @@ -666,16 +666,14 @@ def env_variables(self) -> dict: """get used environment variables.""" return self._env_variables - def __getstate__( - self) -> Tuple[dict, Optional[str], Optional[str], dict, bool]: - state = (self._cfg_dict, self._filename, self._text, - self._env_variables, self._format_python_code) - return state + def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str], dict]: + return (self._cfg_dict, self._filename, self._text, + self._env_variables) def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str], - dict, bool]): - super(Config, self).__setattr__('_cfg_dict', state[0]) - super(Config, self).__setattr__('_filename', state[1]) - super(Config, self).__setattr__('_text', state[2]) - super(Config, self).__setattr__('_env_variables', state[3]) - super(Config, self).__setattr__('_format_python_code', state[4]) + dict]): + _cfg_dict, _filename, _text, _env_variables = state + super(Config, self).__setattr__('_cfg_dict', _cfg_dict) + super(Config, self).__setattr__('_filename', _filename) + super(Config, self).__setattr__('_text', _text) + super(Config, self).__setattr__('_text', _env_variables) diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index a389c23e0c..94588c30a8 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -10,8 +10,8 @@ from rich.console import Console from rich.table import Table -from mmengine.config.lazy import LazyObject from mmengine.config.utils import MODULE2PACKAGE +from mmengine.config.lazy import LazyObject from mmengine.utils import is_seq_of from .default_scope import DefaultScope diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index b8d8318a8c..697af422c9 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -986,9 +986,8 @@ def test_lazy_import(self, tmp_path): cfg_dict = cfg.to_dict() assert (cfg_dict['train_dataloader']['dataset']['type'] == '') - assert (cfg_dict['custom_hooks'][0]['type'] - in ('', - '')) + assert ( + cfg_dict['custom_hooks'][0]['type'] == '') # Dumped config dumped_cfg_path = tmp_path / 'test_dump_lazy.py' cfg.dump(dumped_cfg_path) @@ -1061,6 +1060,12 @@ def _compare_dict(a, b): osp.join(self.data_path, 'config/lazy_module_config/error_mix_using1.py')) + # Force to import in non-lazy-import mode + Config.fromfile( + osp.join(self.data_path, + 'config/lazy_module_config/error_mix_using1.py'), + lazy_import=False) + # current lazy-import config, base text config with pytest.raises(AttributeError, match='item2'): Config.fromfile( @@ -1083,7 +1088,7 @@ def _compare_dict(a, b): dumped_cfg = Config.fromfile(dumped_cfg_path) assert set(dumped_cfg.keys()) == { - 'path', 'name', 'suffix', 'chained', 'existed', 'cfgname', 'ex' + 'path', 'name', 'suffix', 'chained', 'existed', 'cfgname' } assert dumped_cfg.to_dict() == cfg.to_dict()