From 0ec2d3e4bfa2d0a5237e9747da1ef9d5e4a4453b Mon Sep 17 00:00:00 2001 From: Yi30 <106061964+yiliu30@users.noreply.github.com> Date: Sat, 18 Nov 2023 05:58:47 +0800 Subject: [PATCH] Add get and set APIs for the ZeRO-3 partitioned parameters (#4681) The DeepSpeed currently supports a set of debugging APIs to [get](https://deepspeed.readthedocs.io/en/latest/zero3.html#debugging) and [set](https://deepspeed.readthedocs.io/en/latest/zero3.html#modifying-partitioned-states) the **full** model states (parameters, gradients, and optimizer states). However, in some scenarios, only **local states** are needed, for example, when pruning some model layers based on a local criterion. After calling `model_engine.step()`, we need to apply the local mask to the partitioned parameters owned by each process. Therefore, I am submitting this PR to introduce some new APIs for `get` and `set` ZeRO-3 partial model states. ### APIs intro ```python def safe_get_local_fp32_param(param): """Get the fp32 partitioned parameter.""" def safe_get_local_grad(param): """Get the fp32 gradient of a partitioned parameter.""" def safe_get_local_optimizer_state(param, optim_state_key): """Get the fp32 optimizer state of a partitioned parameter.""" def safe_set_local_fp32_param(param, value): """Update the partitioned fp32 parameter.""" def safe_set_local_optimizer_state(param, value, optim_state_key): """Update the fp32 optimizer state of a partitioned parameter.""" ``` ### Usage ```python # local API from deepspeed.utils import ( safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state, safe_set_local_fp32_param, safe_set_local_optimizer_state ) ``` ### TODO - [x] Add local APIs - [x] Add UTs - [x] Update Docs @tjruwase --------- Signed-off-by: yliu Co-authored-by: yliu Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/zero/stage3.py | 40 ++++++++ deepspeed/utils/__init__.py | 2 + deepspeed/utils/tensor_fragment.py | 66 +++++++++++++ docs/code-docs/source/zero3.rst | 45 +++++++-- .../runtime/zero/test_zero_tensor_fragment.py | 99 +++++++++++++++++-- 5 files changed, 233 insertions(+), 19 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index d393e4fd898c..4dfb91676525 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2238,6 +2238,46 @@ def set_full_hp_param(self, value, param, optim_state_key=None): if self._swappable_optimizer_subgroup(group_idx): self._optimizer_states_and_gradient_swap_out(group_idx) + ### Local API START ### + + def get_local_fp32_param(self, param, optim_state_key=None) -> Tensor: + if not param.requires_grad: + return None + fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key) + return fp32_opt_state + + def get_local_fp32_grad_for_param(self, param) -> Tensor: + if not param.requires_grad: + return None + + if not get_accelerator().is_synchronized_device(): + self.reduce_and_partition_stream.synchronize() + + if self.offload_optimizer: + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements) + else: + fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float() + return fp32_grad + + def set_local_hp_param(self, value, param, optim_state_key=None): + if not param.requires_grad: + return + + assert hasattr(param, "ds_tensor"), f" The parameter does not contain the partitioned copy of the tensor." + assert value.numel() == param.ds_tensor.numel( + ), f" Number of elements do not match: {value.numel()} != {param.ds_tensor.ds_numel}" + + fp32_opt_state_partition, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key) + value_partition = value.flatten() + fp32_opt_state_partition.data.copy_(value_partition.data) + + if self._swappable_optimizer_subgroup(group_idx): + self._optimizer_states_and_gradient_swap_out(group_idx) + logger.info(f"[set_local_hp_param][update the params' value successfully]") + + ### Local API END ### + @instrument_w_nvtx def _partition_all_parameters(self): self.parameter_offload.partition_all_parameters() diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index b6668b5ff5ce..6237d7239682 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -14,6 +14,8 @@ from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state from .tensor_fragment import set_full_hp_param from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state +from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state +from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state from .mixed_precision_linkage import link_hp_params from deepspeed.runtime.dataloader import RepeatingLoader from .numa import get_numactl_cmd diff --git a/deepspeed/utils/tensor_fragment.py b/deepspeed/utils/tensor_fragment.py index 18e373799ab7..5f94070dc4c7 100644 --- a/deepspeed/utils/tensor_fragment.py +++ b/deepspeed/utils/tensor_fragment.py @@ -185,6 +185,72 @@ def safe_get_full_grad(param): return None +### Local API START ### +def safe_get_local_grad(param): + """Get the fp32 gradient of a partitioned parameter. + Args: + param (``torch.nn.Parameter``): A model parameter + """ + if param.grad is not None: + return param.grad + + # ZeRO stage 3 param + if hasattr(param, 'ds_id'): + return param._z3_optimizer.get_local_fp32_grad_for_param(param) + + return None + + +def safe_get_local_fp32_param(param): + """Get the fp32 partitioned parameter. + Args: + param (``torch.nn.Parameter``): A model parameter + """ + # ZeRO stage 3 param + if hasattr(param, 'ds_id'): + return param._z3_optimizer.get_local_fp32_param(param) + + return None + + +def safe_get_local_optimizer_state(param, optim_state_key): + """Get the fp32 optimizer state of a partitioned parameter. + Args: + param (``torch.nn.Parameter``): A model parameter + optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer) + """ + # ZeRO stage 3 param + if hasattr(param, 'ds_id'): + return param._z3_optimizer.get_local_fp32_param(param, optim_state_key) + + return None + + +def safe_set_local_optimizer_state(param, value, optim_state_key): + """Update the fp32 optimizer state of a partitioned parameter. + Args: + param (``torch.nn.Parameter``): A model parameter + value (``torch.Tensor``): New value + optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer) + """ + # ZeRO stage 3 param + if hasattr(param, 'ds_id'): + param._z3_optimizer.set_local_hp_param(value, param, optim_state_key) + + +def safe_set_local_fp32_param(param, value): + """Update the partitioned fp32 parameter. + Args: + param (``torch.nn.Parameter``): A model parameter + value (``torch.Tensor``): New value + """ + # ZeRO stage 3 param + if hasattr(param, 'ds_id'): + param._z3_optimizer.set_local_hp_param(value, param) + + +### Local API END ### + # TODO: Implement API for setting ZeRO partitioned gradients diff --git a/docs/code-docs/source/zero3.rst b/docs/code-docs/source/zero3.rst index 56a7987dc496..2a6a48ca91db 100644 --- a/docs/code-docs/source/zero3.rst +++ b/docs/code-docs/source/zero3.rst @@ -341,9 +341,9 @@ parallelism to fit them in limited GPU memory. Debugging --------- -Debugging ZeRO training is complicated by the partitioning of parameters, gradients, and optimizer states. None of these 3 groups of tensors (model states) can be normally accessed because of that. To overcome that DeepSpeed provides the following routines for accessing individual model states in their unpartitioned form. +Debugging ZeRO training is complicated by the partitioning of parameters, gradients, and optimizer states. None of these 3 groups of tensors (model states) can be normally accessed because of that. To overcome that DeepSpeed provides the following routines for accessing individual model states in both their partitioned (local) and unpartitioned (full) forms. -Important: Please note that these utilities must be called by all processes participating in the training, even if you decide to do something with the result only in the main process. If all processes don't participate these utilities will hang waiting for all processes to send their contribution. +Important: Please note that, to access the unpartitioned (full) form, these utilities must be called by all processes participating in the training, even if you decide to do something with the result only in the main process. If all processes don't participate these utilities will hang waiting for all processes to send their contribution. Additionally, you must be aware that these routines return correct data only in specific phases of the training. So for examples the gradients are valid after ``backward`` and before ``step``. The optimizer states are updated after ``step``. Same goes for fp32 master weights. @@ -353,6 +353,12 @@ Additionally, you must be aware that these routines return correct data only in .. autofunction:: deepspeed.utils.safe_get_full_optimizer_state +.. autofunction:: deepspeed.utils.safe_get_local_fp32_param + +.. autofunction:: deepspeed.utils.safe_get_local_grad + +.. autofunction:: deepspeed.utils.safe_get_local_optimizer_state + These routines can be used in a training loop as shown in the following snippet. @@ -362,16 +368,26 @@ These routines can be used in a training loop as shown in the following snippet. [...] from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state for n, lp in model.named_parameters(): - # 1. gradient lookup + # 1. Access the full states + # 1) gradient lookup # For zero1 and zero2, gradient lookup must be called after `backward` and before `step` # For zero3, gradient lookup must be called after `backward` hp_grad = safe_get_full_grad(lp) - # 2. fp32 and optim states can probably be called anywhere in the training loop, but will be updated after `step` + + # 2) fp32 and optim states can probably be called anywhere in the training loop, but will be updated after `step` hp = safe_get_full_fp32_param(lp) exp_avg = safe_get_full_optimizer_state(lp, "exp_avg") exp_avg_sq = safe_get_full_optimizer_state(lp, "exp_avg_sq") + # 2. Access the local states (zero3) + # For zero3, all of the parameters, gradients, and optimizer states are partitioned, + # and each process can access its corresponding local state. + local_hp = safe_get_local_fp32_param(lp) + local_hp_grad = safe_get_local_grad(lp) + local_exp_avg = safe_get_local_optimizer_state(lp, "exp_avg") + local_exp_avg_sq = safe_get_local_optimizer_state(lp, "exp_avg_sq") + [...] optimizer.step() @@ -380,12 +396,15 @@ These routines can be used in a training loop as shown in the following snippet. Modifying Partitioned States ---------------------------- -Sometimes, a user may want to modify parameters or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following two routines for modifying the fp32 master parameters and the fp32 optimizer states. +Sometimes, a user may want to modify parameters or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following routines for modifying the fp32 master parameters and the fp32 optimizer states. .. autofunction:: deepspeed.utils.safe_set_full_fp32_param .. autofunction:: deepspeed.utils.safe_set_full_optimizer_state +.. autofunction:: deepspeed.utils.safe_set_local_fp32_param + +.. autofunction:: deepspeed.utils.safe_set_local_optimizer_state These routines can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet. @@ -393,14 +412,22 @@ These routines can be used at any point after initialization of the DeepSpeed en [...] from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state + from deepspeed.utils import safe_set_local_fp32_param, safe_set_local_optimizer_state # Here is an example to zero all the fp32 parameters and optimizer states. for n, lp in model.named_parameters(): - # Assume zero stage 1 or 2, since stage 3 requires a gather to assemble lp + # 1. For zero stage 1 or 2, set the full fp32 and their full optim states zero_tensor = torch.zeros_like(lp) - hp = safe_set_full_fp32_param(lp, zero_tensor) - exp_avg = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg") - exp_avg_sq = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg_sq") + safe_set_full_fp32_param(lp, zero_tensor) + safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg") + safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg_sq") + + # 2. For zero stage 3, each process sets its local fp32 parameters and their local optimizer states individually + zero_tensor_local = torch.zeros_like(lp.ds_tensor.shape) + + safe_set_local_fp32_param(lp, zero_tensor_local) + safe_set_local_optimizer_state(lp, zero_tensor_local, "exp_avg") + safe_set_local_optimizer_state(lp, zero_tensor_local, "exp_avg_sq") [...] diff --git a/tests/unit/runtime/zero/test_zero_tensor_fragment.py b/tests/unit/runtime/zero/test_zero_tensor_fragment.py index 63d05ab6d352..e50b03035bad 100644 --- a/tests/unit/runtime/zero/test_zero_tensor_fragment.py +++ b/tests/unit/runtime/zero/test_zero_tensor_fragment.py @@ -14,6 +14,8 @@ import deepspeed from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state +from deepspeed.utils import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state +from deepspeed.utils import safe_set_local_fp32_param, safe_set_local_optimizer_state from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.ops.aio import AsyncIOBuilder @@ -35,6 +37,22 @@ def validate_full_tensors(model): assert all([p is None for p in param_list]) +def validate_local_tensors(model): + for _, lp in model.named_parameters(): + hp = safe_get_local_fp32_param(lp) + exp_avg = safe_get_local_optimizer_state(lp, 'exp_avg') + exp_avg_sq = safe_get_local_optimizer_state(lp, 'exp_avg_sq') + hp_grad = safe_get_local_grad(lp) + param_list = [hp, hp_grad, exp_avg, exp_avg_sq] + if lp.requires_grad: + assert all([p is not None for p in param_list]) + else: + assert all([p is None for p in param_list]) + + +validate_funcs_mapping = {"full": validate_full_tensors, "local": validate_local_tensors} + + class MyModel(torch.nn.Module): def __init__(self, hidden_dim, frozen_weights): @@ -58,7 +76,7 @@ def forward(self, x, y): return val -def run_fragmented_model(model, config_dict, hidden_dim, dtype): +def run_fragmented_model(model, config_dict, hidden_dim, dtype, validate_func): model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) data_loader = random_dataloader(model=model, total_samples=10, @@ -70,7 +88,7 @@ def run_fragmented_model(model, config_dict, hidden_dim, dtype): loss = model(batch[0], batch[1]) loss = loss[1] model.backward(loss) - validate_full_tensors(model) + validate_func(model) model.step() # Needed in ZeRO 3. Not doing so can give memory leak @@ -83,15 +101,19 @@ class TestTensorFragmentGet(DistributedTest): world_size = 2 reuse_dist_env = True + @pytest.mark.parametrize('api_type', ['local', 'full']) @pytest.mark.parametrize('zero_stage', [1, 2, 3]) @pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]) - def test_zero_fragments(self, tmpdir, zero_stage, offload_device, frozen_weights): + def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, frozen_weights): if offload_device == OffloadDeviceEnum.nvme: if zero_stage != 3: pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}") if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: pytest.skip('Skip tests since async-io is not compatible') + if api_type == "local" and zero_stage != 3: + pytest.skip(f"Local APIs only for zero stage 3 but current stage is {zero_stage}") + config_dict = { "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, @@ -125,7 +147,9 @@ def test_zero_fragments(self, tmpdir, zero_stage, offload_device, frozen_weights else: model = MyModel(hidden_dim, frozen_weights) - run_fragmented_model(model, config_dict, hidden_dim, torch.float16) + validate_func = validate_funcs_mapping[api_type] + + run_fragmented_model(model, config_dict, hidden_dim, torch.float16, validate_func) def test_bf16_fragments(self, frozen_weights): if frozen_weights: @@ -154,10 +178,10 @@ def test_bf16_fragments(self, frozen_weights): hidden_dim = 128 model = MyModel(hidden_dim, frozen_weights) - run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16) + run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16, validate_full_tensors) -def create_random_values(model, key_list, group): +def create_random_values(model, key_list, group, use_cuda=True): param_values = {} for n, lp in model.named_parameters(): param_shape = lp.ds_shape if hasattr(lp, 'ds_id') else lp.shape @@ -188,21 +212,72 @@ def validate_param_values_with_dict(model, value_dict): assert torch.equal(expected_tensor, actual_tensor) +def create_random_values_for_local(model, key_list, group, use_cuda=True): + param_values = {} + for n, lp in model.named_parameters(): + param_shape = lp.ds_tensor.shape + param_values[n] = {} + for key in key_list: + device = model.device if use_cuda else "cpu" + rand_value = torch.rand(param_shape, dtype=torch.float32, device=device) + # dist.broadcast(rand_value, src=0, group=group) + param_values[n][key] = rand_value + return param_values + + +def set_local_param_values_with_dict(model, value_dict): + for n, lp in model.named_parameters(): + + for key, value_tensor in value_dict[n].items(): + if key == WEIGHT_KEY: + safe_set_local_fp32_param(lp, value_tensor) + else: + safe_set_local_optimizer_state(lp, value_tensor, key) + + +def validate_local_param_values_with_dict(model, value_dict): + for n, lp in model.named_parameters(): + for key, expected_tensor in value_dict[n].items(): + if key == WEIGHT_KEY: + actual_tensor = safe_get_local_fp32_param(lp) + else: + actual_tensor = safe_get_local_optimizer_state(lp, key) + assert torch.equal(expected_tensor, actual_tensor) + + +helper_funcs_mapping = { + "full": { + "create_random_values": create_random_values, + "set_param_values_with_dict": set_param_values_with_dict, + "validate_param_values_with_dict": validate_param_values_with_dict + }, + "local": { + "create_random_values": create_random_values_for_local, + "set_param_values_with_dict": set_local_param_values_with_dict, + "validate_param_values_with_dict": validate_local_param_values_with_dict + } +} + + @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) class TestTensorFragmentUpdate(DistributedTest): # Need multiple gpus to test possible hanging world_size = 2 reuse_dist_env = True + @pytest.mark.parametrize('api_type', ['local', 'full']) @pytest.mark.parametrize('zero_stage', [1, 2, 3]) @pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]) - def test_zero_fragments(self, tmpdir, zero_stage, offload_device, dtype): + def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, dtype): if dtype == torch.bfloat16 and not bf16_required_version_check(accelerator_check=False): pytest.skip( " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" ) + if api_type == "local" and zero_stage != 3: + pytest.skip(f"Local APIs only for zero stage 3 but current stage is {zero_stage}") + if offload_device == OffloadDeviceEnum.nvme: if zero_stage != 3: pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}") @@ -250,9 +325,13 @@ def test_zero_fragments(self, tmpdir, zero_stage, offload_device, dtype): dist.barrier() optim_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY] - optim_state_values = create_random_values(model, optim_keys, group) - set_param_values_with_dict(model, optim_state_values) - validate_param_values_with_dict(model, optim_state_values) + helper_funcs = helper_funcs_mapping[api_type] + optim_state_values = helper_funcs["create_random_values"](model, + optim_keys, + group, + use_cuda=offload_device == OffloadDeviceEnum.none) + helper_funcs["set_param_values_with_dict"](model, optim_state_values) + helper_funcs["validate_param_values_with_dict"](model, optim_state_values) # Needed in ZeRO 3. Not doing so can leak memory. model.destroy()