Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
`clone_tensors_for_torch_save()` function: When the `item.device` is different from `device` input, `tensor.clone()` is not actually required because `to()` function also clones the original tensor. +) I observed memory bloat under following conditions: * Training a Whisper model w/ `transformers` framework with `ZeRO-0` and `ZeRO-1` configuration. * Memory bloating can be observed every time the model state_dict is cloned using `clone_tensors_for_torch_save()` After I removed the unnecessary `clone()`, seems like the problem is solved. Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
- Loading branch information