diff --git a/include/mscclpp/nvls_device.hpp b/include/mscclpp/nvls_device.hpp index 52ade275d..edaf2e256 100644 --- a/include/mscclpp/nvls_device.hpp +++ b/include/mscclpp/nvls_device.hpp @@ -20,7 +20,7 @@ struct DeviceMulticastPointerDeviceHandle { void* mcPtr; size_t bufferSize; -#if defined(MSCCLPP_DEVICE_COMPILE) +#if defined(MSCCLPP_DEVICE_CUDA) template MSCCLPP_DEVICE_INLINE void multimemLoad(TVaule& val, T* ptr) { static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4"); @@ -54,7 +54,7 @@ struct DeviceMulticastPointerDeviceHandle { static_assert(dependentFalse, "Not supported type"); } }; -#endif +#endif // defined(MSCCLPP_DEVICE_CUDA) }; } // namespace mscclpp diff --git a/python/mscclpp/comm.py b/python/mscclpp/comm.py index d84410668..918cd97fb 100644 --- a/python/mscclpp/comm.py +++ b/python/mscclpp/comm.py @@ -86,15 +86,18 @@ def make_connection( ) -> dict[int, Connection]: if type(endpoints) is Transport: endpoints = EndpointConfig(endpoints) - if endpoints.transport == Transport.Nvls: - return self.communicator.connct_nvls_collective(all_ranks, endpoints) + elif type(endpoints) is dict: + endpoints = {k: EndpointConfig(v) if type(v) is Transport else v for k, v in endpoints.items()} connections = {} for rank in all_ranks: if type(endpoints) is dict: endpoint = endpoints[rank] else: endpoint = endpoints - connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint) + if endpoint.transport == Transport.Nvls: + connections[rank] = self.communicator.connct_nvls_collective(all_ranks, endpoint) + else: + connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint) self.communicator.setup() connections = {rank: connections[rank].get() for rank in connections} return connections diff --git a/test/nvls_test.cu b/test/nvls_test.cu index e01b4d790..55ece3fcf 100644 --- a/test/nvls_test.cu +++ b/test/nvls_test.cu @@ -1,11 +1,13 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +#include + +#if (USE_NVLS) #include #include #include #include -#include #include #include #include @@ -71,7 +73,6 @@ __global__ void testing(float* mc_ptr, int size, int myrank, int nranks) { } int main() { -#if (USE_NVLS) int myrank, nranks; MPI_Init(NULL, NULL); MPI_Comm_rank(MPI_COMM_WORLD, &myrank); @@ -199,5 +200,14 @@ int main() { } MPI_Barrier(MPI_COMM_WORLD); MPI_Finalize(); -#endif // (USE_NVLS) + return 0; } + +#else // !(USE_NVLS) + +int main() { + printf("This test requires NVLS to be enabled\n"); + return 0; +} + +#endif // !(USE_NVLS)