diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 46d7328e5..70c1e2592 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -35,23 +35,8 @@ // mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5, // mscclpp::Transport::IB6, mscclpp::Transport::IB7}; -// Declare the global map and mutex to store associations between ncclComm and the shared pointers. -std::unordered_map communicatorMap; -std::mutex communicatorMapMutex; - -// Thread-local variable for current communicator -thread_local ncclComm* currentNcclComm = nullptr; - -// Function to set the current communicator -void setCurrentNcclComm(ncclUniqueId commId, ncclComm* comm) { - std::lock_guard lock(communicatorMapMutex); - communicatorMap[commId] = comm; -} - -// Function to get the current communicator -ncclComm* getCurrentNcclComm() { - return currentNcclComm; -} +// Declare the global map to store associations between raw pointer and shared pointer +std::unordered_map> ptrMap; struct channelKey { const void* buff; @@ -97,8 +82,6 @@ struct ncclComm { std::shared_ptr scratchBuff; std::vector remoteScratchRegMemories; - ncclUniqueId commId; // Add commId to ncclComm - uint32_t numScratchBuff; uint32_t buffFlag; }; @@ -434,9 +417,6 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI } } - // Register the communicator - setCurrentNcclComm(commId, commPtr); - *comm = commPtr; return ncclSuccess; } @@ -670,55 +650,27 @@ NCCL_API ncclResult_t ncclGroupEnd() { return ncclSuccess; } -// Custom deleter for CUDA memory -void cudaDeleter(char* ptr) { - if (ptr) { - cudaFree(ptr); - } -} - ncclResult_t ncclMemAlloc(void** ptr, size_t size) { // Allocate memory using mscclpp::allocSharedPhysicalCuda - char* rawPtr = mscclpp::allocSharedPhysicalCuda(size); + auto rawPtr = mscclpp::allocSharedPhysicalCuda(size); if (rawPtr == nullptr) { - return ncclError; - } - - // Obtain current communicator - ncclComm* currentComm = getCurrentNcclComm(); - if (currentComm == nullptr) { - return ncclError; + return ncclInternalError; } - // Create a shared pointer and store it in the global map - std::shared_ptr sharedPtr(rawPtr, cudaDeleter); - { - std::lock_guard lock(communicatorMapMutex); - communicatorMap[currentComm->commId][rawPtr] = sharedPtr; - } + ptrMap[rawPtr.get()] = rawPtr; - // Return the raw pointer - *ptr = rawPtr; + // Return the pointer + *ptr = rawPtr.get(); return ncclSuccess; } ncclResult_t ncclMemFree(void* ptr) { - // Obtain current communicator - ncclComm* currentComm = getCurrentNcclComm(); - if (currentComm == nullptr) { - return ncclError; + auto ptrIt = ptrMap.find(ptr); + if (ptrIt != ptrMap.end()) { + ptrMap.erase(ptrIt); + return ncclSuccess; } - std::lock_guard lock(communicatorMapMutex); - auto commIt = communicatorMap.find(currentComm->commId); - if (commIt != communicatorMap.end()) { - auto ptrIt = commIt->second.find(ptr); - if (ptrIt != commIt->second.end()) { - // Shared pointer is automatically deleted - commIt->second.erase(ptrIt); - return ncclSuccess; - } - } // Pointer not found - return ncclError; + return ncclInternalError; }