Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Jul 28, 2023
1 parent b9ec5a6 commit b82a86a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
9 changes: 4 additions & 5 deletions include/mscclpp/proxy_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <mscclpp/fifo.hpp>
#include <mscclpp/proxy.hpp>
#include <mscclpp/semaphore.hpp>
#include <unordered_map>

namespace mscclpp {

Expand Down Expand Up @@ -41,10 +40,10 @@ class ProxyService : public BaseProxyService {
/// @return The ID of the semaphore.
SemaphoreId addSemaphore(std::shared_ptr<Connection> connection);

/// Add a pitch pair to the proxy service.
/// @param id The ID of the semaphore.
/// Add a 2D channel to the proxy service.
/// @param connection The connection associated with the channel.
/// @param pitch The pitch pair.
void addPitch(SemaphoreId id, std::pair<uint64_t, uint64_t> pitch);
SemaphoreId add2DChannel(std::shared_ptr<Connection> connection, std::pair<uint64_t, uint64_t> pitch);

/// Register a memory region with the proxy service.
/// @param memory The memory region to register.
Expand All @@ -71,7 +70,7 @@ class ProxyService : public BaseProxyService {
Communicator& communicator_;
std::vector<std::shared_ptr<Host2DeviceSemaphore>> semaphores_;
std::vector<RegisteredMemory> memories_;
std::unordered_map<SemaphoreId, std::pair<uint64_t, uint64_t>> pitches_;
std::vector<std::pair<uint64_t, uint64_t>> pitches_;
Proxy proxy_;
int deviceNumaNode;

Expand Down
7 changes: 6 additions & 1 deletion src/proxy_channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr<Connectio
return semaphores_.size() - 1;
}

MSCCLPP_API_CPP void ProxyService::addPitch(SemaphoreId id, std::pair<uint64_t, uint64_t> pitch) {
MSCCLPP_API_CPP SemaphoreId ProxyService::add2DChannel(std::shared_ptr<Connection> connection,
std::pair<uint64_t, uint64_t> pitch) {
semaphores_.push_back(std::make_shared<Host2DeviceSemaphore>(communicator_, connection));
SemaphoreId id = semaphores_.size() - 1;
if (id >= pitches_.size()) pitches_.resize(id + 1, std::pair<uint64_t, uint64_t>(0, 0));
pitches_[id] = pitch;
return id;
}

MSCCLPP_API_CPP MemoryId ProxyService::addMemory(RegisteredMemory memory) {
Expand Down
3 changes: 1 addition & 2 deletions test/mp_unit/proxy_channel_tests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ void ProxyChannelOneToOneTest::setupMeshConnections(

communicator->setup();

mscclpp::SemaphoreId cid = channelService->addSemaphore(conn);
channelService->addPitch(cid, std::pair<size_t, size_t>(pitch, pitch));
mscclpp::SemaphoreId cid = channelService->add2DChannel(conn, std::pair<size_t, size_t>(pitch, pitch));
communicator->setup();

proxyChannels.emplace_back(mscclpp::deviceHandle(
Expand Down

0 comments on commit b82a86a

Please sign in to comment.