Skip to content

Commit

Permalink
hipGraph work
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Feb 22, 2024
1 parent c50f7b5 commit cf91552
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
10 changes: 7 additions & 3 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ struct ncclComm {
std::unordered_map<channelKey, std::vector<mscclpp::SmChannel>> smOutChannels;
std::shared_ptr<char> scratchBuff;
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smDeviceHandles;
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smOutDeviceHandles;
};

cudaError_t allreduce(int* buff, int* scratch, void* resultBuff, int rank, int nRanksPerNode, int worldSize,
Expand Down Expand Up @@ -739,12 +741,14 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>(sendbuff));
it = comm->smChannels.emplace(key, channels).first;
}
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>>& smChannelDeviceHandles = comm->smDeviceHandles;
smChannelDeviceHandles.clear();
std::transform(it->second.begin(), it->second.end(), std::back_inserter(smChannelDeviceHandles),
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
// TODO: if sendbuff and recvbuff don't change, we can avoid copying smChannelDeviceHandles to device
CUDACHECK(cudaMemcpyToSymbol(constSmChannels, smChannelDeviceHandles.data(),
sizeof(mscclpp::DeviceHandle<mscclpp::SmChannel>) * smChannelDeviceHandles.size()));
CUDACHECK(cudaMemcpyToSymbolAsync(constSmChannels, smChannelDeviceHandles.data(),
sizeof(mscclpp::DeviceHandle<mscclpp::SmChannel>) * smChannelDeviceHandles.size(),
0, cudaMemcpyHostToDevice, stream));
} else {
auto it = comm->smChannels.find(key);
auto outIt = comm->smOutChannels.find(key);
Expand Down
1 change: 1 addition & 0 deletions include/mscclpp/gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri
#define cudaMemcpy(...) hipMemcpy(__VA_ARGS__)
#define cudaMemcpyAsync(...) hipMemcpyAsync(__VA_ARGS__)
#define cudaMemcpyToSymbol(...) hipMemcpyToSymbol(__VA_ARGS__)
#define cudaMemcpyToSymbolAsync(...) hipMemcpyToSymbolAsync(__VA_ARGS__)
#define cudaStreamCreate(...) hipStreamCreate(__VA_ARGS__)
#define cudaStreamCreateWithFlags(...) hipStreamCreateWithFlags(__VA_ARGS__)
#define cudaStreamSynchronize(...) hipStreamSynchronize(__VA_ARGS__)
Expand Down

0 comments on commit cf91552

Please sign in to comment.