Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Sep 6, 2023
1 parent ad13693 commit 89cad56
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 52 deletions.
9 changes: 3 additions & 6 deletions src/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,9 @@ MSCCLPP_API_CPP NonblockingFuture<RegisteredMemory> Communicator::recvMemoryOnSe
return NonblockingFuture<RegisteredMemory>(memoryReceiver->memoryPromise_.get_future());
}

MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(int remoteRank, int tag, Transport transport,
int ibMaxCqSize /*=1024*/,
int ibMaxCqPollNum /*=1*/,
int ibMaxSendWr /*=8192*/,
int ibMaxWrPerSend /*=64*/,
int ibMaxNumSgesPerWr /*=16*/) {
MSCCLPP_API_CPP std::shared_ptr<Connection> Communicator::connectOnSetup(
int remoteRank, int tag, Transport transport, int ibMaxCqSize /*=1024*/, int ibMaxCqPollNum /*=1*/,
int ibMaxSendWr /*=8192*/, int ibMaxWrPerSend /*=64*/, int ibMaxNumSgesPerWr /*=16*/) {
std::shared_ptr<ConnectionBase> conn;
if (transport == Transport::CudaIpc) {
// sanity check: make sure the IPC connection is being made within a node
Expand Down
95 changes: 51 additions & 44 deletions src/ib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "api.h"
#include "debug.h"

static ibv_device_attr getDeviceAttr(ibv_context *ctx) {
static ibv_device_attr getDeviceAttr(ibv_context* ctx) {
ibv_device_attr devAttr;
if (ibv_query_device(ctx, &devAttr) != 0) {
std::stringstream err;
Expand All @@ -26,6 +26,12 @@ static ibv_device_attr getDeviceAttr(ibv_context *ctx) {
return devAttr;
}

static ibv_qp_attr createQpAttr() {
ibv_qp_attr qpAttr;
std::memset(&qpAttr, 0, sizeof(qpAttr));
return qpAttr;
}

namespace mscclpp {

IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
Expand Down Expand Up @@ -115,8 +121,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
this->info.iid = gid.global.interface_id;
}

struct ibv_qp_attr qpAttr;
memset(&qpAttr, 0, sizeof(qpAttr));
ibv_qp_attr qpAttr = createQpAttr();
qpAttr.qp_state = IBV_QPS_INIT;
qpAttr.pkey_index = 0;
qpAttr.port_num = port;
Expand All @@ -140,30 +145,29 @@ IbQp::~IbQp() {
}

void IbQp::rtr(const IbQpInfo& info) {
struct ibv_qp_attr qp_attr;
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
qp_attr.qp_state = IBV_QPS_RTR;
qp_attr.path_mtu = static_cast<ibv_mtu>(info.mtu);
qp_attr.dest_qp_num = info.qpn;
qp_attr.rq_psn = 0;
qp_attr.max_dest_rd_atomic = 1;
qp_attr.min_rnr_timer = 0x12;
ibv_qp_attr qpAttr = createQpAttr();
qpAttr.qp_state = IBV_QPS_RTR;
qpAttr.path_mtu = static_cast<ibv_mtu>(info.mtu);
qpAttr.dest_qp_num = info.qpn;
qpAttr.rq_psn = 0;
qpAttr.max_dest_rd_atomic = 1;
qpAttr.min_rnr_timer = 0x12;
if (info.linkLayer == IBV_LINK_LAYER_ETHERNET || info.is_grh) {
qp_attr.ah_attr.is_global = 1;
qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info.spn;
qp_attr.ah_attr.grh.dgid.global.interface_id = info.iid;
qp_attr.ah_attr.grh.flow_label = 0;
qp_attr.ah_attr.grh.sgid_index = 0;
qp_attr.ah_attr.grh.hop_limit = 255;
qp_attr.ah_attr.grh.traffic_class = 0;
qpAttr.ah_attr.is_global = 1;
qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info.spn;
qpAttr.ah_attr.grh.dgid.global.interface_id = info.iid;
qpAttr.ah_attr.grh.flow_label = 0;
qpAttr.ah_attr.grh.sgid_index = 0;
qpAttr.ah_attr.grh.hop_limit = 255;
qpAttr.ah_attr.grh.traffic_class = 0;
} else {
qp_attr.ah_attr.is_global = 0;
qpAttr.ah_attr.is_global = 0;
}
qp_attr.ah_attr.dlid = info.lid;
qp_attr.ah_attr.sl = 0;
qp_attr.ah_attr.src_path_bits = 0;
qp_attr.ah_attr.port_num = info.port;
int ret = ibv_modify_qp(this->qp, &qp_attr,
qpAttr.ah_attr.dlid = info.lid;
qpAttr.ah_attr.sl = 0;
qpAttr.ah_attr.src_path_bits = 0;
qpAttr.ah_attr.port_num = info.port;
int ret = ibv_modify_qp(this->qp, &qpAttr,
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER);
if (ret != 0) {
Expand All @@ -174,16 +178,15 @@ void IbQp::rtr(const IbQpInfo& info) {
}

void IbQp::rts() {
struct ibv_qp_attr qp_attr;
std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr));
qp_attr.qp_state = IBV_QPS_RTS;
qp_attr.timeout = 18;
qp_attr.retry_cnt = 7;
qp_attr.rnr_retry = 7;
qp_attr.sq_psn = 0;
qp_attr.max_rd_atomic = 1;
ibv_qp_attr qpAttr = createQpAttr();
qpAttr.qp_state = IBV_QPS_RTS;
qpAttr.timeout = 18;
qpAttr.retry_cnt = 7;
qpAttr.rnr_retry = 7;
qpAttr.sq_psn = 0;
qpAttr.max_rd_atomic = 1;
int ret = ibv_modify_qp(
this->qp, &qp_attr,
this->qp, &qpAttr,
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC);
if (ret != 0) {
std::stringstream err;
Expand Down Expand Up @@ -257,8 +260,9 @@ void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size,
wrInfo.sge->lkey = mr->getLkey();
}

void IbQp::stageSendGather(const std::vector<IbMr*>& srcMrs, const IbMrInfo& dstInfo, const std::vector<uint32_t>& srcSizes,
uint64_t wrId, const std::vector<uint64_t>& srcOffsets, uint64_t dstOffset, bool signaled) {
void IbQp::stageSendGather(const std::vector<IbMr*>& srcMrs, const IbMrInfo& dstInfo,
const std::vector<uint32_t>& srcSizes, uint64_t wrId,
const std::vector<uint64_t>& srcOffsets, uint64_t dstOffset, bool signaled) {
size_t numSrcs = srcMrs.size();
if (numSrcs != srcSizes.size() || numSrcs != srcOffsets.size()) {
std::stringstream err;
Expand All @@ -272,9 +276,11 @@ void IbQp::stageSendGather(const std::vector<IbMr*>& srcMrs, const IbMrInfo& dst
wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0;
wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstInfo.addr) + dstOffset;
wrInfo.wr->wr.rdma.rkey = dstInfo.rkey;
// wrInfo.sge->addr = (uint64_t)(mr->getBuff()) + srcOffset;
// wrInfo.sge->length = size;
// wrInfo.sge->lkey = mr->getLkey();
for (size_t i = 0; i < numSrcs; ++i) {
wrInfo.sge[i].addr = (uint64_t)(srcMrs[i]->getBuff()) + srcOffsets[i];
wrInfo.sge[i].length = srcSizes[i];
wrInfo.sge[i].lkey = srcMrs[i]->getLkey();
}
}

void IbQp::postSend() {
Expand Down Expand Up @@ -367,8 +373,8 @@ int IbCtx::getAnyActivePort() const {
return -1;
}

void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr,
int port) const {
void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend,
int maxNumSgesPerWr, int port) const {
if (!this->isPortUsable(port)) {
throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage);
}
Expand All @@ -393,16 +399,17 @@ void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int m
}
}

IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr,
int port /*=-1*/) {
IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend,
int maxNumSgesPerWr, int port /*=-1*/) {
if (port == -1) {
port = this->getAnyActivePort();
if (port == -1) {
throw mscclpp::Error("No active port found", ErrorCode::InternalError);
}
}
validateConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port);
qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr));
this->validateConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port);
qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend,
maxNumSgesPerWr));
return qps.back().get();
}

Expand Down
7 changes: 5 additions & 2 deletions src/include/ib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <list>
#include <memory>
#include <string>
#include <vector>

// Forward declarations of IB structures
struct ibv_context;
Expand Down Expand Up @@ -107,15 +108,17 @@ class IbCtx {
IbCtx(const std::string& devName);
~IbCtx();

IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port = -1);
IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr,
int port = -1);
const IbMr* registerMr(void* buff, std::size_t size);

const std::string& getDevName() const;

private:
bool isPortUsable(int port) const;
int getAnyActivePort() const;
void validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port) const;
void validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend,
int maxNumSgesPerWr, int port) const;

const std::string devName;
ibv_context* ctx;
Expand Down

0 comments on commit 89cad56

Please sign in to comment.