Skip to content

Commit

Permalink
specialized
Browse files Browse the repository at this point in the history
  • Loading branch information
Saeed Maleki committed Aug 1, 2023
1 parent 7547c52 commit 9d6e51e
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 100 deletions.
1 change: 0 additions & 1 deletion include/mscclpp/sm_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#define MSCCLPP_SM_CHANNEL_HPP_

#include <mscclpp/core.hpp>
#include <mscclpp/packet.hpp>
#include <mscclpp/semaphore.hpp>
#include <mscclpp/sm_channel_device.hpp>
#include <type_traits>
Expand Down
203 changes: 112 additions & 91 deletions include/mscclpp/sm_channel_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,108 +4,129 @@
#ifndef MSCCLPP_SM_CHANNEL_DEVICE_HPP_
#define MSCCLPP_SM_CHANNEL_DEVICE_HPP_

#include "semaphore_device.hpp"
#include "poll.hpp"
#include "packet.hpp"

namespace mscclpp {

#ifdef __CUDACC__
/// Helper for aligned data type access.
/// @tparam T The data type.
template <typename T>
struct Element {
/// Load an element from DRAM.
///
/// This is a warpper of ld.volatile.global.* PTX instruction. Address alignment is not this function's
/// responsibility.
///
/// @param v The value to be loaded.
/// @param p The address of the value to be loaded.
///
static __forceinline__ __device__ void load(T& v, const T* p) {
// We should only use the specialized functions.
__assert_fail("Unsupported type", __FILE__, __LINE__, __PRETTY_FUNCTION__);
}

/// Write an element on DRAM.
///
/// This is a wrapper of st.volatile.global.* PTX instruction. Address alignment is not this function's
/// responsibility.
///
/// @param p The address of the value to be written.
/// @param v The value to be written.
///
static __forceinline__ __device__ void store(T* p, const T& v) {
// We should only use the specialized functions.
__assert_fail("Unsupported type", __FILE__, __LINE__, __PRETTY_FUNCTION__);
}

/// Copy aligned elements from the source memory to the destination memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of
/// elements.
///
/// @param dst The destination address.
/// @param src The source address.
/// @param numElems The number of elements to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different
/// from the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
static __forceinline__ __device__ void copy(T* dst, T* src, uint64_t numElems, uint32_t threadId,
uint32_t numThreads) {
T reg;
for (size_t i = threadId; i < numElems; i += numThreads) {
// Load to register first.
load(reg, src + i);
store(dst + i, reg);
}
}
};

template <>
struct Element<unsigned long long> {
using T = unsigned long long;
static __forceinline__ __device__ void load(T& v, const T* p) {
asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(v) : "l"(p) : "memory");
}

static __forceinline__ __device__ void store(T* p, const T& v) {
asm volatile("st.volatile.global.u64 [%0], %1;" : : "l"(p), "l"(v) : "memory");
}
};

template <>
struct Element<uint> {
using T = uint;
static __forceinline__ __device__ void load(T& v, const T* p) {
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(v) : "l"(p) : "memory");
}

static __forceinline__ __device__ void store(T* p, const T& v) {
asm volatile("st.volatile.global.u32 [%0], %1;" : : "l"(p), "r"(v) : "memory");
}
};

template <>
struct Element<ulonglong2> {
using T = ulonglong2;
static __forceinline__ __device__ void load(T& v, const T* p) {
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory");
}

static __forceinline__ __device__ void store(T* p, const T& v) {
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" : : "l"(p), "l"(v.x), "l"(v.y) : "memory");
}
};

template <>
struct Element<uint4> {
using T = uint4;
static __forceinline__ __device__ void load(T& v, const T* p) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
: "=r"(v.w), "=r"(v.x), "=r"(v.y), "=r"(v.z)
: "l"(p)
: "memory");
}

static __forceinline__ __device__ void store(T* p, const T& v) {
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
:
: "l"(p), "r"(v.w), "r"(v.x), "r"(v.y), "r"(v.z)
: "memory");
}
};
#endif // __CUDACC__

/// Channel for accessing peer memory directly from SM.
struct SmChannelDeviceHandle {
SmDevice2DeviceSemaphore::DeviceHandle semaphore_;
SmDevice2DeviceSemaphoreDeviceHandle semaphore_;
void* src_;
void* dst_;
void* getPacketBuffer_;

#ifdef __CUDACC__
/// Helper for aligned data type access.
/// @tparam T The data type.
template <typename T>
struct Element {
static constexpr bool is4B = (sizeof(T) == 4);
static constexpr bool is8B = (sizeof(T) == 8);
static constexpr bool is4Bx2 =
(std::is_same<T, int2>::value || std::is_same<T, uint2>::value || std::is_same<T, float2>::value);
static constexpr bool is4Bx4 =
(std::is_same<T, int4>::value || std::is_same<T, uint4>::value || std::is_same<T, float4>::value);
static constexpr bool is8Bx2 =
(std::is_same<T, longlong2>::value || std::is_same<T, ulonglong2>::value || std::is_same<T, double2>::value);
// Note: we do not support long2 and ulong2 as their size may differ on different platforms.
static constexpr bool isValid = (is4B || is8B || is4Bx2 || is4Bx4 || is8Bx2);

/// Load an element from DRAM.
///
/// This is a warpper of ld.volatile.global.* PTX instruction. Address alignment is not this function's
/// responsibility.
///
/// @param v The value to be loaded.
/// @param p The address of the value to be loaded.
///
static __forceinline__ __device__ void load(T& v, const T* p) {
if constexpr (is4B) {
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(v) : "l"(p) : "memory");
} else if constexpr (is8B) {
asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(v) : "l"(p) : "memory");
} else if constexpr (is4Bx2) {
asm volatile("ld.volatile.global.v2.u32 {%0,%1}, [%2];" : "=r"(v.x), "=r"(v.y) : "l"(p) : "memory");
} else if constexpr (is4Bx4) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];"
: "=r"(v.w), "=r"(v.x), "=r"(v.y), "=r"(v.z)
: "l"(p)
: "memory");
} else if constexpr (is8Bx2) {
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory");
}
static_assert(isValid, "Unsupported type T");
}

