diff --git a/include/mscclpp/gpu.hpp b/include/mscclpp/gpu.hpp index d9f7f0b27..e860c7b93 100644 --- a/include/mscclpp/gpu.hpp +++ b/include/mscclpp/gpu.hpp @@ -98,9 +98,9 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri // NVLS #if !defined(__HIP_PLATFORM_AMD__) #include -#define USE_NVLS ((CUDART_VERSION >= 12040) && (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 6, 0))) +#define CUDA_NVLS_SUPPORTED ((CUDART_VERSION >= 12040) && (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 6, 0))) #else // !defined(__HIP_PLATFORM_AMD__) -#define USE_NVLS 0 +#define CUDA_NVLS_SUPPORTED 0 #endif // !defined(__HIP_PLATFORM_AMD__) // GPU sync threads diff --git a/include/mscclpp/gpu_utils.hpp b/include/mscclpp/gpu_utils.hpp index c32ec8ab7..80cc435bf 100644 --- a/include/mscclpp/gpu_utils.hpp +++ b/include/mscclpp/gpu_utils.hpp @@ -260,8 +260,10 @@ static inline size_t getMulticastGranularity(size_t size, CUmulticastGranularity #if defined(__HIP_PLATFORM_AMD__) // TODO: revisit when HIP fixes this typo in the field name prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; -#else +#elif (CUDA_NVLS_SUPPORTED) prop.handleTypes = (CUmemAllocationHandleType)(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR | CU_MEM_HANDLE_TYPE_FABRIC); +#else + prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; #endif prop.flags = 0; MSCCLPP_CUTHROW(cuMulticastGetGranularity(&gran, &prop, granFlag)); @@ -278,6 +280,9 @@ std::shared_ptr allocSharedPhysicalCuda(size_t count, size_t gran = 0) { if (!isFabricSupported()) { throw Error("Only suupport GPU with Fabric support", ErrorCode::InvalidUsage); } + if (count == 0) { + return nullptr; + } if (gran == 0) { gran = getMulticastGranularity(count * sizeof(T), CU_MULTICAST_GRANULARITY_RECOMMENDED); @@ -384,7 +389,17 @@ UniqueCudaHostPtr makeUniqueCudaHost(size_t count) { /// @param gran the granularity of the allocation. /// @return A std::unique_ptr to the allocated memory. template -std::unique_ptr allocUniquePhysicalCuda(size_t count, size_t gran) { +std::unique_ptr allocUniquePhysicalCuda(size_t count, size_t gran = 0) { + if (!isFabricSupported()) { + throw Error("Only suupport GPU with Fabric support", ErrorCode::InvalidUsage); + } + if (count == 0) { + return nullptr; + } + + if (gran == 0) { + gran = getMulticastGranularity(count * sizeof(T), CU_MULTICAST_GRANULARITY_RECOMMENDED); + } return detail::safeAlloc, CudaPhysicalDeleter, std::unique_ptr, CudaDeleter>>>(count, gran); } diff --git a/src/nvls.cc b/src/nvls.cc index d52cad6bd..bbb21a7c3 100644 --- a/src/nvls.cc +++ b/src/nvls.cc @@ -15,7 +15,7 @@ namespace mscclpp { -#if (USE_NVLS) +#if (CUDA_NVLS_SUPPORTED) class NvlsConnection::Impl : public std::enable_shared_from_this { public: // use this only for the root of the NVLS @@ -229,7 +229,7 @@ std::shared_ptr NvlsConnection::Impl::bindMemoryToMulticastHandle(size_t o return std::shared_ptr(mcPtr, deleter); } -#else // !(USE_NVLS) +#else // !(CUDA_NVLS_SUPPORTED) class NvlsConnection::Impl { public: // use this only for the root of the NVLS @@ -251,7 +251,7 @@ class NvlsConnection::Impl { Error notSupportedError = Error("NVLS is not supported on this CUDA version (< 12.1) or kernel version (< 5.6.0)", ErrorCode::InvalidUsage); }; -#endif // !(USE_NVLS) +#endif // !(CUDA_NVLS_SUPPORTED) const int NvlsConnection::DefaultNvlsBufferSize = (1 << 29); diff --git a/test/executor_test.cc b/test/executor_test.cc index 53925e7fc..3fc0b1e21 100644 --- a/test/executor_test.cc +++ b/test/executor_test.cc @@ -131,7 +131,7 @@ int main(int argc, char* argv[]) { } mscclpp::ExecutionPlan plan(executionPlanName, executionPlanPath); -#if (USE_NVLS) +#if (CUDA_NVLS_SUPPORTED) std::shared_ptr sendbuff = mscclpp::allocSharedPhysicalCuda(bufferSize); #else std::shared_ptr sendbuff = mscclpp::allocExtSharedCuda(bufferSize); diff --git a/test/nvls_test.cu b/test/nvls_test.cu index bc90aa795..42aefdc2d 100644 --- a/test/nvls_test.cu +++ b/test/nvls_test.cu @@ -11,7 +11,7 @@ #include #include -#if (USE_NVLS) +#if (CUDA_NVLS_SUPPORTED) #define CUCHECK(cmd) \ do { \ @@ -202,11 +202,11 @@ int main() { return 0; } -#else // !(USE_NVLS) +#else // !(CUDA_NVLS_SUPPORTED) int main() { printf("This test requires NVLS to be enabled\n"); return 0; } -#endif // !(USE_NVLS) +#endif // !(CUDA_NVLS_SUPPORTED)