Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix light load bug #379

Merged
merged 7 commits into from
Nov 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 80 additions & 24 deletions src/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,62 @@ struct ExecutionContextKey {
recvBuffSize == other.recvBuffSize && plan == other.plan;
}
};

struct DeviceExecutionPlanKey {
size_t inputMessageSize;
size_t outputMessageSize;
size_t constSrcOffset;
size_t constDstOffset;

bool operator==(const DeviceExecutionPlanKey& other) const {
return inputMessageSize == other.inputMessageSize && outputMessageSize == other.outputMessageSize &&
constSrcOffset == other.constSrcOffset && constDstOffset == other.constDstOffset;
}
};

} // namespace mscclpp

namespace std {

// Refer https://www.boost.org/doc/libs/1_86_0/libs/container_hash/doc/html/hash.html#combine
template <typename T>
inline void hash_combine(std::size_t& seed, const T& value) {
std::hash<T> hasher;
seed ^= hasher(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}

template <>
struct hash<std::pair<mscclpp::BufferType, int>> {
std::size_t operator()(const std::pair<mscclpp::BufferType, int>& key) const {
return std::hash<int>()(key.second) ^ std::hash<int>()(static_cast<int>(key.first));
std::size_t seed = 0;
hash_combine(seed, static_cast<int>(key.first));
hash_combine(seed, key.second);
return seed;
}
};

template <>
struct hash<mscclpp::ExecutionContextKey> {
std::size_t operator()(const mscclpp::ExecutionContextKey& key) const {
return std::hash<void*>()(key.sendBuff) ^ std::hash<void*>()(key.recvBuff) ^ std::hash<size_t>()(key.sendBuffSize) ^
std::hash<size_t>()(key.recvBuffSize) ^ std::hash<std::string>()(key.plan);
size_t seed = 0;
hash_combine(seed, key.sendBuff);
hash_combine(seed, key.recvBuff);
hash_combine(seed, key.sendBuffSize);
hash_combine(seed, key.recvBuffSize);
hash_combine(seed, key.plan);
return seed;
}
};

template <>
struct hash<mscclpp::DeviceExecutionPlanKey> {
std::size_t operator()(const mscclpp::DeviceExecutionPlanKey& key) const {
std::size_t seed = 0;
hash_combine(seed, key.inputMessageSize);
hash_combine(seed, key.outputMessageSize);
hash_combine(seed, key.constSrcOffset);
hash_combine(seed, key.constDstOffset);
return seed;
}
};
} // namespace std
Expand All @@ -61,11 +102,12 @@ struct ExecutionContext {
std::vector<mscclpp::SemaphoreId> proxySemaphores;
std::vector<mscclpp::SmChannel> smChannels;
std::vector<mscclpp::SimpleProxyChannel> proxyChannels;
std::vector<DeviceExecutionPlan> deviceExecutionPlans;
std::unordered_map<DeviceExecutionPlanKey, std::vector<DeviceExecutionPlan>> deviceExecutionPlans;
std::unordered_map<DeviceExecutionPlanKey, std::shared_ptr<char>> deviceExecutionPlansBuffers;
std::shared_ptr<char> scratchBuffer;
size_t scratchBufferSize;
std::shared_ptr<char> deviceExecutionPlansBuffer;
int nthreadsPerBlock;
DeviceExecutionPlanKey currentDevicePlan;
};

struct Executor::Impl {
Expand All @@ -81,23 +123,32 @@ struct Executor::Impl {
~Impl() = default;

ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t inputMessageSize,
size_t outputMessageSize, size_t contsSrcOffset, size_t constDstOffset,
size_t outputMessageSize, size_t constSrcOffset, size_t constDstOffset,
size_t sendBufferSize, size_t recvBufferSize, const ExecutionPlan& plan) {
ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize, plan.impl_->name};
DeviceExecutionPlanKey devicePlanKey = {inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset};
if (this->contexts.find(key) != this->contexts.end()) {
auto& devicePlans = this->contexts[key].deviceExecutionPlans;
if (this->contexts[key].currentDevicePlan == devicePlanKey) {
return this->contexts[key];
} else if (devicePlans.find(devicePlanKey) != devicePlans.end()) {
this->contexts[key].currentDevicePlan = devicePlanKey;
return this->contexts[key];
}
plan.impl_->operationsReset();
plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset);
this->setupDeviceExecutionPlan(this->contexts[key], rank, plan);
this->contexts[key].deviceExecutionPlansBuffer =
allocExtSharedCuda<char>(this->contexts[key].deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan));
memcpyCuda(this->contexts[key].deviceExecutionPlansBuffer.get(),
(char*)this->contexts[key].deviceExecutionPlans.data(),
this->contexts[key].deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice);
plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset);
this->setupDeviceExecutionPlan(this->contexts[key], devicePlanKey, rank, plan);
this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey] =
allocExtSharedCuda<char>(devicePlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan));
memcpyCuda(this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey].get(),
(char*)devicePlans[devicePlanKey].data(),
devicePlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice);
this->contexts[key].currentDevicePlan = devicePlanKey;
return this->contexts[key];
}

plan.impl_->reset();
plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset);
plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset);

ExecutionContext context;
size_t scratchBufferSize = plan.impl_->getScratchBufferSize(rank, sendBufferSize, recvBufferSize);
Expand All @@ -109,11 +160,14 @@ struct Executor::Impl {
this->setupConnections(context, rank, plan);
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
this->setupDeviceExecutionPlan(context, rank, plan);
context.deviceExecutionPlansBuffer =
allocExtSharedCuda<char>(context.deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan));
memcpyCuda(context.deviceExecutionPlansBuffer.get(), (char*)context.deviceExecutionPlans.data(),
context.deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice);
this->setupDeviceExecutionPlan(context, devicePlanKey, rank, plan);
context.deviceExecutionPlansBuffers[devicePlanKey] =
allocExtSharedCuda<char>(context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan));
memcpyCuda(context.deviceExecutionPlansBuffers[devicePlanKey].get(),
(char*)context.deviceExecutionPlans[devicePlanKey].data(),
context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan),
cudaMemcpyHostToDevice);
context.currentDevicePlan = devicePlanKey;
context.proxyService->startProxy();
this->contexts.insert({key, context});
return context;
Expand Down Expand Up @@ -278,7 +332,8 @@ struct Executor::Impl {
}
}

void setupDeviceExecutionPlan(ExecutionContext& context, int rank, const ExecutionPlan& plan) {
void setupDeviceExecutionPlan(ExecutionContext& context, const DeviceExecutionPlanKey& key, int rank,
const ExecutionPlan& plan) {
std::vector<DeviceExecutionPlan> deviceExecutionPlans;
for (int threadblock = 0; threadblock < plan.impl_->getThreadblockCount(rank); threadblock++) {
DeviceExecutionPlan deviceExecutionPlan = {};
Expand All @@ -299,13 +354,14 @@ struct Executor::Impl {
}
deviceExecutionPlans.push_back(deviceExecutionPlan);
}
context.deviceExecutionPlans = std::move(deviceExecutionPlans);
context.deviceExecutionPlans[key] = std::move(deviceExecutionPlans);
}

void launchKernel(ExecutionContext& context, int rank, void* sendbuff, void* recvbuff, DataType dataType,
cudaStream_t stream, PacketType packetType) {
static uint32_t flag = 0;
int nthreadblocks = context.deviceExecutionPlans.size();
DeviceExecutionPlanKey key = context.currentDevicePlan;
int nthreadblocks = context.deviceExecutionPlans[key].size();
#if defined(ENABLE_NPKIT)
#if defined(__HIP_PLATFORM_AMD__)
if (nthreadblocks > NPKIT_MAX_NUM_GPU_THREADBLOCKS) {
Expand All @@ -323,13 +379,13 @@ struct Executor::Impl {
case PacketType::LL16:
ExecutionKernel::launchKernel<LL16Packet>(
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(),
sharedMemSize, stream, ++flag);
break;
case PacketType::LL8:
ExecutionKernel::launchKernel<LL8Packet>(
rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(),
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffer.get(),
context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(),
sharedMemSize, stream, ++flag);
break;
default:
Expand Down
Loading