Skip to content

Commit

Permalink
support reduce_scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Sep 13, 2024
1 parent 0ad5677 commit 009e334
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 10 deletions.
198 changes: 191 additions & 7 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,13 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) {
return true;
}

constexpr const char* MULTI_DEVICE_ERROR_MSG =
"Expecting one tensor only but got multiple. You are probably using multiple "
"devices under one thread. The support for such usage has been deprecated. "
"For details, please refer to "
"https://pytorch.org/docs/stable/distributed.html#multi-gpu-collective-functions. "
"ProcessGroupXCCL continues supporting multi-process and multi-thread modes.";

ProcessGroupXCCL::ProcessGroupXCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
Expand Down Expand Up @@ -376,7 +383,50 @@ void ProcessGroupXCCL::groupEnd() {
--xcclActiveGroupCounter_;
}

// align with good design single-device style, input_t and output_t due to
// TODO: wait p2p enable
static constexpr int CoalActive = 0x01, CoalColl = 0x02;
void ProcessGroupXCCL::startCoalescing() {
coalescedDevice_.set_index(-1);
coalescedComm_ = nullptr;
coalescing_state_ |= CoalActive;
groupStart();
}

c10::intrusive_ptr<Work> ProcessGroupXCCL::endCoalescing(OpType optype) {
if (coalescedComm_ == nullptr) {
// There is no actual work being coalesced, return here
groupEnd();
coalescing_state_ = 0;
return nullptr;
}
TORCH_CHECK(
coalescedDevice_.index() >= 0,
"Somthing went wrong. Did you call end_coalescing before start_coalescing?");

auto comm = coalescedComm_;
auto device = coalescedDevice_;

const auto key = std::to_string(device.index());
auto stream = xcclStreams_.at(key);

auto work = initWork(device, rank_, optype);
work->blockingWait_ = blockingWait_;

groupEnd();

work->xcclEndEvent_->record(stream);

coalescing_state_ = 0;
coalescedComm_ = nullptr;
return work;
}

c10::intrusive_ptr<Work> ProcessGroupXCCL::endCoalescing() {
// Default OpType to COALESCED if not specified
return endCoalescing(OpType::COALESCED);
}

// align with single-device style, input_t and output_t due to
// allgatherv need vector output
template <
typename Fn,
Expand All @@ -399,6 +449,21 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
const auto key = std::to_string(device.index());
auto comm = getXCCLComm(key, device);

if (coalescing_state_ & CoalActive) {
coalescing_state_ |= CoalColl;
if (coalescedDevice_.index() < 0) {
coalescedDevice_ = device;
} else {
TORCH_CHECK(
coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG);
}
if (coalescedComm_ == nullptr) {
coalescedComm_ = comm;
} else {
TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG);
}
}

auto stream = xcclStreams_.at(key);

c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
Expand Down Expand Up @@ -502,6 +567,21 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collectiveCoalesced(
const auto key = std::to_string(device.index());
auto comm = getXCCLComm(key, device);

if (coalescing_state_ & CoalActive) {
coalescing_state_ |= CoalColl;
if (coalescedDevice_.index() < 0) {
coalescedDevice_ = device;
} else {
TORCH_CHECK(
coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG);
}
if (coalescedComm_ == nullptr) {
coalescedComm_ = comm;
} else {
TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG);
}
}

auto stream = xcclStreams_.at(key);

