Skip to content

Commit

Permalink
Merge branch 'xccl-bak' into xccl-p2p
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Oct 21, 2024
2 parents f85a845 + 86b66c3 commit 1e68c30
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 186 deletions.
3 changes: 0 additions & 3 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1376,9 +1376,6 @@ if(USE_DISTRIBUTED)
endif()
if(USE_XPU AND USE_C10D_XCCL)
target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL)
set_source_files_properties(
${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupXCCL.cpp
PROPERTIES COMPILE_DEFINITIONS "CCL_ENABLE_ZE;CCL_ENABLE_SYCL")
endif()
if(USE_MPI AND USE_C10D_MPI)
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
Expand Down
7 changes: 4 additions & 3 deletions cmake/Modules/FindXCCL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

include(FindPackageHandleStandardArgs)

set(XCCL_ROOT "")
if(DEFINED ENV{CCL_ROOT})
set(XCCL_ROOT $ENV{CCL_ROOT})
set(XCCL_ROOT "/opt/intel/oneapi/ccl/latest")
if (NOT EXISTS "${XCCL_ROOT}")
message(STATUS "Default OneCCL not found, using current environment OneAPI")
set(XCCL_ROOT $ENV{ONEAPI_ROOT}/ccl/latest)
endif()

string(COMPARE EQUAL "${XCCL_ROOT}" "" nocclfound)
Expand Down
13 changes: 5 additions & 8 deletions test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
get_device_count,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
Expand Down Expand Up @@ -60,17 +61,13 @@
torch.backends.cuda.matmul.allow_tf32 = False


def gpus_for_rank(world_size):
def gpus_for_rank(world_size, backend):
"""Multigpu tests are designed to simulate the multi nodes with multi
GPUs on each node. Nccl backend requires equal #GPUs in each process.
On a single node, all visible GPUs are evenly
divided to subsets, each process only uses a subset.
"""
device_count = (
torch.xpu.device_count()
if torch.xpu.is_available()
else torch.cuda.device_count()
)
device_count = get_device_count(backend)
visible_devices = list(range(device_count))
gpus_per_process = device_count // world_size
gpus_for_rank = []
Expand Down Expand Up @@ -833,7 +830,7 @@ def update_parameters(model):
def _gpu_model_with_ddp_comm_hook(
self, process_group, hook=None, gradient_as_bucket_view=False, state=None
):
device_id = gpus_for_rank(self.world_size)[self.rank][0]
device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0]
gpu_model = DistributedDataParallel(
ModuleForDdpCommHook().to(device_id),
device_ids=[device_id],
Expand All @@ -850,7 +847,7 @@ def _gpu_model_with_ddp_comm_hook(
def _gpu_model_with_builtin_ddp_comm_hook(
self, process_group, hook=None, gradient_as_bucket_view=False
):
device_id = gpus_for_rank(self.world_size)[self.rank][0]
device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0]
gpu_model = DistributedDataParallel(
ModuleForDdpCommHook().to(device_id),
device_ids=[device_id],
Expand Down
21 changes: 11 additions & 10 deletions torch/csrc/distributed/c10d/ProcessGroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
NCCL = 2,
UCC = 3,
MPI = 4,
CUSTOM = 5,
XCCL = 6,
XCCL = 5,
CUSTOM = 6,
};

static std::string backendTypeToString(const BackendType& type) {
Expand Down Expand Up @@ -131,6 +131,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
return backendType_;
};

inline bool backendSupportsSequenceNumbers(BackendType backendType) {
if (backendType == BackendType::GLOO || backendType == BackendType::NCCL ||
backendType == BackendType::XCCL || backendType == BackendType::UCC)
return true;
return false;
}

virtual void startCoalescing(c10::DeviceType deviceType) {
// only nccl has implemented startCoalescing so only execute for nccl
// backends
Expand Down Expand Up @@ -508,10 +515,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
virtual void setSequenceNumberForGroup() {
auto backendType = getBackendType();
// TODO: HACK for backend name to get sequence number for that backend.
if (backendType == ProcessGroup::BackendType::GLOO ||
backendType == ProcessGroup::BackendType::NCCL ||
backendType == ProcessGroup::BackendType::XCCL ||
backendType == ProcessGroup::BackendType::UCC) {
if (backendSupportsSequenceNumbers(backendType)) {
getDefaultBackend()->setSequenceNumberForGroup();
} else {
TORCH_CHECK(
Expand All @@ -530,10 +534,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
auto backendType = getBackendType();

// TODO: HACK for backend name to get sequence number for that backend.
if (backendType == ProcessGroup::BackendType::GLOO ||
backendType == ProcessGroup::BackendType::NCCL ||
backendType == ProcessGroup::BackendType::XCCL ||
backendType == ProcessGroup::BackendType::UCC) {
if (backendSupportsSequenceNumbers(backendType)) {
return getDefaultBackend()->getSequenceNumberForGroup();
} else {
TORCH_CHECK(
Expand Down
131 changes: 49 additions & 82 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
@@ -1,40 +1,20 @@
#ifdef USE_C10D_XCCL

#include <comm/XPUGuard.h>
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp>
#include <exception>
#include <fstream>
#include <map>
#include <sstream>
#include <stdexcept>
#include <tuple>
#include <unordered_set>
#include <utility>

#include <c10/core/DeviceType.h>
#include <c10/util/Optional.h>

namespace c10d {

namespace {

// wait nonblocking implement
AutoXcclGroup::AutoXcclGroup() {
ccl::group_start();
}

AutoXcclGroup::~AutoXcclGroup() noexcept(false) {
ccl::group_end();
}

std::map<c10d::ReduceOp, ccl::reduction> xcclOps = {
const std::map<c10d::ReduceOp, ccl::reduction> xcclOps = {
{ReduceOp::MIN, ccl::reduction::min},
{ReduceOp::MAX, ccl::reduction::max},
{ReduceOp::SUM, ccl::reduction::sum},
{ReduceOp::PRODUCT, ccl::reduction::prod},
};

std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
const std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
{at::kByte, ccl::datatype::uint8},
{at::kChar, ccl::datatype::int8},
{at::kInt, ccl::datatype::int32},
Expand Down Expand Up @@ -158,19 +138,15 @@ ccl::datatype getXcclDataType(at::ScalarType type) {

ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
try {
if (input.scalar_type() == at::kBool) {
if (reduceOp == ReduceOp::SUM) {
// For bool tensors, map sum to max, which both represent a bitwise or.
// This is to prevent overflow issues with sum, since we use uint8 to
// represent a bool (see xcclDatatypes mapping align with cuda).
return ccl::reduction::max;
}
if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) {
// Map sum to max for bool tensors to avoid overflow issues with sum.
return ccl::reduction::max;
}
return xcclOps.at(reduceOp);
} catch (const std::out_of_range&) {
C10_THROW_ERROR(
ValueError,
"Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + " with XCCL");
"Cannot use ReduceOp." + reduceOpToString(reduceOp) + " with XCCL");
}
}

Expand All @@ -196,29 +172,30 @@ bool complexViewAsRealAllowed(const ReduceOp reduceOp) {

} // namespace

static std::mutex xcclCommDevIdxMapMutex;
static std::unordered_map<std::shared_ptr<xcclComm_t>, int> xcclCommDevIdxMap;
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0;

ProcessGroupXCCL::WorkXCCL::WorkXCCL(
at::Device& device,
int rank,
OpType opType,
uint64_t seq,
const char* profilingTitle,
const std::optional<std::vector<at::Tensor>>& inputs)
: Work(rank, opType, "profilingTitle", inputs),
: Work(rank, opType, profilingTitle, inputs),
device_(device),
workStartTime_(std::chrono::steady_clock::now()) {
unsigned char enable_timing = 0;
xcclEndEvent_ = std::make_shared<at::xpu::XPUEvent>(enable_timing);
workStartTime_(std::chrono::steady_clock::now()),
seq_(seq) {
xcclEndEvent_ = std::make_shared<at::xpu::XPUEvent>();
}

ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w)
: Work(w.rank_, w.opType_),
device_(w.device_),
xcclEndEvent_(w.xcclEndEvent_),
blockingWait_(w.blockingWait_),
workStartTime_(w.workStartTime_) {}
workStartTime_(w.workStartTime_),
seq_(w.seq_) {}

ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default;

Expand All @@ -244,12 +221,9 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal(
currentTimepoint - workStartTime_);
if (timeElapsed >= timeout) {
std::string exceptionMsg = c10::str(
"Work ran for ",
timeElapsed.count(),
" milliseconds before timing out.");
"Work ran time out after ", timeElapsed.count(), " milliseconds.");
TORCH_CHECK(false, exceptionMsg)
}

std::this_thread::sleep_for(
std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
}
Expand All @@ -275,20 +249,6 @@ ProcessGroupXCCL::ProcessGroupXCCL(
: Backend(rank, size), store_(store) {
blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false);
init();

// Intel oneCCL requires passing CCL_LOCAL_RANK and CCL_LOCAL_SIZE for non-MPI
// launchers.
if (!with_mpirun()) {
int local_rank = getXCCLEnvVar("LOCAL_RANK");
int local_world_size = getXCCLEnvVar("LOCAL_WORLD_SIZE");
if (local_rank == -1 || local_world_size == -1) {
local_rank = rank;
local_world_size = size;
}
setXCCLEnvVar("CCL_PROCESS_LAUNCHER", "none");
setXCCLEnvVar("CCL_LOCAL_RANK", local_rank);
setXCCLEnvVar("CCL_LOCAL_SIZE", local_world_size);
}
}

ProcessGroupXCCL::~ProcessGroupXCCL() = default;
Expand All @@ -297,10 +257,16 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
at::Device& device,
int rank,
OpType opType,
const char* profilingTitle,
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& outputs) {
auto r = c10::make_intrusive<ProcessGroupXCCL::WorkXCCL>(
device, rank, opType, std::optional<std::vector<at::Tensor>>(inputs));
device,
rank,
opType,
seqCollective_,
profilingTitle,
std::optional<std::vector<at::Tensor>>(inputs));
return r;
}

Expand All @@ -320,6 +286,7 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
usedDeviceIdxs_.insert(device.index());

{
// todo: why do we need mutex here?
std::lock_guard<std::mutex> lock(mutex_);
if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) {
return devXCCLCommMap_[deviceKey];
Expand Down Expand Up @@ -352,20 +319,23 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
}

c10::impl::VirtualGuardImpl impl(device.type());
c10::Stream stream = impl.getStream(device);
c10::Stream stream =
impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false);
sycl::queue& q = c10::xpu::XPUStream(stream).queue();

auto ctx = ccl::create_context(q.get_context());
ccl::vector_class<ccl::pair_class<int, ccl::device>> devs_rank;
devs_rank.emplace_back(rank, ccl::create_device(q.get_device()));

auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, kvs);
XCCLComm = std::make_shared<xcclComm_t>(std::move(comms[0]));
auto xccl_kvs = get_kvs(rank_, *store_);
auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs);
std::shared_ptr<xcclComm_t> XCCLComm =
std::make_shared<xcclComm_t>(std::move(comms[0]));

{
std::lock_guard<std::mutex> lock(mutex_);
inInitializationCommMap_.emplace(deviceKey, XCCLComm);
}
std::lock_guard<std::mutex> lock(mutex_);
devXCCLCommMap_.emplace(deviceKey, XCCLComm);
xcclStreamsMap_.emplace(deviceKey, std::move(stream));
xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent());

for (const auto i : c10::irange(xcclActiveGroupCounter_)) {
(void)i;
Expand Down Expand Up @@ -510,8 +480,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(

template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
at::Tensor& input,
at::Tensor& output,
std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
Fn fn,
PreProcess pre,
PostProcess post,
Expand Down Expand Up @@ -563,27 +533,23 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collectiveCoalesced(
}
}

auto stream = xcclStreams_.at(key);
syncStream(device, xcclEvents_[key], stream);
auto stream = xcclStreamsMap_.at(key);
syncStream(device, xcclEventsMap_[key], stream);

c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
work = initWork(device, rank_, opType);

work = initWork(device, rank_, opType, profilingTitle);
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);

at::xpu::OptionalXPUGuard gpuGuard(device);

{
AutoXcclGroup xccl_group_guard;
for (const auto i : c10::irange(inputs.size())) {
c10::xpu::XPUCachingAllocator::recordStream(
inputs[i].storage().data_ptr(), stream);
fn(inputs[i], outputs[i], *comm, stream);
}
pre(stream, work);
for (const auto i : c10::irange(inputs.size())) {
c10::xpu::XPUCachingAllocator::recordStream(
inputs[i].storage().data_ptr(), stream);
fn(inputs[i], outputs[i], *comm, stream);
}
post(stream, work);

work->xcclEndEvent_->record(stream);

std::vector<c10::Stream> streams = {stream.unwrap()};
c10::MultiStreamGuard streamGuard(streams);
std::vector<at::Device> devices{device};
Expand Down Expand Up @@ -973,8 +939,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
at::xpu::XPUStream& stream) {
auto xcclDataType = getXcclDataType(input.scalar_type());
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
ccl::event ret_evt;
ret_evt = ccl::allreduce(
auto ccl_stream = ccl::create_stream(stream.queue());
ccl::allreduce(
input.data_ptr(),
output.data_ptr(),
(size_t)input.numel(),
Expand All @@ -984,7 +950,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
ccl::create_stream(stream.queue()));
return ret_evt;
},
OpType::ALLREDUCE);
OpType::ALLREDUCE,
"xccl:all_reduce");
}

c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
Expand Down
Loading

0 comments on commit 1e68c30

Please sign in to comment.