From be427296f217ab3011864985f78e1cdbd28a4185 Mon Sep 17 00:00:00 2001 From: Mamzi Bayatpour <77160721+MamziB@users.noreply.github.com> Date: Wed, 10 Apr 2024 17:14:19 -0700 Subject: [PATCH] TL/MLX5: enhance reliablity protocol in Mcast (#957) Co-authored-by: Manjunath Gorentla Venkata --- .../tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c | 36 +++---- src/components/tl/mlx5/mcast/tl_mlx5_mcast.h | 17 ++-- .../tl/mlx5/mcast/tl_mlx5_mcast_coll.c | 23 +++-- .../tl/mlx5/mcast/tl_mlx5_mcast_progress.c | 93 +++++++++++-------- .../tl/mlx5/mcast/tl_mlx5_mcast_progress.h | 12 --- .../tl/mlx5/mcast/tl_mlx5_mcast_team.c | 1 + 6 files changed, 95 insertions(+), 87 deletions(-) diff --git a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c index d5c5d9dfb4..11be1473d2 100644 --- a/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c +++ b/src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c @@ -27,9 +27,9 @@ void ucc_tl_mlx5_mcast_completion_cb(void* context, ucc_status_t status) //NOLIN static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t len, ucc_rank_t my_team_rank, ucc_rank_t dest, - ucc_team_h team, ucc_context_h ctx, - ucc_coll_callback_t *callback, - ucc_coll_req_h *p2p_req, int is_send) + ucc_team_h team, ucc_coll_callback_t *callback, + ucc_coll_req_h *p2p_req, int is_send, + ucc_base_lib_t *lib) { ucc_status_t status = UCC_OK; ucc_coll_req_h req = NULL; @@ -47,11 +47,11 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t args.cb.data = callback->data; args.active_set.size = 2; args.active_set.start = my_team_rank; - args.active_set.stride = dest - my_team_rank; + args.active_set.stride = (int)dest - (int)my_team_rank; status = ucc_collective_init(&args, &req, team); if (ucc_unlikely(UCC_OK != status)) { - tl_error(ctx->lib, "nonblocking p2p init failed"); + tl_error(lib, "nonblocking p2p init failed"); return status; } @@ -59,7 +59,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t status = ucc_collective_post(req); if (ucc_unlikely(UCC_OK != status)) { - tl_error(ctx->lib, "nonblocking p2p post failed"); + tl_error(lib, "nonblocking p2p post failed"); return status; } @@ -70,20 +70,20 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t static inline ucc_status_t do_send_nb(void *sbuf, size_t len, ucc_rank_t my_team_rank, ucc_rank_t dest, ucc_team_h team, - ucc_context_h ctx, ucc_coll_callback_t - *callback, ucc_coll_req_h *req) + ucc_coll_callback_t *callback, + ucc_coll_req_h *req, ucc_base_lib_t *lib) { return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest, - team, ctx, callback, req, 1); + team, callback, req, 1, lib); } static inline ucc_status_t do_recv_nb(void *rbuf, size_t len, ucc_rank_t my_team_rank, ucc_rank_t dest, ucc_team_h team, - ucc_context_h ctx, ucc_coll_callback_t - *callback, ucc_coll_req_h *req) + ucc_coll_callback_t *callback, + ucc_coll_req_h *req, ucc_base_lib_t *lib) { return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest, - team, ctx, callback, req, 0); + team, callback, req, 0, lib); } ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t @@ -97,16 +97,16 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t ucc_coll_req_h req = NULL; ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank; ucc_team_h team = oob_p2p_ctx->base_team; - ucc_context_h ctx = oob_p2p_ctx->base_ctx; ucc_coll_callback_t callback; callback.cb = ucc_tl_mlx5_mcast_completion_cb; callback.data = obj; - status = do_send_nb(src, size, my_team_rank, rank, team, ctx, &callback, &req); + tl_trace(oob_p2p_ctx->lib, "P2P: SEND to %d Msg Size %ld", rank, size); + status = do_send_nb(src, size, my_team_rank, rank, team, &callback, &req, oob_p2p_ctx->lib); if (status < 0) { - tl_error(ctx->lib, "nonblocking p2p send failed"); + tl_error(oob_p2p_ctx->lib, "nonblocking p2p send failed"); return status; } @@ -124,16 +124,16 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t ucc_coll_req_h req = NULL; ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank; ucc_team_h team = oob_p2p_ctx->base_team; - ucc_context_h ctx = oob_p2p_ctx->base_ctx; ucc_coll_callback_t callback; callback.cb = ucc_tl_mlx5_mcast_completion_cb; callback.data = obj; - status = do_recv_nb(dst, size, my_team_rank, rank, team, ctx, &callback, &req); + tl_trace(oob_p2p_ctx->lib, "P2P: RECV to %d Msg Size %ld", rank, size); + status = do_recv_nb(dst, size, my_team_rank, rank, team, &callback, &req, oob_p2p_ctx->lib); if (status < 0) { - tl_error(ctx->lib, "nonblocking p2p recv failed"); + tl_error(oob_p2p_ctx->lib, "nonblocking p2p recv failed"); return status; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 28626b7d13..8c261d830c 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -30,7 +30,7 @@ #define DEF_SL 0 #define DEF_SRC_PATH_BITS 0 #define GRH_LENGTH 40 -#define DROP_THRESHOLD 1000000 +#define DROP_THRESHOLD 1000 #define MAX_COMM_POW2 32 enum { @@ -41,7 +41,8 @@ enum { enum { MCAST_P2P_NACK, MCAST_P2P_ACK, - MCAST_P2P_NEED_NACK_SEND + MCAST_P2P_NEED_NACK_SEND, + MCAST_P2P_NACK_SEND_PENDING }; enum { @@ -138,7 +139,6 @@ typedef struct ucc_tl_mlx5_mcast_coll_context { struct rdma_cm_id *id; struct rdma_event_channel *channel; ucc_mpool_t compl_objects_mp; - ucc_mpool_t nack_reqs_mp; ucc_list_link_t pending_nacks_list; ucc_rcache_t *rcache; ucc_tl_mlx5_mcast_ctx_params_t params; @@ -308,10 +308,11 @@ typedef struct ucc_tl_mlx5_mcast_coll_req { } ucc_tl_mlx5_mcast_coll_req_t; typedef struct ucc_tl_mlx5_mcast_oob_p2p_context { - ucc_context_h base_ctx; - ucc_team_h base_team; - ucc_rank_t my_team_rank; - ucc_subset_t subset; + ucc_context_h base_ctx; + ucc_team_h base_team; + ucc_rank_t my_team_rank; + ucc_subset_t subset; + ucc_base_lib_t *lib; } ucc_tl_mlx5_mcast_oob_p2p_context_t; static inline struct pp_packet* ucc_tl_mlx5_mcast_buf_get_free(ucc_tl_mlx5_mcast_coll_comm_t* comm) @@ -333,7 +334,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast struct pp_packet *pp = NULL; int count = comm->params.rx_depth - comm->pending_recv; int i; - + if (count <= comm->params.post_recv_thresh) { return UCC_OK; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index 4669c88640..9696ba8c82 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -68,18 +68,17 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req int to_recv_left; int pending_q_size; + + status = ucc_tl_mlx5_mcast_check_nack_requests(comm, UINT32_MAX); + if (status < 0) { + return status; + } + if (ucc_unlikely(comm->recv_drop_packet_in_progress)) { /* wait till parent resend the dropped packet */ return UCC_INPROGRESS; } - if (comm->reliable_in_progress) { - /* wait till all the children send their ACK for current window */ - status = ucc_tl_mlx5_mcast_r_window_recycle(comm, req); - if (UCC_OK != status) { - return status; - } - } if (req->to_send || req->to_recv) { num_free_win = wsize - (comm->psn - comm->last_acked); @@ -137,12 +136,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req comm->timer = 0; } } + } - /* This function will check if we have to do a round of reliability protocol */ - status = ucc_tl_mlx5_mcast_r_window_recycle(comm, req); - if (UCC_OK != status) { - return status; - } + /* This function will check if we have to do a round of reliability protocol */ + status = ucc_tl_mlx5_mcast_r_window_recycle(comm, req); + if (UCC_OK != status) { + return status; } if (req->to_send || req->to_recv || (zcopy && comm->psn != comm->last_acked)) { diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c index a201944ecf..4522097973 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -10,15 +10,38 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_completion(ucc_tl_mlx5_mcast_p2p_comp static ucc_status_t ucc_tl_mlx5_mcast_send_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj); -static ucc_status_t ucc_tl_mlx5_mcast_dummy_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) // NOLINT +static ucc_status_t ucc_tl_mlx5_mcast_reliability_send_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *comp_obj) { + ucc_tl_mlx5_mcast_coll_comm_t *comm = (ucc_tl_mlx5_mcast_coll_comm_t*)comp_obj->data[0]; + unsigned int pkt_id = comp_obj->data[2]; + struct packet *p = (struct packet *)comp_obj->data[1]; + ucc_status_t status; + + if (p != NULL) { + /* it was a nack packet to our parent */ + ucc_free(p); + } + + if (pkt_id != UINT_MAX) { + /* we sent the real data to our child so reduce the nack reqs */ + ucc_assert(comm->nack_requests > 0); + ucc_assert(comm->p2p_pkt[pkt_id].type == MCAST_P2P_NACK_SEND_PENDING); + comm->p2p_pkt[pkt_id].type = MCAST_P2P_ACK; + comm->nack_requests--; + status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[pkt_id], + sizeof(struct packet), comm->p2p_pkt[pkt_id].from, + comm->p2p_ctx, GET_COMPL_OBJ(comm, + ucc_tl_mlx5_mcast_recv_completion, pkt_id, NULL)); + if (status < 0) { + return status; + } + } + + ucc_mpool_put(comp_obj); + return UCC_OK; } -static ucc_tl_mlx5_mcast_p2p_completion_obj_t dummy_completion_obj = { - .compl_cb = ucc_tl_mlx5_mcast_dummy_completion, -}; - static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm, int p2p_pkt_id) { @@ -27,26 +50,22 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_ ucc_status_t status; ucc_assert(pp->psn == psn); + ucc_assert(comm->p2p_pkt[p2p_pkt_id].type == MCAST_P2P_NEED_NACK_SEND); + + comm->p2p_pkt[p2p_pkt_id].type = MCAST_P2P_NACK_SEND_PENDING; - tl_trace(comm->lib, "[comm %d, rank %d] Send data NACK: to %d, psn %d, context %ld\n", + tl_trace(comm->lib, "[comm %d, rank %d] Send data NACK: to %d, psn %d, context %ld nack_requests %d \n", comm->comm_id, comm->rank, - comm->p2p_pkt[p2p_pkt_id].from, psn, pp->context); + comm->p2p_pkt[p2p_pkt_id].from, psn, pp->context, comm->nack_requests); status = comm->params.p2p_iface.send_nb((void*) (pp->context ? pp->context : pp->buf), pp->length, comm->p2p_pkt[p2p_pkt_id].from, - comm->p2p_ctx, &dummy_completion_obj); - if (status < 0) { - return status; - } - - status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[p2p_pkt_id], - sizeof(struct packet), comm->p2p_pkt[p2p_pkt_id].from, comm->p2p_ctx, GET_COMPL_OBJ(comm, - ucc_tl_mlx5_mcast_recv_completion, p2p_pkt_id, NULL)); + ucc_tl_mlx5_mcast_reliability_send_completion, NULL, p2p_pkt_id)); if (status < 0) { return status; } - + return UCC_OK; } @@ -68,8 +87,6 @@ ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t if (status != UCC_OK) { break; } - comm->p2p_pkt[i].type = MCAST_P2P_ACK; - comm->nack_requests--; } } } else { @@ -82,8 +99,6 @@ ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t if (status < 0) { break; } - comm->p2p_pkt[i].type = MCAST_P2P_ACK; - comm->nack_requests--; } } } @@ -145,33 +160,35 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_data_completion(ucc_tl_mlx5_mcast_p2p static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcast_coll_comm_t* comm, ucc_tl_mlx5_mcast_coll_req_t *req) { + ucc_status_t status = UCC_OK; + uint32_t psn = ucc_tl_mlx5_mcast_find_nack_psn(comm, req); struct pp_packet *pp; ucc_rank_t parent; - ucc_status_t status; + struct packet *p; - struct packet p = { - .type = MCAST_P2P_NACK, - .psn = ucc_tl_mlx5_mcast_find_nack_psn(comm, req), - .from = comm->rank, - .comm_id = comm->comm_id, - }; + p = ucc_calloc(1, sizeof(struct packet)); + p->type = MCAST_P2P_NACK; + p->psn = psn; + p->from = comm->rank; + p->comm_id = comm->comm_id; parent = ucc_tl_mlx5_mcast_get_nack_parent(req); comm->nacks_counter++; - status = comm->params.p2p_iface.send_nb(&p, sizeof(struct packet), parent, - comm->p2p_ctx, &dummy_completion_obj); + status = comm->params.p2p_iface.send_nb(p, sizeof(struct packet), parent, + comm->p2p_ctx, GET_COMPL_OBJ(comm, + ucc_tl_mlx5_mcast_reliability_send_completion, p, UINT_MAX)); if (status < 0) { return status; } tl_trace(comm->lib, "[comm %d, rank %d] Sent NAK : parent %d, psn %d", - comm->comm_id, comm->rank, parent, p.psn); + comm->comm_id, comm->rank, parent, psn); // Prepare to obtain the data. pp = ucc_tl_mlx5_mcast_buf_get_free(comm); - pp->psn = p.psn; + pp->psn = psn; pp->length = PSN_TO_RECV_LEN(pp->psn, req, comm); comm->recv_drop_packet_in_progress = true; @@ -234,21 +251,23 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_completion(ucc_tl_mlx5_mcast_p2p_comp psn = comm->p2p_pkt[pkt_id].psn; pp = comm->r_window[psn % comm->wsize]; - tl_trace(comm->lib, "[comm %d, rank %d] Got NACK: from %d, psn %d, avail %d", + tl_trace(comm->lib, "[comm %d, rank %d] Got NACK: from %d, psn %d, avail %d pkt_id %d", comm->comm_id, comm->rank, - comm->p2p_pkt[pkt_id].from, psn, pp->psn == psn); + comm->p2p_pkt[pkt_id].from, psn, pp->psn == psn, pkt_id); + + comm->p2p_pkt[pkt_id].type = MCAST_P2P_NEED_NACK_SEND; + comm->nack_requests++; if (pp->psn == psn) { + /* parent already has this packet so it is ready to forward it to its child */ status = ucc_tl_mlx5_mcast_resend_packet_reliable(comm, pkt_id); - if (status < 0) { + if (status != UCC_OK) { return status; } - } else { - comm->p2p_pkt[pkt_id].type = MCAST_P2P_NEED_NACK_SEND; - comm->nack_requests++; } } else { + ucc_assert(comm->p2p_pkt[pkt_id].type == MCAST_P2P_ACK); comm->racks_n++; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h index 1bceb89976..b1e2b38526 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h @@ -30,18 +30,6 @@ obj; \ }) -#define GET_NACK_REQ(_comm, _pkt_id) \ - ({ \ - void* item; \ - ucc_tl_mlx5_mcast_nack_req_t *_req; \ - item = ucc_mpool_get(&(_comm)->ctx->nack_reqs_mp); \ - \ - _req = (ucc_tl_mlx5_mcast_nack_req_t *)item; \ - _req->comm = _comm; \ - _req->pkt_id = _pkt_id; \ - _req; \ - }) - ucc_status_t ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm, ucc_tl_mlx5_mcast_coll_req_t *req, ucc_rank_t root); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index d6b6c763c0..6823abaa08 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -91,6 +91,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, oob_p2p_ctx->base_ctx = context; oob_p2p_ctx->base_team = team_params->team; oob_p2p_ctx->my_team_rank = team_params->rank; + oob_p2p_ctx->lib = mcast_context->lib; set.myrank = team_params->rank; set.map = team_params->map; oob_p2p_ctx->subset = set;