c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
Expand Down Expand Up @@ -535,8 +615,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collectiveCoalesced(
c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
TORCH_CHECK(
tensors.size() == 1, "Expecting one tensor only but got multiple");
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
auto tensor = tensors.back();
check_xpu_single_tensor(tensor);
return collective(
Expand Down Expand Up @@ -600,8 +679,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
c10::intrusive_ptr<Work> ProcessGroupXCCL::broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts) {
TORCH_CHECK(
tensors.size() == 1, "Expecting one tensor only but got multiple");
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
auto tensor = tensors.back();
if (tensor.is_complex()) {
tensor = at::view_as_real(tensor);
Expand Down Expand Up @@ -633,12 +711,46 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::broadcast(
OpType::BROADCAST);
}

c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_oop(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
const ReduceOptions& opts) {
if (outputTensor.numel() != inputTensor.numel()) {
C10_THROW_ERROR(
ValueError,
"Tensor input and output of _reduce_oop must have the same number of elements ");
}
return collective(
inputTensor,
outputTensor,
[&](at::Tensor& input,
at::Tensor& output,
ccl::reduce_attr attr,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
const int root = opts.rootRank + opts.rootTensor;
const auto xcclDataType = getXcclDataType(input.scalar_type());
const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
ccl::event ret_evt;
ret_evt = ccl::reduce(
input.data_ptr(),
output.data_ptr(),
(size_t)input.numel(),
xcclDataType,
xcclReduceOp,
root,
comm,
ccl::create_stream(stream.queue()));
return ret_evt;
},
OpType::REDUCE);
}

c10::intrusive_ptr<Work> ProcessGroupXCCL::allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts) {
TORCH_CHECK(
inputTensors.size() == 1, "Expecting one tensor only but got multiple");
TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto inputTensor = inputTensors.back();
check_xpu_single_tensor(inputTensor);
Expand Down Expand Up @@ -802,6 +914,78 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allgather_into_tensor_coalesced(
OpType::COALESCED);
}

c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) {
TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto outputTensor = outputTensors.back();
check_xpu_single_tensor(outputTensor);
// @lint-ignore CLANGTIDY
auto inputTensors_ = inputTensors.back();
TORCH_CHECK(
!isFloat8Type(outputTensor.scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");

bool same_size = check_same_size(inputTensors_);
if (same_size) {
// Flatten a vector of tensors into a single, stacked tensor.
at::Tensor inputFlattened = newLikeFlat(inputTensors_);
return collective(
inputFlattened,
outputTensor,
[&](at::Tensor& input,
at::Tensor& output,
ccl::reduce_attr attr,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
c10::xpu::XPUCachingAllocator::recordStream(
output.storage().data_ptr(), stream);
auto xcclDataType = getXcclDataType(input.scalar_type());
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
ccl::event ret_evt;
ret_evt = ccl::reduce_scatter(
input.data_ptr(),
output.data_ptr(),
(size_t)output.numel(),
xcclDataType,
xcclReduceOp,
comm,
ccl::create_stream(stream.queue()));
return ret_evt;
},
[&](at::xpu::XPUStream& Stream,
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL>& work) {
// Copy the input tensors to the flattened inputs.
c10::StreamGuard guard(Stream);
for (const auto j : c10::irange(inputTensors_.size())) {
c10::xpu::XPUCachingAllocator::recordStream(
inputTensors_[j].storage().data_ptr(), Stream);
inputFlattened[j].copy_(inputTensors_[j], true);
}
},
[&](at::xpu::XPUStream&,
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL>&) {},
OpType::REDUCE_SCATTER);
} else {
const auto num_reduces = inputTensors_.size();
startCoalescing();
for (const int i : c10::irange(num_reduces)) {
auto& input = inputTensors_[i];
auto& output = (i == rank_) ? outputTensor : input;
auto reduceOpts = ReduceOptions{
opts.reduceOp,
static_cast<int64_t>(i),
static_cast<int64_t>(0),
opts.timeout};
_reduce_oop(output, input, reduceOpts);
}
auto work = endCoalescing(OpType::REDUCE_SCATTER);
return work;
}
}

} // namespace c10d

#endif // USE_C10D_XCCL
18 changes: 15 additions & 3 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ class TORCH_API ProcessGroupXCCL : public Backend {
return std::string(XCCL_BACKEND_NAME);
}

void startCoalescing() override;

c10::intrusive_ptr<Work> endCoalescing() override;

c10::intrusive_ptr<Work> endCoalescing(OpType optype);

std::shared_ptr<xcclComm_t> getXCCLComm(
const std::string& deviceKey,
at::Device& device);
Expand Down Expand Up @@ -221,6 +227,11 @@ class TORCH_API ProcessGroupXCCL : public Backend {
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;

c10::intrusive_ptr<Work> _reduce_oop(
at::Tensor& outputTensors,
at::Tensor& inputTensors,
const ReduceOptions& opts = ReduceOptions());

c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
Expand All @@ -246,9 +257,7 @@ class TORCH_API ProcessGroupXCCL : public Backend {
c10::intrusive_ptr<Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
TORCH_CHECK(false, "ProcessGroupXCCL::reduce_scatter not implemented");
}
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;

c10::intrusive_ptr<Work> _reduce_scatter_base(
at::Tensor& outputTensor,
Expand Down Expand Up @@ -327,6 +336,9 @@ class TORCH_API ProcessGroupXCCL : public Backend {
std::unordered_map<std::string, std::shared_ptr<xcclComm_t>> devXCCLCommMap_;
c10::intrusive_ptr<Store> store_;
std::mutex mutex_;
int coalescing_state_ = 0;
at::Device coalescedDevice_ = at::Device("xpu");
std::shared_ptr<xcclComm_t> coalescedComm_ = nullptr;
bool blockingWait_ = false;
static thread_local uint64_t xcclActiveGroupCounter_;
};
Expand Down

0 comments on commit 009e334

Please sign in to comment.