-
Notifications
You must be signed in to change notification settings - Fork 409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] No backend type associated with device type cpu #2477
Comments
Hi! thanks for your contribution!, great first issue! |
I meet the same bug. |
I also met the same bug |
Hello there, any update in this issue? |
Hi all, thanks for reporting this issue. I am currently looking into what can be done on to solve this issue. The |
Hi, @SkafteNicki def all_gather(data, group=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size, group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor, group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list The function comes from the training reference from torchvision: https://github.com/pytorch/vision/blob/main/references/detection/utils.py When I add the following codes to my script, I found it works in multi-gpus evaluation with coco_evaluator = MeanAveragePrecision(iou_type=args.iou_type, backend=args.backend)
coco_evaluator.dist_sync_fn = utils.all_gather if args.evaluate_on_cpu else None |
@SkafteNicki lets add the first test for this multi-GPU so we can reproduce and prevent it in the future? |
I've looked into the problem and found out that the main reason for this error is that default distributed backend for lightning is nccl. If |
I have also encountered the same issue when trying to use
But this did not resolve the problem. The root cause seems to be that a duplicated instance of the metric class This seems to be where the |
This is a copy-paste of my reply to this issue: Lightning-AI/pytorch-lightning#18803 I was having the same error message when using For me it worked adding the following three kwargs when the metric was initialized:
All three arguments are needed to solve it in my case. My code now looks like:
|
Thank you @Holer90 for sharing. Unfortunately your solution doesn't seem to work for me. It'd be useful to know a bit more about your configuration. In particular, what's your Trainer flags configuration, e.g. |
🐛 Bug
Metrics (predefined in library and custom implementations) using concatenation
dist_reduce_fx="cat"
and CPU computationcompute_on_cpu=True
raise an error when training in multiple GPUs (ddp
). The concrete error isRuntimeError: No backend type associated with device type cpu
.To Reproduce
Code sample:
Stacktrace
Expected behavior
Metric is computed properly merging the different lists in the differents processes in multi GPU training scenarios.
Environment
pip
Additional context
Related bug in PyTorch Lightning
Lightning-AI/pytorch-lightning#18803
The text was updated successfully, but these errors were encountered: