Skip to content

Commit

Permalink
Stage3: Use new torch grad accumulation hooks API
Browse files Browse the repository at this point in the history
* This commit addresses an issue reported in:
  #6718
* The existing code has been using the grad_acc node hook to reduce params grads.
  The constructs such as param.data = replicated_tensor.data used in
  allgather_params(..) are compiled into param.set() causing the hook assigned
  to the grad_acc node not being called.
* This is a known torch issue pytorch/pytorch#139742.
* The above caused accuracy issues and could be temporarily solved by simply
  disabling the torch compile when activation checkpointing is used.
* This commit provides a clean solution by replacing the hook on a grad_acc node
  to a hook using a new and robust hook API on a param itself:
  param.register_post_accumulate_grad_hook(..)
  • Loading branch information
deepcharm committed Nov 21, 2024
1 parent 83e4364 commit 5687088
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,6 @@ def overlapping_partition_gradients_reduce_epilogue(self):

def create_reduce_and_remove_grad_hooks(self):
print_rank_0(f'[Begin] Create gradient reduction hooks')
self.grad_accs = []
self.leaf_parameters = defaultdict(list)
for i, param_group in enumerate(self.fp16_groups):
for param in param_group:
Expand All @@ -1169,15 +1168,13 @@ def create_reduce_and_remove_grad_hooks(self):

#print(f"After all gather {param.device}, {param.shape}")
def wrapper(param):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]

@instrument_w_nvtx
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param)

self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads))
self.grad_accs.append(grad_acc)
self._grad_acc_hooks.append(
param.register_post_accumulate_grad_hook(reduce_partition_and_remove_grads))

#print(f"param grad fn {param.expand_as(param).grad_fn}")
if z3_leaf_parameter(param):
Expand Down

0 comments on commit 5687088

Please sign in to comment.