From 43dd1d76ee7be15b259f14f8f59132ac131723ff Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 21 Aug 2024 13:54:51 +0300 Subject: [PATCH] TL/MLX5: clean and fix after rebase lintrunner cleaning --- src/components/tl/mlx5/alltoall/alltoall.h | 6 +- .../tl/mlx5/alltoall/alltoall_coll.c | 255 ++++++++++-------- .../tl/mlx5/alltoall/alltoall_mkeys.c | 12 +- src/components/tl/mlx5/tl_mlx5.c | 27 +- src/components/tl/mlx5/tl_mlx5.h | 6 +- src/components/tl/mlx5/tl_mlx5_dm.c | 39 ++- src/components/tl/mlx5/tl_mlx5_wqe.c | 3 +- test/mpi/buffer.cc | 7 - 8 files changed, 194 insertions(+), 161 deletions(-) diff --git a/src/components/tl/mlx5/alltoall/alltoall.h b/src/components/tl/mlx5/alltoall/alltoall.h index 92b69e2e6f..c49a0ff9cf 100644 --- a/src/components/tl/mlx5/alltoall/alltoall.h +++ b/src/components/tl/mlx5/alltoall/alltoall.h @@ -58,9 +58,9 @@ typedef struct ucc_tl_mlx5_alltoall_node { struct mlx5dv_mkey *team_recv_mkey; void *umr_entries_buf; struct ibv_mr *umr_entries_mr; - int fanin_index; - int fanin_dist; - int fanin_max_dist; + int fanin_index; + int fanin_dist; + int fanin_max_dist; } ucc_tl_mlx5_alltoall_node_t; typedef struct alltoall_net_ctrl { diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 324e9909b8..10a08989a9 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -83,7 +83,6 @@ static ucc_status_t ucc_tl_mlx5_poll_cq(struct ibv_cq *cq, ucc_base_lib_t *lib) ucc_tl_mlx5_dm_chunk_t *dm = (ucc_tl_mlx5_dm_chunk_t *)wcs[i].wr_id; dm->task->alltoall.blocks_completed++; dm->completed_sends++; - /* printf("returning dm %p to pool\n", (void*)team->work_completion[i].wr_id); */ if (dm->posted_all && dm->completed_sends == dm->posted_sends) { ucc_mpool_put(dm); } @@ -95,20 +94,19 @@ static ucc_status_t ucc_tl_mlx5_poll_cq(struct ibv_cq *cq, ucc_base_lib_t *lib) static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team, ucc_tl_mlx5_schedule_t *task) { - ucc_tl_mlx5_alltoall_t * a2a = team->a2a; + ucc_tl_mlx5_alltoall_t *a2a = team->a2a; int seq_index = task->alltoall.seq_index; int npolls = UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls; - int radix = UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix; - int vrank = a2a->node.sbgp->group_rank - a2a->node.asr_rank; - int *dist = &a2a->node.fanin_dist; - int size = a2a->node.sbgp->group_size; - int seq_num = task->alltoall.seq_num; - int c_flag = 0; - int polls; - int peer, vpeer, pos, i; + int radix = UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix; + int vrank = a2a->node.sbgp->group_rank - a2a->node.asr_rank; + int *dist = &a2a->node.fanin_dist; + int size = a2a->node.sbgp->group_size; + int seq_num = task->alltoall.seq_num; + int c_flag = 0; + int polls; + int peer, vpeer, pos, i; ucc_tl_mlx5_alltoall_ctrl_t *ctrl_v; - while (*dist <= a2a->node.fanin_max_dist) { if (vrank % *dist == 0) { pos = (vrank / *dist) % radix; @@ -119,7 +117,7 @@ static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team, a2a->node.fanin_index = radix; break; } - peer = (vpeer + a2a->node.asr_rank) % size; + peer = (vpeer + a2a->node.asr_rank) % size; ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, peer); for (polls = 0; polls < npolls; polls++) { if (ctrl_v->seq_num == seq_num) { @@ -133,22 +131,26 @@ static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team, } } else { ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->seq_num = seq_num; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_done", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_fanin_done", 0); return UCC_OK; } } *dist *= radix; a2a->node.fanin_index = 1; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_step_done", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_step_done", + 0); } UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_done", 0); for (i = 0; i < size; i++) { ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, i); - ucc_assert(i == a2a->node.sbgp->group_rank || ctrl_v->seq_num == seq_num); + ucc_assert(i == a2a->node.sbgp->group_rank || + ctrl_v->seq_num == seq_num); c_flag |= ctrl_v->mkey_cache_flag; } ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag = c_flag; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_retrieve_cache_flags_done", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_retrieve_cache_flags_done", 0); return UCC_OK; } @@ -224,7 +226,9 @@ static ucc_status_t ucc_tl_mlx5_reg_fanin_start(ucc_coll_task_t *coll_task) a2a->node.fanin_dist = 1; for (a2a->node.fanin_max_dist = 1; a2a->node.fanin_max_dist < a2a->node.sbgp->group_size; - a2a->node.fanin_max_dist *= UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix){} + a2a->node.fanin_max_dist *= + UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix) { + } if (UCC_OK == ucc_tl_mlx5_node_fanin(team, task)) { tl_debug(UCC_TL_MLX5_TEAM_LIB(team), "fanin complete"); @@ -279,9 +283,11 @@ static ucc_status_t ucc_tl_mlx5_fanout_start(ucc_coll_task_t *coll_task) tl_debug(UCC_TASK_LIB(task), "fanout start"); /* start task if completion event received */ if (team->a2a->node.sbgp->group_rank == team->a2a->node.asr_rank) { - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_wait-on-data_start", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_wait-on-data_start", 0); } else { - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", + 0); } /* Start fanout */ ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, coll_task); @@ -305,8 +311,10 @@ static void ucc_tl_mlx5_fanout_progress(ucc_coll_task_t *coll_task) coll_task->status = UCC_INPROGRESS; return; } - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_wait-on-data_complete", 0); - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_wait-on-data_complete", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", + 0); } if (UCC_OK == ucc_tl_mlx5_node_fanout(team, task)) { @@ -331,20 +339,21 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task) coll_task->super.status = UCC_INPROGRESS; task->alltoall.started = 0; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_start", 0); - // despite while statement in poll_umr_cq, non blocking because have independent cq, // will be finished in a finite time - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_populate_UMR_start", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_populate_UMR_start", + 0); ucc_tl_mlx5_populate_send_recv_mkeys(team, task); - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_populate_UMR_end", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_populate_UMR_end", + 0); UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_start", 0); //Reset atomic notification counter to 0 #if ATOMIC_IN_MEMIC tl_mlx5_atomic_t zero = 0; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_ibv_memcpy_start", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_ibv_memcpy_start", + 0); if (0 != ibv_memcpy_to_dm(a2a->net.atomic.counters, task->alltoall.seq_index * sizeof(tl_mlx5_atomic_t), @@ -394,10 +403,12 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task) tl_error(UCC_TASK_LIB(task), "failed sending barrier notice"); return status; } - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_send_posted", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_barrier_send_posted", 0); } coll_task->status = UCC_OK; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_done", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_done", + 0); return ucc_task_complete(coll_task); } return UCC_OK; @@ -420,11 +431,12 @@ static void ucc_tl_mlx5_asr_barrier_progress(ucc_coll_task_t *coll_task) } } -ucc_tl_mlx5_dm_chunk_t* ucc_tl_mlx5_a2a_wait_for_dm_chunk (ucc_tl_mlx5_schedule_t *task) +ucc_tl_mlx5_dm_chunk_t * +ucc_tl_mlx5_a2a_wait_for_dm_chunk(ucc_tl_mlx5_schedule_t *task) { - ucc_base_lib_t * lib = UCC_TASK_LIB(task); - ucc_tl_mlx5_team_t * team = TASK_TEAM(task); - ucc_tl_mlx5_dm_chunk_t *dm = NULL; + ucc_base_lib_t * lib = UCC_TASK_LIB(task); + ucc_tl_mlx5_team_t * team = TASK_TEAM(&task->super); + ucc_tl_mlx5_dm_chunk_t *dm = NULL; dm = ucc_mpool_get(&team->dm_pool); if (!dm) { @@ -444,30 +456,33 @@ ucc_tl_mlx5_dm_chunk_t* ucc_tl_mlx5_a2a_wait_for_dm_chunk (ucc_tl_mlx5_schedule_ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) { ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); - ucc_base_lib_t * lib = UCC_TASK_LIB(task); - ucc_tl_mlx5_team_t * team = TASK_TEAM(task); - ucc_tl_mlx5_a2a_t * a2a = team->a2a; + ucc_base_lib_t *lib = UCC_TASK_LIB(task); + ucc_tl_mlx5_team_t *team = TASK_TEAM(&task->super); + ucc_tl_mlx5_alltoall_t *a2a = team->a2a; int node_size = a2a->node.sbgp->group_size; int net_size = a2a->net.sbgp->group_size; int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * a2a->max_num_of_columns; - int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size; - int block_h = task->alltoall.block_height; - int block_w = task->alltoall.block_width; - int col_msgsize = task->alltoall.msg_size * block_w * node_size; - int line_msgsize = task->alltoall.msg_size * block_h * node_size; - int block_msgsize = block_h * block_w * task->alltoall.msg_size; - ucc_status_t status = UCC_OK; - int node_grid_w = node_size / block_w; - int node_nbr_blocks = (node_size * node_size) / (block_h * block_w); - int seq_index = task->alltoall.seq_index; - int node_idx, block_row = 0, block_col = 0, block_idx, rank, dest_rank, cyc_rank; - uint64_t src_addr, remote_addr = 0; + int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size; + int block_h = task->alltoall.block_height; + int block_w = task->alltoall.block_width; + int col_msgsize = task->alltoall.msg_size * block_w * node_size; + int line_msgsize = task->alltoall.msg_size * block_h * node_size; + int block_msgsize = block_h * block_w * task->alltoall.msg_size; + ucc_status_t status = UCC_OK; + int node_grid_w = node_size / block_w; + int node_nbr_blocks = (node_size * node_size) / (block_h * block_w); + int seq_index = task->alltoall.seq_index; + int block_row = 0, block_col = 0; + uint64_t remote_addr = 0; ucc_tl_mlx5_dm_chunk_t *dm = NULL; - int i,j, k, send_to_self; int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; - int nbr_serialized_batches = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches; - int nbr_batches_per_passage = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage; + int nbr_serialized_batches = + UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches; + int nbr_batches_per_passage = + UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage; + int i, j, k, send_to_self, block_idx, rank, dest_rank, cyc_rank, node_idx; + uint64_t src_addr; coll_task->status = UCC_INPROGRESS; coll_task->super.status = UCC_INPROGRESS; @@ -481,32 +496,38 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) for (j = 0; j < nbr_batches_per_passage; j++) { for (node_idx = 0; node_idx < net_size; node_idx++) { - cyc_rank = (node_idx + a2a->net.sbgp->group_rank) % net_size; - dest_rank = a2a->net.rank_map[cyc_rank]; + cyc_rank = (node_idx + a2a->net.sbgp->group_rank) % net_size; + dest_rank = a2a->net.rank_map[cyc_rank]; send_to_self = (cyc_rank == a2a->net.sbgp->group_rank); - if (tl_mlx5_barrier_flag(task, cyc_rank) != task->alltoall.seq_num) { + if (tl_mlx5_barrier_flag(task, cyc_rank) != + task->alltoall.seq_num) { continue; } dm = NULL; - if (!send_to_self && task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks) { - dm = ucc_tl_mlx5_a2a_wait_for_dm_chunk (task); + if (!send_to_self && + task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks) { + dm = ucc_tl_mlx5_a2a_wait_for_dm_chunk(task); if (status != UCC_OK) { return status; } } - send_start(team, cyc_rank); for (i = 0; i < nbr_serialized_batches; i++) { + send_start(team, cyc_rank); + for (i = 0; i < nbr_serialized_batches; i++) { for (k = 0; - k < batch_size && task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks; - k++, task->alltoall.op->blocks_sent[cyc_rank]++) { + k < batch_size && + task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks; + k++, task->alltoall.op->blocks_sent[cyc_rank]++) { block_idx = task->alltoall.op->blocks_sent[cyc_rank]; block_col = block_idx % node_grid_w; block_row = block_idx / node_grid_w; - src_addr = (uintptr_t)(op_msgsize * seq_index + node_msgsize * dest_rank + - col_msgsize * block_col + block_msgsize * block_row); + src_addr = (uintptr_t)( + op_msgsize * seq_index + node_msgsize * dest_rank + + col_msgsize * block_col + block_msgsize * block_row); if (send_to_self || !k) { - remote_addr = - (uintptr_t)(op_msgsize * seq_index + node_msgsize * rank + - block_msgsize * block_col + line_msgsize * block_row); + remote_addr = (uintptr_t)(op_msgsize * seq_index + + node_msgsize * rank + + block_msgsize * block_col + + line_msgsize * block_row); } if (send_to_self) { status = ucc_tl_mlx5_post_transpose( @@ -514,7 +535,9 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) a2a->node.ops[seq_index].send_mkeys[0]->lkey, a2a->net.rkeys[cyc_rank], src_addr, remote_addr, task->alltoall.msg_size, block_w, block_h, - (block_row == 0 && block_col == 0) ? IBV_SEND_SIGNALED : 0); + (block_row == 0 && block_col == 0) + ? IBV_SEND_SIGNALED + : 0); if (UCC_OK != status) { return status; } @@ -522,7 +545,8 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) status = ucc_tl_mlx5_post_transpose( tl_mlx5_get_qp(a2a, cyc_rank), a2a->node.ops[seq_index].send_mkeys[0]->lkey, - team->dm_mr->rkey, src_addr, dm->addr + k * block_msgsize, + team->dm_mr->rkey, src_addr, + dm->addr + k * block_msgsize, task->alltoall.msg_size, block_w, block_h, 0); if (UCC_OK != status) { return status; @@ -535,8 +559,8 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) team->dm_mr->lkey, remote_addr, a2a->net.rkeys[cyc_rank], IBV_SEND_SIGNALED, dm); if (status != UCC_OK) { - tl_error(lib, "failed sending block [%d,%d,%d]", node_idx, block_row, - block_col); + tl_error(lib, "failed sending block [%d,%d,%d]", + node_idx, block_row, block_col); return status; } dm->posted_sends++; @@ -551,18 +575,23 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) } if (task->alltoall.op->blocks_sent[cyc_rank] == node_nbr_blocks) { send_start(team, cyc_rank); - status = send_atomic(a2a, cyc_rank, tl_mlx5_atomic_addr(task, cyc_rank), - tl_mlx5_atomic_rkey(task, cyc_rank)); + status = send_atomic(a2a, cyc_rank, + tl_mlx5_atomic_addr(task, cyc_rank), + tl_mlx5_atomic_rkey(task, cyc_rank)); - if (UCC_OK == status) { - status = send_done(team, cyc_rank); - } - task->alltoall.op->blocks_sent[cyc_rank] = 1; - task->alltoall.started++; - if (status != UCC_OK) { - tl_error(UCC_TASK_LIB(task), "Failed sending atomic to node [%d]", - cyc_rank); - return status; + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(task), + "Failed sending atomic to node [%d]", cyc_rank); + return status; + } + status = send_done(team, cyc_rank); + if (UCC_OK != status) { + tl_error(lib, "Failed sending atomic to node %d", node_idx); + return status; + } + task->alltoall.op->blocks_sent[cyc_rank]++; + task->alltoall.started++; + } } } if (!task->alltoall.send_blocks_enqueued) { @@ -602,7 +631,7 @@ ucc_tl_mlx5_send_blocks_leftovers_start(ucc_coll_task_t *coll_task) int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * a2a->max_num_of_columns; int mkey_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team); - int block_size = task->alltoall.block_height; //ucc_assert(task->alltoall.block_height == task->alltoall.block_width); + int block_size = task->alltoall.block_height; int col_msgsize = msg_size * block_size * node_size; int block_msgsize = SQUARED(block_size) * msg_size; int block_size_leftovers_side = node_size % block_size; @@ -684,7 +713,7 @@ ucc_tl_mlx5_send_blocks_leftovers_start(ucc_coll_task_t *coll_task) dm = ucc_mpool_get(&team->dm_pool); send_start(team, cyc_rank); } - dm_addr = dm->addr; + dm_addr = dm->addr; dm->task = task; status = ucc_tl_mlx5_post_transpose( @@ -817,28 +846,30 @@ static inline int block_size_fits(size_t msgsize, int height, int width) size_t t1 = power2(ucc_max(msgsize, 8)); size_t tsize = height * ucc_max(power2(width) * t1, MAX_MSG_SIZE); - return tsize <= MAX_TRANSPOSE_SIZE && msgsize <= 128 && height <= 64 && width <= 64; + return tsize <= MAX_TRANSPOSE_SIZE && msgsize <= 128 && height <= 64 && + width <= 64; } -static inline void get_block_dimensions(int ppn, int msgsize, int force_regular, int force_longer, int force_wider, int* block_height, int* block_width) +static inline void get_block_dimensions(int ppn, int msgsize, int force_regular, + int force_longer, int force_wider, + int *block_height, int *block_width) { int h_best = 1; int w_best = 1; - int h,w; + int h, w; for (h = 1; h <= 64; h++) { if (force_regular && (ppn % h)) { continue; } for (w = 1; w <= 64; w++) { - if ((force_regular && (ppn % w)) - || (force_wider && (w < h)) - || (force_longer && (w > h))) { + if ((force_regular && (ppn % w)) || (force_wider && (w < h)) || + (force_longer && (w > h))) { continue; } - if (block_size_fits(msgsize, h, w) - && h*w >= h_best*w_best) { - if ( h*w > h_best*w_best || abs(h/w-1) < abs(h_best/w_best-1)) { + if (block_size_fits(msgsize, h, w) && h * w >= h_best * w_best) { + if (h * w > h_best * w_best || + abs(h / w - 1) < abs(h_best / w_best - 1)) { h_best = h; w_best = w; } @@ -850,7 +881,6 @@ static inline void get_block_dimensions(int ppn, int msgsize, int force_regular, *block_width = w_best; } - UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, (coll_args, team, task_h), ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, @@ -863,15 +893,15 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, == a2a->node.asr_rank); int n_tasks = is_asr ? 5 : 3; int curr_task = 0; - int ppn = tl_team->a2a->node.sbgp->group_size; - ucc_tl_mlx5_lib_config_t* cfg = &UCC_TL_MLX5_TEAM_LIB(tl_team)->cfg; - int i; + int ppn = tl_team->a2a->node.sbgp->group_size; + ucc_tl_mlx5_lib_config_t *cfg = &UCC_TL_MLX5_TEAM_LIB(tl_team)->cfg; ucc_schedule_t *schedule; ucc_tl_mlx5_schedule_t *task; size_t msg_size; int block_size, i; ucc_coll_task_t *tasks[5]; ucc_status_t status; + size_t bytes_count, bytes_count_last, bytes_skip, bytes_skip_last; if (UCC_IS_INPLACE(coll_args->args)) { return UCC_ERR_NOT_SUPPORTED; @@ -916,36 +946,40 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, a2a->sequence_number += 1; if (a2a->requested_block_size) { - task->alltoall.block_height = task->alltoall.block_width = a2a->requested_block_size; - if (cfg->force_regular - && ((ppn % task->alltoall.block_height) - || (ppn % task->alltoall.block_width))){ - tl_debug(UCC_TL_TEAM_LIB(tl_team), "the requested block size implies irregular case" - "consider changing the block size or turn off the config FORCE_REGULAR"); + task->alltoall.block_height = task->alltoall.block_width = + a2a->requested_block_size; + if (cfg->force_regular && ((ppn % task->alltoall.block_height) || + (ppn % task->alltoall.block_width))) { + tl_debug(UCC_TL_TEAM_LIB(tl_team), + "the requested block size implies irregular case" + "consider changing the block size or turn off the config " + "FORCE_REGULAR"); return UCC_ERR_INVALID_PARAM; } } else { if (!cfg->force_regular) { if (!(cfg->force_longer && cfg->force_wider)) { - tl_debug(UCC_TL_TEAM_LIB(tl_team), "turning off FORCE_REGULAR automatically forces the blocks to be square"); + tl_debug(UCC_TL_TEAM_LIB(tl_team), + "turning off FORCE_REGULAR automatically forces the " + "blocks to be square"); cfg->force_longer = 1; - cfg->force_wider = 1; + cfg->force_wider = 1; } } - get_block_dimensions(ppn, task->alltoall.msg_size, - cfg->force_regular, - cfg->force_longer, - cfg->force_wider, - &task->alltoall.block_height, &task->alltoall.block_width); + get_block_dimensions(ppn, task->alltoall.msg_size, cfg->force_regular, + cfg->force_longer, cfg->force_wider, + &task->alltoall.block_height, + &task->alltoall.block_width); } - tl_debug(UCC_TL_TEAM_LIB(tl_team), "block dimensions: [%d,%d]", task->alltoall.block_height, task->alltoall.block_width); + tl_debug(UCC_TL_TEAM_LIB(tl_team), "block dimensions: [%d,%d]", + task->alltoall.block_height, task->alltoall.block_width); //todo following section correct assuming homogenous PPN across all nodes task->alltoall.num_of_blocks_columns = (a2a->node.sbgp->group_size % task->alltoall.block_height) - ? ucc_div_round_up(a2a->node.sbgp->group_size, task->alltoall.block_height) + ? ucc_div_round_up(a2a->node.sbgp->group_size, + task->alltoall.block_height) : 0; - task->alltoall.block_size = block_size; // TODO remove for connectX-7 - this is mkey_entry->stride (count+skip) limitation - only 16 bits if (task->alltoall @@ -955,8 +989,7 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, (1ULL << 16); // TODO We need to query this from device (or device type) and not user hardcoded values ucc_assert(task->alltoall.block_height == task->alltoall.block_width); - int block_size = task->alltoall.block_height; - size_t bytes_count, bytes_count_last, bytes_skip, bytes_skip_last; + block_size = task->alltoall.block_height; ucc_assert(task->alltoall.block_height == task->alltoall.block_width); diff --git a/src/components/tl/mlx5/alltoall/alltoall_mkeys.c b/src/components/tl/mlx5/alltoall/alltoall_mkeys.c index 3500aed52c..468068024c 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_mkeys.c +++ b/src/components/tl/mlx5/alltoall/alltoall_mkeys.c @@ -299,7 +299,7 @@ ucc_status_t ucc_tl_mlx5_populate_send_recv_mkeys(ucc_tl_mlx5_team_t * team, if (ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag & UCC_MLX5_NEED_SEND_MKEY_UPDATE) { repeat_count = nbc ? a2a->net.sbgp->group_size - : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_width; + : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_width; for (i = 0; i < n_mkeys; i++) { status = populate_strided_mkey(a2a, send_mem_access_flags, node->ops[seq_index].send_mkeys[i], @@ -314,8 +314,9 @@ ucc_status_t ucc_tl_mlx5_populate_send_recv_mkeys(ucc_tl_mlx5_team_t * team, } if (ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag & UCC_MLX5_NEED_RECV_MKEY_UPDATE) { - repeat_count = nbc ? a2a->net.sbgp->group_size - : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_height; + repeat_count = + nbc ? a2a->net.sbgp->group_size + : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_height; for (i = 0; i < n_mkeys; i++) { status = populate_strided_mkey(a2a, recv_mem_access_flags, node->ops[seq_index].recv_mkeys[i], @@ -336,7 +337,7 @@ static void update_mkey_entry(ucc_tl_mlx5_alltoall_t *a2a, { ucc_tl_mlx5_alltoall_node_t *node = &a2a->node; int block_height = req->alltoall.block_height; - int block_width = req->alltoall.block_width; + int block_width = req->alltoall.block_width; size_t msg_size = req->alltoall.msg_size; int nbc = req->alltoall.num_of_blocks_columns; struct ibv_mr *buff = direction_send @@ -349,7 +350,8 @@ static void update_mkey_entry(ucc_tl_mlx5_alltoall_t *a2a, mkey_entry = (umr_t *)(direction_send ? MY_SEND_UMR_DATA(req, a2a, 0) : MY_RECV_UMR_DATA(req, a2a, 0)); mkey_entry->addr = (uintptr_t)buff->addr; - mkey_entry->bytes_count = (direction_send? block_width : block_height) * msg_size; + mkey_entry->bytes_count = + (direction_send ? block_width : block_height) * msg_size; mkey_entry->bytes_skip = 0; mkey_entry->lkey = direction_send ? buff->lkey : buff->rkey; } else { diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index b8830d4613..d9082fa833 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -28,12 +28,15 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, dm_buf_num), UCC_CONFIG_TYPE_ULUNITS}, - {"FORCE_REGULAR", "y", "Force the regular case where the block dimensions " - "divide ppn. Requires BLOCK_SIZE=0", - ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_regular), UCC_CONFIG_TYPE_BOOL}, + {"FORCE_REGULAR", "y", + "Force the regular case where the block dimensions " + "divide ppn. Requires BLOCK_SIZE=0", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_regular), + UCC_CONFIG_TYPE_BOOL}, {"FORCE_LONGER", "y", "Force the blocks to have more height than width", - ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_longer), UCC_CONFIG_TYPE_BOOL}, + ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_longer), + UCC_CONFIG_TYPE_BOOL}, {"FORCE_WIDER", "n", "Force the blocks to have more width than height", ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_wider), UCC_CONFIG_TYPE_BOOL}, @@ -118,18 +121,21 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, fanin_kn_radix), UCC_CONFIG_TYPE_UINT}, - {"SEND_BATCH_SIZE", "1", "number of blocks that are transposed " - "on the NIC before being sent as a batch to a remote peer", + {"SEND_BATCH_SIZE", "1", + "number of blocks that are transposed " + "on the NIC before being sent as a batch to a remote peer", ucc_offsetof(ucc_tl_mlx5_lib_config_t, block_batch_size), UCC_CONFIG_TYPE_UINT}, - {"NBR_SERIALIZED_BATCHES", "1", "number of block batches " - "(within the set of blocks to be sent to a given remote peer) " + {"NBR_SERIALIZED_BATCHES", "1", + "number of block batches " + "(within the set of blocks to be sent to a given remote peer) " "serialized on the same device memory chunk", ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_serialized_batches), UCC_CONFIG_TYPE_UINT}, - {"NBR_BATCHES_PER_PASSAGE", "32", "", + {"NBR_BATCHES_PER_PASSAGE", "32", + "number of batches of blocks sent to one remote node before enqueing", ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_batches_per_passage), UCC_CONFIG_TYPE_UINT}, @@ -155,7 +161,8 @@ static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.ib_dev_name), UCC_CONFIG_TYPE_STRING}, {"FANIN_NPOLLS", "1000", - "Number of shared memory polling before returning UCC_INPROGRESS during internode FANIN", + "Number of shared memory polling before returning UCC_INPROGRESS during " + "internode FANIN", ucc_offsetof(ucc_tl_mlx5_context_config_t, npolls), UCC_CONFIG_TYPE_UINT}, {NULL}}; diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 20f71d67e8..fe44921004 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -65,8 +65,8 @@ typedef struct ucc_tl_mlx5_lib_config { int nbr_batches_per_passage; int block_batch_size; int force_regular; - int force_longer; - int force_wider; + int force_longer; + int force_wider; } ucc_tl_mlx5_lib_config_t; typedef struct ucc_tl_mlx5_context_config { @@ -103,7 +103,7 @@ typedef struct ucc_tl_mlx5_task ucc_tl_mlx5_task_t; typedef struct ucc_tl_mlx5_schedule ucc_tl_mlx5_schedule_t; typedef struct ucc_tl_mlx5_dm_chunk_t { uintptr_t addr; // 0 based offset from the beginning of - // memic_mr (obtained with ibv_reg_dm_mr) + // memic_mr (obtained with ibv_reg_dm_mr) ucc_tl_mlx5_schedule_t *task; int posted_sends; int posted_all; diff --git a/src/components/tl/mlx5/tl_mlx5_dm.c b/src/components/tl/mlx5/tl_mlx5_dm.c index 23d9aa723d..65ac46fe69 100644 --- a/src/components/tl/mlx5/tl_mlx5_dm.c +++ b/src/components/tl/mlx5/tl_mlx5_dm.c @@ -77,21 +77,16 @@ static void ucc_tl_mlx5_dm_chunk_init(ucc_mpool_t *mp, //NOLINT ucc_tl_mlx5_dm_chunk_t *c = (ucc_tl_mlx5_dm_chunk_t *)obj; ucc_tl_mlx5_team_t *team = ucc_container_of(mp, ucc_tl_mlx5_team_t, dm_pool); - c->addr = (uintptr_t)PTR_OFFSET( - (UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_host)? - team->dm_ptr : NULL, - team->dm_offset); - team->dm_offset = team->dm_offset + - UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size - * UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; - c->posted_sends = 0; - c->posted_all=0; - c->completed_sends = 0; -} - c->offset = (ptrdiff_t)team->dm_offset; - team->dm_offset = PTR_OFFSET(team->dm_offset, - UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size); + c->addr = (uintptr_t)PTR_OFFSET( + (UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_host) ? team->dm_ptr : NULL, + team->dm_offset); + c->posted_sends = 0; + c->posted_all = 0; + c->completed_sends = 0; + team->dm_offset = + team->dm_offset + UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size * + UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; } static ucc_mpool_ops_t ucc_tl_mlx5_dm_ops = { @@ -230,17 +225,19 @@ ucc_status_t ucc_tl_mlx5_dm_init(ucc_tl_mlx5_team_t *team) } status = ucc_tl_mlx5_dm_alloc_reg( - ctx->shared_ctx, ctx->shared_pd, cfg->dm_host, cfg->dm_buf_size * cfg->block_batch_size, - &cfg->dm_buf_num, &team->dm_ptr, &team->dm_mr, UCC_TL_TEAM_LIB(team)); + ctx->shared_ctx, ctx->shared_pd, cfg->dm_host, + cfg->dm_buf_size * cfg->block_batch_size, &cfg->dm_buf_num, + &team->dm_ptr, &team->dm_mr, UCC_TL_TEAM_LIB(team)); if (status != UCC_OK) { goto err_dm_alloc; } team->dm_offset = 0; - // TODO: fix case dm_host=true - status = ucc_mpool_init(&team->dm_pool, 0, sizeof(ucc_tl_mlx5_dm_chunk_t), - 0, UCC_CACHE_LINE_SIZE, 1, - cfg->dm_buf_num, &ucc_tl_mlx5_dm_ops, - ctx->super.super.ucc_context->thread_mode, "mlx5 dm pool"); + // TODO: fix/check the case dm_host=true + ucc_assert(!cfg->dm_host); + status = ucc_mpool_init( + &team->dm_pool, 0, sizeof(ucc_tl_mlx5_dm_chunk_t), 0, + UCC_CACHE_LINE_SIZE, 1, cfg->dm_buf_num, &ucc_tl_mlx5_dm_ops, + ctx->super.super.ucc_context->thread_mode, "mlx5 dm pool"); if (status != UCC_OK) { tl_debug(UCC_TL_TEAM_LIB(team), "failed to init dm pool"); goto err_mpool_init; diff --git a/src/components/tl/mlx5/tl_mlx5_wqe.c b/src/components/tl/mlx5/tl_mlx5_wqe.c index 399b5626ca..e783abf791 100644 --- a/src/components/tl/mlx5/tl_mlx5_wqe.c +++ b/src/components/tl/mlx5/tl_mlx5_wqe.c @@ -39,7 +39,8 @@ static inline uint8_t get_umr_mr_flags(uint32_t acc) typedef struct transpose_seg { __be32 element_size; /* 8 bit value */ - __be16 num_cols; /* 7 bit value */ //TODO: from PRM we should have the rows first and then the colls... is this a bug ? + //From PRM we should have the rows first and then the colls. This is probably a naming error + __be16 num_cols; /* 7 bit value */ __be16 num_rows; /* 7 bit value */ __be64 padding; } transpose_seg_t; diff --git a/test/mpi/buffer.cc b/test/mpi/buffer.cc index 91cf1c311b..f31f42c553 100644 --- a/test/mpi/buffer.cc +++ b/test/mpi/buffer.cc @@ -182,13 +182,6 @@ ucc_status_t compare_buffers(void *_rst, void *expected, size_t count, } else { status = memcmp(rst, expected, count*ucc_dt_size(dt)) ? UCC_ERR_NO_MESSAGE : UCC_OK; - // uint8_t* a = (uint8_t*)rst; - // uint8_t* b = (uint8_t*)expected; - // for (int i=0; i