diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 1aea4c83f539..b43d14498c95 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1041,23 +1041,22 @@ def _set_client_model(self, model): def _configure_distributed_model(self, model): self._set_client_model(model) - - is_zero3_model = self.zero_optimization_partition_weights() and any( + is_zero_init_model = self.zero_optimization_partition_weights() and any( [hasattr(param, "ds_id") for param in self.module.parameters()]) if self.fp16_enabled(): - if is_zero3_model: + if is_zero_init_model: self.__check_params(self.module, torch.half) self.module.half() elif self.bfloat16_enabled(): - if is_zero3_model: + if is_zero_init_model: self.__check_params(self.module, torch.bfloat16) self.module.bfloat16() else: self.__check_params(self.module, torch.float) # zero.Init() handles device placement of model - if not self.dont_change_device: + if not (self.dont_change_device or is_zero_init_model): self.module.to(self.device) # MoE related initialization @@ -1097,7 +1096,7 @@ def _configure_distributed_model(self, model): self.expert_parallel_group = groups._get_expert_parallel_group_dict() self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict() - if not self.amp_enabled(): + if not (self.amp_enabled() or is_zero_init_model): self._broadcast_model() # check if parameters are duplicated in optimizer param_groups diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 55f933e78983..b19e956d2d70 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -38,7 +38,8 @@ "zero_hpz_partition_size": 1, "zero_quantized_weights": [true|false], "zero_quantized_gradients": [true|false], - "memory_efficient_linear": [true|false] + "memory_efficient_linear": [true|false], + "override_module_apply": [true|false], } } """ @@ -269,11 +270,17 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): mics_shard_size: int = Field(-1, new_param="mics_shard_size") mics_hierarchical_params_gather: bool = False + memory_efficient_linear: bool = True """ Use memory efficient linear implementation, for Stage 3. """ + override_module_apply: bool = True + """ + Override nn.Module apply function, for Stage 3. + """ + # Validators @validator("overlap_comm") def overlap_comm_valid(cls, field_value, values): diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index a7ad2ce32823..f0f9d6e8bf6e 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -34,7 +34,6 @@ from deepspeed.accelerator import get_accelerator from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus -param_count = 0 partitioned_param_data_shape = [0] zero_init_context = 0 top_level_context = None @@ -217,12 +216,14 @@ class ZeroParamStatus(Enum): INFLIGHT = 3 +_orig_torch_tensor = torch.tensor _orig_torch_empty = torch.empty _orig_torch_zeros = torch.zeros _orig_torch_ones = torch.ones _orig_torch_full = torch.full _orig_torch_arange = torch.arange _orig_torch_eye = torch.eye +_orig_torch_randn = torch.randn def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable: @@ -288,6 +289,8 @@ def free_param(param: Parameter) -> None: # Inserts _post_init_method at the end of init method # for all sub classes of torch.nn.Module class InsertPostInitMethodToModuleSubClasses(object): + num_module_parameters = 0 + num_module_elements = 0 def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtype=None): self.mem_efficient_linear = mem_efficient_linear @@ -324,7 +327,10 @@ def __exit__(self, exc_type, exc_value, traceback): top_level_context = None if dist.get_rank() == 0: - logger.info("finished initializing model with %.2fB parameters", param_count / 1e9) + billion_elems = InsertPostInitMethodToModuleSubClasses.num_module_elements / 1e9 + num_params = InsertPostInitMethodToModuleSubClasses.num_module_parameters + logger.info( + f"finished initializing model - num_params = {num_params}, num_elems = {billion_elems:.2f}B") # Now that we cleaned up the metaclass injection, raise the exception. if exc_type is not None: @@ -381,14 +387,16 @@ def wrapped_fn_to_apply(module_to_apply_fn_to: Module) -> None: 3. broadcasts root rank's parameters to the other ranks 4. re-partitions the parameters """ - if not all(is_zero_param(p) for p in module_to_apply_fn_to.parameters(recurse=False)): - raise RuntimeError(f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, " - f"were zero params, is it possible that the parameters were " - f"overwritten after they were initialized? " - f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} ") + + # TODO Delay error checking for dangling partitioned parameters to post module init + # raise RuntimeError(f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, " + # f"were zero params, is it possible that the parameters were " + # f"overwritten after they were initialized? " + # f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} ") params_to_apply_fn_to: Iterable[Parameter] = list( - sorted(module_to_apply_fn_to.parameters(recurse=False), key=lambda p: p.ds_id)) + sorted([p for p in module_to_apply_fn_to.parameters(recurse=False) if is_zero_param(p)], + key=lambda p: p.ds_id)) for param in params_to_apply_fn_to: param.all_gather() @@ -464,7 +472,8 @@ def _init_subclass(cls, **kwargs): # Replace .__init__() for future subclasses of torch.nn.Module torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) - torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply) + if Init.override_module_apply: + torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply) self._add_tensor_creation_wrappers() @@ -489,7 +498,8 @@ def _disable_class(cls): # putting methods back the way we found them torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass - torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply + if Init.override_module_apply: + torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply self._remove_tensor_creation_wrappers() @@ -497,21 +507,25 @@ def _disable_class(cls): def _add_tensor_creation_wrappers(self): torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype) + torch.tensor = zero_wrapper_for_fp_tensor_constructor(_orig_torch_tensor, self.dtype) torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype) torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype) torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype) torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype) torch.arange = zero_wrapper_for_fp_tensor_constructor(_orig_torch_arange, self.dtype) torch.eye = zero_wrapper_for_fp_tensor_constructor(_orig_torch_eye, self.dtype) + torch.randn = zero_wrapper_for_fp_tensor_constructor(_orig_torch_randn, self.dtype) def _remove_tensor_creation_wrappers(self): torch.Tensor.__new__ = torch.Tensor.__old_new__ + torch.tensor = _orig_torch_tensor torch.empty = _orig_torch_empty torch.zeros = _orig_torch_zeros torch.ones = _orig_torch_ones torch.full = _orig_torch_full torch.arange = _orig_torch_arange torch.eye = _orig_torch_eye + torch.randn = _orig_torch_randn def shutdown_init_context(): @@ -687,6 +701,7 @@ class Init(InsertPostInitMethodToModuleSubClasses): num_persisted_parameters = 0 num_persisted_elements = 0 apply_param_persistence = False + override_module_apply = get_config_default(DeepSpeedZeroConfig, "override_module_apply") def __init__(self, module=None, @@ -845,9 +860,12 @@ def __init__(self, self.quantizer_module = CUDAQuantizer() print_rank_0(f'Using quantizer: {self.quantizer_module.__class__.__name__}', force=True) - if _ds_config is not None and _ds_config.zero_config.offload_param is not None: - remote_device = _ds_config.zero_config.offload_param.device - pin_memory = _ds_config.zero_config.offload_param.pin_memory + if _ds_config is not None: + Init.override_module_apply = _ds_config.zero_config.override_module_apply + + if _ds_config.zero_config.offload_param is not None: + remote_device = _ds_config.zero_config.offload_param.device + pin_memory = _ds_config.zero_config.offload_param.pin_memory self._validate_remote_device(remote_device, _ds_config) @@ -877,12 +895,21 @@ def _update_persist_config(self, ds_config): Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold Init.model_persistence_threshold = ds_config.zero_config.model_persistence_threshold // self.num_partitions + def _zero_init_param(self, param): + self._convert_to_deepspeed_param(param) + if dist.get_world_group() == self.get_dp_process_group(): + dist.broadcast(param, 0, self.get_dp_process_group()) + else: + dist.broadcast(param, dist.get_global_rank(self.get_dp_process_group(), 0), self.get_dp_process_group()) + param.partition() + def _convert_to_zero_parameters(self, param_list): for param in param_list: if is_zero_param(param): continue - self._convert_to_deepspeed_param(param) - param.partition() + + param.data = param.data.to(self.local_device) + self._zero_init_param(param) def _validate_remote_device(self, remote_device, ds_config): if ds_config is not None: @@ -904,28 +931,19 @@ def _post_init_method(self, module): print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False) see_memory_usage(f"Before converting and partitioning params in {module.__class__.__name__}", force=False) - global param_count for name, param in module.named_parameters(recurse=False): - param_count += param.numel() + print_rank_0(f'Analyzing param {name} in {module.__class__.__name__}', force=False) + InsertPostInitMethodToModuleSubClasses.num_module_parameters += 1 + InsertPostInitMethodToModuleSubClasses.num_module_elements += param.numel() if not is_zero_param(param): - self._convert_to_deepspeed_param(param) + if not get_accelerator().on_accelerator(param): + param.data = param.data.to(self.local_device) + self._zero_init_param(param) print_rank_0( f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}") - if get_accelerator().on_accelerator(param): - if dist.get_world_group() == self.get_dp_process_group(): - dist.broadcast(param, 0, self.get_dp_process_group()) - else: - dist.broadcast(param, dist.get_global_rank(self.get_dp_process_group(), 0), - self.get_dp_process_group()) - else: - if dist.get_rank() == 0: - logger.warn(f"param `{name}` in {module.__class__.__name__} " - f"not on GPU so was not broadcasted from rank 0") - - param.partition() see_memory_usage( - f"Param count {param_count}. After converting and partitioning params in {module.__class__.__name__}", + f"Param count {InsertPostInitMethodToModuleSubClasses.num_module_elements}. After converting and partitioning params in {module.__class__.__name__}", force=False) def _convert_to_deepspeed_param(self, param): @@ -1342,7 +1360,6 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): tensor_size = self._aligned_size(param) partition_size = tensor_size // self.num_partitions - if param.ds_tensor is None: final_location = None if self.remote_device == OffloadDeviceEnum.nvme and self.param_swapper.swappable_tensor( diff --git a/docs/code-docs/source/zero3.rst b/docs/code-docs/source/zero3.rst index b70dc022ce4b..333b29ed98d8 100644 --- a/docs/code-docs/source/zero3.rst +++ b/docs/code-docs/source/zero3.rst @@ -309,6 +309,17 @@ DeepSpeed can automatically detect the following external parameter scenarios: .. autofunction:: deepspeed.zero.unregister_external_parameter +.. `Module.apply `_ +Overriding Module.apply +=============================== +A convenient mechanism for customizing model initialization is `Module.apply `_. +With ZeRO stage 3, ``Module.apply`` implementations must account for parameter partitioning by ``zero.Init`` during model initialization. The default behavior of ZeRO stage 3 is to automatically +handle this issue by overriding ``Module.apply`` to ensure that parameters are gathered before access by ``Module.apply``. The benefit of this approach is development convenience, since +users are saved the burden of manual parameter coordination in ``Module.apply``. However, the downside is slow model initialization, since all the model parameters (e.g., billions) are gathered +even though the common usage of ``Module.apply`` is to customize a few parameters. Developers can disable this default behavior by setting the ``override_module_apply`` configuration knob to ``False``, +for faster model initialization at the cost of manually handling partitioned parameters in their ``Module.apply`` implementations. + + Memory-Centric Tiling --------------------- diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 0f0cb337fc51..6295a75e67c8 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,3 +1,4 @@ +accelerate clang-format==16.0.2 coverage docutils<0.18 diff --git a/tests/unit/runtime/zero/test_zero_nesting_init.py b/tests/unit/runtime/zero/test_zero_nesting_init.py index 162916d1c22a..143e7e997b13 100644 --- a/tests/unit/runtime/zero/test_zero_nesting_init.py +++ b/tests/unit/runtime/zero/test_zero_nesting_init.py @@ -7,6 +7,9 @@ from unit.common import DistributedTest +from transformers import VisionEncoderDecoderModel +from transformers.deepspeed import HfDeepSpeedConfig + import deepspeed @@ -44,3 +47,26 @@ def test_shutdown_in_nesting_init(self): # ensure that zero3 processed the parameter assert hasattr(model2.weight, "ds_id") deepspeed_engine2, *_ = deepspeed.initialize(model=model2, config_params=ds_config) + + +class TestNestedParallelInit(DistributedTest): + world_size = 1 + + # Testing a model with composed and nested zero.Inits, with 3 zero.Init contexts, 1 parent and 2 children. + # The skeleton of the model is like so + # + # class VisionEncoderDecoderModel(...):: + # def __init__(self): + # encoder = AutoModel.from_config(config.encoder) + # decoder = AutoModelForCausalLM.from_config(config.decoder) + # + # And the user calls like below: + # VisionEncoderDecoderModel.from_pretrained(...) + # which calls this constructor inside zero.Init + + def test_nested_parallel_init(self): + ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3)) + dschf = HfDeepSpeedConfig(ds_config) # keep this object alive + model = VisionEncoderDecoderModel.from_pretrained( + "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2") + assert all([hasattr(p, 'ds_id') for p in model.parameters()])