Skip to content

Commit

Permalink
TL/MLX5: Addressing Sergey's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Jul 16, 2024
1 parent 569af1f commit 75576ee
Show file tree
Hide file tree
Showing 20 changed files with 119 additions and 106 deletions.
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/alltoall/alltoall.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/alltoall/alltoall.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/alltoall/alltoall_mkeys.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/alltoall/alltoall_mkeys.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_context {
struct rdma_cm_id *id;
struct rdma_event_channel *channel;
ucc_mpool_t compl_objects_mp;
ucc_mpool_t mcast_req_mp;
ucc_list_link_t pending_nacks_list;
ucc_rcache_t *rcache;
ucc_tl_mlx5_mcast_ctx_params_t params;
Expand All @@ -179,7 +180,6 @@ typedef struct ucc_tl_mlx5_mcast_context {
ucc_thread_mode_t tm;
ucc_tl_mlx5_mcast_coll_context_t mcast_context;
ucc_tl_mlx5_mcast_context_config_t cfg;
ucc_mpool_t req_mp;
int mcast_enabled;
int mcast_ctx_ready;
ucc_tl_mlx5_mcast_oob_ctx_t oob_ctx;
Expand Down
164 changes: 80 additions & 84 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reset_reliablity(ucc_tl_mlx5_mcast_
static inline void ucc_tl_mlx5_mcast_init_async_reliability_slots(ucc_tl_mlx5_mcast_coll_req_t *req)
{
ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm;
void *dest;
char *dest;

ucc_assert(req->ag_counter == comm->allgather_comm.under_progress_counter);

Expand Down Expand Up @@ -219,33 +219,34 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_staging_based_allgather(ucc_tl_m

if (MCAST_ALLGATHER_IN_PROGRESS(req, comm)) {
return UCC_INPROGRESS;
} else {
if (ONE_SIDED_SYNCHRONOUS_PROTO == req->one_sided_reliability_scheme) {
if (!req->barrier_req) {
// mcast operations are done and now go to barrier
status = comm->service_coll.barrier_post(comm->p2p_ctx, &req->barrier_req);
if (status != UCC_OK) {
return status;
}
tl_trace(comm->lib, "mcast operations are done and now go to barrier");
return UCC_INPROGRESS;
} else {
status = comm->service_coll.coll_test(req->barrier_req);
if (status == UCC_OK) {
req->barrier_req = NULL;
tl_trace(comm->lib, "barrier at the end of mcast allgather is completed");
} else {
return status;
}
}
}

if (ONE_SIDED_SYNCHRONOUS_PROTO == req->one_sided_reliability_scheme) {
/* mcast operations are all done, now wait until all the processes
* are done with their mcast operations */
if (!req->barrier_req) {
// mcast operations are done and now go to barrier
status = comm->service_coll.barrier_post(comm->p2p_ctx, &req->barrier_req);
if (status != UCC_OK) {
return status;
}
tl_trace(comm->lib, "mcast operations are done and now go to barrier");
}

/* this task is completed */
return UCC_OK;
status = comm->service_coll.coll_test(req->barrier_req);
if (status == UCC_OK) {
req->barrier_req = NULL;
tl_trace(comm->lib, "barrier at the end of mcast allgather is completed");
} else {
return status;
}
}

/* this task is completed */
return UCC_OK;
}

static inline ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_coll_req_t* req)
static inline ucc_status_t ucc_tl_mlx5_mcast_allgather_test(ucc_tl_mlx5_mcast_coll_req_t* req)
{
ucc_status_t status;

Expand All @@ -272,7 +273,49 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_co

ucc_status_t ucc_tl_mlx5_mcast_allgather_start(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
return UCC_OK;
}

void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
ucc_tl_mlx5_mcast_coll_req_t *req = task->coll_mcast.req_handle;
ucc_status_t status;

ucc_assert(req != NULL);

if (req->ag_counter != req->comm->allgather_comm.under_progress_counter) {
/* it is not this task's turn for progress */
ucc_assert(req->comm->allgather_comm.under_progress_counter < req->ag_counter);
return;
}

status = ucc_tl_mlx5_mcast_allgather_test(task->coll_mcast.req_handle);
if (UCC_INPROGRESS == status) {
return;
} else if (UCC_OK == status) {
coll_task->status = UCC_OK;
req->comm->allgather_comm.under_progress_counter++;
ucc_mpool_put(req);
task->coll_mcast.req_handle = NULL;
} else {
tl_error(UCC_TASK_LIB(task), "progress mcast allgather failed:%d", status);
coll_task->status = status;
if (req->rreg) {
ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->rreg);
req->rreg = NULL;
}
if (req->recv_rreg) {
ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->recv_rreg);
req->recv_rreg = NULL;
}
ucc_mpool_put(req);
}
}

ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task)
{
ucc_coll_task_t *coll_task = &(task->super);
ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task);
ucc_tl_mlx5_mcast_team_t *team = mlx5_team->mcast;
ucc_coll_args_t *args = &TASK_ARGS(task);
Expand All @@ -286,23 +329,24 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_start(ucc_coll_task_t *coll_task)
ucc_tl_mlx5_mcast_reg_t *reg = NULL;
ucc_tl_mlx5_mcast_coll_req_t *req;


if (!data_size) {
coll_task->status = UCC_OK;
return ucc_task_complete(coll_task);
}

task->coll_mcast.req_handle = NULL;

tl_trace(comm->lib, "MCAST allgather start, sbuf %p, rbuf %p, size %ld, comm %d, "
tl_trace(comm->lib, "MCAST allgather init, sbuf %p, rbuf %p, size %ld, comm %d, "
"comm_size %d, counter %d",
sbuf, rbuf, data_size, comm->comm_id, comm->commsize, comm->allgather_comm.coll_counter);

req = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_req_t), "mcast_req");
req = ucc_mpool_get(&comm->ctx->mcast_req_mp);
if (!req) {
tl_warn(comm->lib, "malloc failed");
tl_error(comm->lib, "failed to get a mcast req");
status = UCC_ERR_NO_MEMORY;
goto failed;
}
memset(req, 0, sizeof(ucc_tl_mlx5_mcast_coll_req_t));

