diff --git a/CMakeLists.txt b/CMakeLists.txt index 302febab7..b77c7e1b6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,6 +18,7 @@ option(ENABLE_TRACE "Enable tracing" OFF) option(USE_NPKIT "Use NPKIT" ON) option(BUILD_TESTS "Build tests" ON) option(BUILD_PYTHON_BINDINGS "Build Python bindings" ON) +option(BUILD_APPS_NCCL "Build NCCL interfaces" OFF) option(USE_CUDA "Use NVIDIA/CUDA." OFF) option(USE_ROCM "Use AMD/ROCm." OFF) option(BYPASS_GPU_CHECK "Bypass GPU check." OFF) @@ -101,7 +102,7 @@ find_package(Threads REQUIRED) add_library(mscclpp_obj OBJECT) target_include_directories(mscclpp_obj - PRIVATE + SYSTEM PRIVATE ${GPU_INCLUDE_DIRS} ${IBVERBS_INCLUDE_DIRS} ${NUMA_INCLUDE_DIRS}) @@ -151,3 +152,8 @@ endif() if(BUILD_PYTHON_BINDINGS) add_subdirectory(python) endif() + +# NCCL interfaces +if(BUILD_APPS_NCCL) + add_subdirectory(apps/nccl) +endif() diff --git a/apps/nccl/CMakeLists.txt b/apps/nccl/CMakeLists.txt new file mode 100644 index 000000000..33f385da0 --- /dev/null +++ b/apps/nccl/CMakeLists.txt @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS src/*) +file(GLOB_RECURSE HEADERS CONFIGURE_DEPENDS include/nccl.h) + +if(USE_ROCM) + set_source_files_properties(${SOURCES} PROPERTIES LANGUAGE CXX) +endif() + +add_library(mscclpp_nccl_obj OBJECT) +target_sources(mscclpp_nccl_obj PRIVATE ${SOURCES}) +target_sources(mscclpp_nccl_obj PUBLIC FILE_SET HEADERS FILES ${HEADERS}) +target_include_directories(mscclpp_nccl_obj PRIVATE include SYSTEM PRIVATE ${GPU_INCLUDE_DIRS}) +target_link_libraries(mscclpp_nccl_obj PRIVATE ${GPU_LIBRARIES} PUBLIC mscclpp_obj) +set_target_properties(mscclpp_nccl_obj PROPERTIES LINKER_LANGUAGE CXX POSITION_INDEPENDENT_CODE 1 VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION}) +if(USE_CUDA) + target_compile_definitions(mscclpp_nccl_obj PRIVATE USE_CUDA) +elseif(USE_ROCM) + target_compile_definitions(mscclpp_nccl_obj PRIVATE USE_ROCM) +endif() + +add_library(mscclpp_nccl SHARED) +target_link_libraries(mscclpp_nccl PUBLIC mscclpp_obj mscclpp_nccl_obj) +set_target_properties(mscclpp_nccl PROPERTIES VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION}) +add_library(mscclpp_nccl_static STATIC) +target_link_libraries(mscclpp_nccl_static PUBLIC mscclpp_obj mscclpp_nccl_obj) +set_target_properties(mscclpp_nccl_static PROPERTIES VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION}) + +install(TARGETS mscclpp_nccl_obj + FILE_SET HEADERS DESTINATION ${INSTALL_PREFIX}/include) +install(TARGETS mscclpp_nccl + LIBRARY DESTINATION ${INSTALL_PREFIX}/lib) +install(TARGETS mscclpp_nccl_static + ARCHIVE DESTINATION ${INSTALL_PREFIX}/lib) + +if(BUILD_TESTS) + add_subdirectory(test) +endif() diff --git a/apps/nccl/README.md b/apps/nccl/README.md new file mode 100644 index 000000000..62a24d9e2 --- /dev/null +++ b/apps/nccl/README.md @@ -0,0 +1,14 @@ +# NCCL Interfaces of MSCCL++ + +Compile + +```bash +CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_APPS_NCCL=ON .. +make -j +``` + +Run rccl-tests + +```bash +mpirun -np 8 --bind-to numa --allow-run-as-root -x LD_PRELOAD=$MSCCLPP_BUILD/apps/nccl/libmscclpp_nccl.so -x HIP_FORCE_DEV_KERNARG=1 -x HSA_ENABLE_IPC_MODE_LEGACY=1 -x MSCCLPP_DEBUG=WARN -x MSCCLPP_DEBUG_SUBSYS=ALL -x NCCL_DEBUG=WARN ./build/all_reduce_perf -b 1K -e 256M -f 2 -d half -G 20 -w 10 -n 50 +``` diff --git a/apps/nccl/include/nccl.h b/apps/nccl/include/nccl.h new file mode 100644 index 000000000..f94173ef8 --- /dev/null +++ b/apps/nccl/include/nccl.h @@ -0,0 +1,468 @@ +/************************************************************************* + * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef NCCL_H_ +#define NCCL_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#include +/* Opaque handle to communicator */ +typedef struct ncclComm* ncclComm_t; +#define NCCL_COMM_NULL NULL + +#define NCCL_UNIQUE_ID_BYTES 128 +typedef struct { char internal[NCCL_UNIQUE_ID_BYTES]; } ncclUniqueId; + +/* Error type */ +typedef enum { ncclSuccess = 0, + ncclUnhandledCudaError = 1, + ncclSystemError = 2, + ncclInternalError = 3, + ncclInvalidArgument = 4, + ncclInvalidUsage = 5, + ncclRemoteError = 6, + ncclInProgress = 7, + ncclNumResults = 8 } ncclResult_t; + +#define NCCL_CONFIG_UNDEF_INT INT_MIN +#define NCCL_CONFIG_UNDEF_PTR NULL +#define NCCL_SPLIT_NOCOLOR -1 + +/* Communicator configuration. Users can assign value to attributes to specify the + * behavior of a communicator. */ +typedef struct ncclConfig_v21700 { + /* attributes that users should never touch. */ + size_t size; + unsigned int magic; + unsigned int version; + /* attributes that users are able to customize. */ + int blocking; + int cgaClusterSize; + int minCTAs; + int maxCTAs; + const char *netName; + int splitShare; +} ncclConfig_t; + +/* Config initializer must be assigned to initialize config structure when it is created. + * Not initialized config will result in NCCL error. */ +#define NCCL_CONFIG_INITIALIZER { \ + sizeof(ncclConfig_t), /* size */ \ + 0xcafebeef, /* magic */ \ + NCCL_VERSION(NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH), /* version */ \ + NCCL_CONFIG_UNDEF_INT, /* blocking */ \ + NCCL_CONFIG_UNDEF_INT, /* cgaClusterSize */ \ + NCCL_CONFIG_UNDEF_INT, /* minCTAs */ \ + NCCL_CONFIG_UNDEF_INT, /* maxCTAs */ \ + NCCL_CONFIG_UNDEF_PTR, /* netName */ \ + NCCL_CONFIG_UNDEF_INT /* splitShare */ \ +} + +/* Return the NCCL_VERSION_CODE of the NCCL library in the supplied integer. + * This integer is coded with the MAJOR, MINOR and PATCH level of the + * NCCL library + */ +ncclResult_t ncclGetVersion(int *version); +ncclResult_t pncclGetVersion(int *version); + +/* Generates an Id to be used in ncclCommInitRank. ncclGetUniqueId should be + * called once and the Id should be distributed to all ranks in the + * communicator before calling ncclCommInitRank. */ +ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); +ncclResult_t pncclGetUniqueId(ncclUniqueId* uniqueId); + +/* Create a new communicator (multi thread/process version) with a configuration + * set by users. */ +ncclResult_t ncclCommInitRankConfig(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank, ncclConfig_t* config); +ncclResult_t pncclCommInitRankConfig(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank, ncclConfig_t* config); + +/* Creates a new communicator (multi thread/process version). + * rank must be between 0 and nranks-1 and unique within a communicator clique. + * Each rank is associated to a CUDA device, which has to be set before calling + * ncclCommInitRank. + * ncclCommInitRank implicitly syncronizes with other ranks, so it must be + * called by different threads/processes or use ncclGroupStart/ncclGroupEnd. */ +ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); +ncclResult_t pncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + +/* Creates a clique of communicators (single process version). + * This is a convenience function to create a single-process communicator clique. + * Returns an array of ndev newly initialized communicators in comm. + * comm should be pre-allocated with size at least ndev*sizeof(ncclComm_t). + * If devlist is NULL, the first ndev CUDA devices are used. + * Order of devlist defines user-order of processors within the communicator. */ +ncclResult_t ncclCommInitAll(ncclComm_t* comm, int ndev, const int* devlist); +ncclResult_t pncclCommInitAll(ncclComm_t* comm, int ndev, const int* devlist); + +/* Finalize a communicator. ncclCommFinalize flushes all issued communications, + * and marks communicator state as ncclInProgress. The state will change to ncclSuccess + * when the communicator is globally quiescent and related resources are freed; then, + * calling ncclCommDestroy can locally free the rest of the resources (e.g. communicator + * itself) without blocking. */ +ncclResult_t ncclCommFinalize(ncclComm_t comm); +ncclResult_t pncclCommFinalize(ncclComm_t comm); + +/* Frees local resources associated with communicator object. */ +ncclResult_t ncclCommDestroy(ncclComm_t comm); +ncclResult_t pncclCommDestroy(ncclComm_t comm); + +/* Frees resources associated with communicator object and aborts any operations + * that might still be running on the device. */ +ncclResult_t ncclCommAbort(ncclComm_t comm); +ncclResult_t pncclCommAbort(ncclComm_t comm); + +/* Creates one or more communicators from an existing one. + * Ranks with the same color will end up in the same communicator. + * Within the new communicator, key will be used to order ranks. + * NCCL_SPLIT_NOCOLOR as color will indicate the rank will not be part of any group + * and will therefore return a NULL communicator. + * If config is NULL, the new communicator will inherit the original communicator's + * configuration*/ +ncclResult_t ncclCommSplit(ncclComm_t comm, int color, int key, ncclComm_t *newcomm, ncclConfig_t* config); +ncclResult_t pncclCommSplit(ncclComm_t comm, int color, int key, ncclComm_t *newcomm, ncclConfig_t* config); + +/* Returns a string for each error code. */ +const char* ncclGetErrorString(ncclResult_t result); +const char* pncclGetErrorString(ncclResult_t result); + +/* Returns a human-readable message of the last error that occurred. + * comm is currently unused and can be set to NULL + */ +const char* ncclGetLastError(ncclComm_t comm); +const char* pncclGetLastError(ncclComm_t comm); + +/* Checks whether the comm has encountered any asynchronous errors */ +ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError); +ncclResult_t pncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError); + +/* Gets the number of ranks in the communicator clique. */ +ncclResult_t ncclCommCount(const ncclComm_t comm, int* count); +ncclResult_t pncclCommCount(const ncclComm_t comm, int* count); + +/* Returns the cuda device number associated with the communicator. */ +ncclResult_t ncclCommCuDevice(const ncclComm_t comm, int* device); +ncclResult_t pncclCommCuDevice(const ncclComm_t comm, int* device); + +/* Returns the user-ordered "rank" associated with the communicator. */ +ncclResult_t ncclCommUserRank(const ncclComm_t comm, int* rank); +ncclResult_t pncclCommUserRank(const ncclComm_t comm, int* rank); + +/* Reduction operation selector */ +typedef enum { ncclNumOps_dummy = 5 } ncclRedOp_dummy_t; +typedef enum { ncclSum = 0, + ncclProd = 1, + ncclMax = 2, + ncclMin = 3, + ncclAvg = 4, + /* ncclNumOps: The number of built-in ncclRedOp_t values. Also + * serves as the least possible value for dynamic ncclRedOp_t's + * as constructed by ncclRedOpCreate*** functions. */ + ncclNumOps = 5, + /* ncclMaxRedOp: The largest valid value for ncclRedOp_t. + * It is defined to be the largest signed value (since compilers + * are permitted to use signed enums) that won't grow + * sizeof(ncclRedOp_t) when compared to previous NCCL versions to + * maintain ABI compatibility. */ + ncclMaxRedOp = 0x7fffffff>>(32-8*sizeof(ncclRedOp_dummy_t)) + } ncclRedOp_t; + +/* Data types */ +typedef enum { ncclInt8 = 0, ncclChar = 0, + ncclUint8 = 1, + ncclInt32 = 2, ncclInt = 2, + ncclUint32 = 3, + ncclInt64 = 4, + ncclUint64 = 5, + ncclFloat16 = 6, ncclHalf = 6, + ncclFloat32 = 7, ncclFloat = 7, + ncclFloat64 = 8, ncclDouble = 8, +#if defined(__CUDA_BF16_TYPES_EXIST__) && defined(__CUDA_FP8_TYPES_EXIST__) + ncclBfloat16 = 9, + ncclFp8E4M3 = 10, + ncclFp8E5M2 = 11, + ncclNumTypes = 12 +#elif defined(__CUDA_BF16_TYPES_EXIST__) + ncclBfloat16 = 9, + ncclNumTypes = 10 +#else + ncclNumTypes = 9 +#endif +} ncclDataType_t; + +/* ncclScalarResidence_t: Location and dereferencing logic for scalar arguments. */ +typedef enum { + /* ncclScalarDevice: The scalar is in device-visible memory and will be + * dereferenced while the collective is running. */ + ncclScalarDevice = 0, + + /* ncclScalarHostImmediate: The scalar is in host-visible memory and will be + * dereferenced before the ncclRedOpCreate***() function returns. */ + ncclScalarHostImmediate = 1 +} ncclScalarResidence_t; + +/* + * ncclRedOpCreatePreMulSum + * + * Creates a new reduction operator which pre-multiplies input values by a given + * scalar locally before reducing them with peer values via summation. For use + * only with collectives launched against *comm* and *datatype*. The + * *residence* argument indicates how/when the memory pointed to by *scalar* + * will be dereferenced. Upon return, the newly created operator's handle + * is stored in *op*. + */ +ncclResult_t ncclRedOpCreatePreMulSum(ncclRedOp_t *op, void *scalar, ncclDataType_t datatype, ncclScalarResidence_t residence, ncclComm_t comm); +ncclResult_t pncclRedOpCreatePreMulSum(ncclRedOp_t *op, void *scalar, ncclDataType_t datatype, ncclScalarResidence_t residence, ncclComm_t comm); + +/* + * ncclRedOpDestroy + * + * Destroys the reduction operator *op*. The operator must have been created by + * ncclRedOpCreatePreMul with the matching communicator *comm*. An operator may be + * destroyed as soon as the last NCCL function which is given that operator returns. + */ +ncclResult_t ncclRedOpDestroy(ncclRedOp_t op, ncclComm_t comm); +ncclResult_t pncclRedOpDestroy(ncclRedOp_t op, ncclComm_t comm); + +/* + * Collective communication operations + * + * Collective communication operations must be called separately for each + * communicator in a communicator clique. + * + * They return when operations have been enqueued on the CUDA stream. + * + * Since they may perform inter-CPU synchronization, each call has to be done + * from a different thread or process, or need to use Group Semantics (see + * below). + */ + +/* + * Reduce + * + * Reduces data arrays of length count in sendbuff into recvbuff using op + * operation. + * recvbuff may be NULL on all calls except for root device. + * root is the rank (not the CUDA device) where data will reside after the + * operation is complete. + * + * In-place operation will happen if sendbuff == recvbuff. + */ +ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, + ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream); +ncclResult_t pncclReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, + ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream); + +/* + * (deprecated) Broadcast (in-place) + * + * Copies count values from root to all other devices. + * root is the rank (not the CUDA device) where data resides before the + * operation is started. + * + * This operation is implicitly in place. + */ +ncclResult_t ncclBcast(void* buff, size_t count, ncclDataType_t datatype, int root, + ncclComm_t comm, cudaStream_t stream); +ncclResult_t pncclBcast(void* buff, size_t count, ncclDataType_t datatype, int root, + ncclComm_t comm, cudaStream_t stream); + +/* + * Broadcast + * + * Copies count values from root to all other devices. + * root is the rank (not the CUDA device) where data resides before the + * operation is started. + * + * In-place operation will happen if sendbuff == recvbuff. + */ +ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root, + ncclComm_t comm, cudaStream_t stream); +ncclResult_t pncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root, + ncclComm_t comm, cudaStream_t stream); + +/* + * All-Reduce + * + * Reduces data arrays of length count in sendbuff using op operation, and + * leaves identical copies of result on each recvbuff. + * + * In-place operation will happen if sendbuff == recvbuff. + */ +ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, + ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream); +ncclResult_t pncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, + ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream); + +/* + * Reduce-Scatter + * + * Reduces data in sendbuff using op operation and leaves reduced result + * scattered over the devices so that recvbuff on rank i will contain the i-th + * block of the result. + * Assumes sendcount is equal to nranks*recvcount, which means that sendbuff + * should have a size of at least nranks*recvcount elements. + * + * In-place operations will happen if recvbuff == sendbuff + rank * recvcount. + */ +ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, + size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + cudaStream_t stream); +ncclResult_t pncclReduceScatter(const void* sendbuff, void* recvbuff, + size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + cudaStream_t stream); + +/* + * All-Gather + * + * Each device gathers sendcount values from other GPUs into recvbuff, + * receiving data from rank i at offset i*sendcount. + * Assumes recvcount is equal to nranks*sendcount, which means that recvbuff + * should have a size of at least nranks*sendcount elements. + * + * In-place operations will happen if sendbuff == recvbuff + rank * sendcount. + */ +ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, + ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream); +ncclResult_t pncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, + ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream); + +/* + * Send + * + * Send data from sendbuff to rank peer. + * + * Rank peer needs to call ncclRecv with the same datatype and the same count from this + * rank. + * + * This operation is blocking for the GPU. If multiple ncclSend and ncclRecv operations + * need to progress concurrently to complete, they must be fused within a ncclGroupStart/ + * ncclGroupEnd section. + */ +ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer, + ncclComm_t comm, cudaStream_t stream); +ncclResult_t pncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer, + ncclComm_t comm, cudaStream_t stream); + +/* + * Receive + * + * Receive data from rank peer into recvbuff. + * + * Rank peer needs to call ncclSend with the same datatype and the same count to this + * rank. + * + * This operation is blocking for the GPU. If multiple ncclSend and ncclRecv operations + * need to progress concurrently to complete, they must be fused within a ncclGroupStart/ + * ncclGroupEnd section. + */ +ncclResult_t pncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer, + ncclComm_t comm, cudaStream_t stream); +ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer, + ncclComm_t comm, cudaStream_t stream); + +/* All-To-All + * + * Device (i) send (j)th block of data to device (j) and be placed as (i)th + * block. Each block for sending/receiving has count elements, which means + * that recvbuff and sendbuff should have a size of nranks*count elements. + * + * In-place operation will happen if sendbuff == recvbuff. + */ +ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, + ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream); +ncclResult_t pncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, + ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream); +/*! @brief Opaque handle to MSCCL algorithm */ +typedef int mscclAlgoHandle_t; + +/*! @brief MSCCL Load Algorithm + * + * @details Load MSCCL algorithm file specified in mscclAlgoFilePath and return + * its handle via mscclAlgoHandle. This API is expected to be called by MSCCL + * scheduler instead of end users. + */ +ncclResult_t mscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank); +ncclResult_t pmscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank); + +/*! @brief MSCCL Run Algorithm + * + * @details Run MSCCL algorithm specified by mscclAlgoHandle. The parameter + * list merges all possible parameters required by different operations as this + * is a general-purposed API. This API is expected to be called by MSCCL + * scheduler instead of end users. + */ +ncclResult_t mscclRunAlgo( + const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[], + void* recvBuff, const size_t recvCounts[], const size_t rDisPls[], + size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op, + mscclAlgoHandle_t mscclAlgoHandle, ncclComm_t comm, cudaStream_t stream); +ncclResult_t pmscclRunAlgo( + const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[], + void* recvBuff, const size_t recvCounts[], const size_t rDisPls[], + size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op, + mscclAlgoHandle_t mscclAlgoHandle, ncclComm_t comm, cudaStream_t stream); + +/*! @brief MSCCL Load Algorithm + * + * @details Unload MSCCL algorithm previous loaded using its handle. This API + * is expected to be called by MSCCL scheduler instead of end users. + */ +ncclResult_t mscclUnloadAlgo(mscclAlgoHandle_t mscclAlgoHandle); +ncclResult_t pmscclUnloadAlgo(mscclAlgoHandle_t mscclAlgoHandle); + +/* + * Group semantics + * + * When managing multiple GPUs from a single thread, and since NCCL collective + * calls may perform inter-CPU synchronization, we need to "group" calls for + * different ranks/devices into a single call. + * + * Grouping NCCL calls as being part of the same collective operation is done + * using ncclGroupStart and ncclGroupEnd. ncclGroupStart will enqueue all + * collective calls until the ncclGroupEnd call, which will wait for all calls + * to be complete. Note that for collective communication, ncclGroupEnd only + * guarantees that the operations are enqueued on the streams, not that + * the operation is effectively done. + * + * Both collective communication and ncclCommInitRank can be used in conjunction + * of ncclGroupStart/ncclGroupEnd, but not together. + * + * Group semantics also allow to fuse multiple operations on the same device + * to improve performance (for aggregated collective calls), or to permit + * concurrent progress of multiple send/receive operations. + */ + +/* + * Group Start + * + * Start a group call. All calls to NCCL until ncclGroupEnd will be fused into + * a single NCCL operation. Nothing will be started on the CUDA stream until + * ncclGroupEnd. + */ +ncclResult_t ncclGroupStart(); +ncclResult_t pncclGroupStart(); + +/* + * Group End + * + * End a group call. Start a fused NCCL operation consisting of all calls since + * ncclGroupStart. Operations on the CUDA stream depending on the NCCL operations + * need to be called after ncclGroupEnd. + */ +ncclResult_t ncclGroupEnd(); +ncclResult_t pncclGroupEnd(); + +#ifdef __cplusplus +} // end extern "C" +#endif + +#endif // end include guard diff --git a/apps/nccl/rccl_test.py b/apps/nccl/rccl_test.py new file mode 100644 index 000000000..bca7af815 --- /dev/null +++ b/apps/nccl/rccl_test.py @@ -0,0 +1,67 @@ +import os +from mpi4py import MPI +import torch +from cupy.cuda import nccl + +ROOT_RANK = 0 +comm = MPI.COMM_WORLD +rank = comm.Get_rank() + +is_group_root = rank == ROOT_RANK + +world_size = comm.Get_size() + +os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) + +device_type = "cuda" +torch.cuda.set_device(0) +device_index = 0 +device = torch.device(type=device_type, index=device_index) + +if is_group_root: + id_ = nccl.get_unique_id() +else: + id_ = None + +ranks = range(world_size) +id_, ranks = comm.bcast((id_, ranks), root=0) +group = nccl.NcclCommunicator(len(ranks), id_, rank) +print(f"{rank=}, {device=}, {group=}") + +M = 1024 +N = 4096 +K = 2048 +shape_a = (M, K) +shape_b = (K, N) +shape_c = (M, N) + +a = torch.ones(shape_a, device="cuda") +b = torch.ones(shape_b, device="cuda") +c = torch.mm(a, b) + +print(c) + +# nccl_op = nccl.NCCL_SUM +# group.allReduce( +# sendbuf=c.data_ptr(), +# recvbuf=c.data_ptr(), +# count=c.nelement(), +# datatype=nccl.NCCL_FLOAT, +# op=nccl_op, +# stream=torch.cuda.current_stream().cuda_stream) + +# print(c) + +d = torch.ones((1024 * 1024,), device="cuda") +e = torch.zeros((8 * 1024 * 1024,), device="cuda") +e[rank * 1024 * 1024 : (rank + 1) * 1024 * 1024] = d + +group.allGather( + sendbuf=d.data_ptr(), + recvbuf=e.data_ptr(), + count=d.nelement(), + datatype=nccl.NCCL_FLOAT, + stream=torch.cuda.current_stream().cuda_stream, +) + +print(e) diff --git a/apps/nccl/src/allgather.hpp b/apps/nccl/src/allgather.hpp new file mode 100644 index 000000000..9d8e1fbc7 --- /dev/null +++ b/apps/nccl/src/allgather.hpp @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ALLGATHER_HPP_ +#define ALLGATHER_HPP_ + +#include +#include +#include +#include +#include + +#include "common.hpp" + +template +__global__ void __launch_bounds__(1024, 1) + allgather6(void* sendbuff, mscclpp::DeviceHandle* smChannels, size_t rank, + [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) { + const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; + const size_t lid = tid % WARP_SIZE; + const size_t wid = tid / WARP_SIZE; + + const size_t nThread = blockDim.x * gridDim.x; + const size_t nWarp = nThread / WARP_SIZE; + const size_t nPeer = nRanksPerNode - 1; + const size_t chanOffset = nPeer * blockIdx.x; + auto smChans = smChannels + chanOffset; + + if (threadIdx.x < nPeer) { + smChans[threadIdx.x].relaxedSignal(); + smChans[threadIdx.x].wait(); + } + __syncthreads(); + + const size_t bytesPerGPU = nelemsPerGPU * sizeof(int); + const size_t bytes = bytesPerGPU * nPeer; + size_t unitBytesPerThread; + if (bytes >= nThread * 64) { + unitBytesPerThread = 64; + } else { + unitBytesPerThread = 16; + } + const size_t unitBytesPerWarp = unitBytesPerThread * WARP_SIZE; + const size_t unitBytes = unitBytesPerWarp * nWarp; + const size_t nLoop = bytes / unitBytes; + + if (nLoop > 0) { + // First loop unrolling + const size_t peerIdx = wid % nPeer; + const size_t offset = bytesPerGPU * rank + (wid / nPeer) * unitBytesPerWarp; + if constexpr (IsOutOfPlace) { + char* dst = reinterpret_cast(smChans[peerIdx].dst_); + char* src = reinterpret_cast(smChans[peerIdx].src_); + char* buff = reinterpret_cast(sendbuff); + const size_t offsetWithinRank = (wid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].copy<16, false>(src + offset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE); + smChans[peerIdx].copy<16, false>(dst + offset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE); + } else { + smChans[peerIdx].put<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE); + } + } + + for (size_t i = 1; i < nLoop; ++i) { + const size_t gWid = wid + i * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t offset = bytesPerGPU * rank + (gWid / nPeer) * unitBytesPerWarp; + if constexpr (IsOutOfPlace) { + char* dst = reinterpret_cast(smChans[peerIdx].dst_); + char* src = reinterpret_cast(smChans[peerIdx].src_); + char* buff = reinterpret_cast(sendbuff); + const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; + smChans[peerIdx].copy<16, false>(src + offset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE); + smChans[peerIdx].copy<16, false>(dst + offset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE); + } else { + smChans[peerIdx].put<16, false>(offset, unitBytesPerWarp, lid, WARP_SIZE); + } + } + + if (bytes % unitBytes > 0) { + const size_t gWid = wid + nLoop * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; + const size_t offset = bytesPerGPU * rank + offsetWithinRank; + const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU) + ? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0) + : unitBytesPerWarp; + if (remainBytes > 0) { + if constexpr (IsOutOfPlace) { + char* dst = reinterpret_cast(smChans[peerIdx].dst_); + char* src = reinterpret_cast(smChans[peerIdx].src_); + char* buff = reinterpret_cast(sendbuff); + smChans[peerIdx].copy<16, true>(src + offset, buff + offsetWithinRank, remainBytes, lid, WARP_SIZE); + smChans[peerIdx].copy<16, true>(dst + offset, buff + offsetWithinRank, remainBytes, lid, WARP_SIZE); + } else { + smChans[peerIdx].put<16, true>(offset, remainBytes, lid, WARP_SIZE); + } + } + } +} + +template +cudaError_t allgather(T* buff, [[maybe_unused]] T* scratch, [[maybe_unused]] T* resultBuff, + mscclpp::DeviceHandle* smChannels, int rank, int nRanksPerNode, int worldSize, + size_t nelems, cudaStream_t stream) { + allgather6<<<28, 1024, 0, stream>>>((void*)buff, smChannels, rank, worldSize, nRanksPerNode, + nelems * sizeof(T) / sizeof(int)); + return cudaGetLastError(); +} + +#endif // ALLGATHER_HPP_ diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp new file mode 100644 index 000000000..437a79557 --- /dev/null +++ b/apps/nccl/src/allreduce.hpp @@ -0,0 +1,487 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ALLREDUCE_HPP_ +#define ALLREDUCE_HPP_ + +#include +#include +#include +#include +#include +#include + +#include "common.hpp" + +extern __device__ mscclpp::DeviceSyncer deviceSyncer; + +template +__forceinline__ __device__ To bit_cast(const From& src) { + static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast"); + + union { + From f; + To t; + } u; + u.f = src; + return u.t; +} + +template +__forceinline__ __device__ T add_elements(T a, T b) { + return a + b; +} + +template <> +__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) { + return __hadd2(a, b); +} + +template +__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) { + int4 ret; + ret.w = bit_cast(add_elements(bit_cast(a.w), bit_cast(b.w))); + ret.x = bit_cast(add_elements(bit_cast(a.x), bit_cast(b.x))); + ret.y = bit_cast(add_elements(bit_cast(a.y), bit_cast(b.y))); + ret.z = bit_cast(add_elements(bit_cast(a.z), bit_cast(b.z))); + return ret; +} + +template +__forceinline__ __device__ int4 add_vectors(int4 a, int4 b) { + return add_vectors_helper(a, b); +} + +template <> +__forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) { + return add_vectors_helper<__half2>(a, b); +} + +template +__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) { + uint2 ret; + ret.x = bit_cast(add_elements(bit_cast(a.x), bit_cast(b.x))); + ret.y = bit_cast(add_elements(bit_cast(a.y), bit_cast(b.y))); + return ret; +} + +template +__forceinline__ __device__ uint2 add_vectors(uint2 a, uint2 b) { + return add_vectors_helper(a, b); +} + +template <> +__forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) { + return add_vectors_helper<__half2>(a, b); +} + +template +__forceinline__ __device__ int add_vectors_helper(int a, int b) { + return bit_cast(add_elements(bit_cast(a), bit_cast(b))); +} + +template +__forceinline__ __device__ int add_vectors(int a, int b) { + return add_vectors_helper(a, b); +} + +template <> +__forceinline__ __device__ int add_vectors<__half>(int a, int b) { + return add_vectors_helper<__half2>(a, b); +} + +template +__forceinline__ __device__ uint32_t add_vectors_helper(uint32_t a, uint32_t b) { + return bit_cast(add_elements(bit_cast(a), bit_cast(b))); +} + +template +__forceinline__ __device__ uint32_t add_vectors(uint32_t a, uint32_t b) { + return add_vectors_helper(a, b); +} + +template <> +__forceinline__ __device__ uint32_t add_vectors<__half>(uint32_t a, uint32_t b) { + return add_vectors_helper<__half2>(a, b); +} + +template +__forceinline__ __device__ void vectorSum(T* dst, T* src, size_t nElem, int blockId, int nBlocks) { + size_t nInt4 = nElem / 4; + size_t nLastInts = nElem % 4; + int4* dst4 = (int4*)dst; + int4* src4 = (int4*)src; + for (size_t i = threadIdx.x + blockId * blockDim.x; i < nInt4; i += blockDim.x * nBlocks) { + dst4[i] = add_vectors(dst4[i], src4[i]); + } + if (nLastInts > 0) { + int* dstLast = ((int*)dst) + nInt4 * 4; + int* srcLast = ((int*)src) + nInt4 * 4; + for (size_t i = threadIdx.x + blockId * blockDim.x; i < nLastInts; i += blockDim.x * nBlocks) { + dstLast[i] = add_vectors(dstLast[i], srcLast[i]); + } + } +} + +template +__forceinline__ __device__ void vectorSum(T* dst, T* src, size_t nElem) { + vectorSum(dst, src, nElem, blockIdx.x, gridDim.x); +} + +template +__global__ void __launch_bounds__(1024, 1) + allreduce6(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, int rank, + int nRanksPerNode, int worldSize, size_t nelems, uint32_t flag) { + // This version of allreduce only works for single nodes + if (worldSize != nRanksPerNode) return; + nelems = nelems / (sizeof(int) / sizeof(T)); + const int nPeers = nRanksPerNode - 1; + const int nPkts = nelems / 2; + const int nelemsPerRank = nelems / worldSize; + const int nPktsPerRank = nelemsPerRank / 2; + // thread block & channel info + const int nBlocksPerPeer = gridDim.x / nPeers; + const int localBlockIdx = blockIdx.x % nBlocksPerPeer; + const int peerIdx = blockIdx.x / nBlocksPerPeer; + const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1; + mscclpp::SmChannelDeviceHandle smChan = smChannels[peerIdx]; + const int tid = threadIdx.x + localBlockIdx * blockDim.x; + // double buffering + size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket); + void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset); + size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LLPacket); + size_t scratchResultOffset = + (flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket); + size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int); + uint2* src = (uint2*)((char*)buff + rank * nelemsPerRank * sizeof(int)); + uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int)); + + // step 1: write to scratch buffer + smChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag); + // step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) { + uint2 data = make_uint2(0, 0); + for (int index = 0; index < nPeers; index++) { + const int remoteRank = index < rank ? index : index + 1; + mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank; + uint2 val = dstPkt[idx].read(flag, -1); + data = add_vectors(val, data); + } + data = add_vectors(data, src[idx]); + dst[idx].x = data.x; + dst[idx].y = data.y; + for (int index = 0; index < nPeers; index++) { + mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)smChannels[index].dst_ + scratchResultOffset); + dstPkt[idx + rank * nPktsPerRank].write(data.x, data.y, flag); + } + } + // step 3: get data result from scratch buffer + mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset); + const int dstOffset = remoteRank * nPktsPerRank; + uint2* result = (uint2*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int)); + for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) { + uint2 data = dstPkt[idx + dstOffset].read(flag, -1); + result[idx].x = data.x; + result[idx].y = data.y; + } +} + +template +__global__ void __launch_bounds__(1024, 1) + allreduce1(T* src, T* dst, mscclpp::DeviceHandle* smChannels, + mscclpp::DeviceHandle* smOutChannels, int rank, int nranks, size_t nelems) { + const size_t chunkSize = nelems / nranks; + if (nranks == 1) return; + const int nPeer = nranks - 1; + const size_t indexOffset = rank * chunkSize; + const size_t vectorSize = sizeof(int4) / sizeof(T); + const size_t indexOffset4 = indexOffset / vectorSize; + int4* src4 = (int4*)src; + int4* dst4 = (int4*)dst; + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + + // synchronize everyone + if (tid == 0) { + __threadfence_system(); + } + __syncthreads(); + if (tid < nPeer) { + smChannels[tid].relaxedSignal(); + } + if (tid >= nPeer && tid < nPeer * 2) { + smChannels[tid - nPeer].wait(); + } + deviceSyncer.sync(gridDim.x); + + // use int4 as much as possible + const size_t nInt4 = chunkSize / vectorSize; + for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nInt4; idx += blockDim.x * gridDim.x) { + int4 tmp = src4[indexOffset4 + idx]; + for (int index = 0; index < nPeer; ++index) { + int4 val; + int peerIdx = (index + rank); + if (peerIdx >= nPeer) peerIdx -= nPeer; + val = smChannels[peerIdx].read(indexOffset4 + idx); + tmp = add_vectors(tmp, val); + } + dst4[indexOffset4 + idx] = tmp; + } + + // use the given TYPE for the rest + size_t processed = nInt4 * vectorSize * nranks; + const size_t nRemElems = nelems - processed; + const size_t startIdx = processed + (nRemElems * rank) / nranks; + const size_t endIdx = processed + (nRemElems * (rank + 1)) / nranks; + for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x + startIdx; idx < endIdx; idx += blockDim.x * gridDim.x) { + T tmp = src[idx]; + for (int index = 0; index < nPeer; ++index) { + int peerIdx = (index + rank); + if (peerIdx >= nPeer) peerIdx -= nPeer; + T val = smChannels[peerIdx].read(idx); + tmp += val; + } + dst[idx] = tmp; + } + + // synchronize everyone again + deviceSyncer.sync(gridDim.x); + if (tid == 0) { + __threadfence_system(); + } + __syncthreads(); + if (tid < nPeer) { + smChannels[tid].relaxedSignal(); + } + if (tid >= nPeer && tid < nPeer * 2) { + smChannels[tid - nPeer].wait(); + } + + deviceSyncer.sync(gridDim.x); + for (int i = 0; i < nPeer; ++i) { + int peerIdx = (i + rank); + if (peerIdx >= nPeer) peerIdx -= nPeer; + const int remoteRank = (peerIdx < rank ? peerIdx : peerIdx + 1); + size_t offset = chunkSize * remoteRank * sizeof(T); + smOutChannels[peerIdx].get(offset, chunkSize * sizeof(T), tid, blockDim.x * gridDim.x); + } +} + +template +__global__ void __launch_bounds__(1024, 1) + allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, int rank, + int nRanksPerNode, int worldSize, size_t nelems, uint32_t flag) { + // This version of allreduce only works for single nodes + if (worldSize != nRanksPerNode) return; + nelems = nelems / (sizeof(int) / sizeof(T)); + const int nPeers = nRanksPerNode - 1; + const size_t nPkts = nelems; + const int nelemsPerRank = nelems / worldSize; + const int nPktsPerRank = nelemsPerRank; + // thread block & channel info + const int nBlocksPerPeer = gridDim.x / nPeers; + const int localBlockIdx = blockIdx.x % nBlocksPerPeer; + const int peerIdx = blockIdx.x / nBlocksPerPeer; + const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1; + const int tid = threadIdx.x + localBlockIdx * blockDim.x; + // double buffering + size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LL8Packet); + void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset); + size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LL8Packet); + size_t scratchResultOffset = + (flag & 1) ? 2 * nPkts * sizeof(mscclpp::LL8Packet) : 3 * nPkts * sizeof(mscclpp::LL8Packet); + size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int); + uint32_t* src = (uint32_t*)((char*)buff + rank * nelemsPerRank * sizeof(int)); + uint32_t* dst = (uint32_t*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int)); + + // Put channels into shared memory, read channel info from global memory is unexpectable slow. + __shared__ mscclpp::DeviceHandle channels[NRANKS_PER_NODE - 1]; + const int lid = tid % WARP_SIZE; + if (lid < nPeers) { + channels[lid] = smChannels[lid]; + } + __syncwarp(); + + // step 1: write to scratch buffer + channels[peerIdx].putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, + blockDim.x * nBlocksPerPeer, flag); + // step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) { + uint32_t data = 0; + for (int index = 0; index < nPeers; index++) { + const int remoteRank = index < rank ? index : index + 1; + mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)scratchBuff + remoteRank * nPktsPerRank; + uint32_t val = dstPkt[idx].read(flag, -1); + data = add_vectors(val, data); + } + data = add_vectors(data, src[idx]); + dst[idx] = data; + + mscclpp::LL8Packet packet; + packet.data = data; + packet.flag = flag; + size_t offset = scratchResultOffset / sizeof(mscclpp::LL8Packet) + (idx + rank * nPktsPerRank); + for (int index = 0; index < nPeers; index++) { + channels[index].write(offset, packet); + } + } + // step 3: get data result from scratch buffer + mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)((char*)scratch + scratchResultOffset); + const int dstOffset = remoteRank * nPktsPerRank; + uint32_t* result = (uint32_t*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int)); + for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) { + uint32_t data = dstPkt[idx + dstOffset].read(flag, -1); + result[idx] = data; + } +} + +template +__global__ void __launch_bounds__(512, 1) + allreduce8(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, + mscclpp::DeviceHandle* smOutChannels, int rank, int nRanksPerNode, int worldSize, + size_t nelems) { + const int nPeer = nRanksPerNode - 1; + const size_t chanOffset = nPeer * blockIdx.x; + // assume (nelems * sizeof(T)) is divisible by (16 * worldSize) + const size_t nInt4 = nelems * sizeof(T) / sizeof(int4); + const size_t nInt4PerRank = nInt4 / worldSize; + auto smChans = smChannels + chanOffset; + auto smOutChans = smOutChannels + chanOffset; + + int4* buff4 = reinterpret_cast(buff); + int4* scratch4 = reinterpret_cast(scratch); + int4* resultBuff4 = reinterpret_cast(resultBuff); + + // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4` + constexpr size_t unitNInt4 = 512; + const size_t maxNInt4PerBlock = (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4; + size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x; + size_t nInt4OfThisBlock = maxNInt4PerBlock; + size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock; + constexpr size_t nInt4PerChunk = 1024 * 256 / sizeof(int4); // 256KB + if (blockIdx.x >= nNeededBlocks) { + nInt4OfThisBlock = 0; + } else if (blockIdx.x == nNeededBlocks - 1) { + nInt4OfThisBlock = nInt4PerRank - maxNInt4PerBlock * (nNeededBlocks - 1); + } + const size_t nItrs = nInt4OfThisBlock / nInt4PerChunk; + const size_t restNInt4 = nInt4OfThisBlock % nInt4PerChunk; + const size_t chunkSizePerRank = nNeededBlocks * nInt4PerChunk; + const size_t blockOffset = nInt4PerChunk * blockIdx.x; + const size_t scratchChunkRankOffset = chunkSizePerRank * rank; + + __shared__ mscclpp::DeviceHandle channels[NRANKS_PER_NODE - 1]; + __shared__ mscclpp::DeviceHandle outChannels[NRANKS_PER_NODE - 1]; + const int lid = threadIdx.x % WARP_SIZE; + if (lid < nPeer) { + channels[lid] = smChans[lid]; + outChannels[lid] = smOutChans[lid]; + } + __syncwarp(); + + // we can use double buffering to hide synchronization overhead + for (size_t itr = 0; itr < nItrs; itr++) { + if (threadIdx.x < static_cast(nPeer)) { + outChannels[threadIdx.x].signal(); + outChannels[threadIdx.x].wait(); + } + __syncthreads(); + // Starts allgather + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { + for (int i = 0; i < nPeer; i++) { + const int peerIdx = (i + blockIdx.x) % nPeer; + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock]; + channels[peerIdx].write(scratchChunkRankOffset + blockOffset + idx, val); + } + } + + /// Starts reduce-scatter + if (threadIdx.x < static_cast(nPeer)) { + outChannels[threadIdx.x].signal(); + outChannels[threadIdx.x].wait(); + } + __syncthreads(); + + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; + data = add_vectors(val, data); + } + resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; + for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock, data); + } + } + offsetOfThisBlock += nInt4PerChunk; + } + if (restNInt4 > 0) { + if (threadIdx.x < static_cast(nPeer)) { + outChannels[threadIdx.x].signal(); + outChannels[threadIdx.x].wait(); + } + __syncthreads(); + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { + for (int i = 0; i < nPeer; i++) { + const int peerIdx = (i + blockIdx.x) % nPeer; + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock]; + channels[peerIdx].write(scratchChunkRankOffset + blockOffset + idx, val); + } + } + + if (threadIdx.x < static_cast(nPeer)) { + outChannels[threadIdx.x].signal(); + outChannels[threadIdx.x].wait(); + } + __syncthreads(); + + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; + int4 val = scratch4[chunkSizePerRank * remoteRank + blockOffset + idx]; + data = add_vectors(val, data); + } + resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; + for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { + outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock, data); + } + } + } +} + +template +cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, + mscclpp::DeviceHandle* smOutChannels, int rank, int nRanksPerNode, + int worldSize, size_t nelems, cudaStream_t stream) { + static uint32_t flag = 1; +#if defined(__HIP_PLATFORM_AMD__) + if (sizeof(T) * nelems <= (1 << 20)) { + int nBlocks = 28; + int nThreadsPerBlock = 1024; + if (nelems >= 8192) { + nBlocks = 56; + nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024; + } + allreduce7<<>>(buff, scratch, resultBuff, smChannels, rank, nRanksPerNode, + worldSize, nelems, flag++); + } else { + int nBlocks = 35; + int nThreadsPerBlock = 512; + allreduce8<<>>(buff, scratch, resultBuff, smChannels, smOutChannels, rank, nRanksPerNode, + worldSize, nelems); + } +#else + if (sizeof(T) * nelems <= (1 << 20)) { + allreduce6<<<21, 512, 0, stream>>>(buff, scratch, resultBuff, smChannels, rank, nRanksPerNode, worldSize, nelems, + flag++); + } else { + allreduce1<<<24, 1024, 0, stream>>>(buff, resultBuff, smChannels, smOutChannels, rank, worldSize, nelems); + } +#endif + return cudaGetLastError(); +} + +#endif // ALLREDUCE_KERNEL_H diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp new file mode 100644 index 000000000..cddc69625 --- /dev/null +++ b/apps/nccl/src/common.hpp @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef NCCL_COMMON_HPP_ +#define NCCL_COMMON_HPP_ + +#if defined(__HIP_PLATFORM_AMD__) +#define WARP_SIZE 64 +#define __syncwarp() __builtin_amdgcn_wave_barrier() +#else +#define WARP_SIZE 32 +#endif + +constexpr int NRANKS_PER_NODE = 8; +constexpr int SCRATCH_SIZE = 1024 * 1024 * 70; // 35 thread-blocks * 8 ranks * 256KB = 70MB + +#endif // NCCL_COMMON_HPP_ diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu new file mode 100644 index 000000000..b1fabdf21 --- /dev/null +++ b/apps/nccl/src/nccl.cu @@ -0,0 +1,442 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include +#include +#include + +#include "allgather.hpp" +#include "allreduce.hpp" +#include "nccl.h" + +#define NCCL_API extern "C" __attribute__((visibility("default"))) + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define NUM_CHANNELS_PER_CONNECTION 64 +__device__ mscclpp::DeviceSyncer deviceSyncer; + +// static const mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2, +// mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5, +// mscclpp::Transport::IB6, mscclpp::Transport::IB7}; + +struct channelKey { + const void* sendbuff; + const void* recvbuff; + size_t bytes; + bool operator==(const channelKey& other) const { + return sendbuff == other.sendbuff && recvbuff == other.recvbuff && bytes == other.bytes; + } +}; + +namespace std { +template <> +struct hash { + std::size_t operator()(const channelKey& k) const { + return std::hash()(k.sendbuff) ^ std::hash()(k.recvbuff) ^ std::hash()(k.bytes); + } +}; +} // namespace std + +struct ChannelInfo { + std::vector smChannels; + std::vector smOutChannels; + std::shared_ptr> smChannelDeviceHandles; + std::shared_ptr> smOutChannelDeviceHandles; +}; + +struct ncclComm { + std::shared_ptr comm; + std::vector> connections; + std::vector> smSemaphores; + + std::unordered_map channelInfos; + std::shared_ptr scratchBuff; + std::vector remoteScratchRegMemories; +}; + +static size_t ncclTypeSize(ncclDataType_t type) { + switch (type) { + case ncclInt8: + case ncclUint8: + return 1; + case ncclFloat16: + return 2; + case ncclInt32: + case ncclUint32: + return 4; + case ncclInt64: + case ncclUint64: + return 8; + case ncclFloat32: + return 4; + case ncclFloat64: + return 8; +#if defined(__CUDA_BF16_TYPES_EXIST__) + case ncclBfloat16: + return 2; +#endif // defined(__CUDA_BF16_TYPES_EXIST__) +#if defined(__CUDA_FP8_TYPES_EXIST__) + case ncclFp8E4M3: + case ncclFp8E5M2: + return 1; +#endif // defined(__CUDA_FP8_TYPES_EXIST__) + case ncclNumTypes: + return 0; + } + return 0; +} + +static mscclpp::Transport getTransport(int, int) { + // if (rank / nRanksPerNode == peerRank / nRanksPerNode) { + // return mscclpp::Transport::CudaIpc; + // } else { + // return IBs[rank % nRanksPerNode]; + // } + return mscclpp::Transport::CudaIpc; +} + +static std::vector setupRemoteMemories(std::shared_ptr comm, int rank, + void* buff, size_t bytes, + mscclpp::TransportFlags transport) { + std::vector remoteMemories; + mscclpp::RegisteredMemory memory = comm->registerMemory(buff, bytes, transport); + std::vector> remoteRegMemoryFutures; + for (int i = 0; i < comm->bootstrap()->getNranks(); i++) { + if (i == rank) continue; + remoteRegMemoryFutures.push_back(comm->recvMemoryOnSetup(i, 0)); + comm->sendMemoryOnSetup(memory, i, 0); + } + comm->setup(); + std::transform(remoteRegMemoryFutures.begin(), remoteRegMemoryFutures.end(), std::back_inserter(remoteMemories), + [](const auto& future) { return future.get(); }); + return remoteMemories; +} + +static std::vector setupSmChannels(ncclComm_t comm, + const std::vector& remoteMemories, + void* src) { + std::vector channels; + std::vector>& smSemaphores = comm->smSemaphores; + size_t nConnections = comm->connections.size(); + for (size_t idx = 0; idx < NUM_CHANNELS_PER_CONNECTION; ++idx) { + for (size_t cid = 0; cid < nConnections; ++cid) { + if (comm->connections[cid]->transport() == mscclpp::Transport::CudaIpc) { + channels.emplace_back(smSemaphores[idx * nConnections + cid], remoteMemories[cid], src, nullptr); + } + } + } + return channels; +} + +static std::shared_ptr> setupSmChannelDeviceHandles( + const std::vector& smChannels) { + std::vector> smChannelDeviceHandles; + std::transform(smChannels.begin(), smChannels.end(), std::back_inserter(smChannelDeviceHandles), + [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); + std::shared_ptr> ptr = + mscclpp::allocSharedCuda>(smChannelDeviceHandles.size()); + mscclpp::memcpyCuda>(ptr.get(), smChannelDeviceHandles.data(), + smChannelDeviceHandles.size(), cudaMemcpyHostToDevice); + return ptr; +} + +NCCL_API ncclResult_t ncclGetVersion(int* version) { + if (version == nullptr) return ncclInvalidArgument; + *version = MSCCLPP_VERSION; + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId) { + if (uniqueId == nullptr) return ncclInvalidArgument; + if (MSCCLPP_UNIQUE_ID_BYTES != NCCL_UNIQUE_ID_BYTES) return ncclInternalError; + mscclpp::UniqueId id = mscclpp::TcpBootstrap::createUniqueId(); + memcpy(uniqueId, &id, sizeof(ncclUniqueId)); + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclCommInitRankConfig(ncclComm_t*, int, ncclUniqueId, int, + ncclConfig_t*) { + // TODO: implement this function + return ncclInternalError; +} + +NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) { + if (comm == nullptr) return ncclInvalidArgument; + if (nranks < 0 || rank < 0 || rank >= nranks) return ncclInvalidArgument; + std::shared_ptr bootstrap = std::make_shared(rank, nranks); + mscclpp::UniqueId id; + memcpy(id.data(), &commId, sizeof(ncclUniqueId)); + bootstrap->initialize(id); + std::shared_ptr mscclppComm = std::make_shared(bootstrap); + std::vector>> connectionFutures; + + for (int i = 0; i < mscclppComm->bootstrap()->getNranks(); i++) { + if (i == rank) continue; + mscclpp::Transport transport = getTransport(rank, i); + connectionFutures.push_back(mscclppComm->connectOnSetup(i, 0, transport)); + } + mscclppComm->setup(); + + std::vector> connections; + std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections), + [](const auto& future) { return future.get(); }); + + std::vector> smSemaphores; + for (size_t idx = 0; idx < NUM_CHANNELS_PER_CONNECTION; ++idx) { + for (size_t cid = 0; cid < connections.size(); ++cid) { + if (connections[cid]->transport() == mscclpp::Transport::CudaIpc) { + smSemaphores.emplace_back( + std::make_shared(*(mscclppComm), connections[cid])); + } + } + } + mscclppComm->setup(); + + ncclComm* commPtr = new ncclComm(); + commPtr->comm = mscclppComm; + commPtr->connections = std::move(connections); + commPtr->smSemaphores = std::move(smSemaphores); + commPtr->scratchBuff = mscclpp::allocExtSharedCuda(SCRATCH_SIZE); + commPtr->remoteScratchRegMemories = + setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc); + + *comm = commPtr; + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclCommInitAll(ncclComm_t*, int, const int*) { + // TODO: implement this function + return ncclInternalError; +} + +NCCL_API ncclResult_t ncclCommFinalize(ncclComm_t comm) { + comm->comm->bootstrap()->barrier(); + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclCommDestroy(ncclComm_t comm) { + if (comm == nullptr) return ncclInvalidArgument; + delete comm; + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclCommAbort(ncclComm_t) { + // TODO: implement this function + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclCommSplit(ncclComm_t, int, int, ncclComm_t*, ncclConfig_t*) { + // TODO: implement this function + return ncclInternalError; +} + +NCCL_API const char* ncclGetErrorString(ncclResult_t result) { + switch (result) { + case ncclSuccess: + return "no error"; + case ncclUnhandledCudaError: + return "unhandled cuda error (run with NCCL_DEBUG=INFO for details)"; + case ncclSystemError: + return "unhandled system error (run with NCCL_DEBUG=INFO for details)"; + case ncclInternalError: + return "internal error - please report this issue to the NCCL developers"; + case ncclInvalidArgument: + return "invalid argument (run with NCCL_DEBUG=WARN for details)"; + case ncclInvalidUsage: + return "invalid usage (run with NCCL_DEBUG=WARN for details)"; + case ncclRemoteError: + return "remote process exited or there was a network error"; + case ncclInProgress: + return "NCCL operation in progress"; + default: + return "unknown result code"; + } +} + +NCCL_API const char* ncclGetLastError(ncclComm_t) { + // TODO: implement this function + return nullptr; +} + +NCCL_API ncclResult_t ncclCommGetAsyncError(ncclComm_t, ncclResult_t* asyncError) { + if (asyncError == nullptr) return ncclInvalidArgument; + *asyncError = ncclSuccess; + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclCommCount(const ncclComm_t comm, int* count) { + if (comm == nullptr || count == nullptr) return ncclInvalidArgument; + *count = comm->comm->bootstrap()->getNranks(); + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclCommCuDevice(const ncclComm_t comm, int* device) { + if (comm == nullptr || device == nullptr) return ncclInvalidArgument; + *device = comm->comm->bootstrap()->getRank(); + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclCommUserRank(const ncclComm_t comm, int* rank) { + if (comm == nullptr || rank == nullptr) return ncclInvalidArgument; + *rank = comm->comm->bootstrap()->getRank(); + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclRedOpCreatePreMulSum(ncclRedOp_t*, void*, ncclDataType_t, + ncclScalarResidence_t, ncclComm_t) { + // TODO: implement this function + return ncclInternalError; +} + +NCCL_API ncclResult_t ncclRedOpDestroy(ncclRedOp_t, ncclComm_t) { + // TODO: implement this function + return ncclInternalError; +} + +NCCL_API ncclResult_t ncclReduce(const void*, void*, size_t, ncclDataType_t, + ncclRedOp_t, int, ncclComm_t, cudaStream_t) { + // TODO: implement this function + return ncclInternalError; +} + +NCCL_API ncclResult_t ncclBcast(void*, size_t, ncclDataType_t, int, ncclComm_t, + cudaStream_t) { + // TODO: implement this function + return ncclInternalError; +} + +NCCL_API ncclResult_t ncclBroadcast(const void*, void*, size_t, ncclDataType_t, + int, ncclComm_t, cudaStream_t) { + // TODO: implement this function + return ncclInternalError; +} + +NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, + ncclRedOp_t, ncclComm_t comm, cudaStream_t stream) { + size_t bytes = count * ncclTypeSize(datatype); + if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument; + int rank = comm->comm->bootstrap()->getRank(); + channelKey key{sendbuff, recvbuff, bytes}; + mscclpp::DeviceHandle* smChannels = nullptr; + mscclpp::DeviceHandle* smOutChannels = nullptr; + + auto it = comm->channelInfos.find(key); + if (it == comm->channelInfos.end()) { + // setup smChannels (src: sendbuff, dst: remote scratch buff) + std::vector channels = setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast(sendbuff)); + ChannelInfo channelInfo{channels, {}, setupSmChannelDeviceHandles(channels), nullptr}; + it = comm->channelInfos.emplace(key, channelInfo).first; + + // setup smOutChannels (src: recvbuff, dst: remote recvbuff) + if (bytes > (1 << 20)) { + std::vector remoteMemories = + setupRemoteMemories(comm->comm, rank, recvbuff, bytes, mscclpp::Transport::CudaIpc); + std::vector outChannels = setupSmChannels(comm, remoteMemories, recvbuff); + it->second.smOutChannels = outChannels; + it->second.smOutChannelDeviceHandles = setupSmChannelDeviceHandles(outChannels); + } + } + + smChannels = it->second.smChannelDeviceHandles.get(); + smOutChannels = it->second.smOutChannelDeviceHandles.get(); + + switch (datatype) { + case ncclFloat16: + CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels, + rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + break; + case ncclFloat32: + CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels, + smOutChannels, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE, + comm->comm->bootstrap()->getNranks(), count, stream)); + break; + case ncclInt32: + case ncclUint32: + CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels, + comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), + count, stream)); + break; + default: + return ncclInvalidArgument; + } + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclReduceScatter(const void*, void*, size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t) { + // TODO: implement this function + return ncclInternalError; +} + +NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, + ncclComm_t comm, cudaStream_t stream) { + size_t bytes = sendcount * ncclTypeSize(datatype); + if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument; + int rank = comm->comm->bootstrap()->getRank(); + int nRank = comm->comm->bootstrap()->getNranks(); + channelKey key{sendbuff, recvbuff, bytes}; + mscclpp::DeviceHandle* smChannels = nullptr; + + auto it = comm->channelInfos.find(key); + if (it == comm->channelInfos.end()) { + std::vector remoteMemories = + setupRemoteMemories(comm->comm, rank, const_cast(recvbuff), bytes * nRank, + mscclpp::Transport::CudaIpc); + std::vector channels = + setupSmChannels(comm, remoteMemories, const_cast(recvbuff)); + std::vector> smChannelDeviceHandles; + std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles), + [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); + ChannelInfo channelInfo{channels, {}, setupSmChannelDeviceHandles(channels), nullptr}; + it = comm->channelInfos.emplace(key, channelInfo).first; + } + smChannels = it->second.smChannelDeviceHandles.get(); + if ((char*)sendbuff == (char*)recvbuff + rank * sendcount) { + CUDACHECK(allgather((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, + rank, NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream)); + } else { + CUDACHECK(allgather((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, + rank, NRANKS_PER_NODE, nRank, bytes / sizeof(int), stream)); + } + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclSend(const void*, size_t, ncclDataType_t, int, ncclComm_t, + cudaStream_t) { + // TODO: implement this function + return ncclInternalError; +} + +NCCL_API ncclResult_t ncclRecv(void*, size_t, ncclDataType_t, int, ncclComm_t, + cudaStream_t) { + // TODO: implement this function + return ncclInternalError; +} + +NCCL_API ncclResult_t ncclAllToAll(const void*, void*, size_t, ncclDataType_t, + ncclComm_t, cudaStream_t) { + // TODO: implement this function + return ncclInternalError; +} + +NCCL_API ncclResult_t ncclGroupStart() { + // Do nothing + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclGroupEnd() { + // Do nothing + return ncclSuccess; +} diff --git a/apps/nccl/test/CMakeLists.txt b/apps/nccl/test/CMakeLists.txt new file mode 100644 index 000000000..025d2db79 --- /dev/null +++ b/apps/nccl/test/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +find_package(MPI) + +add_executable(nccl_api_test nccl_api_test.cc) +target_link_libraries(nccl_api_test mscclpp mscclpp_nccl ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} Threads::Threads MPI::MPI_CXX) +target_include_directories(nccl_api_test PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/apps/nccl/include) diff --git a/apps/nccl/test/nccl_api_test.cc b/apps/nccl/test/nccl_api_test.cc new file mode 100644 index 000000000..4e23a217b --- /dev/null +++ b/apps/nccl/test/nccl_api_test.cc @@ -0,0 +1,116 @@ +// Code borrowed from https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/examples.html + +#include +#include +#include +#include + +#include + +#include "mpi.h" +#include "nccl.h" + +#define MPICHECK(cmd) \ + do { \ + int e = cmd; \ + if (e != MPI_SUCCESS) { \ + printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define NCCLCHECK(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) { \ + printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +static uint64_t getHostHash(const char* string) { + // Based on DJB2a, result = result * 33 ^ char + uint64_t result = 5381; + for (int c = 0; string[c] != '\0'; c++) { + result = ((result << 5) + result) ^ string[c]; + } + return result; +} + +static void getHostName(char* hostname, int maxlen) { + gethostname(hostname, maxlen); + for (int i = 0; i < maxlen; i++) { + if (hostname[i] == '.') { + hostname[i] = '\0'; + return; + } + } +} + +int main(int argc, char* argv[]) { + int size = 32 * 1024 * 1024; + + int myRank, nRanks, localRank = 0; + + // initializing MPI + MPICHECK(MPI_Init(&argc, &argv)); + MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); + MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks)); + + // calculating localRank based on hostname which is used in selecting a GPU + uint64_t hostHashs[nRanks]; + char hostname[1024]; + getHostName(hostname, 1024); + hostHashs[myRank] = getHostHash(hostname); + MPICHECK(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, hostHashs, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD)); + for (int p = 0; p < nRanks; p++) { + if (p == myRank) break; + if (hostHashs[p] == hostHashs[myRank]) localRank++; + } + + ncclUniqueId id; + ncclComm_t comm; + float *sendbuff, *recvbuff; + cudaStream_t s; + + // get NCCL unique ID at rank 0 and broadcast it to all others + if (myRank == 0) ncclGetUniqueId(&id); + MPICHECK(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); + + // picking a GPU based on localRank, allocate device buffers + CUDACHECK(cudaSetDevice(localRank)); + CUDACHECK(cudaMalloc(&sendbuff, size * sizeof(float))); + CUDACHECK(cudaMalloc(&recvbuff, size * sizeof(float))); + CUDACHECK(cudaStreamCreate(&s)); + + // initializing NCCL + NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); + + // communicating using NCCL + NCCLCHECK(ncclAllReduce((const void*)sendbuff, (void*)recvbuff, size, ncclFloat, ncclSum, comm, s)); + + // completing NCCL operation by synchronizing on the CUDA stream + CUDACHECK(cudaStreamSynchronize(s)); + + // free device buffers + CUDACHECK(cudaFree(sendbuff)); + CUDACHECK(cudaFree(recvbuff)); + + // finalizing NCCL + ncclCommDestroy(comm); + + // finalizing MPI + MPICHECK(MPI_Finalize()); + + printf("[MPI Rank %d] Success \n", myRank); + return 0; +} diff --git a/docs/quickstart.md b/docs/quickstart.md index af1bbe5f3..a30c42032 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -59,7 +59,10 @@ $ sudo make install/fast Python 3.8 or later is required. ```bash +# For NVIDIA platforms $ python -m pip install . +# For AMD platforms +$ CXX=/path/to/hipcc python -m pip install . ``` ## Docker Images diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 02c277a3e..c2a4dff44 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -51,6 +51,10 @@ class Bootstrap { /// A native implementation of the bootstrap using TCP sockets. class TcpBootstrap : public Bootstrap { public: + /// Create a random unique ID. + /// @return The created unique ID. + static UniqueId createUniqueId(); + /// Constructor. /// @param rank The rank of the process. /// @param nRanks The total number of ranks. @@ -59,10 +63,6 @@ class TcpBootstrap : public Bootstrap { /// Destructor. ~TcpBootstrap(); - /// Create a random unique ID and store it in the @ref TcpBootstrap. - /// @return The created unique ID. - UniqueId createUniqueId(); - /// Return the unique ID stored in the @ref TcpBootstrap. /// @return The unique ID stored in the @ref TcpBootstrap. UniqueId getUniqueId() const; diff --git a/include/mscclpp/gpu.hpp b/include/mscclpp/gpu.hpp index f560a655c..8e9e17ab5 100644 --- a/include/mscclpp/gpu.hpp +++ b/include/mscclpp/gpu.hpp @@ -6,6 +6,8 @@ #if defined(__HIP_PLATFORM_AMD__) +// #include +#include #include using cudaError_t = hipError_t; @@ -61,6 +63,8 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri #define cudaMemcpy(...) hipMemcpy(__VA_ARGS__) #define cudaMemcpyAsync(...) hipMemcpyAsync(__VA_ARGS__) #define cudaMemcpyToSymbol(...) hipMemcpyToSymbol(__VA_ARGS__) +#define cudaMemcpyToSymbolAsync(...) hipMemcpyToSymbolAsync(__VA_ARGS__) +#define cudaStreamCreate(...) hipStreamCreate(__VA_ARGS__) #define cudaStreamCreateWithFlags(...) hipStreamCreateWithFlags(__VA_ARGS__) #define cudaStreamSynchronize(...) hipStreamSynchronize(__VA_ARGS__) #define cudaStreamBeginCapture(...) hipStreamBeginCapture(__VA_ARGS__) @@ -90,6 +94,12 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri #include #include #include +#if (CUDART_VERSION >= 11000) +#include +#endif +#if (CUDART_VERSION >= 11080) +#include +#endif #endif diff --git a/python/mscclpp/CMakeLists.txt b/python/mscclpp/CMakeLists.txt index 0fe510c80..bb9eadf32 100644 --- a/python/mscclpp/CMakeLists.txt +++ b/python/mscclpp/CMakeLists.txt @@ -9,6 +9,6 @@ FetchContent_MakeAvailable(nanobind) file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp) nanobind_add_module(mscclpp_py ${SOURCES}) set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp) -target_link_libraries(mscclpp_py PRIVATE ${GPU_LIBRARIES} mscclpp_static) -target_include_directories(mscclpp_py PRIVATE ${GPU_INCLUDE_DIRS}) +target_link_libraries(mscclpp_py PRIVATE mscclpp_static ${GPU_LIBRARIES}) +target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS}) install(TARGETS mscclpp_py LIBRARY DESTINATION .) diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index 5fd4bd317..1a1cd2780 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -63,7 +63,7 @@ void register_core(nb::module_& m) { .def_static( "create", [](int rank, int nRanks) { return std::make_shared(rank, nRanks); }, nb::arg("rank"), nb::arg("nRanks")) - .def("create_unique_id", &TcpBootstrap::createUniqueId) + .def_static("create_unique_id", &TcpBootstrap::createUniqueId) .def("get_unique_id", &TcpBootstrap::getUniqueId) .def("initialize", static_cast(&TcpBootstrap::initialize), nb::call_guard(), nb::arg("uniqueId"), nb::arg("timeoutSec") = 30) diff --git a/python/test/CMakeLists.txt b/python/test/CMakeLists.txt index cf705841c..be62aea99 100644 --- a/python/test/CMakeLists.txt +++ b/python/test/CMakeLists.txt @@ -9,5 +9,5 @@ FetchContent_MakeAvailable(nanobind) file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp) nanobind_add_module(mscclpp_py_test ${SOURCES}) set_target_properties(mscclpp_py_test PROPERTIES OUTPUT_NAME _ext) -target_link_libraries(mscclpp_py_test PRIVATE ${GPU_LIBRARIES} mscclpp_static) -target_include_directories(mscclpp_py_test PRIVATE ${GPU_INCLUDE_DIRS}) +target_link_libraries(mscclpp_py_test PRIVATE mscclpp_static ${GPU_LIBRARIES}) +target_include_directories(mscclpp_py_test SYSTEM PRIVATE ${GPU_INCLUDE_DIRS}) diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index 00a58b992..c9cea10f4 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -70,12 +70,14 @@ static_assert(sizeof(UniqueIdInternal) <= sizeof(UniqueId), "UniqueIdInternal is class TcpBootstrap::Impl { public: + static UniqueId createUniqueId(); + static UniqueId getUniqueId(const UniqueIdInternal& uniqueId); + Impl(int rank, int nRanks); ~Impl(); void initialize(const UniqueId& uniqueId, int64_t timeoutSec); void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec); void establishConnections(int64_t timeoutSec); - UniqueId createUniqueId(); UniqueId getUniqueId() const; int getRank(); int getNranks(); @@ -99,7 +101,6 @@ class TcpBootstrap::Impl { std::unique_ptr abortFlagStorage_; volatile uint32_t* abortFlag_; std::thread rootThread_; - char netIfName_[MAX_IF_NAME_SIZE + 1]; SocketAddress netIfAddr_; std::unordered_map, std::shared_ptr, PairHash> peerSendSockets_; std::unordered_map, std::shared_ptr, PairHash> peerRecvSockets_; @@ -110,15 +111,33 @@ class TcpBootstrap::Impl { std::shared_ptr getPeerSendSocket(int peer, int tag); std::shared_ptr getPeerRecvSocket(int peer, int tag); + static void assignPortToUniqueId(UniqueIdInternal& uniqueId); + static void netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr); + void bootstrapCreateRoot(); void bootstrapRoot(); void getRemoteAddresses(Socket* listenSock, std::vector& rankAddresses, std::vector& rankAddressesRoot, int& rank); void sendHandleToPeer(int peer, const std::vector& rankAddresses, const std::vector& rankAddressesRoot); - void netInit(std::string ipPortPair, std::string interface); }; +UniqueId TcpBootstrap::Impl::createUniqueId() { + UniqueIdInternal uniqueId; + SocketAddress netIfAddr; + netInit("", "", netIfAddr); + getRandomData(&uniqueId.magic, sizeof(uniqueId_.magic)); + std::memcpy(&uniqueId.addr, &netIfAddr, sizeof(SocketAddress)); + assignPortToUniqueId(uniqueId); + return getUniqueId(uniqueId); +} + +UniqueId TcpBootstrap::Impl::getUniqueId(const UniqueIdInternal& uniqueId) { + UniqueId ret; + std::memcpy(&ret, &uniqueId, sizeof(uniqueId)); + return ret; +} + TcpBootstrap::Impl::Impl(int rank, int nRanks) : rank_(rank), nRanks_(nRanks), @@ -128,29 +147,26 @@ TcpBootstrap::Impl::Impl(int rank, int nRanks) abortFlagStorage_(new uint32_t(0)), abortFlag_(abortFlagStorage_.get()) {} -UniqueId TcpBootstrap::Impl::getUniqueId() const { - UniqueId ret; - std::memcpy(&ret, &uniqueId_, sizeof(uniqueId_)); - return ret; -} - -UniqueId TcpBootstrap::Impl::createUniqueId() { - netInit("", ""); - getRandomData(&uniqueId_.magic, sizeof(uniqueId_.magic)); - std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress)); - bootstrapCreateRoot(); - return getUniqueId(); -} +UniqueId TcpBootstrap::Impl::getUniqueId() const { return getUniqueId(uniqueId_); } int TcpBootstrap::Impl::getRank() { return rank_; } int TcpBootstrap::Impl::getNranks() { return nRanks_; } void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId, int64_t timeoutSec) { - netInit("", ""); + if (!netInitialized) { + netInit("", "", netIfAddr_); + netInitialized = true; + } std::memcpy(&uniqueId_, &uniqueId, sizeof(uniqueId_)); + if (rank_ == 0) { + bootstrapCreateRoot(); + } + char line[MAX_IF_NAME_SIZE + 1]; + SocketToString(&uniqueId_.addr, line); + INFO(MSCCLPP_INIT, "rank %d nranks %d - connecting to %s", rank_, nRanks_, line); establishConnections(timeoutSec); } @@ -170,7 +186,10 @@ void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t tim ipPortPair = ifIpPortTrio.substr(ipPortPair.find_first_of(':') + 1); } - netInit(ipPortPair, interface); + if (!netInitialized) { + netInit(ipPortPair, interface, netIfAddr_); + netInitialized = true; + } uniqueId_.magic = 0xdeadbeef; std::memcpy(&uniqueId_.addr, &netIfAddr_, sizeof(SocketAddress)); @@ -230,9 +249,15 @@ void TcpBootstrap::Impl::sendHandleToPeer(int peer, const std::vector socket = std::make_unique(&uniqueId.addr, uniqueId.magic, SocketTypeBootstrap); + socket->bind(); + uniqueId.addr = socket->getAddr(); +} + void TcpBootstrap::Impl::bootstrapCreateRoot() { listenSockRoot_ = std::make_unique(&uniqueId_.addr, uniqueId_.magic, SocketTypeBootstrap, abortFlag_, 0); - listenSockRoot_->listen(); + listenSockRoot_->bindAndListen(); uniqueId_.addr = listenSockRoot_->getAddr(); rootThread_ = std::thread([this]() { @@ -279,34 +304,33 @@ void TcpBootstrap::Impl::bootstrapRoot() { TRACE(MSCCLPP_INIT, "DONE"); } -void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface) { - if (netInitialized) return; +void TcpBootstrap::Impl::netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr) { + char netIfName[MAX_IF_NAME_SIZE + 1]; if (!ipPortPair.empty()) { if (interface != "") { // we know the - int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1, interface.c_str()); + int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1, interface.c_str()); if (ret <= 0) throw Error("NET/Socket : No interface named " + interface + " found.", ErrorCode::InternalError); } else { // we do not know the try to match it next SocketAddress remoteAddr; SocketGetAddrFromString(&remoteAddr, ipPortPair.c_str()); - if (FindInterfaceMatchSubnet(netIfName_, &netIfAddr_, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) { + if (FindInterfaceMatchSubnet(netIfName, &netIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) { throw Error("NET/Socket : No usable listening interface found", ErrorCode::InternalError); } } } else { - int ret = FindInterfaces(netIfName_, &netIfAddr_, MAX_IF_NAME_SIZE, 1); + int ret = FindInterfaces(netIfName, &netIfAddr, MAX_IF_NAME_SIZE, 1); if (ret <= 0) { throw Error("TcpBootstrap : no socket interface found", ErrorCode::InternalError); } } char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2]; - std::sprintf(line, " %s:", netIfName_); - SocketToString(&netIfAddr_, line + strlen(line)); + std::sprintf(line, " %s:", netIfName); + SocketToString(&netIfAddr, line + strlen(line)); INFO(MSCCLPP_INIT, "TcpBootstrap : Using%s", line); - netInitialized = true; } #define TIMEOUT(__exp) \ @@ -345,13 +369,13 @@ void TcpBootstrap::Impl::establishConnections(int64_t timeoutSec) { uint64_t magic = uniqueId_.magic; // Create socket for other ranks to contact me listenSock_ = std::make_unique(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_); - listenSock_->listen(); + listenSock_->bindAndListen(); info.extAddressListen = listenSock_->getAddr(); { // Create socket for root to contact me Socket lsock(&netIfAddr_, magic, SocketTypeBootstrap, abortFlag_); - lsock.listen(); + lsock.bindAndListen(); info.extAddressListenRoot = lsock.getAddr(); // stagger connection times to avoid an overload of the root @@ -486,9 +510,9 @@ void TcpBootstrap::Impl::close() { peerRecvSockets_.clear(); } -MSCCLPP_API_CPP TcpBootstrap::TcpBootstrap(int rank, int nRanks) { pimpl_ = std::make_unique(rank, nRanks); } +MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return Impl::createUniqueId(); } -MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return pimpl_->createUniqueId(); } +MSCCLPP_API_CPP TcpBootstrap::TcpBootstrap(int rank, int nRanks) { pimpl_ = std::make_unique(rank, nRanks); } MSCCLPP_API_CPP UniqueId TcpBootstrap::getUniqueId() const { return pimpl_->getUniqueId(); } diff --git a/src/bootstrap/socket.cc b/src/bootstrap/socket.cc index 2267af9b3..a79821f1b 100644 --- a/src/bootstrap/socket.cc +++ b/src/bootstrap/socket.cc @@ -390,7 +390,7 @@ Socket::Socket(const SocketAddress* addr, uint64_t magic, enum SocketType type, Socket::~Socket() { close(); } -void Socket::listen() { +void Socket::bind() { if (fd_ == -1) { throw Error("file descriptor is -1", ErrorCode::InvalidUsage); } @@ -433,7 +433,11 @@ void Socket::listen() { if (::getsockname(fd_, &addr_.sa, &size) != 0) { throw SysError("getsockname failed", errno); } + state_ = SocketStateBound; +} +void Socket::bindAndListen() { + bind(); #ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN + 1]; TRACE(MSCCLPP_INIT | MSCCLPP_NET, "Listening on socket %s", SocketToString(&addr_, line)); diff --git a/src/fifo.cc b/src/fifo.cc index 4255bcdcd..592bf7d00 100644 --- a/src/fifo.cc +++ b/src/fifo.cc @@ -56,6 +56,7 @@ MSCCLPP_API_CPP void Fifo::pop() { MSCCLPP_API_CPP void Fifo::flushTail(bool sync) { // Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure that the fifo can // make progress even if there is no request mscclppSync. However, mscclppSync type is for flush request. + AvoidCudaGraphCaptureGuard cgcGuard; MSCCLPP_CUDATHROW(cudaMemcpyAsync(pimpl->tailReplica.get(), &pimpl->hostTail, sizeof(uint64_t), cudaMemcpyHostToDevice, pimpl->stream)); if (sync) { diff --git a/src/include/socket.h b/src/include/socket.h index 9f043414e..ed125c990 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -35,10 +35,11 @@ enum SocketState { SocketStateConnecting = 4, SocketStateConnectPolling = 5, SocketStateConnected = 6, - SocketStateReady = 7, - SocketStateClosed = 8, - SocketStateError = 9, - SocketStateNum = 10 + SocketStateBound = 7, + SocketStateReady = 8, + SocketStateClosed = 9, + SocketStateError = 10, + SocketStateNum = 11 }; enum SocketType { @@ -62,7 +63,8 @@ class Socket { enum SocketType type = SocketTypeUnknown, volatile uint32_t* abortFlag = nullptr, int asyncFlag = 0); ~Socket(); - void listen(); + void bind(); + void bindAndListen(); void connect(int64_t timeout = -1); void accept(const Socket* listenSocket, int64_t timeout = -1); void send(void* ptr, int size); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 0268af1c6..da47066ea 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -5,7 +5,7 @@ find_package(MPI) set(TEST_LIBS_COMMON mscclpp ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} Threads::Threads) set(TEST_LIBS_GTEST GTest::gtest_main GTest::gmock_main) -set(TEST_INC_COMMON PRIVATE ${PROJECT_SOURCE_DIR}/include ${GPU_INCLUDE_DIRS}) +set(TEST_INC_COMMON PRIVATE ${PROJECT_SOURCE_DIR}/include SYSTEM PRIVATE ${GPU_INCLUDE_DIRS}) set(TEST_INC_INTERNAL PRIVATE ${PROJECT_SOURCE_DIR}/src/include) if(USE_ROCM) diff --git a/test/mscclpp-test/allreduce_test.cu b/test/mscclpp-test/allreduce_test.cu index 84eb694b1..07021a551 100644 --- a/test/mscclpp-test/allreduce_test.cu +++ b/test/mscclpp-test/allreduce_test.cu @@ -1142,9 +1142,15 @@ void AllReduceTestColl::runColl(const TestArgs& args, cudaStream_t stream) { tmpBuff = scratchPacketBuff; nThreadsPerBlock = 512; } else if (kernelNum == 7) { - nBlocks = 28; tmpBuff = scratchPacketBuff; - nThreadsPerBlock = 1024; + // tune the #blocks and #threads for MI300X + if (paramCount_ < 8192) { + nBlocks = 28; + nThreadsPerBlock = 1024; + } else { + nBlocks = 56; + nThreadsPerBlock = (paramCount_ <= 76800) ? 512 : 1024; + } } else { nBlocks = std::max(args.nRanksPerNode - 1, 1) * BLOCKS_PER_PEER; tmpBuff = scratchPacketBuff; diff --git a/test/unit/socket_tests.cc b/test/unit/socket_tests.cc index fe0a063e5..4fa8d3915 100644 --- a/test/unit/socket_tests.cc +++ b/test/unit/socket_tests.cc @@ -17,7 +17,7 @@ TEST(Socket, ListenAndConnect) { ASSERT_NO_THROW(mscclpp::SocketGetAddrFromString(&listenAddr, ipPortPair.c_str())); mscclpp::Socket listenSock(&listenAddr); - listenSock.listen(); + listenSock.bindAndListen(); std::thread clientThread([&listenAddr]() { mscclpp::Socket sock(&listenAddr);