diff --git a/python/mscclpp/comm.py b/python/mscclpp/comm.py index ca3620924..f370b9441 100644 --- a/python/mscclpp/comm.py +++ b/python/mscclpp/comm.py @@ -99,7 +99,7 @@ def make_connection( else: endpoint = endpoints if endpoint.transport == Transport.Nvls: - return connect_nvls_collective(self.communicator, all_ranks) + return connect_nvls_collective(self.communicator, all_ranks, 2**30) else: connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint) self.communicator.setup() diff --git a/python/mscclpp_benchmark/allreduce_bench.py b/python/mscclpp_benchmark/allreduce_bench.py index 69e4f3adc..e93c0479e 100644 --- a/python/mscclpp_benchmark/allreduce_bench.py +++ b/python/mscclpp_benchmark/allreduce_bench.py @@ -289,7 +289,7 @@ def get_netinterface_info(): mscclpp_algbw = [] nccl_algbw = [] speed_ups = [] - end_range = 28 if is_nvls_supported() else 29 + end_range = 29 for i in range(10, end_range): if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1: nelems = 2**i