Skip to content

Commit

Permalink
Use IB transport flags only when an IB device exists (#355)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Sep 19, 2024
1 parent 5c4e105 commit 74130c7
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions test/mscclpp-test/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <fstream>
#include <iomanip>
#include <iostream>
#include <mscclpp/core.hpp>
#include <mscclpp/utils.hpp>
#include <nlohmann/json.hpp>
#include <sstream>
Expand Down Expand Up @@ -399,7 +400,8 @@ void BaseTestEngine::setupMeshConnectionsInternal(
void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>>& proxyChannels,
void* inputBuff, size_t inputBuffBytes, void* outputBuff,
size_t outputBuffBytes, SetupChannelFunc setupChannel) {
const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum];
mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc;
if (mscclpp::getIBDeviceCount() > 0) allTransports |= IBs[args_.gpuNum];
mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports);
mscclpp::RegisteredMemory outputBufRegMem;
if (outputBuff) {
Expand Down Expand Up @@ -429,7 +431,8 @@ void BaseTestEngine::setupMeshConnections(std::vector<DeviceHandle<mscclpp::Simp
void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smChannels, void* inputBuff,
size_t inputBuffBytes, void* outputBuff, size_t outputBuffBytes,
ChannelSemantic semantic, size_t nChannelPerConnection) {
const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum];
mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc;
if (mscclpp::getIBDeviceCount() > 0) allTransports |= IBs[args_.gpuNum];
mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports);
mscclpp::RegisteredMemory getPacketBufRegMem;
mscclpp::RegisteredMemory outputBufRegMem;
Expand Down Expand Up @@ -469,7 +472,8 @@ void BaseTestEngine::setupMeshConnections(std::vector<mscclpp::SmChannel>& smCha
void* inputBuff, size_t inputBuffBytes, void* putPacketBuff,
size_t putPacketBuffBytes, void* getPacketBuff, size_t getPacketBuffBytes,
void* outputBuff, size_t outputBuffBytes) {
const mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc | IBs[args_.gpuNum];
mscclpp::TransportFlags allTransports = mscclpp::Transport::CudaIpc;
if (mscclpp::getIBDeviceCount() > 0) allTransports |= IBs[args_.gpuNum];
mscclpp::RegisteredMemory inputBufRegMem = comm_->registerMemory(inputBuff, inputBuffBytes, allTransports);
mscclpp::RegisteredMemory putPacketBufRegMem;
mscclpp::RegisteredMemory getPacketBufRegMem;
Expand Down

0 comments on commit 74130c7

Please sign in to comment.