Skip to content

Commit

Permalink
enable coalese
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Sep 10, 2024
1 parent 227e98d commit 86cd619
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 3 deletions.
137 changes: 137 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@
namespace c10d {

namespace {

// wait nonblocking implement
AutoXcclGroup::AutoXcclGroup() {
comm_ = nullptr;
ccl::group_start();
}

AutoNcclGroup::AutoNcclGroup(xcclComm_t comm) {
comm_ = comm;
ccl::group_start();
}

AutoNcclGroup::~AutoNcclGroup() noexcept(false) {
ccl::group_end();
}

std::map<c10d::ReduceOp, ccl::reduction> xcclOps = {
{ReduceOp::MIN, ccl::reduction::min},
{ReduceOp::MAX, ccl::reduction::max},
Expand Down Expand Up @@ -85,6 +101,34 @@ void check_xpu_single_tensor(const at::Tensor& tensor) {
}
}

int64_t check_xpu_tensors_same_device(const std::vector<at::Tensor>& tensors) {
if (tensors.size() == 0) {
C10_THROW_ERROR(ValueError, "Tensor list must be nonempty");
}

const auto& first = tensors.front();

int64_t total_numel = 0;
for (const auto& t : tensors) {
if (!t.is_xpu() || t.is_sparse()) {
C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense");
}
if (t.scalar_type() != first.scalar_type()) {
C10_THROW_ERROR(TypeError, "Tensors must have identical type");
}
if (!t.is_non_overlapping_and_dense()) {
C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense");
}
TORCH_CHECK_WITH(
ValueError,
t.get_device() == tensors[0].get_device(),
"Expected list of tensors on the same device");
total_numel += t.numel();
}

return total_numel;
}

ccl::datatype getXcclDataType(at::ScalarType type) {
auto it = xcclDatatypes.find(type);
TORCH_CHECK_WITH(
Expand Down Expand Up @@ -132,6 +176,9 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
static std::mutex xcclCommDevIdxMapMutex;
static std::unordered_map<std::shared_ptr<xcclComm_t>, int> xcclCommDevIdxMap;

// Before implementing send/recv, the xcclActiveGroupCounter_ variable has no effect.
thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0;

ProcessGroupXCCL::WorkXCCL::WorkXCCL(
at::Device& device,
int rank,
Expand Down Expand Up @@ -240,6 +287,16 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
return it->second;
}

void ProcessGroupXCCL::groupStart() {
ccl::group_start();
++xcclActiveGroupCounter_;
}

void ProcessGroupXCCL::groupEnd() {
ccl::group_end();
--xcclActiveGroupCounter_;
}

template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
at::Tensor& input,
Expand Down Expand Up @@ -302,6 +359,53 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
opType);
}

template <typename Fn>
c10::intrusive_ptr<Work> ProcessGroupXCCL::collectiveCoalesced(
std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
Fn fn,
OpType opType) {

using traits = function_traits<Fn>;
using attr_t = typename traits::template arg<2>::type;
attr_t attr = ccl::create_operation_attr<attr_t>();

auto device = inputs[0].device();
const auto key = std::to_string(device.index());
auto comm = getXCCLComm(key, device);

auto stream = xcclStreams_.at(key);
auto ccl_stream = ccl::create_stream(stream.queue());

c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;

work = initWork(device, rank_, opType);

work->outputs_ =
std::make_shared<std::vector<at::Tensor>>(std::move(outputs));

{
AutoXcclGroup xccl_group_guard(comm);
for (const auto i : c10::irange(inputs.size())) {
c10::xpu::XPUCachingAllocator::recordStream(
inputs[i].storage().data_ptr(), stream);
fn(inputs[i], outputs[i], attr, *comm, ccl_stream);
}
}

work->xcclEndEvent_->record(stream);

std::vector<c10::Stream> streams = {stream.unwrap()};
c10::MultiStreamGuard streamGuard(streams);
std::vector<at::Device> devices{device};
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devices);
work->future_->markCompleted(at::IValue(*work->outputs_));

return work;

}

c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
Expand Down Expand Up @@ -334,6 +438,39 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
OpType::ALLREDUCE);
}

c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) {
check_xpu_tensors_same_device(tensors);
TORCH_CHECK(
!isFloat8Type(tensors.back().scalar_type()),
"Float8 dtypes are not currenlty supported for XCCL reductions");

return collectiveCoalesced(
tensors,
tensors,
[&](at::Tensor& input,
at::Tensor& output,
xcclComm_t& comm,
ccl::stream& stream) {
auto xcclDataType = getXcclDataType(input.scalar_type());
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
ccl::event ret_evt;
ret_evt = ccl::allreduce(
input.data_ptr(),
output.data_ptr(),
(size_t)input.numel(),
xcclDataType,
xcclReduceOp,
comm,
stream,
attr);
return ret_evt;
},
OpType::COALESCED);
}

} // namespace c10d

#endif // USE_C10D_XCCL
26 changes: 23 additions & 3 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ using xcclComm_t = ccl::communicator;
using XCCL_KVS = ccl::shared_ptr_class<ccl::kvs>;
constexpr const char* XCCL_BACKEND_NAME = "xccl";

namespace {
struct AutoXcclGroup {
AutoXcclGroup();
AutoXcclGroup(xcclComm_t comm);
~AutoXcclGroup() noexcept(false);
xcclComm_t comm_;
};
} // namespace

class TORCH_API ProcessGroupXCCL : public Backend {
public:
class WorkXCCL : public Work {
Expand Down Expand Up @@ -126,16 +135,21 @@ class TORCH_API ProcessGroupXCCL : public Backend {
PostProcess post,
OpType opType);

template <typename Fn>
c10::intrusive_ptr<Work> collectiveCoalesced(
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn,
OpType opType);

c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;

c10::intrusive_ptr<Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override {
TORCH_CHECK(false, "ProcessGroupXCCL::allreduce_coalesced not implemented");
}
AllreduceCoalescedOptions()) override;

c10::intrusive_ptr<Work> reduce(
std::vector<at::Tensor>& tensors,
Expand Down Expand Up @@ -238,6 +252,10 @@ class TORCH_API ProcessGroupXCCL : public Backend {
TORCH_CHECK(false, "ProcessGroupXCCL::recv not implemented");
}

void groupStart();

void groupEnd();

c10::intrusive_ptr<Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
Expand All @@ -259,6 +277,8 @@ class TORCH_API ProcessGroupXCCL : public Backend {
std::unordered_map<std::string, std::shared_ptr<xcclComm_t>> devXCCLCommMap_;
c10::intrusive_ptr<Store> store_;
std::mutex mutex_;
static thread_local uint64_t xcclActiveGroupCounter_;
uint64_t seqCollective_{0};
};
} // namespace c10d

Expand Down

0 comments on commit 86cd619

Please sign in to comment.