diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 663ee636ed..f04d9d159a 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -200,8 +200,6 @@ struct pp_packet { }; struct mcast_ctx { - struct ibv_qp *qp; - struct ibv_ah *ah; struct ibv_send_wr swr; struct ibv_sge ssg; @@ -310,8 +308,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { ucc_rank_t commsize; char *grh_buf; struct ibv_mr *grh_mr; - uint16_t mcast_lid; - union ibv_gid mgid; + uint16_t *lid_list; + union ibv_gid *mgid_list; unsigned max_inline; size_t max_eager; int max_per_packet; @@ -334,7 +332,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { int comm_id; void *p2p_ctx; ucc_base_lib_t *lib; - struct sockaddr_in6 mcast_addr; + struct sockaddr_in6 *mcast_addr_list; int cuda_mem_enabled; ucc_tl_mlx5_mcast_join_info_t *group_setup_info; ucc_service_coll_req_t *group_setup_info_req; @@ -490,7 +488,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast } if (i != 0) { rwr[i-1].next = NULL; - if (ibv_post_recv(comm->mcast.qp, &rwr[0], &bad_wr)) { + if (ibv_post_recv(comm->mcast.qp_list[0], &rwr[0], &bad_wr)) { tl_error(comm->lib, "failed to prepost recvs: errno %d", errno); return UCC_ERR_NO_RESOURCE; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c index a116c08cf8..a3016a786b 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c @@ -282,10 +282,19 @@ ucc_status_t ucc_tl_mlx5_setup_mcast_group_join_post(ucc_tl_mlx5_mcast_coll_comm ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_tl_mlx5_mcast_coll_comm_t *comm) { - struct ibv_qp_init_attr qp_init_attr = {0}; + int max_inline = INT_MAX; + struct ibv_qp_init_attr qp_init_attr = {0}; + int i; + int j; + + comm->mcast.qp_list = ucc_malloc(comm->mcast_group_count * sizeof(struct ibv_qp *), "ibv_qp* list"); + if (!comm->mcast.qp_list) { + tl_error(ctx->lib, "failed to allocate memory for ibv_qp*"); + return UCC_ERR_NO_MEMORY; + } qp_init_attr.qp_type = IBV_QPT_UD; - qp_init_attr.send_cq = comm->scq; + qp_init_attr.send_cq = comm->scq; //cq can be shared between multiple QPs qp_init_attr.recv_cq = comm->rcq; qp_init_attr.sq_sig_all = 0; qp_init_attr.cap.max_send_wr = comm->params.sx_depth; @@ -294,41 +303,78 @@ ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, qp_init_attr.cap.max_send_sge = comm->params.sx_sge; qp_init_attr.cap.max_recv_sge = comm->params.rx_sge; - comm->mcast.qp = ibv_create_qp(ctx->pd, &qp_init_attr); - if (!comm->mcast.qp) { - tl_warn(ctx->lib, "failed to create mcast qp, errno %d", errno); - return UCC_ERR_NO_RESOURCE; + for (i = 0; i < comm->mcast_group_count; i++) { + comm->mcast.qp_list[i] = ibv_create_qp(ctx->pd, &qp_init_attr); + if (!comm->mcast.qp_list[i]) { + tl_error(ctx->lib, "Failed to create mcast UD qp index %d, errno %d", i, errno); + goto error; + } + if (qp_init_attr.cap.max_inline_data < max_inline) { + max_inline = qp_init_attr.cap.max_inline_data; + } } if (comm->cuda_mem_enabled) { /* max inline send otherwise it segfault during ibv send */ comm->max_inline = 0; } else { - comm->max_inline = qp_init_attr.cap.max_inline_data; + comm->max_inline = max_inline; } return UCC_OK; + +error: + for (j = 0; j < i; j++) { + ibv_destroy_qp(comm->mcast.qp_list[j]); + } + ucc_free(comm->mcast.qp_list); + comm->mcast.qp_list = NULL; + + return UCC_ERR_NO_RESOURCE; } static ucc_status_t ucc_tl_mlx5_mcast_create_ah(ucc_tl_mlx5_mcast_coll_comm_t *comm) { + int i, j, ret; struct ibv_ah_attr ah_attr = { .is_global = 1, .grh = {.sgid_index = 0}, - .dlid = comm->mcast_lid, .sl = DEF_SL, .src_path_bits = DEF_SRC_PATH_BITS, .port_num = comm->ctx->ib_port }; - memcpy(ah_attr.grh.dgid.raw, &comm->mgid, sizeof(ah_attr.grh.dgid.raw)); + comm->mcast.ah_list = ucc_malloc(comm->mcast_group_count * sizeof(struct ibv_ah *), "ibv_ah array"); + if (!comm->mcast.ah_list) { + tl_error(comm->lib, "failed to allocate memory for mcast address handle of size %lu", + comm->mcast_group_count * sizeof(struct ibv_ah *)); + return UCC_ERR_NO_MEMORY; + } - comm->mcast.ah = ibv_create_ah(comm->ctx->pd, &ah_attr); - if (!comm->mcast.ah) { - tl_warn(comm->lib, "failed to create AH"); - return UCC_ERR_NO_RESOURCE; + for (i = 0; i < comm->mcast_group_count; i ++) { + ah_attr.dlid = comm->lid_list[i]; + memcpy(ah_attr.grh.dgid.raw, &comm->mgid_list[i], sizeof(ah_attr.grh.dgid.raw)); + + comm->mcast.ah_list[i] = ibv_create_ah(comm->ctx->pd, &ah_attr); + if (!comm->mcast.ah_list[i]) { + tl_error(comm->lib, "failed to create AH index %d", i); + goto error; + } } + return UCC_OK; + +error: + for (j = 0; j < i; j++) { + ret = ibv_destroy_ah(comm->mcast.ah_list[j]); + if (ret) { + tl_error(comm->lib, "couldn't destroy ah"); + return UCC_ERR_NO_RESOURCE; + } + } + ucc_free(comm->mcast.ah_list); + comm->mcast.ah_list = NULL; + return UCC_ERR_NO_RESOURCE; } ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, @@ -337,16 +383,15 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, struct ibv_port_attr port_attr; struct ibv_qp_attr attr; uint16_t pkey; + int i; ibv_query_port(ctx->ctx, ctx->ib_port, &port_attr); - for (ctx->pkey_index = 0; ctx->pkey_index < port_attr.pkey_tbl_len; ++ctx->pkey_index) { ibv_query_pkey(ctx->ctx, ctx->ib_port, ctx->pkey_index, &pkey); if (pkey == DEF_PKEY) break; } - if (ctx->pkey_index >= port_attr.pkey_tbl_len) { ctx->pkey_index = 0; ibv_query_pkey(ctx->ctx, ctx->ib_port, ctx->pkey_index, &pkey); @@ -359,43 +404,53 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, "index 0 pkey:0x%04x", DEF_PKEY, ctx->ib_port, pkey); } - attr.qp_state = IBV_QPS_INIT; - attr.pkey_index = ctx->pkey_index; - attr.port_num = ctx->ib_port; - attr.qkey = DEF_QKEY; + for (i = 0; i < comm->mcast_group_count; i++) { + attr.qp_state = IBV_QPS_INIT; + attr.pkey_index = ctx->pkey_index; + attr.port_num = ctx->ib_port; + attr.qkey = DEF_QKEY; - if (ibv_modify_qp(comm->mcast.qp, &attr, - IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY)) { - tl_warn(ctx->lib, "failed to move mcast qp to INIT, errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } + if (ibv_modify_qp(comm->mcast.qp_list[i], &attr, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY)) { + tl_error(ctx->lib, "failed to move mcast qp to INIT, errno %d", errno); + goto error; + } - if (ibv_attach_mcast(comm->mcast.qp, &comm->mgid, comm->mcast_lid)) { - tl_warn(ctx->lib, "failed to attach QP to the mcast group, errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } + if (ibv_attach_mcast(comm->mcast.qp_list[i], &comm->mgid_list[i], comm->lid_list[i])) { + tl_error(ctx->lib, "failed to attach QP to the mcast group with mcast_lid %d , errno %d", errno, comm->lid_list[i]); + goto error; + } - /* Ok, now cycle to RTR on everyone */ - attr.qp_state = IBV_QPS_RTR; - if (ibv_modify_qp(comm->mcast.qp, &attr, IBV_QP_STATE)) { - tl_warn(ctx->lib, "failed to modify QP to RTR, errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } + attr.qp_state = IBV_QPS_RTR; + if (ibv_modify_qp(comm->mcast.qp_list[i], &attr, IBV_QP_STATE)) { + tl_error(ctx->lib, "failed to modify QP to RTR, errno %d", errno); + goto error; + } - attr.qp_state = IBV_QPS_RTS; - attr.sq_psn = DEF_PSN; - if (ibv_modify_qp(comm->mcast.qp, &attr, IBV_QP_STATE | IBV_QP_SQ_PSN)) { - tl_warn(ctx->lib, "failed to modify QP to RTS, errno %d", errno); - return UCC_ERR_NO_RESOURCE; + attr.qp_state = IBV_QPS_RTS; + attr.sq_psn = DEF_PSN; + if (ibv_modify_qp(comm->mcast.qp_list[i], &attr, IBV_QP_STATE | IBV_QP_SQ_PSN)) { + tl_error(ctx->lib, "failed to modify QP to RTS, errno %d", errno); + goto error; + } } - /* Create the address handle */ + /* create the address handle */ if (UCC_OK != ucc_tl_mlx5_mcast_create_ah(comm)) { tl_warn(ctx->lib, "failed to create adress handle"); - return UCC_ERR_NO_RESOURCE; + goto error; } return UCC_OK; + +error: + for (i=0; i < comm->mcast_group_count; i++) { + ibv_destroy_qp(comm->mcast.qp_list[i]); + } + ucc_free(comm->mcast.qp_list); + comm->mcast.qp_list = NULL; + + return UCC_ERR_NO_RESOURCE; } ucc_status_t ucc_tl_mlx5_mcast_create_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, @@ -538,7 +593,7 @@ ucc_status_t ucc_tl_mlx5_fini_mcast_group(ucc_tl_mlx5_mcast_coll_context_t *ctx, char buf[40]; const char *dst; - dst = inet_ntop(AF_INET6, &comm->mcast_addr, buf, 40); + dst = inet_ntop(AF_INET6, &comm->mcast_addr_list[0], buf, 40); if (NULL == dst) { tl_error(comm->lib, "inet_ntop failed"); return UCC_ERR_NO_RESOURCE; @@ -546,7 +601,7 @@ ucc_status_t ucc_tl_mlx5_fini_mcast_group(ucc_tl_mlx5_mcast_coll_context_t *ctx, tl_debug(ctx->lib, "mcast leave: ctx %p, comm %p, dgid: %s", ctx, comm, buf); - if (rdma_leave_multicast(ctx->id, (struct sockaddr*)&comm->mcast_addr)) { + if (rdma_leave_multicast(ctx->id, (struct sockaddr*)&comm->mcast_addr_list[0])) { tl_error(comm->lib, "mcast rmda_leave_multicast failed"); return UCC_ERR_NO_RESOURCE; } @@ -559,11 +614,10 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) ucc_tl_mlx5_mcast_context_t *mcast_ctx = ucc_container_of(comm->ctx, ucc_tl_mlx5_mcast_context_t, mcast_context); ucc_tl_mlx5_context_t *mlx5_ctx = ucc_container_of(mcast_ctx, ucc_tl_mlx5_context_t, mcast); ucc_context_h context = mlx5_ctx->super.super.ucc_context; - int ret; + int ret, i; ucc_status_t status; - tl_debug(comm->lib, "cleaning mcast comm: %p, id %d, mlid %x", - comm, comm->comm_id, comm->mcast_lid); + tl_debug(comm->lib, "cleaning mcast comm: %p, id %d", comm, comm->comm_id); while (UCC_INPROGRESS == (status = ucc_tl_mlx5_mcast_reliable(comm))) { ucc_context_progress(context); @@ -575,20 +629,26 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) return status; } - if (comm->mcast.qp) { - ret = ibv_detach_mcast(comm->mcast.qp, &comm->mgid, comm->mcast_lid); - if (ret) { - tl_error(comm->lib, "couldn't detach QP, ret %d, errno %d", ret, errno); - return UCC_ERR_NO_RESOURCE; - } - } + if (comm->mcast.qp_list) { + for (i = 0; i < comm->mcast_group_count; i++) { + if (comm->mcast.qp_list[i]) { + ret = ibv_detach_mcast(comm->mcast.qp_list[i], &(comm->mgid_list[i]), comm->lid_list[i]); + if (ret) { + tl_error(comm->lib, "couldn't detach QP, ret %d, errno %d", ret, errno); + return UCC_ERR_NO_RESOURCE; + } - if (comm->mcast.qp) { - ret = ibv_destroy_qp(comm->mcast.qp); - if (ret) { - tl_error(comm->lib, "failed to destroy QP %d", ret); - return UCC_ERR_NO_RESOURCE; + ret = ibv_destroy_qp(comm->mcast.qp_list[i]); + if (ret) { + tl_error(comm->lib, "failed to destroy QP %d", ret); + return UCC_ERR_NO_RESOURCE; + } + + comm->mcast.qp_list[i] = NULL; + } } + ucc_free(comm->mcast.qp_list); + comm->mcast.qp_list = NULL; } if (comm->rcq) { @@ -643,20 +703,33 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) ucc_free(comm->call_rsgs); } - if (comm->mcast.ah) { - ret = ibv_destroy_ah(comm->mcast.ah); - if (ret) { - tl_error(comm->lib, "couldn't destroy ah"); - return UCC_ERR_NO_RESOURCE; + if (comm->mcast.ah_list) { + for (i = 0; i < comm->mcast_group_count; i++) { + if (comm->mcast.ah_list[i]) { + ret = ibv_destroy_ah(comm->mcast.ah_list[i]); + if (ret) { + tl_error(comm->lib, "couldn't destroy ah"); + return UCC_ERR_NO_RESOURCE; + } + comm->mcast.ah_list[i] = NULL; + } } + ucc_free(comm->mcast.ah_list); + comm->mcast.ah_list = NULL; } - if (comm->mcast_lid) { + if (comm->lid_list) { status = ucc_tl_mlx5_fini_mcast_group(comm->ctx, comm); if (status) { tl_error(comm->lib, "couldn't leave mcast group"); return status; } + ucc_free(comm->lid_list); + ucc_free(comm->mgid_list); + ucc_free(comm->mcast_addr_list); + comm->lid_list = NULL; + comm->lid_list = NULL; + comm->mcast_addr_list = NULL; } if (comm->ctx->params.print_nack_stats) { diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h index d0b1a1ddd3..138a3f57ad 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h @@ -108,7 +108,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t tl_trace(comm->lib, "post_send, psn %d, length %d, zcopy %d, signaled %d", pp->psn, pp->length, zcopy, swr[0].send_flags & IBV_SEND_SIGNALED); - if (0 != (rc = ibv_post_send(comm->mcast.qp, &swr[0], &bad_wr))) { + if (0 != (rc = ibv_post_send(comm->mcast.qp_list[0], &swr[0], &bad_wr))) { tl_error(comm->lib, "post send failed: ret %d, start_psn %d, to_send %d, " "to_recv %d, length %d, psn %d, inline %d", rc, req->start_psn, req->to_send, req->to_recv, diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index b70ca6e2f6..23948c92e4 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -150,6 +150,30 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, comm->r_window[i] = &comm->dummy_packet; } + comm->mcast_addr_list = ucc_calloc(comm->mcast_group_count, sizeof(struct sockaddr_in6), + "mcast_addr_list"); + if (!comm->mcast_addr_list) { + tl_error(mcast_context->lib, "unable to allocate memory for mcast net address"); + status = UCC_ERR_NO_MEMORY; + goto cleanup; + } + + comm->lid_list = ucc_calloc(comm->mcast_group_count, sizeof(uint16_t), + "lid_list"); + if (!comm->lid_list) { + tl_error(mcast_context->lib, "unable to allocate memory for mcast lid"); + status = UCC_ERR_NO_MEMORY; + goto cleanup; + } + + comm->mgid_list = ucc_calloc(comm->mcast_group_count, sizeof(union ibv_gid), + "mgid_list"); + if (!comm->mgid_list) { + tl_error(mcast_context->lib, "unable to allocate memory for mcast mgid"); + status = UCC_ERR_NO_MEMORY; + goto cleanup; + } + comm->lib = base_context->lib; new_mcast_team->mcast_comm = comm; *mcast_team = new_mcast_team; @@ -159,6 +183,9 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, return UCC_OK; cleanup: + ucc_free(comm->mcast_addr_list); + ucc_free(comm->lid_list); + ucc_free(comm->mgid_list); ucc_free(comm); ucc_free(new_mcast_team); ucc_free(oob_p2p_ctx); @@ -263,7 +290,7 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_ ucc_list_add_tail(&comm->bpool, &comm->pp[i].super); } - comm->mcast.swr.wr.ud.ah = comm->mcast.ah; + comm->mcast.swr.wr.ud.ah = comm->mcast.ah_list[0]; comm->mcast.swr.num_sge = 1; comm->mcast.swr.sg_list = &comm->mcast.ssg; comm->mcast.swr.opcode = IBV_WR_SEND_WITH_IMM; @@ -325,8 +352,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) return UCC_INPROGRESS; } - comm->mcast_addr = net_addr; - tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; + comm->mcast_addr_list[0] = net_addr; + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; return UCC_INPROGRESS; } @@ -373,11 +400,11 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY) { /* rank 0 bcast the lid/gid to other processes */ - data->status = UCC_OK; - data->dgid = comm->event->param.ud.ah_attr.grh.dgid; - data->dlid = comm->event->param.ud.ah_attr.dlid; - comm->mcast_lid = data->dlid; - comm->mgid = data->dgid; + data->status = UCC_OK; + data->dgid = comm->event->param.ud.ah_attr.grh.dgid; + data->dlid = comm->event->param.ud.ah_attr.dlid; + comm->lid_list[0] = data->dlid; + comm->mgid_list[0] = data->dgid; } else { /* rank 0 bcast the failed status to other processes so others do not hang */ data->status = UCC_ERR_NO_RESOURCE; @@ -522,8 +549,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) return status; } - comm->mcast_addr = net_addr; - tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; + comm->mcast_addr_list[0] = net_addr; + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; return UCC_INPROGRESS; } @@ -549,8 +576,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) ucc_assert(comm->event != NULL); - comm->mcast_lid = comm->group_setup_info->dlid; - comm->mgid = comm->group_setup_info->dgid; + comm->lid_list[0] = comm->group_setup_info->dlid; + comm->mgid_list[0] = comm->group_setup_info->dgid; ucc_free(comm->group_setup_info); if (comm->event) {