diff --git a/apps/nccl/src/broadcast.hpp b/apps/nccl/src/broadcast.hpp index d0c10a87a..2f9695ef9 100644 --- a/apps/nccl/src/broadcast.hpp +++ b/apps/nccl/src/broadcast.hpp @@ -19,12 +19,17 @@ __global__ void __launch_bounds__(1024, 1) const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; const size_t lid = tid % WARP_SIZE; const size_t wid = tid / WARP_SIZE; - + const size_t nPeer = nRanksPerNode - 1; + const size_t nBlocksPerPeer = gridDim.x / nPeer; const size_t nThread = blockDim.x * gridDim.x; const size_t nWarp = nThread / WARP_SIZE; - const size_t nPeer = nRanksPerNode - 1; const size_t chanOffset = nPeer * blockIdx.x; - auto smChans = smChannels + chanOffset; + //auto smChans = smChannels + chanOffset; + + __shared__ mscclpp::DeviceHandle smChans[NRANKS_PER_NODE - 1]; + if(threadIdx.x < nPeer) + smChans[threadIdx.x] = smChannels[chanOffset+threadIdx.x]; + __syncthreads(); if (threadIdx.x < nPeer) { smChans[threadIdx.x].relaxedSignal();