Skip to content

Commit

Permalink
TL/MLX5: enhance reliablity protocol in Mcast (#957)
Browse files Browse the repository at this point in the history
Co-authored-by: Manjunath Gorentla Venkata <[email protected]>
  • Loading branch information
MamziB and manjugv authored Apr 11, 2024
1 parent 666160f commit be42729
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 87 deletions.
36 changes: 18 additions & 18 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ void ucc_tl_mlx5_mcast_completion_cb(void* context, ucc_status_t status) //NOLIN

static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t
len, ucc_rank_t my_team_rank, ucc_rank_t dest,
ucc_team_h team, ucc_context_h ctx,
ucc_coll_callback_t *callback,
ucc_coll_req_h *p2p_req, int is_send)
ucc_team_h team, ucc_coll_callback_t *callback,
ucc_coll_req_h *p2p_req, int is_send,
ucc_base_lib_t *lib)
{
ucc_status_t status = UCC_OK;
ucc_coll_req_h req = NULL;
Expand All @@ -47,19 +47,19 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t
args.cb.data = callback->data;
args.active_set.size = 2;
args.active_set.start = my_team_rank;
args.active_set.stride = dest - my_team_rank;
args.active_set.stride = (int)dest - (int)my_team_rank;

status = ucc_collective_init(&args, &req, team);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(ctx->lib, "nonblocking p2p init failed");
tl_error(lib, "nonblocking p2p init failed");
return status;
}

((ucc_tl_mlx5_mcast_p2p_completion_obj_t *)args.cb.data)->req = req;

status = ucc_collective_post(req);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(ctx->lib, "nonblocking p2p post failed");
tl_error(lib, "nonblocking p2p post failed");
return status;
}

Expand All @@ -70,20 +70,20 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t

static inline ucc_status_t do_send_nb(void *sbuf, size_t len, ucc_rank_t
my_team_rank, ucc_rank_t dest, ucc_team_h team,
ucc_context_h ctx, ucc_coll_callback_t
*callback, ucc_coll_req_h *req)
ucc_coll_callback_t *callback,
ucc_coll_req_h *req, ucc_base_lib_t *lib)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest,
team, ctx, callback, req, 1);
team, callback, req, 1, lib);
}

static inline ucc_status_t do_recv_nb(void *rbuf, size_t len, ucc_rank_t
my_team_rank, ucc_rank_t dest, ucc_team_h team,
ucc_context_h ctx, ucc_coll_callback_t
*callback, ucc_coll_req_h *req)
ucc_coll_callback_t *callback,
ucc_coll_req_h *req, ucc_base_lib_t *lib)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest,
team, ctx, callback, req, 0);
team, callback, req, 0, lib);
}

ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
Expand All @@ -97,16 +97,16 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
ucc_coll_req_h req = NULL;
ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank;
ucc_team_h team = oob_p2p_ctx->base_team;
ucc_context_h ctx = oob_p2p_ctx->base_ctx;
ucc_coll_callback_t callback;

callback.cb = ucc_tl_mlx5_mcast_completion_cb;
callback.data = obj;

status = do_send_nb(src, size, my_team_rank, rank, team, ctx, &callback, &req);
tl_trace(oob_p2p_ctx->lib, "P2P: SEND to %d Msg Size %ld", rank, size);
status = do_send_nb(src, size, my_team_rank, rank, team, &callback, &req, oob_p2p_ctx->lib);

if (status < 0) {
tl_error(ctx->lib, "nonblocking p2p send failed");
tl_error(oob_p2p_ctx->lib, "nonblocking p2p send failed");
return status;
}

Expand All @@ -124,16 +124,16 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t
ucc_coll_req_h req = NULL;
ucc_rank_t my_team_rank = oob_p2p_ctx->my_team_rank;
ucc_team_h team = oob_p2p_ctx->base_team;
ucc_context_h ctx = oob_p2p_ctx->base_ctx;
ucc_coll_callback_t callback;

