Skip to content

Commit

Permalink
Fix for ROCm 6.0 (#347)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Sep 2, 2024
1 parent 4eca6f1 commit 72b99a4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ target_include_directories(mscclpp_obj
SYSTEM PRIVATE
${GPU_INCLUDE_DIRS}
${NUMA_INCLUDE_DIRS})
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} nlohmann_json::nlohmann_json Threads::Threads)
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} nlohmann_json::nlohmann_json Threads::Threads dl)
if(IBVERBS_FOUND)
target_include_directories(mscclpp_obj SYSTEM PRIVATE ${IBVERBS_INCLUDE_DIRS})
target_link_libraries(mscclpp_obj PRIVATE ${IBVERBS_LIBRARIES})
Expand Down
14 changes: 12 additions & 2 deletions src/include/execution_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ MSCCLPP_DEVICE_INLINE __half2 add_elements(__half2 a, __half2 b) {
return __hadd2(a, b);
}

template <>
MSCCLPP_DEVICE_INLINE __bfloat16 add_elements(__bfloat16 a, __bfloat16 b) {
return __hadd(a, b);
}

template <>
MSCCLPP_DEVICE_INLINE __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) {
return __hadd2(a, b);
}

template <typename T>
MSCCLPP_DEVICE_INLINE int4 add_vectors_helper(int4 a, int4 b) {
int4 ret;
Expand Down Expand Up @@ -239,7 +249,7 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceCopySend(T* output, uint32_t outputOf
T tmp = input[idx];
for (int index = 0; index < nSrcChannels; ++index) {
size_t srcOffset = srcOffsets[index] / sizeof(T);
tmp += smChannels[srcChannelIndexes[index]].read<T>(srcOffset + idx);
tmp = add_elements(tmp, smChannels[srcChannelIndexes[index]].read<T>(srcOffset + idx));
}
output[idx] = tmp;
if (sendToRemote) {
Expand Down Expand Up @@ -360,7 +370,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
T tmp = src[idx];
for (int index = 0; index < nOutChannels; ++index) {
size_t offset = inputOffsets[index] / sizeof(T);
tmp += input[offset + idx];
tmp = add_elements(tmp, input[offset + idx]);
}
dst[idx] = tmp;
for (int index = 0; index < nOutChannels; ++index) {
Expand Down

0 comments on commit 72b99a4

Please sign in to comment.