req->comm = comm;
req->ptr = sbuf;
Expand Down Expand Up @@ -331,15 +375,15 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_start(ucc_coll_task_t *coll_task)
goto failed;
}

req->last_pkt_len = req->length - (req->num_packets - 1)*comm->max_per_packet;
req->last_pkt_len = req->length % comm->max_per_packet;

ucc_assert(req->last_pkt_len > 0 && req->last_pkt_len <= comm->max_per_packet);

if (req->proto == MCAST_PROTO_ZCOPY) {
/* register the send buffer */
status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->ptr, req->length, &reg);
if (UCC_OK != status) {
ucc_free(req);
tl_error(comm->lib, "sendbuf registeration failed");
goto failed;
}
req->rreg = reg;
Expand All @@ -362,63 +406,15 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_start(ucc_coll_task_t *coll_task)

task->coll_mcast.req_handle = req;
coll_task->status = UCC_INPROGRESS;
task->super.post = ucc_tl_mlx5_mcast_allgather_start;
task->super.progress = ucc_tl_mlx5_mcast_allgather_progress;
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super);

failed:
tl_warn(UCC_TASK_LIB(task), "mcast start allgather failed:%d", status);
coll_task->status = status;
return ucc_task_complete(coll_task);
}

void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
ucc_tl_mlx5_mcast_coll_req_t *req = task->coll_mcast.req_handle;
ucc_status_t status;

if (req != NULL) {
if (req->ag_counter != req->comm->allgather_comm.under_progress_counter) {
/* it is not this task's turn for progress */
ucc_assert(req->comm->allgather_comm.under_progress_counter < req->ag_counter);
return;
}

status = ucc_tl_mlx5_mcast_test_allgather(task->coll_mcast.req_handle);
if (UCC_INPROGRESS == status) {
return;
} else if (UCC_OK == status) {
coll_task->status = UCC_OK;
req->comm->allgather_comm.under_progress_counter++;
ucc_free(req);
task->coll_mcast.req_handle = NULL;
} else {
tl_error(UCC_TASK_LIB(task), "progress mcast allgather failed:%d", status);
coll_task->status = status;
if (req->rreg) {
ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->rreg);
req->rreg = NULL;
}
if (req->recv_rreg) {
ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->recv_rreg);
req->recv_rreg = NULL;
}
ucc_free(req);
ucc_task_complete(coll_task);
}
} else {
tl_error(UCC_TASK_LIB(task), "progress mcast allgather failed, mcast coll not initialized");
coll_task->status = UCC_ERR_NO_RESOURCE;
ucc_task_complete(coll_task);
tl_warn(UCC_TASK_LIB(task), "mcast init allgather failed:%d", status);
if (req) {
ucc_mpool_put(req);
}

return;
}

ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task)
{
task->super.post = ucc_tl_mlx5_mcast_allgather_start;
task->super.progress = ucc_tl_mlx5_mcast_allgather_progress;

return UCC_OK;
return status;
}

