Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Nov 1, 2024
1 parent 3d0be17 commit f7d153f
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 149 deletions.
128 changes: 35 additions & 93 deletions include/mscclpp/gpu_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "errors.hpp"
#include "gpu.hpp"
#include "utils.hpp"

/// Throw @ref mscclpp::CudaError if @p cmd does not return cudaSuccess.
/// @param cmd The command to execute.
Expand Down Expand Up @@ -78,15 +79,6 @@ struct CudaStreamWithFlags {
template <class T>
struct CudaDeleter;

template <class T>
struct PhysicalCudaMemory {
CUmemGenericAllocationHandle memHandle_;
T* devicePtr_;
size_t size_;
PhysicalCudaMemory(CUmemGenericAllocationHandle memHandle, T* devicePtr, size_t size)
: memHandle_(memHandle), devicePtr_(devicePtr), size_(size) {}
};

namespace detail {

/// A wrapper of cudaMalloc that sets the allocated memory to zero.
Expand All @@ -105,43 +97,7 @@ T* cudaCalloc(size_t nelem) {
}

template <class T>
PhysicalCudaMemory<T>* cudaPhysicalCalloc(size_t nelem, size_t gran) {
AvoidCudaGraphCaptureGuard cgcGuard;

int deviceId = -1;
MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId));

CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = deviceId;
#if defined(__HIP_PLATFORM_AMD__)
// TODO: revisit when HIP fixes this typo in the field name
prop.requestedHandleType = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
#else
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
#endif

CUmemGenericAllocationHandle memHandle;
size_t bufferSize = sizeof(T) * nelem;
// allocate physical memory
MSCCLPP_CUTHROW(cuMemCreate(&memHandle, bufferSize, &prop, 0 /*flags*/));

T* devicePtr = nullptr;
// Map the device pointer
MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&devicePtr, bufferSize, gran, 0U, 0));
MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)devicePtr, bufferSize, 0, memHandle, 0));
setReadWriteMemoryAccess(devicePtr, bufferSize);
CudaStreamWithFlags stream(cudaStreamNonBlocking);
MSCCLPP_CUDATHROW(cudaMemsetAsync(devicePtr, 0, bufferSize, stream));

MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream));

return new PhysicalCudaMemory<T>(memHandle, devicePtr, bufferSize);
}

template <class T>
T* cudaPhysicalCallocPtr(size_t nbytes, size_t gran) {
T* cudaPhysicalCalloc(size_t nbytes, size_t gran) {
AvoidCudaGraphCaptureGuard cgcGuard;
int deviceId = -1;
CUdevice currentDevice;
Expand Down Expand Up @@ -259,17 +215,6 @@ struct CudaDeleter {

template <class T>
struct CudaPhysicalDeleter {
static_assert(!std::is_array_v<T>, "T must not be an array");
void operator()(PhysicalCudaMemory<T>* ptr) {
AvoidCudaGraphCaptureGuard cgcGuard;
MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr->devicePtr_, ptr->size_));
MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr->devicePtr_, ptr->size_));
MSCCLPP_CUTHROW(cuMemRelease(ptr->memHandle_));
}
};

