Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Nov 14, 2024
1 parent 755629a commit 5b78c68
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/include/registered_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct RegisteredMemory::Impl {
size_t size;
uint64_t hostHash;
uint64_t pidHash;
bool isCuMemMapAlloc;
TransportFlags transports;
std::vector<TransportInfo> transportInfos;

Expand Down
46 changes: 38 additions & 8 deletions src/registered_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
#include "utils_internal.hpp"

namespace {
#if (CUDA_NVLS_SUPPORTED)
// Get the recommended granularity for cuMemAddressReserve
size_t getRecommendedGranularity() {
#if (CUDA_NVLS_SUPPORTED)
size_t gran = 0;
int deviceId = -1;
int currentDevice = -1;
Expand All @@ -29,8 +29,33 @@ size_t getRecommendedGranularity() {
prop.location.id = currentDevice;
MSCCLPP_CUTHROW(cuMemGetAllocationGranularity(&gran, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
return gran;
#else
throw mscclpp::Error("Only support GPU with NVLS support", mscclpp::ErrorCode::InvalidUsage);
#endif
}

CUmemAllocationHandleType getNvlsCompatibleMemHandleType() {
#if (CUDA_NVLS_SUPPORTED)
return CU_MEM_HANDLE_TYPE_FABRIC;
#else
throw mscclpp::Error("Only support GPU with NVLS support", mscclpp::ErrorCode::InvalidUsage);
#endif
}

// Check if ptr is allocaed by cuMemMap
bool isCuMemMapAllocated(void* ptr) {
CUmemGenericAllocationHandle handle;
CUresult result = cuMemRetainAllocationHandle(&handle, ptr);
if (result != CUDA_SUCCESS) {
return false;
}
MSCCLPP_CUTHROW(cuMemRelease(handle));
if (!mscclpp::isNvlsSupported()) {
throw mscclpp::Error("cuMemMap is used in env without NVLS support", mscclpp::ErrorCode::InvalidUsage);
}
return true;
}

} // namespace

namespace mscclpp {
Expand All @@ -49,11 +74,12 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, TransportFlags transports,
void* baseDataPtr;
size_t baseDataSize; // dummy
MSCCLPP_CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data));
if (isNvlsSupported()) {
this->isCuMemMapAlloc = isCuMemMapAllocated(baseDataPtr);
if (this->isCuMemMapAlloc) {
CUmemGenericAllocationHandle handle;
MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, baseDataPtr));
MSCCLPP_CUTHROW(
cuMemExportToShareableHandle(transportInfo.shareableHandle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0));
cuMemExportToShareableHandle(transportInfo.shareableHandle, handle, getNvlsCompatibleMemHandleType(), 0));
transportInfo.offsetFromBase = (char*)data - (char*)baseDataPtr;
} else {
cudaIpcMemHandle_t handle;
Expand Down Expand Up @@ -105,6 +131,8 @@ MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize() {
std::copy_n(reinterpret_cast<char*>(&pimpl_->size), sizeof(pimpl_->size), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->hostHash), sizeof(pimpl_->hostHash), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->pidHash), sizeof(pimpl_->pidHash), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->isCuMemMapAlloc), sizeof(pimpl_->isCuMemMapAlloc),
std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->transports), sizeof(pimpl_->transports), std::back_inserter(result));
if (pimpl_->transportInfos.size() > static_cast<size_t>(std::numeric_limits<int8_t>::max())) {
throw mscclpp::Error("Too many transport info entries", ErrorCode::InternalError);
Expand All @@ -114,7 +142,7 @@ MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize() {
for (auto& entry : pimpl_->transportInfos) {
std::copy_n(reinterpret_cast<char*>(&entry.transport), sizeof(entry.transport), std::back_inserter(result));
if (entry.transport == Transport::CudaIpc) {
if (isNvlsSupported()) {
if (pimpl_->isCuMemMapAlloc) {
std::copy_n(reinterpret_cast<char*>(&entry.shareableHandle), sizeof(entry.shareableHandle),
std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&entry.offsetFromBase), sizeof(entry.offsetFromBase),
Expand Down Expand Up @@ -148,6 +176,8 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
it += sizeof(this->hostHash);
std::copy_n(it, sizeof(this->pidHash), reinterpret_cast<char*>(&this->pidHash));
it += sizeof(this->pidHash);
std::copy_n(it, sizeof(this->isCuMemMapAlloc), reinterpret_cast<char*>(&this->isCuMemMapAlloc));
it += sizeof(this->isCuMemMapAlloc);
std::copy_n(it, sizeof(this->transports), reinterpret_cast<char*>(&this->transports));
it += sizeof(this->transports);
int8_t transportCount;
Expand All @@ -158,7 +188,7 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast<char*>(&transportInfo.transport));
it += sizeof(transportInfo.transport);
if (transportInfo.transport == Transport::CudaIpc) {
if (isNvlsSupported()) {
if (this->isCuMemMapAlloc) {
std::copy_n(it, sizeof(transportInfo.shareableHandle), reinterpret_cast<char*>(&transportInfo.shareableHandle));
it += sizeof(transportInfo.shareableHandle);
std::copy_n(it, sizeof(transportInfo.offsetFromBase), reinterpret_cast<char*>(&transportInfo.offsetFromBase));
Expand Down Expand Up @@ -192,9 +222,9 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
// The memory is local to the machine but not to the process, so we need to open the CUDA IPC handle
auto entry = getTransportInfo(Transport::CudaIpc);
void* base;
if (isNvlsSupported()) {
if (this->isCuMemMapAlloc) {
CUmemGenericAllocationHandle handle;
MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, entry.shareableHandle, CU_MEM_HANDLE_TYPE_FABRIC));
MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, entry.shareableHandle, getNvlsCompatibleMemHandleType()));
size_t gran = getRecommendedGranularity();
MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&base, this->size, gran, 0, 0));
MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)base, this->size, 0, handle, 0));
Expand All @@ -215,7 +245,7 @@ RegisteredMemory::Impl::~Impl() {
// Close the CUDA IPC handle if it was opened during deserialization
if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash && getPidHash() != this->pidHash) {
void* base = static_cast<char*>(data) - getTransportInfo(Transport::CudaIpc).cudaIpcOffsetFromBase;
if (isNvlsSupported()) {
if (this->isCuMemMapAlloc) {
CUmemGenericAllocationHandle handle;
size_t size = 0;
MSCCLPP_CULOG_WARN(cuMemRetainAllocationHandle(&handle, base));
Expand Down

0 comments on commit 5b78c68

Please sign in to comment.