Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/MLX5: fix memtype in bcast reliability #1022

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void ucc_tl_mlx5_mcast_completion_cb(void* context, ucc_status_t status) //NOLIN

static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t
len, ucc_rank_t my_team_rank, ucc_rank_t dest,
ucc_memory_type_t mem_type,
ucc_team_h team, ucc_coll_callback_t *callback,
ucc_coll_req_h *p2p_req, int is_send,
ucc_base_lib_t *lib)
Expand All @@ -41,7 +42,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t
args.src.info.buffer = buf;
args.src.info.count = len;
args.src.info.datatype = UCC_DT_INT8;
args.src.info.mem_type = UCC_MEMORY_TYPE_HOST;
args.src.info.mem_type = mem_type;
args.root = is_send ? my_team_rank : dest;
args.cb.cb = callback->cb;
args.cb.data = callback->data;
Expand Down Expand Up @@ -69,25 +70,27 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_p2p_bcast_nb(void *buf, size_t
}

static inline ucc_status_t do_send_nb(void *sbuf, size_t len, ucc_rank_t
my_team_rank, ucc_rank_t dest, ucc_team_h team,
my_team_rank, ucc_rank_t dest,
ucc_memory_type_t mem_type, ucc_team_h team,
ucc_coll_callback_t *callback,
ucc_coll_req_h *req, ucc_base_lib_t *lib)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest,
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(sbuf, len, my_team_rank, dest, mem_type,
team, callback, req, 1, lib);
}

static inline ucc_status_t do_recv_nb(void *rbuf, size_t len, ucc_rank_t
my_team_rank, ucc_rank_t dest, ucc_team_h team,
my_team_rank, ucc_rank_t dest,
ucc_memory_type_t mem_type, ucc_team_h team,
ucc_coll_callback_t *callback,
ucc_coll_req_h *req, ucc_base_lib_t *lib)
{
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest,
return ucc_tl_mlx5_mcast_do_p2p_bcast_nb(rbuf, len, my_team_rank, dest, mem_type,
team, callback, req, 0, lib);
}

ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
rank, void *context,
rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj)
{
Expand All @@ -103,7 +106,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
callback.data = obj;

tl_trace(oob_p2p_ctx->lib, "P2P: SEND to %d Msg Size %ld", rank, size);
status = do_send_nb(src, size, my_team_rank, rank, team, &callback, &req, oob_p2p_ctx->lib);
status = do_send_nb(src, size, my_team_rank, rank, mem_type, team, &callback, &req, oob_p2p_ctx->lib);

if (status < 0) {
tl_error(oob_p2p_ctx->lib, "nonblocking p2p send failed");
Expand All @@ -114,7 +117,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
}

ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t
rank, void *context,
rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj)
{
Expand All @@ -130,7 +133,7 @@ ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void *dst, size_t size, ucc_rank_t
callback.data = obj;

tl_trace(oob_p2p_ctx->lib, "P2P: RECV to %d Msg Size %ld", rank, size);
status = do_recv_nb(dst, size, my_team_rank, rank, team, &callback, &req, oob_p2p_ctx->lib);
status = do_recv_nb(dst, size, my_team_rank, rank, mem_type, team, &callback, &req, oob_p2p_ctx->lib);

if (status < 0) {
tl_error(oob_p2p_ctx->lib, "nonblocking p2p recv failed");
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/mlx5/mcast/p2p/ucc_tl_mlx5_mcast_p2p.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
#include "components/tl/mlx5/mcast/tl_mlx5_mcast.h"

ucc_status_t ucc_tl_mlx5_mcast_p2p_send_nb(void* src, size_t size, ucc_rank_t
rank, void *context,
rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);

ucc_status_t ucc_tl_mlx5_mcast_p2p_recv_nb(void* dst, size_t size, ucc_rank_t
rank, void *context,
rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t
*obj);

Expand Down
8 changes: 5 additions & 3 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#define DEF_SL 0
#define DEF_SRC_PATH_BITS 0
#define GRH_LENGTH 40
#define DROP_THRESHOLD 1000
#define DROP_THRESHOLD 10000
#define MAX_COMM_POW2 32

/* Allgather RDMA-based reliability designs */
Expand All @@ -42,6 +42,8 @@
#define ONE_SIDED_SLOTS_COUNT 2 /* number of memory slots during async design */
#define ONE_SIDED_SLOTS_INFO_SIZE sizeof(uint32_t) /* size of metadata prepended to each slots in bytes */

#define CUDA_MEM_MCAST_BCAST_MAX_MSG 4000

enum {
MCAST_PROTO_EAGER, /* Internal staging buffers */
MCAST_PROTO_ZCOPY
Expand Down Expand Up @@ -75,12 +77,12 @@ typedef struct ucc_tl_mlx5_mcast_p2p_completion_obj {
typedef int (*ucc_tl_mlx5_mcast_p2p_wait_cb_fn_t)(void *wait_arg);

typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_send_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_rank_t rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);


typedef ucc_status_t (*ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t)(void* src, size_t size,
ucc_rank_t rank, void *context,
ucc_rank_t rank, ucc_memory_type_t mem_type, void *context,
ucc_tl_mlx5_mcast_p2p_completion_obj_t *compl_obj);

typedef struct ucc_tl_mlx5_mcast_p2p_interface {
Expand Down
35 changes: 32 additions & 3 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_bcast(void* buf, size_t siz
req->am_root = (root == comm->rank);
req->mr = comm->pp_mr;
req->rreg = NULL;
req->proto = (req->length < comm->max_eager) ? MCAST_PROTO_EAGER : MCAST_PROTO_ZCOPY;
/* cost of copy is too high in cuda buffers */
req->proto = (req->length < comm->max_eager && !comm->cuda_mem_enabled) ?
MCAST_PROTO_EAGER : MCAST_PROTO_ZCOPY;

status = ucc_tl_mlx5_mcast_prepare_reliable(comm, req, req->root);
if (ucc_unlikely(UCC_OK != status)) {
Expand Down Expand Up @@ -283,8 +285,8 @@ void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task)
}
}

ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team)
static inline ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team)
{
ucc_tl_mlx5_team_t *mlx5_team = ucc_derived_of(team, ucc_tl_mlx5_team_t);
ucc_tl_mlx5_mcast_coll_comm_t *comm = mlx5_team->mcast->mcast_comm;
Expand All @@ -300,6 +302,33 @@ ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_
return UCC_ERR_NO_RESOURCE;
}

ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team)
{
ucc_coll_args_t *args = &coll_args->args;
int buf_size = ucc_dt_size(args->src.info.datatype) * args->src.info.count;

if (UCC_COLL_ARGS_ACTIVE_SET(args)) {
tl_trace(team->context->lib, "mcast bcast not supported for active sets");
return UCC_ERR_NOT_SUPPORTED;
}

if (UCC_OK != ucc_tl_mlx5_mcast_check_memory_type_cap(coll_args, team)) {
tl_trace(team->context->lib, "mcast bcast not compatible with this memory type");
return UCC_ERR_NOT_SUPPORTED;
}

if (args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA &&
buf_size > CUDA_MEM_MCAST_BCAST_MAX_MSG) {
/* for large messages (more than one mtu) we need zero-copy design which
* is not implemented yet */
tl_trace(team->context->lib, "mcast cuda bcast not supported for large messages");
return UCC_ERR_NOT_IMPLEMENTED;
}

return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task)
{
task->super.post = ucc_tl_mlx5_mcast_bcast_start;
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task);

ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req);

ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team);
ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team);
#endif
45 changes: 37 additions & 8 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ static ucc_status_t ucc_tl_mlx5_mcast_reliability_send_completion(ucc_tl_mlx5_mc
comm->nack_requests--;
status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[pkt_id],
sizeof(struct packet), comm->p2p_pkt[pkt_id].from,
UCC_MEMORY_TYPE_HOST,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_recv_completion, pkt_id, NULL));
if (status < 0) {
Expand All @@ -48,6 +49,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_
uint32_t psn = comm->p2p_pkt[p2p_pkt_id].psn;
struct pp_packet *pp = comm->r_window[psn % comm->wsize];
ucc_status_t status;
ucc_memory_type_t mem_type;

ucc_assert(pp->psn == psn);
ucc_assert(comm->p2p_pkt[p2p_pkt_id].type == MCAST_P2P_NEED_NACK_SEND);
Expand All @@ -58,8 +60,14 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_
comm->comm_id, comm->rank,
comm->p2p_pkt[p2p_pkt_id].from, psn, pp->context, comm->nack_requests);

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
mem_type = UCC_MEMORY_TYPE_HOST;
}

