Skip to content

Commit

Permalink
move multimem instruction to source code
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Jan 24, 2024
1 parent fa0565f commit 41b703f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 12 deletions.
1 change: 1 addition & 0 deletions include/mscclpp/gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri
#else

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#endif
Expand Down
44 changes: 44 additions & 0 deletions include/mscclpp/nvls_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,57 @@
#ifndef MSCCLPP_NVLS_DEVICE_HPP_
#define MSCCLPP_NVLS_DEVICE_HPP_

#include <mscclpp/gpu.hpp>
#include <type_traits>

#include "device.hpp"

namespace mscclpp {

template <class>
constexpr bool dependentFalse = false; // workaround before CWG2518/P2593R1

/// Device-side handle for @ref Host2DeviceSemaphore.
struct DeviceMulticastPointerDeviceHandle {
void* devicePtr;
void* mcPtr;
size_t bufferSize;

#if defined(MSCCLPP_DEVICE_COMPILE)
template <int NElemPerThread = 4, typename TVaule = float4, typename T = float>
MSCCLPP_DEVICE_INLINE void multimemLoad(TVaule& val, T* ptr) {
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
if constexpr (std::is_same<T, float>::value) {
asm("multimem.ld_reduce.global.add.v4.f32 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same<T, half2>::value) {
asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
: "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}
};

template <int NElemPerThread = 4, typename TVaule, typename T>
MSCCLPP_DEVICE_INLINE void multimemStore(const TVaule& val, T* ptr) {
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
if constexpr (std::is_same<T, float>::value) {
asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z),
"r"(val.w)
: "memory");
} else if constexpr (std::is_same<T, half2>::value) {
asm volatile("multimem.st.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z),
"r"(val.w)
: "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}
};
#endif
};

} // namespace mscclpp
Expand Down
14 changes: 2 additions & 12 deletions python/mscclpp_benchmark/allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -783,16 +783,6 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
// -------------------------------------------

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
#define MULTIMEM_ST(val, ptr) \
asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), \
"r"(val.w) \
: "memory");
// specific PTX for fp16 reduction. bf16 would be multimem.ld_reduce.global.add.v4.bf16x2 etc
#define MULTIMEM_LD(val, ptr) \
asm("multimem.ld_reduce.global.add.v4.f32 {%0,%1,%2,%3}, [%4];" \
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) \
: "l"(ptr) \
: "memory");

extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
Expand Down Expand Up @@ -822,8 +812,8 @@ extern "C" __global__ void __launch_bounds__(1024, 1)

for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
uint4 val;
MULTIMEM_LD(val, mc_ptr + idx);
MULTIMEM_ST(val, mc_ptr + idx);
nvlsPtrs.multimemLoad(val, mc_ptr + idx);
nvlsPtrs.multimemStore(val, mc_ptr + idx);
}

deviceSyncer.sync(gridDim.x);
Expand Down

0 comments on commit 41b703f

Please sign in to comment.