Skip to content

Commit

Permalink
Directly store the shared pointer to a map without the communicator a…
Browse files Browse the repository at this point in the history
…nd commId; Remove the cudaDeleter
  • Loading branch information
seagater committed Dec 9, 2024
1 parent 4298abe commit 129d31b
Showing 1 changed file with 12 additions and 60 deletions.
72 changes: 12 additions & 60 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,8 @@
// mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5,
// mscclpp::Transport::IB6, mscclpp::Transport::IB7};

// Declare the global map and mutex to store associations between ncclComm and the shared pointers.
std::unordered_map<ncclUniqueId, ncclComm*> communicatorMap;
std::mutex communicatorMapMutex;

// Thread-local variable for current communicator
thread_local ncclComm* currentNcclComm = nullptr;

// Function to set the current communicator
void setCurrentNcclComm(ncclUniqueId commId, ncclComm* comm) {
std::lock_guard<std::mutex> lock(communicatorMapMutex);
communicatorMap[commId] = comm;
}

// Function to get the current communicator
ncclComm* getCurrentNcclComm() {
return currentNcclComm;
}
// Declare the global map to store associations between raw pointer and shared pointer
std::unordered_map<void*, std::shared_ptr<char>> ptrMap;

struct channelKey {
const void* buff;
Expand Down Expand Up @@ -97,8 +82,6 @@ struct ncclComm {
std::shared_ptr<char> scratchBuff;
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;

ncclUniqueId commId; // Add commId to ncclComm

uint32_t numScratchBuff;
uint32_t buffFlag;
};
Expand Down Expand Up @@ -434,9 +417,6 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
}
}

// Register the communicator
setCurrentNcclComm(commId, commPtr);

*comm = commPtr;
return ncclSuccess;
}
Expand Down Expand Up @@ -670,55 +650,27 @@ NCCL_API ncclResult_t ncclGroupEnd() {
return ncclSuccess;
}

// Custom deleter for CUDA memory
void cudaDeleter(char* ptr) {
if (ptr) {
cudaFree(ptr);
}
}

ncclResult_t ncclMemAlloc(void** ptr, size_t size) {
// Allocate memory using mscclpp::allocSharedPhysicalCuda
char* rawPtr = mscclpp::allocSharedPhysicalCuda(size);
auto rawPtr = mscclpp::allocSharedPhysicalCuda<char>(size);
if (rawPtr == nullptr) {
return ncclError;
}

// Obtain current communicator
ncclComm* currentComm = getCurrentNcclComm();
if (currentComm == nullptr) {
return ncclError;
return ncclInternalError;
}

// Create a shared pointer and store it in the global map
std::shared_ptr<char> sharedPtr(rawPtr, cudaDeleter);
{
std::lock_guard<std::mutex> lock(communicatorMapMutex);
communicatorMap[currentComm->commId][rawPtr] = sharedPtr;
}
ptrMap[rawPtr.get()] = rawPtr;

// Return the raw pointer
*ptr = rawPtr;
// Return the pointer
*ptr = rawPtr.get();
return ncclSuccess;
}

ncclResult_t ncclMemFree(void* ptr) {
// Obtain current communicator
ncclComm* currentComm = getCurrentNcclComm();
if (currentComm == nullptr) {
return ncclError;
auto ptrIt = ptrMap.find(ptr);
if (ptrIt != ptrMap.end()) {
ptrMap.erase(ptrIt);
return ncclSuccess;
}

std::lock_guard<std::mutex> lock(communicatorMapMutex);
auto commIt = communicatorMap.find(currentComm->commId);
if (commIt != communicatorMap.end()) {
auto ptrIt = commIt->second.find(ptr);
if (ptrIt != commIt->second.end()) {
// Shared pointer is automatically deleted
commIt->second.erase(ptrIt);
return ncclSuccess;
}
}
// Pointer not found
return ncclError;
return ncclInternalError;
}

0 comments on commit 129d31b

Please sign in to comment.