Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Oct 31, 2024
1 parent ea47314 commit 4bd884e
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 23 deletions.
14 changes: 12 additions & 2 deletions include/mscclpp/gpu_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ T* cudaPhysicalCallocPtr(size_t nelem, size_t gran) {
// TODO: revisit when HIP fixes this typo in the field name
prop.requestedHandleType = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
#else
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;
#endif

CUmemGenericAllocationHandle memHandle;
Expand Down Expand Up @@ -263,7 +263,6 @@ struct CudaPhysicalPtrDeleter {
CUmemGenericAllocationHandle handle;
size_t size = 0;
MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, ptr));
MSCCLPP_CUTHROW(cuMemRelease(handle));
MSCCLPP_CUTHROW(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));
MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, size));
MSCCLPP_CUTHROW(cuMemRelease(handle));
Expand Down Expand Up @@ -451,6 +450,17 @@ void memcpyCuda(T* dst, const T* src, size_t count, cudaMemcpyKind kind = cudaMe
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream));
}

bool inline isCuMemMapAllocated(void* ptr) {
CUmemGenericAllocationHandle handle;
CUresult result = cuMemRetainAllocationHandle(&handle, ptr);
if (result != CUDA_SUCCESS) {
printf("ptr is %p\n", ptr);
return false;
}
cuMemRelease(handle);
return true;
}

} // namespace mscclpp

#endif // MSCCLPP_GPU_UTILS_HPP_
5 changes: 5 additions & 0 deletions src/include/registered_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ struct TransportInfo {
const IbMr* ibMr;
IbMrInfo ibMrInfo;
};
struct {
char shareableHandle[64];
size_t offsetFromBase;
};
};
};

