Skip to content

Commit

Permalink
TL/MLX5: fix memtype in bcast reliablity
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Oct 21, 2024
1 parent 699b658 commit e81d9a8
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 36 deletions.
21 changes: 12 additions & 9 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void ucc_tl_mlx5_mcast_completion_cb(void* context, ucc_status_t status) //NOLIN

static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t
len, ucc_rank_t my_team_rank, ucc_rank_t dest,
ucc_memory_type_t mem_type,
ucc_team_h team, ucc_coll_callback_t *callback,
ucc_coll_req_h *p2p_req, int is_send,
ucc_base_lib_t *lib)
Expand All @@ -41,7 +42,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t
args.src.info.buffer = buf;
args.src.info.count = len;
args.src.info.datatype = UCC_DT_INT8;
args.src.info.mem_type = UCC_MEMORY_TYPE_HOST;
args.src.info.mem_type = mem_type;
args.root = is_send ? my_team_rank : dest;
args.cb.cb = callback->cb;
args.cb.data = callback->data;
Expand Down Expand Up @@ -69,25 +70,27 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t
}

static inline ucc_status_t do_send_nb(void *sbuf, size_t len, ucc_rank_t
my_team_rank, ucc_rank_t dest, ucc_team_h team,
my_team_rank, ucc_rank_t dest,
ucc_memory_type_t mem_type, ucc_team_h team,
ucc_coll_callback_t *callback,
ucc_coll_req_h *req, ucc_base_lib_t *lib)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest,
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest, mem_type,
team, callback, req, 1, lib);
}

static inline ucc_status_t do_recv_nb(void *rbuf, size_t len, ucc_rank_t
my_team_rank, ucc_rank_t dest, ucc_team_h team,
my_team_rank, ucc_rank_t dest,
ucc_memory_type_t mem_type, ucc_team_h team,
ucc_coll_callback_t *callback,
ucc_coll_req_h *req, ucc_base_lib_t *lib)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest,
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest, mem_type,
team, callback, req, 0, lib);
}

ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
rank, void *context,
rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj)
{
Expand All @@ -103,7 +106,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
callback.data = obj;

tl_trace(oob_p2p_ctx->lib, "P2P: SEND to %d Msg Size %ld", rank, size);
status = do_send_nb(src, size, my_team_rank, rank, team, &callback, &req, oob_p2p_ctx->lib);
status = do_send_nb(src, size, my_team_rank, rank, mem_type, team, &callback, &req, oob_p2p_ctx->lib);

if (status < 0) {
tl_error(oob_p2p_ctx->lib, "nonblocking p2p send failed");
Expand All @@ -114,7 +117,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
}

ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t
rank, void *context,
rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj)
{
Expand All @@ -130,7 +133,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t
callback.data = obj;

tl_trace(oob_p2p_ctx->lib, "P2P: RECV to %d Msg Size %ld", rank, size);
status = do_recv_nb(dst, size, my_team_rank, rank, team, &callback, &req, oob_p2p_ctx->lib);
status = do_recv_nb(dst, size, my_team_rank, rank, mem_type, team, &callback, &req, oob_p2p_ctx->lib);

if (status < 0) {
tl_error(oob_p2p_ctx->lib, "nonblocking p2p recv failed");
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
#include "components/tl/mlx5/mcast/tl_mlx5_mcast.h"

ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
rank, void *context,
rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);

ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* dst, size_t size, ucc_rank_t
rank, void *context,
rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);

Expand Down
8 changes: 5 additions & 3 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#define DEF_SL 0
#define DEF_SRC_PATH_BITS 0
#define GRH_LENGTH 40
#define DROP_THRESHOLD 1000
#define DROP_THRESHOLD 10000
#define MAX_COMM_POW2 32

/* Allgather RDMA-based reliability designs */
Expand All @@ -42,6 +42,8 @@
#define ONE_SIDED_SLOTS_COUNT 2 /* number of memory slots during async design */
#define ONE_SIDED_SLOTS_INFO_SIZE sizeof(uint32_t) /* size of metadata prepended to each slots in bytes */

#define CUDA_MEM_MCAST_BCAST_MAX_MSG 4000

enum {
MCAST_PROTO_EAGER, /* Internal staging buffers */
MCAST_PROTO_ZCOPY
Expand Down Expand Up @@ -75,12 +77,12 @@ typedef struct ucc_tl_mlx5_mcast_p2p_completion_obj {
typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg);

typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_rank_t rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);


typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_rank_t rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);

typedef struct ucc_tl_mlx5_mcast_p2p_interface {
Expand Down
35 changes: 32 additions & 3 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_bcast(void* buf, size_t siz
req->am_root = (root == comm->rank);
req->mr = comm->pp_mr;
req->rreg = NULL;
req->proto = (req->length < comm->max_eager) ? MCAST_PROTO_EAGER : MCAST_PROTO_ZCOPY;
/* cost of copy is too high in cuda buffers */
req->proto = (req->length < comm->max_eager && !comm->cuda_mem_enabled) ?
MCAST_PROTO_EAGER : MCAST_PROTO_ZCOPY;

status = ucc_tl_mlx5_mcast_prepare_reliable(comm, req, req->root);
if (ucc_unlikely(UCC_OK != status)) {
Expand Down Expand Up @@ -283,8 +285,8 @@ void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task)
}
}

ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team)
static inline ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team)
{
ucc_tl_mlx5_team_t *mlx5_team = ucc_derived_of(team, ucc_tl_mlx5_team_t);
ucc_tl_mlx5_mcast_coll_comm_t *comm = mlx5_team->mcast->mcast_comm;
Expand All @@ -300,6 +302,33 @@ ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_
return UCC_ERR_NO_RESOURCE;
}

ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team)
{
ucc_coll_args_t *args = &coll_args->args;
int buf_size = ucc_dt_size(args->src.info.datatype) * args->src.info.count;

if (UCC_COLL_ARGS_ACTIVE_SET(args)) {
tl_trace(team->context->lib, "mcast bcast not supported for active sets");
return UCC_ERR_NOT_SUPPORTED;
}

if (UCC_OK != ucc_tl_mlx5_mcast_check_memory_type_cap(coll_args, team)) {
tl_trace(team->context->lib, "mcast bcast not compatible with this memory type");
return UCC_ERR_NOT_SUPPORTED;
}

if (args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA &&
buf_size > CUDA_MEM_MCAST_BCAST_MAX_MSG) {
/* for large messages (more than one mtu) we need zero-copy design which
* is not implemented yet */
tl_trace(team->context->lib, "mcast cuda bcast not supported for large messages");
return UCC_ERR_NOT_IMPLEMENTED;
}

return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task)
{
task->super.post = ucc_tl_mlx5_mcast_bcast_start;
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task);

ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req);

ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team);
ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team);
#endif
45 changes: 37 additions & 8 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ static ucc_status_t ucc_tl_mlx5_mcast_reliability_send_completion(ucc_tl_mlx5_mc
comm->nack_requests--;
status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[pkt_id],
sizeof(struct packet), comm->p2p_pkt[pkt_id].from,
UCC_MEMORY_TYPE_HOST,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_recv_completion, pkt_id, NULL));
if (status < 0) {
Expand All @@ -48,6 +49,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_
uint32_t psn = comm->p2p_pkt[p2p_pkt_id].psn;
struct pp_packet *pp = comm->r_window[psn % comm->wsize];
ucc_status_t status;
ucc_memory_type_t mem_type;

ucc_assert(pp->psn == psn);
ucc_assert(comm->p2p_pkt[p2p_pkt_id].type == MCAST_P2P_NEED_NACK_SEND);
Expand All @@ -58,8 +60,14 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_
comm->comm_id, comm->rank,
comm->p2p_pkt[p2p_pkt_id].from, psn, pp->context, comm->nack_requests);

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
mem_type = UCC_MEMORY_TYPE_HOST;
}

status = comm->params.p2p_iface.send_nb((void*) (pp->context ? pp->context : pp->buf),
pp->length, comm->p2p_pkt[p2p_pkt_id].from,
pp->length, comm->p2p_pkt[p2p_pkt_id].from, mem_type,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_reliability_send_completion, NULL, p2p_pkt_id));
if (status < 0) {
Expand Down Expand Up @@ -138,11 +146,25 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_data_completion(ucc_tl_mlx5_mcast_p2p
struct pp_packet *pp = (struct pp_packet *)obj->data[1];
ucc_tl_mlx5_mcast_coll_req_t *req = (ucc_tl_mlx5_mcast_coll_req_t *)obj->data[2];
void *dest;
ucc_memory_type_t mem_type;

tl_trace(comm->lib, "[comm %d, rank %d] Recved data psn %d", comm->comm_id, comm->rank, pp->psn);

dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm);
memcpy(dest, (void*) pp->buf, pp->length);

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
mem_type = UCC_MEMORY_TYPE_HOST;
}

status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length,
mem_type, mem_type);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->lib, "failed to copy buffer");
return status;
}

req->to_recv--;
comm->r_window[pp->psn % comm->wsize] = pp;

Expand All @@ -165,6 +187,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas
struct pp_packet *pp;
ucc_rank_t parent;
struct packet *p;
ucc_memory_type_t mem_type;

p = ucc_calloc(1, sizeof(struct packet));
p->type = MCAST_P2P_NACK;
Expand All @@ -176,7 +199,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas

comm->nacks_counter++;

status = comm->params.p2p_iface.send_nb(p, sizeof(struct packet), parent,
status = comm->params.p2p_iface.send_nb(p, sizeof(struct packet), parent, UCC_MEMORY_TYPE_HOST,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_reliability_send_completion, p, UINT_MAX));
if (status < 0) {
Expand All @@ -193,8 +216,14 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas

comm->recv_drop_packet_in_progress = true;

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
mem_type = UCC_MEMORY_TYPE_HOST;
}

status = comm->params.p2p_iface.recv_nb((void*) pp->buf,
pp->length, parent,
pp->length, parent, mem_type,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_recv_data_completion, pp, req));
if (status < 0) {
Expand Down Expand Up @@ -225,7 +254,7 @@ ucc_status_t ucc_tl_mlx5_mcast_reliable_send(ucc_tl_mlx5_mcast_coll_comm_t *comm
comm->rank, parent, comm->parent_n, comm->psn);

status = comm->params.p2p_iface.send_nb(&comm->p2p_spkt[i],
sizeof(struct packet), parent,
sizeof(struct packet), parent, UCC_MEMORY_TYPE_HOST,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_send_completion, i, NULL));
if (status < 0) {
Expand Down Expand Up @@ -325,7 +354,7 @@ ucc_status_t ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *c
comm->rank, child, comm->child_n, comm->psn);

status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[comm->child_n - 1],
sizeof(struct packet), child,
sizeof(struct packet), child, UCC_MEMORY_TYPE_HOST,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_recv_completion, comm->child_n - 1, req));
if (status < 0) {
Expand Down Expand Up @@ -369,8 +398,8 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com
ucc_tl_mlx5_mcast_coll_req_t *req,
struct pp_packet* pp)
{
ucc_status_t status = UCC_OK;
void *dest;
ucc_status_t status = UCC_OK;
void *dest;
ucc_memory_type_t mem_type;
ucc_assert(pp->psn >= req->start_psn &&
pp->psn < req->start_psn + req->num_packets);
Expand Down
12 changes: 3 additions & 9 deletions src/components/tl/mlx5/tl_mlx5_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,13 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args,
{
ucc_status_t status = UCC_OK;
ucc_tl_mlx5_task_t *task = NULL;

if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) {
tl_trace(team->context->lib, "mcast bcast not supported for active sets");
return UCC_ERR_NOT_SUPPORTED;
}

if (UCC_OK != ucc_tl_mlx5_mcast_check_memory_type_cap(coll_args, team)) {
tl_trace(team->context->lib, "mcast bcast not compatible with this memory type");
return UCC_ERR_NOT_SUPPORTED;
status = ucc_tl_mlx5_mcast_check_support(coll_args, team);
if (UCC_OK != status) {
return status;
}

task = ucc_tl_mlx5_get_task(coll_args, team);

if (ucc_unlikely(!task)) {
return UCC_ERR_NO_MEMORY;
}
Expand Down

0 comments on commit e81d9a8

Please sign in to comment.