Skip to content

Commit

Permalink
Imporve according to comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Aug 29, 2023
1 parent b23516b commit 5d79f39
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 21 deletions.
51 changes: 33 additions & 18 deletions mmengine/config/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import importlib
import re
import sys
from importlib.util import spec_from_loader
from typing import Any
from importlib.util import find_spec, spec_from_loader
from typing import Any, Optional


class LazyObject:
Expand All @@ -16,30 +16,27 @@ class LazyObject:
>>> import torch.nn as nn
>>> from mmdet.models import RetinaNet
>>> import mmcls.models
>>> import mmcls.datasets
>>> import mmcls
Will be parsed as:
Examples:
>>> # import torch.nn as nn
>>> nn = lazyObject('torch.nn')
>>> nn = LazyObject('torch.nn')
>>> # from mmdet.models import RetinaNet
>>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet')
>>> # import mmcls.models; import mmcls.datasets; import mmcls
>>> mmcls = lazyObject(['mmcls', 'mmcls.datasets', 'mmcls.models'])
>>> RetinaNet = LazyObject('RetinaNet', LazyObject('mmdet.models'))
>>> # import mmcls.models
>>> mmcls = LazyObject('mmcls.models')
``LazyObject`` records all module information and will be further
referenced by the configuration file.
Args:
module (str or list or tuple): The module name to be imported.
imported (str, optional): The imported module name. Defaults to None.
location (str, optional): The filename and line number of the imported
module statement happened.
name (str): The name of a module or attribution.
source (LazyObject, optional): The source of the lazy object.
Defaults to None.
"""

def __init__(self, name: str, source: 'LazyObject' = None):
def __init__(self, name: str, source: Optional['LazyObject'] = None):
self.name = name
self.source = source

Expand All @@ -58,9 +55,20 @@ def build(self) -> Any:
f'Failed to import {self.name} from {self.source}')
else:
try:
return importlib.import_module(self.name)
except Exception as e:
raise type(e)(f'Failed to import {self.name} for {e}')
for idx in range(self.name.count('.') + 1):
module, *attrs = self.name.rsplit('.', idx)
try:
spec = find_spec(module)
except ImportError:
spec = None
if spec is not None:
res = importlib.import_module(module)
for attr in attrs:
res = getattr(res, attr)
return res
raise ImportError(f'No module named `{module}`.')
except (ImportError, AttributeError) as e:
raise ImportError(f'Failed to import {self.name} for {e}')

def __deepcopy__(self, memo):
return LazyObject(self.name, self.source)
Expand All @@ -74,10 +82,13 @@ def __str__(self) -> str:
return self.name

def __repr__(self) -> str:
return f"<Lazy '{str(self)}'>"
arg = f'name={repr(self.name)}'
if self.source is not None:
arg += f', source={repr(self.source)}'
return f'LazyObject({arg})'

@property
def dump_str(self):
def dump_str(self) -> str:
return f'<{str(self)}>'

@classmethod
Expand Down Expand Up @@ -130,6 +141,10 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
sys.meta_path.remove(self)
for name in self.lazy_modules:
if '.' in name:
parent_module, _, child_name = name.rpartition('.')
if parent_module in sys.modules:
delattr(sys.modules[parent_module], child_name)
sys.modules.pop(name, None)

def __repr__(self):
Expand Down
13 changes: 10 additions & 3 deletions mmengine/config/new_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def dump_extra_type(value):
if isinstance(value, LazyObject):
return value.dump_str
if isinstance(value, (type, FunctionType, BuiltinFunctionType)):
return '<' + value.__module__ + '.' + value.__name__ + '>'
return LazyObject(value.__name__, value.__module__).dump_str
if isinstance(value, ModuleType):
return f'<{value.__name__}>'
return LazyObject(value.__name__).dump_str

typename = type(value).__module__ + type(value).__name__
if typename == 'torch.dtype':
return '<' + str(value) + '>'
return LazyObject(str(value)).dump_str

return None

Expand Down Expand Up @@ -393,6 +393,13 @@ def __enter__(self):
old_import = builtins.__import__

def new_import(name, globals=None, locals=None, fromlist=(), level=0):
# For relative import, the new import allows import from files
# which are not in a package.
# For absolute import, the new import will try to find the python
# file according to the module name literally, it's used to handle
# importing from installed packages, like
# `mmpretrain.configs.resnet.resnet18_8xb32_in1k`.

cur_file = None

# Try to import the base config source file
Expand Down

0 comments on commit 5d79f39

Please sign in to comment.