Skip to content

Commit

Permalink
Removes unnecessary cloning (#6761)
Browse files Browse the repository at this point in the history
`clone_tensors_for_torch_save()` function:

When the `item.device` is different from `device` input,
`tensor.clone()` is not actually required because `to()` function also
clones the original tensor.


+) I observed memory bloat under following conditions:
* Training a Whisper model w/ `transformers` framework with `ZeRO-0` and
`ZeRO-1` configuration.
* Memory bloating can be observed every time the model state_dict is
cloned using `clone_tensors_for_torch_save()`

After I removed the unnecessary `clone()`, seems like the problem is
solved.

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
3 people authored Nov 21, 2024
1 parent b5709cc commit f515104
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion deepspeed/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def clone_tensors_for_torch_save(item, device=torch.device('cpu')):
- copy of ``item`` with cloned tensors on target device
"""
if torch.is_tensor(item):
return item.detach().clone().to(device)
if type(device) is str:
device = torch.device(device)
if device == item.device:
return item.detach().clone()
else:
return item.detach().to(device)
elif isinstance(item, list):
return [clone_tensors_for_torch_save(v, device) for v in item]
elif isinstance(item, tuple):
Expand Down

0 comments on commit f515104

Please sign in to comment.