Skip to content

Commit

Permalink
Fix NVLS support (#258)
Browse files Browse the repository at this point in the history
* Do not compile nvls_test with ROCm
* Fix multi-node tests
  • Loading branch information
chhwang authored Feb 6, 2024
1 parent d34e097 commit 6a19b19
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
4 changes: 2 additions & 2 deletions include/mscclpp/nvls_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct DeviceMulticastPointerDeviceHandle {
void* mcPtr;
size_t bufferSize;

#if defined(MSCCLPP_DEVICE_COMPILE)
#if defined(MSCCLPP_DEVICE_CUDA)
template <int NElemPerThread = 4, typename TVaule = float4, typename T = float>
MSCCLPP_DEVICE_INLINE void multimemLoad(TVaule& val, T* ptr) {
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
Expand Down Expand Up @@ -54,7 +54,7 @@ struct DeviceMulticastPointerDeviceHandle {
static_assert(dependentFalse<T>, "Not supported type");
}
};
#endif
#endif // defined(MSCCLPP_DEVICE_CUDA)
};

} // namespace mscclpp
Expand Down
9 changes: 6 additions & 3 deletions python/mscclpp/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,18 @@ def make_connection(
) -> dict[int, Connection]:
if type(endpoints) is Transport:
endpoints = EndpointConfig(endpoints)
if endpoints.transport == Transport.Nvls:
return self.communicator.connct_nvls_collective(all_ranks, endpoints)
elif type(endpoints) is dict:
endpoints = {k: EndpointConfig(v) if type(v) is Transport else v for k, v in endpoints.items()}
connections = {}
for rank in all_ranks:
if type(endpoints) is dict:
endpoint = endpoints[rank]
else:
endpoint = endpoints
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
if endpoint.transport == Transport.Nvls:
connections[rank] = self.communicator.connct_nvls_collective(all_ranks, endpoint)
else:
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
self.communicator.setup()
connections = {rank: connections[rank].get() for rank in connections}
return connections
Expand Down
16 changes: 13 additions & 3 deletions test/nvls_test.cu
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <stdio.h>

#if (USE_NVLS)
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <mpi.h>
#include <stdio.h>
#include <sys/syscall.h>
#include <sys/types.h>
#include <unistd.h>
Expand Down Expand Up @@ -71,7 +73,6 @@ __global__ void testing(float* mc_ptr, int size, int myrank, int nranks) {
}

int main() {
#if (USE_NVLS)
int myrank, nranks;
MPI_Init(NULL, NULL);
MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
Expand Down Expand Up @@ -199,5 +200,14 @@ int main() {
}
MPI_Barrier(MPI_COMM_WORLD);
MPI_Finalize();
#endif // (USE_NVLS)
return 0;
}

#else // !(USE_NVLS)

int main() {
printf("This test requires NVLS to be enabled\n");
return 0;
}

#endif // !(USE_NVLS)

0 comments on commit 6a19b19

Please sign in to comment.