Skip to content

Commit

Permalink
Multiple zero stage 3 related fixes (microsoft#3886)
Browse files Browse the repository at this point in the history
* Option to override module apply

* Removing early partitioning in override

* Unit tests

* Cleanup

* Adapt unit test to succeed

* Handle missed params

* Add accelerate

* Code cleanup

* Add doc

* Add doc

* Add doc
  • Loading branch information
tjruwase authored Jul 28, 2023
1 parent 7f26bb6 commit 7f90ef4
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 39 deletions.
11 changes: 5 additions & 6 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,23 +1041,22 @@ def _set_client_model(self, model):

def _configure_distributed_model(self, model):
self._set_client_model(model)

is_zero3_model = self.zero_optimization_partition_weights() and any(
is_zero_init_model = self.zero_optimization_partition_weights() and any(
[hasattr(param, "ds_id") for param in self.module.parameters()])

if self.fp16_enabled():
if is_zero3_model:
if is_zero_init_model:
self.__check_params(self.module, torch.half)
self.module.half()
elif self.bfloat16_enabled():
if is_zero3_model:
if is_zero_init_model:
self.__check_params(self.module, torch.bfloat16)
self.module.bfloat16()
else:
self.__check_params(self.module, torch.float)

# zero.Init() handles device placement of model
if not self.dont_change_device:
if not (self.dont_change_device or is_zero_init_model):
self.module.to(self.device)

# MoE related initialization
Expand Down Expand Up @@ -1097,7 +1096,7 @@ def _configure_distributed_model(self, model):
self.expert_parallel_group = groups._get_expert_parallel_group_dict()
self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict()

if not self.amp_enabled():
if not (self.amp_enabled() or is_zero_init_model):
self._broadcast_model()

# check if parameters are duplicated in optimizer param_groups
Expand Down
9 changes: 8 additions & 1 deletion deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
"zero_hpz_partition_size": 1,
"zero_quantized_weights": [true|false],
"zero_quantized_gradients": [true|false],
"memory_efficient_linear": [true|false]
"memory_efficient_linear": [true|false],
"override_module_apply": [true|false],
}
}
"""
Expand Down Expand Up @@ -269,11 +270,17 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
mics_shard_size: int = Field(-1, new_param="mics_shard_size")

mics_hierarchical_params_gather: bool = False

memory_efficient_linear: bool = True
"""
Use memory efficient linear implementation, for Stage 3.
"""

override_module_apply: bool = True
"""
Override nn.Module apply function, for Stage 3.
"""

# Validators
@validator("overlap_comm")
def overlap_comm_valid(cls, field_value, values):
Expand Down
81 changes: 49 additions & 32 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from deepspeed.accelerator import get_accelerator
from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus

param_count = 0
partitioned_param_data_shape = [0]
zero_init_context = 0
top_level_context = None
Expand Down Expand Up @@ -217,12 +216,14 @@ class ZeroParamStatus(Enum):
INFLIGHT = 3


_orig_torch_tensor = torch.tensor
_orig_torch_empty = torch.empty
_orig_torch_zeros = torch.zeros
_orig_torch_ones = torch.ones
_orig_torch_full = torch.full
_orig_torch_arange = torch.arange
_orig_torch_eye = torch.eye
_orig_torch_randn = torch.randn


def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable:
Expand Down Expand Up @@ -288,6 +289,8 @@ def free_param(param: Parameter) -> None:
# Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module
class InsertPostInitMethodToModuleSubClasses(object):
num_module_parameters = 0
num_module_elements = 0

def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtype=None):
self.mem_efficient_linear = mem_efficient_linear
Expand Down Expand Up @@ -324,7 +327,10 @@ def __exit__(self, exc_type, exc_value, traceback):
top_level_context = None

if dist.get_rank() == 0:
logger.info("finished initializing model with %.2fB parameters", param_count / 1e9)
billion_elems = InsertPostInitMethodToModuleSubClasses.num_module_elements / 1e9
num_params = InsertPostInitMethodToModuleSubClasses.num_module_parameters
logger.info(
f"finished initializing model - num_params = {num_params}, num_elems = {billion_elems:.2f}B")

# Now that we cleaned up the metaclass injection, raise the exception.
if exc_type is not None:
Expand Down Expand Up @@ -381,14 +387,16 @@ def wrapped_fn_to_apply(module_to_apply_fn_to: Module) -> None:
3. broadcasts root rank's parameters to the other ranks
4. re-partitions the parameters
"""
if not all(is_zero_param(p) for p in module_to_apply_fn_to.parameters(recurse=False)):
raise RuntimeError(f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, "
f"were zero params, is it possible that the parameters were "
f"overwritten after they were initialized? "
f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} ")

