Skip to content

Commit

Permalink
TL/MLX5: add nonblocking cudaMemcpy support
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Dec 16, 2024
1 parent 73651ea commit 4104cc9
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 16 deletions.
17 changes: 17 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_req {
ucc_service_coll_req_t *allgather_rkeys_req;
ucc_service_coll_req_t *barrier_req;
void *recv_rreg;
ucc_ee_executor_task_t *exec_task;
ucc_coll_task_t *coll_task;
} ucc_tl_mlx5_mcast_coll_req_t;

typedef struct ucc_tl_mlx5_mcast_oob_p2p_context {
Expand Down Expand Up @@ -555,6 +557,21 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_
return UCC_OK;
}

#define EXEC_TASK_TEST(_errmsg, _etask, _lib) do { \
if (_etask != NULL) { \
status = ucc_ee_executor_task_test(_etask); \
if (status > 0) { \
return status; \
} \
ucc_ee_executor_task_finalize(_etask); \
_etask = NULL; \
if (ucc_unlikely(status < 0)) { \
tl_error(_lib, _errmsg); \
return status; \
} \
} \
} while(0)

ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context,
ucc_tl_mlx5_mcast_team_t **mcast_team,
ucc_tl_mlx5_mcast_context_t *ctx,
Expand Down
10 changes: 9 additions & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_
return status;
}

while (req->exec_task != NULL) {
EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib);
}

comm->bcast_comm.n_mcast_reliable++;

for (; comm->bcast_comm.last_acked < comm->psn; comm->bcast_comm.last_acked++) {
Expand Down Expand Up @@ -267,7 +271,10 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task)
return ucc_task_complete(coll_task);
}

coll_task->status = status;
ucc_assert(task->coll_mcast.req_handle != NULL);

coll_task->status = status;
task->coll_mcast.req_handle->coll_task = coll_task;

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super);
}
Expand Down Expand Up @@ -333,6 +340,7 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task)
{
task->super.post = ucc_tl_mlx5_mcast_bcast_start;
task->super.progress = ucc_tl_mlx5_mcast_collective_progress;
task->super.flags = UCC_COLL_TASK_FLAG_EXECUTOR;

return UCC_OK;
}
1 change: 1 addition & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req);

ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team);

#endif
34 changes: 23 additions & 11 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,10 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com
ucc_tl_mlx5_mcast_coll_req_t *req,
struct pp_packet* pp)
{
ucc_status_t status = UCC_OK;
void *dest;
ucc_memory_type_t mem_type;
ucc_status_t status = UCC_OK;
void *dest;
ucc_ee_executor_task_args_t eargs;
ucc_ee_executor_t *exec;
ucc_assert(pp->psn >= req->start_psn &&
pp->psn < req->start_psn + req->num_packets);

Expand All @@ -402,19 +403,30 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com

if (pp->length > 0 ) {
dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm);

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
mem_type = UCC_MEMORY_TYPE_HOST;
while (req->exec_task != NULL) {
EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib);
}

status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length,
mem_type, mem_type);
/* for cuda copy, exec is nonblocking but for host copy it is blocking */
status = ucc_coll_task_get_executor(req->coll_task, &exec);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->lib, "failed to copy buffer");
return status;
}

eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.src = (void*) pp->buf;
eargs.copy.dst = dest;
eargs.copy.len = pp->length;

assert(req->exec_task == NULL);
status = ucc_ee_executor_task_post(exec, &eargs, &req->exec_task);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}

if (req->exec_task != NULL) {
EXEC_TASK_TEST("failed to progress the memcpy", req->exec_task, comm->lib);
}
}

comm->r_window[pp->psn & (comm->bcast_comm.wsize-1)] = pp;
Expand Down
8 changes: 4 additions & 4 deletions src/components/tl/mlx5/tl_mlx5_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_status_t status = UCC_OK;
ucc_tl_mlx5_task_t *task = NULL;
ucc_status_t status = UCC_OK;
ucc_tl_mlx5_task_t *task = NULL;

status = ucc_tl_mlx5_mcast_check_support(coll_args, team);
if (UCC_OK != status) {
Expand All @@ -35,21 +35,21 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args,
if (ucc_unlikely(UCC_OK != status)) {
goto free_task;
}
*task_h = &(task->super);
break;
case UCC_COLL_TYPE_ALLGATHER:
status = ucc_tl_mlx5_mcast_allgather_init(task);
if (ucc_unlikely(UCC_OK != status)) {
goto free_task;
}
*task_h = &(task->super);
break;
default:
status = UCC_ERR_NOT_SUPPORTED;
tl_trace(team->context->lib, "mcast not supported for this collective type");
goto free_task;
}

*task_h = &(task->super);

tl_debug(UCC_TASK_LIB(task), "initialized mcast collective task %p", task);

return UCC_OK;
Expand Down

0 comments on commit 4104cc9

Please sign in to comment.