Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Nov 8, 2024
1 parent 9ba63c0 commit 179a6a6
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ struct ExecutionContextKey {
}
};

struct ExecutionDevicePlanKey {
struct DeviceExecutionPlanKey {
size_t inputMessageSize;
size_t outputMessageSize;
size_t contsSrcOffset;
size_t constDstOffset;

bool operator==(const ExecutionDevicePlanKey& other) const {
bool operator==(const DeviceExecutionPlanKey& other) const {
return inputMessageSize == other.inputMessageSize && outputMessageSize == other.outputMessageSize &&
contsSrcOffset == other.contsSrcOffset && constDstOffset == other.constDstOffset;
}
Expand Down Expand Up @@ -70,8 +70,8 @@ struct hash<mscclpp::ExecutionContextKey> {
};

template <>
struct hash<mscclpp::ExecutionDevicePlanKey> {
std::size_t operator()(const mscclpp::ExecutionDevicePlanKey& key) const {
struct hash<mscclpp::DeviceExecutionPlanKey> {
std::size_t operator()(const mscclpp::DeviceExecutionPlanKey& key) const {
std::size_t seed = 0;
hash_combine(seed, key.inputMessageSize);
hash_combine(seed, key.outputMessageSize);
Expand Down Expand Up @@ -102,12 +102,12 @@ struct ExecutionContext {
std::vector<mscclpp::SemaphoreId> proxySemaphores;
std::vector<mscclpp::SmChannel> smChannels;
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
std::unordered_map<ExecutionDevicePlanKey, std::vector<DeviceExecutionPlan>> deviceExecutionPlans;
std::unordered_map<ExecutionDevicePlanKey, std::shared_ptr<char>> deviceExecutionPlansBuffers;
std::unordered_map<DeviceExecutionPlanKey, std::vector<DeviceExecutionPlan>> deviceExecutionPlans;
std::unordered_map<DeviceExecutionPlanKey, std::shared_ptr<char>> deviceExecutionPlansBuffers;
std::shared_ptr<char> scratchBuffer;
size_t scratchBufferSize;
int nthreadsPerBlock;
ExecutionDevicePlanKey currentDevicePlan;
DeviceExecutionPlanKey currentDevicePlan;
};

struct Executor::Impl {
Expand All @@ -126,7 +126,7 @@ struct Executor::Impl {
size_t outputMessageSize, size_t contsSrcOffset, size_t constDstOffset,
size_t sendBufferSize, size_t recvBufferSize, const ExecutionPlan& plan) {
ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize, plan.impl_->name};
ExecutionDevicePlanKey devicePlanKey = {inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset};
DeviceExecutionPlanKey devicePlanKey = {inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset};
if (this->contexts.find(key) != this->contexts.end()) {
if (this->contexts[key].currentDevicePlan == devicePlanKey) {
return this->contexts[key];
Expand Down Expand Up @@ -330,7 +330,7 @@ struct Executor::Impl {
}

void setupDeviceExecutionPlan(ExecutionContext& context, int rank, const ExecutionPlan& plan,
const ExecutionDevicePlanKey& key) {
const DeviceExecutionPlanKey& key) {
std::vector<DeviceExecutionPlan> deviceExecutionPlans;
for (int threadblock = 0; threadblock < plan.impl_->getThreadblockCount(rank); threadblock++) {
DeviceExecutionPlan deviceExecutionPlan = {};
Expand All @@ -357,7 +357,7 @@ struct Executor::Impl {
void launchKernel(ExecutionContext& context, int rank, void* sendbuff, void* recvbuff, DataType dataType,
cudaStream_t stream, PacketType packetType) {
static uint32_t flag = 0;
ExecutionDevicePlanKey key = context.currentDevicePlan;
DeviceExecutionPlanKey key = context.currentDevicePlan;
int nthreadblocks = context.deviceExecutionPlans[key].size();
#if defined(ENABLE_NPKIT)
#if defined(__HIP_PLATFORM_AMD__)
Expand Down

0 comments on commit 179a6a6

Please sign in to comment.