Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add barrier-free collective algorithms #248

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 99 additions & 2 deletions include/mscclpp/packet_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#endif // defined(MSCCLPP_DEVICE_COMPILE)

namespace mscclpp {

/// LL (low latency) protocol packet.
union alignas(16) LLPacket {
// Assume data is written with an atomicity of 8 bytes (IB/RDMA).
Expand Down Expand Up @@ -43,6 +42,8 @@ union alignas(16) LLPacket {
ulonglong2* p = reinterpret_cast<ulonglong2*>(&reg);
atomicStore(&(raw_.x), p->x, memoryOrderRelaxed);
atomicStore(&(raw_.y), p->y, memoryOrderRelaxed);
// __builtin_nontemporal_store(p->x, &(raw_.x));
// __builtin_nontemporal_store(p->y, &(raw_.y));
#endif
}

Expand Down Expand Up @@ -77,9 +78,22 @@ union alignas(16) LLPacket {
/// @param flag The flag to read.
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
/// @return The 8-byte data read.
MSCCLPP_DEVICE_INLINE uint2 read(uint32_t flag, int64_t maxSpinCount = 100000000) const {
MSCCLPP_DEVICE_INLINE uint2 read(uint32_t flag, int64_t maxSpinCount = 1000000000) const {
uint2 data;
POLL_MAYBE_JAILBREAK(readOnce(flag, data), maxSpinCount);
// int64_t spins = 0;
// ulonglong2 reg;
// uint4* ptr;

// do {
// reg.x = __builtin_nontemporal_load(&(raw_.x));
// reg.y = __builtin_nontemporal_load(&(raw_.y));
// ptr = reinterpret_cast<uint4*>(&reg);
// // if (spins >= maxSpinCount) break;
// // spins++;
// } while ((ptr->y != flag) || (ptr->w != flag));
// data.x = ptr->x;
// data.y = ptr->z;
return data;
}

Expand All @@ -88,6 +102,61 @@ union alignas(16) LLPacket {
#endif // defined(MSCCLPP_DEVICE_COMPILE)
};

union alignas(8) LLPacket2 {
// Assume data is written with an atomicity of 8 bytes (IB/RDMA).
struct {
uint32_t data;
uint32_t flag;
};
uint64_t raw_;
#if defined(MSCCLPP_DEVICE_COMPILE)

MSCCLPP_DEVICE_INLINE LLPacket2() {}

MSCCLPP_DEVICE_INLINE void write(uint32_t val, uint32_t flag) {
#if defined(MSCCLPP_DEVICE_CUDA)
#else // !defined(MSCCLPP_DEVICE_CUDA)
uint2 reg = make_uint2(val, flag);
uint64_t* p = reinterpret_cast<uint64_t*>(&reg);
// __builtin_nontemporal_store(*p, &(raw_));
atomicStore(&(raw_), *p, memoryOrderRelaxed);
#endif
}

MSCCLPP_DEVICE_INLINE bool readOnce(uint32_t flag, uint32_t& data) const {
#if defined(MSCCLPP_DEVICE_CUDA)
#else // !defined(MSCCLPP_DEVICE_CUDA)
uint64_t reg;
reg = atomicLoad(&(raw_), memoryOrderRelaxed);
// reg = __builtin_nontemporal_load(&(raw_));
uint2* ptr = reinterpret_cast<uint2*>(&reg);
data = ptr->x;
return (ptr->y != flag);
#endif
}

MSCCLPP_DEVICE_INLINE uint32_t read(uint32_t flag, int64_t maxSpinCount = 100000000) const {
uint32_t data;
// POLL_MAYBE_JAILBREAK(readOnce(flag, data), maxSpinCount);
int64_t spins = 0;
uint64_t reg;
uint2* ptr;

do {
reg = atomicLoad(&(raw_), memoryOrderRelaxed);
ptr = reinterpret_cast<uint2*>(&reg);
if (spins >= maxSpinCount) break;
spins++;
} while ((ptr->y != flag));
data = ptr->x;
return data;
}

/// Clear the packet.
MSCCLPP_DEVICE_INLINE void clear() { raw_ = 0; }
#endif // defined(MSCCLPP_DEVICE_COMPILE)
};

#if defined(MSCCLPP_DEVICE_COMPILE)
/// Read from the origin and write to the target buffer.
MSCCLPP_DEVICE_INLINE void putPackets(void* targetPtr, uint64_t targetOffset, const void* originPtr,
Expand Down Expand Up @@ -116,6 +185,34 @@ MSCCLPP_DEVICE_INLINE void getPackets(const void* targetPtr, uint64_t targetOffs
originBase[i] = pkt->read(flag);
}
}

/// Read from the origin and write to the target buffer.
MSCCLPP_DEVICE_INLINE void putPackets2(void* targetPtr, uint64_t targetOffset, const void* originPtr,
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
uint32_t numThreads, uint32_t flag) {
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
const uint32_t* originBase = (const uint32_t*)((const char*)originPtr + originOffset);
LLPacket2* targetBase = (LLPacket2*)((char*)targetPtr + targetOffset);
size_t nElem = originBytes / sizeof(uint32_t);
for (size_t i = threadId; i < nElem; i += numThreads) {
LLPacket2* pkt = &targetBase[i];
pkt->write(originBase[i], flag);
}
}

/// Read from the target buffer and write to the origin.
MSCCLPP_DEVICE_INLINE void getPackets2(const void* targetPtr, uint64_t targetOffset, void* originPtr,
uint64_t originOffset, uint64_t originBytes, uint32_t threadId,
uint32_t numThreads, uint32_t flag) {
// Offsets should be aligned to 8 bytes & size should be a multiple of 8 bytes
const LLPacket2* targetBase = (const LLPacket2*)((const char*)targetPtr + targetOffset);
uint32_t* originBase = (uint32_t*)((char*)originPtr + originOffset);
size_t nElem = originBytes / sizeof(uint32_t);
for (size_t i = threadId; i < nElem; i += numThreads) {
const LLPacket2* pkt = &targetBase[i];
originBase[i] = pkt->read(flag);
}
}
#endif // defined(MSCCLPP_DEVICE_COMPILE)

}; // namespace mscclpp
Expand Down
10 changes: 10 additions & 0 deletions include/mscclpp/sm_channel_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,16 @@ struct SmChannelDeviceHandle {
mscclpp::getPackets(getPacketBuffer_, targetOffset, src_, originOffset, originBytes, threadId, numThreads, flag);
}

MSCCLPP_DEVICE_INLINE void putPackets2(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
mscclpp::putPackets2(dst_, targetOffset, src_, originOffset, originBytes, threadId, numThreads, flag);
}

MSCCLPP_DEVICE_INLINE void getPackets2(uint64_t targetOffset, uint64_t originOffset, uint64_t originBytes,
uint32_t threadId, uint32_t numThreads, uint32_t flag) {
mscclpp::getPackets2(getPacketBuffer_, targetOffset, src_, originOffset, originBytes, threadId, numThreads, flag);
}

/// Signal the remote semaphore.
///
/// This function guarantees that all the memory operation before this function is completed before the remote
Expand Down
18 changes: 14 additions & 4 deletions test/mp_unit/sm_channel_tests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,11 @@ TEST_F(SmChannelOneToOneTest, GetPingPong) {
EXPECT_EQ(*ret, 0);
}

__global__ void kernelSmPacketPingPong(int* buff, int rank, int nElem, int* ret) {
__global__ void kernelSmPacketPingPong(int* buff, int rank, int nElem, int* ret, int nTries = 1000) {
if (rank > 1) return;

DeviceHandle<mscclpp::SmChannel>& smChan = gChannelOneToOneTestConstSmChans;
volatile int* sendBuff = (volatile int*)buff;
int nTries = 1000;
int putOffset = (rank == 0) ? 0 : 10000000;
int getOffset = (rank == 0) ? 10000000 : 0;
for (int i = 0; i < nTries; i++) {
Expand Down Expand Up @@ -305,8 +304,6 @@ TEST_F(SmChannelOneToOneTest, PacketPingPong) {
// The least nelem is 2 for packet ping pong
kernelSmPacketPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 2, ret.get());
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());

EXPECT_EQ(*ret, 0);
*ret = 0;

kernelSmPacketPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, ret.get());
Expand All @@ -325,4 +322,17 @@ TEST_F(SmChannelOneToOneTest, PacketPingPong) {
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());

EXPECT_EQ(*ret, 0);
*ret = 0;

int nTries = 1000000;
communicator->bootstrap()->barrier();
mscclpp::Timer timer;
kernelSmPacketPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1024, ret.get(), nTries);
MSCCLPP_CUDATHROW(cudaDeviceSynchronize());
communicator->bootstrap()->barrier();

if (gEnv->rank == 0) {
std::cout << "smPacketPingPong"
<< ": " << std::setprecision(4) << (float)timer.elapsed() / (float)(nTries) << " us/iter\n";
}
}
Loading
Loading