template <class T>
struct CudaPhysicalPtrDeleter {
static_assert(!std::is_array_v<T>, "T must not be an array");
void operator()(T* ptr) {
AvoidCudaGraphCaptureGuard cgcGuard;
Expand Down Expand Up @@ -303,31 +248,13 @@ std::shared_ptr<T> allocSharedCuda(size_t count = 1) {
return detail::safeAlloc<T, detail::cudaCalloc<T>, CudaDeleter<T>, std::shared_ptr<T>>(count);
}

/// Allocated physical memory on the device and returns a memory handle along with a memory handle for it.
/// The deallocation only happens PhysicalCudaMemory goes out of scope.
/// @tparam T Type of each element in the allocated memory.
/// @param count Number of elements to allocate.
/// @param gran the granularity of the allocation.
/// @return A std::shared_ptr to the memory handle and a device pointer for that memory.
template <class T>
std::shared_ptr<PhysicalCudaMemory<T>> allocSharedPhysicalCuda(size_t count, size_t gran) {
return detail::safeAlloc<PhysicalCudaMemory<T>, detail::cudaPhysicalCalloc<T>, CudaPhysicalDeleter<T>,
std::shared_ptr<PhysicalCudaMemory<T>>>(count, gran);
}

#if (USE_NVLS)
/// Allocates physical memory on the device and returns a std::shared_ptr to it. The memory is zeroed out.
/// @tparam T Type of each element in the allocated memory.
/// @param count Number of elements to allocate.
/// @return A std::shared_ptr to the allocated memory.
template <class T>
std::shared_ptr<T> allocSharedPhysicalCudaPtr(size_t count) {
static inline size_t getMulticastGranularity(size_t size, CUmulticastGranularity_flags granFlag) {
size_t gran = 0;
int numDevices = 0;
MSCCLPP_CUDATHROW(cudaGetDeviceCount(&numDevices));

CUmulticastObjectProp prop = {};
prop.size = count * sizeof(T);
prop.size = size;
// This is a dummy value, it might affect the granularity in the future
prop.numDevices = numDevices;
#if defined(__HIP_PLATFORM_AMD__)
Expand All @@ -337,12 +264,27 @@ std::shared_ptr<T> allocSharedPhysicalCudaPtr(size_t count) {
prop.handleTypes = (CUmemAllocationHandleType)(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR | CU_MEM_HANDLE_TYPE_FABRIC);
#endif
prop.flags = 0;
MSCCLPP_CUTHROW(cuMulticastGetGranularity(&gran, &prop, CU_MULTICAST_GRANULARITY_RECOMMENDED));
MSCCLPP_CUTHROW(cuMulticastGetGranularity(&gran, &prop, granFlag));
return gran;
}

/// Allocates physical memory on the device and returns a std::shared_ptr to it. The memory is zeroed out.
/// @tparam T Type of each element in the allocated memory.
/// @param count Number of elements to allocate.
/// @param gran the granularity of the allocation.
/// @return A std::shared_ptr to the allocated memory.
template <class T>
std::shared_ptr<T> allocSharedPhysicalCuda(size_t count, size_t gran = 0) {
if (!isFabricSupported()) {
throw Error("Only suupport GPU with Fabric support", ErrorCode::InvalidUsage);
}

if (gran == 0) {
gran = getMulticastGranularity(count * sizeof(T), CU_MULTICAST_GRANULARITY_RECOMMENDED);
}
size_t nbytes = (count * sizeof(T) + gran - 1) / gran * gran;
return detail::safeAlloc<T, detail::cudaPhysicalCallocPtr<T>, CudaPhysicalPtrDeleter<T>, std::shared_ptr<T>>(nbytes,
gran);
return detail::safeAlloc<T, detail::cudaPhysicalCalloc<T>, CudaPhysicalDeleter<T>, std::shared_ptr<T>>(nbytes, gran);
}
#endif

/// Allocates memory on the device and returns a std::shared_ptr to it. The memory is zeroed out.
/// @tparam T Type of each element in the allocated memory.
Expand All @@ -367,18 +309,6 @@ UniqueCudaPtr<T> allocUniqueCuda(size_t count = 1) {
return detail::safeAlloc<T, detail::cudaCalloc<T>, CudaDeleter<T>, UniqueCudaPtr<T>>(count);
}

/// Allocated physical memory on the device and returns a memory handle along with a virtual memory handle for it.
/// The memory is zeroed out.
/// @tparam T Type of each element in the allocated memory.
/// @param count Number of elements to allocate.
/// @param gran the granularity of the allocation.
/// @return A std::unique_ptr to the memory handle and a device pointer for that memory.
template <class T>
std::unique_ptr<PhysicalCudaMemory<T>> allocUniquePhysicalCuda(size_t count, size_t gran) {
return detail::safeAlloc<PhysicalCudaMemory<T>, detail::cudaPhysicalCalloc<T>, CudaPhysicalDeleter<T>,
std::unique_ptr<CudaPhysicalDeleter<T>, CudaDeleter<CudaPhysicalDeleter<T>>>>(count, gran);
}

/// Allocates memory on the device and returns a std::unique_ptr to it. The memory is zeroed out.
/// @tparam T Type of each element in the allocated memory.
/// @param count Number of elements to allocate.
Expand Down Expand Up @@ -447,6 +377,18 @@ UniqueCudaHostPtr<T> makeUniqueCudaHost(size_t count) {
return ptr;
}

/// Allocated physical memory on the device and returns a memory handle along with a virtual memory handle for it.
/// The memory is zeroed out.
/// @tparam T Type of each element in the allocated memory.
/// @param count Number of elements to allocate.
/// @param gran the granularity of the allocation.
/// @return A std::unique_ptr to the allocated memory.
template <class T>
std::unique_ptr<T> allocUniquePhysicalCuda(size_t count, size_t gran) {
return detail::safeAlloc<T, detail::cudaPhysicalCalloc<T>, CudaPhysicalDeleter<T>,
std::unique_ptr<CudaPhysicalDeleter<T>, CudaDeleter<CudaPhysicalDeleter<T>>>>(count, gran);
}

/// Asynchronous cudaMemcpy without capture into a CUDA graph.
/// @tparam T Type of each element in the allocated memory.
/// @param dst Destination pointer.
Expand Down
11 changes: 5 additions & 6 deletions include/mscclpp/nvls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,19 @@ class NvlsConnection {

struct DeviceMulticastPointer {
private:
std::shared_ptr<PhysicalCudaMemory<char>> deviceMem_;
void* devicePtr_;
std::shared_ptr<char> ptr_;
std::shared_ptr<char> mcPtr_;
size_t bufferSize_;

public:
using DeviceHandle = DeviceMulticastPointerDeviceHandle;
DeviceMulticastPointer(std::shared_ptr<PhysicalCudaMemory<char>> deviceMem, std::shared_ptr<char> mcPtr,
size_t bufferSize)
: deviceMem_(deviceMem), devicePtr_(nullptr), mcPtr_(mcPtr), bufferSize_(bufferSize) {}
DeviceMulticastPointer(std::shared_ptr<char> ptr, std::shared_ptr<char> mcPtr, size_t bufferSize)
: devicePtr_(ptr.get()), ptr_(ptr), mcPtr_(mcPtr), bufferSize_(bufferSize) {}
DeviceMulticastPointer(void* devicePtr, std::shared_ptr<char> mcPtr, size_t bufferSize)
: deviceMem_(nullptr), devicePtr_(devicePtr), mcPtr_(mcPtr), bufferSize_(bufferSize) {}
: devicePtr_(devicePtr), ptr_(nullptr), mcPtr_(mcPtr), bufferSize_(bufferSize) {}
DeviceHandle deviceHandle();
char* getDevicePtr();
void* getDevicePtr();

friend class NvlsConnection;
};
Expand Down
1 change: 1 addition & 0 deletions include/mscclpp/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct ScopedTimer : public Timer {
std::string getHostName(int maxlen, const char delim);

bool isNvlsSupported();
bool isFabricSupported();

} // namespace mscclpp

Expand Down
2 changes: 1 addition & 1 deletion python/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
PacketType,
version,
is_nvls_supported,
alloc_shared_physical_cuda_ptr,
alloc_shared_physical_cuda,
npkit,
)

Expand Down
8 changes: 3 additions & 5 deletions python/mscclpp/gpu_utils_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ using namespace mscclpp;

class PyCudaMemory {
public:
PyCudaMemory(size_t size, DataType dtype) : size_(size), dtype_(dtype) {
ptr_ = allocSharedPhysicalCudaPtr<char>(size);
}
PyCudaMemory(size_t size, DataType dtype) : size_(size), dtype_(dtype) { ptr_ = allocSharedPhysicalCuda<char>(size); }

uintptr_t getPtr() const { return (uintptr_t)(ptr_.get()); }
size_t size() const { return size_; }
Expand All @@ -24,7 +22,7 @@ class PyCudaMemory {
DataType dtype_;
};

std::shared_ptr<PyCudaMemory> allocSharedPhysicalCudaPtrDispatcher(size_t count, DataType dtype) {
std::shared_ptr<PyCudaMemory> allocSharedPhysicalCudaDispatcher(size_t count, DataType dtype) {
size_t size = 0;
switch (dtype) {
case DataType::FLOAT32:
Expand Down Expand Up @@ -52,5 +50,5 @@ void register_gpu_utils(nb::module_& m) {
.def("get_ptr", &PyCudaMemory::getPtr, "Get the raw pointer")
.def("size", &PyCudaMemory::size, "Get the size of the allocated memory")
.def("dtype", &PyCudaMemory::dtype, "Get the data type of the memory");
m.def("alloc_shared_physical_cuda_ptr", &allocSharedPhysicalCudaPtrDispatcher, nb::arg("count"), nb::arg("dtype"));
m.def("alloc_shared_physical_cuda", &allocSharedPhysicalCudaDispatcher, nb::arg("count"), nb::arg("dtype"));
}
4 changes: 2 additions & 2 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
ExecutionPlan,
PacketType,
npkit,
alloc_shared_physical_cuda_ptr,
alloc_shared_physical_cuda,
is_nvls_supported,
)
import mscclpp.comm as mscclpp_comm
Expand Down Expand Up @@ -97,7 +97,7 @@ def determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name):

def allocate_buffer(nelems, dtype):
if is_nvls_supported:
buffer_raw = alloc_shared_physical_cuda_ptr(nelems, dtype=dtype_to_mscclpp_dtype(dtype))
buffer_raw = alloc_shared_physical_cuda(nelems, dtype=dtype_to_mscclpp_dtype(dtype))
buffer_ptr = cp.cuda.MemoryPointer(
cp.cuda.UnownedMemory(buffer_raw.get_ptr(), buffer_raw.size(), buffer_raw), 0
)
Expand Down
16 changes: 7 additions & 9 deletions src/nvls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,11 @@ NvlsConnection::NvlsConnection(const std::vector<char>& data) : pimpl_(std::make
std::vector<char> NvlsConnection::serialize() { return pimpl_->serialize(); }

std::shared_ptr<NvlsConnection::DeviceMulticastPointer> NvlsConnection::allocateAndBindCuda(size_t size) {
auto mem = allocSharedPhysicalCuda<char>(size, pimpl_->getMinMcGran());
auto mcPtr = pimpl_->bindMemory(mem->memHandle_, size);
return std::make_shared<DeviceMulticastPointer>(mem, mcPtr, size);
CUmemGenericAllocationHandle handle;
auto ptr = allocSharedPhysicalCuda<char>(size, pimpl_->getMinMcGran());
MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, ptr.get()));
auto mcPtr = pimpl_->bindMemory(handle, size);
return std::make_shared<DeviceMulticastPointer>(ptr, mcPtr, size);
}

std::shared_ptr<char> NvlsConnection::bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size) {
Expand All @@ -286,17 +288,13 @@ std::shared_ptr<char> NvlsConnection::bindAllocatedCudaWithPtr(CUdeviceptr devic

NvlsConnection::DeviceMulticastPointer::DeviceHandle NvlsConnection::DeviceMulticastPointer::deviceHandle() {
NvlsConnection::DeviceMulticastPointer::DeviceHandle device;
if (this->deviceMem_ != nullptr) {
device.devicePtr = this->deviceMem_->devicePtr_;
} else {
device.devicePtr = this->devicePtr_;
}
device.devicePtr = this->devicePtr_;
device.mcPtr = this->mcPtr_.get();
device.bufferSize = this->bufferSize_;
return device;
};

char* NvlsConnection::DeviceMulticastPointer::getDevicePtr() { return deviceMem_->devicePtr_; };
void* NvlsConnection::DeviceMulticastPointer::getDevicePtr() { return devicePtr_; };

size_t NvlsConnection::getMultiCastMinGranularity() { return pimpl_->getMinMcGran(); }

Expand Down
58 changes: 32 additions & 26 deletions src/registered_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, TransportFlags transports,
this->isCuMemMapAlloc = true;
}
if (this->isCuMemMapAlloc) {
#if (USE_NVLS)
CUmemGenericAllocationHandle handle;
MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, baseDataPtr));
MSCCLPP_CUTHROW(
cuMemExportToShareableHandle(transportInfo.shareableHandle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0));
transportInfo.offsetFromBase = (char*)data - (char*)baseDataPtr;
#endif
if (isFabricSupported()) {
CUmemGenericAllocationHandle handle;
MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, baseDataPtr));
MSCCLPP_CUTHROW(
cuMemExportToShareableHandle(transportInfo.shareableHandle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0));
transportInfo.offsetFromBase = (char*)data - (char*)baseDataPtr;
} else {
throw Error("Fabric is not supported", ErrorCode::InvalidUsage);
}
} else {
cudaIpcMemHandle_t handle;
MSCCLPP_CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr));
Expand Down Expand Up @@ -216,15 +218,17 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
auto entry = getTransportInfo(Transport::CudaIpc);
void* base;
if (this->isCuMemMapAlloc) {
#if (USE_NVLS)
CUmemGenericAllocationHandle handle;
MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, entry.shareableHandle, CU_MEM_HANDLE_TYPE_FABRIC));
size_t gran = getRecommendedGranularity();
MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&base, this->size, gran, 0, 0));
MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)base, this->size, 0, handle, 0));
setReadWriteMemoryAccess(base, this->size);
this->data = static_cast<char*>(base) + entry.offsetFromBase;
#endif
if (isFabricSupported()) {
CUmemGenericAllocationHandle handle;
MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, entry.shareableHandle, CU_MEM_HANDLE_TYPE_FABRIC));
size_t gran = getRecommendedGranularity();
MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&base, this->size, gran, 0, 0));
MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)base, this->size, 0, handle, 0));
setReadWriteMemoryAccess(base, this->size);
this->data = static_cast<char*>(base) + entry.offsetFromBase;
} else {
throw Error("Fabric is not supported", ErrorCode::InvalidUsage);
}
} else {
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess));
this->data = static_cast<char*>(base) + entry.cudaIpcOffsetFromBase;
Expand All @@ -241,16 +245,18 @@ RegisteredMemory::Impl::~Impl() {
if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash && getPidHash() != this->pidHash) {
void* base = static_cast<char*>(data) - getTransportInfo(Transport::CudaIpc).cudaIpcOffsetFromBase;
if (this->isCuMemMapAlloc) {
#if (USE_NVLS)
CUmemGenericAllocationHandle handle;
size_t size = 0;
MSCCLPP_CULOG_WARN(cuMemRetainAllocationHandle(&handle, base));
MSCCLPP_CULOG_WARN(cuMemRelease(handle));
MSCCLPP_CULOG_WARN(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)base));
MSCCLPP_CULOG_WARN(cuMemUnmap((CUdeviceptr)base, size));
MSCCLPP_CULOG_WARN(cuMemRelease(handle));
MSCCLPP_CULOG_WARN(cuMemAddressFree((CUdeviceptr)base, size));
#endif
if (isFabricSupported()) {
CUmemGenericAllocationHandle handle;
size_t size = 0;
MSCCLPP_CULOG_WARN(cuMemRetainAllocationHandle(&handle, base));
MSCCLPP_CULOG_WARN(cuMemRelease(handle));
MSCCLPP_CULOG_WARN(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)base));
MSCCLPP_CULOG_WARN(cuMemUnmap((CUdeviceptr)base, size));
MSCCLPP_CULOG_WARN(cuMemRelease(handle));
MSCCLPP_CULOG_WARN(cuMemAddressFree((CUdeviceptr)base, size));
} else {
WARN("Fabric is not supported, skipping cleanup");
}
} else {
cudaError_t err = cudaIpcCloseMemHandle(base);
if (err != cudaSuccess) {
Expand Down
Loading

0 comments on commit f7d153f

Please sign in to comment.