From 20fdcf643db61bf7e20973a50b179af7d172e5f8 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 31 Oct 2024 11:24:04 +0000 Subject: [PATCH] clean some code --- include/mscclpp/gpu_utils.hpp | 41 +++++++++++++---------------------- src/nvls.cc | 10 +++------ src/registered_memory.cc | 8 +------ 3 files changed, 19 insertions(+), 40 deletions(-) diff --git a/include/mscclpp/gpu_utils.hpp b/include/mscclpp/gpu_utils.hpp index a96201d75..a6d830330 100644 --- a/include/mscclpp/gpu_utils.hpp +++ b/include/mscclpp/gpu_utils.hpp @@ -46,6 +46,19 @@ namespace mscclpp { +/// set memory access permission to read-write +/// @param base Base memory pointer. +/// @param size Size of the memory. +inline void setReadWriteMemoryAccess(void* base, size_t size) { + CUmemAccessDesc accessDesc = {}; + int deviceId; + MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = deviceId; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)base, size, &accessDesc, 1)); +} + /// A RAII guard that will cudaThreadExchangeStreamCaptureMode to cudaStreamCaptureModeRelaxed on construction and /// restore the previous mode on destruction. This is helpful when we want to avoid CUDA graph capture. struct AvoidCudaGraphCaptureGuard { @@ -114,16 +127,11 @@ PhysicalCudaMemory* cudaPhysicalCalloc(size_t nelem, size_t gran) { // allocate physical memory MSCCLPP_CUTHROW(cuMemCreate(&memHandle, bufferSize, &prop, 0 /*flags*/)); - CUmemAccessDesc accessDesc = {}; - accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - accessDesc.location.id = deviceId; - accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - T* devicePtr = nullptr; // Map the device pointer MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&devicePtr, bufferSize, gran, 0U, 0)); MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)devicePtr, bufferSize, 0, memHandle, 0)); - MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)devicePtr, bufferSize, &accessDesc, 1)); + setReadWriteMemoryAccess(devicePtr, bufferSize); CudaStreamWithFlags stream(cudaStreamNonBlocking); MSCCLPP_CUDATHROW(cudaMemsetAsync(devicePtr, 0, bufferSize, stream)); @@ -155,18 +163,13 @@ T* cudaPhysicalCallocPtr(size_t nbytes) { CUmemGenericAllocationHandle memHandle; MSCCLPP_CUTHROW(cuMemCreate(&memHandle, nbytes, &prop, 0 /*flags*/)); - CUmemAccessDesc accessDesc = {}; - accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - accessDesc.location.id = deviceId; - accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - T* devicePtr = nullptr; size_t gran = 0; // Map the device pointer MSCCLPP_CUTHROW(cuMemGetAllocationGranularity(&gran, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&devicePtr, nbytes, gran, 0U, 0)); MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)devicePtr, nbytes, 0, memHandle, 0)); - MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)devicePtr, nbytes, &accessDesc, 1)); + setReadWriteMemoryAccess(devicePtr, nbytes); CudaStreamWithFlags stream(cudaStreamNonBlocking); MSCCLPP_CUDATHROW(cudaMemsetAsync(devicePtr, 0, nbytes, stream)); MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); @@ -473,20 +476,6 @@ void memcpyCuda(T* dst, const T* src, size_t count, cudaMemcpyKind kind = cudaMe MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); } -/// Check if ptr is allocaed by cuMemMap -/// @param ptr Pointer to check -/// @return true if ptr is allocated by cuMemMap, false otherwise -bool inline isCuMemMapAllocated(void* ptr) { - CUmemGenericAllocationHandle handle; - CUresult result = cuMemRetainAllocationHandle(&handle, ptr); - if (result != CUDA_SUCCESS) { - printf("ptr is %p\n", ptr); - return false; - } - cuMemRelease(handle); - return true; -} - } // namespace mscclpp #endif // MSCCLPP_GPU_UTILS_HPP_ diff --git a/src/nvls.cc b/src/nvls.cc index 29fdb710d..80f8cb809 100644 --- a/src/nvls.cc +++ b/src/nvls.cc @@ -208,19 +208,15 @@ std::shared_ptr NvlsConnection::Impl::bindMemoryWithPtr(CUdeviceptr device std::shared_ptr NvlsConnection::Impl::bindMemoryToMulticastHandle(size_t offset, size_t bufferSize) { char* mcPtr; - CUmemAccessDesc accessDesc = {}; - accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - int deviceId = -1; - MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); - accessDesc.location.id = deviceId; - accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)(&mcPtr), bufferSize, minMcGran_, 0U, 0)); MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)(mcPtr), bufferSize, 0, mcHandle_, 0)); - MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)(mcPtr), bufferSize, &accessDesc, 1)); + setReadWriteMemoryAccess(mcPtr, bufferSize); INFO(MSCCLPP_COLL, "NVLS connection bound memory at offset %ld, size %ld", offset, bufferSize); auto deleter = [=, self = shared_from_this()](char* ptr) { + int deviceId; CUdevice device; + MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); MSCCLPP_CUTHROW(cuDeviceGet(&device, deviceId)); MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, bufferSize)); MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr, bufferSize)); diff --git a/src/registered_memory.cc b/src/registered_memory.cc index ee4a3ec50..64e79beb3 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -220,16 +220,10 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { #if (USE_NVLS) CUmemGenericAllocationHandle handle; MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, entry.shareableHandle, CU_MEM_HANDLE_TYPE_FABRIC)); - CUmemAccessDesc accessDesc = {}; - int deviceId; - MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); - accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - accessDesc.location.id = deviceId; - accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; size_t gran = getRecommendedGranularity(); MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&base, this->size, gran, 0, 0)); MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)base, this->size, 0, handle, 0)); - MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)base, this->size, &accessDesc, 1)); + setReadWriteMemoryAccess(base, this->size); this->data = static_cast(base) + entry.offsetFromBase; #endif } else {