Skip to content

Commit

Permalink
Merge branch 'main' into binyli/ci-nccl
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 authored Dec 7, 2024
2 parents 0e6af86 + 756f24c commit 1403e39
Show file tree
Hide file tree
Showing 25 changed files with 239 additions and 234 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ MSCCL++ provides peer-to-peer communication methods between GPUs. A peer-to-peer

```cpp
// `ProxyChannel` will be explained in the following section.
__device__ mscclpp::DeviceHandle<mscclpp::SimpleProxyChannel> channel;
__device__ mscclpp::DeviceHandle<mscclpp::ProxyChannel> channel;
__global__ void gpuKernel() {
...
// Only one thread is needed for this method.
Expand Down
6 changes: 3 additions & 3 deletions docs/design/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ In this section, we will discuss several use cases that demonstrate the capabili

MSCCL++ enables the offloading of communication logic from the GPU to the CPU, facilitating the overlapping of communication and computation processes. The code snippet provided illustrates this overlapping technique. In the depicted scenario, the GPU emits a signal to the CPU indicating readiness for data transfer. Subsequently, while the GPU continues to execute computation tasks, the CPU initiates the data transfer to the designated target device.
```cpp
__device__ void gpuKernel(mscclpp::SimpleProxyChannelDeviceHandle* proxyChannel) {
__device__ void gpuKernel(mscclpp::ProxyChannelDeviceHandle* proxyChannel) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
// Send a trigger to the CPU
if (tid == 0) {
Expand All @@ -138,11 +138,11 @@ Traditional communication libraries enforce a separation between communication a
MCSCL++ offers a low-level communication API, allowing users to design customized collective communication algorithms. The following code demonstrates how to implement a customized All2All algorithm using MSCCL++.
```cpp
using DeviceHandle = mscclpp::DeviceHandle<T>;
__device__ void localAlltoall(DeviceHandle<mscclpp::SimpleProxyChannel>* proxyChans, int rank,
__device__ void localAlltoall(DeviceHandle<mscclpp::ProxyChannel>* proxyChans, int rank,
int nRanksPerNode, size_t nElements) {
int remoteRank = ((int)blockIdx.x < rank) ? blockIdx.x : blockIdx.x + 1;
for (int i = 1; i < nRanksPerNode; i++) {
DeviceHandle<mscclpp::SimpleProxyChannel> proxyChan = proxyChans[blockIdx.x];
DeviceHandle<mscclpp::ProxyChannel> proxyChan = proxyChans[blockIdx.x];
if (threadIdx.x == 0 && remoteRank % nRanksPerNode == (rank + i) % nRanksPerNode) {
proxyChan.putWithSignalAndFlush(rank * nElements * sizeof(int), remoteRank * nElements * sizeof(int),
nElements * sizeof(int));
Expand Down
10 changes: 5 additions & 5 deletions docs/getting-started/tutorials/initialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ We will setup a mesh topology with eight GPUs. Each GPU will be connected to its

template <class T>
using DeviceHandle = mscclpp::DeviceHandle<T>;
__constant__ DeviceHandle<mscclpp::SimpleProxyChannel> constProxyChans[8];
__constant__ DeviceHandle<mscclpp::ProxyChannel> constProxyChans[8];

void setupMeshTopology(int rank, int worldsize, void* data, size_t dataSize) {
std::string ip_port = "10.0.0.4:50000";
Expand Down Expand Up @@ -55,17 +55,17 @@ void setupMeshTopology(int rank, int worldsize, void* data, size_t dataSize) {

comm.setup();

std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;
std::vector<DeviceHandle<mscclpp::ProxyChannel>> proxyChannels;
for (size_t i = 0; i < semaphoreIds.size(); ++i) {
proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel(
proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::ProxyChannel(
proxyService.proxyChannel(semaphoreIds[i]), proxyService.addMemory(remoteMemories[i].get()),
proxyService.addMemory(localMemories[i]))));
}

if (proxyChannels.size() > sizeof(constProxyChans) / sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>)) {
if (proxyChannels.size() > sizeof(constProxyChans) / sizeof(DeviceHandle<mscclpp::ProxyChannel>)) {
std::runtime_error("unexpected error");
}
CUDACHECK(cudaMemcpyToSymbol(constProxyChans, proxyChannels.data(),
sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>) * proxyChannels.size()));
sizeof(DeviceHandle<mscclpp::ProxyChannel>) * proxyChannels.size()));
}
```
4 changes: 2 additions & 2 deletions docs/getting-started/tutorials/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ We provide some Python utils to help you launch kernel via python. Here is a exa
```python
from mscclpp.utils import KernelBuilder, pack

def launch_kernel(my_rank: int, nranks: int, simple_channels: List[SimpleProxyChannel], memory: cp.ndarray):
def launch_kernel(my_rank: int, nranks: int, simple_channels: List[ProxyChannel], memory: cp.ndarray):
file_dir = os.path.dirname(os.path.abspath(__file__))
kernel = KernelBuilder(file="test.cu", kernel_name="test", file_dir=file_dir).get_compiled_kernel()
params = b""
Expand Down Expand Up @@ -77,7 +77,7 @@ The test kernel is defined in `test.cu` as follows:
// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
extern "C" __global__ void __launch_bounds__(1024, 1)
simple_proxy_channel(mscclpp::SimpleProxyChannelDeviceHandle* channels, int my_rank, int nranks,
proxy_channel(mscclpp::ProxyChannelDeviceHandle* channels, int my_rank, int nranks,
int num_elements) {
int tid = threadIdx.x;
int nthreads = blockDim.x;
Expand Down
59 changes: 33 additions & 26 deletions include/mscclpp/proxy_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

namespace mscclpp {

struct BaseProxyChannel;
struct ProxyChannel;

/// Base class for proxy services. Proxy services are used to proxy data between devices.
Expand Down Expand Up @@ -48,10 +49,17 @@ class ProxyService : public BaseProxyService {
/// @return The semaphore.
std::shared_ptr<Host2DeviceSemaphore> semaphore(SemaphoreId id) const;

/// Get a proxy channel by semaphore ID.
/// Get a base proxy channel by semaphore ID.
/// @param id The ID of the semaphore.
/// @return The base proxy channel.
BaseProxyChannel baseProxyChannel(SemaphoreId id);

/// Get a proxy channel by semaphore ID and memory regions.
/// @param id The ID of the semaphore.
/// @param dst The destination memory region.
/// @param src The source memory region.
/// @return The proxy channel.
ProxyChannel proxyChannel(SemaphoreId id);
ProxyChannel proxyChannel(SemaphoreId id, MemoryId dst, MemoryId src);

/// Start the proxy service.
void startProxy();
Expand All @@ -71,66 +79,65 @@ class ProxyService : public BaseProxyService {
};

/// Proxy channel.
struct ProxyChannel {
private:
struct BaseProxyChannel {
protected:
SemaphoreId semaphoreId_;

std::shared_ptr<Host2DeviceSemaphore> semaphore_;

std::shared_ptr<Proxy> proxy_;

public:
ProxyChannel() = default;
BaseProxyChannel() = default;

ProxyChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore, std::shared_ptr<Proxy> proxy);
BaseProxyChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore,
std::shared_ptr<Proxy> proxy);

ProxyChannel(const ProxyChannel& other) = default;
BaseProxyChannel(const BaseProxyChannel& other) = default;

ProxyChannel& operator=(ProxyChannel& other) = default;
BaseProxyChannel& operator=(BaseProxyChannel& other) = default;

/// Device-side handle for @ref ProxyChannel.
using DeviceHandle = ProxyChannelDeviceHandle;
/// Device-side handle for @ref BaseProxyChannel.
using DeviceHandle = BaseProxyChannelDeviceHandle;

/// Returns the device-side handle.
///
/// User should make sure the ProxyChannel is not released when using the returned handle.
/// User should make sure the BaseProxyChannel is not released when using the returned handle.
///
DeviceHandle deviceHandle() const;
};

/// Simple proxy channel with a single destination and source memory region.
struct SimpleProxyChannel {
/// A common form of proxy channel with a single destination and source memory region.
struct ProxyChannel : public BaseProxyChannel {
private:
ProxyChannel proxyChan_;
MemoryId dst_;
MemoryId src_;

public:
/// Default constructor.
SimpleProxyChannel() = default;
ProxyChannel() = default;

/// Constructor.
/// @param proxyChan The proxy channel.
/// @param semaphoreId The ID of the semaphore.
/// @param semaphore The semaphore.
/// @param proxy The proxy.
/// @param dst The destination memory region.
/// @param src The source memory region.
SimpleProxyChannel(ProxyChannel proxyChan, MemoryId dst, MemoryId src);

/// Constructor.
/// @param proxyChan The proxy channel.
SimpleProxyChannel(ProxyChannel proxyChan) : proxyChan_(proxyChan) {}
ProxyChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore, std::shared_ptr<Proxy> proxy,
MemoryId dst, MemoryId src);

/// Copy constructor.
SimpleProxyChannel(const SimpleProxyChannel& other) = default;
ProxyChannel(const ProxyChannel& other) = default;

/// Assignment operator.
SimpleProxyChannel& operator=(SimpleProxyChannel& other) = default;
ProxyChannel& operator=(ProxyChannel& other) = default;

/// Device-side handle for @ref SimpleProxyChannel.
using DeviceHandle = SimpleProxyChannelDeviceHandle;
/// Device-side handle for @ref ProxyChannel.
using DeviceHandle = ProxyChannelDeviceHandle;

/// Returns the device-side handle.
///
/// User should make sure the SimpleProxyChannel is not released when using the returned handle.
/// User should make sure the ProxyChannel is not released when using the returned handle.
///
DeviceHandle deviceHandle() const;
};
Expand Down
37 changes: 17 additions & 20 deletions include/mscclpp/proxy_channel_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ union ChannelTrigger {
#endif // defined(MSCCLPP_DEVICE_COMPILE)
};

struct ProxyChannelDeviceHandle {
struct BaseProxyChannelDeviceHandle {
SemaphoreId semaphoreId_;

Host2DeviceSemaphoreDeviceHandle semaphore_;
Expand All @@ -92,6 +92,12 @@ struct ProxyChannelDeviceHandle {
// can produce for and the sole proxy thread consumes it.
FifoDeviceHandle fifo_;

BaseProxyChannelDeviceHandle() {}

BaseProxyChannelDeviceHandle(SemaphoreId semaphoreId, Host2DeviceSemaphoreDeviceHandle semaphore,
FifoDeviceHandle fifo)
: semaphoreId_(semaphoreId), semaphore_(semaphore), fifo_(fifo) {}

#if defined(MSCCLPP_DEVICE_COMPILE)
/// Push a @ref TriggerData to the FIFO.
/// @param dst The destination memory region.
Expand Down Expand Up @@ -175,34 +181,36 @@ struct ProxyChannelDeviceHandle {
#endif // defined(MSCCLPP_DEVICE_COMPILE)
};

struct SimpleProxyChannelDeviceHandle {
ProxyChannelDeviceHandle proxyChan_;
struct ProxyChannelDeviceHandle : public BaseProxyChannelDeviceHandle {
MemoryId dst_;
MemoryId src_;

ProxyChannelDeviceHandle(){};

ProxyChannelDeviceHandle(SemaphoreId semaphoreId, Host2DeviceSemaphoreDeviceHandle semaphore, FifoDeviceHandle fifo,
MemoryId dst, MemoryId src)
: BaseProxyChannelDeviceHandle(semaphoreId, semaphore, fifo), dst_(dst), src_(src) {}

#if defined(MSCCLPP_DEVICE_COMPILE)
/// Push a @ref TriggerData to the FIFO.
/// @param dstOffset The offset into the destination memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
MSCCLPP_DEVICE_INLINE void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
proxyChan_.put(dst_, dstOffset, src_, srcOffset, size);
BaseProxyChannelDeviceHandle::put(dst_, dstOffset, src_, srcOffset, size);
}

/// Push a @ref TriggerData to the FIFO.
/// @param offset The common offset into the destination and source memory regions.
/// @param size The size of the transfer.
MSCCLPP_DEVICE_INLINE void put(uint64_t offset, uint64_t size) { put(offset, offset, size); }

/// Push a @ref TriggerFlag to the FIFO.
MSCCLPP_DEVICE_INLINE void signal() { proxyChan_.signal(); }

/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param dstOffset The offset into the destination memory region.
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
MSCCLPP_DEVICE_INLINE void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
proxyChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size);
BaseProxyChannelDeviceHandle::putWithSignal(dst_, dstOffset, src_, srcOffset, size);
}

/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
Expand All @@ -215,7 +223,7 @@ struct SimpleProxyChannelDeviceHandle {
/// @param srcOffset The offset into the source memory region.
/// @param size The size of the transfer.
MSCCLPP_DEVICE_INLINE void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) {
proxyChan_.putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size);
BaseProxyChannelDeviceHandle::putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size);
}

/// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO.
Expand All @@ -224,17 +232,6 @@ struct SimpleProxyChannelDeviceHandle {
MSCCLPP_DEVICE_INLINE void putWithSignalAndFlush(uint64_t offset, uint64_t size) {
putWithSignalAndFlush(offset, offset, size);
}

/// Push a @ref TriggerSync to the FIFO.
MSCCLPP_DEVICE_INLINE void flush() { proxyChan_.flush(); }

/// Check if the proxy channel has been signaled.
/// @return true if the proxy channel has been signaled.
MSCCLPP_DEVICE_INLINE bool poll() { return proxyChan_.poll(); }

/// Wait for the proxy channel to be signaled.
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
MSCCLPP_DEVICE_INLINE void wait(int64_t maxSpinCount = 10000000) { proxyChan_.wait(maxSpinCount); }
#endif // defined(MSCCLPP_DEVICE_COMPILE)
};

Expand Down
2 changes: 1 addition & 1 deletion python/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
numa,
ProxyService,
RegisteredMemory,
SimpleProxyChannel,
ProxyChannel,
SmChannel,
SmDevice2DeviceSemaphore,
TcpBootstrap,
Expand Down
12 changes: 6 additions & 6 deletions python/mscclpp/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Host2HostSemaphore,
ProxyService,
RegisteredMemory,
SimpleProxyChannel,
ProxyChannel,
SmChannel,
SmDevice2DeviceSemaphore,
TcpBootstrap,
Expand Down Expand Up @@ -180,8 +180,8 @@ def make_proxy_channels(
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
channels = {}
for rank in semaphores:
channels[rank] = SimpleProxyChannel(
proxy_service.proxy_channel(semaphore_ids[rank]), memory_ids[rank], memory_ids[self.my_rank]
channels[rank] = proxy_service.proxy_channel(
semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank]
)
return channels

Expand Down Expand Up @@ -218,8 +218,8 @@ def make_proxy_channels_with_scratch(
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
channels = {}
for rank in semaphores:
channels[rank] = SimpleProxyChannel(
proxy_service.proxy_channel(semaphore_ids[rank]), memory_ids[rank], memory_ids[self.my_rank]
channels[rank] = proxy_service.proxy_channel(
semaphore_ids[rank], memory_ids[rank], memory_ids[self.my_rank]
)
return channels

Expand All @@ -232,7 +232,7 @@ def register_semaphore_with_proxy(
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
channels = {}
for rank in semaphores:
channels[rank] = proxy_service.proxy_channel(semaphore_ids[rank])
channels[rank] = proxy_service.base_proxy_channel(semaphore_ids[rank])
return channels

def register_memory_with_proxy(
Expand Down
35 changes: 19 additions & 16 deletions python/mscclpp/proxy_channel_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,36 @@ void register_proxy_channel(nb::module_& m) {
.def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore"))
.def("add_memory", &ProxyService::addMemory, nb::arg("memory"))
.def("semaphore", &ProxyService::semaphore, nb::arg("id"))
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"));
.def("base_proxy_channel", &ProxyService::baseProxyChannel, nb::arg("id"))
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"), nb::arg("dst"), nb::arg("src"));

nb::class_<ProxyChannel>(m, "ProxyChannel")
nb::class_<BaseProxyChannel>(m, "BaseProxyChannel")
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"))
.def("device_handle", &BaseProxyChannel::deviceHandle);

nb::class_<BaseProxyChannel::DeviceHandle>(m, "BaseProxyChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphoreId_", &BaseProxyChannel::DeviceHandle::semaphoreId_)
.def_rw("semaphore_", &BaseProxyChannel::DeviceHandle::semaphore_)
.def_rw("fifo_", &BaseProxyChannel::DeviceHandle::fifo_)
.def_prop_ro("raw", [](const BaseProxyChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});

nb::class_<ProxyChannel>(m, "ProxyChannel")
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>, MemoryId, MemoryId>(),
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
.def("device_handle", &ProxyChannel::deviceHandle);

nb::class_<ProxyChannel::DeviceHandle>(m, "ProxyChannelDeviceHandle")
.def(nb::init<>())
.def_rw("semaphoreId_", &ProxyChannel::DeviceHandle::semaphoreId_)
.def_rw("semaphore_", &ProxyChannel::DeviceHandle::semaphore_)
.def_rw("fifo_", &ProxyChannel::DeviceHandle::fifo_)
.def_rw("src_", &ProxyChannel::DeviceHandle::src_)
.def_rw("dst_", &ProxyChannel::DeviceHandle::dst_)
.def_prop_ro("raw", [](const ProxyChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});

nb::class_<SimpleProxyChannel>(m, "SimpleProxyChannel")
.def(nb::init<ProxyChannel, MemoryId, MemoryId>(), nb::arg("proxyChan"), nb::arg("dst"), nb::arg("src"))
.def(nb::init<SimpleProxyChannel>(), nb::arg("proxyChan"))
.def("device_handle", &SimpleProxyChannel::deviceHandle);

nb::class_<SimpleProxyChannel::DeviceHandle>(m, "SimpleProxyChannelDeviceHandle")
.def(nb::init<>())
.def_rw("proxyChan_", &SimpleProxyChannel::DeviceHandle::proxyChan_)
.def_rw("src_", &SimpleProxyChannel::DeviceHandle::src_)
.def_rw("dst_", &SimpleProxyChannel::DeviceHandle::dst_)
.def_prop_ro("raw", [](const SimpleProxyChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});
};
Loading

0 comments on commit 1403e39

Please sign in to comment.