diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 663ee636ed..3772f55616 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -441,6 +441,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_req { ucc_service_coll_req_t *allgather_rkeys_req; ucc_service_coll_req_t *barrier_req; void *recv_rreg; + ucc_ee_executor_task_t *exec_task; + ucc_coll_task_t *coll_task; } ucc_tl_mlx5_mcast_coll_req_t; typedef struct ucc_tl_mlx5_mcast_oob_p2p_context { @@ -555,6 +557,21 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_ return UCC_OK; } +#define EXEC_TASK_TEST(_errmsg, _etask, _lib) do { \ + if (_etask != NULL) { \ + status = ucc_ee_executor_task_test(_etask); \ + if (status > 0) { \ + return status; \ + } \ + ucc_ee_executor_task_finalize(_etask); \ + _etask = NULL; \ + if (ucc_unlikely(status < 0)) { \ + tl_error(_lib, _errmsg); \ + return status; \ + } \ + } \ +} while(0) + ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, ucc_tl_mlx5_mcast_team_t **mcast_team, ucc_tl_mlx5_mcast_context_t *ctx, diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index b6fbe84e3d..cf813fd5af 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -33,6 +33,10 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_ return status; } + while (req->exec_task != NULL) { + EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib); + } + comm->bcast_comm.n_mcast_reliable++; for (; comm->bcast_comm.last_acked < comm->psn; comm->bcast_comm.last_acked++) { @@ -267,7 +271,10 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) return ucc_task_complete(coll_task); } - coll_task->status = status; + ucc_assert(task->coll_mcast.req_handle != NULL); + + coll_task->status = status; + task->coll_mcast.req_handle->coll_task = coll_task; return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super); } @@ -333,6 +340,7 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) { task->super.post = ucc_tl_mlx5_mcast_bcast_start; task->super.progress = ucc_tl_mlx5_mcast_collective_progress; + task->super.flags = UCC_COLL_TASK_FLAG_EXECUTOR; return UCC_OK; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h index f34e8827f4..ccc563ecc7 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h @@ -16,4 +16,5 @@ ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req); ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team); + #endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c index 3620cf629f..8031af6dc0 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -391,9 +391,10 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com ucc_tl_mlx5_mcast_coll_req_t *req, struct pp_packet* pp) { - ucc_status_t status = UCC_OK; - void *dest; - ucc_memory_type_t mem_type; + ucc_status_t status = UCC_OK; + void *dest; + ucc_ee_executor_task_args_t eargs; + ucc_ee_executor_t *exec; ucc_assert(pp->psn >= req->start_psn && pp->psn < req->start_psn + req->num_packets); @@ -402,19 +403,30 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com if (pp->length > 0 ) { dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); - - if (comm->cuda_mem_enabled) { - mem_type = UCC_MEMORY_TYPE_CUDA; - } else { - mem_type = UCC_MEMORY_TYPE_HOST; + while (req->exec_task != NULL) { + EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib); } - status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length, - mem_type, mem_type); + /* for cuda copy, exec is nonblocking but for host copy it is blocking */ + status = ucc_coll_task_get_executor(req->coll_task, &exec); if (ucc_unlikely(status != UCC_OK)) { - tl_error(comm->lib, "failed to copy buffer"); return status; } + + eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; + eargs.copy.src = (void*) pp->buf; + eargs.copy.dst = dest; + eargs.copy.len = pp->length; + + assert(req->exec_task == NULL); + status = ucc_ee_executor_task_post(exec, &eargs, &req->exec_task); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + + if (req->exec_task != NULL) { + EXEC_TASK_TEST("failed to progress the memcpy", req->exec_task, comm->lib); + } } comm->r_window[pp->psn & (comm->bcast_comm.wsize-1)] = pp; diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index 94d336ba6e..aabdbf8010 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -14,8 +14,8 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h) { - ucc_status_t status = UCC_OK; - ucc_tl_mlx5_task_t *task = NULL; + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_task_t *task = NULL; status = ucc_tl_mlx5_mcast_check_support(coll_args, team); if (UCC_OK != status) { @@ -35,12 +35,14 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, if (ucc_unlikely(UCC_OK != status)) { goto free_task; } + *task_h = &(task->super); break; case UCC_COLL_TYPE_ALLGATHER: status = ucc_tl_mlx5_mcast_allgather_init(task); if (ucc_unlikely(UCC_OK != status)) { goto free_task; } + *task_h = &(task->super); break; default: status = UCC_ERR_NOT_SUPPORTED; @@ -48,8 +50,6 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, goto free_task; } - *task_h = &(task->super); - tl_debug(UCC_TASK_LIB(task), "initialized mcast collective task %p", task); return UCC_OK;