Skip to content

Commit

Permalink
atomic for semaphores instead of fences
Browse files Browse the repository at this point in the history
  • Loading branch information
Mutinifni committed Oct 5, 2023
1 parent b3d0fdb commit efb2520
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions include/mscclpp/semaphore_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#ifndef MSCCLPP_SEMAPHORE_DEVICE_HPP_
#define MSCCLPP_SEMAPHORE_DEVICE_HPP_

#include <cuda/atomic>

#include "poll.hpp"

namespace mscclpp {
Expand Down Expand Up @@ -36,15 +38,18 @@ struct SmDevice2DeviceSemaphoreDeviceHandle {
/// Poll if the remote device has signaled.
/// @return true if the remote device has signaled.
__forceinline__ __device__ bool poll() {
bool signaled = ((*inboundSemaphoreId) > (*expectedInboundSemaphoreId));
bool signaled = (cuda::atomic_ref<uint64_t, cuda::thread_scope_system>{*inboundSemaphoreId}.load(
cuda::memory_order_acquire) > (*expectedInboundSemaphoreId));
if (signaled) (*expectedInboundSemaphoreId) += 1;
return signaled;
}

/// Wait for the remote device to signal.
__forceinline__ __device__ void wait(int64_t maxSpinCount = 10000000) {
(*expectedInboundSemaphoreId) += 1;
POLL_MAYBE_JAILBREAK((*inboundSemaphoreId) < (*expectedInboundSemaphoreId), maxSpinCount);
POLL_MAYBE_JAILBREAK((cuda::atomic_ref<uint64_t, cuda::thread_scope_system>{*inboundSemaphoreId}.load(
cuda::memory_order_acquire) < (*expectedInboundSemaphoreId)),
maxSpinCount);
}

/// Signal the remote device.
Expand All @@ -55,9 +60,9 @@ struct SmDevice2DeviceSemaphoreDeviceHandle {
__forceinline__ __device__ void signal() {
// This fence ensures that preceding writes are visible on the peer GPU before the incremented
// `outboundSemaphoreId` is visible.
__threadfence_system();
semaphoreIncrement();
*remoteInboundSemaphoreId = semaphoreGetLocal();
cuda::atomic_ref<uint64_t, cuda::thread_scope_system>{*remoteInboundSemaphoreId}.store(semaphoreGetLocal(),
cuda::memory_order_release);
}

/// Signal the remote device for copied packets.
Expand All @@ -78,9 +83,9 @@ struct SmDevice2DeviceSemaphoreDeviceHandle {
__forceinline__ __device__ uint64_t semaphoreGetLocal() const { return *outboundSemaphoreId; }
#endif // __CUDACC__

volatile uint64_t* inboundSemaphoreId;
uint64_t* inboundSemaphoreId;
uint64_t* outboundSemaphoreId;
volatile uint64_t* remoteInboundSemaphoreId;
uint64_t* remoteInboundSemaphoreId;
uint64_t* expectedInboundSemaphoreId;
};

Expand Down

0 comments on commit efb2520

Please sign in to comment.