diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index a54c847..59ab5b0 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -14,6 +14,7 @@ from argparse import ArgumentParser from math import ceil from typing import List, Any, Union, Optional, Callable +import datetime import os import pickle import signal @@ -92,6 +93,7 @@ def initialize_distributed_compute(dist_backend: str = "ccl"): init_method="env://", world_size=int(os.environ.get("WORLD_SIZE")), rank=int(os.environ.get("RANK")), + timeout=datetime.timedelta(minutes=5), ) my_rank = dist.get_rank() # Global! @@ -226,11 +228,17 @@ def gather_distributed_data( # First gather sizes to allocate correct buffer sizes. local_size = torch.tensor([local_tensor.numel()], device=local_tensor.device) - size_list = [ - torch.tensor([0], device=local_tensor.device) for _ in range(world_size) - ] + if rank == 0: + size_list = [ + torch.tensor([0], device=local_tensor.device) for _ in range(world_size) + ] + else: + size_list = None print("+ all gather of local_size={} to size_list".format(local_size)) - dist.all_gather(size_list, local_size) + dist.gather(local_size, gather_list=size_list, dst=0) + # 44 [0] Process 0: Caught error: Invalid function argument. Expected parameter `tensor` to be of type torch.Tensor.[0] + dist.barrier() # Add synchronization + # Pad local tensor to maximum size. print("+ padding local tensor") @@ -245,11 +253,15 @@ def gather_distributed_data( # Gather all tensors. print("+ gathering all tensors from world_size={}".format(world_size)) - tensor_list = [ - torch.zeros(max_size, dtype=local_tensor.dtype, device=local_tensor.device) - for _ in range(world_size) - ] - dist.all_gather(tensor_list, local_tensor) + if rank == 0: + tensor_list = [ + torch.zeros(max_size, dtype=local_tensor.dtype, device=local_tensor.device) + for _ in range(world_size) + ] + else: + tensor_list = None + dist.gather(local_tensor, gather_list=tensor_list, dst=0) + dist.barrier() # Add synchronization # Trim padding and deserialize if necessary. result = []