diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 60494a099..02c277a3e 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -13,6 +13,9 @@ #include #include #include +#include +#include +#include #include #include @@ -40,6 +43,7 @@ class Bootstrap { virtual void allGather(void* allData, int size) = 0; virtual void barrier() = 0; + void groupBarrier(const std::vector& ranks); void send(const std::vector& data, int peer, int tag); void recv(std::vector& data, int peer, int tag); }; @@ -125,6 +129,7 @@ class TcpBootstrap : public Bootstrap { enum class Transport { Unknown, // Unknown transport type. CudaIpc, // CUDA IPC transport type. + Nvls, // NVLS transport type. IB0, // InfiniBand device 0 transport type. IB1, // InfiniBand device 1 transport type. IB2, // InfiniBand device 2 transport type. @@ -136,10 +141,11 @@ enum class Transport { NumTransports // The number of transports. }; -const std::string TransportNames[] = {"UNK", "IPC", "IB0", "IB1", "IB2", "IB3", "IB4", "IB5", "IB6", "IB7", "NUM"}; +const std::string TransportNames[] = {"UNK", "IPC", "NVLS", "IB0", "IB1", "IB2", + "IB3", "IB4", "IB5", "IB6", "IB7", "NUM"}; namespace detail { -const size_t TransportFlagsSize = 10; +const size_t TransportFlagsSize = 11; static_assert(TransportFlagsSize == static_cast(Transport::NumTransports), "TransportFlagsSize must match the number of transports"); /// Bitset for storing transport flags. @@ -445,12 +451,50 @@ class Connection { static std::shared_ptr getImpl(Endpoint& memory); }; +class NvlsConnection { + public: + NvlsConnection(size_t bufferSize, int numDevices); + NvlsConnection(const std::vector& data); + NvlsConnection() = delete; + std::vector serialize(); + + // Everyone needs to synchronize after creating a NVLS connection before adding devices + void addDevice(); + void addDevice(int cudaDeviceId); + + struct DeviceMulticastPointer { + private: + std::shared_ptr> deviceMem_; + std::shared_ptr mcPtr_; + size_t bufferSize_; + + public: + using DeviceHandle = DeviceMulticastPointerDeviceHandle; + DeviceMulticastPointer(std::shared_ptr> deviceMem, std::shared_ptr mcPtr, + size_t bufferSize) + : deviceMem_(deviceMem), mcPtr_(mcPtr), bufferSize_(bufferSize) {} + DeviceHandle deviceHandle(); + char* getDevicePtr(); + + friend class NvlsConnection; + }; + + std::shared_ptr allocateAndBindCuda(size_t size); + size_t getMultiCastMinGranularity(); + + private: + class Impl; + std::shared_ptr pimpl_; +}; + /// Used to configure an endpoint. struct EndpointConfig { static const int DefaultMaxCqSize = 1024; static const int DefaultMaxCqPollNum = 1; static const int DefaultMaxSendWr = 8192; static const int DefaultMaxWrPerSend = 64; + // the recommended buffer size for NVLS, returned by cuMulticastGetGranularity + static const int DefaultNvlsBufferSize = (1 << 29); Transport transport; int ibMaxCqSize = DefaultMaxCqSize; @@ -458,6 +502,8 @@ struct EndpointConfig { int ibMaxSendWr = DefaultMaxSendWr; int ibMaxWrPerSend = DefaultMaxWrPerSend; + size_t nvlsBufferSize = DefaultNvlsBufferSize; + /// Default constructor. Sets transport to Transport::Unknown. EndpointConfig() : transport(Transport::Unknown) {} @@ -465,6 +511,11 @@ struct EndpointConfig { /// /// @param transport The transport to use. EndpointConfig(Transport transport) : transport(transport) {} + + /// Constructor for NVLS explicitly + /// @param transport must be either NvlsRoot or NvlsNonRoot + /// @param nvlsBufferSize is the buffer to be alloced on each device + EndpointConfig(Transport transport, size_t nvlsBufferSize) : transport(transport), nvlsBufferSize(nvlsBufferSize) {} }; /// Represents a context for communication. This provides a low-level interface for forming connections in use-cases @@ -648,6 +699,16 @@ class Communicator { /// to the connection. NonblockingFuture> connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig); + /// Connect to NVLS on setup. + /// + /// This function used to connect to NVLS on setup. NVLS collective using multicast operations to send/recv data. + /// Here we need to put all involved ranks into the collective group. + /// + /// @param allRanks The ranks of all processes involved in the collective. + /// @param config The configuration for the local endpoint. + /// @return std::shared_ptr A shared pointer to the NVLS connection. + std::shared_ptr connctNvlsCollective(std::vector allRanks, EndpointConfig config); + /// Get the remote rank a connection is connected to. /// /// @param connection The connection to get the remote rank for. diff --git a/include/mscclpp/gpu.hpp b/include/mscclpp/gpu.hpp index d3d48ce1f..f560a655c 100644 --- a/include/mscclpp/gpu.hpp +++ b/include/mscclpp/gpu.hpp @@ -19,6 +19,9 @@ using cudaIpcMemHandle_t = hipIpcMemHandle_t; using CUresult = hipError_t; using CUdeviceptr = hipDeviceptr_t; +using CUmemGenericAllocationHandle = hipMemGenericAllocationHandle_t; +using CUmemAllocationProp = hipMemAllocationProp; +using CUmemAccessDesc = hipMemAccessDesc; constexpr auto cudaSuccess = hipSuccess; constexpr auto cudaStreamNonBlocking = hipStreamNonBlocking; @@ -32,6 +35,11 @@ constexpr auto cudaMemcpyHostToDevice = hipMemcpyHostToDevice; constexpr auto cudaMemcpyDeviceToHost = hipMemcpyDeviceToHost; constexpr auto cudaIpcMemLazyEnablePeerAccess = hipIpcMemLazyEnablePeerAccess; +constexpr auto CU_MEM_ALLOCATION_TYPE_PINNED = hipMemAllocationTypePinned; +constexpr auto CU_MEM_LOCATION_TYPE_DEVICE = hipMemLocationTypeDevice; +constexpr auto CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR = hipMemHandleTypePosixFileDescriptor; +constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWrite; + #ifndef CUDA_SUCCESS #define CUDA_SUCCESS hipSuccess #endif // CUDA_SUCCESS @@ -68,13 +76,29 @@ constexpr auto cudaIpcMemLazyEnablePeerAccess = hipIpcMemLazyEnablePeerAccess; #define cudaIpcCloseMemHandle(...) hipIpcCloseMemHandle(__VA_ARGS__) #define cuGetErrorString(...) hipDrvGetErrorString(__VA_ARGS__) +#define cuMemAddressReserve(...) hipMemAddressReserve(__VA_ARGS__) +#define cuMemAddressFree(...) hipMemAddressFree(__VA_ARGS__) #define cuMemGetAddressRange(...) hipMemGetAddressRange(__VA_ARGS__) +#define cuMemCreate(...) hipMemCreate(__VA_ARGS__) +#define cuMemRelease(...) hipMemRelease(__VA_ARGS__) +#define cuMemSetAccess(...) hipMemSetAccess(__VA_ARGS__) +#define cuMemMap(...) hipMemMap(__VA_ARGS__) +#define cuMemUnmap(...) hipMemUnmap(__VA_ARGS__) #else #include +#include #include #endif +// NVLS +#if !defined(__HIP_PLATFORM_AMD__) +#include +#define USE_NVLS ((CUDART_VERSION >= 12010) && (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 6, 0))) +#else // !defined(__HIP_PLATFORM_AMD__) +#define USE_NVLS 0 +#endif // !defined(__HIP_PLATFORM_AMD__) + #endif // MSCCLPP_GPU_HPP_ diff --git a/include/mscclpp/gpu_utils.hpp b/include/mscclpp/gpu_utils.hpp index e0cd7c3da..9be6a7d16 100644 --- a/include/mscclpp/gpu_utils.hpp +++ b/include/mscclpp/gpu_utils.hpp @@ -50,6 +50,18 @@ struct CudaStreamWithFlags { cudaStream_t stream_; }; +template +struct CudaDeleter; + +template +struct PhysicalCudaMemory { + CUmemGenericAllocationHandle memHandle_; + T* devicePtr_; + size_t size_; + PhysicalCudaMemory(CUmemGenericAllocationHandle memHandle, T* devicePtr, size_t size) + : memHandle_(memHandle), devicePtr_(devicePtr), size_(size) {} +}; + namespace detail { /// A wrapper of cudaMalloc that sets the allocated memory to zero. @@ -67,6 +79,47 @@ T* cudaCalloc(size_t nelem) { return ptr; } +template +PhysicalCudaMemory* cudaPhysicalCalloc(size_t nelem, size_t gran) { + AvoidCudaGraphCaptureGuard cgcGuard; + + int deviceId = -1; + MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = deviceId; +#if defined(__HIP_PLATFORM_AMD__) + // TODO: revisit when HIP fixes this typo in the field name + prop.requestedHandleType = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; +#else + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; +#endif + + CUmemGenericAllocationHandle memHandle; + size_t bufferSize = sizeof(T) * nelem; + // allocate physical memory + MSCCLPP_CUTHROW(cuMemCreate(&memHandle, bufferSize, &prop, 0 /*flags*/)); + + CUmemAccessDesc accessDesc = {}; + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = deviceId; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + + T* devicePtr = nullptr; + // Map the device pointer + MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&devicePtr, bufferSize, gran, 0U, 0)); + MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)devicePtr, bufferSize, 0, memHandle, 0)); + MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)devicePtr, bufferSize, &accessDesc, 1)); + CudaStreamWithFlags stream(cudaStreamNonBlocking); + MSCCLPP_CUDATHROW(cudaMemsetAsync(devicePtr, 0, bufferSize, stream)); + + MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); + + return new PhysicalCudaMemory(memHandle, devicePtr, bufferSize); +} + template T* cudaExtCalloc(size_t nelem) { AvoidCudaGraphCaptureGuard cgcGuard; @@ -118,6 +171,25 @@ Memory safeAlloc(size_t nelem) { return Memory(ptr, Deleter()); } +template +Memory safeAlloc(size_t nelem, size_t gran) { + if ((nelem * sizeof(T)) % gran) { + throw Error("The request allocation size is not divisible by the required granularity:" + + std::to_string(nelem * sizeof(T)) + " vs " + std::to_string(gran), + ErrorCode::InvalidUsage); + } + T* ptr = nullptr; + try { + ptr = alloc(nelem, gran); + } catch (...) { + if (ptr) { + Deleter()(ptr); + } + throw; + } + return Memory(ptr, Deleter()); +} + } // namespace detail /// A deleter that calls cudaFree for use with std::unique_ptr or std::shared_ptr. @@ -131,6 +203,17 @@ struct CudaDeleter { } }; +template +struct CudaPhysicalDeleter { + static_assert(!std::is_array_v, "T must not be an array"); + void operator()(PhysicalCudaMemory* ptr) { + AvoidCudaGraphCaptureGuard cgcGuard; + MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr->devicePtr_, ptr->size_)); + MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr->devicePtr_, ptr->size_)); + MSCCLPP_CUTHROW(cuMemRelease(ptr->memHandle_)); + } +}; + /// A deleter that calls cudaFreeHost for use with std::unique_ptr or std::shared_ptr. /// @tparam T Type of each element in the allocated memory. template @@ -151,6 +234,18 @@ std::shared_ptr allocSharedCuda(size_t count = 1) { return detail::safeAlloc, CudaDeleter, std::shared_ptr>(count); } +/// Allocated physical memory on the device and returns a memory handle along with a memory handle for it. +/// The deallocation only happens PhysicalCudaMemory goes out of scope. +/// @tparam T Type of each element in the allocated memory. +/// @param count Number of elements to allocate. +/// @param gran the granularity of the allocation. +/// @return A std::shared_ptr to the memory handle and a device pointer for that memory. +template +std::shared_ptr> allocSharedPhysicalCuda(size_t count, size_t gran) { + return detail::safeAlloc, detail::cudaPhysicalCalloc, CudaPhysicalDeleter, + std::shared_ptr>>(count, gran); +} + /// Allocates memory on the device and returns a std::shared_ptr to it. The memory is zeroed out. /// @tparam T Type of each element in the allocated memory. /// @param count Number of elements to allocate. @@ -174,6 +269,18 @@ UniqueCudaPtr allocUniqueCuda(size_t count = 1) { return detail::safeAlloc, CudaDeleter, UniqueCudaPtr>(count); } +/// Allocated physical memory on the device and returns a memory handle along with a virtual memory handle for it. +/// The memory is zeroed out. +/// @tparam T Type of each element in the allocated memory. +/// @param count Number of elements to allocate. +/// @param gran the granularity of the allocation. +/// @return A std::unique_ptr to the memory handle and a device pointer for that memory. +template +std::unique_ptr> allocUniquePhysicalCuda(size_t count, size_t gran) { + return detail::safeAlloc, detail::cudaPhysicalCalloc, CudaPhysicalDeleter, + std::unique_ptr, CudaDeleter>>>(count, gran); +} + /// Allocates memory on the device and returns a std::unique_ptr to it. The memory is zeroed out. /// @tparam T Type of each element in the allocated memory. /// @param count Number of elements to allocate. diff --git a/include/mscclpp/nvls_device.hpp b/include/mscclpp/nvls_device.hpp new file mode 100644 index 000000000..52ade275d --- /dev/null +++ b/include/mscclpp/nvls_device.hpp @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCLPP_NVLS_DEVICE_HPP_ +#define MSCCLPP_NVLS_DEVICE_HPP_ + +#include +#include + +#include "device.hpp" + +namespace mscclpp { + +template +constexpr bool dependentFalse = false; // workaround before CWG2518/P2593R1 + +/// Device-side handle for @ref Host2DeviceSemaphore. +struct DeviceMulticastPointerDeviceHandle { + void* devicePtr; + void* mcPtr; + size_t bufferSize; + +#if defined(MSCCLPP_DEVICE_COMPILE) + template + MSCCLPP_DEVICE_INLINE void multimemLoad(TVaule& val, T* ptr) { + static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4"); + if constexpr (std::is_same::value) { + asm("multimem.ld_reduce.global.add.v4.f32 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(ptr) + : "memory"); + } else if constexpr (std::is_same::value) { + asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(ptr) + : "memory"); + } else { + static_assert(dependentFalse, "Not supported type"); + } + }; + + template + MSCCLPP_DEVICE_INLINE void multimemStore(const TVaule& val, T* ptr) { + static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4"); + if constexpr (std::is_same::value) { + asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), + "r"(val.w) + : "memory"); + } else if constexpr (std::is_same::value) { + asm volatile("multimem.st.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), + "r"(val.w) + : "memory"); + } else { + static_assert(dependentFalse, "Not supported type"); + } + }; +#endif +}; + +} // namespace mscclpp + +#endif // MSCCLPP_SEMAPHORE_DEVICE_HPP_ diff --git a/include/mscclpp/utils.hpp b/include/mscclpp/utils.hpp index c8ef3d271..80b3bf39d 100644 --- a/include/mscclpp/utils.hpp +++ b/include/mscclpp/utils.hpp @@ -37,6 +37,8 @@ struct ScopedTimer : public Timer { std::string getHostName(int maxlen, const char delim); +bool isNvlsSupported(); + } // namespace mscclpp #endif // MSCCLPP_UTILS_HPP_ diff --git a/python/mscclpp/__init__.py b/python/mscclpp/__init__.py index d411bc1b0..8f013e080 100644 --- a/python/mscclpp/__init__.py +++ b/python/mscclpp/__init__.py @@ -6,6 +6,7 @@ from ._mscclpp import ( Communicator, Connection, + EndpointConfig, Fifo, Host2DeviceSemaphore, Host2HostSemaphore, @@ -19,6 +20,7 @@ Transport, TransportFlags, version, + is_nvls_supported, ) __version__ = version() diff --git a/python/mscclpp/comm.py b/python/mscclpp/comm.py index c01c04a2d..d84410668 100644 --- a/python/mscclpp/comm.py +++ b/python/mscclpp/comm.py @@ -8,6 +8,7 @@ from ._mscclpp import ( Communicator, Connection, + EndpointConfig, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, @@ -79,15 +80,21 @@ def my_ib_device(self, local_rank: int) -> Transport: assert False # only 8 IBs are supported def make_connection( - self, remote_ranks: list[int], transports: Transport | dict[int, Transport] + self, + all_ranks: list[int], + endpoints: EndpointConfig | Transport | dict[int, EndpointConfig] | dict[int, Transport], ) -> dict[int, Connection]: + if type(endpoints) is Transport: + endpoints = EndpointConfig(endpoints) + if endpoints.transport == Transport.Nvls: + return self.communicator.connct_nvls_collective(all_ranks, endpoints) connections = {} - for rank in remote_ranks: - if type(transports) is dict: - transport = transports[rank] + for rank in all_ranks: + if type(endpoints) is dict: + endpoint = endpoints[rank] else: - transport = transports - connections[rank] = self.communicator.connect_on_setup(rank, 0, transport) + endpoint = endpoints + connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint) self.communicator.setup() connections = {rank: connections[rank].get() for rank in connections} return connections diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index 4e92f8841..5fd4bd317 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include @@ -72,6 +73,7 @@ void register_core(nb::module_& m) { nb::enum_(m, "Transport") .value("Unknown", Transport::Unknown) .value("CudaIpc", Transport::CudaIpc) + .value("Nvls", Transport::Nvls) .value("IB0", Transport::IB0) .value("IB1", Transport::IB1) .value("IB2", Transport::IB2) @@ -124,6 +126,24 @@ void register_core(nb::module_& m) { .def("transport", &Connection::transport) .def("remote_transport", &Connection::remoteTransport); + nb::class_(m, "DeviceMulticastPointer") + .def("get_device_ptr", + [](NvlsConnection::DeviceMulticastPointer* self) { return (uintptr_t)self->getDevicePtr(); }) + .def("device_handle", &NvlsConnection::DeviceMulticastPointer::deviceHandle); + + nb::class_(m, "DeviceHandle") + .def(nb::init<>()) + .def_rw("devicePtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::devicePtr) + .def_rw("mcPtr", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::mcPtr) + .def_rw("size", &NvlsConnection::DeviceMulticastPointer::DeviceHandle::bufferSize) + .def_prop_ro("raw", [](const NvlsConnection::DeviceMulticastPointer::DeviceHandle& self) -> nb::bytes { + return nb::bytes(reinterpret_cast(&self), sizeof(self)); + }); + + nb::class_(m, "NvlsConnection") + .def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda) + .def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity); + nb::class_(m, "Endpoint") .def("transport", &Endpoint::transport) .def("serialize", &Endpoint::serialize) @@ -132,6 +152,7 @@ void register_core(nb::module_& m) { nb::class_(m, "EndpointConfig") .def(nb::init<>()) .def(nb::init_implicit(), nb::arg("transport")) + .def(nb::init(), nb::arg("transport"), nb::arg("nvlsBufferSize")) .def_rw("transport", &EndpointConfig::transport) .def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize) .def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum) @@ -168,6 +189,7 @@ void register_core(nb::module_& m) { .def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag")) .def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig")) + .def("connct_nvls_collective", &Communicator::connctNvlsCollective, nb::arg("allRanks"), nb::arg("config")) .def("remote_rank_of", &Communicator::remoteRankOf) .def("tag_of", &Communicator::tagOf) .def("setup", &Communicator::setup); diff --git a/python/mscclpp/utils_py.cpp b/python/mscclpp/utils_py.cpp index 16800a752..e9e847ee8 100644 --- a/python/mscclpp/utils_py.cpp +++ b/python/mscclpp/utils_py.cpp @@ -20,4 +20,5 @@ void register_utils(nb::module_& m) { nb::class_(m, "ScopedTimer").def(nb::init(), nb::arg("name")); m.def("get_host_name", &getHostName, nb::arg("maxlen"), nb::arg("delim")); + m.def("is_nvls_supported", &isNvlsSupported); } diff --git a/python/mscclpp_benchmark/allreduce.cu b/python/mscclpp_benchmark/allreduce.cu index e86047283..69aa3919f 100644 --- a/python/mscclpp_benchmark/allreduce.cu +++ b/python/mscclpp_benchmark/allreduce.cu @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -775,3 +776,57 @@ extern "C" __global__ void __launch_bounds__(1024, 1) globalFlag += 1; } } + +// ------------------------------------------- +// AllReduce6 +// NVLS +// ------------------------------------------- + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + +extern "C" __global__ void __launch_bounds__(1024, 1) + allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, + mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, TYPE* buff, int my_rank, int nranks, + size_t nelem) { + float* dev_ptr = (float*)nvlsPtrs.devicePtr; + float* mc_ptr = (float*)nvlsPtrs.mcPtr; + int tid = threadIdx.x; + int bid = blockIdx.x; + + if (tid == 0 && bid == 0) { + __threadfence_system(); + } + if (bid == 0) { + if (tid < nranks - 1) { + semaphores[tid].signal(); + semaphores[tid].wait(); + } + } + deviceSyncer.sync(gridDim.x); + + int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks; + int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks; + + int my_offset = (tid + bid * blockDim.x) * 4; + int my_step = blockDim.x * gridDim.x * 4; + + for (int idx = my_st + my_offset; idx < my_en; idx += my_step) { + uint4 val; + nvlsPtrs.multimemLoad(val, mc_ptr + idx); + nvlsPtrs.multimemStore(val, mc_ptr + idx); + } + + deviceSyncer.sync(gridDim.x); + if (tid == 0 && bid == 0) { + __threadfence_system(); + } + + if (bid == 0) { + if (tid < nranks - 1) { + semaphores[tid].signal(); + semaphores[tid].wait(); + } + } + deviceSyncer.sync(gridDim.x); +} +#endif diff --git a/python/mscclpp_benchmark/allreduce_bench.py b/python/mscclpp_benchmark/allreduce_bench.py index 80aa5e93a..47d622c2e 100644 --- a/python/mscclpp_benchmark/allreduce_bench.py +++ b/python/mscclpp_benchmark/allreduce_bench.py @@ -2,12 +2,19 @@ # Licensed under the MIT license. import cupy as cp -from mscclpp_op import MscclppAllReduce1, MscclppAllReduce2, MscclppAllReduce3, MscclppAllReduce4, MscclppAllReduce5 +from mscclpp_op import ( + MscclppAllReduce1, + MscclppAllReduce2, + MscclppAllReduce3, + MscclppAllReduce4, + MscclppAllReduce5, + MscclppAllReduce6, +) from nccl_op import NcclAllReduce from mpi4py import MPI import cupy.cuda.nccl as nccl import mscclpp.comm as mscclpp_comm -from mscclpp import ProxyService +from mscclpp import ProxyService, is_nvls_supported from prettytable import PrettyTable import netifaces as ni @@ -121,6 +128,21 @@ def bench_time(niter: int, func): return cp.cuda.get_elapsed_time(start, end) / niter * 1000.0 +def find_best_algo(mscclpp_algos, niter): + assert len(mscclpp_algos) > 0 + best_time = 10000000.0 + best_algo = None + for algo in mscclpp_algos: + config, cur_time = find_best_config(algo, niter) + if cur_time < best_time: + best_time = cur_time + best_algo = algo + algo.set_params(*config) + if MPI.COMM_WORLD.rank == 0: + print(best_algo, end="", flush=True) + return best_algo + + def find_best_config(mscclpp_call, niter): best_time = 10000000.0 for config in mscclpp_call.auto_tune(): @@ -133,7 +155,7 @@ def find_best_config(mscclpp_call, niter): best_config = MPI.COMM_WORLD.bcast(best_config, root=0) if MPI.COMM_WORLD.rank == 0: print(best_config, end="", flush=True) - return best_config + return best_config, best_time def run_benchmark( @@ -145,26 +167,27 @@ def run_benchmark( proxy_service = None if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1: + proxy_service = ProxyService() if memory.nbytes < 2**20: - mscclpp_call = MscclppAllReduce2(mscclpp_group, memory, memory_out) - elif memory.nbytes < 2**29: - mscclpp_call = MscclppAllReduce1(mscclpp_group, memory) + mscclpp_algos = [MscclppAllReduce2(mscclpp_group, memory, memory_out)] else: - proxy_service = ProxyService() - mscclpp_call = MscclppAllReduce3(mscclpp_group, memory, proxy_service) - proxy_service.start_proxy() + mscclpp_algos = [ + MscclppAllReduce1(mscclpp_group, memory), + MscclppAllReduce3(mscclpp_group, memory, proxy_service), + ] + if is_nvls_supported(): + mscclpp_algos.append(MscclppAllReduce6(mscclpp_group, nelem, data_type)) else: if memory.nbytes < 2**22: - proxy_service = ProxyService() - mscclpp_call = MscclppAllReduce5(mscclpp_group, memory, memory_out, N_GPUS_PER_NODE, proxy_service) - proxy_service.start_proxy() + mscclpp_algos = [MscclppAllReduce5(mscclpp_group, memory, memory_out, N_GPUS_PER_NODE, proxy_service)] else: - proxy_service = ProxyService() - mscclpp_call = MscclppAllReduce4(mscclpp_group, memory, N_GPUS_PER_NODE, proxy_service) - proxy_service.start_proxy() + mscclpp_algos = [MscclppAllReduce4(mscclpp_group, memory, N_GPUS_PER_NODE, proxy_service)] - best_config = find_best_config(mscclpp_call, 20) - mscclpp_call.set_params(*best_config) + proxy_service.start_proxy() + MPI.COMM_WORLD.barrier() + mscclpp_call = find_best_algo(mscclpp_algos, 20) + if isinstance(mscclpp_call, MscclppAllReduce6): + memory = mscclpp_call.get_memory() nccl_call = NcclAllReduce(nccl_op, memory) @@ -177,13 +200,8 @@ def run_benchmark( nccl_algBw = memory_nbytes / nccl_time / 1e3 nccl_check = "PASS" if check_correctness(memory, nccl_call) else "FAIL" - if ( - isinstance(mscclpp_call, MscclppAllReduce3) - or isinstance(mscclpp_call, MscclppAllReduce5) - or isinstance(mscclpp_call, MscclppAllReduce4) - ): - MPI.COMM_WORLD.barrier() - proxy_service.stop_proxy() + MPI.COMM_WORLD.barrier() + proxy_service.stop_proxy() speed_up = nccl_time / mscclpp_time if MPI.COMM_WORLD.rank == 0: @@ -247,7 +265,8 @@ def run_benchmark( mscclpp_algbw = [] nccl_algbw = [] speed_ups = [] - for i in range(10, 29): + end_range = 28 if is_nvls_supported() else 29 + for i in range(10, end_range): if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1: nelems = 2**i elif MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 2: diff --git a/python/mscclpp_benchmark/mscclpp_op.py b/python/mscclpp_benchmark/mscclpp_op.py index ab51f7c84..f36cc5f61 100644 --- a/python/mscclpp_benchmark/mscclpp_op.py +++ b/python/mscclpp_benchmark/mscclpp_op.py @@ -1,7 +1,7 @@ import os import cupy as cp import ctypes -from mscclpp import Transport, ProxyService +from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore import mscclpp.comm as mscclpp_comm from mscclpp.utils import KernelBuilder, pack @@ -418,3 +418,82 @@ def auto_tune(self): for block_size in block_size_to_try: self.set_params(nblocks, block_size) yield nblocks, block_size + + +class MscclppAllReduce6: + def __init__( + self, + group: mscclpp_comm.CommGroup, + nelem: int, + memory_dtype: cp.dtype, + block_size: int = 1024, + nblocks: int = 32, + ): + self.group = group + datatype_size = memory_dtype().itemsize + buffer_size = nelem * datatype_size + type_str = type_to_str(memory_dtype) + all_ranks = list(range(group.nranks)) + remote_nghrs = all_ranks.copy() + remote_nghrs.remove(self.group.my_rank) + + self.group.barrier() + # create a connection for each remote neighbor + self.nvlink_connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc) + self.nvls_connection = group.make_connection(all_ranks, Transport.Nvls) + min_gran = self.nvls_connection.get_multicast_min_granularity() + aligned_buffer_size = int(((buffer_size + min_gran - 1) // min_gran) * min_gran) + self.nvls_mem_handle = self.nvls_connection.allocate_bind_memory( + aligned_buffer_size + ) # just using recommended size for now + self.memory_ptr = self.nvls_mem_handle.get_device_ptr() + + self.cp_memory_ptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(self.memory_ptr, aligned_buffer_size, None), 0) + self.memory = cp.ndarray(nelem, memory_dtype, self.cp_memory_ptr) + + # create a sm_channel for each remote neighbor + self.semaphores = group.make_semaphore(self.nvlink_connections, SmDevice2DeviceSemaphore) + file_dir = os.path.dirname(os.path.abspath(__file__)) + self.kernel = KernelBuilder( + file="allreduce.cu", + kernel_name="allreduce6", + file_dir=file_dir, + macro_dict={"TYPE": type_str}, + ).get_compiled_kernel() + self.device_handles = [] + for rank in range(self.group.nranks): + if rank != self.group.my_rank: + self.device_handles.append(self.semaphores[rank].device_handle().raw) + + self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8) + self.nvls_handle = self.nvls_mem_handle.device_handle().raw + + self.set_params(nblocks, block_size) + + def get_memory(self): + return self.memory + + def __call__(self, stream_ptr): + self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr) + return self.memory + + def set_params(self, nblocks, block_size): + self.nblocks = nblocks + self.block_size = block_size + self.params = b"" + self.params += pack( + self.device_handles_cp, + self.nvls_handle, + self.memory, + self.group.my_rank, + self.group.nranks, + ctypes.c_size_t(self.memory.size), + ) + + def auto_tune(self): + nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108] + block_size_to_try = [256, 512, 1024] + for nblocks in nblocks_to_try: + for block_size in block_size_to_try: + self.set_params(nblocks, block_size) + yield nblocks, block_size diff --git a/python/test/nvls_test.cu b/python/test/nvls_test.cu new file mode 100644 index 000000000..022b4d6ca --- /dev/null +++ b/python/test/nvls_test.cu @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include +#include + +__device__ mscclpp::DeviceSyncer deviceSyncer; + +extern "C" __global__ void __launch_bounds__(1024, 1) + nvls_test(mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, + mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks, int nbytes) { + int nelem = nbytes / sizeof(float); + float* dev_ptr = (float*)nvlsPtrs.devicePtr; + float* mc_ptr = (float*)nvlsPtrs.mcPtr; + int tid = threadIdx.x; + int bid = blockIdx.x; + + for (int idx = bid * blockDim.x + tid; idx < nelem; idx += blockDim.x * gridDim.x) { + dev_ptr[idx] = my_rank; + } + deviceSyncer.sync(gridDim.x); + if (tid == 0 && bid == 0) { + __threadfence_system(); + } + + if (bid == 0) { + if (tid < nranks && tid != my_rank) { + semaphores[tid].signal(); + semaphores[tid].wait(); + } + } + deviceSyncer.sync(gridDim.x); + + int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks; + int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks; + + int my_offset = (tid + bid * blockDim.x) * 4; + int my_step = blockDim.x * gridDim.x * 4; + + for (int idx = my_st + my_offset; idx < my_en; idx += my_step) { + uint4 val; + nvlsPtrs.multimemLoad(val, mc_ptr + idx); + nvlsPtrs.multimemStore(val, mc_ptr + idx); + } + + deviceSyncer.sync(gridDim.x); + if (tid == 0 && bid == 0) { + __threadfence_system(); + } + + if (bid == 0) { + if (tid < nranks && tid != my_rank) { + semaphores[tid].signal(); + semaphores[tid].wait(); + } + } + deviceSyncer.sync(gridDim.x); + + for (int idx = bid * blockDim.x + tid; idx < nelem; idx += blockDim.x * gridDim.x) { + if (dev_ptr[idx] != ((nranks * (nranks - 1)) / 2)) { + __assert_fail("dev_ptr[idx] != nranks", __FILE__, __LINE__, __PRETTY_FUNCTION__); + } + } +} diff --git a/python/test/test_mscclpp.py b/python/test/test_mscclpp.py index f3a7f9dd6..4b3cb6ebf 100644 --- a/python/test/test_mscclpp.py +++ b/python/test/test_mscclpp.py @@ -12,6 +12,7 @@ import pytest from mscclpp import ( + EndpointConfig, Fifo, Host2DeviceSemaphore, Host2HostSemaphore, @@ -19,6 +20,7 @@ SmDevice2DeviceSemaphore, TcpBootstrap, Transport, + is_nvls_supported, ) import mscclpp.comm as mscclpp_comm from mscclpp.utils import KernelBuilder, pack @@ -117,13 +119,15 @@ def init_target(): mpi_group.comm.barrier() -def create_and_connect(mpi_group: MpiGroup, transport: str): - if transport == "NVLink" and all_ranks_on_the_same_node(mpi_group) is False: - pytest.skip("cannot use nvlink for cross node") - group = mscclpp_comm.CommGroup(mpi_group.comm) +def create_connection(group: mscclpp_comm.CommGroup, transport: str): + if transport == "NVLS": + all_ranks = list(range(group.nranks)) + tran = Transport.Nvls + connection = group.make_connection(all_ranks, tran) + return connection - remote_nghrs = list(range(mpi_group.comm.size)) - remote_nghrs.remove(mpi_group.comm.rank) + remote_nghrs = list(range(group.nranks)) + remote_nghrs.remove(group.my_rank) if transport == "NVLink": tran = Transport.CudaIpc elif transport == "IB": @@ -131,20 +135,28 @@ def create_and_connect(mpi_group: MpiGroup, transport: str): else: assert False connections = group.make_connection(remote_nghrs, tran) - return group, connections + return connections + + +def create_group_and_connection(mpi_group: MpiGroup, transport: str): + if (transport == "NVLink" or transport == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False: + pytest.skip("cannot use nvlink/nvls for cross node") + group = mscclpp_comm.CommGroup(mpi_group.comm) + connection = create_connection(group, transport) + return group, connection @parametrize_mpi_groups(2, 4, 8, 16) @pytest.mark.parametrize("transport", ["IB", "NVLink"]) def test_group_with_connections(mpi_group: MpiGroup, transport: str): - create_and_connect(mpi_group, transport) + create_group_and_connection(mpi_group, transport) @parametrize_mpi_groups(2, 4, 8, 16) @pytest.mark.parametrize("transport", ["IB", "NVLink"]) @pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int): - group, connections = create_and_connect(mpi_group, transport) + group, connections = create_group_and_connection(mpi_group, transport) memory = cp.zeros(nelem, dtype=cp.int32) nelemPerRank = nelem // group.nranks sizePerRank = nelemPerRank * memory.itemsize @@ -185,7 +197,7 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, if device == "cpu" and transport == "NVLink": pytest.skip("nvlink doesn't work with host allocated memory") - group, connections = create_and_connect(mpi_group, transport) + group, connections = create_group_and_connection(mpi_group, transport) xp = cp if device == "cuda" else np if group.my_rank == 0: memory = xp.random.randn(nelem) @@ -229,7 +241,7 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, @parametrize_mpi_groups(2, 4, 8, 16) def test_h2h_semaphores(mpi_group: MpiGroup): - group, connections = create_and_connect(mpi_group, "IB") + group, connections = create_group_and_connection(mpi_group, "IB") semaphores = group.make_semaphore(connections, Host2HostSemaphore) for rank in connections: @@ -242,7 +254,7 @@ def test_h2h_semaphores(mpi_group: MpiGroup): @parametrize_mpi_groups(2, 4, 8, 16) def test_h2h_semaphores_gil_release(mpi_group: MpiGroup): - group, connections = create_and_connect(mpi_group, "IB") + group, connections = create_group_and_connection(mpi_group, "IB") semaphores = group.make_semaphore(connections, Host2HostSemaphore) @@ -267,6 +279,24 @@ def target_signal(sems, conns): group.barrier() +@parametrize_mpi_groups(8) +@pytest.mark.skipif(is_nvls_supported() is False, reason="NVLS is not supported") +def test_nvls_connection(mpi_group: MpiGroup): + if all_ranks_on_the_same_node(mpi_group) is False: + pytest.skip("cannot use nvls for cross node") + group = mscclpp_comm.CommGroup(mpi_group.comm) + all_ranks = list(range(group.nranks)) + endpoint = EndpointConfig(Transport.Nvls, 2**22) + nvls_connection = group.make_connection(all_ranks, endpoint) + mem_handle1 = nvls_connection.allocate_bind_memory(2**21) + mem_handle2 = nvls_connection.allocate_bind_memory(2**21) + with pytest.raises(Exception): + mem_handle3 = nvls_connection.allocate_bind_memory(2**21) + # the memory is freed on the destructor of mem_handle2 + mem_handle2 = None + mem_handle3 = nvls_connection.allocate_bind_memory(2**21) + + class MscclppKernel: def __init__( self, @@ -278,6 +308,8 @@ def __init__( use_packet=False, scratch=None, fifo=None, + nvls_mem_handle=None, + nvls_buffer_size=None, ): file_dir = os.path.dirname(os.path.abspath(__file__)) if test_name == "h2d_semaphore": @@ -316,11 +348,17 @@ def __init__( ).get_compiled_kernel() self.nblocks = 1 self.nthreads = 1024 + elif test_name == "nvls": + self._kernel = KernelBuilder( + file="nvls_test.cu", kernel_name="nvls_test", file_dir=file_dir + ).get_compiled_kernel() + self.nblocks = 64 + self.nthreads = 1024 else: assert False self.params = b"" - if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "simple_proxy_channel"]: + if semaphore_or_channels != None: first_arg = next(iter(semaphore_or_channels.values())) size_of_semaphore_or_channels = len(first_arg.device_handle().raw) device_handles = [] @@ -333,6 +371,8 @@ def __init__( device_handles.append(semaphore_or_channels[rank].device_handle().raw) # keep a reference to the device handles so that they don't get garbage collected self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(device_handles)), dtype=cp.uint8) + + if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "simple_proxy_channel"]: self.params += pack(self._d_semaphore_or_channels, my_rank, nranks) if test_name == "sm_channel": self.params += pack(tensor.size, use_packet) @@ -341,9 +381,13 @@ def __init__( elif test_name == "fifo": self.params = fifo.device_handle().raw elif test_name == "proxy": - semaphore_device_handles = [semaphore.device_handle().raw for semaphore in semaphore_or_channels] - self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(semaphore_device_handles)), dtype=cp.uint8) self.params = pack(my_rank, nranks) + fifo.raw + pack(self._d_semaphore_or_channels) + elif test_name == "nvls": + self.params = ( + nvls_mem_handle.device_handle().raw + + pack(self._d_semaphore_or_channels) + + pack(my_rank, nranks, nvls_buffer_size) + ) def __call__(self): return self._kernel.launch_kernel(self.params, self.nblocks, self.nthreads, 0, None) @@ -356,7 +400,7 @@ def signal(semaphores): for rank in semaphores: semaphores[rank].signal() - group, connections = create_and_connect(mpi_group, transport) + group, connections = create_group_and_connection(mpi_group, transport) semaphores = group.make_semaphore(connections, Host2DeviceSemaphore) kernel = MscclppKernel("h2d_semaphore", group.my_rank, group.nranks, semaphores) @@ -372,7 +416,7 @@ def signal(semaphores): @parametrize_mpi_groups(2, 4, 8, 16) def test_d2d_semaphores(mpi_group: MpiGroup): - group, connections = create_and_connect(mpi_group, "NVLink") + group, connections = create_group_and_connection(mpi_group, "NVLink") semaphores = group.make_semaphore(connections, SmDevice2DeviceSemaphore) group.barrier() @@ -386,7 +430,7 @@ def test_d2d_semaphores(mpi_group: MpiGroup): @pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) @pytest.mark.parametrize("use_packet", [False, True]) def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool): - group, connections = create_and_connect(mpi_group, "NVLink") + group, connections = create_group_and_connection(mpi_group, "NVLink") memory = cp.zeros(nelem, dtype=cp.int32) if use_packet: @@ -434,7 +478,7 @@ def test_fifo( @pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) @pytest.mark.parametrize("transport", ["IB", "NVLink"]) def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str): - group, connections = create_and_connect(mpi_group, transport) + group, connections = create_group_and_connection(mpi_group, transport) memory = cp.zeros(nelem, dtype=cp.int32) nelemPerRank = nelem // group.nranks @@ -468,7 +512,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str): fifo_device_handle = proxy.fifo_device_handle() kernel = MscclppKernel( - "proxy", my_rank=group.my_rank, nranks=group.nranks, semaphore_or_channels=list_sem, fifo=fifo_device_handle + "proxy", my_rank=group.my_rank, nranks=group.nranks, semaphore_or_channels=semaphores, fifo=fifo_device_handle ) proxy.start() group.barrier() @@ -484,7 +528,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str): @pytest.mark.parametrize("transport", ["NVLink", "IB"]) @pytest.mark.parametrize("use_packet", [False, True]) def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool): - group, connections = create_and_connect(mpi_group, transport) + group, connections = create_group_and_connection(mpi_group, transport) memory = cp.zeros(nelem, dtype=cp.int32) if use_packet: @@ -522,3 +566,27 @@ def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, u proxy_service.stop_proxy() group.barrier() assert cp.array_equal(memory, memory_expected) + + +@parametrize_mpi_groups(4, 8) +@pytest.mark.skipif(is_nvls_supported() is False, reason="NVLS is not supported") +def test_nvls(mpi_group: MpiGroup): + group, nvls_connection = create_group_and_connection(mpi_group, "NVLS") + nbytes = 2**21 + mem_handle = nvls_connection.allocate_bind_memory(nbytes) + + nvlinks_connections = create_connection(group, "NVLink") + semaphores = group.make_semaphore(nvlinks_connections, SmDevice2DeviceSemaphore) + + kernel = MscclppKernel( + "nvls", + my_rank=group.my_rank, + nranks=group.nranks, + nvls_mem_handle=mem_handle, + nvls_buffer_size=nbytes, + semaphore_or_channels=semaphores, + ) + + kernel() + cp.cuda.runtime.deviceSynchronize() + group.barrier() diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index 649a1f62e..00a58b992 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -35,6 +35,20 @@ struct ExtInfo { SocketAddress extAddressListen; }; +MSCCLPP_API_CPP void Bootstrap::groupBarrier(const std::vector& ranks) { + int dummy = 0; + for (auto rank : ranks) { + if (rank != this->getRank()) { + this->send(static_cast(&dummy), sizeof(dummy), rank, 0); + } + } + for (auto rank : ranks) { + if (rank != this->getRank()) { + this->recv(static_cast(&dummy), sizeof(dummy), rank, 0); + } + } +} + MSCCLPP_API_CPP void Bootstrap::send(const std::vector& data, int peer, int tag) { size_t size = data.size(); send((void*)&size, sizeof(size_t), peer, tag); diff --git a/src/communicator.cc b/src/communicator.cc index d2f0e6172..d0fb07a23 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -105,6 +105,45 @@ MSCCLPP_API_CPP NonblockingFuture> Communicator::con return NonblockingFuture>(connector->connectionPromise_.get_future()); } +MSCCLPP_API_CPP std::shared_ptr Communicator::connctNvlsCollective(std::vector allRanks, + EndpointConfig config) { + auto bootstrap = this->bootstrap(); + int rank = bootstrap->getRank(); + bool isRoot = false; + bool amongAllRanks = false; + int rootRank = allRanks[0]; + for (auto nvlsRank : allRanks) { + if (nvlsRank == rank) amongAllRanks = true; + rootRank = std::min(rootRank, nvlsRank); + } + if (amongAllRanks == false) { + throw Error("rank is not among allRanks", ErrorCode::InvalidUsage); + } + if (rootRank == rank) isRoot = true; + + std::shared_ptr conn; + if (isRoot) { + conn = std::make_shared(config.nvlsBufferSize, allRanks.size()); + auto serialized = conn->serialize(); + for (auto nvlsRank : allRanks) { + if (nvlsRank != rank) bootstrap->send(serialized, nvlsRank, 0); + } + } else { + std::vector data; + bootstrap->recv(data, rootRank, 0); + conn = std::make_shared(data); + } + + // Now let's synchronize all ranks + bootstrap->groupBarrier(allRanks); + // now it is safe to add my device + conn->addDevice(); + + // sync here to make sure all ranks have added their devices + bootstrap->groupBarrier(allRanks); + return conn; +} + MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) { return pimpl_->connectionInfos_.at(&connection).remoteRank; } diff --git a/src/connection.cc b/src/connection.cc index 834a1456c..4d719f3b2 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -10,7 +10,6 @@ #include "endpoint.hpp" #include "infiniband/verbs.h" #include "npkit/npkit.h" -#include "registered_memory.hpp" namespace mscclpp { diff --git a/src/include/context.hpp b/src/include/context.hpp index 39a699560..abb95b27d 100644 --- a/src/include/context.hpp +++ b/src/include/context.hpp @@ -17,6 +17,7 @@ struct Context::Impl { std::vector> connections_; std::unordered_map> ibContexts_; CudaStreamWithFlags ipcStream_; + CUmemGenericAllocationHandle mcHandle_; Impl(); diff --git a/src/nvls_connection.cc b/src/nvls_connection.cc new file mode 100644 index 000000000..78f3e52d5 --- /dev/null +++ b/src/nvls_connection.cc @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include + +#include +#include +#include + +#include "debug.h" +#include "endpoint.hpp" + +namespace mscclpp { + +#if (USE_NVLS) +class NvlsConnection::Impl : public std::enable_shared_from_this { + public: + // use this only for the root of the NVLS + Impl(size_t bufferSize, int numDevices); + Impl(const std::vector& data); + ~Impl(); + + Impl(const Impl&) = delete; + Impl& operator=(const Impl&) = delete; + + size_t getMinMcGran() { return minMcGran_; } + std::vector serialize(); + void addDevice(int cudaDeviceId); + size_t allocateBuffer(size_t size); + void freeBuffer(size_t offset, size_t size) noexcept; + std::shared_ptr bindMemory(std::shared_ptr> physicalMem, size_t devBuffSize); + + private: + friend class NvlsConnection; + CUmemGenericAllocationHandle mcHandle_; + CUmulticastObjectProp mcProp_; + size_t bufferSize_; + size_t minMcGran_; + size_t mcGran_; + // These are only defined for multicast (NVLS) capability + pid_t rootPid_; + int mcFileDesc_; + + std::list> allocatedRanges_; + std::list> freeRanges_; +}; + +NvlsConnection::Impl::Impl(size_t bufferSize, int numDevices) { + minMcGran_ = 0; + mcGran_ = 0; + mcProp_ = {}; + mcProp_.size = bufferSize; + mcProp_.numDevices = numDevices; + mcProp_.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + MSCCLPP_CUTHROW(cuMulticastGetGranularity(&minMcGran_, &mcProp_, CU_MULTICAST_GRANULARITY_MINIMUM)); + MSCCLPP_CUTHROW(cuMulticastGetGranularity(&mcGran_, &mcProp_, CU_MULTICAST_GRANULARITY_RECOMMENDED)); + mcProp_.size = ((mcProp_.size + minMcGran_ - 1) / minMcGran_) * minMcGran_; + bufferSize_ = mcProp_.size; + MSCCLPP_CUTHROW(cuMulticastCreate(&mcHandle_, &mcProp_)); + mcFileDesc_ = 0; + MSCCLPP_CUTHROW( + cuMemExportToShareableHandle(&mcFileDesc_, mcHandle_, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0 /*flags*/)); + freeRanges_.emplace_back(0, bufferSize_); + + rootPid_ = getpid(); + if (rootPid_ < 0) { + throw mscclpp::SysError("getpid() failed", errno); + } + + INFO(MSCCLPP_COLL, "NVLS handle created on root with size %ld. minGranularity %ld and recommendedGranularity %ld\n", + mcProp_.size, minMcGran_, mcGran_); +} + +NvlsConnection::Impl::Impl(const std::vector& data) { + auto it = data.begin(); + std::copy_n(it, sizeof(this->mcHandle_), reinterpret_cast(&this->mcHandle_)); + it += sizeof(this->mcHandle_); + std::copy_n(it, sizeof(this->bufferSize_), reinterpret_cast(&this->bufferSize_)); + it += sizeof(this->bufferSize_); + std::copy_n(it, sizeof(this->minMcGran_), reinterpret_cast(&this->minMcGran_)); + it += sizeof(this->minMcGran_); + std::copy_n(it, sizeof(this->mcGran_), reinterpret_cast(&this->mcGran_)); + it += sizeof(this->mcGran_); + std::copy_n(it, sizeof(this->rootPid_), reinterpret_cast(&this->rootPid_)); + it += sizeof(this->rootPid_); + std::copy_n(it, sizeof(this->mcFileDesc_), reinterpret_cast(&this->mcFileDesc_)); + + freeRanges_.emplace_back(0, bufferSize_); + int rootPidFd = syscall(SYS_pidfd_open, rootPid_, 0); + if (rootPidFd < 0) { + throw mscclpp::SysError("pidfd_open() failed", errno); + } + int mcRootFileDescFd = syscall(SYS_pidfd_getfd, rootPidFd, mcFileDesc_, 0); + if (mcRootFileDescFd < 0) { + throw mscclpp::SysError("pidfd_getfd() failed", errno); + } + MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&mcHandle_, reinterpret_cast(mcRootFileDescFd), + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + close(rootPidFd); + close(mcRootFileDescFd); + + INFO(MSCCLPP_COLL, "NVLS handle was imported from root"); +} + +NvlsConnection::Impl::~Impl() { + // we don't need to free multicast handle object according to NCCL. + if (rootPid_ == getpid()) { + close(mcFileDesc_); + } +} + +std::vector NvlsConnection::Impl::serialize() { + std::vector result; + std::copy_n(reinterpret_cast(&mcHandle_), sizeof(mcHandle_), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&bufferSize_), sizeof(bufferSize_), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&minMcGran_), sizeof(minMcGran_), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&mcGran_), sizeof(mcGran_), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&rootPid_), sizeof(rootPid_), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&mcFileDesc_), sizeof(mcFileDesc_), std::back_inserter(result)); + return result; +} + +void NvlsConnection::Impl::addDevice(int cudaDeviceId) { + MSCCLPP_CUTHROW(cuMulticastAddDevice(mcHandle_, cudaDeviceId)); + INFO(MSCCLPP_COLL, "NVLS connection created"); +} + +size_t NvlsConnection::Impl::allocateBuffer(size_t size) { + if (freeRanges_.empty()) { + throw Error("This NVLS connection mapped more than it was supposed to", ErrorCode::InvalidUsage); + } + auto it = std::find_if(freeRanges_.begin(), freeRanges_.end(), + [size](const std::pair& range) { return range.second >= size; }); + if (it != freeRanges_.end()) { + size_t offset = it->first; + size_t rangeSize = it->second; + if (rangeSize == size) { + freeRanges_.erase(it); + } else { + it->first += size; + it->second -= size; + } + allocatedRanges_.emplace_back(offset, size); + INFO(MSCCLPP_COLL, "NVLS connection allocated %ld bytes at offset %ld", size, offset); + return offset; + } + throw Error("This NVLS connection cannot map the requested devBuffSize", ErrorCode::InvalidUsage); +} + +void NvlsConnection::Impl::freeBuffer(size_t offset, size_t size) noexcept { + auto it = std::find_if( + allocatedRanges_.begin(), allocatedRanges_.end(), + [offset, size](const std::pair& range) { return range.first == offset && range.second == size; }); + if (it == allocatedRanges_.end()) { + WARN("NVLS connection tried to free a buffer that was not allocated"); + return; + } + allocatedRanges_.erase(it); + it = std::find_if(freeRanges_.begin(), freeRanges_.end(), [offset, size](const std::pair& range) { + return range.first + range.second >= offset; + }); + if (it == freeRanges_.end()) { + freeRanges_.emplace_back(offset, size); + return; + } + if (it->first + it->second == offset) { + // merge with the previous free range if possible + it->second += size; + // merge with the next free range if possible + auto nextItr = std::next(it); + if (nextItr != freeRanges_.end() && it->first + it->second == nextItr->first) { + it->second += nextItr->second; + freeRanges_.erase(nextItr); + } + return; + } else if (it->first == offset + size) { + // merge with the next free range if possible + it->first -= size; + it->second += size; + return; + } else { + freeRanges_.emplace(it, offset, size); + return; + } +} + +std::shared_ptr NvlsConnection::Impl::bindMemory(std::shared_ptr> physicalMem, + size_t devBuffSize) { + size_t offset = allocateBuffer(devBuffSize); + MSCCLPP_CUTHROW( + cuMulticastBindMem(mcHandle_, offset /*mcOffset*/, physicalMem->memHandle_, 0 /*memOffset*/, devBuffSize, 0)); + + char* mcPtr; + + CUmemAccessDesc accessDesc = {}; + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + int deviceId = -1; + MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); + accessDesc.location.id = deviceId; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)(&mcPtr), devBuffSize, minMcGran_, 0U, 0)); + MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)(mcPtr), devBuffSize, 0, mcHandle_, 0)); + MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)(mcPtr), devBuffSize, &accessDesc, 1)); + + auto deleter = [=, self = shared_from_this()](char* ptr) { + CUdevice device; + MSCCLPP_CUTHROW(cuDeviceGet(&device, deviceId)); + MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, devBuffSize)); + MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr, devBuffSize)); + MSCCLPP_CUTHROW(cuMulticastUnbind(mcHandle_, device, offset, devBuffSize)); + self->freeBuffer(offset, devBuffSize); + }; + + return std::shared_ptr(mcPtr, deleter); +} +#else // !(USE_NVLS) +class NvlsConnection::Impl { + public: + // use this only for the root of the NVLS + Impl(size_t, int) { throw notSupportedError; } + Impl(const std::vector&) { throw notSupportedError; } + + Impl(const Impl&) = delete; + Impl& operator=(const Impl&) = delete; + + std::vector serialize() { throw notSupportedError; } + size_t allocateBuffer(size_t) { throw notSupportedError; } + void freeBuffer(size_t, size_t) { throw notSupportedError; } + std::shared_ptr bindMemory(std::shared_ptr>, size_t) { throw notSupportedError; } + void addDevice(int) { throw notSupportedError; } + size_t getMinMcGran() { throw notSupportedError; } + + private: + Error notSupportedError = Error("NVLS is not supported on this CUDA version", ErrorCode::InvalidUsage); +}; +#endif // !(USE_NVLS) + +NvlsConnection::NvlsConnection(size_t bufferSize, int numDevices) + : pimpl_(std::make_shared(bufferSize, numDevices)) {} + +void NvlsConnection::addDevice() { + int cudaDeviceId; + MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDeviceId)); + this->addDevice(cudaDeviceId); +} + +void NvlsConnection::addDevice(int cudaDeviceId) { pimpl_->addDevice(cudaDeviceId); } + +NvlsConnection::NvlsConnection(const std::vector& data) : pimpl_(std::make_shared(data)) {} + +std::vector NvlsConnection::serialize() { return pimpl_->serialize(); } + +std::shared_ptr NvlsConnection::allocateAndBindCuda(size_t size) { + auto mem = allocSharedPhysicalCuda(size, pimpl_->getMinMcGran()); + auto mcPtr = pimpl_->bindMemory(mem, size); + return std::make_shared(mem, mcPtr, size); +} + +NvlsConnection::DeviceMulticastPointer::DeviceHandle NvlsConnection::DeviceMulticastPointer::deviceHandle() { + NvlsConnection::DeviceMulticastPointer::DeviceHandle device; + device.devicePtr = this->deviceMem_->devicePtr_; + device.mcPtr = this->mcPtr_.get(); + device.bufferSize = this->bufferSize_; + return device; +}; + +char* NvlsConnection::DeviceMulticastPointer::getDevicePtr() { return deviceMem_->devicePtr_; }; + +size_t NvlsConnection::getMultiCastMinGranularity() { return pimpl_->getMinMcGran(); } + +} // namespace mscclpp diff --git a/src/utils.cc b/src/utils.cc index 7153d55c5..8475f2f60 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -66,4 +67,15 @@ std::string getHostName(int maxlen, const char delim) { return hostname.substr(0, i); } +bool isNvlsSupported() { +#if (CUDART_VERSION >= 12010) + CUdevice dev; + int isNvlsSupported; + MSCCLPP_CUTHROW(cuCtxGetDevice(&dev)); + MSCCLPP_CUTHROW(cuDeviceGetAttribute(&isNvlsSupported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, dev)); + return isNvlsSupported == 1; +#endif + return false; +} + } // namespace mscclpp diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 087fdeb86..ef85cde5a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -23,6 +23,7 @@ endfunction() add_test_executable(allgather_test_cpp allgather_test_cpp.cu) add_test_executable(allgather_test_host_offloading allgather_test_host_offloading.cu) +add_test_executable(nvls_test nvls_test.cu) configure_file(run_mpi_test.sh.in run_mpi_test.sh) diff --git a/test/nvls_test.cu b/test/nvls_test.cu new file mode 100644 index 000000000..e01b4d790 --- /dev/null +++ b/test/nvls_test.cu @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#define CUCHECK(cmd) \ + do { \ + auto err = cmd; \ + if (err != 0) { \ + printf("Cuda failure %d: Line %d", err, __LINE__); \ + exit(-1); \ + } \ + } while (false) + +// AR kernel snippet for sm_90 only + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 +#define MULTIMEM_ST(val, ptr) \ + asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), \ + "r"(val.w) \ + : "memory"); +// specific PTX for fp16 reduction. bf16 would be multimem.ld_reduce.global.add.v4.bf16x2 etc +#define MULTIMEM_LD(val, ptr) \ + asm("multimem.ld_reduce.global.add.v4.f32 {%0,%1,%2,%3}, [%4];" \ + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) \ + : "l"(ptr) \ + : "memory"); +#else +#define MULTIMEM_ST(val, ptr) +#define MULTIMEM_LD(val, ptr) +#endif + +__global__ void init_kernel(float* uc_ptr, int size, int myrank, int nranks) { + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < size; idx += blockDim.x * gridDim.x) { + uc_ptr[idx] = myrank + idx; + } +} + +__global__ void check_correctness(float* uc_ptr, int size, int myrank, int nranks) { + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < size; idx += blockDim.x * gridDim.x) { + float expected = (float)((nranks * (nranks - 1)) / 2 + nranks * idx); + if (abs(uc_ptr[idx] - expected) > 0.01 * expected) { + printf("error! idx %d: %f != %f\n", idx, uc_ptr[idx], expected); + } + } +} + +__global__ void testing(float* mc_ptr, int size, int myrank, int nranks) { + // for allreduce we dont even need an UC pointer. just using same mc_ptr for in-place reduction + // line is assumed to be 16B 4 ints of 8 halves + int my_st = ((int64_t)size * (int64_t)myrank) / (int64_t)nranks; + int my_en = ((int64_t)size * (int64_t)(myrank + 1)) / (int64_t)nranks; + + int my_offset = (threadIdx.x + blockIdx.x * blockDim.x) * 4; + int my_step = blockDim.x * gridDim.x * 4; + + for (int idx = my_st + my_offset; idx < my_en; idx += my_step) { + [[maybe_unused]] uint4 val; + MULTIMEM_LD(val, mc_ptr + idx); + MULTIMEM_ST(val, mc_ptr + idx); + } +} + +int main() { +#if (USE_NVLS) + int myrank, nranks; + MPI_Init(NULL, NULL); + MPI_Comm_rank(MPI_COMM_WORLD, &myrank); + MPI_Comm_size(MPI_COMM_WORLD, &nranks); + + cudaSetDevice(myrank); + + size_t size = 1024 * 1024 * 512; + CUmemAllocationHandleType handleType = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + + CUmulticastObjectProp mcProp = {}; + mcProp.numDevices = nranks; + mcProp.size = size; + mcProp.handleTypes = handleType; + + size_t minGran, gran; + gran = 0; + minGran = 0; + CUCHECK(cuMulticastGetGranularity(&minGran, &mcProp, CU_MULTICAST_GRANULARITY_MINIMUM)); + CUCHECK(cuMulticastGetGranularity(&gran, &mcProp, CU_MULTICAST_GRANULARITY_RECOMMENDED)); + + if (!myrank) printf("nvls multicast granularity: gran = %lu, minGrad = %lu\n", gran, minGran); + size_t mcSize = ((size + gran - 1) / gran) * gran; + mcProp.size = mcSize; + + CUmemGenericAllocationHandle handle; + // only one rank creates the multicast object + if (!myrank) CUCHECK(cuMulticastCreate(&handle, &mcProp)); + + int fd = 0; + if (!myrank) CUCHECK(cuMemExportToShareableHandle(&fd, handle, handleType, 0 /*flags*/)); + + // some ugly UDS business + // Borrow ipcsocket.{c,h} from nccl code + // in cuda 12.4 new fabric handle type is available so instead it would be possible to use MPI_Allgather for the + // exported handles + // moreover it would the only way to do it on GraceHopper systems, since UDS is limited to single Unix node + + pid_t currentPid = getpid(); + MPI_Bcast(&fd, sizeof(fd), MPI_CHAR, 0, MPI_COMM_WORLD); + MPI_Bcast(¤tPid, sizeof(currentPid), MPI_CHAR, 0, MPI_COMM_WORLD); + int pidFd = syscall(SYS_pidfd_open, currentPid, 0); + + // MPI_Bcast(&fd, sizeof(fd), MPI_CHAR, 0, MPI_COMM_WORLD); + // everyone else would now have same multicast object + int peerFd = 0; + peerFd = syscall(SYS_pidfd_getfd, pidFd, fd, 0); + if (myrank) CUCHECK(cuMemImportFromShareableHandle(&handle, reinterpret_cast(peerFd), handleType)); + MPI_Barrier(MPI_COMM_WORLD); + + close(fd); + // end of ugly UDS business + // everyone adds device(s), no syncs required, just need to ensure bindmem happens after all this is called + int mydev = myrank; + CUCHECK(cuMulticastAddDevice(handle, mydev)); + MPI_Barrier(MPI_COMM_WORLD); + + CUmemGenericAllocationHandle memhandle; + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = mydev; + prop.requestedHandleTypes = handleType; + + // allocate physical memory (data buffer) + CUCHECK(cuMemCreate(&memhandle, size, &prop, 0 /*flags*/)); + + void* uc_va; + void* mc_va; + CUmemAccessDesc accessDesc = {}; + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = mydev; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + + // Map a VA to UC space + CUCHECK(cuMemAddressReserve((CUdeviceptr*)&uc_va, size, minGran, 0U, 0)); + cudaMemset(uc_va, 0, size); + CUCHECK(cuMemMap((CUdeviceptr)uc_va, size, 0, memhandle, 0)); + // set access on UC address + CUCHECK(cuMemSetAccess((CUdeviceptr)uc_va, size, &accessDesc, 1)); + + // everyone binds memory to the multicast + CUCHECK(cuMulticastBindMem(handle, 0 /*mcOffset*/, memhandle, 0 /*memOffset*/, size, 0)); + MPI_Barrier(MPI_COMM_WORLD); + // usual VA business: map both MC and PA to two different VA addresses + + // Map a VA to MC space + CUCHECK(cuMemAddressReserve((CUdeviceptr*)&mc_va, mcSize, minGran, 0U, 0)); + CUCHECK(cuMemMap((CUdeviceptr)mc_va, mcSize, 0, handle, 0)); + // set access on MC address + CUCHECK(cuMemSetAccess((CUdeviceptr)mc_va, mcSize, &accessDesc, 1)); + + int rept = 10; + int block_size = 1024; + int nblocks = 16; + + cudaDeviceSynchronize(); + MPI_Barrier(MPI_COMM_WORLD); + init_kernel<<>>((float*)uc_va, size / sizeof(float), myrank, nranks); + cudaDeviceSynchronize(); + MPI_Barrier(MPI_COMM_WORLD); + testing<<>>((float*)mc_va, size / sizeof(float), myrank, nranks); + cudaDeviceSynchronize(); + MPI_Barrier(MPI_COMM_WORLD); + check_correctness<<>>((float*)uc_va, size / sizeof(float), myrank, nranks); + cudaDeviceSynchronize(); + MPI_Barrier(MPI_COMM_WORLD); + + for (size_t input_size = 1024; input_size <= size; input_size *= 2) { + // warmup + for (int i = 0; i < rept; i++) { + testing<<>>((float*)mc_va, input_size / sizeof(float), myrank, nranks); + } + cudaDeviceSynchronize(); + MPI_Barrier(MPI_COMM_WORLD); + double st = MPI_Wtime(); + for (int i = 0; i < rept; i++) { + testing<<>>((float*)mc_va, input_size / sizeof(float), myrank, nranks); + } + cudaDeviceSynchronize(); + double en = MPI_Wtime(); + double time = (en - st) / rept; + if (!myrank) + printf("input_size %ld | Time = %f us, alg_bw = %f (GBps)\n", input_size, time * 1e6, input_size / 1e9 / time); + } + MPI_Barrier(MPI_COMM_WORLD); + MPI_Finalize(); +#endif // (USE_NVLS) +}