diff --git a/test/mscclpp-test/common.cc b/test/mscclpp-test/common.cc index 9c52f9f4a..899823f7d 100644 --- a/test/mscclpp-test/common.cc +++ b/test/mscclpp-test/common.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -399,7 +400,8 @@ void BaseTestEngine::setupMeshConnectionsInternal( void BaseTestEngine::setupMeshConnections(std::vector>& proxyChannels, void* inputBuff, size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes, SetupChannelFunc setupChannel) { - const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum]; + mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc; + if (mscclpp::getIBDeviceCount() > 0) allTransports |= IBs[args_.gpuNum]; mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports); mscclpp::RegisteredMemory outputBufRegMem; if (outputBuff) { @@ -429,7 +431,8 @@ void BaseTestEngine::setupMeshConnections(std::vector& smChannels, void* inputBuff, size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes, ChannelSemantic semantic, size_t nChannelPerConnection) { - const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum]; + mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc; + if (mscclpp::getIBDeviceCount() > 0) allTransports |= IBs[args_.gpuNum]; mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports); mscclpp::RegisteredMemory getPacketBufRegMem; mscclpp::RegisteredMemory outputBufRegMem; @@ -469,7 +472,8 @@ void BaseTestEngine::setupMeshConnections(std::vector& smCha void* inputBuff, size_t inputBuffBytes, void* putPacketBuff, size_t putPacketBuffBytes, void* getPacketBuff, size_t getPacketBuffBytes, void* outputBuff, size_t outputBuffBytes) { - const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum]; + mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc; + if (mscclpp::getIBDeviceCount() > 0) allTransports |= IBs[args_.gpuNum]; mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports); mscclpp::RegisteredMemory putPacketBufRegMem; mscclpp::RegisteredMemory getPacketBufRegMem;