callback.cb = ucc_tl_mlx5_mcast_completion_cb;
callback.data = obj;

status = do_recv_nb(dst, size, my_team_rank, rank, team, ctx, &callback, &req);
tl_trace(oob_p2p_ctx->lib, "P2P: RECV to %d Msg Size %ld", rank, size);
status = do_recv_nb(dst, size, my_team_rank, rank, team, &callback, &req, oob_p2p_ctx->lib);

if (status < 0) {
tl_error(ctx->lib, "nonblocking p2p recv failed");
tl_error(oob_p2p_ctx->lib, "nonblocking p2p recv failed");
return status;
}

Expand Down
17 changes: 9 additions & 8 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#define DEF_SL 0
#define DEF_SRC_PATH_BITS 0
#define GRH_LENGTH 40
#define DROP_THRESHOLD 1000000
#define DROP_THRESHOLD 1000
#define MAX_COMM_POW2 32

enum {
Expand All @@ -41,7 +41,8 @@ enum {
enum {
MCAST_P2P_NACK,
MCAST_P2P_ACK,
MCAST_P2P_NEED_NACK_SEND
MCAST_P2P_NEED_NACK_SEND,
MCAST_P2P_NACK_SEND_PENDING
};

enum {
Expand Down Expand Up @@ -138,7 +139,6 @@ typedef struct ucc_tl_mlx5_mcast_coll_context {
struct rdma_cm_id *id;
struct rdma_event_channel *channel;
ucc_mpool_t compl_objects_mp;
ucc_mpool_t nack_reqs_mp;
ucc_list_link_t pending_nacks_list;
ucc_rcache_t *rcache;
ucc_tl_mlx5_mcast_ctx_params_t params;
Expand Down Expand Up @@ -308,10 +308,11 @@ typedef struct ucc_tl_mlx5_mcast_coll_req {
} ucc_tl_mlx5_mcast_coll_req_t;

typedef struct ucc_tl_mlx5_mcast_oob_p2p_context {
ucc_context_h base_ctx;
ucc_team_h base_team;
ucc_rank_t my_team_rank;
ucc_subset_t subset;
ucc_context_h base_ctx;
ucc_team_h base_team;
ucc_rank_t my_team_rank;
ucc_subset_t subset;
ucc_base_lib_t *lib;
} ucc_tl_mlx5_mcast_oob_p2p_context_t;

static inline struct pp_packet* ucc_tl_mlx5_mcast_buf_get_free(ucc_tl_mlx5_mcast_coll_comm_t* comm)
Expand All @@ -333,7 +334,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast
struct pp_packet *pp = NULL;
int count = comm->params.rx_depth - comm->pending_recv;
int i;

if (count <= comm->params.post_recv_thresh) {
return UCC_OK;
}
Expand Down
23 changes: 11 additions & 12 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,17 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req
int to_recv_left;
int pending_q_size;


status = ucc_tl_mlx5_mcast_check_nack_requests(comm, UINT32_MAX);
if (status < 0) {
return status;
}

if (ucc_unlikely(comm->recv_drop_packet_in_progress)) {
/* wait till parent resend the dropped packet */
return UCC_INPROGRESS;
}

if (comm->reliable_in_progress) {
/* wait till all the children send their ACK for current window */
status = ucc_tl_mlx5_mcast_r_window_recycle(comm, req);
if (UCC_OK != status) {
return status;
}
}

if (req->to_send || req->to_recv) {
num_free_win = wsize - (comm->psn - comm->last_acked);
Expand Down Expand Up @@ -137,12 +136,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req
comm->timer = 0;
}
}
}

/* This function will check if we have to do a round of reliability protocol */
status = ucc_tl_mlx5_mcast_r_window_recycle(comm, req);
if (UCC_OK != status) {
return status;
}
/* This function will check if we have to do a round of reliability protocol */
status = ucc_tl_mlx5_mcast_r_window_recycle(comm, req);
if (UCC_OK != status) {
return status;
}

if (req->to_send || req->to_recv || (zcopy && comm->psn != comm->last_acked)) {
Expand Down
93 changes: 56 additions & 37 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,38 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_completion(ucc_tl_mlx5_mcast_p2p_comp

static ucc_status_t ucc_tl_mlx5_mcast_send_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj);

static ucc_status_t ucc_tl_mlx5_mcast_dummy_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *obj) // NOLINT
static ucc_status_t ucc_tl_mlx5_mcast_reliability_send_completion(ucc_tl_mlx5_mcast_p2p_completion_obj_t *comp_obj)
{
ucc_tl_mlx5_mcast_coll_comm_t *comm = (ucc_tl_mlx5_mcast_coll_comm_t*)comp_obj->data[0];
unsigned int pkt_id = comp_obj->data[2];
struct packet *p = (struct packet *)comp_obj->data[1];
ucc_status_t status;

if (p != NULL) {
/* it was a nack packet to our parent */
ucc_free(p);
}

if (pkt_id != UINT_MAX) {
/* we sent the real data to our child so reduce the nack reqs */
ucc_assert(comm->nack_requests > 0);
ucc_assert(comm->p2p_pkt[pkt_id].type == MCAST_P2P_NACK_SEND_PENDING);
comm->p2p_pkt[pkt_id].type = MCAST_P2P_ACK;
comm->nack_requests--;
status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[pkt_id],
sizeof(struct packet), comm->p2p_pkt[pkt_id].from,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_recv_completion, pkt_id, NULL));
if (status < 0) {
return status;
}
}

ucc_mpool_put(comp_obj);

return UCC_OK;
}

static ucc_tl_mlx5_mcast_p2p_completion_obj_t dummy_completion_obj = {
.compl_cb = ucc_tl_mlx5_mcast_dummy_completion,
};

static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm,
int p2p_pkt_id)
{
Expand All @@ -27,26 +50,22 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_
ucc_status_t status;

ucc_assert(pp->psn == psn);
ucc_assert(comm->p2p_pkt[p2p_pkt_id].type == MCAST_P2P_NEED_NACK_SEND);

comm->p2p_pkt[p2p_pkt_id].type = MCAST_P2P_NACK_SEND_PENDING;

tl_trace(comm->lib, "[comm %d, rank %d] Send data NACK: to %d, psn %d, context %ld\n",
tl_trace(comm->lib, "[comm %d, rank %d] Send data NACK: to %d, psn %d, context %ld nack_requests %d \n",
comm->comm_id, comm->rank,
comm->p2p_pkt[p2p_pkt_id].from, psn, pp->context);
comm->p2p_pkt[p2p_pkt_id].from, psn, pp->context, comm->nack_requests);

status = comm->params.p2p_iface.send_nb((void*) (pp->context ? pp->context : pp->buf),
pp->length, comm->p2p_pkt[p2p_pkt_id].from,
comm->p2p_ctx, &dummy_completion_obj);
if (status < 0) {
return status;
}

status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[p2p_pkt_id],
sizeof(struct packet), comm->p2p_pkt[p2p_pkt_id].from,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_recv_completion, p2p_pkt_id, NULL));
ucc_tl_mlx5_mcast_reliability_send_completion, NULL, p2p_pkt_id));
if (status < 0) {
return status;
}

return UCC_OK;
}

Expand All @@ -68,8 +87,6 @@ ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t
if (status != UCC_OK) {
break;
}
comm->p2p_pkt[i].type = MCAST_P2P_ACK;
comm->nack_requests--;
}
}
} else {
Expand All @@ -82,8 +99,6 @@ ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t
if (status < 0) {
break;
}
comm->p2p_pkt[i].type = MCAST_P2P_ACK;
comm->nack_requests--;
}
}
}
Expand Down Expand Up @@ -145,33 +160,35 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_data_completion(ucc_tl_mlx5_mcast_p2p
static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcast_coll_comm_t* comm,
ucc_tl_mlx5_mcast_coll_req_t *req)
{
ucc_status_t status = UCC_OK;
uint32_t psn = ucc_tl_mlx5_mcast_find_nack_psn(comm, req);
struct pp_packet *pp;
ucc_rank_t parent;
ucc_status_t status;
struct packet *p;

struct packet p = {
.type = MCAST_P2P_NACK,
.psn = ucc_tl_mlx5_mcast_find_nack_psn(comm, req),
.from = comm->rank,
.comm_id = comm->comm_id,
};
p = ucc_calloc(1, sizeof(struct packet));
p->type = MCAST_P2P_NACK;
p->psn = psn;
p->from = comm->rank;
p->comm_id = comm->comm_id;

parent = ucc_tl_mlx5_mcast_get_nack_parent(req);

comm->nacks_counter++;

status = comm->params.p2p_iface.send_nb(&p, sizeof(struct packet), parent,
comm->p2p_ctx, &dummy_completion_obj);
status = comm->params.p2p_iface.send_nb(p, sizeof(struct packet), parent,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_reliability_send_completion, p, UINT_MAX));
if (status < 0) {
return status;
}

tl_trace(comm->lib, "[comm %d, rank %d] Sent NAK : parent %d, psn %d",
comm->comm_id, comm->rank, parent, p.psn);
comm->comm_id, comm->rank, parent, psn);

// Prepare to obtain the data.
pp = ucc_tl_mlx5_mcast_buf_get_free(comm);
pp->psn = p.psn;
pp->psn = psn;
pp->length = PSN_TO_RECV_LEN(pp->psn, req, comm);

comm->recv_drop_packet_in_progress = true;
Expand Down Expand Up @@ -234,21 +251,23 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_completion(ucc_tl_mlx5_mcast_p2p_comp
psn = comm->p2p_pkt[pkt_id].psn;
pp = comm->r_window[psn % comm->wsize];

tl_trace(comm->lib, "[comm %d, rank %d] Got NACK: from %d, psn %d, avail %d",
tl_trace(comm->lib, "[comm %d, rank %d] Got NACK: from %d, psn %d, avail %d pkt_id %d",
comm->comm_id, comm->rank,
comm->p2p_pkt[pkt_id].from, psn, pp->psn == psn);
comm->p2p_pkt[pkt_id].from, psn, pp->psn == psn, pkt_id);

comm->p2p_pkt[pkt_id].type = MCAST_P2P_NEED_NACK_SEND;
comm->nack_requests++;

if (pp->psn == psn) {
/* parent already has this packet so it is ready to forward it to its child */
status = ucc_tl_mlx5_mcast_resend_packet_reliable(comm, pkt_id);
if (status < 0) {
if (status != UCC_OK) {
return status;
}
} else {
comm->p2p_pkt[pkt_id].type = MCAST_P2P_NEED_NACK_SEND;
comm->nack_requests++;
}

} else {
ucc_assert(comm->p2p_pkt[pkt_id].type == MCAST_P2P_ACK);
comm->racks_n++;
}

Expand Down
12 changes: 0 additions & 12 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,6 @@
obj; \
})

#define GET_NACK_REQ(_comm, _pkt_id) \
({ \
void* item; \
ucc_tl_mlx5_mcast_nack_req_t *_req; \
item = ucc_mpool_get(&(_comm)->ctx->nack_reqs_mp); \
\
_req = (ucc_tl_mlx5_mcast_nack_req_t *)item; \
_req->comm = _comm; \
_req->pkt_id = _pkt_id; \
_req; \
})

ucc_status_t ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t *req,
ucc_rank_t root);
Expand Down
1 change: 1 addition & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,
oob_p2p_ctx->base_ctx = context;
oob_p2p_ctx->base_team = team_params->team;
oob_p2p_ctx->my_team_rank = team_params->rank;
oob_p2p_ctx->lib = mcast_context->lib;
set.myrank = team_params->rank;
set.map = team_params->map;
oob_p2p_ctx->subset = set;
Expand Down

0 comments on commit be42729

Please sign in to comment.