Skip to content

Commit

Permalink
[Enhance] Enable exclude_frozen_parameters for `DeepSpeedEngine._ze…
Browse files Browse the repository at this point in the history
…ro3_consolidated_16bit_state_dict` (open-mmlab#1517)
  • Loading branch information
LZHgrla authored Apr 12, 2024
1 parent e258c84 commit 39ed23f
Showing 1 changed file with 19 additions and 22 deletions.
41 changes: 19 additions & 22 deletions mmengine/_strategy/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ def __init__(
self.config['steps_per_print'] = steps_per_print
self._inputs_to_half = inputs_to_half
assert (exclude_frozen_parameters is None or
digit_version(deepspeed.__version__) >= digit_version('0.10.1')
), ('DeepSpeed >= 0.10.1 is required to enable '
digit_version(deepspeed.__version__) >= digit_version('0.13.2')
), ('DeepSpeed >= 0.13.2 is required to enable '
'exclude_frozen_parameters')
self.exclude_frozen_parameters = exclude_frozen_parameters

Expand Down Expand Up @@ -430,7 +430,7 @@ def load_checkpoint(
self.logger.info(f'Load checkpoint from {filename}')

dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
_, extra_ckpt = self.model.load_checkpoint(
dirname,
tag=basename,
Expand Down Expand Up @@ -468,7 +468,7 @@ def resume(
self.logger.info(f'Resume checkpoint from {filename}')

dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
_, extra_ckpt = self.model.load_checkpoint(
dirname,
tag=basename,
Expand Down Expand Up @@ -551,34 +551,31 @@ def save_checkpoint(
level=logging.WARNING)
save_optimizer = True

state_dict_kwargs = {}
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
state_dict_kwargs[
'exclude_frozen_parameters'] = self.exclude_frozen_parameters

if save_optimizer:
if hasattr(self, 'optim_wrapper'):
# The key can not be 'optimizer', otherwise error will be
# thrown when loading or resuming checkpoint.
extra_ckpt['optim_wrapper'] = self.optim_state_dict()

dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False,
exclude_frozen_parameters=self.exclude_frozen_parameters)
else:
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False)
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False,
**state_dict_kwargs)
else:
if self.model.zero_optimization_partition_weights():
# TODO: `_zero3_consolidated_16bit_state_dict` doesn't support
# `exclude_frozen_parameters`.
state_dict = self.model._zero3_consolidated_16bit_state_dict()
state_dict = self.model._zero3_consolidated_16bit_state_dict(
**state_dict_kwargs)
else:
state_dict = self.model.module_state_dict(
exclude_frozen_parameters=self.exclude_frozen_parameters)
state_dict = self.model.module_state_dict(**state_dict_kwargs)

if is_main_process():
ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt}
save_checkpoint(ckpt, filename)

0 comments on commit 39ed23f

Please sign in to comment.