From 179a6a6452f6544af25384bc1facb6353884e1cc Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 8 Nov 2024 06:20:53 +0000 Subject: [PATCH] Fix --- src/executor/executor.cc | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/executor/executor.cc b/src/executor/executor.cc index c51f3302a..92dcba35e 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -23,13 +23,13 @@ struct ExecutionContextKey { } }; -struct ExecutionDevicePlanKey { +struct DeviceExecutionPlanKey { size_t inputMessageSize; size_t outputMessageSize; size_t contsSrcOffset; size_t constDstOffset; - bool operator==(const ExecutionDevicePlanKey& other) const { + bool operator==(const DeviceExecutionPlanKey& other) const { return inputMessageSize == other.inputMessageSize && outputMessageSize == other.outputMessageSize && contsSrcOffset == other.contsSrcOffset && constDstOffset == other.constDstOffset; } @@ -70,8 +70,8 @@ struct hash { }; template <> -struct hash { - std::size_t operator()(const mscclpp::ExecutionDevicePlanKey& key) const { +struct hash { + std::size_t operator()(const mscclpp::DeviceExecutionPlanKey& key) const { std::size_t seed = 0; hash_combine(seed, key.inputMessageSize); hash_combine(seed, key.outputMessageSize); @@ -102,12 +102,12 @@ struct ExecutionContext { std::vector proxySemaphores; std::vector smChannels; std::vector proxyChannels; - std::unordered_map> deviceExecutionPlans; - std::unordered_map> deviceExecutionPlansBuffers; + std::unordered_map> deviceExecutionPlans; + std::unordered_map> deviceExecutionPlansBuffers; std::shared_ptr scratchBuffer; size_t scratchBufferSize; int nthreadsPerBlock; - ExecutionDevicePlanKey currentDevicePlan; + DeviceExecutionPlanKey currentDevicePlan; }; struct Executor::Impl { @@ -126,7 +126,7 @@ struct Executor::Impl { size_t outputMessageSize, size_t contsSrcOffset, size_t constDstOffset, size_t sendBufferSize, size_t recvBufferSize, const ExecutionPlan& plan) { ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize, plan.impl_->name}; - ExecutionDevicePlanKey devicePlanKey = {inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset}; + DeviceExecutionPlanKey devicePlanKey = {inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset}; if (this->contexts.find(key) != this->contexts.end()) { if (this->contexts[key].currentDevicePlan == devicePlanKey) { return this->contexts[key]; @@ -330,7 +330,7 @@ struct Executor::Impl { } void setupDeviceExecutionPlan(ExecutionContext& context, int rank, const ExecutionPlan& plan, - const ExecutionDevicePlanKey& key) { + const DeviceExecutionPlanKey& key) { std::vector deviceExecutionPlans; for (int threadblock = 0; threadblock < plan.impl_->getThreadblockCount(rank); threadblock++) { DeviceExecutionPlan deviceExecutionPlan = {}; @@ -357,7 +357,7 @@ struct Executor::Impl { void launchKernel(ExecutionContext& context, int rank, void* sendbuff, void* recvbuff, DataType dataType, cudaStream_t stream, PacketType packetType) { static uint32_t flag = 0; - ExecutionDevicePlanKey key = context.currentDevicePlan; + DeviceExecutionPlanKey key = context.currentDevicePlan; int nthreadblocks = context.deviceExecutionPlans[key].size(); #if defined(ENABLE_NPKIT) #if defined(__HIP_PLATFORM_AMD__)