Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Apr 7, 2024
1 parent d97f312 commit 64106f1
Show file tree
Hide file tree
Showing 6 changed files with 350 additions and 14 deletions.
4 changes: 2 additions & 2 deletions include/mscclpp/packet_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ union alignas(8) LL8Packet {
MSCCLPP_DEVICE_INLINE LL8Packet() {}

MSCCLPP_DEVICE_INLINE LL8Packet(uint32_t val, uint32_t flag) {
data = val;
flag = flag;
this->data = val;
this->flag = flag;
}

MSCCLPP_DEVICE_INLINE void write(uint32_t val, uint32_t flag) {
Expand Down
6 changes: 5 additions & 1 deletion src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ auto convertToChannelType = [](const std::string& str) {
namespace mscclpp {
using json = nlohmann::json;

ExecutionPlan::Impl::Impl(std::string planPath) : planPath(planPath) {}
ExecutionPlan::Impl::Impl(std::string planPath) : planPath(planPath), isUsingPacket(false) {}

std::vector<ChannelInfo> ExecutionPlan::Impl::getChannelInfos(int rank, ChannelType channelType) const {
auto pred = [channelType](const ChannelInfo& info) { return info.channelType == channelType; };
Expand Down Expand Up @@ -111,6 +111,10 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
this->name = obj["name"];
std::string protocol = obj["protocol"];
if (protocol == "LL") {
this->isUsingPacket = true;
}
auto gpus = obj["gpus"];

for (const auto& gpu : gpus) {
Expand Down
19 changes: 10 additions & 9 deletions src/include/execution_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ MSCCLPP_DEVICE_INLINE void handlePutPacket(uint32_t inputOffsetByBytes, DeviceHa
uint8_t* dstChannelIndexes, uint32_t* dstOffsets, int nDstChannels,
uint32_t size, uint32_t flag) {
for (int index = 0; index < nDstChannels; ++index) {
smChannels[dstChannelIndexes[index]].putPackets<PacketType>(dstOffsets[index], inputOffsetByBytes, size,
threadIdx.x, blockDim.x, flag);
smChannels[dstChannelIndexes[index]].putPackets<PacketType>(
dstOffsets[index] * sizeof(PacketType), inputOffsetByBytes, size, threadIdx.x, blockDim.x, flag);
}
}

Expand All @@ -223,14 +223,14 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* output, uint32_t outputOffs
uint32_t* srcOffsets, int nDstChannels, int nSrcs, size_t size,
uint32_t flag) {
size_t nPackets = size * 2 / sizeof(PacketType);
uint32_t srcOffset = inputOffsetByBytes / sizeof(PacketValType<PacketType>);
uint32_t dstOffset = outputOffsetByBytes / sizeof(PacketValType<PacketType>);
const uint32_t srcOffset = inputOffsetByBytes / sizeof(PacketValType<PacketType>);
const uint32_t dstOffset = outputOffsetByBytes / sizeof(PacketValType<PacketType>);
PacketValType<PacketType>* src = (PacketValType<PacketType>*)input + srcOffset;
PacketValType<PacketType>* dst = (PacketValType<PacketType>*)output + dstOffset;
for (int idx = threadIdx.x; idx < nPackets; idx += blockDim.x) {
for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) {
PacketValType<PacketType> data = {};
for (int index = 0; index < nSrcs; ++index) {
PacketType* pkt = (PacketType*)input + srcOffsets[index] / sizeof(PacketType);
PacketType* pkt = (PacketType*)((char*)input + 2 * srcOffsets[index]);
PacketValType<PacketType> val = pkt[idx].read(flag);
data = add_vectors<T>(data, val);
}
Expand All @@ -239,16 +239,17 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* output, uint32_t outputOffs

PacketType pkt(data, flag);
for (int index = 0; index < nDstChannels; ++index) {
smChannels[dstChannelIndexes[index]].write(dstOffsets[index] / sizeof(PacketValType<PacketType>) + idx, pkt);
size_t offset = (dstOffsets[index] * 2) / sizeof(PacketType);
smChannels[dstChannelIndexes[index]].write(offset + idx, pkt);
}
}
}

template <typename PacketType>
MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, uint32_t dstOffset, uint32_t srcOffset, size_t size,
uint32_t flag) {
PacketType* srcPackets = (PacketType*)src;
PacketValType<PacketType>* result = (PacketValType<PacketType>*)dst;
PacketType* srcPackets = (PacketType*)((char*)src + 2 * srcOffset);
PacketValType<PacketType>* result = (PacketValType<PacketType>*)((char*)dst + dstOffset);
size_t nPackets = size * 2 / sizeof(PacketType);
for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) {
PacketValType<PacketType> data = srcPackets[idx].read(flag);
Expand Down
1 change: 1 addition & 0 deletions src/include/execution_plan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ struct ExecutionPlan::Impl {
void setupOperations(const nlohmann::json& gpus);

std::string planPath;
bool isUsingPacket;
// operations for [rank][threadblock] = [operations]
std::unordered_map<int, std::vector<std::vector<Operation>>> operations;
std::unordered_map<int, std::vector<ChannelInfo>> channelInfos;
Expand Down
Loading

0 comments on commit 64106f1

Please sign in to comment.