Skip to content

Commit

Permalink
Add get and set APIs for the ZeRO-3 partitioned parameters (microsoft…
Browse files Browse the repository at this point in the history
…#4681)

The DeepSpeed currently supports a set of debugging APIs to
[get](https://deepspeed.readthedocs.io/en/latest/zero3.html#debugging)
and
[set](https://deepspeed.readthedocs.io/en/latest/zero3.html#modifying-partitioned-states)
the **full** model states (parameters, gradients, and optimizer states).
However, in some scenarios, only **local states** are needed, for
example, when pruning some model layers based on a local criterion.
After calling `model_engine.step()`, we need to apply the local mask to
the partitioned parameters owned by each process. Therefore, I am
submitting this PR to introduce some new APIs for `get` and `set` ZeRO-3
partial model states.

### APIs intro
```python
def safe_get_local_fp32_param(param):
    """Get the fp32 partitioned parameter."""

def safe_get_local_grad(param):
    """Get the fp32 gradient of a partitioned parameter."""

def safe_get_local_optimizer_state(param, optim_state_key):
    """Get the fp32 optimizer state of a partitioned parameter."""

def safe_set_local_fp32_param(param, value):
    """Update the partitioned fp32 parameter."""

def safe_set_local_optimizer_state(param, value, optim_state_key):
    """Update the fp32 optimizer state of a partitioned parameter."""
```

### Usage
```python
# local API
from deepspeed.utils import (
    safe_get_local_fp32_param,
    safe_get_local_grad,
    safe_get_local_optimizer_state,
    safe_set_local_fp32_param,
    safe_set_local_optimizer_state
    )
```
### TODO
- [x] Add local APIs
- [x] Add UTs
- [x] Update Docs

@tjruwase

---------

Signed-off-by: yliu <test@[email protected]>
Co-authored-by: yliu <test@[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
4 people authored Nov 17, 2023
1 parent a3926bb commit 0ec2d3e
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 19 deletions.
40 changes: 40 additions & 0 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2238,6 +2238,46 @@ def set_full_hp_param(self, value, param, optim_state_key=None):
if self._swappable_optimizer_subgroup(group_idx):
self._optimizer_states_and_gradient_swap_out(group_idx)

### Local API START ###

def get_local_fp32_param(self, param, optim_state_key=None) -> Tensor:
if not param.requires_grad:
return None
fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
return fp32_opt_state

def get_local_fp32_grad_for_param(self, param) -> Tensor:
if not param.requires_grad:
return None

if not get_accelerator().is_synchronized_device():
self.reduce_and_partition_stream.synchronize()

if self.offload_optimizer:
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements)
else:
fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float()
return fp32_grad

def set_local_hp_param(self, value, param, optim_state_key=None):
if not param.requires_grad:
return

assert hasattr(param, "ds_tensor"), f" The parameter does not contain the partitioned copy of the tensor."
assert value.numel() == param.ds_tensor.numel(
), f" Number of elements do not match: {value.numel()} != {param.ds_tensor.ds_numel}"

fp32_opt_state_partition, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
value_partition = value.flatten()
fp32_opt_state_partition.data.copy_(value_partition.data)

if self._swappable_optimizer_subgroup(group_idx):
self._optimizer_states_and_gradient_swap_out(group_idx)
logger.info(f"[set_local_hp_param][update the params' value successfully]")

### Local API END ###

@instrument_w_nvtx
def _partition_all_parameters(self):
self.parameter_offload.partition_all_parameters()
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from .tensor_fragment import set_full_hp_param
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state
from .mixed_precision_linkage import link_hp_params
from deepspeed.runtime.dataloader import RepeatingLoader
from .numa import get_numactl_cmd
66 changes: 66 additions & 0 deletions deepspeed/utils/tensor_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,72 @@ def safe_get_full_grad(param):
return None


### Local API START ###
def safe_get_local_grad(param):
"""Get the fp32 gradient of a partitioned parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
"""
if param.grad is not None:
return param.grad

# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
return param._z3_optimizer.get_local_fp32_grad_for_param(param)

return None


def safe_get_local_fp32_param(param):
"""Get the fp32 partitioned parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
return param._z3_optimizer.get_local_fp32_param(param)

return None


def safe_get_local_optimizer_state(param, optim_state_key):
"""Get the fp32 optimizer state of a partitioned parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
return param._z3_optimizer.get_local_fp32_param(param, optim_state_key)

return None


def safe_set_local_optimizer_state(param, value, optim_state_key):
"""Update the fp32 optimizer state of a partitioned parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): New value
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
param._z3_optimizer.set_local_hp_param(value, param, optim_state_key)


def safe_set_local_fp32_param(param, value):
"""Update the partitioned fp32 parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): New value
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
param._z3_optimizer.set_local_hp_param(value, param)


### Local API END ###

# TODO: Implement API for setting ZeRO partitioned gradients


Expand Down
45 changes: 36 additions & 9 deletions docs/code-docs/source/zero3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,9 @@ parallelism to fit them in limited GPU memory.
Debugging
---------

Debugging ZeRO training is complicated by the partitioning of parameters, gradients, and optimizer states. None of these 3 groups of tensors (model states) can be normally accessed because of that. To overcome that DeepSpeed provides the following routines for accessing individual model states in their unpartitioned form.
Debugging ZeRO training is complicated by the partitioning of parameters, gradients, and optimizer states. None of these 3 groups of tensors (model states) can be normally accessed because of that. To overcome that DeepSpeed provides the following routines for accessing individual model states in both their partitioned (local) and unpartitioned (full) forms.

Important: Please note that these utilities must be called by all processes participating in the training, even if you decide to do something with the result only in the main process. If all processes don't participate these utilities will hang waiting for all processes to send their contribution.
Important: Please note that, to access the unpartitioned (full) form, these utilities must be called by all processes participating in the training, even if you decide to do something with the result only in the main process. If all processes don't participate these utilities will hang waiting for all processes to send their contribution.

Additionally, you must be aware that these routines return correct data only in specific phases of the training. So for examples the gradients are valid after ``backward`` and before ``step``. The optimizer states are updated after ``step``. Same goes for fp32 master weights.

Expand All @@ -353,6 +353,12 @@ Additionally, you must be aware that these routines return correct data only in

.. autofunction:: deepspeed.utils.safe_get_full_optimizer_state

.. autofunction:: deepspeed.utils.safe_get_local_fp32_param

.. autofunction:: deepspeed.utils.safe_get_local_grad

.. autofunction:: deepspeed.utils.safe_get_local_optimizer_state


These routines can be used in a training loop as shown in the following snippet.

Expand All @@ -362,16 +368,26 @@ These routines can be used in a training loop as shown in the following snippet.
[...]
from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
for n, lp in model.named_parameters():
# 1. gradient lookup
# 1. Access the full states
# 1) gradient lookup
# For zero1 and zero2, gradient lookup must be called after `backward` and before `step`
# For zero3, gradient lookup must be called after `backward`
hp_grad = safe_get_full_grad(lp)
# 2. fp32 and optim states can probably be called anywhere in the training loop, but will be updated after `step`
# 2) fp32 and optim states can probably be called anywhere in the training loop, but will be updated after `step`
hp = safe_get_full_fp32_param(lp)
exp_avg = safe_get_full_optimizer_state(lp, "exp_avg")
exp_avg_sq = safe_get_full_optimizer_state(lp, "exp_avg_sq")
# 2. Access the local states (zero3)
# For zero3, all of the parameters, gradients, and optimizer states are partitioned,
# and each process can access its corresponding local state.
local_hp = safe_get_local_fp32_param(lp)
local_hp_grad = safe_get_local_grad(lp)
local_exp_avg = safe_get_local_optimizer_state(lp, "exp_avg")
local_exp_avg_sq = safe_get_local_optimizer_state(lp, "exp_avg_sq")
[...]
optimizer.step()
Expand All @@ -380,27 +396,38 @@ These routines can be used in a training loop as shown in the following snippet.
Modifying Partitioned States
----------------------------

Sometimes, a user may want to modify parameters or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following two routines for modifying the fp32 master parameters and the fp32 optimizer states.
Sometimes, a user may want to modify parameters or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following routines for modifying the fp32 master parameters and the fp32 optimizer states.

.. autofunction:: deepspeed.utils.safe_set_full_fp32_param

.. autofunction:: deepspeed.utils.safe_set_full_optimizer_state

.. autofunction:: deepspeed.utils.safe_set_local_fp32_param

.. autofunction:: deepspeed.utils.safe_set_local_optimizer_state

These routines can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet.

.. code-block:: python
[...]
from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state
from deepspeed.utils import safe_set_local_fp32_param, safe_set_local_optimizer_state
# Here is an example to zero all the fp32 parameters and optimizer states.
for n, lp in model.named_parameters():
# Assume zero stage 1 or 2, since stage 3 requires a gather to assemble lp
# 1. For zero stage 1 or 2, set the full fp32 and their full optim states
zero_tensor = torch.zeros_like(lp)
hp = safe_set_full_fp32_param(lp, zero_tensor)
exp_avg = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg")
exp_avg_sq = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg_sq")
safe_set_full_fp32_param(lp, zero_tensor)
safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg")
safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg_sq")
# 2. For zero stage 3, each process sets its local fp32 parameters and their local optimizer states individually
zero_tensor_local = torch.zeros_like(lp.ds_tensor.shape)
safe_set_local_fp32_param(lp, zero_tensor_local)
safe_set_local_optimizer_state(lp, zero_tensor_local, "exp_avg")
safe_set_local_optimizer_state(lp, zero_tensor_local, "exp_avg_sq")
[...]
Expand Down
Loading

0 comments on commit 0ec2d3e

Please sign in to comment.