Expand All @@ -39,6 +43,7 @@ struct RegisteredMemory::Impl {
size_t size;
uint64_t hostHash;
uint64_t pidHash;
uint32_t isCuMemMap;
TransportFlags transports;
std::vector<TransportInfo> transportInfos;

Expand Down
100 changes: 79 additions & 21 deletions src/registered_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,33 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, TransportFlags transports,
size(size),
hostHash(getHostHash()),
pidHash(getPidHash()),
isCuMemMap(0),
transports(transports) {
if (transports.has(Transport::CudaIpc)) {
TransportInfo transportInfo;
transportInfo.transport = Transport::CudaIpc;
cudaIpcMemHandle_t handle;

void* baseDataPtr;
size_t baseDataSize; // dummy
MSCCLPP_CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data));
MSCCLPP_CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr));
// TODO: bug with offset of base?
transportInfo.cudaIpcBaseHandle = handle;
transportInfo.cudaIpcOffsetFromBase = (char*)data - (char*)baseDataPtr;
if (isCuMemMapAllocated(data)) {
this->isCuMemMap = 1;
}
if (this->isCuMemMap) {
CUmemGenericAllocationHandle handle;
printf("baseDataPtr: %p\n", baseDataPtr);
printf("data: %p\n", data);
MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, baseDataPtr));
MSCCLPP_CUTHROW(
cuMemExportToShareableHandle(transportInfo.shareableHandle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0));
transportInfo.offsetFromBase = (char*)data - (char*)baseDataPtr;
} else {
cudaIpcMemHandle_t handle;
MSCCLPP_CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr));
// TODO: bug with offset of base?
transportInfo.cudaIpcBaseHandle = handle;
transportInfo.cudaIpcOffsetFromBase = (char*)data - (char*)baseDataPtr;
}
this->transportInfos.push_back(transportInfo);
}
if ((transports & AllIBTransports).any()) {
Expand Down Expand Up @@ -75,6 +89,7 @@ MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize() {
std::copy_n(reinterpret_cast<char*>(&pimpl_->size), sizeof(pimpl_->size), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->hostHash), sizeof(pimpl_->hostHash), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->pidHash), sizeof(pimpl_->pidHash), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->isCuMemMap), sizeof(pimpl_->isCuMemMap), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl_->transports), sizeof(pimpl_->transports), std::back_inserter(result));
if (pimpl_->transportInfos.size() > static_cast<size_t>(std::numeric_limits<int8_t>::max())) {
throw mscclpp::Error("Too many transport info entries", ErrorCode::InternalError);
Expand All @@ -84,10 +99,17 @@ MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize() {
for (auto& entry : pimpl_->transportInfos) {
std::copy_n(reinterpret_cast<char*>(&entry.transport), sizeof(entry.transport), std::back_inserter(result));
if (entry.transport == Transport::CudaIpc) {
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcBaseHandle), sizeof(entry.cudaIpcBaseHandle),
std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcOffsetFromBase), sizeof(entry.cudaIpcOffsetFromBase),
std::back_inserter(result));
if (pimpl_->isCuMemMap) {
std::copy_n(reinterpret_cast<char*>(&entry.shareableHandle), sizeof(entry.shareableHandle),
std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&entry.offsetFromBase), sizeof(entry.offsetFromBase),
std::back_inserter(result));
} else {
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcBaseHandle), sizeof(entry.cudaIpcBaseHandle),
std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&entry.cudaIpcOffsetFromBase), sizeof(entry.cudaIpcOffsetFromBase),
std::back_inserter(result));
}
} else if (AllIBTransports.has(entry.transport)) {
std::copy_n(reinterpret_cast<char*>(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result));
} else {
Expand All @@ -111,6 +133,8 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
it += sizeof(this->hostHash);
std::copy_n(it, sizeof(this->pidHash), reinterpret_cast<char*>(&this->pidHash));
it += sizeof(this->pidHash);
std::copy_n(it, sizeof(this->isCuMemMap), reinterpret_cast<char*>(&this->isCuMemMap));
it += sizeof(this->isCuMemMap);
std::copy_n(it, sizeof(this->transports), reinterpret_cast<char*>(&this->transports));
it += sizeof(this->transports);
int8_t transportCount;
Expand All @@ -121,12 +145,19 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast<char*>(&transportInfo.transport));
it += sizeof(transportInfo.transport);
if (transportInfo.transport == Transport::CudaIpc) {
std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle),
reinterpret_cast<char*>(&transportInfo.cudaIpcBaseHandle));
it += sizeof(transportInfo.cudaIpcBaseHandle);
std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase),
reinterpret_cast<char*>(&transportInfo.cudaIpcOffsetFromBase));
it += sizeof(transportInfo.cudaIpcOffsetFromBase);
if (this->isCuMemMap) {
std::copy_n(it, sizeof(transportInfo.shareableHandle), reinterpret_cast<char*>(&transportInfo.shareableHandle));
it += sizeof(transportInfo.shareableHandle);
std::copy_n(it, sizeof(transportInfo.offsetFromBase), reinterpret_cast<char*>(&transportInfo.offsetFromBase));
it += sizeof(transportInfo.offsetFromBase);
} else {
std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle),
reinterpret_cast<char*>(&transportInfo.cudaIpcBaseHandle));
it += sizeof(transportInfo.cudaIpcBaseHandle);
std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase),
reinterpret_cast<char*>(&transportInfo.cudaIpcOffsetFromBase));
it += sizeof(transportInfo.cudaIpcOffsetFromBase);
}
} else if (AllIBTransports.has(transportInfo.transport)) {
std::copy_n(it, sizeof(transportInfo.ibMrInfo), reinterpret_cast<char*>(&transportInfo.ibMrInfo));
it += sizeof(transportInfo.ibMrInfo);
Expand All @@ -148,8 +179,24 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
// The memory is local to the machine but not to the process, so we need to open the CUDA IPC handle
auto entry = getTransportInfo(Transport::CudaIpc);
void* base;
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess));
this->data = static_cast<char*>(base) + entry.cudaIpcOffsetFromBase;
if (this->isCuMemMap) {
CUmemGenericAllocationHandle handle;
MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, entry.shareableHandle, CU_MEM_HANDLE_TYPE_FABRIC));
CUmemAccessDesc accessDesc = {};
int deviceId;
MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId));
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = deviceId;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
size_t gran = 2 * 1024 * 1024;
MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&base, this->size, gran, 0, 0));
MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)base, this->size, 0, handle, 0));
MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)base, this->size, &accessDesc, 1));
this->data = static_cast<char*>(base) + entry.offsetFromBase;
} else {
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess));
this->data = static_cast<char*>(base) + entry.cudaIpcOffsetFromBase;
}
INFO(MSCCLPP_P2P, "Opened CUDA IPC handle at pointer %p", this->data);
} else {
// No valid data pointer can be set
Expand All @@ -161,11 +208,22 @@ RegisteredMemory::Impl::~Impl() {
// Close the CUDA IPC handle if it was opened during deserialization
if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash && getPidHash() != this->pidHash) {
void* base = static_cast<char*>(data) - getTransportInfo(Transport::CudaIpc).cudaIpcOffsetFromBase;
cudaError_t err = cudaIpcCloseMemHandle(base);
if (err != cudaSuccess) {
WARN("Failed to close CUDA IPC handle at pointer %p: %s", base, cudaGetErrorString(err));
if (this->isCuMemMap) {
CUmemGenericAllocationHandle handle;
size_t size = 0;
cuMemRetainAllocationHandle(&handle, base);
cuMemRelease(handle);
cuMemGetAddressRange(NULL, &size, (CUdeviceptr)base);
cuMemUnmap((CUdeviceptr)base, size);
cuMemRelease(handle);
cuMemAddressFree((CUdeviceptr)base, size);
} else {
INFO(MSCCLPP_P2P, "Closed CUDA IPC handle at pointer %p", base);
cudaError_t err = cudaIpcCloseMemHandle(base);
if (err != cudaSuccess) {
WARN("Failed to close CUDA IPC handle at pointer %p: %s", base, cudaGetErrorString(err));
} else {
INFO(MSCCLPP_P2P, "Closed CUDA IPC handle at pointer %p", base);
}
}
data = nullptr;
}
Expand Down

0 comments on commit 4bd884e

Please sign in to comment.