6 changes: 4 additions & 2 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,16 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_do_bcast(void* buf, size_t size, ucc_rank_t
buf, size, root, comm->comm_id, comm->commsize, comm->rank ==
root, comm->psn );

req = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_req_t), "mcast_req");
req = ucc_mpool_get(&comm->ctx->mcast_req_mp);
if (!req) {
tl_error(comm->lib, "failed to get mcast req");
return UCC_ERR_NO_MEMORY;
}
memset(req, 0, sizeof(ucc_tl_mlx5_mcast_coll_req_t));

status = ucc_tl_mlx5_mcast_prepare_bcast(buf, size, root, comm, req);
if (UCC_OK != status) {
ucc_free(req);
ucc_mpool_put(req);
return status;
}

Expand Down
12 changes: 12 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont
device_attr.max_cq, device_attr.max_cqe);

ctx->max_qp_wr = device_attr.max_qp_wr;

status = ucc_mpool_init(&ctx->compl_objects_mp, 0, sizeof(ucc_tl_mlx5_mcast_p2p_completion_obj_t), 0,
UCC_CACHE_LINE_SIZE, 8, UINT_MAX,
&ucc_coll_task_mpool_ops,
Expand All @@ -249,6 +250,17 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont
goto error;
}

status = ucc_mpool_init(&ctx->mcast_req_mp, 0, sizeof(ucc_tl_mlx5_mcast_coll_req_t), 0,
UCC_CACHE_LINE_SIZE, 8, UINT_MAX,
&ucc_coll_task_mpool_ops,
UCC_THREAD_SINGLE,
"ucc_tl_mlx5_mcast_coll_req_t");
if (ucc_unlikely(UCC_OK != status)) {
tl_warn(lib, "failed to initialize mcast_req_mp mpool");
status = UCC_ERR_NO_MEMORY;
goto error;
}

ctx->rcache = NULL;
status = ucc_tl_mlx5_mcast_setup_rcache(ctx);
if (UCC_OK != status) {
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/mlx5/tl_mlx5.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -98,7 +98,7 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = {

{"MCAST_ONE_SIDED_RELIABILITY_ENABLE", "1", "Enable one sided reliability for mcast",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable),
UCC_CONFIG_TYPE_INT},
UCC_CONFIG_TYPE_BOOL},

{NULL}};

Expand Down
9 changes: 6 additions & 3 deletions src/components/tl/mlx5/tl_mlx5_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args,
ucc_status_t status = UCC_OK;
ucc_tl_mlx5_task_t *task = NULL;

if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) {
tl_trace(team->context->lib, "mcast collective not supported for active sets");
if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args) ||
UCC_IS_INPLACE(coll_args->args)) {
tl_trace(team->context->lib, "mcast collective not supported");
return UCC_ERR_NOT_SUPPORTED;
}

Expand All @@ -44,6 +45,8 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args,
break;
default:
status = UCC_ERR_NOT_SUPPORTED;
tl_trace(team->context->lib, "mcast not supported for this collective type");
goto free_task;
}

*task_h = &(task->super);
Expand All @@ -64,7 +67,7 @@ ucc_status_t ucc_tl_mlx5_task_finalize(ucc_coll_task_t *coll_task)

if (req != NULL) {
ucc_assert(coll_task->status != UCC_INPROGRESS);
ucc_free(req);
ucc_mpool_put(req);
tl_trace(UCC_TASK_LIB(task), "finalizing an mcast task %p", task);
task->coll_mcast.req_handle = NULL;
}
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_coll.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_context.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_dm.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_dm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
Loading

0 comments on commit 75576ee

Please sign in to comment.