Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/MLX5: mcast multi-group support part 1 #1060

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ struct pp_packet {
};

struct mcast_ctx {
struct ibv_qp *qp;
struct ibv_ah *ah;
struct ibv_send_wr swr;
struct ibv_sge ssg;

Expand Down Expand Up @@ -310,8 +308,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
ucc_rank_t commsize;
char *grh_buf;
struct ibv_mr *grh_mr;
uint16_t mcast_lid;
union ibv_gid mgid;
uint16_t *lid_list;
union ibv_gid *mgid_list;
unsigned max_inline;
size_t max_eager;
int max_per_packet;
Expand All @@ -334,7 +332,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
int comm_id;
void *p2p_ctx;
ucc_base_lib_t *lib;
struct sockaddr_in6 mcast_addr;
struct sockaddr_in6 *mcast_addr_list;
int cuda_mem_enabled;
ucc_tl_mlx5_mcast_join_info_t *group_setup_info;
ucc_service_coll_req_t *group_setup_info_req;
Expand Down Expand Up @@ -490,7 +488,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast
}
if (i != 0) {
rwr[i-1].next = NULL;
if (ibv_post_recv(comm->mcast.qp, &rwr[0], &bad_wr)) {
if (ibv_post_recv(comm->mcast.qp_list[0], &rwr[0], &bad_wr)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hardcoded index to be replaced in part2

tl_error(comm->lib, "failed to prepost recvs: errno %d", errno);
return UCC_ERR_NO_RESOURCE;
}
Expand Down
201 changes: 137 additions & 64 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,19 @@ ucc_status_t ucc_tl_mlx5_setup_mcast_group_join_post(ucc_tl_mlx5_mcast_coll_comm
ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
ucc_tl_mlx5_mcast_coll_comm_t *comm)
{
struct ibv_qp_init_attr qp_init_attr = {0};
int max_inline = INT_MAX;
struct ibv_qp_init_attr qp_init_attr = {0};
int i;
int j;

comm->mcast.qp_list = ucc_malloc(comm->mcast_group_count * sizeof(struct ibv_qp *), "ibv_qp* list");
if (!comm->mcast.qp_list) {
tl_error(ctx->lib, "failed to allocate memory for ibv_qp*");
return UCC_ERR_NO_MEMORY;
}

qp_init_attr.qp_type = IBV_QPT_UD;
qp_init_attr.send_cq = comm->scq;
qp_init_attr.send_cq = comm->scq; //cq can be shared between multiple QPs
qp_init_attr.recv_cq = comm->rcq;
qp_init_attr.sq_sig_all = 0;
qp_init_attr.cap.max_send_wr = comm->params.sx_depth;
Expand All @@ -294,41 +303,78 @@ ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
qp_init_attr.cap.max_send_sge = comm->params.sx_sge;
qp_init_attr.cap.max_recv_sge = comm->params.rx_sge;

comm->mcast.qp = ibv_create_qp(ctx->pd, &qp_init_attr);
if (!comm->mcast.qp) {
tl_warn(ctx->lib, "failed to create mcast qp, errno %d", errno);
return UCC_ERR_NO_RESOURCE;
for (i = 0; i < comm->mcast_group_count; i++) {
comm->mcast.qp_list[i] = ibv_create_qp(ctx->pd, &qp_init_attr);
if (!comm->mcast.qp_list[i]) {
tl_error(ctx->lib, "Failed to create mcast UD qp index %d, errno %d", i, errno);
goto error;
}
if (qp_init_attr.cap.max_inline_data < max_inline) {
max_inline = qp_init_attr.cap.max_inline_data;
}
}

if (comm->cuda_mem_enabled) {
/* max inline send otherwise it segfault during ibv send */
comm->max_inline = 0;
} else {
comm->max_inline = qp_init_attr.cap.max_inline_data;
comm->max_inline = max_inline;
}

return UCC_OK;

error:
for (j = 0; j < i; j++) {
ibv_destroy_qp(comm->mcast.qp_list[j]);
}
ucc_free(comm->mcast.qp_list);
comm->mcast.qp_list = NULL;

return UCC_ERR_NO_RESOURCE;
}

static ucc_status_t ucc_tl_mlx5_mcast_create_ah(ucc_tl_mlx5_mcast_coll_comm_t *comm)
{
int i, j, ret;
struct ibv_ah_attr ah_attr = {
.is_global = 1,
.grh = {.sgid_index = 0},
.dlid = comm->mcast_lid,
.sl = DEF_SL,
.src_path_bits = DEF_SRC_PATH_BITS,
.port_num = comm->ctx->ib_port
};

memcpy(ah_attr.grh.dgid.raw, &comm->mgid, sizeof(ah_attr.grh.dgid.raw));
comm->mcast.ah_list = ucc_malloc(comm->mcast_group_count * sizeof(struct ibv_ah *), "ibv_ah array");
if (!comm->mcast.ah_list) {
tl_error(comm->lib, "failed to allocate memory for mcast address handle of size %lu",
comm->mcast_group_count * sizeof(struct ibv_ah *));
return UCC_ERR_NO_MEMORY;
}

comm->mcast.ah = ibv_create_ah(comm->ctx->pd, &ah_attr);
if (!comm->mcast.ah) {
tl_warn(comm->lib, "failed to create AH");
return UCC_ERR_NO_RESOURCE;
for (i = 0; i < comm->mcast_group_count; i ++) {
ah_attr.dlid = comm->lid_list[i];
memcpy(ah_attr.grh.dgid.raw, &comm->mgid_list[i], sizeof(ah_attr.grh.dgid.raw));

comm->mcast.ah_list[i] = ibv_create_ah(comm->ctx->pd, &ah_attr);
if (!comm->mcast.ah_list[i]) {
tl_error(comm->lib, "failed to create AH index %d", i);
goto error;
}
}

return UCC_OK;

error:
for (j = 0; j < i; j++) {
ret = ibv_destroy_ah(comm->mcast.ah_list[j]);
if (ret) {
tl_error(comm->lib, "couldn't destroy ah");
return UCC_ERR_NO_RESOURCE;
}
}
ucc_free(comm->mcast.ah_list);
comm->mcast.ah_list = NULL;
return UCC_ERR_NO_RESOURCE;
}

ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
Expand All @@ -337,16 +383,15 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
struct ibv_port_attr port_attr;
struct ibv_qp_attr attr;
uint16_t pkey;
int i;

ibv_query_port(ctx->ctx, ctx->ib_port, &port_attr);

for (ctx->pkey_index = 0; ctx->pkey_index < port_attr.pkey_tbl_len;
++ctx->pkey_index) {
ibv_query_pkey(ctx->ctx, ctx->ib_port, ctx->pkey_index, &pkey);
if (pkey == DEF_PKEY)
break;
}

if (ctx->pkey_index >= port_attr.pkey_tbl_len) {
ctx->pkey_index = 0;
ibv_query_pkey(ctx->ctx, ctx->ib_port, ctx->pkey_index, &pkey);
Expand All @@ -359,43 +404,53 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
"index 0 pkey:0x%04x", DEF_PKEY, ctx->ib_port, pkey);
}

attr.qp_state = IBV_QPS_INIT;
attr.pkey_index = ctx->pkey_index;
attr.port_num = ctx->ib_port;
attr.qkey = DEF_QKEY;
for (i = 0; i < comm->mcast_group_count; i++) {
attr.qp_state = IBV_QPS_INIT;
attr.pkey_index = ctx->pkey_index;
attr.port_num = ctx->ib_port;
attr.qkey = DEF_QKEY;

if (ibv_modify_qp(comm->mcast.qp, &attr,
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY)) {
tl_warn(ctx->lib, "failed to move mcast qp to INIT, errno %d", errno);
return UCC_ERR_NO_RESOURCE;
}
if (ibv_modify_qp(comm->mcast.qp_list[i], &attr,
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY)) {
tl_error(ctx->lib, "failed to move mcast qp to INIT, errno %d", errno);
goto error;
}

if (ibv_attach_mcast(comm->mcast.qp, &comm->mgid, comm->mcast_lid)) {
tl_warn(ctx->lib, "failed to attach QP to the mcast group, errno %d", errno);
return UCC_ERR_NO_RESOURCE;
}
if (ibv_attach_mcast(comm->mcast.qp_list[i], &comm->mgid_list[i], comm->lid_list[i])) {
tl_error(ctx->lib, "failed to attach QP to the mcast group with mcast_lid %d , errno %d", errno, comm->lid_list[i]);
goto error;
}

/* Ok, now cycle to RTR on everyone */
attr.qp_state = IBV_QPS_RTR;
if (ibv_modify_qp(comm->mcast.qp, &attr, IBV_QP_STATE)) {
tl_warn(ctx->lib, "failed to modify QP to RTR, errno %d", errno);
return UCC_ERR_NO_RESOURCE;
}
attr.qp_state = IBV_QPS_RTR;
if (ibv_modify_qp(comm->mcast.qp_list[i], &attr, IBV_QP_STATE)) {
tl_error(ctx->lib, "failed to modify QP to RTR, errno %d", errno);
goto error;
}

attr.qp_state = IBV_QPS_RTS;
attr.sq_psn = DEF_PSN;
if (ibv_modify_qp(comm->mcast.qp, &attr, IBV_QP_STATE | IBV_QP_SQ_PSN)) {
tl_warn(ctx->lib, "failed to modify QP to RTS, errno %d", errno);
return UCC_ERR_NO_RESOURCE;
attr.qp_state = IBV_QPS_RTS;
attr.sq_psn = DEF_PSN;
if (ibv_modify_qp(comm->mcast.qp_list[i], &attr, IBV_QP_STATE | IBV_QP_SQ_PSN)) {
tl_error(ctx->lib, "failed to modify QP to RTS, errno %d", errno);
goto error;
}
}

/* Create the address handle */
/* create the address handle */
if (UCC_OK != ucc_tl_mlx5_mcast_create_ah(comm)) {
tl_warn(ctx->lib, "failed to create adress handle");
return UCC_ERR_NO_RESOURCE;
goto error;
}

return UCC_OK;

error:
for (i=0; i < comm->mcast_group_count; i++) {
ibv_destroy_qp(comm->mcast.qp_list[i]);
}
ucc_free(comm->mcast.qp_list);
comm->mcast.qp_list = NULL;

return UCC_ERR_NO_RESOURCE;
}

ucc_status_t ucc_tl_mlx5_mcast_create_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
Expand Down Expand Up @@ -538,15 +593,15 @@ ucc_status_t ucc_tl_mlx5_fini_mcast_group(ucc_tl_mlx5_mcast_coll_context_t *ctx,
char buf[40];
const char *dst;

dst = inet_ntop(AF_INET6, &comm->mcast_addr, buf, 40);
dst = inet_ntop(AF_INET6, &comm->mcast_addr_list[0], buf, 40);
if (NULL == dst) {
tl_error(comm->lib, "inet_ntop failed");
return UCC_ERR_NO_RESOURCE;
}

tl_debug(ctx->lib, "mcast leave: ctx %p, comm %p, dgid: %s", ctx, comm, buf);

if (rdma_leave_multicast(ctx->id, (struct sockaddr*)&comm->mcast_addr)) {
if (rdma_leave_multicast(ctx->id, (struct sockaddr*)&comm->mcast_addr_list[0])) {
tl_error(comm->lib, "mcast rmda_leave_multicast failed");
return UCC_ERR_NO_RESOURCE;
}
Expand All @@ -559,11 +614,10 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
ucc_tl_mlx5_mcast_context_t *mcast_ctx = ucc_container_of(comm->ctx, ucc_tl_mlx5_mcast_context_t, mcast_context);
ucc_tl_mlx5_context_t *mlx5_ctx = ucc_container_of(mcast_ctx, ucc_tl_mlx5_context_t, mcast);
ucc_context_h context = mlx5_ctx->super.super.ucc_context;
int ret;
int ret, i;
ucc_status_t status;

tl_debug(comm->lib, "cleaning mcast comm: %p, id %d, mlid %x",
comm, comm->comm_id, comm->mcast_lid);
tl_debug(comm->lib, "cleaning mcast comm: %p, id %d", comm, comm->comm_id);

while (UCC_INPROGRESS == (status = ucc_tl_mlx5_mcast_reliable(comm))) {
ucc_context_progress(context);
Expand All @@ -575,20 +629,26 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
return status;
}

if (comm->mcast.qp) {
ret = ibv_detach_mcast(comm->mcast.qp, &comm->mgid, comm->mcast_lid);
if (ret) {
tl_error(comm->lib, "couldn't detach QP, ret %d, errno %d", ret, errno);
return UCC_ERR_NO_RESOURCE;
}
}
if (comm->mcast.qp_list) {
for (i = 0; i < comm->mcast_group_count; i++) {
if (comm->mcast.qp_list[i]) {
ret = ibv_detach_mcast(comm->mcast.qp_list[i], &(comm->mgid_list[i]), comm->lid_list[i]);
if (ret) {
tl_error(comm->lib, "couldn't detach QP, ret %d, errno %d", ret, errno);
return UCC_ERR_NO_RESOURCE;
}

if (comm->mcast.qp) {
ret = ibv_destroy_qp(comm->mcast.qp);
if (ret) {
tl_error(comm->lib, "failed to destroy QP %d", ret);
return UCC_ERR_NO_RESOURCE;
ret = ibv_destroy_qp(comm->mcast.qp_list[i]);
if (ret) {
tl_error(comm->lib, "failed to destroy QP %d", ret);
return UCC_ERR_NO_RESOURCE;
}

comm->mcast.qp_list[i] = NULL;
}
}
ucc_free(comm->mcast.qp_list);
comm->mcast.qp_list = NULL;
}

if (comm->rcq) {
Expand Down Expand Up @@ -643,20 +703,33 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
ucc_free(comm->call_rsgs);
}

if (comm->mcast.ah) {
ret = ibv_destroy_ah(comm->mcast.ah);
if (ret) {
tl_error(comm->lib, "couldn't destroy ah");
return UCC_ERR_NO_RESOURCE;
if (comm->mcast.ah_list) {
for (i = 0; i < comm->mcast_group_count; i++) {
if (comm->mcast.ah_list[i]) {
ret = ibv_destroy_ah(comm->mcast.ah_list[i]);
if (ret) {
tl_error(comm->lib, "couldn't destroy ah");
return UCC_ERR_NO_RESOURCE;
}
comm->mcast.ah_list[i] = NULL;
}
}
ucc_free(comm->mcast.ah_list);
comm->mcast.ah_list = NULL;
}

if (comm->mcast_lid) {
if (comm->lid_list) {
status = ucc_tl_mlx5_fini_mcast_group(comm->ctx, comm);
if (status) {
tl_error(comm->lib, "couldn't leave mcast group");
return status;
}
ucc_free(comm->lid_list);
ucc_free(comm->mgid_list);
ucc_free(comm->mcast_addr_list);
comm->lid_list = NULL;
comm->lid_list = NULL;
comm->mcast_addr_list = NULL;
}

if (comm->ctx->params.print_nack_stats) {
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t
tl_trace(comm->lib, "post_send, psn %d, length %d, zcopy %d, signaled %d",
pp->psn, pp->length, zcopy, swr[0].send_flags & IBV_SEND_SIGNALED);

if (0 != (rc = ibv_post_send(comm->mcast.qp, &swr[0], &bad_wr))) {
if (0 != (rc = ibv_post_send(comm->mcast.qp_list[0], &swr[0], &bad_wr))) {
tl_error(comm->lib, "post send failed: ret %d, start_psn %d, to_send %d, "
"to_recv %d, length %d, psn %d, inline %d",
rc, req->start_psn, req->to_send, req->to_recv,
Expand Down
Loading
Loading