Skip to content

Commit

Permalink
code refine
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxiaoli73 committed Oct 13, 2024
1 parent fd44abe commit d12b922
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 111 deletions.
9 changes: 5 additions & 4 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,13 @@ ncclRedOpRAII getNcclReduceOp(
NCCL_MINOR));
break;
case ReduceOp::BAND:
C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with NCCL");
break;
case ReduceOp::BOR:
C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with NCCL");
break;
case ReduceOp::BXOR:
C10_THROW_ERROR(
ValueError,
"Cannot use ReduceOp." + reduce_op_to_string(reduceOp) +
" with NCCL");
C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with NCCL");
break;
default:
C10_THROW_ERROR(ValueError, "Unhandled ReduceOp");
Expand Down
72 changes: 19 additions & 53 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
namespace c10d {

namespace {
std::map<c10d::ReduceOp, ccl::reduction> xcclOps = {
const std::map<c10d::ReduceOp, ccl::reduction> xcclOps = {
{ReduceOp::MIN, ccl::reduction::min},
{ReduceOp::MAX, ccl::reduction::max},
{ReduceOp::SUM, ccl::reduction::sum},
{ReduceOp::PRODUCT, ccl::reduction::prod},
};

std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
const std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
{at::kByte, ccl::datatype::uint8},
{at::kChar, ccl::datatype::int8},
{at::kInt, ccl::datatype::int32},
Expand All @@ -35,7 +35,7 @@ std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
{at::kBool, ccl::datatype::uint8},
};

void check_xpu_single_tensor(const at::Tensor& tensor) {
void checkXPUTensor(at::Tensor& tensor) {
if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) {
C10_THROW_ERROR(
ValueError, "Tensors must be XPU and dense and non-complex");
Expand All @@ -62,19 +62,15 @@ ccl::datatype getXcclDataType(

ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
try {
if (input.scalar_type() == at::kBool) {
if (reduceOp == ReduceOp::SUM) {
// For bool tensors, map sum to max, which both represent a bitwise
// or. This is to prevent overflow issues with sum, since we use uint8
// to represent a bool (see xcclDatatypes mapping align with cuda).
return ccl::reduction::max;
}
if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) {
// Map sum to max for bool tensors to avoid overflow issues with sum.
return ccl::reduction::max;
}
return xcclOps.at(reduceOp);
} catch (const std::out_of_range&) {
C10_THROW_ERROR(
ValueError,
"Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + " with XCCL");
"Cannot use ReduceOp." + reduceOpToString(reduceOp) + " with XCCL");
}
}

Expand All @@ -85,7 +81,6 @@ void syncStream(
xcclEvent.record(at::xpu::getCurrentXPUStream(device.index()));
xcclEvent.block(xcclStream);
}

} // namespace

constexpr int64_t kSynchronizeBusyWaitMillis = 10;
Expand All @@ -98,8 +93,7 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(
: Work(rank, opType, "profilingTitle", inputs),
device_(device),
workStartTime_(std::chrono::steady_clock::now()) {
unsigned char enable_timing = 0;
xcclEndEvent_ = std::make_shared<at::xpu::XPUEvent>(enable_timing);
xcclEndEvent_ = std::make_shared<at::xpu::XPUEvent>();
}

ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w)
Expand Down Expand Up @@ -133,12 +127,11 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal(
currentTimepoint - workStartTime_);
if (timeElapsed >= timeout) {
std::string exceptionMsg = c10::str(
"Work ran for ",
"Work ran time out after ",
timeElapsed.count(),
" milliseconds before timing out.");
" milliseconds.");
TORCH_CHECK(false, exceptionMsg)
}

std::this_thread::sleep_for(
std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
}
Expand All @@ -157,20 +150,6 @@ ProcessGroupXCCL::ProcessGroupXCCL(
: Backend(rank, size), store_(store) {
blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false);
init();

// Intel oneCCL requires passing CCL_LOCAL_RANK and CCL_LOCAL_SIZE for non-MPI
// launchers.
if (!with_mpirun()) {
int local_rank = getXCCLEnvVar("LOCAL_RANK");
int local_world_size = getXCCLEnvVar("LOCAL_WORLD_SIZE");
if (local_rank == -1 || local_world_size == -1) {
local_rank = rank;
local_world_size = size;
}
setXCCLEnvVar("CCL_PROCESS_LAUNCHER", "none");
setXCCLEnvVar("CCL_LOCAL_RANK", local_rank);
setXCCLEnvVar("CCL_LOCAL_SIZE", local_world_size);
}
}

ProcessGroupXCCL::~ProcessGroupXCCL() = default;
Expand All @@ -189,24 +168,16 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
const std::string& deviceKey,
at::Device& device) {
if (deviceKey.empty()) {
C10_THROW_ERROR(
DistBackendError,
"Not able to create/get the XCCL Communicator since "
"the devices are empty ");
}

TORCH_CHECK_WITH(DistBackendError, deviceKey.empty(), "Not able to create/get "
"XCCL Communicator since the devices are empty ");
{
// todo: why do we need mutex here?
std::lock_guard<std::mutex> lock(mutex_);
if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) {
return devXCCLCommMap_[deviceKey];
}
}

std::shared_ptr<xcclComm_t> XCCLComm;

XCCL_KVS kvs = get_kvs(rank_, *store_);

int numRanks, rank;
numRanks = getSize();
rank = getRank();
Expand All @@ -219,14 +190,12 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
ccl::vector_class<ccl::pair_class<int, ccl::device>> devs_rank;
devs_rank.emplace_back(rank, ccl::create_device(q.get_device()));

auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, kvs);
XCCLComm = std::make_shared<xcclComm_t>(std::move(comms[0]));

