Skip to content

Commit

Permalink
added block -- causes infinite hang
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Dec 13, 2024
1 parent b133740 commit ac9afe1
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from argparse import ArgumentParser
from math import ceil
from typing import List, Any, Union, Optional, Callable
import datetime
import os
import pickle
import signal
Expand Down Expand Up @@ -92,6 +93,7 @@ def initialize_distributed_compute(dist_backend: str = "ccl"):
init_method="env://",
world_size=int(os.environ.get("WORLD_SIZE")),
rank=int(os.environ.get("RANK")),
timeout=datetime.timedelta(minutes=5),
)

my_rank = dist.get_rank() # Global!
Expand Down Expand Up @@ -226,11 +228,17 @@ def gather_distributed_data(

# First gather sizes to allocate correct buffer sizes.
local_size = torch.tensor([local_tensor.numel()], device=local_tensor.device)
size_list = [
torch.tensor([0], device=local_tensor.device) for _ in range(world_size)
]
if rank == 0:
size_list = [
torch.tensor([0], device=local_tensor.device) for _ in range(world_size)
]
else:
size_list = None
print("+ all gather of local_size={} to size_list".format(local_size))
dist.all_gather(size_list, local_size)
dist.gather(local_size, gather_list=size_list, dst=0)
# 44 [0] Process 0: Caught error: Invalid function argument. Expected parameter `tensor` to be of type torch.Tensor.[0]
dist.barrier() # Add synchronization


# Pad local tensor to maximum size.
print("+ padding local tensor")
Expand All @@ -245,11 +253,15 @@ def gather_distributed_data(

# Gather all tensors.
print("+ gathering all tensors from world_size={}".format(world_size))
tensor_list = [
torch.zeros(max_size, dtype=local_tensor.dtype, device=local_tensor.device)
for _ in range(world_size)
]
dist.all_gather(tensor_list, local_tensor)
if rank == 0:
tensor_list = [
torch.zeros(max_size, dtype=local_tensor.dtype, device=local_tensor.device)
for _ in range(world_size)
]
else:
tensor_list = None
dist.gather(local_tensor, gather_list=tensor_list, dst=0)
dist.barrier() # Add synchronization

# Trim padding and deserialize if necessary.
result = []
Expand Down

0 comments on commit ac9afe1

Please sign in to comment.