Skip to content

Commit

Permalink
Fixed issues found on AMD (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
pash-msft authored Mar 8, 2024
1 parent d72f09f commit 96e31bf
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
6 changes: 2 additions & 4 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,8 @@ static std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SmChannel>> setupSmChannel
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SmChannel>> ptr =
mscclpp::allocSharedCuda<mscclpp::DeviceHandle<mscclpp::SmChannel>>(smChannelDeviceHandles.size());
mscclpp::AvoidCudaGraphCaptureGuard guard;
CUDACHECK(cudaMemcpy(ptr.get(), smChannelDeviceHandles.data(),
sizeof(mscclpp::DeviceHandle<mscclpp::SmChannel>) * smChannelDeviceHandles.size(),
cudaMemcpyHostToDevice));
mscclpp::memcpyCuda<mscclpp::DeviceHandle<mscclpp::SmChannel>>(ptr.get(), smChannelDeviceHandles.data(),
smChannelDeviceHandles.size(), cudaMemcpyHostToDevice);
return ptr;
}

Expand Down
2 changes: 1 addition & 1 deletion include/mscclpp/gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#if defined(__HIP_PLATFORM_AMD__)

#include <hip/hip_bf16.h>
// #include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>

Expand Down
1 change: 1 addition & 0 deletions src/fifo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ MSCCLPP_API_CPP void Fifo::pop() {
MSCCLPP_API_CPP void Fifo::flushTail(bool sync) {
// Flush the tail to device memory. This is either triggered every ProxyFlushPeriod to make sure that the fifo can
// make progress even if there is no request mscclppSync. However, mscclppSync type is for flush request.
AvoidCudaGraphCaptureGuard cgcGuard;
MSCCLPP_CUDATHROW(cudaMemcpyAsync(pimpl->tailReplica.get(), &pimpl->hostTail, sizeof(uint64_t),
cudaMemcpyHostToDevice, pimpl->stream));
if (sync) {
Expand Down

0 comments on commit 96e31bf

Please sign in to comment.