From 29416ff830eadc36351e0019121dbd738e8188a2 Mon Sep 17 00:00:00 2001 From: Samuel Nordmann Date: Sun, 23 Apr 2023 18:53:12 +0300 Subject: [PATCH 1/9] TL/MLX5: implement knomial fanin TL/MLX5: add npolls cfg for FANIN TL/MLX5: knomial fanin TL/MLX5: add prints and profile events TL/MLX5: remove debug prints --- src/components/tl/mlx5/alltoall/alltoall.h | 3 + .../tl/mlx5/alltoall/alltoall_coll.c | 111 ++++++++++++------ src/components/tl/mlx5/tl_mlx5.c | 7 ++ src/components/tl/mlx5/tl_mlx5.h | 2 + src/components/tl/mlx5/tl_mlx5_dm.c | 1 + 5 files changed, 91 insertions(+), 33 deletions(-) diff --git a/src/components/tl/mlx5/alltoall/alltoall.h b/src/components/tl/mlx5/alltoall/alltoall.h index 9fd9d787cc..92b69e2e6f 100644 --- a/src/components/tl/mlx5/alltoall/alltoall.h +++ b/src/components/tl/mlx5/alltoall/alltoall.h @@ -58,6 +58,9 @@ typedef struct ucc_tl_mlx5_alltoall_node { struct mlx5dv_mkey *team_recv_mkey; void *umr_entries_buf; struct ibv_mr *umr_entries_mr; + int fanin_index; + int fanin_dist; + int fanin_max_dist; } ucc_tl_mlx5_alltoall_node_t; typedef struct alltoall_net_ctrl { diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 70439263fb..6f4de27e72 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -91,40 +91,66 @@ static ucc_status_t ucc_tl_mlx5_poll_cq(struct ibv_cq *cq, ucc_base_lib_t *lib) static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team, ucc_tl_mlx5_schedule_t *task) { - ucc_tl_mlx5_alltoall_t *a2a = team->a2a; - int seq_index = task->alltoall.seq_index; - int i; - ucc_tl_mlx5_alltoall_ctrl_t *ctrl_v; - - if (a2a->node.sbgp->group_rank != a2a->node.asr_rank) { - ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->seq_num = - task->alltoall.seq_num; - } else { - for (i = 0; i < a2a->node.sbgp->group_size; i++) { - if (i == a2a->node.sbgp->group_rank) { - continue; - } - ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, i); - if (ctrl_v->seq_num != task->alltoall.seq_num) { - return UCC_INPROGRESS; - } - } - for (i = 0; i < a2a->node.sbgp->group_size; i++) { - if (i == a2a->node.sbgp->group_rank) { - continue; + ucc_tl_mlx5_a2a_t * a2a = team->a2a; + int seq_index = task->alltoall.seq_index; + int npolls = UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls; + int radix = UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix; + int vrank = a2a->node.sbgp->group_rank - a2a->node.asr_rank; + int *dist = &a2a->node.fanin_dist; + int size = a2a->node.sbgp->group_size; + int seq_num = task->alltoall.seq_num; + int c_flag = 0; + int polls; + int peer, vpeer, pos, i; + ucc_tl_mlx5_a2a_ctrl_t *ctrl_v; + + + while (*dist <= a2a->node.fanin_max_dist) { + if (vrank % *dist == 0) { + pos = (vrank / *dist) % radix; + if (pos == 0) { + while (a2a->node.fanin_index < radix) { + vpeer = vrank + a2a->node.fanin_index * *dist; + if (vpeer >= size) { + a2a->node.fanin_index = radix; + break; + } + peer = (vpeer + a2a->node.asr_rank) % size; + ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, peer); + for (polls = 0; polls < npolls; polls++) { + if (ctrl_v->seq_num == seq_num) { + a2a->node.fanin_index++; + break; + } + } + if (polls == npolls) { + return UCC_INPROGRESS; + } + } + } else { + ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->seq_num = seq_num; + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_done", 0); + return UCC_OK; } - ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, i); - ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag |= - ctrl_v->mkey_cache_flag; } + *dist *= radix; + a2a->node.fanin_index = 1; + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_step_done", 0); } UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_done", 0); + for (i = 0; i < size; i++) { + ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, i); + ucc_assert(i == a2a->node.sbgp->group_rank || ctrl_v->seq_num == seq_num); + c_flag |= ctrl_v->mkey_cache_flag; + } + ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag = c_flag; + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_retrieve_cache_flags_done", 0); return UCC_OK; } /* Each rank registers sbuf and rbuf and place the registration data - in the shared memory location. Next, all rank in node nitify the - ASR the registration data is ready using SHM Fanin */ + in the shared memory location. Next, all rank in node notify the + ASR that the registration data is ready using SHM Fanin */ static ucc_status_t ucc_tl_mlx5_reg_fanin_start(ucc_coll_task_t *coll_task) { ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); @@ -137,7 +163,7 @@ static ucc_status_t ucc_tl_mlx5_reg_fanin_start(ucc_coll_task_t *coll_task) ucc_tl_mlx5_rcache_region_t *send_ptr; ucc_tl_mlx5_rcache_region_t *recv_ptr; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_start", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_reg_start", 0); tl_debug(UCC_TASK_LIB(task), "register memory buffers"); coll_task->status = UCC_INPROGRESS; coll_task->super.status = UCC_INPROGRESS; @@ -187,11 +213,18 @@ static ucc_status_t ucc_tl_mlx5_reg_fanin_start(ucc_coll_task_t *coll_task) /* Start fanin */ ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag = flag; ucc_tl_mlx5_update_mkeys_entries(a2a, task, flag); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_reg_done", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_start", 0); + + a2a->node.fanin_index = 1; + a2a->node.fanin_dist = 1; + for (a2a->node.fanin_max_dist = 1; + a2a->node.fanin_max_dist < a2a->node.sbgp->group_size; + a2a->node.fanin_max_dist *= UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix){} if (UCC_OK == ucc_tl_mlx5_node_fanin(team, task)) { tl_debug(UCC_TL_MLX5_TEAM_LIB(team), "fanin complete"); coll_task->status = UCC_OK; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_done", 0); return ucc_task_complete(coll_task); } @@ -204,7 +237,6 @@ void ucc_tl_mlx5_reg_fanin_progress(ucc_coll_task_t *coll_task) ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); ucc_tl_mlx5_team_t *team = SCHEDULE_TEAM(task); - ucc_assert(team->a2a->node.sbgp->group_rank == team->a2a->node.asr_rank); if (UCC_OK == ucc_tl_mlx5_node_fanin(team, task)) { tl_debug(UCC_TL_MLX5_TEAM_LIB(team), "fanin complete"); coll_task->status = UCC_OK; @@ -242,7 +274,11 @@ static ucc_status_t ucc_tl_mlx5_fanout_start(ucc_coll_task_t *coll_task) tl_debug(UCC_TASK_LIB(task), "fanout start"); /* start task if completion event received */ - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", 0); + if (team->a2a->node.sbgp->group_rank == team->a2a->node.asr_rank) { + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_wait-on-data_start", 0); + } else { + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", 0); + } /* Start fanout */ ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, coll_task); return UCC_OK; @@ -265,6 +301,8 @@ static void ucc_tl_mlx5_fanout_progress(ucc_coll_task_t *coll_task) coll_task->status = UCC_INPROGRESS; return; } + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_wait-on-data_complete", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", 0); } if (UCC_OK == ucc_tl_mlx5_node_fanout(team, task)) { @@ -289,14 +327,20 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task) coll_task->super.status = UCC_INPROGRESS; task->alltoall.started = 0; + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_start", 0); + // despite while statement in poll_umr_cq, non blocking because have independent cq, // will be finished in a finite time + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_populate_UMR_start", 0); ucc_tl_mlx5_populate_send_recv_mkeys(team, task); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_populate_UMR_end", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_start", 0); //Reset atomic notification counter to 0 #if ATOMIC_IN_MEMIC tl_mlx5_atomic_t zero = 0; + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_ibv_memcpy_start", 0); if (0 != ibv_memcpy_to_dm(a2a->net.atomic.counters, task->alltoall.seq_index * sizeof(tl_mlx5_atomic_t), @@ -304,6 +348,7 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task) tl_error(UCC_TASK_LIB(task), "failed to reset atomic in memic"); return UCC_ERR_NO_MESSAGE; } + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_ibv_memcpy_done", 0); #else a2a->net.atomic.counters[task->alltoall.seq_index] = 0; #endif @@ -342,13 +387,13 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task) status = send_done(team, i); } if (status != UCC_OK) { - tl_error(UCC_TASK_LIB(task), "failed sending barrier notice"); + tl_error(UCC_TASK_LIB(task), "failed sending barrier notice"); return status; } + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_send_posted", 0); } coll_task->status = UCC_OK; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barreir_done", - 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_done", 0); return ucc_task_complete(coll_task); } return UCC_OK; diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 5cdd6c51a1..790df70ce3 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -104,6 +104,10 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable), UCC_CONFIG_TYPE_BOOL}, + {"FANIN_KN_RADIX", "4", "Radix of the knomial tree fanin algorithm", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, fanin_kn_radix), + UCC_CONFIG_TYPE_UINT}, + {NULL}}; static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { @@ -125,6 +129,9 @@ static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { {"MCAST_NET_DEVICE", "", "Specifies which network device to use for Mcast", ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.ib_dev_name), UCC_CONFIG_TYPE_STRING}, + {"FANIN_NPOLLS", "1000", + "Number of shared memory polling before returning UCC_INPROGRESS during internode FANIN", + ucc_offsetof(ucc_tl_mlx5_context_config_t, npolls), UCC_CONFIG_TYPE_UINT}, {NULL}}; diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 159ecda8ed..6b2568d152 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -60,12 +60,14 @@ typedef struct ucc_tl_mlx5_lib_config { int dm_host; ucc_tl_mlx5_ib_qp_conf_t qp_conf; ucc_tl_mlx5_mcast_coll_comm_init_spec_t mcast_conf; + int fanin_kn_radix; } ucc_tl_mlx5_lib_config_t; typedef struct ucc_tl_mlx5_context_config { ucc_tl_context_config_t super; ucs_config_names_array_t devices; ucc_tl_mlx5_mcast_ctx_params_t mcast_ctx_conf; + int npolls; } ucc_tl_mlx5_context_config_t; typedef struct ucc_tl_mlx5_lib { diff --git a/src/components/tl/mlx5/tl_mlx5_dm.c b/src/components/tl/mlx5/tl_mlx5_dm.c index 2a0c474a39..7725402a01 100644 --- a/src/components/tl/mlx5/tl_mlx5_dm.c +++ b/src/components/tl/mlx5/tl_mlx5_dm.c @@ -224,6 +224,7 @@ ucc_status_t ucc_tl_mlx5_dm_init(ucc_tl_mlx5_team_t *team) if (status != UCC_OK) { goto err_dm_alloc; } + team->dm_offset = NULL; status = ucc_mpool_init( From 4bbd39a41e12c17fdc6ceee2dd7435f0066eb497 Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 23 May 2023 19:49:54 +0300 Subject: [PATCH 2/9] TL/MLX5: configurable batch_size, ser, pollings tiny bit more robust print blocks dimensions fully working configurable batch_size, serialization, and pollings --- .../tl/mlx5/alltoall/alltoall_coll.c | 219 +++++++++++------- src/components/tl/mlx5/tl_mlx5.c | 15 ++ src/components/tl/mlx5/tl_mlx5.h | 12 +- src/components/tl/mlx5/tl_mlx5_dm.c | 27 ++- src/components/tl/mlx5/tl_mlx5_team.c | 4 +- test/mpi/buffer.cc | 7 + 6 files changed, 182 insertions(+), 102 deletions(-) diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 6f4de27e72..1901497e90 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -82,7 +82,13 @@ static ucc_status_t ucc_tl_mlx5_poll_cq(struct ibv_cq *cq, ucc_base_lib_t *lib) } else { ucc_tl_mlx5_dm_chunk_t *dm = (ucc_tl_mlx5_dm_chunk_t *)wcs[i].wr_id; dm->task->alltoall.blocks_completed++; - ucc_mpool_put(dm); + ucc_assert(dm->completed_jobs < dm->posted_jobs); + dm->completed_jobs++; + // printf("!!!!!!PID %i, inside poll cq, dm=%p, dm->addr=%lu, dm->counter=%i, dm->nbr_jobs=%i, dm->posted_jobs=%i\n", getpid(), dm, dm->addr, dm->counter, dm->nbr_jobs, dm->posted_jobs); + /* printf("returning dm %p to pool\n", (void*)team->work_completion[i].wr_id); */ + if (dm->posted_all && dm->completed_jobs == dm->posted_jobs) { + ucc_mpool_put(dm); + } } } return UCC_OK; @@ -91,7 +97,7 @@ static ucc_status_t ucc_tl_mlx5_poll_cq(struct ibv_cq *cq, ucc_base_lib_t *lib) static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team, ucc_tl_mlx5_schedule_t *task) { - ucc_tl_mlx5_a2a_t * a2a = team->a2a; + ucc_tl_mlx5_alltoall_t * a2a = team->a2a; int seq_index = task->alltoall.seq_index; int npolls = UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls; int radix = UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix; @@ -102,7 +108,7 @@ static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team, int c_flag = 0; int polls; int peer, vpeer, pos, i; - ucc_tl_mlx5_a2a_ctrl_t *ctrl_v; + ucc_tl_mlx5_alltoall_ctrl_t *ctrl_v; while (*dist <= a2a->node.fanin_max_dist) { @@ -416,33 +422,63 @@ static void ucc_tl_mlx5_asr_barrier_progress(ucc_coll_task_t *coll_task) } } +ucc_tl_mlx5_dm_chunk_t* ucc_tl_mlx5_a2a_wait_for_dm_chunk (ucc_tl_mlx5_schedule_t *task) +{ + ucc_base_lib_t * lib = UCC_TASK_LIB(task); + ucc_tl_mlx5_team_t * team = TASK_TEAM(task); + ucc_tl_mlx5_dm_chunk_t *dm = NULL; + + dm = ucc_mpool_get(&team->dm_pool); + if (!dm) { + while (!dm) { + if (UCC_OK != ucc_tl_mlx5_poll_cq(team->a2a->net.cq, lib)) { + return NULL; + } + dm = ucc_mpool_get(&team->dm_pool); + } + } + dm->task = task; + + return dm; +} + // add polling mechanism for blocks in order to maintain const qp tx rx static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) { - ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); - ucc_tl_mlx5_team_t *team = SCHEDULE_TEAM(task); - ucc_tl_mlx5_alltoall_t *a2a = team->a2a; - int node_size = a2a->node.sbgp->group_size; - int net_size = a2a->net.sbgp->group_size; - int op_msgsize = node_size * a2a->max_msg_size - * UCC_TL_TEAM_SIZE(team) - * a2a->max_num_of_columns; - int node_msgsize = SQUARED(node_size) - * task->alltoall.msg_size; - int block_size = task->alltoall.block_size; - int col_msgsize = task->alltoall.msg_size - * block_size * node_size; - int block_msgsize = SQUARED(block_size) - * task->alltoall.msg_size; - int dm_host = UCC_TL_MLX5_TEAM_LIB(team) - ->cfg.dm_host; - ucc_status_t status = UCC_OK; - ucc_base_lib_t *lib = UCC_TASK_LIB(task); - int seq_index = task->alltoall.seq_index; - int i, j, k, rank, dest_rank, cyc_rank; + // printf("ucc_tl_mlx5_send_blocks_start\n"); + ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); + ucc_base_lib_t * lib = UCC_TASK_LIB(task); + ucc_tl_mlx5_team_t * team = TASK_TEAM(task); + ucc_tl_mlx5_a2a_t * a2a = team->a2a; + int node_size = a2a->node.sbgp->group_size; + int net_size = a2a->net.sbgp->group_size; + int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * + a2a->max_num_of_columns; + int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size; + int block_size = task->alltoall.block_size; + int col_msgsize = task->alltoall.msg_size * block_size * node_size; + int block_msgsize = SQUARED(block_size) * task->alltoall.msg_size; + ucc_status_t status = UCC_OK; + int node_grid_dim = node_size / block_size; + int seq_index = task->alltoall.seq_index; + int node_idx, block_row, block_col, block_idx, rank, dest_rank, cyc_rank; uint64_t src_addr, remote_addr; - ucc_tl_mlx5_dm_chunk_t *dm; - uintptr_t dm_addr; + ucc_tl_mlx5_dm_chunk_t *dm = NULL; + int remaining_blocks; + int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; + int nbr_serialized_batches = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches; + int nbr_batches_per_passage = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage; + int nbr_blocks_to_handle = nbr_serialized_batches * batch_size; + + batch_size = ucc_min(batch_size, nbr_batches_per_passage); + + while (nbr_batches_per_passage % batch_size) { + batch_size--; + } + + ucc_assert(nbr_batches_per_passage % batch_size == 0); + + // printf("block_msgsize=%d, node_grid_dim=%d\n", block_msgsize, node_grid_dim); coll_task->status = UCC_INPROGRESS; coll_task->super.status = UCC_INPROGRESS; @@ -454,84 +490,95 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) "mlx5_alltoall_block_send_start", 0); } - for (i = 0; i < net_size; i++) { - cyc_rank = (i + a2a->net.sbgp->group_rank) % net_size; - dest_rank = a2a->net.rank_map[cyc_rank]; - if (task->alltoall.op->blocks_sent[cyc_rank] || - tl_mlx5_barrier_flag(task, cyc_rank) != task->alltoall.seq_num) { - continue; - } - - //send all blocks from curr node to some ARR - for (j = 0; j < (node_size / block_size); j++) { - for (k = 0; k < (node_size / block_size); k++) { + for (j = 0; j < nbr_batches_per_passage; j++) { + for (node_idx = 0; node_idx < net_size; node_idx++) { + cyc_rank = (node_idx + a2a->net.sbgp->group_rank) % net_size; + dest_rank = a2a->net.rank_map[cyc_rank]; + if (tl_mlx5_barrier_flag(task, cyc_rank) != task->alltoall.seq_num) { + continue; + } + remaining_blocks = SQUARED(node_grid_dim) - task->alltoall.op->blocks_sent[cyc_rank]; + dm=NULL; + if (cyc_rank != a2a->net.sbgp->group_rank && remaining_blocks > 0) { + dm = ucc_tl_mlx5_a2a_wait_for_dm_chunk (task); + if (status != UCC_OK) { + return status; + } + // dm->nbr_jobs = ucc_div_round_up(ucc_min(nbr_serialized_batches, remaining_blocks), batch_size); + } + send_start(team, cyc_rank); + remote_addr = 0; + for (i = 0; + i < nbr_blocks_to_handle && task->alltoall.op->blocks_sent[cyc_rank] < SQUARED(node_grid_dim); + i++, task->alltoall.op->blocks_sent[cyc_rank]++) { + block_idx = task->alltoall.op->blocks_sent[cyc_rank]; + block_row = block_idx / node_grid_dim; + block_col = block_idx % node_grid_dim; src_addr = (uintptr_t)(node_msgsize * dest_rank + - col_msgsize * j + block_msgsize * k); - remote_addr = + col_msgsize * block_col + block_msgsize * block_row); + uintptr_t remote_addr_i = (uintptr_t)(op_msgsize * seq_index + node_msgsize * rank + - block_msgsize * j + col_msgsize * k); - - send_start(team, cyc_rank); + block_msgsize * block_col + col_msgsize * block_row); + if (i % batch_size == 0) { + // src_addr = (uintptr_t)(node_msgsize * dest_rank + + // col_msgsize * block_row + block_msgsize * block_col); + remote_addr = + (uintptr_t)(op_msgsize * seq_index + node_msgsize * rank + + block_msgsize * block_col + col_msgsize * block_row); + } if (cyc_rank == a2a->net.sbgp->group_rank) { status = ucc_tl_mlx5_post_transpose( tl_mlx5_get_qp(a2a, cyc_rank), a2a->node.ops[seq_index].send_mkeys[0]->lkey, - a2a->net.rkeys[cyc_rank], src_addr, remote_addr, + a2a->net.rkeys[cyc_rank], src_addr, remote_addr_i, task->alltoall.msg_size, block_size, block_size, (j == 0 && k == 0) ? IBV_SEND_SIGNALED : 0); if (UCC_OK != status) { return status; } } else { - dm = ucc_mpool_get(&team->dm_pool); - while (!dm) { - status = send_done(team, cyc_rank); - if (UCC_OK != status) { - return status; - } - - status = ucc_tl_mlx5_poll_cq(a2a->net.cq, lib); - if (UCC_OK != status) { - return status; - } - dm = ucc_mpool_get(&team->dm_pool); - send_start(team, cyc_rank); - } - if (dm_host) { - dm_addr = - (uintptr_t)PTR_OFFSET(team->dm_ptr, dm->offset); - } else { - dm_addr = dm->offset; // dm reg mr 0 based - } - dm->task = task; - + // printf("PID: %d, at rank %d, cyc_rank %d, i = %d, batch_size=%d, block_idx=%d, block_msgsize=%d transpose from src_addr %ld to dm_addr_i=%ld (%ld)\n", + // getpid(), a2a->net.sbgp->group_rank, cyc_rank, i, batch_size, block_idx, block_msgsize, src_addr, dm_addr_i, dm_addr_i / block_msgsize); + uintptr_t dm_addr_i = dm->addr + (i % batch_size) * block_msgsize; status = ucc_tl_mlx5_post_transpose( tl_mlx5_get_qp(a2a, cyc_rank), a2a->node.ops[seq_index].send_mkeys[0]->lkey, - team->dm_mr->rkey, src_addr, dm_addr, + team->dm_mr->rkey, src_addr, dm_addr_i, task->alltoall.msg_size, block_size, block_size, 0); if (UCC_OK != status) { return status; } - status = send_block_data( - a2a, cyc_rank, dm_addr, block_msgsize, - team->dm_mr->lkey, remote_addr, - a2a->net.rkeys[cyc_rank], IBV_SEND_SIGNALED, dm); - if (status != UCC_OK) { - tl_error(lib, "failed sending block [%d,%d,%d]", i, j, - k); - return status; + if (((i+1) % batch_size) == 0 + || (i+1) == nbr_blocks_to_handle + || (task->alltoall.op->blocks_sent[cyc_rank]+1) == SQUARED(node_grid_dim)) { + // if (!a2a->net.sbgp->group_rank) { + // printf("!PID: %d, at rank %d, cyc_rank %d, i = %d, batch_size=%d, block_idx=%d, block_msgsize=%d SEND from dm_addr %ld to remote_addr=%ld\n dm = %p, dm->counter = %d, dm->nbr_jobs = %d, dm->posted_jobs=%i\n", + // getpid(), a2a->net.sbgp->group_rank, cyc_rank, i, batch_size, block_idx, block_msgsize, dm->addr, remote_addr, dm, dm->counter, dm->nbr_jobs, dm->posted_jobs); + // } + status = send_block_data( + a2a, cyc_rank, dm->addr, block_msgsize * batch_size, + team->dm_mr->lkey, remote_addr, + a2a->net.rkeys[cyc_rank], IBV_SEND_SIGNALED, dm); + if (status != UCC_OK) { + tl_error(lib, "failed sending block [%d,%d,%d]", node_idx, block_row, + block_col); + return status; + } + dm->posted_jobs++; } } - status = send_done(team, cyc_rank); - if (status != UCC_OK) { - return status; - } } - } - send_start(team, cyc_rank); - status = send_atomic(a2a, cyc_rank, tl_mlx5_atomic_addr(task, cyc_rank), - tl_mlx5_atomic_rkey(task, cyc_rank)); + status = send_done(team, cyc_rank); + if (status != UCC_OK) { + return status; + } + if (dm) { + dm->posted_all=1; + } + if (task->alltoall.op->blocks_sent[cyc_rank] == SQUARED(node_grid_dim)) { + send_start(team, cyc_rank); + status = send_atomic(a2a, cyc_rank, tl_mlx5_atomic_addr(task, cyc_rank), + tl_mlx5_atomic_rkey(task, cyc_rank)); if (UCC_OK == status) { status = send_done(team, cyc_rank); @@ -590,7 +637,6 @@ ucc_tl_mlx5_send_blocks_leftovers_start(ucc_coll_task_t *coll_task) int block_msgsize_leftovers = block_size_leftovers_side * block_size * msg_size; int corner_msgsize = SQUARED(block_size_leftovers_side) * msg_size; - int dm_host = UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_host; ucc_status_t status = UCC_OK; ucc_base_lib_t *lib = UCC_TASK_LIB(coll_task); int nbc = task->alltoall.num_of_blocks_columns; @@ -664,12 +710,7 @@ ucc_tl_mlx5_send_blocks_leftovers_start(ucc_coll_task_t *coll_task) dm = ucc_mpool_get(&team->dm_pool); send_start(team, cyc_rank); } - if (dm_host) { - dm_addr = - (uintptr_t)PTR_OFFSET(team->dm_ptr, dm->offset); - } else { - dm_addr = dm->offset; // dm reg mr 0 based - } + dm_addr = dm->addr; dm->task = task; status = ucc_tl_mlx5_post_transpose( diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 790df70ce3..6a2d8ae02d 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -108,6 +108,21 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, fanin_kn_radix), UCC_CONFIG_TYPE_UINT}, + {"SEND_BATCH_SIZE", "1", "number of blocks that are transposed " + "on the NIC before being sent as a batch to a remote peer", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, block_batch_size), + UCC_CONFIG_TYPE_UINT}, + + {"NBR_SERIALIZED_BATCHES", "8", "number of block batches " + "(within the set of blocks to be sent to a given remote peer)" + "serialized on the same device memory chunk", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_serialized_batches), + UCC_CONFIG_TYPE_UINT}, + + {"NBR_BATCHES_PER_PASSAGE", "1024", "", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_batches_per_passage), + UCC_CONFIG_TYPE_UINT}, + {NULL}}; static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 6b2568d152..0d56b45e1f 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -61,6 +61,9 @@ typedef struct ucc_tl_mlx5_lib_config { ucc_tl_mlx5_ib_qp_conf_t qp_conf; ucc_tl_mlx5_mcast_coll_comm_init_spec_t mcast_conf; int fanin_kn_radix; + int nbr_serialized_batches; + int nbr_batches_per_passage; + int block_batch_size; } ucc_tl_mlx5_lib_config_t; typedef struct ucc_tl_mlx5_context_config { @@ -95,10 +98,13 @@ UCC_CLASS_DECLARE(ucc_tl_mlx5_context_t, const ucc_base_context_params_t*, typedef struct ucc_tl_mlx5_task ucc_tl_mlx5_task_t; typedef struct ucc_tl_mlx5_schedule ucc_tl_mlx5_schedule_t; -typedef struct ucc_tl_mlx5_dm_chunk { - ptrdiff_t offset; /* 0 based offset from the beginning of - memic_mr (obtained with ibv_reg_dm_mr) */ +typedef struct ucc_tl_mlx5_dm_chunk_t { + uintptr_t addr; // 0 based offset from the beginning of + // memic_mr (obtained with ibv_reg_dm_mr) ucc_tl_mlx5_schedule_t *task; + int posted_jobs; + int posted_all; + int completed_jobs; } ucc_tl_mlx5_dm_chunk_t; typedef struct ucc_tl_mlx5_alltoall ucc_tl_mlx5_alltoall_t; diff --git a/src/components/tl/mlx5/tl_mlx5_dm.c b/src/components/tl/mlx5/tl_mlx5_dm.c index 7725402a01..6a77b775db 100644 --- a/src/components/tl/mlx5/tl_mlx5_dm.c +++ b/src/components/tl/mlx5/tl_mlx5_dm.c @@ -77,6 +77,17 @@ static void ucc_tl_mlx5_dm_chunk_init(ucc_mpool_t *mp, //NOLINT ucc_tl_mlx5_dm_chunk_t *c = (ucc_tl_mlx5_dm_chunk_t *)obj; ucc_tl_mlx5_team_t *team = ucc_container_of(mp, ucc_tl_mlx5_team_t, dm_pool); + c->addr = (uintptr_t)PTR_OFFSET( + (UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_host)? + team->dm_ptr : NULL, + team->dm_offset); + team->dm_offset = team->dm_offset + + UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size + * UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; + c->posted_jobs = 0; + c->posted_all=0; + c->completed_jobs = 0; +} c->offset = (ptrdiff_t)team->dm_offset; team->dm_offset = PTR_OFFSET(team->dm_offset, @@ -219,18 +230,18 @@ ucc_status_t ucc_tl_mlx5_dm_init(ucc_tl_mlx5_team_t *team) } status = ucc_tl_mlx5_dm_alloc_reg( - ctx->shared_ctx, ctx->shared_pd, cfg->dm_host, cfg->dm_buf_size, + ctx->shared_ctx, ctx->shared_pd, cfg->dm_host, cfg->dm_buf_size * cfg->block_batch_size, &cfg->dm_buf_num, &team->dm_ptr, &team->dm_mr, UCC_TL_TEAM_LIB(team)); if (status != UCC_OK) { goto err_dm_alloc; } - - team->dm_offset = NULL; - - status = ucc_mpool_init( - &team->dm_pool, 0, sizeof(ucc_tl_mlx5_dm_chunk_t), 0, - UCC_CACHE_LINE_SIZE, 1, cfg->dm_buf_num, &ucc_tl_mlx5_dm_ops, - ctx->super.super.ucc_context->thread_mode, "mlx5 dm pool"); + team->dm_offset = 0; + printf("Number of MEMIC chunks = %ld, chunck size = %ld\n", cfg->dm_buf_num, cfg->dm_buf_size * cfg->block_batch_size); + // TODO: fix case dm_host=true + status = ucc_mpool_init(&team->dm_pool, 0, sizeof(ucc_tl_mlx5_dm_chunk_t), + 0, UCC_CACHE_LINE_SIZE, 1, + cfg->dm_buf_num, &ucc_tl_mlx5_dm_ops, + ctx->super.super.ucc_context->thread_mode, "mlx5 dm pool"); if (status != UCC_OK) { tl_debug(UCC_TL_TEAM_LIB(team), "failed to init dm pool"); goto err_mpool_init; diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index 1e5f6ddf56..6a65274d1d 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -117,7 +117,7 @@ ucc_status_t ucc_tl_mlx5_team_destroy(ucc_base_team_t *tl_team) return UCC_OK; } -static inline ucc_status_t ucc_tl_mlx5_a2a_team_test(ucc_base_team_t *team) +static inline ucc_status_t ucc_tl_mlx5_alltoall_team_test(ucc_base_team_t *team) { ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); @@ -253,7 +253,7 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) goto initial_sync_post; } - a2a_status = ucc_tl_mlx5_a2a_team_test(team); + a2a_status = ucc_tl_mlx5_alltoall_team_test(team); if (a2a_status < 0) { tl_warn(team->context->lib, "ALLTOALL tl team: %p creation failed %d", team, a2a_status); diff --git a/test/mpi/buffer.cc b/test/mpi/buffer.cc index f31f42c553..91cf1c311b 100644 --- a/test/mpi/buffer.cc +++ b/test/mpi/buffer.cc @@ -182,6 +182,13 @@ ucc_status_t compare_buffers(void *_rst, void *expected, size_t count, } else { status = memcmp(rst, expected, count*ucc_dt_size(dt)) ? UCC_ERR_NO_MESSAGE : UCC_OK; + // uint8_t* a = (uint8_t*)rst; + // uint8_t* b = (uint8_t*)expected; + // for (int i=0; i Date: Wed, 14 Jun 2023 20:23:04 +0300 Subject: [PATCH 3/9] TL/MLX5: BFS iter, visit nodes before blocks clean --- .../tl/mlx5/alltoall/alltoall_coll.c | 126 +++++++----------- src/components/tl/mlx5/tl_mlx5.c | 6 +- src/components/tl/mlx5/tl_mlx5.h | 4 +- src/components/tl/mlx5/tl_mlx5_dm.c | 4 +- 4 files changed, 56 insertions(+), 84 deletions(-) diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 1901497e90..2fbc06cb20 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -82,11 +82,9 @@ static ucc_status_t ucc_tl_mlx5_poll_cq(struct ibv_cq *cq, ucc_base_lib_t *lib) } else { ucc_tl_mlx5_dm_chunk_t *dm = (ucc_tl_mlx5_dm_chunk_t *)wcs[i].wr_id; dm->task->alltoall.blocks_completed++; - ucc_assert(dm->completed_jobs < dm->posted_jobs); - dm->completed_jobs++; - // printf("!!!!!!PID %i, inside poll cq, dm=%p, dm->addr=%lu, dm->counter=%i, dm->nbr_jobs=%i, dm->posted_jobs=%i\n", getpid(), dm, dm->addr, dm->counter, dm->nbr_jobs, dm->posted_jobs); + dm->completed_sends++; /* printf("returning dm %p to pool\n", (void*)team->work_completion[i].wr_id); */ - if (dm->posted_all && dm->completed_jobs == dm->posted_jobs) { + if (dm->posted_all && dm->completed_sends == dm->posted_sends) { ucc_mpool_put(dm); } } @@ -464,21 +462,10 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) int node_idx, block_row, block_col, block_idx, rank, dest_rank, cyc_rank; uint64_t src_addr, remote_addr; ucc_tl_mlx5_dm_chunk_t *dm = NULL; - int remaining_blocks; + int i,j, k, send_to_self; int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; int nbr_serialized_batches = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches; int nbr_batches_per_passage = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage; - int nbr_blocks_to_handle = nbr_serialized_batches * batch_size; - - batch_size = ucc_min(batch_size, nbr_batches_per_passage); - - while (nbr_batches_per_passage % batch_size) { - batch_size--; - } - - ucc_assert(nbr_batches_per_passage % batch_size == 0); - - // printf("block_msgsize=%d, node_grid_dim=%d\n", block_msgsize, node_grid_dim); coll_task->status = UCC_INPROGRESS; coll_task->super.status = UCC_INPROGRESS; @@ -494,86 +481,71 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) for (node_idx = 0; node_idx < net_size; node_idx++) { cyc_rank = (node_idx + a2a->net.sbgp->group_rank) % net_size; dest_rank = a2a->net.rank_map[cyc_rank]; + send_to_self = (cyc_rank == a2a->net.sbgp->group_rank); if (tl_mlx5_barrier_flag(task, cyc_rank) != task->alltoall.seq_num) { continue; } - remaining_blocks = SQUARED(node_grid_dim) - task->alltoall.op->blocks_sent[cyc_rank]; - dm=NULL; - if (cyc_rank != a2a->net.sbgp->group_rank && remaining_blocks > 0) { + dm = NULL; + if (!send_to_self && task->alltoall.op->blocks_sent[cyc_rank] < SQUARED(node_grid_dim)) { dm = ucc_tl_mlx5_a2a_wait_for_dm_chunk (task); if (status != UCC_OK) { return status; } - // dm->nbr_jobs = ucc_div_round_up(ucc_min(nbr_serialized_batches, remaining_blocks), batch_size); } - send_start(team, cyc_rank); - remote_addr = 0; - for (i = 0; - i < nbr_blocks_to_handle && task->alltoall.op->blocks_sent[cyc_rank] < SQUARED(node_grid_dim); - i++, task->alltoall.op->blocks_sent[cyc_rank]++) { - block_idx = task->alltoall.op->blocks_sent[cyc_rank]; - block_row = block_idx / node_grid_dim; - block_col = block_idx % node_grid_dim; - src_addr = (uintptr_t)(node_msgsize * dest_rank + - col_msgsize * block_col + block_msgsize * block_row); - uintptr_t remote_addr_i = - (uintptr_t)(op_msgsize * seq_index + node_msgsize * rank + - block_msgsize * block_col + col_msgsize * block_row); - if (i % batch_size == 0) { - // src_addr = (uintptr_t)(node_msgsize * dest_rank + - // col_msgsize * block_row + block_msgsize * block_col); - remote_addr = - (uintptr_t)(op_msgsize * seq_index + node_msgsize * rank + - block_msgsize * block_col + col_msgsize * block_row); - } - if (cyc_rank == a2a->net.sbgp->group_rank) { - status = ucc_tl_mlx5_post_transpose( - tl_mlx5_get_qp(a2a, cyc_rank), - a2a->node.ops[seq_index].send_mkeys[0]->lkey, - a2a->net.rkeys[cyc_rank], src_addr, remote_addr_i, - task->alltoall.msg_size, block_size, block_size, - (j == 0 && k == 0) ? IBV_SEND_SIGNALED : 0); - if (UCC_OK != status) { - return status; - } - } else { - // printf("PID: %d, at rank %d, cyc_rank %d, i = %d, batch_size=%d, block_idx=%d, block_msgsize=%d transpose from src_addr %ld to dm_addr_i=%ld (%ld)\n", - // getpid(), a2a->net.sbgp->group_rank, cyc_rank, i, batch_size, block_idx, block_msgsize, src_addr, dm_addr_i, dm_addr_i / block_msgsize); - uintptr_t dm_addr_i = dm->addr + (i % batch_size) * block_msgsize; - status = ucc_tl_mlx5_post_transpose( - tl_mlx5_get_qp(a2a, cyc_rank), - a2a->node.ops[seq_index].send_mkeys[0]->lkey, - team->dm_mr->rkey, src_addr, dm_addr_i, - task->alltoall.msg_size, block_size, block_size, 0); - if (UCC_OK != status) { - return status; + send_start(team, cyc_rank); for (i = 0; i < nbr_serialized_batches; i++) { + for (k = 0; + k < batch_size && task->alltoall.op->blocks_sent[cyc_rank] < SQUARED(node_grid_dim); + k++, task->alltoall.op->blocks_sent[cyc_rank]++) { + block_idx = task->alltoall.op->blocks_sent[cyc_rank]; + block_row = block_idx / node_grid_dim; + block_col = block_idx % node_grid_dim; + src_addr = (uintptr_t)(op_msgsize * seq_index + node_msgsize * dest_rank + + col_msgsize * block_col + block_msgsize * block_row); + if (send_to_self || !k) { + remote_addr = + (uintptr_t)(op_msgsize * seq_index + node_msgsize * rank + + block_msgsize * block_col + col_msgsize * block_row); } - if (((i+1) % batch_size) == 0 - || (i+1) == nbr_blocks_to_handle - || (task->alltoall.op->blocks_sent[cyc_rank]+1) == SQUARED(node_grid_dim)) { - // if (!a2a->net.sbgp->group_rank) { - // printf("!PID: %d, at rank %d, cyc_rank %d, i = %d, batch_size=%d, block_idx=%d, block_msgsize=%d SEND from dm_addr %ld to remote_addr=%ld\n dm = %p, dm->counter = %d, dm->nbr_jobs = %d, dm->posted_jobs=%i\n", - // getpid(), a2a->net.sbgp->group_rank, cyc_rank, i, batch_size, block_idx, block_msgsize, dm->addr, remote_addr, dm, dm->counter, dm->nbr_jobs, dm->posted_jobs); - // } - status = send_block_data( - a2a, cyc_rank, dm->addr, block_msgsize * batch_size, - team->dm_mr->lkey, remote_addr, - a2a->net.rkeys[cyc_rank], IBV_SEND_SIGNALED, dm); - if (status != UCC_OK) { - tl_error(lib, "failed sending block [%d,%d,%d]", node_idx, block_row, - block_col); + if (send_to_self) { + status = ucc_tl_mlx5_post_transpose( + tl_mlx5_get_qp(a2a, cyc_rank), + a2a->node.ops[seq_index].send_mkeys[0]->lkey, + a2a->net.rkeys[cyc_rank], src_addr, remote_addr, + task->alltoall.msg_size, block_size, block_size, + (block_row == 0 && block_col == 0) ? IBV_SEND_SIGNALED : 0); + if (UCC_OK != status) { + return status; + } + } else { + status = ucc_tl_mlx5_post_transpose( + tl_mlx5_get_qp(a2a, cyc_rank), + a2a->node.ops[seq_index].send_mkeys[0]->lkey, + team->dm_mr->rkey, src_addr, dm->addr + k * block_msgsize, + task->alltoall.msg_size, block_size, block_size, 0); + if (UCC_OK != status) { return status; } - dm->posted_jobs++; } } + if (!send_to_self && k) { + status = send_block_data( + a2a, cyc_rank, dm->addr, block_msgsize * k, + team->dm_mr->lkey, remote_addr, + a2a->net.rkeys[cyc_rank], IBV_SEND_SIGNALED, dm); + if (status != UCC_OK) { + tl_error(lib, "failed sending block [%d,%d,%d]", node_idx, block_row, + block_col); + return status; + } + dm->posted_sends++; + } } status = send_done(team, cyc_rank); if (status != UCC_OK) { return status; } if (dm) { - dm->posted_all=1; + dm->posted_all = 1; } if (task->alltoall.op->blocks_sent[cyc_rank] == SQUARED(node_grid_dim)) { send_start(team, cyc_rank); diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 6a2d8ae02d..47e775e2ee 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -108,18 +108,18 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, fanin_kn_radix), UCC_CONFIG_TYPE_UINT}, - {"SEND_BATCH_SIZE", "1", "number of blocks that are transposed " + {"SEND_BATCH_SIZE", "8", "number of blocks that are transposed " "on the NIC before being sent as a batch to a remote peer", ucc_offsetof(ucc_tl_mlx5_lib_config_t, block_batch_size), UCC_CONFIG_TYPE_UINT}, - {"NBR_SERIALIZED_BATCHES", "8", "number of block batches " + {"NBR_SERIALIZED_BATCHES", "2", "number of block batches " "(within the set of blocks to be sent to a given remote peer)" "serialized on the same device memory chunk", ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_serialized_batches), UCC_CONFIG_TYPE_UINT}, - {"NBR_BATCHES_PER_PASSAGE", "1024", "", + {"NBR_BATCHES_PER_PASSAGE", "32", "", ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_batches_per_passage), UCC_CONFIG_TYPE_UINT}, diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 0d56b45e1f..6bc92ff56c 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -102,9 +102,9 @@ typedef struct ucc_tl_mlx5_dm_chunk_t { uintptr_t addr; // 0 based offset from the beginning of // memic_mr (obtained with ibv_reg_dm_mr) ucc_tl_mlx5_schedule_t *task; - int posted_jobs; + int posted_sends; int posted_all; - int completed_jobs; + int completed_sends; } ucc_tl_mlx5_dm_chunk_t; typedef struct ucc_tl_mlx5_alltoall ucc_tl_mlx5_alltoall_t; diff --git a/src/components/tl/mlx5/tl_mlx5_dm.c b/src/components/tl/mlx5/tl_mlx5_dm.c index 6a77b775db..7ad216b154 100644 --- a/src/components/tl/mlx5/tl_mlx5_dm.c +++ b/src/components/tl/mlx5/tl_mlx5_dm.c @@ -84,9 +84,9 @@ static void ucc_tl_mlx5_dm_chunk_init(ucc_mpool_t *mp, //NOLINT team->dm_offset = team->dm_offset + UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size * UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; - c->posted_jobs = 0; + c->posted_sends = 0; c->posted_all=0; - c->completed_jobs = 0; + c->completed_sends = 0; } c->offset = (ptrdiff_t)team->dm_offset; From 5917644b989127c6c79666dd60b19a6324578b4e Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 21 Jun 2023 17:55:12 +0300 Subject: [PATCH 4/9] TL/MLX5: support rectangular blocks clean and working TL/MLX5: add more config for block dimensions force longer by default --- .../tl/mlx5/alltoall/alltoall_coll.c | 124 ++++++++++++------ .../tl/mlx5/alltoall/alltoall_mkeys.c | 23 ++-- src/components/tl/mlx5/tl_mlx5.c | 16 ++- src/components/tl/mlx5/tl_mlx5.h | 3 + src/components/tl/mlx5/tl_mlx5_coll.h | 3 +- src/components/tl/mlx5/tl_mlx5_dm.c | 1 - src/components/tl/mlx5/tl_mlx5_wqe.c | 2 +- test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc | 2 +- 8 files changed, 119 insertions(+), 55 deletions(-) diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 2fbc06cb20..324e9909b8 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -443,7 +443,6 @@ ucc_tl_mlx5_dm_chunk_t* ucc_tl_mlx5_a2a_wait_for_dm_chunk (ucc_tl_mlx5_schedule_ // add polling mechanism for blocks in order to maintain const qp tx rx static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) { - // printf("ucc_tl_mlx5_send_blocks_start\n"); ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); ucc_base_lib_t * lib = UCC_TASK_LIB(task); ucc_tl_mlx5_team_t * team = TASK_TEAM(task); @@ -453,14 +452,17 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * a2a->max_num_of_columns; int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size; - int block_size = task->alltoall.block_size; - int col_msgsize = task->alltoall.msg_size * block_size * node_size; - int block_msgsize = SQUARED(block_size) * task->alltoall.msg_size; + int block_h = task->alltoall.block_height; + int block_w = task->alltoall.block_width; + int col_msgsize = task->alltoall.msg_size * block_w * node_size; + int line_msgsize = task->alltoall.msg_size * block_h * node_size; + int block_msgsize = block_h * block_w * task->alltoall.msg_size; ucc_status_t status = UCC_OK; - int node_grid_dim = node_size / block_size; + int node_grid_w = node_size / block_w; + int node_nbr_blocks = (node_size * node_size) / (block_h * block_w); int seq_index = task->alltoall.seq_index; - int node_idx, block_row, block_col, block_idx, rank, dest_rank, cyc_rank; - uint64_t src_addr, remote_addr; + int node_idx, block_row = 0, block_col = 0, block_idx, rank, dest_rank, cyc_rank; + uint64_t src_addr, remote_addr = 0; ucc_tl_mlx5_dm_chunk_t *dm = NULL; int i,j, k, send_to_self; int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; @@ -486,7 +488,7 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) continue; } dm = NULL; - if (!send_to_self && task->alltoall.op->blocks_sent[cyc_rank] < SQUARED(node_grid_dim)) { + if (!send_to_self && task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks) { dm = ucc_tl_mlx5_a2a_wait_for_dm_chunk (task); if (status != UCC_OK) { return status; @@ -494,24 +496,24 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) } send_start(team, cyc_rank); for (i = 0; i < nbr_serialized_batches; i++) { for (k = 0; - k < batch_size && task->alltoall.op->blocks_sent[cyc_rank] < SQUARED(node_grid_dim); + k < batch_size && task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks; k++, task->alltoall.op->blocks_sent[cyc_rank]++) { block_idx = task->alltoall.op->blocks_sent[cyc_rank]; - block_row = block_idx / node_grid_dim; - block_col = block_idx % node_grid_dim; + block_col = block_idx % node_grid_w; + block_row = block_idx / node_grid_w; src_addr = (uintptr_t)(op_msgsize * seq_index + node_msgsize * dest_rank + col_msgsize * block_col + block_msgsize * block_row); if (send_to_self || !k) { remote_addr = (uintptr_t)(op_msgsize * seq_index + node_msgsize * rank + - block_msgsize * block_col + col_msgsize * block_row); + block_msgsize * block_col + line_msgsize * block_row); } if (send_to_self) { status = ucc_tl_mlx5_post_transpose( tl_mlx5_get_qp(a2a, cyc_rank), a2a->node.ops[seq_index].send_mkeys[0]->lkey, a2a->net.rkeys[cyc_rank], src_addr, remote_addr, - task->alltoall.msg_size, block_size, block_size, + task->alltoall.msg_size, block_w, block_h, (block_row == 0 && block_col == 0) ? IBV_SEND_SIGNALED : 0); if (UCC_OK != status) { return status; @@ -521,7 +523,7 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) tl_mlx5_get_qp(a2a, cyc_rank), a2a->node.ops[seq_index].send_mkeys[0]->lkey, team->dm_mr->rkey, src_addr, dm->addr + k * block_msgsize, - task->alltoall.msg_size, block_size, block_size, 0); + task->alltoall.msg_size, block_w, block_h, 0); if (UCC_OK != status) { return status; } @@ -547,7 +549,7 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) if (dm) { dm->posted_all = 1; } - if (task->alltoall.op->blocks_sent[cyc_rank] == SQUARED(node_grid_dim)) { + if (task->alltoall.op->blocks_sent[cyc_rank] == node_nbr_blocks) { send_start(team, cyc_rank); status = send_atomic(a2a, cyc_rank, tl_mlx5_atomic_addr(task, cyc_rank), tl_mlx5_atomic_rkey(task, cyc_rank)); @@ -600,7 +602,7 @@ ucc_tl_mlx5_send_blocks_leftovers_start(ucc_coll_task_t *coll_task) int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * a2a->max_num_of_columns; int mkey_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team); - int block_size = task->alltoall.block_size; + int block_size = task->alltoall.block_height; //ucc_assert(task->alltoall.block_height == task->alltoall.block_width); int col_msgsize = msg_size * block_size * node_size; int block_msgsize = SQUARED(block_size) * msg_size; int block_size_leftovers_side = node_size % block_size; @@ -810,28 +812,45 @@ static inline int power2(int value) return p; } -static inline int block_size_fits(size_t msgsize, int block_size) +static inline int block_size_fits(size_t msgsize, int height, int width) { - return block_size * ucc_max(power2(block_size) * msgsize, MAX_MSG_SIZE) <= - MAX_TRANSPOSE_SIZE; + size_t t1 = power2(ucc_max(msgsize, 8)); + size_t tsize = height * ucc_max(power2(width) * t1, MAX_MSG_SIZE); + + return tsize <= MAX_TRANSPOSE_SIZE && msgsize <= 128 && height <= 64 && width <= 64; } -static inline int get_block_size(ucc_tl_mlx5_schedule_t *task) +static inline void get_block_dimensions(int ppn, int msgsize, int force_regular, int force_longer, int force_wider, int* block_height, int* block_width) { - ucc_tl_mlx5_team_t *team = SCHEDULE_TEAM(task); - int ppn = team->a2a->node.sbgp->group_size; - size_t effective_msgsize = power2(ucc_max( - task->alltoall.msg_size, 8)); - int block_size; + int h_best = 1; + int w_best = 1; + int h,w; - block_size = ppn; - while (!block_size_fits(effective_msgsize, block_size)) { - block_size--; + for (h = 1; h <= 64; h++) { + if (force_regular && (ppn % h)) { + continue; + } + for (w = 1; w <= 64; w++) { + if ((force_regular && (ppn % w)) + || (force_wider && (w < h)) + || (force_longer && (w > h))) { + continue; + } + if (block_size_fits(msgsize, h, w) + && h*w >= h_best*w_best) { + if ( h*w > h_best*w_best || abs(h/w-1) < abs(h_best/w_best-1)) { + h_best = h; + w_best = w; + } + } + } } - return ucc_max(1, block_size); + *block_height = h_best; + *block_width = w_best; } + UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, (coll_args, team, task_h), ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, @@ -844,6 +863,9 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, == a2a->node.asr_rank); int n_tasks = is_asr ? 5 : 3; int curr_task = 0; + int ppn = tl_team->a2a->node.sbgp->group_size; + ucc_tl_mlx5_lib_config_t* cfg = &UCC_TL_MLX5_TEAM_LIB(tl_team)->cfg; + int i; ucc_schedule_t *schedule; ucc_tl_mlx5_schedule_t *task; size_t msg_size; @@ -893,13 +915,35 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, tl_trace(UCC_TL_TEAM_LIB(tl_team), "Seq num is %d", task->alltoall.seq_num); a2a->sequence_number += 1; - block_size = a2a->requested_block_size ? a2a->requested_block_size - : get_block_size(task); + if (a2a->requested_block_size) { + task->alltoall.block_height = task->alltoall.block_width = a2a->requested_block_size; + if (cfg->force_regular + && ((ppn % task->alltoall.block_height) + || (ppn % task->alltoall.block_width))){ + tl_debug(UCC_TL_TEAM_LIB(tl_team), "the requested block size implies irregular case" + "consider changing the block size or turn off the config FORCE_REGULAR"); + return UCC_ERR_INVALID_PARAM; + } + } else { + if (!cfg->force_regular) { + if (!(cfg->force_longer && cfg->force_wider)) { + tl_debug(UCC_TL_TEAM_LIB(tl_team), "turning off FORCE_REGULAR automatically forces the blocks to be square"); + cfg->force_longer = 1; + cfg->force_wider = 1; + } + } + get_block_dimensions(ppn, task->alltoall.msg_size, + cfg->force_regular, + cfg->force_longer, + cfg->force_wider, + &task->alltoall.block_height, &task->alltoall.block_width); + } + tl_debug(UCC_TL_TEAM_LIB(tl_team), "block dimensions: [%d,%d]", task->alltoall.block_height, task->alltoall.block_width); //todo following section correct assuming homogenous PPN across all nodes task->alltoall.num_of_blocks_columns = - (a2a->node.sbgp->group_size % block_size) - ? ucc_div_round_up(a2a->node.sbgp->group_size, block_size) + (a2a->node.sbgp->group_size % task->alltoall.block_height) + ? ucc_div_round_up(a2a->node.sbgp->group_size, task->alltoall.block_height) : 0; task->alltoall.block_size = block_size; @@ -910,14 +954,16 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, size_t limit = (1ULL << 16); // TODO We need to query this from device (or device type) and not user hardcoded values + ucc_assert(task->alltoall.block_height == task->alltoall.block_width); + int block_size = task->alltoall.block_height; size_t bytes_count, bytes_count_last, bytes_skip, bytes_skip_last; - int ppn = a2a->node.sbgp->group_size; - int bs = block_size; - bytes_count_last = (ppn % bs) * msg_size; - bytes_skip_last = (ppn - (ppn % bs)) * msg_size; - bytes_count = bs * msg_size; - bytes_skip = (ppn - bs) * msg_size; + ucc_assert(task->alltoall.block_height == task->alltoall.block_width); + + bytes_count_last = (ppn % block_size) * msg_size; + bytes_skip_last = (ppn - (ppn % block_size)) * msg_size; + bytes_count = block_size * msg_size; + bytes_skip = (ppn - block_size) * msg_size; if ((bytes_count + bytes_skip >= limit) || (bytes_count_last + bytes_skip_last >= limit)) { tl_debug(UCC_TL_TEAM_LIB(tl_team), "unsupported operation"); diff --git a/src/components/tl/mlx5/alltoall/alltoall_mkeys.c b/src/components/tl/mlx5/alltoall/alltoall_mkeys.c index e8b3052501..3500aed52c 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_mkeys.c +++ b/src/components/tl/mlx5/alltoall/alltoall_mkeys.c @@ -291,14 +291,15 @@ ucc_status_t ucc_tl_mlx5_populate_send_recv_mkeys(ucc_tl_mlx5_team_t * team, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE; int nbc = req->alltoall.num_of_blocks_columns; int seq_index = req->alltoall.seq_index; - int repeat_count = nbc ? a2a->net.sbgp->group_size - : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_size; int n_mkeys = nbc ? nbc : 1; + int repeat_count; int i; ucc_status_t status; if (ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag & UCC_MLX5_NEED_SEND_MKEY_UPDATE) { + repeat_count = nbc ? a2a->net.sbgp->group_size + : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_width; for (i = 0; i < n_mkeys; i++) { status = populate_strided_mkey(a2a, send_mem_access_flags, node->ops[seq_index].send_mkeys[i], @@ -313,6 +314,8 @@ ucc_status_t ucc_tl_mlx5_populate_send_recv_mkeys(ucc_tl_mlx5_team_t * team, } if (ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag & UCC_MLX5_NEED_RECV_MKEY_UPDATE) { + repeat_count = nbc ? a2a->net.sbgp->group_size + : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_height; for (i = 0; i < n_mkeys; i++) { status = populate_strided_mkey(a2a, recv_mem_access_flags, node->ops[seq_index].recv_mkeys[i], @@ -332,7 +335,8 @@ static void update_mkey_entry(ucc_tl_mlx5_alltoall_t *a2a, ucc_tl_mlx5_schedule_t *req, int direction_send) { ucc_tl_mlx5_alltoall_node_t *node = &a2a->node; - int block_size = req->alltoall.block_size; + int block_height = req->alltoall.block_height; + int block_width = req->alltoall.block_width; size_t msg_size = req->alltoall.msg_size; int nbc = req->alltoall.num_of_blocks_columns; struct ibv_mr *buff = direction_send @@ -345,26 +349,27 @@ static void update_mkey_entry(ucc_tl_mlx5_alltoall_t *a2a, mkey_entry = (umr_t *)(direction_send ? MY_SEND_UMR_DATA(req, a2a, 0) : MY_RECV_UMR_DATA(req, a2a, 0)); mkey_entry->addr = (uintptr_t)buff->addr; - mkey_entry->bytes_count = block_size * msg_size; + mkey_entry->bytes_count = (direction_send? block_width : block_height) * msg_size; mkey_entry->bytes_skip = 0; mkey_entry->lkey = direction_send ? buff->lkey : buff->rkey; } else { for (i = 0; i < nbc; i++) { + ucc_assert(block_height == block_width); mkey_entry = (umr_t *)(direction_send ? MY_SEND_UMR_DATA(req, a2a, i) : MY_RECV_UMR_DATA(req, a2a, i)); mkey_entry->addr = - (uintptr_t)buff->addr + i * (block_size * msg_size); + (uintptr_t)buff->addr + i * (block_height * msg_size); mkey_entry->bytes_count = (i == (nbc - 1)) - ? ((node->sbgp->group_size % block_size) * msg_size) - : (block_size * msg_size); + ? ((node->sbgp->group_size % block_height) * msg_size) + : (block_height * msg_size); mkey_entry->bytes_skip = (i == (nbc - 1)) ? ((node->sbgp->group_size - - (node->sbgp->group_size % block_size)) * + (node->sbgp->group_size % block_height)) * msg_size) - : ((node->sbgp->group_size - block_size) * msg_size); + : ((node->sbgp->group_size - block_height) * msg_size); mkey_entry->lkey = direction_send ? buff->lkey : buff->rkey; } } diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 47e775e2ee..b8830d4613 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -28,6 +28,16 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, dm_buf_num), UCC_CONFIG_TYPE_ULUNITS}, + {"FORCE_REGULAR", "y", "Force the regular case where the block dimensions " + "divide ppn. Requires BLOCK_SIZE=0", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_regular), UCC_CONFIG_TYPE_BOOL}, + + {"FORCE_LONGER", "y", "Force the blocks to have more height than width", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_longer), UCC_CONFIG_TYPE_BOOL}, + + {"FORCE_WIDER", "n", "Force the blocks to have more width than height", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_wider), UCC_CONFIG_TYPE_BOOL}, + {"BLOCK_SIZE", "0", "Size of the blocks that are sent using blocked AlltoAll Algorithm", ucc_offsetof(ucc_tl_mlx5_lib_config_t, block_size), UCC_CONFIG_TYPE_UINT}, @@ -108,13 +118,13 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, fanin_kn_radix), UCC_CONFIG_TYPE_UINT}, - {"SEND_BATCH_SIZE", "8", "number of blocks that are transposed " + {"SEND_BATCH_SIZE", "1", "number of blocks that are transposed " "on the NIC before being sent as a batch to a remote peer", ucc_offsetof(ucc_tl_mlx5_lib_config_t, block_batch_size), UCC_CONFIG_TYPE_UINT}, - {"NBR_SERIALIZED_BATCHES", "2", "number of block batches " - "(within the set of blocks to be sent to a given remote peer)" + {"NBR_SERIALIZED_BATCHES", "1", "number of block batches " + "(within the set of blocks to be sent to a given remote peer) " "serialized on the same device memory chunk", ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_serialized_batches), UCC_CONFIG_TYPE_UINT}, diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 6bc92ff56c..20f71d67e8 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -64,6 +64,9 @@ typedef struct ucc_tl_mlx5_lib_config { int nbr_serialized_batches; int nbr_batches_per_passage; int block_batch_size; + int force_regular; + int force_longer; + int force_wider; } ucc_tl_mlx5_lib_config_t; typedef struct ucc_tl_mlx5_context_config { diff --git a/src/components/tl/mlx5/tl_mlx5_coll.h b/src/components/tl/mlx5/tl_mlx5_coll.h index 8ffe3eaf64..9f05d01d90 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.h +++ b/src/components/tl/mlx5/tl_mlx5_coll.h @@ -27,7 +27,8 @@ typedef struct ucc_tl_mlx5_schedule { int seq_num; int seq_index; int num_of_blocks_columns; - int block_size; + int block_height; + int block_width; int started; int send_blocks_enqueued; int blocks_sent; diff --git a/src/components/tl/mlx5/tl_mlx5_dm.c b/src/components/tl/mlx5/tl_mlx5_dm.c index 7ad216b154..23d9aa723d 100644 --- a/src/components/tl/mlx5/tl_mlx5_dm.c +++ b/src/components/tl/mlx5/tl_mlx5_dm.c @@ -236,7 +236,6 @@ ucc_status_t ucc_tl_mlx5_dm_init(ucc_tl_mlx5_team_t *team) goto err_dm_alloc; } team->dm_offset = 0; - printf("Number of MEMIC chunks = %ld, chunck size = %ld\n", cfg->dm_buf_num, cfg->dm_buf_size * cfg->block_batch_size); // TODO: fix case dm_host=true status = ucc_mpool_init(&team->dm_pool, 0, sizeof(ucc_tl_mlx5_dm_chunk_t), 0, UCC_CACHE_LINE_SIZE, 1, diff --git a/src/components/tl/mlx5/tl_mlx5_wqe.c b/src/components/tl/mlx5/tl_mlx5_wqe.c index cf4d590658..399b5626ca 100644 --- a/src/components/tl/mlx5/tl_mlx5_wqe.c +++ b/src/components/tl/mlx5/tl_mlx5_wqe.c @@ -39,8 +39,8 @@ static inline uint8_t get_umr_mr_flags(uint32_t acc) typedef struct transpose_seg { __be32 element_size; /* 8 bit value */ + __be16 num_cols; /* 7 bit value */ //TODO: from PRM we should have the rows first and then the colls... is this a bug ? __be16 num_rows; /* 7 bit value */ - __be16 num_cols; /* 7 bit value */ __be64 padding; } transpose_seg_t; diff --git a/test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc b/test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc index 70cc9a4190..3103ba2380 100644 --- a/test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc +++ b/test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc @@ -64,7 +64,7 @@ UCC_TEST_P(test_tl_mlx5_transpose, transposeWqe) ibv_wr_start(qp.qp_ex); post_transpose(qp.qp, src_mr->lkey, dst_mr->rkey, (uintptr_t)src, - (uintptr_t)dst, elem_size, nrows, ncols, IBV_SEND_SIGNALED); + (uintptr_t)dst, elem_size, ncols, nrows, IBV_SEND_SIGNALED); GTEST_ASSERT_EQ(ibv_wr_complete(qp.qp_ex), 0); while (!completions_num) { From 0bf47f8541e6dab3c82cd7e2c6d416fcd2e4755a Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 21 Aug 2024 13:54:51 +0300 Subject: [PATCH 5/9] TL/MLX5: clean and fix after rebase lintrunner cleaning --- src/components/tl/mlx5/alltoall/alltoall.h | 6 +- .../tl/mlx5/alltoall/alltoall_coll.c | 255 ++++++++++-------- .../tl/mlx5/alltoall/alltoall_mkeys.c | 12 +- src/components/tl/mlx5/tl_mlx5.c | 27 +- src/components/tl/mlx5/tl_mlx5.h | 6 +- src/components/tl/mlx5/tl_mlx5_dm.c | 39 ++- src/components/tl/mlx5/tl_mlx5_wqe.c | 3 +- test/mpi/buffer.cc | 7 - 8 files changed, 194 insertions(+), 161 deletions(-) diff --git a/src/components/tl/mlx5/alltoall/alltoall.h b/src/components/tl/mlx5/alltoall/alltoall.h index 92b69e2e6f..c49a0ff9cf 100644 --- a/src/components/tl/mlx5/alltoall/alltoall.h +++ b/src/components/tl/mlx5/alltoall/alltoall.h @@ -58,9 +58,9 @@ typedef struct ucc_tl_mlx5_alltoall_node { struct mlx5dv_mkey *team_recv_mkey; void *umr_entries_buf; struct ibv_mr *umr_entries_mr; - int fanin_index; - int fanin_dist; - int fanin_max_dist; + int fanin_index; + int fanin_dist; + int fanin_max_dist; } ucc_tl_mlx5_alltoall_node_t; typedef struct alltoall_net_ctrl { diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 324e9909b8..10a08989a9 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -83,7 +83,6 @@ static ucc_status_t ucc_tl_mlx5_poll_cq(struct ibv_cq *cq, ucc_base_lib_t *lib) ucc_tl_mlx5_dm_chunk_t *dm = (ucc_tl_mlx5_dm_chunk_t *)wcs[i].wr_id; dm->task->alltoall.blocks_completed++; dm->completed_sends++; - /* printf("returning dm %p to pool\n", (void*)team->work_completion[i].wr_id); */ if (dm->posted_all && dm->completed_sends == dm->posted_sends) { ucc_mpool_put(dm); } @@ -95,20 +94,19 @@ static ucc_status_t ucc_tl_mlx5_poll_cq(struct ibv_cq *cq, ucc_base_lib_t *lib) static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team, ucc_tl_mlx5_schedule_t *task) { - ucc_tl_mlx5_alltoall_t * a2a = team->a2a; + ucc_tl_mlx5_alltoall_t *a2a = team->a2a; int seq_index = task->alltoall.seq_index; int npolls = UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls; - int radix = UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix; - int vrank = a2a->node.sbgp->group_rank - a2a->node.asr_rank; - int *dist = &a2a->node.fanin_dist; - int size = a2a->node.sbgp->group_size; - int seq_num = task->alltoall.seq_num; - int c_flag = 0; - int polls; - int peer, vpeer, pos, i; + int radix = UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix; + int vrank = a2a->node.sbgp->group_rank - a2a->node.asr_rank; + int *dist = &a2a->node.fanin_dist; + int size = a2a->node.sbgp->group_size; + int seq_num = task->alltoall.seq_num; + int c_flag = 0; + int polls; + int peer, vpeer, pos, i; ucc_tl_mlx5_alltoall_ctrl_t *ctrl_v; - while (*dist <= a2a->node.fanin_max_dist) { if (vrank % *dist == 0) { pos = (vrank / *dist) % radix; @@ -119,7 +117,7 @@ static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team, a2a->node.fanin_index = radix; break; } - peer = (vpeer + a2a->node.asr_rank) % size; + peer = (vpeer + a2a->node.asr_rank) % size; ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, peer); for (polls = 0; polls < npolls; polls++) { if (ctrl_v->seq_num == seq_num) { @@ -133,22 +131,26 @@ static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team, } } else { ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->seq_num = seq_num; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_done", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_fanin_done", 0); return UCC_OK; } } *dist *= radix; a2a->node.fanin_index = 1; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_step_done", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_step_done", + 0); } UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_done", 0); for (i = 0; i < size; i++) { ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, i); - ucc_assert(i == a2a->node.sbgp->group_rank || ctrl_v->seq_num == seq_num); + ucc_assert(i == a2a->node.sbgp->group_rank || + ctrl_v->seq_num == seq_num); c_flag |= ctrl_v->mkey_cache_flag; } ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag = c_flag; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_retrieve_cache_flags_done", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_retrieve_cache_flags_done", 0); return UCC_OK; } @@ -224,7 +226,9 @@ static ucc_status_t ucc_tl_mlx5_reg_fanin_start(ucc_coll_task_t *coll_task) a2a->node.fanin_dist = 1; for (a2a->node.fanin_max_dist = 1; a2a->node.fanin_max_dist < a2a->node.sbgp->group_size; - a2a->node.fanin_max_dist *= UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix){} + a2a->node.fanin_max_dist *= + UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix) { + } if (UCC_OK == ucc_tl_mlx5_node_fanin(team, task)) { tl_debug(UCC_TL_MLX5_TEAM_LIB(team), "fanin complete"); @@ -279,9 +283,11 @@ static ucc_status_t ucc_tl_mlx5_fanout_start(ucc_coll_task_t *coll_task) tl_debug(UCC_TASK_LIB(task), "fanout start"); /* start task if completion event received */ if (team->a2a->node.sbgp->group_rank == team->a2a->node.asr_rank) { - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_wait-on-data_start", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_wait-on-data_start", 0); } else { - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", + 0); } /* Start fanout */ ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, coll_task); @@ -305,8 +311,10 @@ static void ucc_tl_mlx5_fanout_progress(ucc_coll_task_t *coll_task) coll_task->status = UCC_INPROGRESS; return; } - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_wait-on-data_complete", 0); - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_wait-on-data_complete", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", + 0); } if (UCC_OK == ucc_tl_mlx5_node_fanout(team, task)) { @@ -331,20 +339,21 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task) coll_task->super.status = UCC_INPROGRESS; task->alltoall.started = 0; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_start", 0); - // despite while statement in poll_umr_cq, non blocking because have independent cq, // will be finished in a finite time - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_populate_UMR_start", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_populate_UMR_start", + 0); ucc_tl_mlx5_populate_send_recv_mkeys(team, task); - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_populate_UMR_end", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_populate_UMR_end", + 0); UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_start", 0); //Reset atomic notification counter to 0 #if ATOMIC_IN_MEMIC tl_mlx5_atomic_t zero = 0; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_ibv_memcpy_start", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_ibv_memcpy_start", + 0); if (0 != ibv_memcpy_to_dm(a2a->net.atomic.counters, task->alltoall.seq_index * sizeof(tl_mlx5_atomic_t), @@ -394,10 +403,12 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task) tl_error(UCC_TASK_LIB(task), "failed sending barrier notice"); return status; } - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_send_posted", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_barrier_send_posted", 0); } coll_task->status = UCC_OK; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_done", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_done", + 0); return ucc_task_complete(coll_task); } return UCC_OK; @@ -420,11 +431,12 @@ static void ucc_tl_mlx5_asr_barrier_progress(ucc_coll_task_t *coll_task) } } -ucc_tl_mlx5_dm_chunk_t* ucc_tl_mlx5_a2a_wait_for_dm_chunk (ucc_tl_mlx5_schedule_t *task) +ucc_tl_mlx5_dm_chunk_t * +ucc_tl_mlx5_a2a_wait_for_dm_chunk(ucc_tl_mlx5_schedule_t *task) { - ucc_base_lib_t * lib = UCC_TASK_LIB(task); - ucc_tl_mlx5_team_t * team = TASK_TEAM(task); - ucc_tl_mlx5_dm_chunk_t *dm = NULL; + ucc_base_lib_t * lib = UCC_TASK_LIB(task); + ucc_tl_mlx5_team_t * team = TASK_TEAM(&task->super); + ucc_tl_mlx5_dm_chunk_t *dm = NULL; dm = ucc_mpool_get(&team->dm_pool); if (!dm) { @@ -444,30 +456,33 @@ ucc_tl_mlx5_dm_chunk_t* ucc_tl_mlx5_a2a_wait_for_dm_chunk (ucc_tl_mlx5_schedule_ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) { ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); - ucc_base_lib_t * lib = UCC_TASK_LIB(task); - ucc_tl_mlx5_team_t * team = TASK_TEAM(task); - ucc_tl_mlx5_a2a_t * a2a = team->a2a; + ucc_base_lib_t *lib = UCC_TASK_LIB(task); + ucc_tl_mlx5_team_t *team = TASK_TEAM(&task->super); + ucc_tl_mlx5_alltoall_t *a2a = team->a2a; int node_size = a2a->node.sbgp->group_size; int net_size = a2a->net.sbgp->group_size; int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * a2a->max_num_of_columns; - int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size; - int block_h = task->alltoall.block_height; - int block_w = task->alltoall.block_width; - int col_msgsize = task->alltoall.msg_size * block_w * node_size; - int line_msgsize = task->alltoall.msg_size * block_h * node_size; - int block_msgsize = block_h * block_w * task->alltoall.msg_size; - ucc_status_t status = UCC_OK; - int node_grid_w = node_size / block_w; - int node_nbr_blocks = (node_size * node_size) / (block_h * block_w); - int seq_index = task->alltoall.seq_index; - int node_idx, block_row = 0, block_col = 0, block_idx, rank, dest_rank, cyc_rank; - uint64_t src_addr, remote_addr = 0; + int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size; + int block_h = task->alltoall.block_height; + int block_w = task->alltoall.block_width; + int col_msgsize = task->alltoall.msg_size * block_w * node_size; + int line_msgsize = task->alltoall.msg_size * block_h * node_size; + int block_msgsize = block_h * block_w * task->alltoall.msg_size; + ucc_status_t status = UCC_OK; + int node_grid_w = node_size / block_w; + int node_nbr_blocks = (node_size * node_size) / (block_h * block_w); + int seq_index = task->alltoall.seq_index; + int block_row = 0, block_col = 0; + uint64_t remote_addr = 0; ucc_tl_mlx5_dm_chunk_t *dm = NULL; - int i,j, k, send_to_self; int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; - int nbr_serialized_batches = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches; - int nbr_batches_per_passage = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage; + int nbr_serialized_batches = + UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches; + int nbr_batches_per_passage = + UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage; + int i, j, k, send_to_self, block_idx, rank, dest_rank, cyc_rank, node_idx; + uint64_t src_addr; coll_task->status = UCC_INPROGRESS; coll_task->super.status = UCC_INPROGRESS; @@ -481,32 +496,38 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) for (j = 0; j < nbr_batches_per_passage; j++) { for (node_idx = 0; node_idx < net_size; node_idx++) { - cyc_rank = (node_idx + a2a->net.sbgp->group_rank) % net_size; - dest_rank = a2a->net.rank_map[cyc_rank]; + cyc_rank = (node_idx + a2a->net.sbgp->group_rank) % net_size; + dest_rank = a2a->net.rank_map[cyc_rank]; send_to_self = (cyc_rank == a2a->net.sbgp->group_rank); - if (tl_mlx5_barrier_flag(task, cyc_rank) != task->alltoall.seq_num) { + if (tl_mlx5_barrier_flag(task, cyc_rank) != + task->alltoall.seq_num) { continue; } dm = NULL; - if (!send_to_self && task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks) { - dm = ucc_tl_mlx5_a2a_wait_for_dm_chunk (task); + if (!send_to_self && + task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks) { + dm = ucc_tl_mlx5_a2a_wait_for_dm_chunk(task); if (status != UCC_OK) { return status; } } - send_start(team, cyc_rank); for (i = 0; i < nbr_serialized_batches; i++) { + send_start(team, cyc_rank); + for (i = 0; i < nbr_serialized_batches; i++) { for (k = 0; - k < batch_size && task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks; - k++, task->alltoall.op->blocks_sent[cyc_rank]++) { + k < batch_size && + task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks; + k++, task->alltoall.op->blocks_sent[cyc_rank]++) { block_idx = task->alltoall.op->blocks_sent[cyc_rank]; block_col = block_idx % node_grid_w; block_row = block_idx / node_grid_w; - src_addr = (uintptr_t)(op_msgsize * seq_index + node_msgsize * dest_rank + - col_msgsize * block_col + block_msgsize * block_row); + src_addr = (uintptr_t)( + op_msgsize * seq_index + node_msgsize * dest_rank + + col_msgsize * block_col + block_msgsize * block_row); if (send_to_self || !k) { - remote_addr = - (uintptr_t)(op_msgsize * seq_index + node_msgsize * rank + - block_msgsize * block_col + line_msgsize * block_row); + remote_addr = (uintptr_t)(op_msgsize * seq_index + + node_msgsize * rank + + block_msgsize * block_col + + line_msgsize * block_row); } if (send_to_self) { status = ucc_tl_mlx5_post_transpose( @@ -514,7 +535,9 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) a2a->node.ops[seq_index].send_mkeys[0]->lkey, a2a->net.rkeys[cyc_rank], src_addr, remote_addr, task->alltoall.msg_size, block_w, block_h, - (block_row == 0 && block_col == 0) ? IBV_SEND_SIGNALED : 0); + (block_row == 0 && block_col == 0) + ? IBV_SEND_SIGNALED + : 0); if (UCC_OK != status) { return status; } @@ -522,7 +545,8 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) status = ucc_tl_mlx5_post_transpose( tl_mlx5_get_qp(a2a, cyc_rank), a2a->node.ops[seq_index].send_mkeys[0]->lkey, - team->dm_mr->rkey, src_addr, dm->addr + k * block_msgsize, + team->dm_mr->rkey, src_addr, + dm->addr + k * block_msgsize, task->alltoall.msg_size, block_w, block_h, 0); if (UCC_OK != status) { return status; @@ -535,8 +559,8 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) team->dm_mr->lkey, remote_addr, a2a->net.rkeys[cyc_rank], IBV_SEND_SIGNALED, dm); if (status != UCC_OK) { - tl_error(lib, "failed sending block [%d,%d,%d]", node_idx, block_row, - block_col); + tl_error(lib, "failed sending block [%d,%d,%d]", + node_idx, block_row, block_col); return status; } dm->posted_sends++; @@ -551,18 +575,23 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) } if (task->alltoall.op->blocks_sent[cyc_rank] == node_nbr_blocks) { send_start(team, cyc_rank); - status = send_atomic(a2a, cyc_rank, tl_mlx5_atomic_addr(task, cyc_rank), - tl_mlx5_atomic_rkey(task, cyc_rank)); + status = send_atomic(a2a, cyc_rank, + tl_mlx5_atomic_addr(task, cyc_rank), + tl_mlx5_atomic_rkey(task, cyc_rank)); - if (UCC_OK == status) { - status = send_done(team, cyc_rank); - } - task->alltoall.op->blocks_sent[cyc_rank] = 1; - task->alltoall.started++; - if (status != UCC_OK) { - tl_error(UCC_TASK_LIB(task), "Failed sending atomic to node [%d]", - cyc_rank); - return status; + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(task), + "Failed sending atomic to node [%d]", cyc_rank); + return status; + } + status = send_done(team, cyc_rank); + if (UCC_OK != status) { + tl_error(lib, "Failed sending atomic to node %d", node_idx); + return status; + } + task->alltoall.op->blocks_sent[cyc_rank]++; + task->alltoall.started++; + } } } if (!task->alltoall.send_blocks_enqueued) { @@ -602,7 +631,7 @@ ucc_tl_mlx5_send_blocks_leftovers_start(ucc_coll_task_t *coll_task) int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * a2a->max_num_of_columns; int mkey_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team); - int block_size = task->alltoall.block_height; //ucc_assert(task->alltoall.block_height == task->alltoall.block_width); + int block_size = task->alltoall.block_height; int col_msgsize = msg_size * block_size * node_size; int block_msgsize = SQUARED(block_size) * msg_size; int block_size_leftovers_side = node_size % block_size; @@ -684,7 +713,7 @@ ucc_tl_mlx5_send_blocks_leftovers_start(ucc_coll_task_t *coll_task) dm = ucc_mpool_get(&team->dm_pool); send_start(team, cyc_rank); } - dm_addr = dm->addr; + dm_addr = dm->addr; dm->task = task; status = ucc_tl_mlx5_post_transpose( @@ -817,28 +846,30 @@ static inline int block_size_fits(size_t msgsize, int height, int width) size_t t1 = power2(ucc_max(msgsize, 8)); size_t tsize = height * ucc_max(power2(width) * t1, MAX_MSG_SIZE); - return tsize <= MAX_TRANSPOSE_SIZE && msgsize <= 128 && height <= 64 && width <= 64; + return tsize <= MAX_TRANSPOSE_SIZE && msgsize <= 128 && height <= 64 && + width <= 64; } -static inline void get_block_dimensions(int ppn, int msgsize, int force_regular, int force_longer, int force_wider, int* block_height, int* block_width) +static inline void get_block_dimensions(int ppn, int msgsize, int force_regular, + int force_longer, int force_wider, + int *block_height, int *block_width) { int h_best = 1; int w_best = 1; - int h,w; + int h, w; for (h = 1; h <= 64; h++) { if (force_regular && (ppn % h)) { continue; } for (w = 1; w <= 64; w++) { - if ((force_regular && (ppn % w)) - || (force_wider && (w < h)) - || (force_longer && (w > h))) { + if ((force_regular && (ppn % w)) || (force_wider && (w < h)) || + (force_longer && (w > h))) { continue; } - if (block_size_fits(msgsize, h, w) - && h*w >= h_best*w_best) { - if ( h*w > h_best*w_best || abs(h/w-1) < abs(h_best/w_best-1)) { + if (block_size_fits(msgsize, h, w) && h * w >= h_best * w_best) { + if (h * w > h_best * w_best || + abs(h / w - 1) < abs(h_best / w_best - 1)) { h_best = h; w_best = w; } @@ -850,7 +881,6 @@ static inline void get_block_dimensions(int ppn, int msgsize, int force_regular, *block_width = w_best; } - UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, (coll_args, team, task_h), ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, @@ -863,15 +893,15 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, == a2a->node.asr_rank); int n_tasks = is_asr ? 5 : 3; int curr_task = 0; - int ppn = tl_team->a2a->node.sbgp->group_size; - ucc_tl_mlx5_lib_config_t* cfg = &UCC_TL_MLX5_TEAM_LIB(tl_team)->cfg; - int i; + int ppn = tl_team->a2a->node.sbgp->group_size; + ucc_tl_mlx5_lib_config_t *cfg = &UCC_TL_MLX5_TEAM_LIB(tl_team)->cfg; ucc_schedule_t *schedule; ucc_tl_mlx5_schedule_t *task; size_t msg_size; int block_size, i; ucc_coll_task_t *tasks[5]; ucc_status_t status; + size_t bytes_count, bytes_count_last, bytes_skip, bytes_skip_last; if (UCC_IS_INPLACE(coll_args->args)) { return UCC_ERR_NOT_SUPPORTED; @@ -916,36 +946,40 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, a2a->sequence_number += 1; if (a2a->requested_block_size) { - task->alltoall.block_height = task->alltoall.block_width = a2a->requested_block_size; - if (cfg->force_regular - && ((ppn % task->alltoall.block_height) - || (ppn % task->alltoall.block_width))){ - tl_debug(UCC_TL_TEAM_LIB(tl_team), "the requested block size implies irregular case" - "consider changing the block size or turn off the config FORCE_REGULAR"); + task->alltoall.block_height = task->alltoall.block_width = + a2a->requested_block_size; + if (cfg->force_regular && ((ppn % task->alltoall.block_height) || + (ppn % task->alltoall.block_width))) { + tl_debug(UCC_TL_TEAM_LIB(tl_team), + "the requested block size implies irregular case" + "consider changing the block size or turn off the config " + "FORCE_REGULAR"); return UCC_ERR_INVALID_PARAM; } } else { if (!cfg->force_regular) { if (!(cfg->force_longer && cfg->force_wider)) { - tl_debug(UCC_TL_TEAM_LIB(tl_team), "turning off FORCE_REGULAR automatically forces the blocks to be square"); + tl_debug(UCC_TL_TEAM_LIB(tl_team), + "turning off FORCE_REGULAR automatically forces the " + "blocks to be square"); cfg->force_longer = 1; - cfg->force_wider = 1; + cfg->force_wider = 1; } } - get_block_dimensions(ppn, task->alltoall.msg_size, - cfg->force_regular, - cfg->force_longer, - cfg->force_wider, - &task->alltoall.block_height, &task->alltoall.block_width); + get_block_dimensions(ppn, task->alltoall.msg_size, cfg->force_regular, + cfg->force_longer, cfg->force_wider, + &task->alltoall.block_height, + &task->alltoall.block_width); } - tl_debug(UCC_TL_TEAM_LIB(tl_team), "block dimensions: [%d,%d]", task->alltoall.block_height, task->alltoall.block_width); + tl_debug(UCC_TL_TEAM_LIB(tl_team), "block dimensions: [%d,%d]", + task->alltoall.block_height, task->alltoall.block_width); //todo following section correct assuming homogenous PPN across all nodes task->alltoall.num_of_blocks_columns = (a2a->node.sbgp->group_size % task->alltoall.block_height) - ? ucc_div_round_up(a2a->node.sbgp->group_size, task->alltoall.block_height) + ? ucc_div_round_up(a2a->node.sbgp->group_size, + task->alltoall.block_height) : 0; - task->alltoall.block_size = block_size; // TODO remove for connectX-7 - this is mkey_entry->stride (count+skip) limitation - only 16 bits if (task->alltoall @@ -955,8 +989,7 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, (1ULL << 16); // TODO We need to query this from device (or device type) and not user hardcoded values ucc_assert(task->alltoall.block_height == task->alltoall.block_width); - int block_size = task->alltoall.block_height; - size_t bytes_count, bytes_count_last, bytes_skip, bytes_skip_last; + block_size = task->alltoall.block_height; ucc_assert(task->alltoall.block_height == task->alltoall.block_width); diff --git a/src/components/tl/mlx5/alltoall/alltoall_mkeys.c b/src/components/tl/mlx5/alltoall/alltoall_mkeys.c index 3500aed52c..468068024c 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_mkeys.c +++ b/src/components/tl/mlx5/alltoall/alltoall_mkeys.c @@ -299,7 +299,7 @@ ucc_status_t ucc_tl_mlx5_populate_send_recv_mkeys(ucc_tl_mlx5_team_t * team, if (ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag & UCC_MLX5_NEED_SEND_MKEY_UPDATE) { repeat_count = nbc ? a2a->net.sbgp->group_size - : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_width; + : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_width; for (i = 0; i < n_mkeys; i++) { status = populate_strided_mkey(a2a, send_mem_access_flags, node->ops[seq_index].send_mkeys[i], @@ -314,8 +314,9 @@ ucc_status_t ucc_tl_mlx5_populate_send_recv_mkeys(ucc_tl_mlx5_team_t * team, } if (ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag & UCC_MLX5_NEED_RECV_MKEY_UPDATE) { - repeat_count = nbc ? a2a->net.sbgp->group_size - : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_height; + repeat_count = + nbc ? a2a->net.sbgp->group_size + : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_height; for (i = 0; i < n_mkeys; i++) { status = populate_strided_mkey(a2a, recv_mem_access_flags, node->ops[seq_index].recv_mkeys[i], @@ -336,7 +337,7 @@ static void update_mkey_entry(ucc_tl_mlx5_alltoall_t *a2a, { ucc_tl_mlx5_alltoall_node_t *node = &a2a->node; int block_height = req->alltoall.block_height; - int block_width = req->alltoall.block_width; + int block_width = req->alltoall.block_width; size_t msg_size = req->alltoall.msg_size; int nbc = req->alltoall.num_of_blocks_columns; struct ibv_mr *buff = direction_send @@ -349,7 +350,8 @@ static void update_mkey_entry(ucc_tl_mlx5_alltoall_t *a2a, mkey_entry = (umr_t *)(direction_send ? MY_SEND_UMR_DATA(req, a2a, 0) : MY_RECV_UMR_DATA(req, a2a, 0)); mkey_entry->addr = (uintptr_t)buff->addr; - mkey_entry->bytes_count = (direction_send? block_width : block_height) * msg_size; + mkey_entry->bytes_count = + (direction_send ? block_width : block_height) * msg_size; mkey_entry->bytes_skip = 0; mkey_entry->lkey = direction_send ? buff->lkey : buff->rkey; } else { diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index b8830d4613..d9082fa833 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -28,12 +28,15 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, dm_buf_num), UCC_CONFIG_TYPE_ULUNITS}, - {"FORCE_REGULAR", "y", "Force the regular case where the block dimensions " - "divide ppn. Requires BLOCK_SIZE=0", - ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_regular), UCC_CONFIG_TYPE_BOOL}, + {"FORCE_REGULAR", "y", + "Force the regular case where the block dimensions " + "divide ppn. Requires BLOCK_SIZE=0", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_regular), + UCC_CONFIG_TYPE_BOOL}, {"FORCE_LONGER", "y", "Force the blocks to have more height than width", - ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_longer), UCC_CONFIG_TYPE_BOOL}, + ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_longer), + UCC_CONFIG_TYPE_BOOL}, {"FORCE_WIDER", "n", "Force the blocks to have more width than height", ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_wider), UCC_CONFIG_TYPE_BOOL}, @@ -118,18 +121,21 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, fanin_kn_radix), UCC_CONFIG_TYPE_UINT}, - {"SEND_BATCH_SIZE", "1", "number of blocks that are transposed " - "on the NIC before being sent as a batch to a remote peer", + {"SEND_BATCH_SIZE", "1", + "number of blocks that are transposed " + "on the NIC before being sent as a batch to a remote peer", ucc_offsetof(ucc_tl_mlx5_lib_config_t, block_batch_size), UCC_CONFIG_TYPE_UINT}, - {"NBR_SERIALIZED_BATCHES", "1", "number of block batches " - "(within the set of blocks to be sent to a given remote peer) " + {"NBR_SERIALIZED_BATCHES", "1", + "number of block batches " + "(within the set of blocks to be sent to a given remote peer) " "serialized on the same device memory chunk", ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_serialized_batches), UCC_CONFIG_TYPE_UINT}, - {"NBR_BATCHES_PER_PASSAGE", "32", "", + {"NBR_BATCHES_PER_PASSAGE", "32", + "number of batches of blocks sent to one remote node before enqueing", ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_batches_per_passage), UCC_CONFIG_TYPE_UINT}, @@ -155,7 +161,8 @@ static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.ib_dev_name), UCC_CONFIG_TYPE_STRING}, {"FANIN_NPOLLS", "1000", - "Number of shared memory polling before returning UCC_INPROGRESS during internode FANIN", + "Number of shared memory polling before returning UCC_INPROGRESS during " + "internode FANIN", ucc_offsetof(ucc_tl_mlx5_context_config_t, npolls), UCC_CONFIG_TYPE_UINT}, {NULL}}; diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 20f71d67e8..fe44921004 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -65,8 +65,8 @@ typedef struct ucc_tl_mlx5_lib_config { int nbr_batches_per_passage; int block_batch_size; int force_regular; - int force_longer; - int force_wider; + int force_longer; + int force_wider; } ucc_tl_mlx5_lib_config_t; typedef struct ucc_tl_mlx5_context_config { @@ -103,7 +103,7 @@ typedef struct ucc_tl_mlx5_task ucc_tl_mlx5_task_t; typedef struct ucc_tl_mlx5_schedule ucc_tl_mlx5_schedule_t; typedef struct ucc_tl_mlx5_dm_chunk_t { uintptr_t addr; // 0 based offset from the beginning of - // memic_mr (obtained with ibv_reg_dm_mr) + // memic_mr (obtained with ibv_reg_dm_mr) ucc_tl_mlx5_schedule_t *task; int posted_sends; int posted_all; diff --git a/src/components/tl/mlx5/tl_mlx5_dm.c b/src/components/tl/mlx5/tl_mlx5_dm.c index 23d9aa723d..65ac46fe69 100644 --- a/src/components/tl/mlx5/tl_mlx5_dm.c +++ b/src/components/tl/mlx5/tl_mlx5_dm.c @@ -77,21 +77,16 @@ static void ucc_tl_mlx5_dm_chunk_init(ucc_mpool_t *mp, //NOLINT ucc_tl_mlx5_dm_chunk_t *c = (ucc_tl_mlx5_dm_chunk_t *)obj; ucc_tl_mlx5_team_t *team = ucc_container_of(mp, ucc_tl_mlx5_team_t, dm_pool); - c->addr = (uintptr_t)PTR_OFFSET( - (UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_host)? - team->dm_ptr : NULL, - team->dm_offset); - team->dm_offset = team->dm_offset + - UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size - * UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; - c->posted_sends = 0; - c->posted_all=0; - c->completed_sends = 0; -} - c->offset = (ptrdiff_t)team->dm_offset; - team->dm_offset = PTR_OFFSET(team->dm_offset, - UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size); + c->addr = (uintptr_t)PTR_OFFSET( + (UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_host) ? team->dm_ptr : NULL, + team->dm_offset); + c->posted_sends = 0; + c->posted_all = 0; + c->completed_sends = 0; + team->dm_offset = + team->dm_offset + UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size * + UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; } static ucc_mpool_ops_t ucc_tl_mlx5_dm_ops = { @@ -230,17 +225,19 @@ ucc_status_t ucc_tl_mlx5_dm_init(ucc_tl_mlx5_team_t *team) } status = ucc_tl_mlx5_dm_alloc_reg( - ctx->shared_ctx, ctx->shared_pd, cfg->dm_host, cfg->dm_buf_size * cfg->block_batch_size, - &cfg->dm_buf_num, &team->dm_ptr, &team->dm_mr, UCC_TL_TEAM_LIB(team)); + ctx->shared_ctx, ctx->shared_pd, cfg->dm_host, + cfg->dm_buf_size * cfg->block_batch_size, &cfg->dm_buf_num, + &team->dm_ptr, &team->dm_mr, UCC_TL_TEAM_LIB(team)); if (status != UCC_OK) { goto err_dm_alloc; } team->dm_offset = 0; - // TODO: fix case dm_host=true - status = ucc_mpool_init(&team->dm_pool, 0, sizeof(ucc_tl_mlx5_dm_chunk_t), - 0, UCC_CACHE_LINE_SIZE, 1, - cfg->dm_buf_num, &ucc_tl_mlx5_dm_ops, - ctx->super.super.ucc_context->thread_mode, "mlx5 dm pool"); + // TODO: fix/check the case dm_host=true + ucc_assert(!cfg->dm_host); + status = ucc_mpool_init( + &team->dm_pool, 0, sizeof(ucc_tl_mlx5_dm_chunk_t), 0, + UCC_CACHE_LINE_SIZE, 1, cfg->dm_buf_num, &ucc_tl_mlx5_dm_ops, + ctx->super.super.ucc_context->thread_mode, "mlx5 dm pool"); if (status != UCC_OK) { tl_debug(UCC_TL_TEAM_LIB(team), "failed to init dm pool"); goto err_mpool_init; diff --git a/src/components/tl/mlx5/tl_mlx5_wqe.c b/src/components/tl/mlx5/tl_mlx5_wqe.c index 399b5626ca..e783abf791 100644 --- a/src/components/tl/mlx5/tl_mlx5_wqe.c +++ b/src/components/tl/mlx5/tl_mlx5_wqe.c @@ -39,7 +39,8 @@ static inline uint8_t get_umr_mr_flags(uint32_t acc) typedef struct transpose_seg { __be32 element_size; /* 8 bit value */ - __be16 num_cols; /* 7 bit value */ //TODO: from PRM we should have the rows first and then the colls... is this a bug ? + //From PRM we should have the rows first and then the colls. This is probably a naming error + __be16 num_cols; /* 7 bit value */ __be16 num_rows; /* 7 bit value */ __be64 padding; } transpose_seg_t; diff --git a/test/mpi/buffer.cc b/test/mpi/buffer.cc index 91cf1c311b..f31f42c553 100644 --- a/test/mpi/buffer.cc +++ b/test/mpi/buffer.cc @@ -182,13 +182,6 @@ ucc_status_t compare_buffers(void *_rst, void *expected, size_t count, } else { status = memcmp(rst, expected, count*ucc_dt_size(dt)) ? UCC_ERR_NO_MESSAGE : UCC_OK; - // uint8_t* a = (uint8_t*)rst; - // uint8_t* b = (uint8_t*)expected; - // for (int i=0; i Date: Mon, 30 Dec 2024 20:06:22 +0200 Subject: [PATCH 6/9] CODESTYLE: fix alignments and minor comments --- .../tl/mlx5/alltoall/alltoall_coll.c | 81 +++++++++---------- 1 file changed, 38 insertions(+), 43 deletions(-) diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 10a08989a9..17735c7867 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -96,15 +96,17 @@ static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team, { ucc_tl_mlx5_alltoall_t *a2a = team->a2a; int seq_index = task->alltoall.seq_index; - int npolls = UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls; - int radix = UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix; - int vrank = a2a->node.sbgp->group_rank - a2a->node.asr_rank; - int *dist = &a2a->node.fanin_dist; - int size = a2a->node.sbgp->group_size; - int seq_num = task->alltoall.seq_num; - int c_flag = 0; - int polls; - int peer, vpeer, pos, i; + int npolls = + UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls; + int radix = + UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix; + int vrank = + a2a->node.sbgp->group_rank - a2a->node.asr_rank; + int *dist = &a2a->node.fanin_dist; + int size = a2a->node.sbgp->group_size; + int seq_num = task->alltoall.seq_num; + int c_flag = 0; + int polls, peer, vpeer, pos, i; ucc_tl_mlx5_alltoall_ctrl_t *ctrl_v; while (*dist <= a2a->node.fanin_max_dist) { @@ -282,12 +284,10 @@ static ucc_status_t ucc_tl_mlx5_fanout_start(ucc_coll_task_t *coll_task) tl_debug(UCC_TASK_LIB(task), "fanout start"); /* start task if completion event received */ + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", 0); if (team->a2a->node.sbgp->group_rank == team->a2a->node.asr_rank) { UCC_TL_MLX5_PROFILE_REQUEST_EVENT( task, "mlx5_alltoall_wait-on-data_start", 0); - } else { - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", - 0); } /* Start fanout */ ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, coll_task); @@ -455,34 +455,32 @@ ucc_tl_mlx5_a2a_wait_for_dm_chunk(ucc_tl_mlx5_schedule_t *task) // add polling mechanism for blocks in order to maintain const qp tx rx static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) { - ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); - ucc_base_lib_t *lib = UCC_TASK_LIB(task); - ucc_tl_mlx5_team_t *team = TASK_TEAM(&task->super); - ucc_tl_mlx5_alltoall_t *a2a = team->a2a; - int node_size = a2a->node.sbgp->group_size; - int net_size = a2a->net.sbgp->group_size; - int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * - a2a->max_num_of_columns; - int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size; - int block_h = task->alltoall.block_height; - int block_w = task->alltoall.block_width; - int col_msgsize = task->alltoall.msg_size * block_w * node_size; - int line_msgsize = task->alltoall.msg_size * block_h * node_size; - int block_msgsize = block_h * block_w * task->alltoall.msg_size; - ucc_status_t status = UCC_OK; - int node_grid_w = node_size / block_w; - int node_nbr_blocks = (node_size * node_size) / (block_h * block_w); - int seq_index = task->alltoall.seq_index; - int block_row = 0, block_col = 0; - uint64_t remote_addr = 0; - ucc_tl_mlx5_dm_chunk_t *dm = NULL; - int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; - int nbr_serialized_batches = - UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches; - int nbr_batches_per_passage = - UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage; - int i, j, k, send_to_self, block_idx, rank, dest_rank, cyc_rank, node_idx; - uint64_t src_addr; + ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); + ucc_base_lib_t *lib = UCC_TASK_LIB(task); + ucc_tl_mlx5_team_t *team = TASK_TEAM(&task->super); + ucc_tl_mlx5_alltoall_t *a2a = team->a2a; + int node_size = a2a->node.sbgp->group_size; + int net_size = a2a->net.sbgp->group_size; + int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * a2a->max_num_of_columns; + int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size; + int block_h = task->alltoall.block_height; + int block_w = task->alltoall.block_width; + int col_msgsize = task->alltoall.msg_size * block_w * node_size; + int line_msgsize = task->alltoall.msg_size * block_h * node_size; + int block_msgsize = block_h * block_w * task->alltoall.msg_size; + ucc_status_t status = UCC_OK; + int node_grid_w = node_size / block_w; + int node_nbr_blocks = (node_size * node_size) / (block_h * block_w); + int seq_index = task->alltoall.seq_index; + int block_row = 0; + int block_col = 0; + uint64_t remote_addr = 0; + ucc_tl_mlx5_dm_chunk_t *dm = NULL; + int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; + int nbr_serialized_batches = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches; + int nbr_batches_per_passage = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage; + int i, j, k, send_to_self, block_idx, rank, dest_rank, cyc_rank, node_idx; + uint64_t src_addr; coll_task->status = UCC_INPROGRESS; coll_task->super.status = UCC_INPROGRESS; @@ -507,9 +505,6 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) if (!send_to_self && task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks) { dm = ucc_tl_mlx5_a2a_wait_for_dm_chunk(task); - if (status != UCC_OK) { - return status; - } } send_start(team, cyc_rank); for (i = 0; i < nbr_serialized_batches; i++) { From 0e39489aa65af78f744905c81decb7fdb3e94008 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 30 Dec 2024 20:11:49 +0200 Subject: [PATCH 7/9] CODESTYLE: git-clang-format --- .../tl/mlx5/alltoall/alltoall_coll.c | 72 +++++++++---------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 17735c7867..2a1af759b5 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -96,17 +96,14 @@ static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team, { ucc_tl_mlx5_alltoall_t *a2a = team->a2a; int seq_index = task->alltoall.seq_index; - int npolls = - UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls; - int radix = - UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix; - int vrank = - a2a->node.sbgp->group_rank - a2a->node.asr_rank; - int *dist = &a2a->node.fanin_dist; - int size = a2a->node.sbgp->group_size; - int seq_num = task->alltoall.seq_num; - int c_flag = 0; - int polls, peer, vpeer, pos, i; + int npolls = UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls; + int radix = UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix; + int vrank = a2a->node.sbgp->group_rank - a2a->node.asr_rank; + int *dist = &a2a->node.fanin_dist; + int size = a2a->node.sbgp->group_size; + int seq_num = task->alltoall.seq_num; + int c_flag = 0; + int polls, peer, vpeer, pos, i; ucc_tl_mlx5_alltoall_ctrl_t *ctrl_v; while (*dist <= a2a->node.fanin_max_dist) { @@ -455,31 +452,34 @@ ucc_tl_mlx5_a2a_wait_for_dm_chunk(ucc_tl_mlx5_schedule_t *task) // add polling mechanism for blocks in order to maintain const qp tx rx static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) { - ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); - ucc_base_lib_t *lib = UCC_TASK_LIB(task); - ucc_tl_mlx5_team_t *team = TASK_TEAM(&task->super); - ucc_tl_mlx5_alltoall_t *a2a = team->a2a; - int node_size = a2a->node.sbgp->group_size; - int net_size = a2a->net.sbgp->group_size; - int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * a2a->max_num_of_columns; - int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size; - int block_h = task->alltoall.block_height; - int block_w = task->alltoall.block_width; - int col_msgsize = task->alltoall.msg_size * block_w * node_size; - int line_msgsize = task->alltoall.msg_size * block_h * node_size; - int block_msgsize = block_h * block_w * task->alltoall.msg_size; - ucc_status_t status = UCC_OK; - int node_grid_w = node_size / block_w; - int node_nbr_blocks = (node_size * node_size) / (block_h * block_w); - int seq_index = task->alltoall.seq_index; - int block_row = 0; - int block_col = 0; - uint64_t remote_addr = 0; - ucc_tl_mlx5_dm_chunk_t *dm = NULL; - int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; - int nbr_serialized_batches = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches; - int nbr_batches_per_passage = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage; - int i, j, k, send_to_self, block_idx, rank, dest_rank, cyc_rank, node_idx; + ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); + ucc_base_lib_t * lib = UCC_TASK_LIB(task); + ucc_tl_mlx5_team_t * team = TASK_TEAM(&task->super); + ucc_tl_mlx5_alltoall_t *a2a = team->a2a; + int node_size = a2a->node.sbgp->group_size; + int net_size = a2a->net.sbgp->group_size; + int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * + a2a->max_num_of_columns; + int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size; + int block_h = task->alltoall.block_height; + int block_w = task->alltoall.block_width; + int col_msgsize = task->alltoall.msg_size * block_w * node_size; + int line_msgsize = task->alltoall.msg_size * block_h * node_size; + int block_msgsize = block_h * block_w * task->alltoall.msg_size; + ucc_status_t status = UCC_OK; + int node_grid_w = node_size / block_w; + int node_nbr_blocks = (node_size * node_size) / (block_h * block_w); + int seq_index = task->alltoall.seq_index; + int block_row = 0; + int block_col = 0; + uint64_t remote_addr = 0; + ucc_tl_mlx5_dm_chunk_t *dm = NULL; + int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; + int nbr_serialized_batches = + UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches; + int nbr_batches_per_passage = + UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage; + int i, j, k, send_to_self, block_idx, rank, dest_rank, cyc_rank, node_idx; uint64_t src_addr; coll_task->status = UCC_INPROGRESS; From 83123003d44c94102f426d6325e6ddbc54a148d1 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 30 Dec 2024 20:20:15 +0200 Subject: [PATCH 8/9] REVIEW: add assert to fix clang-tidy --- src/components/tl/mlx5/alltoall/alltoall_coll.c | 1 + 1 file changed, 1 insertion(+) diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 2a1af759b5..ea17fb0253 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -537,6 +537,7 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) return status; } } else { + ucc_assert(dm != NULL); status = ucc_tl_mlx5_post_transpose( tl_mlx5_get_qp(a2a, cyc_rank), a2a->node.ops[seq_index].send_mkeys[0]->lkey, From 4bd59957adc3d2d7947f7b5a102154aa46ca8456 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 2 Jan 2025 17:44:13 +0200 Subject: [PATCH 9/9] REVIEW: use PTR_OFFSET for dm ptr arith --- src/components/tl/mlx5/tl_mlx5_dm.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/components/tl/mlx5/tl_mlx5_dm.c b/src/components/tl/mlx5/tl_mlx5_dm.c index 65ac46fe69..39e4779752 100644 --- a/src/components/tl/mlx5/tl_mlx5_dm.c +++ b/src/components/tl/mlx5/tl_mlx5_dm.c @@ -84,9 +84,9 @@ static void ucc_tl_mlx5_dm_chunk_init(ucc_mpool_t *mp, //NOLINT c->posted_sends = 0; c->posted_all = 0; c->completed_sends = 0; - team->dm_offset = - team->dm_offset + UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size * - UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; + team->dm_offset = PTR_OFFSET( + team->dm_offset, UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size * + UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size); } static ucc_mpool_ops_t ucc_tl_mlx5_dm_ops = {