Skip to content

Commit

Permalink
[Fix] Fix get optimizer_cls (#1324)
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE authored Aug 28, 2023
1 parent 714c8ee commit 170758a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
3 changes: 2 additions & 1 deletion mmengine/optim/optimizer/default_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ def __call__(self, model: nn.Module) -> OptimWrapper:
# `model_params` rather than `params`. Here we get the first argument
# name and fill it with the model parameters.
if isinstance(optimizer_cls, str):
optimizer_cls = OPTIMIZERS.get(self.optimizer_cfg['type'])
with OPTIMIZERS.switch_scope_and_registry(None) as registry:
optimizer_cls = registry.get(self.optimizer_cfg['type'])
fisrt_arg_name = next(
iter(inspect.signature(optimizer_cls).parameters))
# if no paramwise option is specified, just use the global setting
Expand Down
18 changes: 17 additions & 1 deletion tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from mmengine.optim.optimizer.builder import (DADAPTATION_OPTIMIZERS,
LION_OPTIMIZERS,
TORCH_OPTIMIZERS)
from mmengine.registry import build_from_cfg
from mmengine.registry import DefaultScope, Registry, build_from_cfg
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
from mmengine.utils.version_utils import digit_version
Expand Down Expand Up @@ -391,6 +391,22 @@ def test_default_optimizer_constructor(self):
optim_wrapper = optim_constructor(self.model)
self._check_default_optimizer(optim_wrapper.optimizer, self.model)

# Support building custom optimizers
CUSTOM_OPTIMIZERS = Registry(
'custom optimizer', scope='custom optimizer', parent=OPTIMIZERS)

class CustomOptimizer(torch.optim.SGD):

def __init__(self, model_params, *args, **kwargs):
super().__init__(params=model_params, *args, **kwargs)

CUSTOM_OPTIMIZERS.register_module()(CustomOptimizer)
optimizer_cfg = dict(optimizer=dict(type='CustomOptimizer', lr=0.1), )
with DefaultScope.overwrite_default_scope('custom optimizer'):
optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg)
optim_wrapper = optim_constructor(self.model)
OPTIMIZERS.children.pop('custom optimizer')

def test_default_optimizer_constructor_with_model_wrapper(self):
# basic config with pseudo data parallel
model = PseudoDataParallel()
Expand Down

0 comments on commit 170758a

Please sign in to comment.