status = comm->params.p2p_iface.send_nb((void*) (pp->context ? pp->context : pp->buf),
pp->length, comm->p2p_pkt[p2p_pkt_id].from,
pp->length, comm->p2p_pkt[p2p_pkt_id].from, mem_type,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_reliability_send_completion, NULL, p2p_pkt_id));
if (status < 0) {
Expand Down Expand Up @@ -138,11 +146,25 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_data_completion(ucc_tl_mlx5_mcast_p2p
struct pp_packet *pp = (struct pp_packet *)obj->data[1];
ucc_tl_mlx5_mcast_coll_req_t *req = (ucc_tl_mlx5_mcast_coll_req_t *)obj->data[2];
void *dest;
ucc_memory_type_t mem_type;

tl_trace(comm->lib, "[comm %d, rank %d] Recved data psn %d", comm->comm_id, comm->rank, pp->psn);

dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm);
memcpy(dest, (void*) pp->buf, pp->length);

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
mem_type = UCC_MEMORY_TYPE_HOST;
}

status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length,
mem_type, mem_type);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->lib, "failed to copy buffer");
return status;
}

req->to_recv--;
comm->r_window[pp->psn % comm->wsize] = pp;

Expand All @@ -165,6 +187,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas
struct pp_packet *pp;
ucc_rank_t parent;
struct packet *p;
ucc_memory_type_t mem_type;

p = ucc_calloc(1, sizeof(struct packet));
p->type = MCAST_P2P_NACK;
Expand All @@ -176,7 +199,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas

comm->nacks_counter++;

status = comm->params.p2p_iface.send_nb(p, sizeof(struct packet), parent,
status = comm->params.p2p_iface.send_nb(p, sizeof(struct packet), parent, UCC_MEMORY_TYPE_HOST,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_reliability_send_completion, p, UINT_MAX));
if (status < 0) {
Expand All @@ -193,8 +216,14 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas

comm->recv_drop_packet_in_progress = true;

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
mem_type = UCC_MEMORY_TYPE_HOST;
}

status = comm->params.p2p_iface.recv_nb((void*) pp->buf,
pp->length, parent,
pp->length, parent, mem_type,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_recv_data_completion, pp, req));
if (status < 0) {
Expand Down Expand Up @@ -225,7 +254,7 @@ ucc_status_t ucc_tl_mlx5_mcast_reliable_send(ucc_tl_mlx5_mcast_coll_comm_t *comm
comm->rank, parent, comm->parent_n, comm->psn);

status = comm->params.p2p_iface.send_nb(&comm->p2p_spkt[i],
sizeof(struct packet), parent,
sizeof(struct packet), parent, UCC_MEMORY_TYPE_HOST,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_send_completion, i, NULL));
if (status < 0) {
Expand Down Expand Up @@ -325,7 +354,7 @@ ucc_status_t ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *c
comm->rank, child, comm->child_n, comm->psn);

status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[comm->child_n - 1],
sizeof(struct packet), child,
sizeof(struct packet), child, UCC_MEMORY_TYPE_HOST,
comm->p2p_ctx, GET_COMPL_OBJ(comm,
ucc_tl_mlx5_mcast_recv_completion, comm->child_n - 1, req));
if (status < 0) {
Expand Down Expand Up @@ -369,8 +398,8 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com
ucc_tl_mlx5_mcast_coll_req_t *req,
struct pp_packet* pp)
{
ucc_status_t status = UCC_OK;
void *dest;
ucc_status_t status = UCC_OK;
void *dest;
ucc_memory_type_t mem_type;
ucc_assert(pp->psn >= req->start_psn &&
pp->psn < req->start_psn + req->num_packets);
Expand Down
12 changes: 3 additions & 9 deletions src/components/tl/mlx5/tl_mlx5_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,13 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args,
{
ucc_status_t status = UCC_OK;
ucc_tl_mlx5_task_t *task = NULL;

if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) {
tl_trace(team->context->lib, "mcast bcast not supported for active sets");
return UCC_ERR_NOT_SUPPORTED;
}

if (UCC_OK != ucc_tl_mlx5_mcast_check_memory_type_cap(coll_args, team)) {
tl_trace(team->context->lib, "mcast bcast not compatible with this memory type");
return UCC_ERR_NOT_SUPPORTED;
status = ucc_tl_mlx5_mcast_check_support(coll_args, team);
if (UCC_OK != status) {
return status;
}

task = ucc_tl_mlx5_get_task(coll_args, team);

if (ucc_unlikely(!task)) {
return UCC_ERR_NO_MEMORY;
}
Expand Down
Loading