/// Write an element on DRAM.
///
/// This is a wrapper of st.volatile.global.* PTX instruction. Address alignment is not this function's
/// responsibility.
///
/// @param p The address of the value to be written.
/// @param v The value to be written.
///
static __forceinline__ __device__ void store(T* p, const T& v) {
if constexpr (is4B) {
asm volatile("st.volatile.global.u32 [%0], %1;" : : "l"(p), "r"(v) : "memory");
} else if constexpr (is8B) {
asm volatile("st.volatile.global.u64 [%0], %1;" : : "l"(p), "l"(v) : "memory");
} else if constexpr (is4Bx2) {
asm volatile("st.volatile.global.v2.u32 [%0], {%1,%2};" : : "l"(p), "r"(v.x), "r"(v.y) : "memory");
} else if constexpr (is4Bx4) {
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};"
:
: "l"(p), "r"(v.w), "r"(v.x), "r"(v.y), "r"(v.z)
: "memory");
} else if constexpr (is8Bx2) {
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" : : "l"(p), "l"(v.x), "l"(v.y) : "memory");
}
static_assert(isValid, "Unsupported type T");
}

/// Copy aligned elements from the source memory to the destination memory.
///
/// This function is intended to be collectively called by multiple threads. Each thread copies a part of
/// elements.
///
/// @param dst The destination address.
/// @param src The source address.
/// @param numElems The number of elements to be copied.
/// @param threadId The index of the current thread among all threads running this function. This is different
/// from the `threadIdx` in CUDA.
/// @param numThreads The total number of threads that run this function.
///
static __forceinline__ __device__ void copy(T* dst, T* src, uint64_t numElems, uint32_t threadId,
uint32_t numThreads) {
T reg;
for (size_t i = threadId; i < numElems; i += numThreads) {
// Load to register first.
load(reg, src + i);
store(dst + i, reg);
}
}
};

/// Load a value from the remote memory.
/// @tparam T The type of the value to be loaded.
/// @param index The index of the value to be loaded. The offset in bytes is calculated as index * sizeof(T).
Expand Down
7 changes: 1 addition & 6 deletions python/semaphore_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@ void register_semaphore(nb::module_& m) {

nb::class_<SmDevice2DeviceSemaphore> smDevice2DeviceSemaphore(m, "SmDevice2DeviceSemaphore");
smDevice2DeviceSemaphore
.def_static(
"create",
[](Communicator& comm, std::shared_ptr<Connection> conn) {
return std::make_shared<SmDevice2DeviceSemaphore>(comm, conn);
},
nb::arg("communicator"), nb::arg("connection"))
.def(nb::init<Communicator&, std::shared_ptr<Connection>>(), nb::arg("communicator"), nb::arg("connection"))
.def("device_handle", &SmDevice2DeviceSemaphore::deviceHandle);

nb::class_<SmDevice2DeviceSemaphore::DeviceHandle>(smDevice2DeviceSemaphore, "DeviceHandle")
Expand Down
13 changes: 11 additions & 2 deletions python/sm_channel_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@ void register_sm_channel(nb::module_& m) {
uintptr_t src) { new (smChannel) SmChannel(semaphore, dst, (void*)src); })
.def("device_handle", &SmChannel::deviceHandle);

nb::class_<SmChannel::DeviceHandle>(smChannel, "DeviceHandle");

m.def("device_handle", &deviceHandle<SmChannel>, nb::arg("smChannel"));

nb::class_<SmChannel::DeviceHandle>(smChannel, "DeviceHandle")
.def(nb::init<>())
.def_rw("semaphore_", &SmChannel::DeviceHandle::semaphore_)
.def_rw("src_", &SmChannel::DeviceHandle::src_)
.def_rw("dst_", &SmChannel::DeviceHandle::dst_)
.def_rw("getPacketBuffer_", &SmChannel::DeviceHandle::getPacketBuffer_)
.def_prop_ro("raw", [](const SmChannel::DeviceHandle& self) -> nb::bytes {
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
});

};

0 comments on commit 9d6e51e

Please sign in to comment.