diff --git a/src/executor/executor.cc b/src/executor/executor.cc index 9de8a58be..1fcb61865 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -22,21 +22,62 @@ struct ExecutionContextKey { recvBuffSize == other.recvBuffSize && plan == other.plan; } }; + +struct DeviceExecutionPlanKey { + size_t inputMessageSize; + size_t outputMessageSize; + size_t constSrcOffset; + size_t constDstOffset; + + bool operator==(const DeviceExecutionPlanKey& other) const { + return inputMessageSize == other.inputMessageSize && outputMessageSize == other.outputMessageSize && + constSrcOffset == other.constSrcOffset && constDstOffset == other.constDstOffset; + } +}; + } // namespace mscclpp namespace std { + +// Refer https://www.boost.org/doc/libs/1_86_0/libs/container_hash/doc/html/hash.html#combine +template +inline void hash_combine(std::size_t& seed, const T& value) { + std::hash hasher; + seed ^= hasher(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + template <> struct hash> { std::size_t operator()(const std::pair& key) const { - return std::hash()(key.second) ^ std::hash()(static_cast(key.first)); + std::size_t seed = 0; + hash_combine(seed, static_cast(key.first)); + hash_combine(seed, key.second); + return seed; } }; template <> struct hash { std::size_t operator()(const mscclpp::ExecutionContextKey& key) const { - return std::hash()(key.sendBuff) ^ std::hash()(key.recvBuff) ^ std::hash()(key.sendBuffSize) ^ - std::hash()(key.recvBuffSize) ^ std::hash()(key.plan); + size_t seed = 0; + hash_combine(seed, key.sendBuff); + hash_combine(seed, key.recvBuff); + hash_combine(seed, key.sendBuffSize); + hash_combine(seed, key.recvBuffSize); + hash_combine(seed, key.plan); + return seed; + } +}; + +template <> +struct hash { + std::size_t operator()(const mscclpp::DeviceExecutionPlanKey& key) const { + std::size_t seed = 0; + hash_combine(seed, key.inputMessageSize); + hash_combine(seed, key.outputMessageSize); + hash_combine(seed, key.constSrcOffset); + hash_combine(seed, key.constDstOffset); + return seed; } }; } // namespace std @@ -61,11 +102,12 @@ struct ExecutionContext { std::vector proxySemaphores; std::vector smChannels; std::vector proxyChannels; - std::vector deviceExecutionPlans; + std::unordered_map> deviceExecutionPlans; + std::unordered_map> deviceExecutionPlansBuffers; std::shared_ptr scratchBuffer; size_t scratchBufferSize; - std::shared_ptr deviceExecutionPlansBuffer; int nthreadsPerBlock; + DeviceExecutionPlanKey currentDevicePlan; }; struct Executor::Impl { @@ -81,23 +123,32 @@ struct Executor::Impl { ~Impl() = default; ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t inputMessageSize, - size_t outputMessageSize, size_t contsSrcOffset, size_t constDstOffset, + size_t outputMessageSize, size_t constSrcOffset, size_t constDstOffset, size_t sendBufferSize, size_t recvBufferSize, const ExecutionPlan& plan) { ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize, plan.impl_->name}; + DeviceExecutionPlanKey devicePlanKey = {inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset}; if (this->contexts.find(key) != this->contexts.end()) { + auto& devicePlans = this->contexts[key].deviceExecutionPlans; + if (this->contexts[key].currentDevicePlan == devicePlanKey) { + return this->contexts[key]; + } else if (devicePlans.find(devicePlanKey) != devicePlans.end()) { + this->contexts[key].currentDevicePlan = devicePlanKey; + return this->contexts[key]; + } plan.impl_->operationsReset(); - plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset); - this->setupDeviceExecutionPlan(this->contexts[key], rank, plan); - this->contexts[key].deviceExecutionPlansBuffer = - allocExtSharedCuda(this->contexts[key].deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan)); - memcpyCuda(this->contexts[key].deviceExecutionPlansBuffer.get(), - (char*)this->contexts[key].deviceExecutionPlans.data(), - this->contexts[key].deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice); + plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset); + this->setupDeviceExecutionPlan(this->contexts[key], devicePlanKey, rank, plan); + this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey] = + allocExtSharedCuda(devicePlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)); + memcpyCuda(this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey].get(), + (char*)devicePlans[devicePlanKey].data(), + devicePlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice); + this->contexts[key].currentDevicePlan = devicePlanKey; return this->contexts[key]; } plan.impl_->reset(); - plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset); + plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset); ExecutionContext context; size_t scratchBufferSize = plan.impl_->getScratchBufferSize(rank, sendBufferSize, recvBufferSize); @@ -109,11 +160,14 @@ struct Executor::Impl { this->setupConnections(context, rank, plan); this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); - this->setupDeviceExecutionPlan(context, rank, plan); - context.deviceExecutionPlansBuffer = - allocExtSharedCuda(context.deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan)); - memcpyCuda(context.deviceExecutionPlansBuffer.get(), (char*)context.deviceExecutionPlans.data(), - context.deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice); + this->setupDeviceExecutionPlan(context, devicePlanKey, rank, plan); + context.deviceExecutionPlansBuffers[devicePlanKey] = + allocExtSharedCuda(context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)); + memcpyCuda(context.deviceExecutionPlansBuffers[devicePlanKey].get(), + (char*)context.deviceExecutionPlans[devicePlanKey].data(), + context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan), + cudaMemcpyHostToDevice); + context.currentDevicePlan = devicePlanKey; context.proxyService->startProxy(); this->contexts.insert({key, context}); return context; @@ -278,7 +332,8 @@ struct Executor::Impl { } } - void setupDeviceExecutionPlan(ExecutionContext& context, int rank, const ExecutionPlan& plan) { + void setupDeviceExecutionPlan(ExecutionContext& context, const DeviceExecutionPlanKey& key, int rank, + const ExecutionPlan& plan) { std::vector deviceExecutionPlans; for (int threadblock = 0; threadblock < plan.impl_->getThreadblockCount(rank); threadblock++) { DeviceExecutionPlan deviceExecutionPlan = {}; @@ -299,13 +354,14 @@ struct Executor::Impl { } deviceExecutionPlans.push_back(deviceExecutionPlan); } - context.deviceExecutionPlans = std::move(deviceExecutionPlans); + context.deviceExecutionPlans[key] = std::move(deviceExecutionPlans); } void launchKernel(ExecutionContext& context, int rank, void* sendbuff, void* recvbuff, DataType dataType, cudaStream_t stream, PacketType packetType) { static uint32_t flag = 0; - int nthreadblocks = context.deviceExecutionPlans.size(); + DeviceExecutionPlanKey key = context.currentDevicePlan; + int nthreadblocks = context.deviceExecutionPlans[key].size(); #if defined(ENABLE_NPKIT) #if defined(__HIP_PLATFORM_AMD__) if (nthreadblocks > NPKIT_MAX_NUM_GPU_THREADBLOCKS) { @@ -323,13 +379,13 @@ struct Executor::Impl { case PacketType::LL16: ExecutionKernel::launchKernel( rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), - context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), + context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(), sharedMemSize, stream, ++flag); break; case PacketType::LL8: ExecutionKernel::launchKernel( rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), - context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), + context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(), sharedMemSize, stream, ++flag); break; default: