From 0560c8babdaec0561e9cdb935802ce15d3e784a2 Mon Sep 17 00:00:00 2001 From: hanchao Date: Mon, 2 Dec 2024 05:22:00 +0000 Subject: [PATCH] update --- torch/_C/_distributed_c10d.pyi | 1 - torch/distributed/distributed_c10d.py | 10 +++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index ab59afc88b610b..df065a68b8ef8b 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -687,4 +687,3 @@ class ProcessGroupXCCL(Backend): rank: int, size: int, ): ... - \ No newline at end of file diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 8d18edee42059e..18d1f418180f83 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -354,7 +354,7 @@ def register_backend( "`cuda`. Please specify it via the `devices` argument of " "`register_backend`." ) - Backend.backend_capability[name.lower()] = ["cpu", "cuda", "xpu"] + Backend.backend_capability[name.lower()] = ["cpu", "cuda"] elif isinstance(devices, str): # Single device string specified. Simply convert to list. Backend.backend_capability[name.lower()] = [devices] @@ -1910,10 +1910,6 @@ def _new_process_group_helper( elif backend_str == Backend.XCCL: if not is_xccl_available(): raise RuntimeError("Distributed package doesn't have XCCL built in") - if backend_options is not None: - assert isinstance( - backend_options, ProcessGroupXCCL.Options - ), "Expected backend_options argument to be of type ProcessGroupXCCL.Options" backend_class = ProcessGroupXCCL( backend_prefix_store, group_rank, group_size ) @@ -2726,7 +2722,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): return _IllegalWork() else: return None - + work = group.allreduce([tensor], opts) if async_op: @@ -4091,7 +4087,7 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F return _IllegalWork() else: return None - + work = group._reduce_scatter_base(output, input, opts) if async_op: