diff --git a/src/communicator.cc b/src/communicator.cc index d5b49fa1..b87388b7 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -94,12 +94,9 @@ MSCCLPP_API_CPP NonblockingFuture Communicator::recvMemoryOnSe return NonblockingFuture(memoryReceiver->memoryPromise_.get_future()); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int remoteRank, int tag, Transport transport, - int ibMaxCqSize /*=1024*/, - int ibMaxCqPollNum /*=1*/, - int ibMaxSendWr /*=8192*/, - int ibMaxWrPerSend /*=64*/, - int ibMaxNumSgesPerWr /*=16*/) { +MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup( + int remoteRank, int tag, Transport transport, int ibMaxCqSize /*=1024*/, int ibMaxCqPollNum /*=1*/, + int ibMaxSendWr /*=8192*/, int ibMaxWrPerSend /*=64*/, int ibMaxNumSgesPerWr /*=16*/) { std::shared_ptr conn; if (transport == Transport::CudaIpc) { // sanity check: make sure the IPC connection is being made within a node diff --git a/src/ib.cc b/src/ib.cc index 0dcb33fd..63eb1040 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -16,7 +16,7 @@ #include "api.h" #include "debug.h" -static ibv_device_attr getDeviceAttr(ibv_context *ctx) { +static ibv_device_attr getDeviceAttr(ibv_context* ctx) { ibv_device_attr devAttr; if (ibv_query_device(ctx, &devAttr) != 0) { std::stringstream err; @@ -26,6 +26,12 @@ static ibv_device_attr getDeviceAttr(ibv_context *ctx) { return devAttr; } +static ibv_qp_attr createQpAttr() { + ibv_qp_attr qpAttr; + std::memset(&qpAttr, 0, sizeof(qpAttr)); + return qpAttr; +} + namespace mscclpp { IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { @@ -115,8 +121,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN this->info.iid = gid.global.interface_id; } - struct ibv_qp_attr qpAttr; - memset(&qpAttr, 0, sizeof(qpAttr)); + ibv_qp_attr qpAttr = createQpAttr(); qpAttr.qp_state = IBV_QPS_INIT; qpAttr.pkey_index = 0; qpAttr.port_num = port; @@ -140,30 +145,29 @@ IbQp::~IbQp() { } void IbQp::rtr(const IbQpInfo& info) { - struct ibv_qp_attr qp_attr; - std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); - qp_attr.qp_state = IBV_QPS_RTR; - qp_attr.path_mtu = static_cast(info.mtu); - qp_attr.dest_qp_num = info.qpn; - qp_attr.rq_psn = 0; - qp_attr.max_dest_rd_atomic = 1; - qp_attr.min_rnr_timer = 0x12; + ibv_qp_attr qpAttr = createQpAttr(); + qpAttr.qp_state = IBV_QPS_RTR; + qpAttr.path_mtu = static_cast(info.mtu); + qpAttr.dest_qp_num = info.qpn; + qpAttr.rq_psn = 0; + qpAttr.max_dest_rd_atomic = 1; + qpAttr.min_rnr_timer = 0x12; if (info.linkLayer == IBV_LINK_LAYER_ETHERNET || info.is_grh) { - qp_attr.ah_attr.is_global = 1; - qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info.spn; - qp_attr.ah_attr.grh.dgid.global.interface_id = info.iid; - qp_attr.ah_attr.grh.flow_label = 0; - qp_attr.ah_attr.grh.sgid_index = 0; - qp_attr.ah_attr.grh.hop_limit = 255; - qp_attr.ah_attr.grh.traffic_class = 0; + qpAttr.ah_attr.is_global = 1; + qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info.spn; + qpAttr.ah_attr.grh.dgid.global.interface_id = info.iid; + qpAttr.ah_attr.grh.flow_label = 0; + qpAttr.ah_attr.grh.sgid_index = 0; + qpAttr.ah_attr.grh.hop_limit = 255; + qpAttr.ah_attr.grh.traffic_class = 0; } else { - qp_attr.ah_attr.is_global = 0; + qpAttr.ah_attr.is_global = 0; } - qp_attr.ah_attr.dlid = info.lid; - qp_attr.ah_attr.sl = 0; - qp_attr.ah_attr.src_path_bits = 0; - qp_attr.ah_attr.port_num = info.port; - int ret = ibv_modify_qp(this->qp, &qp_attr, + qpAttr.ah_attr.dlid = info.lid; + qpAttr.ah_attr.sl = 0; + qpAttr.ah_attr.src_path_bits = 0; + qpAttr.ah_attr.port_num = info.port; + int ret = ibv_modify_qp(this->qp, &qpAttr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); if (ret != 0) { @@ -174,16 +178,15 @@ void IbQp::rtr(const IbQpInfo& info) { } void IbQp::rts() { - struct ibv_qp_attr qp_attr; - std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); - qp_attr.qp_state = IBV_QPS_RTS; - qp_attr.timeout = 18; - qp_attr.retry_cnt = 7; - qp_attr.rnr_retry = 7; - qp_attr.sq_psn = 0; - qp_attr.max_rd_atomic = 1; + ibv_qp_attr qpAttr = createQpAttr(); + qpAttr.qp_state = IBV_QPS_RTS; + qpAttr.timeout = 18; + qpAttr.retry_cnt = 7; + qpAttr.rnr_retry = 7; + qpAttr.sq_psn = 0; + qpAttr.max_rd_atomic = 1; int ret = ibv_modify_qp( - this->qp, &qp_attr, + this->qp, &qpAttr, IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC); if (ret != 0) { std::stringstream err; @@ -257,8 +260,9 @@ void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, wrInfo.sge->lkey = mr->getLkey(); } -void IbQp::stageSendGather(const std::vector& srcMrs, const IbMrInfo& dstInfo, const std::vector& srcSizes, - uint64_t wrId, const std::vector& srcOffsets, uint64_t dstOffset, bool signaled) { +void IbQp::stageSendGather(const std::vector& srcMrs, const IbMrInfo& dstInfo, + const std::vector& srcSizes, uint64_t wrId, + const std::vector& srcOffsets, uint64_t dstOffset, bool signaled) { size_t numSrcs = srcMrs.size(); if (numSrcs != srcSizes.size() || numSrcs != srcOffsets.size()) { std::stringstream err; @@ -272,9 +276,11 @@ void IbQp::stageSendGather(const std::vector& srcMrs, const IbMrInfo& dst wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0; wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstInfo.addr) + dstOffset; wrInfo.wr->wr.rdma.rkey = dstInfo.rkey; - // wrInfo.sge->addr = (uint64_t)(mr->getBuff()) + srcOffset; - // wrInfo.sge->length = size; - // wrInfo.sge->lkey = mr->getLkey(); + for (size_t i = 0; i < numSrcs; ++i) { + wrInfo.sge[i].addr = (uint64_t)(srcMrs[i]->getBuff()) + srcOffsets[i]; + wrInfo.sge[i].length = srcSizes[i]; + wrInfo.sge[i].lkey = srcMrs[i]->getLkey(); + } } void IbQp::postSend() { @@ -367,8 +373,8 @@ int IbCtx::getAnyActivePort() const { return -1; } -void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, - int port) const { +void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, + int maxNumSgesPerWr, int port) const { if (!this->isPortUsable(port)) { throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage); } @@ -393,16 +399,17 @@ void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int m } } -IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, - int port /*=-1*/) { +IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, + int maxNumSgesPerWr, int port /*=-1*/) { if (port == -1) { port = this->getAnyActivePort(); if (port == -1) { throw mscclpp::Error("No active port found", ErrorCode::InternalError); } } - validateConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port); - qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr)); + this->validateConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port); + qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, + maxNumSgesPerWr)); return qps.back().get(); } diff --git a/src/include/ib.hpp b/src/include/ib.hpp index 0126ef89..cb909111 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -7,6 +7,7 @@ #include #include #include +#include // Forward declarations of IB structures struct ibv_context; @@ -107,7 +108,8 @@ class IbCtx { IbCtx(const std::string& devName); ~IbCtx(); - IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port = -1); + IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, + int port = -1); const IbMr* registerMr(void* buff, std::size_t size); const std::string& getDevName() const; @@ -115,7 +117,8 @@ class IbCtx { private: bool isPortUsable(int port) const; int getAnyActivePort() const; - void validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port) const; + void validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, + int maxNumSgesPerWr, int port) const; const std::string devName; ibv_context* ctx;