Skip to content

Commit

Permalink
Support broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Sep 11, 2024
1 parent 4342c58 commit 22e8004
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
36 changes: 36 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,42 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
OpType::COALESCED);
}

c10::intrusive_ptr<Work> ProcessGroupXCCL::broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts) {
TORCH_CHECK(
tensors.size() == 1, "Expecting one tensor only but got multiple");
auto tensor = tensors.back();
if (tensor.is_complex()) {
tensor = at::view_as_real(tensor);
}
check_xpu_single_tensor(tensor);

const auto root = opts.rootRank + opts.rootTensor;

return collective(
tensor,
tensor,
[&](at::Tensor& input,
at::Tensor& output,
ccl::broadcast_attr attr,
xcclComm_t& comm,
ccl::stream& stream) {
auto xcclDataType = getXcclDataType(input.scalar_type());
ccl::event ret_evt;
ret_evt = ccl::broadcast(
input.data_ptr(),
(size_t)input.numel(),
xcclDataType,
root,
comm,
stream,
attr);
return ret_evt;
},
OpType::BROADCAST);
}

} // namespace c10d

#endif // USE_C10D_XCCL
4 changes: 1 addition & 3 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ class TORCH_API ProcessGroupXCCL : public Backend {

c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override {
TORCH_CHECK(false, "ProcessGroupXCCL::broadcast not implemented");
}
const BroadcastOptions& opts = BroadcastOptions()) override;

c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
Expand Down

0 comments on commit 22e8004

Please sign in to comment.