diff --git a/src/executor/executor.cc b/src/executor/executor.cc index 9de8a58be..929fcd007 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -22,21 +22,62 @@ struct ExecutionContextKey { recvBuffSize == other.recvBuffSize && plan == other.plan; } }; + +struct ExecutionDevicePlanKey { + size_t inputMessageSize; + size_t outputMessageSize; + size_t contsSrcOffset; + size_t constDstOffset; + + bool operator==(const ExecutionDevicePlanKey& other) const { + return inputMessageSize == other.inputMessageSize && outputMessageSize == other.outputMessageSize && + contsSrcOffset == other.contsSrcOffset && constDstOffset == other.constDstOffset; + } +}; + } // namespace mscclpp namespace std { + +// Refer https://www.boost.org/doc/libs/1_86_0/libs/container_hash/doc/html/hash.html#combine +template +inline void hash_combine(std::size_t& seed, const T& value) { + std::hash hasher; + seed ^= hasher(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + template <> struct hash> { std::size_t operator()(const std::pair& key) const { - return std::hash()(key.second) ^ std::hash()(static_cast(key.first)); + std::size_t seed = 0; + hash_combine(seed, static_cast(key.first)); + hash_combine(seed, key.second); + return seed; } }; template <> struct hash { std::size_t operator()(const mscclpp::ExecutionContextKey& key) const { - return std::hash()(key.sendBuff) ^ std::hash()(key.recvBuff) ^ std::hash()(key.sendBuffSize) ^ - std::hash()(key.recvBuffSize) ^ std::hash()(key.plan); + size_t seed = 0; + hash_combine(seed, key.sendBuff); + hash_combine(seed, key.recvBuff); + hash_combine(seed, key.sendBuffSize); + hash_combine(seed, key.recvBuffSize); + hash_combine(seed, key.plan); + return seed; + } +}; + +template <> +struct hash { + std::size_t operator()(const mscclpp::ExecutionDevicePlanKey& key) const { + std::size_t seed = 0; + hash_combine(seed, key.inputMessageSize); + hash_combine(seed, key.outputMessageSize); + hash_combine(seed, key.contsSrcOffset); + hash_combine(seed, key.constDstOffset); + return seed; } }; } // namespace std @@ -61,11 +102,12 @@ struct ExecutionContext { std::vector proxySemaphores; std::vector smChannels; std::vector proxyChannels; - std::vector deviceExecutionPlans; + std::unordered_map> deviceExecutionPlans; + std::unordered_map> deviceExecutionPlansBuffers; std::shared_ptr scratchBuffer; size_t scratchBufferSize; - std::shared_ptr deviceExecutionPlansBuffer; int nthreadsPerBlock; + ExecutionDevicePlanKey currentDevicePlan; }; struct Executor::Impl { @@ -84,15 +126,21 @@ struct Executor::Impl { size_t outputMessageSize, size_t contsSrcOffset, size_t constDstOffset, size_t sendBufferSize, size_t recvBufferSize, const ExecutionPlan& plan) { ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize, plan.impl_->name}; + ExecutionDevicePlanKey devicePlanKey = {inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset}; if (this->contexts.find(key) != this->contexts.end()) { + if (this->contexts[key].currentDevicePlan == devicePlanKey) { + return this->contexts[key]; + } plan.impl_->operationsReset(); plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset); - this->setupDeviceExecutionPlan(this->contexts[key], rank, plan); - this->contexts[key].deviceExecutionPlansBuffer = + this->setupDeviceExecutionPlan(this->contexts[key], rank, plan, devicePlanKey); + this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey] = allocExtSharedCuda(this->contexts[key].deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan)); - memcpyCuda(this->contexts[key].deviceExecutionPlansBuffer.get(), - (char*)this->contexts[key].deviceExecutionPlans.data(), - this->contexts[key].deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice); + memcpyCuda(this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey].get(), + (char*)this->contexts[key].deviceExecutionPlans[devicePlanKey].data(), + this->contexts[key].deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan), + cudaMemcpyHostToDevice); + this->contexts[key].currentDevicePlan = devicePlanKey; return this->contexts[key]; } @@ -109,11 +157,14 @@ struct Executor::Impl { this->setupConnections(context, rank, plan); this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); - this->setupDeviceExecutionPlan(context, rank, plan); - context.deviceExecutionPlansBuffer = - allocExtSharedCuda(context.deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan)); - memcpyCuda(context.deviceExecutionPlansBuffer.get(), (char*)context.deviceExecutionPlans.data(), - context.deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice); + this->setupDeviceExecutionPlan(context, rank, plan, devicePlanKey); + context.deviceExecutionPlansBuffers[devicePlanKey] = + allocExtSharedCuda(context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)); + memcpyCuda(context.deviceExecutionPlansBuffers[devicePlanKey].get(), + (char*)context.deviceExecutionPlans[devicePlanKey].data(), + context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan), + cudaMemcpyHostToDevice); + context.currentDevicePlan = devicePlanKey; context.proxyService->startProxy(); this->contexts.insert({key, context}); return context; @@ -278,7 +329,8 @@ struct Executor::Impl { } } - void setupDeviceExecutionPlan(ExecutionContext& context, int rank, const ExecutionPlan& plan) { + void setupDeviceExecutionPlan(ExecutionContext& context, int rank, const ExecutionPlan& plan, + const ExecutionDevicePlanKey& key) { std::vector deviceExecutionPlans; for (int threadblock = 0; threadblock < plan.impl_->getThreadblockCount(rank); threadblock++) { DeviceExecutionPlan deviceExecutionPlan = {}; @@ -299,12 +351,13 @@ struct Executor::Impl { } deviceExecutionPlans.push_back(deviceExecutionPlan); } - context.deviceExecutionPlans = std::move(deviceExecutionPlans); + context.deviceExecutionPlans[key] = std::move(deviceExecutionPlans); } void launchKernel(ExecutionContext& context, int rank, void* sendbuff, void* recvbuff, DataType dataType, cudaStream_t stream, PacketType packetType) { static uint32_t flag = 0; + ExecutionDevicePlanKey key = context.currentDevicePlan; int nthreadblocks = context.deviceExecutionPlans.size(); #if defined(ENABLE_NPKIT) #if defined(__HIP_PLATFORM_AMD__) @@ -323,13 +376,13 @@ struct Executor::Impl { case PacketType::LL16: ExecutionKernel::launchKernel( rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), - context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), + context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(), sharedMemSize, stream, ++flag); break; case PacketType::LL8: ExecutionKernel::launchKernel( rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), - context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(), + context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(), sharedMemSize, stream, ++flag); break; default: