diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 663ee636ed..3772f55616 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -441,6 +441,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_req { ucc_service_coll_req_t *allgather_rkeys_req; ucc_service_coll_req_t *barrier_req; void *recv_rreg; + ucc_ee_executor_task_t *exec_task; + ucc_coll_task_t *coll_task; } ucc_tl_mlx5_mcast_coll_req_t; typedef struct ucc_tl_mlx5_mcast_oob_p2p_context { @@ -555,6 +557,21 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_ return UCC_OK; } +#define EXEC_TASK_TEST(_errmsg, _etask, _lib) do { \ + if (_etask != NULL) { \ + status = ucc_ee_executor_task_test(_etask); \ + if (status > 0) { \ + return status; \ + } \ + ucc_ee_executor_task_finalize(_etask); \ + _etask = NULL; \ + if (ucc_unlikely(status < 0)) { \ + tl_error(_lib, _errmsg); \ + return status; \ + } \ + } \ +} while(0) + ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, ucc_tl_mlx5_mcast_team_t **mcast_team, ucc_tl_mlx5_mcast_context_t *ctx, diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index b6fbe84e3d..c89996dc4e 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -33,6 +33,11 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_ return status; } + if (comm->cuda_mem_enabled) { + while (req->exec_task != NULL) { + EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib); + } + } comm->bcast_comm.n_mcast_reliable++; for (; comm->bcast_comm.last_acked < comm->psn; comm->bcast_comm.last_acked++) { @@ -267,7 +272,10 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) return ucc_task_complete(coll_task); } - coll_task->status = status; + ucc_assert(task->coll_mcast.req_handle != NULL); + + coll_task->status = status; + task->coll_mcast.req_handle->coll_task = coll_task; return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super); } @@ -329,10 +337,34 @@ ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args, return UCC_OK; } -ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) +ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task, + ucc_base_coll_args_t *coll_args) { + ucc_coll_args_t *args = &coll_args->args; + task->super.post = ucc_tl_mlx5_mcast_bcast_start; task->super.progress = ucc_tl_mlx5_mcast_collective_progress; + if (args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA) { + task->super.flags = UCC_COLL_TASK_FLAG_EXECUTOR; + } return UCC_OK; } + +ucc_status_t ucc_tl_mlx5_mcast_schedule_start(ucc_coll_task_t *coll_task) +{ + return ucc_schedule_start(coll_task); +} + +ucc_status_t ucc_tl_mlx5_mcast_schedule_finalize(ucc_coll_task_t *coll_task) +{ + ucc_status_t status; + ucc_tl_mlx5_schedule_t *schedule = + ucc_derived_of(coll_task, ucc_tl_mlx5_schedule_t); + + status = ucc_schedule_finalize(coll_task); + + ucc_tl_mlx5_put_schedule(schedule); + return status; +} + diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h index f34e8827f4..4f52c6d6b0 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h @@ -10,10 +10,15 @@ #include "tl_mlx5_mcast.h" #include "tl_mlx5_coll.h" -ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task); +ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task, + ucc_base_coll_args_t *coll_args); ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req); ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team); + +ucc_status_t ucc_tl_mlx5_mcast_schedule_start(ucc_coll_task_t *coll_task); + +ucc_status_t ucc_tl_mlx5_mcast_schedule_finalize(ucc_coll_task_t *coll_task); #endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c index 3620cf629f..5bf9739084 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -391,9 +391,10 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com ucc_tl_mlx5_mcast_coll_req_t *req, struct pp_packet* pp) { - ucc_status_t status = UCC_OK; - void *dest; - ucc_memory_type_t mem_type; + ucc_status_t status = UCC_OK; + void *dest; + ucc_ee_executor_task_args_t eargs; + ucc_ee_executor_t *exec; ucc_assert(pp->psn >= req->start_psn && pp->psn < req->start_psn + req->num_packets); @@ -402,18 +403,29 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com if (pp->length > 0 ) { dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); - if (comm->cuda_mem_enabled) { - mem_type = UCC_MEMORY_TYPE_CUDA; - } else { - mem_type = UCC_MEMORY_TYPE_HOST; - } + while (req->exec_task != NULL) { + EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib); + } - status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length, - mem_type, mem_type); - if (ucc_unlikely(status != UCC_OK)) { - tl_error(comm->lib, "failed to copy buffer"); - return status; + /* for cuda memcpy use nonblocking copy */ + status = ucc_coll_task_get_executor(req->coll_task, &exec); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + + eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; + eargs.copy.src = (void*) pp->buf; + eargs.copy.dst = dest; + eargs.copy.len = pp->length; + + assert(req->exec_task == NULL); + status = ucc_ee_executor_task_post(exec, &eargs, &req->exec_task); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + } else { + memcpy(dest, (void*) pp->buf, pp->length); } } diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index 94d336ba6e..39cfa54d74 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -14,8 +14,12 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h) { - ucc_status_t status = UCC_OK; - ucc_tl_mlx5_task_t *task = NULL; + ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_task_t *task = NULL; + ucc_coll_task_t *bcast_task; + ucc_schedule_t *schedule; + status = ucc_tl_mlx5_mcast_check_support(coll_args, team); if (UCC_OK != status) { @@ -31,16 +35,39 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, switch (coll_args->args.coll_type) { case UCC_COLL_TYPE_BCAST: - status = ucc_tl_mlx5_mcast_bcast_init(task); + status = ucc_tl_mlx5_get_schedule(tl_team, coll_args, + (ucc_tl_mlx5_schedule_t **)&schedule); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + status = ucc_tl_mlx5_mcast_bcast_init(task, coll_args); if (ucc_unlikely(UCC_OK != status)) { goto free_task; } + bcast_task = &(task->super); + + status = ucc_schedule_add_task(schedule, bcast_task); + if (ucc_unlikely(UCC_OK != status)) { + goto free_task; + } + status = ucc_event_manager_subscribe(&schedule->super, + UCC_EVENT_SCHEDULE_STARTED, + bcast_task, + ucc_task_start_handler); + if (ucc_unlikely(UCC_OK != status)) { + goto free_task; + } + schedule->super.post = ucc_tl_mlx5_mcast_schedule_start; + schedule->super.progress = NULL; + schedule->super.finalize = ucc_tl_mlx5_mcast_schedule_finalize; + *task_h = &schedule->super; break; case UCC_COLL_TYPE_ALLGATHER: status = ucc_tl_mlx5_mcast_allgather_init(task); if (ucc_unlikely(UCC_OK != status)) { goto free_task; } + *task_h = &(task->super); break; default: status = UCC_ERR_NOT_SUPPORTED; @@ -48,8 +75,6 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, goto free_task; } - *task_h = &(task->super); - tl_debug(UCC_TASK_LIB(task), "initialized mcast collective task %p", task); return UCC_OK;