Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Dec 2, 2024
1 parent fc3f2c7 commit 0560c8b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
1 change: 0 additions & 1 deletion torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -687,4 +687,3 @@ class ProcessGroupXCCL(Backend):
rank: int,
size: int,
): ...

10 changes: 3 additions & 7 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def register_backend(
"`cuda`. Please specify it via the `devices` argument of "
"`register_backend`."
)
Backend.backend_capability[name.lower()] = ["cpu", "cuda", "xpu"]
Backend.backend_capability[name.lower()] = ["cpu", "cuda"]
elif isinstance(devices, str):
# Single device string specified. Simply convert to list.
Backend.backend_capability[name.lower()] = [devices]
Expand Down Expand Up @@ -1910,10 +1910,6 @@ def _new_process_group_helper(
elif backend_str == Backend.XCCL:
if not is_xccl_available():
raise RuntimeError("Distributed package doesn't have XCCL built in")
if backend_options is not None:
assert isinstance(
backend_options, ProcessGroupXCCL.Options
), "Expected backend_options argument to be of type ProcessGroupXCCL.Options"
backend_class = ProcessGroupXCCL(
backend_prefix_store, group_rank, group_size
)
Expand Down Expand Up @@ -2726,7 +2722,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
return _IllegalWork()
else:
return None

work = group.allreduce([tensor], opts)

if async_op:
Expand Down Expand Up @@ -4091,7 +4087,7 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
return _IllegalWork()
else:
return None

work = group._reduce_scatter_base(output, input, opts)

if async_op:
Expand Down

0 comments on commit 0560c8b

Please sign in to comment.