diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 734acf1f30..5e54b69c9c 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -368,6 +368,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_req { int last_pkt_len; int offset; ucc_memory_type_t buf_mem_type; + ucc_ee_executor_task_t *exec_task; + ucc_coll_task_t *coll_task; } ucc_tl_mlx5_mcast_coll_req_t; typedef struct ucc_tl_mlx5_mcast_oob_p2p_context { @@ -427,6 +429,21 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast return UCC_OK; } +#define EXEC_TASK_TEST(_errmsg, _etask, _lib) do { \ + if (_etask != NULL) { \ + status = ucc_ee_executor_task_test(_etask); \ + if (status > 0) { \ + return status; \ + } \ + ucc_ee_executor_task_finalize(_etask); \ + _etask = NULL; \ + if (ucc_unlikely(status < 0)) { \ + tl_error(_lib, _errmsg); \ + return status; \ + } \ + } \ +} while(0) + ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, ucc_tl_mlx5_mcast_team_t **mcast_team, ucc_tl_mlx5_mcast_context_t *ctx, diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index 28c4bbce61..5288e46727 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -33,6 +33,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_ return status; } + if (comm->cuda_mem_enabled) { + while (req->exec_task != NULL) { + EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib); + } + } + comm->n_mcast_reliable++; for (;comm->last_acked < comm->psn; comm->last_acked++) { @@ -270,7 +276,10 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) return ucc_task_complete(coll_task); } - coll_task->status = status; + ucc_assert(task->bcast_mcast.req_handle != NULL); + + coll_task->status = status; + task->bcast_mcast.req_handle->coll_task = coll_task; return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super); } @@ -329,10 +338,34 @@ ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args, return UCC_OK; } -ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) +ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task, + ucc_base_coll_args_t *coll_args) { + ucc_coll_args_t *args = &coll_args->args; + task->super.post = ucc_tl_mlx5_mcast_bcast_start; task->super.progress = ucc_tl_mlx5_mcast_collective_progress; + if (args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA) { + task->super.flags = UCC_COLL_TASK_FLAG_EXECUTOR; + } return UCC_OK; } + +ucc_status_t ucc_tl_mlx5_mcast_schedule_start(ucc_coll_task_t *coll_task) +{ + return ucc_schedule_start(coll_task); +} + +ucc_status_t ucc_tl_mlx5_mcast_schedule_finalize(ucc_coll_task_t *coll_task) +{ + ucc_status_t status; + ucc_tl_mlx5_schedule_t *schedule = + ucc_derived_of(coll_task, ucc_tl_mlx5_schedule_t); + + status = ucc_schedule_finalize(coll_task); + + ucc_tl_mlx5_put_schedule(schedule); + return status; +} + diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h index f34e8827f4..4f52c6d6b0 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h @@ -10,10 +10,15 @@ #include "tl_mlx5_mcast.h" #include "tl_mlx5_coll.h" -ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task); +ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task, + ucc_base_coll_args_t *coll_args); ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req); ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team); + +ucc_status_t ucc_tl_mlx5_mcast_schedule_start(ucc_coll_task_t *coll_task); + +ucc_status_t ucc_tl_mlx5_mcast_schedule_finalize(ucc_coll_task_t *coll_task); #endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c index 41b6ca14f9..75505a9e1e 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -398,9 +398,10 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com ucc_tl_mlx5_mcast_coll_req_t *req, struct pp_packet* pp) { - ucc_status_t status = UCC_OK; - void *dest; - ucc_memory_type_t mem_type; + ucc_status_t status = UCC_OK; + void *dest; + ucc_ee_executor_task_args_t eargs; + ucc_ee_executor_t *exec; ucc_assert(pp->psn >= req->start_psn && pp->psn < req->start_psn + req->num_packets); @@ -409,18 +410,29 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com if (pp->length > 0 ) { dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); - if (comm->cuda_mem_enabled) { - mem_type = UCC_MEMORY_TYPE_CUDA; - } else { - mem_type = UCC_MEMORY_TYPE_HOST; - } + while (req->exec_task != NULL) { + EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib); + } - status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length, - mem_type, mem_type); - if (ucc_unlikely(status != UCC_OK)) { - tl_error(comm->lib, "failed to copy buffer"); - return status; + /* for cuda memcpy use nonblocking copy */ + status = ucc_coll_task_get_executor(req->coll_task, &exec); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + + eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; + eargs.copy.src = (void*) pp->buf; + eargs.copy.dst = dest; + eargs.copy.len = pp->length; + + assert(req->exec_task == NULL); + status = ucc_ee_executor_task_post(exec, &eargs, &req->exec_task); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + } else { + memcpy(dest, (void*) pp->buf, pp->length); } } diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index 909b457325..da993e34d7 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -12,8 +12,12 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h) { - ucc_status_t status = UCC_OK; - ucc_tl_mlx5_task_t *task = NULL; + ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_task_t *task = NULL; + ucc_coll_task_t *bcast_task; + ucc_schedule_t *schedule; + status = ucc_tl_mlx5_mcast_check_support(coll_args, team); if (UCC_OK != status) { @@ -27,12 +31,36 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, task->super.finalize = ucc_tl_mlx5_task_finalize; - status = ucc_tl_mlx5_mcast_bcast_init(task); + status = ucc_tl_mlx5_get_schedule(tl_team, coll_args, + (ucc_tl_mlx5_schedule_t **)&schedule); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + + status = ucc_tl_mlx5_mcast_bcast_init(task, coll_args); + if (ucc_unlikely(UCC_OK != status)) { + goto free_task; + } + + bcast_task = &(task->super); + + status = ucc_schedule_add_task(schedule, bcast_task); + if (ucc_unlikely(UCC_OK != status)) { + goto free_task; + } + + status = ucc_event_manager_subscribe(&schedule->super, + UCC_EVENT_SCHEDULE_STARTED, + bcast_task, + ucc_task_start_handler); if (ucc_unlikely(UCC_OK != status)) { goto free_task; } - *task_h = &(task->super); + schedule->super.post = ucc_tl_mlx5_mcast_schedule_start; + schedule->super.progress = NULL; + schedule->super.finalize = ucc_tl_mlx5_mcast_schedule_finalize; + *task_h = &schedule->super; tl_debug(UCC_TASK_LIB(task), "init coll task %p", task);