# TODO Delay error checking for dangling partitioned parameters to post module init
# raise RuntimeError(f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, "
# f"were zero params, is it possible that the parameters were "
# f"overwritten after they were initialized? "
# f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} ")

params_to_apply_fn_to: Iterable[Parameter] = list(
sorted(module_to_apply_fn_to.parameters(recurse=False), key=lambda p: p.ds_id))
sorted([p for p in module_to_apply_fn_to.parameters(recurse=False) if is_zero_param(p)],
key=lambda p: p.ds_id))

for param in params_to_apply_fn_to:
param.all_gather()
Expand Down Expand Up @@ -464,7 +472,8 @@ def _init_subclass(cls, **kwargs):

# Replace .__init__() for future subclasses of torch.nn.Module
torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply)
if Init.override_module_apply:
torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply)

self._add_tensor_creation_wrappers()

Expand All @@ -489,29 +498,34 @@ def _disable_class(cls):

# putting methods back the way we found them
torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply
if Init.override_module_apply:
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply

self._remove_tensor_creation_wrappers()

self.patched = False

def _add_tensor_creation_wrappers(self):
torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype)
torch.tensor = zero_wrapper_for_fp_tensor_constructor(_orig_torch_tensor, self.dtype)
torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype)
torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype)
torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype)
torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype)
torch.arange = zero_wrapper_for_fp_tensor_constructor(_orig_torch_arange, self.dtype)
torch.eye = zero_wrapper_for_fp_tensor_constructor(_orig_torch_eye, self.dtype)
torch.randn = zero_wrapper_for_fp_tensor_constructor(_orig_torch_randn, self.dtype)

def _remove_tensor_creation_wrappers(self):
torch.Tensor.__new__ = torch.Tensor.__old_new__
torch.tensor = _orig_torch_tensor
torch.empty = _orig_torch_empty
torch.zeros = _orig_torch_zeros
torch.ones = _orig_torch_ones
torch.full = _orig_torch_full
torch.arange = _orig_torch_arange
torch.eye = _orig_torch_eye
torch.randn = _orig_torch_randn


def shutdown_init_context():
Expand Down Expand Up @@ -687,6 +701,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
num_persisted_parameters = 0
num_persisted_elements = 0
apply_param_persistence = False
override_module_apply = get_config_default(DeepSpeedZeroConfig, "override_module_apply")

def __init__(self,
module=None,
Expand Down Expand Up @@ -845,9 +860,12 @@ def __init__(self,
self.quantizer_module = CUDAQuantizer()
print_rank_0(f'Using quantizer: {self.quantizer_module.__class__.__name__}', force=True)

if _ds_config is not None and _ds_config.zero_config.offload_param is not None:
remote_device = _ds_config.zero_config.offload_param.device
pin_memory = _ds_config.zero_config.offload_param.pin_memory
if _ds_config is not None:
Init.override_module_apply = _ds_config.zero_config.override_module_apply

if _ds_config.zero_config.offload_param is not None:
remote_device = _ds_config.zero_config.offload_param.device
pin_memory = _ds_config.zero_config.offload_param.pin_memory

self._validate_remote_device(remote_device, _ds_config)

Expand Down Expand Up @@ -877,12 +895,21 @@ def _update_persist_config(self, ds_config):
Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold
Init.model_persistence_threshold = ds_config.zero_config.model_persistence_threshold // self.num_partitions

def _zero_init_param(self, param):
self._convert_to_deepspeed_param(param)
if dist.get_world_group() == self.get_dp_process_group():
dist.broadcast(param, 0, self.get_dp_process_group())
else:
dist.broadcast(param, dist.get_global_rank(self.get_dp_process_group(), 0), self.get_dp_process_group())
param.partition()

def _convert_to_zero_parameters(self, param_list):
for param in param_list:
if is_zero_param(param):
continue
self._convert_to_deepspeed_param(param)
param.partition()

param.data = param.data.to(self.local_device)
self._zero_init_param(param)

def _validate_remote_device(self, remote_device, ds_config):
if ds_config is not None:
Expand All @@ -904,28 +931,19 @@ def _post_init_method(self, module):
print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)
see_memory_usage(f"Before converting and partitioning params in {module.__class__.__name__}", force=False)

global param_count
for name, param in module.named_parameters(recurse=False):
param_count += param.numel()
print_rank_0(f'Analyzing param {name} in {module.__class__.__name__}', force=False)
InsertPostInitMethodToModuleSubClasses.num_module_parameters += 1
InsertPostInitMethodToModuleSubClasses.num_module_elements += param.numel()
if not is_zero_param(param):
self._convert_to_deepspeed_param(param)
if not get_accelerator().on_accelerator(param):
param.data = param.data.to(self.local_device)
self._zero_init_param(param)
print_rank_0(
f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}")

if get_accelerator().on_accelerator(param):
if dist.get_world_group() == self.get_dp_process_group():
dist.broadcast(param, 0, self.get_dp_process_group())
else:
dist.broadcast(param, dist.get_global_rank(self.get_dp_process_group(), 0),
self.get_dp_process_group())
else:
if dist.get_rank() == 0:
logger.warn(f"param `{name}` in {module.__class__.__name__} "
f"not on GPU so was not broadcasted from rank 0")

param.partition()
see_memory_usage(
f"Param count {param_count}. After converting and partitioning params in {module.__class__.__name__}",
f"Param count {InsertPostInitMethodToModuleSubClasses.num_module_elements}. After converting and partitioning params in {module.__class__.__name__}",
force=False)

def _convert_to_deepspeed_param(self, param):
Expand Down Expand Up @@ -1342,7 +1360,6 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):

tensor_size = self._aligned_size(param)
partition_size = tensor_size // self.num_partitions

if param.ds_tensor is None:
final_location = None
if self.remote_device == OffloadDeviceEnum.nvme and self.param_swapper.swappable_tensor(
Expand Down
11 changes: 11 additions & 0 deletions docs/code-docs/source/zero3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,17 @@ DeepSpeed can automatically detect the following external parameter scenarios:
.. autofunction:: deepspeed.zero.unregister_external_parameter


.. `Module.apply <https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=module+apply#torch.nn.Module.apply>`_
Overriding Module.apply
===============================
A convenient mechanism for customizing model initialization is `Module.apply <https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=module+apply#torch.nn.Module.apply>`_.
With ZeRO stage 3, ``Module.apply`` implementations must account for parameter partitioning by ``zero.Init`` during model initialization. The default behavior of ZeRO stage 3 is to automatically
handle this issue by overriding ``Module.apply`` to ensure that parameters are gathered before access by ``Module.apply``. The benefit of this approach is development convenience, since
users are saved the burden of manual parameter coordination in ``Module.apply``. However, the downside is slow model initialization, since all the model parameters (e.g., billions) are gathered
even though the common usage of ``Module.apply`` is to customize a few parameters. Developers can disable this default behavior by setting the ``override_module_apply`` configuration knob to ``False``,
for faster model initialization at the cost of manually handling partitioned parameters in their ``Module.apply`` implementations.


Memory-Centric Tiling
---------------------

Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
accelerate
clang-format==16.0.2
coverage
docutils<0.18
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/runtime/zero/test_zero_nesting_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from unit.common import DistributedTest

from transformers import VisionEncoderDecoderModel
from transformers.deepspeed import HfDeepSpeedConfig

import deepspeed


Expand Down Expand Up @@ -44,3 +47,26 @@ def test_shutdown_in_nesting_init(self):
# ensure that zero3 processed the parameter
assert hasattr(model2.weight, "ds_id")
deepspeed_engine2, *_ = deepspeed.initialize(model=model2, config_params=ds_config)


class TestNestedParallelInit(DistributedTest):
world_size = 1

# Testing a model with composed and nested zero.Inits, with 3 zero.Init contexts, 1 parent and 2 children.
# The skeleton of the model is like so
#
# class VisionEncoderDecoderModel(...)::
# def __init__(self):
# encoder = AutoModel.from_config(config.encoder)
# decoder = AutoModelForCausalLM.from_config(config.decoder)
#
# And the user calls like below:
# VisionEncoderDecoderModel.from_pretrained(...)
# which calls this constructor inside zero.Init

def test_nested_parallel_init(self):
ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3))
dschf = HfDeepSpeedConfig(ds_config) # keep this object alive
model = VisionEncoderDecoderModel.from_pretrained(
"hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2")
assert all([hasattr(p, 'ds_id') for p in model.parameters()])

0 comments on commit 7f90ef4

Please sign in to comment.