{
std::lock_guard<std::mutex> lock(mutex_);
devXCCLCommMap_.emplace(deviceKey, XCCLComm);
}
auto xccl_kvs = get_kvs(rank_, *store_);
auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs);
std::shared_ptr<xcclComm_t> XCCLComm = std::make_shared<xcclComm_t>(std::move(comms[0]));

std::lock_guard<std::mutex> lock(mutex_);
devXCCLCommMap_.emplace(deviceKey, XCCLComm);
xcclStreamsMap_.emplace(deviceKey, std::move(stream));
xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent());

Expand All @@ -250,11 +219,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(

c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
work = initWork(device, rank_, opType);

work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);

at::xpu::OptionalXPUGuard gpuGuard(device);

pre(stream, work);
for (const auto i : c10::irange(inputs.size())) {
c10::xpu::XPUCachingAllocator::recordStream(
Expand All @@ -264,7 +231,6 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
post(stream, work);

work->xcclEndEvent_->record(stream);

std::vector<c10::Stream> streams = {stream.unwrap()};
c10::MultiStreamGuard streamGuard(streams);
std::vector<at::Device> devices{device};
Expand All @@ -282,7 +248,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
TORCH_CHECK(
tensors.size() == 1, "Expecting one tensor only but got multiple");
auto tensor = tensors.back();
check_xpu_single_tensor(tensor);
checkXPUTensor(tensor);
return collective(
tensor,
tensor,
Expand Down
61 changes: 8 additions & 53 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ static std::vector<std::string> TORCH_XCCL_BLOCKING_WAIT = {
"XCCL_BLOCKING_WAIT"};

using xcclComm_t = ccl::communicator;
using XCCL_KVS = ccl::shared_ptr_class<ccl::kvs>;
constexpr const char* XCCL_BACKEND_NAME = "xccl";

class TORCH_API ProcessGroupXCCL : public Backend {
Expand Down Expand Up @@ -113,9 +112,11 @@ class TORCH_API ProcessGroupXCCL : public Backend {
at::Tensor& output,
Fn fn,
OpType opType) {
auto inputs = std::vector<at::Tensor>{input};
auto outputs = std::vector<at::Tensor>{output};
return collective<Fn>(
input,
output,
inputs,
outputs,
fn,
[](at::xpu::XPUStream&,
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL>&) {},
Expand All @@ -124,19 +125,6 @@ class TORCH_API ProcessGroupXCCL : public Backend {
opType);
}

template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> collective(
at::Tensor& input,
at::Tensor& output,
Fn fn,
PreProcess pre,
PostProcess post,
OpType opType) {
auto inputs = std::vector<at::Tensor>{input};
auto outputs = std::vector<at::Tensor>{output};
return collective(inputs, outputs, fn, pre, post, opType);
}

template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> collective(
std::vector<at::Tensor>& inputs,
Expand All @@ -159,9 +147,11 @@ class TORCH_API ProcessGroupXCCL : public Backend {
bool blockingWait_ = false;

private:
XCCL_KVS kvs;
std::mutex kvs_mutex;
XCCL_KVS get_kvs(int rank, c10d::Store& store) {
ccl::shared_ptr_class<ccl::kvs> kvs;

ccl::shared_ptr_class<ccl::kvs> get_kvs(int rank, c10d::Store& store) {
// todo: why do we need the mutex here?
std::lock_guard<std::mutex> lock(kvs_mutex);
if (kvs)
return kvs;
Expand All @@ -186,41 +176,6 @@ class TORCH_API ProcessGroupXCCL : public Backend {
return kvs;
}
};

namespace {
int getXCCLEnvVar(std::string envVarName) {
char* stringValue = std::getenv(envVarName.c_str());
if (stringValue != nullptr) {
try {
int val = std::stoi(stringValue);
return val;
} catch (std::exception& e) {
TORCH_CHECK(
false,
"Invalid value for environment variable: " + std::string(envVarName));
}
} else {
return -1;
}
}

template <typename T>
void setXCCLEnvVar(const std::string& envVarName, T val) {
if constexpr (std::is_same_v<T, int>) {
setenv(envVarName.c_str(), std::to_string(val).c_str(), 1);
} else if constexpr (std::is_same_v<T, std::string>) {
setenv(envVarName.c_str(), val.c_str(), 1);
}
}

bool with_mpirun() {
return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") ||
getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK"))
? true
: false;
}

} // namespace
} // namespace c10d

#endif // USE_C10D_XCCL
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/Utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ size_t computeLengthsAndOffsets(
return offset;
}

inline std::string reduce_op_to_string(c10d::ReduceOp op) {
inline std::string reduceOpToString(c10d::ReduceOp op) {
switch (op) {
case c10d::ReduceOp::SUM:
return "SUM";
Expand Down

0 comments on commit d12b922

Please sign in to comment.