Skip to content

Commit

Permalink
name change
Browse files Browse the repository at this point in the history
  • Loading branch information
Saeed Maleki committed Aug 4, 2023
1 parent 88b6741 commit b21f068
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 33 deletions.
18 changes: 9 additions & 9 deletions test/allgather_test_cpp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ void initializeAndAllocateAllGatherData(int rank, int world_size, size_t dataSiz
CUDACHECK(cudaMemcpy(*data_d, *data_h, dataSize, cudaMemcpyHostToDevice));
}

void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm,
mscclpp::ProxyService& channelService, int* data_d, size_t dataSize) {
void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, mscclpp::ProxyService& proxyService,
int* data_d, size_t dataSize) {
int thisNode = rankToNode(rank);
int cudaNum = rankToLocalRank(rank);
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
Expand All @@ -226,7 +226,7 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
transport = ibTransport;
}
// Connect with all other ranks
semaphoreIds.push_back(channelService.buildAndAddSemaphore(comm.connectOnSetup(r, 0, transport)));
semaphoreIds.push_back(proxyService.buildAndAddSemaphore(comm.connectOnSetup(r, 0, transport)));
auto memory = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport);
localMemories.push_back(memory);
comm.sendMemoryOnSetup(memory, r, 0);
Expand All @@ -238,8 +238,8 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>> proxyChannels;
for (size_t i = 0; i < semaphoreIds.size(); ++i) {
proxyChannels.push_back(mscclpp::deviceHandle(mscclpp::SimpleProxyChannel(
channelService.proxyChannel(semaphoreIds[i]), channelService.addMemory(remoteMemories[i].get()),
channelService.addMemory(localMemories[i]))));
proxyService.proxyChannel(semaphoreIds[i]), proxyService.addMemory(remoteMemories[i].get()),
proxyService.addMemory(localMemories[i]))));
}

assert(proxyChannels.size() < sizeof(constProxyChans) / sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>));
Expand Down Expand Up @@ -396,16 +396,16 @@ int main(int argc, const char* argv[]) {
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
bootstrap->initialize(ip_port);
mscclpp::Communicator comm(bootstrap);
mscclpp::ProxyService channelService(comm);
mscclpp::ProxyService proxyService(comm);

if (rank == 0) printf("Initializing data for allgather test\n");
initializeAndAllocateAllGatherData(rank, world_size, dataSize, nelemsPerGPU, &data_h, &data_d);

if (rank == 0) printf("Setting up the connection in MSCCL++\n");
setupMscclppConnections(rank, world_size, comm, channelService, data_d, dataSize);
setupMscclppConnections(rank, world_size, comm, proxyService, data_d, dataSize);

if (rank == 0) printf("Launching MSCCL++ proxy threads\n");
channelService.startProxy();
proxyService.startProxy();

if (rank == 0) printf("Testing the correctness of AllGather implementation\n");
cudaStream_t stream;
Expand Down Expand Up @@ -480,7 +480,7 @@ int main(int argc, const char* argv[]) {
bootstrap->allGather(tmp, sizeof(int));

if (rank == 0) printf("Stopping MSCCL++ proxy threads\n");
channelService.stopProxy();
proxyService.stopProxy();

} catch (std::exception& e) {
// todo: throw exceptions in the implementation and process them here
Expand Down
2 changes: 1 addition & 1 deletion test/mp_unit/mp_unit_tests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class ProxyChannelOneToOneTest : public CommunicatorTestBase {
void testPacketPingPong(bool useIbOnly);
void testPacketPingPongPerf(bool useIbOnly);

std::shared_ptr<mscclpp::ProxyService> channelService;
std::shared_ptr<mscclpp::ProxyService> proxyService;
};

class SmChannelOneToOneTest : public CommunicatorTestBase {
Expand Down
20 changes: 10 additions & 10 deletions test/mp_unit/proxy_channel_tests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ void ProxyChannelOneToOneTest::SetUp() {
// Use only two ranks
setNumRanksToUse(2);
CommunicatorTestBase::SetUp();
channelService = std::make_shared<mscclpp::ProxyService>(*communicator.get());
proxyService = std::make_shared<mscclpp::ProxyService>(*communicator.get());
}

void ProxyChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); }
Expand Down Expand Up @@ -49,11 +49,11 @@ void ProxyChannelOneToOneTest::setupMeshConnections(std::vector<mscclpp::SimpleP

communicator->setup();

mscclpp::SemaphoreId cid = channelService->buildAndAddSemaphore(conn);
mscclpp::SemaphoreId cid = proxyService->buildAndAddSemaphore(conn);
communicator->setup();

proxyChannels.emplace_back(channelService->proxyChannel(cid), channelService->addMemory(remoteMemory.get()),
channelService->addMemory(sendBufRegMem));
proxyChannels.emplace_back(proxyService->proxyChannel(cid), proxyService->addMemory(remoteMemory.get()),
proxyService->addMemory(sendBufRegMem));
}
}

Expand Down Expand Up @@ -128,7 +128,7 @@ TEST_F(ProxyChannelOneToOneTest, PingPongIb) {
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannelHandles.data(),
sizeof(DeviceHandle<mscclpp::SimpleProxyChannel>)));

channelService->startProxy();
proxyService->startProxy();

std::shared_ptr<int> ret = mscclpp::makeSharedCudaHost<int>(0);

Expand All @@ -152,7 +152,7 @@ TEST_F(ProxyChannelOneToOneTest, PingPongIb) {

EXPECT_EQ(*ret, 0);

channelService->stopProxy();
proxyService->stopProxy();
}

__device__ mscclpp::DeviceSyncer gChannelOneToOneTestProxyChansSyncer;
Expand Down Expand Up @@ -249,7 +249,7 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) {
mscclpp::DeviceSyncer syncer = {};
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestProxyChansSyncer, &syncer, sizeof(mscclpp::DeviceSyncer)));

channelService->startProxy();
proxyService->startProxy();

std::shared_ptr<int> ret = mscclpp::makeSharedCudaHost<int>(0);

Expand Down Expand Up @@ -285,7 +285,7 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) {

communicator->bootstrap()->barrier();

channelService->stopProxy();
proxyService->stopProxy();
}

void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
Expand Down Expand Up @@ -316,7 +316,7 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
mscclpp::DeviceSyncer syncer = {};
MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestProxyChansSyncer, &syncer, sizeof(mscclpp::DeviceSyncer)));

channelService->startProxy();
proxyService->startProxy();

auto* testInfo = ::testing::UnitTest::GetInstance()->current_test_info();
const std::string testName = std::string(testInfo->test_suite_name()) + "." + std::string(testInfo->name());
Expand All @@ -341,7 +341,7 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) {
std::cout << testName << ": " << std::setprecision(4) << (float)timer.elapsed() / (float)nTries << " us/iter\n";
}

channelService->stopProxy();
proxyService->stopProxy();
}

TEST_F(ProxyChannelOneToOneTest, PacketPingPong) { testPacketPingPong(false); }
Expand Down
20 changes: 10 additions & 10 deletions test/mscclpp-test/allgather_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@ __global__ void allgather4(int rank, int worldSize, int nRanksPerNode, size_t ne
nBlocksForLocalAllGather);
}

class AllGatherChannelService : public mscclpp::BaseProxyService {
class AllGatherproxyService : public mscclpp::BaseProxyService {
public:
AllGatherChannelService(mscclpp::Communicator& communicator, int worldSize, int rank, int cudaDevice);
AllGatherproxyService(mscclpp::Communicator& communicator, int worldSize, int rank, int cudaDevice);
void startProxy() override { proxy_.start(); }
void stopProxy() override { proxy_.stop(); }
void setSendBytes(size_t sendBytes) { this->sendBytes_ = sendBytes; }
Expand Down Expand Up @@ -314,8 +314,8 @@ class AllGatherChannelService : public mscclpp::BaseProxyService {
mscclpp::ProxyHandlerResult handleTrigger(mscclpp::ProxyTrigger triggerRaw);
};

AllGatherChannelService::AllGatherChannelService(mscclpp::Communicator& communicator, int worldSize, int rank,
int cudaDevice)
AllGatherproxyService::AllGatherproxyService(mscclpp::Communicator& communicator, int worldSize, int rank,
int cudaDevice)
: communicator_(communicator),
worldSize_(worldSize),
sendBytes_(0),
Expand All @@ -327,7 +327,7 @@ AllGatherChannelService::AllGatherChannelService(mscclpp::Communicator& communic
numaBind(deviceNumaNode);
}) {}

mscclpp::ProxyHandlerResult AllGatherChannelService::handleTrigger(mscclpp::ProxyTrigger triggerRaw) {
mscclpp::ProxyHandlerResult AllGatherproxyService::handleTrigger(mscclpp::ProxyTrigger triggerRaw) {
size_t offset = rank_ * sendBytes_;
if (triggerRaw.fst != MAGIC) {
// this is not a valid trigger
Expand Down Expand Up @@ -432,7 +432,7 @@ void AllGatherTestColl::setupCollTest(size_t size) {
paramCount_ = base;
expectedCount_ = recvCount_;
if (isUsingHostOffload(kernelNum_)) {
auto service = std::dynamic_pointer_cast<AllGatherChannelService>(chanService_);
auto service = std::dynamic_pointer_cast<AllGatherproxyService>(chanService_);
service->setSendBytes(sendCount_ * typeSize_);
}
mscclpp::DeviceSyncer syncer = {};
Expand All @@ -459,7 +459,7 @@ class AllGatherTestEngine : public BaseTestEngine {
std::vector<void*> getSendBuff() override;
void* getRecvBuff() override;
void* getScratchBuff() override;
std::shared_ptr<mscclpp::BaseProxyService> createChannelService() override;
std::shared_ptr<mscclpp::BaseProxyService> createproxyService() override;

private:
void* getExpectedBuff() override;
Expand Down Expand Up @@ -492,7 +492,7 @@ void AllGatherTestEngine::setupConnections() {
CUDATHROW(cudaMemcpyToSymbol(constSmChans, smChannelHandles.data(),
sizeof(DeviceHandle<mscclpp::SmChannel>) * smChannelHandles.size()));
} else {
auto service = std::dynamic_pointer_cast<AllGatherChannelService>(chanService_);
auto service = std::dynamic_pointer_cast<AllGatherproxyService>(chanService_);
setupMeshConnections(devProxyChannels, sendBuff_.get(), args_.maxBytes, nullptr, 0,
[&](std::vector<std::shared_ptr<mscclpp::Connection>> conns,
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>>& remoteMemories,
Expand All @@ -512,9 +512,9 @@ void AllGatherTestEngine::setupConnections() {
}
}

std::shared_ptr<mscclpp::BaseProxyService> AllGatherTestEngine::createChannelService() {
std::shared_ptr<mscclpp::BaseProxyService> AllGatherTestEngine::createproxyService() {
if (isUsingHostOffload(args_.kernelNum)) {
return std::make_shared<AllGatherChannelService>(*comm_, args_.totalRanks, args_.rank, args_.gpuNum);
return std::make_shared<AllGatherproxyService>(*comm_, args_.totalRanks, args_.rank, args_.gpuNum);
} else {
return std::make_shared<mscclpp::ProxyService>(*comm_);
}
Expand Down
4 changes: 2 additions & 2 deletions test/mscclpp-test/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ void BaseTestEngine::bootstrap() {
}

void BaseTestEngine::setupTest() {
this->chanService_ = this->createChannelService();
this->chanService_ = this->createproxyService();
this->setupConnections();
this->chanService_->startProxy();
this->coll_->setChanService(this->chanService_);
Expand All @@ -357,7 +357,7 @@ size_t BaseTestEngine::checkData() {
return nErrors;
}

std::shared_ptr<mscclpp::BaseProxyService> BaseTestEngine::createChannelService() {
std::shared_ptr<mscclpp::BaseProxyService> BaseTestEngine::createproxyService() {
return std::make_shared<mscclpp::ProxyService>(*comm_);
}

Expand Down
2 changes: 1 addition & 1 deletion test/mscclpp-test/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class BaseTestEngine {

private:
virtual void setupConnections() = 0;
virtual std::shared_ptr<mscclpp::BaseProxyService> createChannelService();
virtual std::shared_ptr<mscclpp::BaseProxyService> createproxyService();
virtual void* getExpectedBuff() = 0;

double benchTime();
Expand Down

0 comments on commit b21f068

Please sign in to comment.