diff --git a/src/components/tl/mlx5/alltoall/alltoall.h b/src/components/tl/mlx5/alltoall/alltoall.h index 9fd9d787cc..c49a0ff9cf 100644 --- a/src/components/tl/mlx5/alltoall/alltoall.h +++ b/src/components/tl/mlx5/alltoall/alltoall.h @@ -58,6 +58,9 @@ typedef struct ucc_tl_mlx5_alltoall_node { struct mlx5dv_mkey *team_recv_mkey; void *umr_entries_buf; struct ibv_mr *umr_entries_mr; + int fanin_index; + int fanin_dist; + int fanin_max_dist; } ucc_tl_mlx5_alltoall_node_t; typedef struct alltoall_net_ctrl { diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 70439263fb..ea17fb0253 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -82,7 +82,10 @@ static ucc_status_t ucc_tl_mlx5_poll_cq(struct ibv_cq *cq, ucc_base_lib_t *lib) } else { ucc_tl_mlx5_dm_chunk_t *dm = (ucc_tl_mlx5_dm_chunk_t *)wcs[i].wr_id; dm->task->alltoall.blocks_completed++; - ucc_mpool_put(dm); + dm->completed_sends++; + if (dm->posted_all && dm->completed_sends == dm->posted_sends) { + ucc_mpool_put(dm); + } } } return UCC_OK; @@ -91,40 +94,68 @@ static ucc_status_t ucc_tl_mlx5_poll_cq(struct ibv_cq *cq, ucc_base_lib_t *lib) static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team, ucc_tl_mlx5_schedule_t *task) { - ucc_tl_mlx5_alltoall_t *a2a = team->a2a; - int seq_index = task->alltoall.seq_index; - int i; + ucc_tl_mlx5_alltoall_t *a2a = team->a2a; + int seq_index = task->alltoall.seq_index; + int npolls = UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls; + int radix = UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix; + int vrank = a2a->node.sbgp->group_rank - a2a->node.asr_rank; + int *dist = &a2a->node.fanin_dist; + int size = a2a->node.sbgp->group_size; + int seq_num = task->alltoall.seq_num; + int c_flag = 0; + int polls, peer, vpeer, pos, i; ucc_tl_mlx5_alltoall_ctrl_t *ctrl_v; - if (a2a->node.sbgp->group_rank != a2a->node.asr_rank) { - ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->seq_num = - task->alltoall.seq_num; - } else { - for (i = 0; i < a2a->node.sbgp->group_size; i++) { - if (i == a2a->node.sbgp->group_rank) { - continue; - } - ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, i); - if (ctrl_v->seq_num != task->alltoall.seq_num) { - return UCC_INPROGRESS; - } - } - for (i = 0; i < a2a->node.sbgp->group_size; i++) { - if (i == a2a->node.sbgp->group_rank) { - continue; + while (*dist <= a2a->node.fanin_max_dist) { + if (vrank % *dist == 0) { + pos = (vrank / *dist) % radix; + if (pos == 0) { + while (a2a->node.fanin_index < radix) { + vpeer = vrank + a2a->node.fanin_index * *dist; + if (vpeer >= size) { + a2a->node.fanin_index = radix; + break; + } + peer = (vpeer + a2a->node.asr_rank) % size; + ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, peer); + for (polls = 0; polls < npolls; polls++) { + if (ctrl_v->seq_num == seq_num) { + a2a->node.fanin_index++; + break; + } + } + if (polls == npolls) { + return UCC_INPROGRESS; + } + } + } else { + ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->seq_num = seq_num; + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_fanin_done", 0); + return UCC_OK; } - ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, i); - ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag |= - ctrl_v->mkey_cache_flag; } + *dist *= radix; + a2a->node.fanin_index = 1; + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_step_done", + 0); } UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_done", 0); + for (i = 0; i < size; i++) { + ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, i); + ucc_assert(i == a2a->node.sbgp->group_rank || + ctrl_v->seq_num == seq_num); + c_flag |= ctrl_v->mkey_cache_flag; + } + ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag = c_flag; + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_retrieve_cache_flags_done", 0); return UCC_OK; } /* Each rank registers sbuf and rbuf and place the registration data - in the shared memory location. Next, all rank in node nitify the - ASR the registration data is ready using SHM Fanin */ + in the shared memory location. Next, all rank in node notify the + ASR that the registration data is ready using SHM Fanin */ static ucc_status_t ucc_tl_mlx5_reg_fanin_start(ucc_coll_task_t *coll_task) { ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); @@ -137,7 +168,7 @@ static ucc_status_t ucc_tl_mlx5_reg_fanin_start(ucc_coll_task_t *coll_task) ucc_tl_mlx5_rcache_region_t *send_ptr; ucc_tl_mlx5_rcache_region_t *recv_ptr; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_start", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_reg_start", 0); tl_debug(UCC_TASK_LIB(task), "register memory buffers"); coll_task->status = UCC_INPROGRESS; coll_task->super.status = UCC_INPROGRESS; @@ -187,11 +218,20 @@ static ucc_status_t ucc_tl_mlx5_reg_fanin_start(ucc_coll_task_t *coll_task) /* Start fanin */ ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag = flag; ucc_tl_mlx5_update_mkeys_entries(a2a, task, flag); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_reg_done", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_start", 0); + + a2a->node.fanin_index = 1; + a2a->node.fanin_dist = 1; + for (a2a->node.fanin_max_dist = 1; + a2a->node.fanin_max_dist < a2a->node.sbgp->group_size; + a2a->node.fanin_max_dist *= + UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix) { + } if (UCC_OK == ucc_tl_mlx5_node_fanin(team, task)) { tl_debug(UCC_TL_MLX5_TEAM_LIB(team), "fanin complete"); coll_task->status = UCC_OK; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_done", 0); return ucc_task_complete(coll_task); } @@ -204,7 +244,6 @@ void ucc_tl_mlx5_reg_fanin_progress(ucc_coll_task_t *coll_task) ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); ucc_tl_mlx5_team_t *team = SCHEDULE_TEAM(task); - ucc_assert(team->a2a->node.sbgp->group_rank == team->a2a->node.asr_rank); if (UCC_OK == ucc_tl_mlx5_node_fanin(team, task)) { tl_debug(UCC_TL_MLX5_TEAM_LIB(team), "fanin complete"); coll_task->status = UCC_OK; @@ -243,6 +282,10 @@ static ucc_status_t ucc_tl_mlx5_fanout_start(ucc_coll_task_t *coll_task) tl_debug(UCC_TASK_LIB(task), "fanout start"); /* start task if completion event received */ UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", 0); + if (team->a2a->node.sbgp->group_rank == team->a2a->node.asr_rank) { + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_wait-on-data_start", 0); + } /* Start fanout */ ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, coll_task); return UCC_OK; @@ -265,6 +308,10 @@ static void ucc_tl_mlx5_fanout_progress(ucc_coll_task_t *coll_task) coll_task->status = UCC_INPROGRESS; return; } + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_wait-on-data_complete", 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", + 0); } if (UCC_OK == ucc_tl_mlx5_node_fanout(team, task)) { @@ -292,11 +339,18 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task) UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_start", 0); // despite while statement in poll_umr_cq, non blocking because have independent cq, // will be finished in a finite time + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_populate_UMR_start", + 0); ucc_tl_mlx5_populate_send_recv_mkeys(team, task); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_populate_UMR_end", + 0); + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_start", 0); //Reset atomic notification counter to 0 #if ATOMIC_IN_MEMIC tl_mlx5_atomic_t zero = 0; + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_ibv_memcpy_start", + 0); if (0 != ibv_memcpy_to_dm(a2a->net.atomic.counters, task->alltoall.seq_index * sizeof(tl_mlx5_atomic_t), @@ -304,6 +358,7 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task) tl_error(UCC_TASK_LIB(task), "failed to reset atomic in memic"); return UCC_ERR_NO_MESSAGE; } + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_ibv_memcpy_done", 0); #else a2a->net.atomic.counters[task->alltoall.seq_index] = 0; #endif @@ -342,12 +397,14 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task) status = send_done(team, i); } if (status != UCC_OK) { - tl_error(UCC_TASK_LIB(task), "failed sending barrier notice"); + tl_error(UCC_TASK_LIB(task), "failed sending barrier notice"); return status; } + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_barrier_send_posted", 0); } coll_task->status = UCC_OK; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barreir_done", + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_done", 0); return ucc_task_complete(coll_task); } @@ -371,33 +428,59 @@ static void ucc_tl_mlx5_asr_barrier_progress(ucc_coll_task_t *coll_task) } } +ucc_tl_mlx5_dm_chunk_t * +ucc_tl_mlx5_a2a_wait_for_dm_chunk(ucc_tl_mlx5_schedule_t *task) +{ + ucc_base_lib_t * lib = UCC_TASK_LIB(task); + ucc_tl_mlx5_team_t * team = TASK_TEAM(&task->super); + ucc_tl_mlx5_dm_chunk_t *dm = NULL; + + dm = ucc_mpool_get(&team->dm_pool); + if (!dm) { + while (!dm) { + if (UCC_OK != ucc_tl_mlx5_poll_cq(team->a2a->net.cq, lib)) { + return NULL; + } + dm = ucc_mpool_get(&team->dm_pool); + } + } + dm->task = task; + + return dm; +} + // add polling mechanism for blocks in order to maintain const qp tx rx static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) { - ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); - ucc_tl_mlx5_team_t *team = SCHEDULE_TEAM(task); - ucc_tl_mlx5_alltoall_t *a2a = team->a2a; - int node_size = a2a->node.sbgp->group_size; - int net_size = a2a->net.sbgp->group_size; - int op_msgsize = node_size * a2a->max_msg_size - * UCC_TL_TEAM_SIZE(team) - * a2a->max_num_of_columns; - int node_msgsize = SQUARED(node_size) - * task->alltoall.msg_size; - int block_size = task->alltoall.block_size; - int col_msgsize = task->alltoall.msg_size - * block_size * node_size; - int block_msgsize = SQUARED(block_size) - * task->alltoall.msg_size; - int dm_host = UCC_TL_MLX5_TEAM_LIB(team) - ->cfg.dm_host; - ucc_status_t status = UCC_OK; - ucc_base_lib_t *lib = UCC_TASK_LIB(task); - int seq_index = task->alltoall.seq_index; - int i, j, k, rank, dest_rank, cyc_rank; - uint64_t src_addr, remote_addr; - ucc_tl_mlx5_dm_chunk_t *dm; - uintptr_t dm_addr; + ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); + ucc_base_lib_t * lib = UCC_TASK_LIB(task); + ucc_tl_mlx5_team_t * team = TASK_TEAM(&task->super); + ucc_tl_mlx5_alltoall_t *a2a = team->a2a; + int node_size = a2a->node.sbgp->group_size; + int net_size = a2a->net.sbgp->group_size; + int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * + a2a->max_num_of_columns; + int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size; + int block_h = task->alltoall.block_height; + int block_w = task->alltoall.block_width; + int col_msgsize = task->alltoall.msg_size * block_w * node_size; + int line_msgsize = task->alltoall.msg_size * block_h * node_size; + int block_msgsize = block_h * block_w * task->alltoall.msg_size; + ucc_status_t status = UCC_OK; + int node_grid_w = node_size / block_w; + int node_nbr_blocks = (node_size * node_size) / (block_h * block_w); + int seq_index = task->alltoall.seq_index; + int block_row = 0; + int block_col = 0; + uint64_t remote_addr = 0; + ucc_tl_mlx5_dm_chunk_t *dm = NULL; + int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; + int nbr_serialized_batches = + UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches; + int nbr_batches_per_passage = + UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage; + int i, j, k, send_to_self, block_idx, rank, dest_rank, cyc_rank, node_idx; + uint64_t src_addr; coll_task->status = UCC_INPROGRESS; coll_task->super.status = UCC_INPROGRESS; @@ -409,95 +492,103 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) "mlx5_alltoall_block_send_start", 0); } - for (i = 0; i < net_size; i++) { - cyc_rank = (i + a2a->net.sbgp->group_rank) % net_size; - dest_rank = a2a->net.rank_map[cyc_rank]; - if (task->alltoall.op->blocks_sent[cyc_rank] || - tl_mlx5_barrier_flag(task, cyc_rank) != task->alltoall.seq_num) { - continue; - } - - //send all blocks from curr node to some ARR - for (j = 0; j < (node_size / block_size); j++) { - for (k = 0; k < (node_size / block_size); k++) { - src_addr = (uintptr_t)(node_msgsize * dest_rank + - col_msgsize * j + block_msgsize * k); - remote_addr = - (uintptr_t)(op_msgsize * seq_index + node_msgsize * rank + - block_msgsize * j + col_msgsize * k); - - send_start(team, cyc_rank); - if (cyc_rank == a2a->net.sbgp->group_rank) { - status = ucc_tl_mlx5_post_transpose( - tl_mlx5_get_qp(a2a, cyc_rank), - a2a->node.ops[seq_index].send_mkeys[0]->lkey, - a2a->net.rkeys[cyc_rank], src_addr, remote_addr, - task->alltoall.msg_size, block_size, block_size, - (j == 0 && k == 0) ? IBV_SEND_SIGNALED : 0); - if (UCC_OK != status) { - return status; + for (j = 0; j < nbr_batches_per_passage; j++) { + for (node_idx = 0; node_idx < net_size; node_idx++) { + cyc_rank = (node_idx + a2a->net.sbgp->group_rank) % net_size; + dest_rank = a2a->net.rank_map[cyc_rank]; + send_to_self = (cyc_rank == a2a->net.sbgp->group_rank); + if (tl_mlx5_barrier_flag(task, cyc_rank) != + task->alltoall.seq_num) { + continue; + } + dm = NULL; + if (!send_to_self && + task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks) { + dm = ucc_tl_mlx5_a2a_wait_for_dm_chunk(task); + } + send_start(team, cyc_rank); + for (i = 0; i < nbr_serialized_batches; i++) { + for (k = 0; + k < batch_size && + task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks; + k++, task->alltoall.op->blocks_sent[cyc_rank]++) { + block_idx = task->alltoall.op->blocks_sent[cyc_rank]; + block_col = block_idx % node_grid_w; + block_row = block_idx / node_grid_w; + src_addr = (uintptr_t)( + op_msgsize * seq_index + node_msgsize * dest_rank + + col_msgsize * block_col + block_msgsize * block_row); + if (send_to_self || !k) { + remote_addr = (uintptr_t)(op_msgsize * seq_index + + node_msgsize * rank + + block_msgsize * block_col + + line_msgsize * block_row); } - } else { - dm = ucc_mpool_get(&team->dm_pool); - while (!dm) { - status = send_done(team, cyc_rank); + if (send_to_self) { + status = ucc_tl_mlx5_post_transpose( + tl_mlx5_get_qp(a2a, cyc_rank), + a2a->node.ops[seq_index].send_mkeys[0]->lkey, + a2a->net.rkeys[cyc_rank], src_addr, remote_addr, + task->alltoall.msg_size, block_w, block_h, + (block_row == 0 && block_col == 0) + ? IBV_SEND_SIGNALED + : 0); if (UCC_OK != status) { return status; } - - status = ucc_tl_mlx5_poll_cq(a2a->net.cq, lib); + } else { + ucc_assert(dm != NULL); + status = ucc_tl_mlx5_post_transpose( + tl_mlx5_get_qp(a2a, cyc_rank), + a2a->node.ops[seq_index].send_mkeys[0]->lkey, + team->dm_mr->rkey, src_addr, + dm->addr + k * block_msgsize, + task->alltoall.msg_size, block_w, block_h, 0); if (UCC_OK != status) { return status; } - dm = ucc_mpool_get(&team->dm_pool); - send_start(team, cyc_rank); - } - if (dm_host) { - dm_addr = - (uintptr_t)PTR_OFFSET(team->dm_ptr, dm->offset); - } else { - dm_addr = dm->offset; // dm reg mr 0 based - } - dm->task = task; - - status = ucc_tl_mlx5_post_transpose( - tl_mlx5_get_qp(a2a, cyc_rank), - a2a->node.ops[seq_index].send_mkeys[0]->lkey, - team->dm_mr->rkey, src_addr, dm_addr, - task->alltoall.msg_size, block_size, block_size, 0); - if (UCC_OK != status) { - return status; } + } + if (!send_to_self && k) { status = send_block_data( - a2a, cyc_rank, dm_addr, block_msgsize, + a2a, cyc_rank, dm->addr, block_msgsize * k, team->dm_mr->lkey, remote_addr, a2a->net.rkeys[cyc_rank], IBV_SEND_SIGNALED, dm); if (status != UCC_OK) { - tl_error(lib, "failed sending block [%d,%d,%d]", i, j, - k); + tl_error(lib, "failed sending block [%d,%d,%d]", + node_idx, block_row, block_col); return status; } + dm->posted_sends++; } - status = send_done(team, cyc_rank); + } + status = send_done(team, cyc_rank); + if (status != UCC_OK) { + return status; + } + if (dm) { + dm->posted_all = 1; + } + if (task->alltoall.op->blocks_sent[cyc_rank] == node_nbr_blocks) { + send_start(team, cyc_rank); + status = send_atomic(a2a, cyc_rank, + tl_mlx5_atomic_addr(task, cyc_rank), + tl_mlx5_atomic_rkey(task, cyc_rank)); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(task), + "Failed sending atomic to node [%d]", cyc_rank); + return status; + } + status = send_done(team, cyc_rank); + if (UCC_OK != status) { + tl_error(lib, "Failed sending atomic to node %d", node_idx); return status; } + task->alltoall.op->blocks_sent[cyc_rank]++; + task->alltoall.started++; } } - send_start(team, cyc_rank); - status = send_atomic(a2a, cyc_rank, tl_mlx5_atomic_addr(task, cyc_rank), - tl_mlx5_atomic_rkey(task, cyc_rank)); - - if (UCC_OK == status) { - status = send_done(team, cyc_rank); - } - task->alltoall.op->blocks_sent[cyc_rank] = 1; - task->alltoall.started++; - if (status != UCC_OK) { - tl_error(UCC_TASK_LIB(task), "Failed sending atomic to node [%d]", - cyc_rank); - return status; - } } if (!task->alltoall.send_blocks_enqueued) { ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, coll_task); @@ -536,7 +627,7 @@ ucc_tl_mlx5_send_blocks_leftovers_start(ucc_coll_task_t *coll_task) int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * a2a->max_num_of_columns; int mkey_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team); - int block_size = task->alltoall.block_size; + int block_size = task->alltoall.block_height; int col_msgsize = msg_size * block_size * node_size; int block_msgsize = SQUARED(block_size) * msg_size; int block_size_leftovers_side = node_size % block_size; @@ -545,7 +636,6 @@ ucc_tl_mlx5_send_blocks_leftovers_start(ucc_coll_task_t *coll_task) int block_msgsize_leftovers = block_size_leftovers_side * block_size * msg_size; int corner_msgsize = SQUARED(block_size_leftovers_side) * msg_size; - int dm_host = UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_host; ucc_status_t status = UCC_OK; ucc_base_lib_t *lib = UCC_TASK_LIB(coll_task); int nbc = task->alltoall.num_of_blocks_columns; @@ -619,12 +709,7 @@ ucc_tl_mlx5_send_blocks_leftovers_start(ucc_coll_task_t *coll_task) dm = ucc_mpool_get(&team->dm_pool); send_start(team, cyc_rank); } - if (dm_host) { - dm_addr = - (uintptr_t)PTR_OFFSET(team->dm_ptr, dm->offset); - } else { - dm_addr = dm->offset; // dm reg mr 0 based - } + dm_addr = dm->addr; dm->task = task; status = ucc_tl_mlx5_post_transpose( @@ -752,26 +837,44 @@ static inline int power2(int value) return p; } -static inline int block_size_fits(size_t msgsize, int block_size) +static inline int block_size_fits(size_t msgsize, int height, int width) { - return block_size * ucc_max(power2(block_size) * msgsize, MAX_MSG_SIZE) <= - MAX_TRANSPOSE_SIZE; + size_t t1 = power2(ucc_max(msgsize, 8)); + size_t tsize = height * ucc_max(power2(width) * t1, MAX_MSG_SIZE); + + return tsize <= MAX_TRANSPOSE_SIZE && msgsize <= 128 && height <= 64 && + width <= 64; } -static inline int get_block_size(ucc_tl_mlx5_schedule_t *task) +static inline void get_block_dimensions(int ppn, int msgsize, int force_regular, + int force_longer, int force_wider, + int *block_height, int *block_width) { - ucc_tl_mlx5_team_t *team = SCHEDULE_TEAM(task); - int ppn = team->a2a->node.sbgp->group_size; - size_t effective_msgsize = power2(ucc_max( - task->alltoall.msg_size, 8)); - int block_size; + int h_best = 1; + int w_best = 1; + int h, w; - block_size = ppn; - while (!block_size_fits(effective_msgsize, block_size)) { - block_size--; + for (h = 1; h <= 64; h++) { + if (force_regular && (ppn % h)) { + continue; + } + for (w = 1; w <= 64; w++) { + if ((force_regular && (ppn % w)) || (force_wider && (w < h)) || + (force_longer && (w > h))) { + continue; + } + if (block_size_fits(msgsize, h, w) && h * w >= h_best * w_best) { + if (h * w > h_best * w_best || + abs(h / w - 1) < abs(h_best / w_best - 1)) { + h_best = h; + w_best = w; + } + } + } } - return ucc_max(1, block_size); + *block_height = h_best; + *block_width = w_best; } UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, @@ -786,12 +889,15 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, == a2a->node.asr_rank); int n_tasks = is_asr ? 5 : 3; int curr_task = 0; + int ppn = tl_team->a2a->node.sbgp->group_size; + ucc_tl_mlx5_lib_config_t *cfg = &UCC_TL_MLX5_TEAM_LIB(tl_team)->cfg; ucc_schedule_t *schedule; ucc_tl_mlx5_schedule_t *task; size_t msg_size; int block_size, i; ucc_coll_task_t *tasks[5]; ucc_status_t status; + size_t bytes_count, bytes_count_last, bytes_skip, bytes_skip_last; if (UCC_IS_INPLACE(coll_args->args)) { return UCC_ERR_NOT_SUPPORTED; @@ -835,15 +941,41 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, tl_trace(UCC_TL_TEAM_LIB(tl_team), "Seq num is %d", task->alltoall.seq_num); a2a->sequence_number += 1; - block_size = a2a->requested_block_size ? a2a->requested_block_size - : get_block_size(task); + if (a2a->requested_block_size) { + task->alltoall.block_height = task->alltoall.block_width = + a2a->requested_block_size; + if (cfg->force_regular && ((ppn % task->alltoall.block_height) || + (ppn % task->alltoall.block_width))) { + tl_debug(UCC_TL_TEAM_LIB(tl_team), + "the requested block size implies irregular case" + "consider changing the block size or turn off the config " + "FORCE_REGULAR"); + return UCC_ERR_INVALID_PARAM; + } + } else { + if (!cfg->force_regular) { + if (!(cfg->force_longer && cfg->force_wider)) { + tl_debug(UCC_TL_TEAM_LIB(tl_team), + "turning off FORCE_REGULAR automatically forces the " + "blocks to be square"); + cfg->force_longer = 1; + cfg->force_wider = 1; + } + } + get_block_dimensions(ppn, task->alltoall.msg_size, cfg->force_regular, + cfg->force_longer, cfg->force_wider, + &task->alltoall.block_height, + &task->alltoall.block_width); + } + tl_debug(UCC_TL_TEAM_LIB(tl_team), "block dimensions: [%d,%d]", + task->alltoall.block_height, task->alltoall.block_width); //todo following section correct assuming homogenous PPN across all nodes task->alltoall.num_of_blocks_columns = - (a2a->node.sbgp->group_size % block_size) - ? ucc_div_round_up(a2a->node.sbgp->group_size, block_size) + (a2a->node.sbgp->group_size % task->alltoall.block_height) + ? ucc_div_round_up(a2a->node.sbgp->group_size, + task->alltoall.block_height) : 0; - task->alltoall.block_size = block_size; // TODO remove for connectX-7 - this is mkey_entry->stride (count+skip) limitation - only 16 bits if (task->alltoall @@ -852,14 +984,15 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, size_t limit = (1ULL << 16); // TODO We need to query this from device (or device type) and not user hardcoded values - size_t bytes_count, bytes_count_last, bytes_skip, bytes_skip_last; - int ppn = a2a->node.sbgp->group_size; - int bs = block_size; - - bytes_count_last = (ppn % bs) * msg_size; - bytes_skip_last = (ppn - (ppn % bs)) * msg_size; - bytes_count = bs * msg_size; - bytes_skip = (ppn - bs) * msg_size; + ucc_assert(task->alltoall.block_height == task->alltoall.block_width); + block_size = task->alltoall.block_height; + + ucc_assert(task->alltoall.block_height == task->alltoall.block_width); + + bytes_count_last = (ppn % block_size) * msg_size; + bytes_skip_last = (ppn - (ppn % block_size)) * msg_size; + bytes_count = block_size * msg_size; + bytes_skip = (ppn - block_size) * msg_size; if ((bytes_count + bytes_skip >= limit) || (bytes_count_last + bytes_skip_last >= limit)) { tl_debug(UCC_TL_TEAM_LIB(tl_team), "unsupported operation"); diff --git a/src/components/tl/mlx5/alltoall/alltoall_mkeys.c b/src/components/tl/mlx5/alltoall/alltoall_mkeys.c index e8b3052501..468068024c 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_mkeys.c +++ b/src/components/tl/mlx5/alltoall/alltoall_mkeys.c @@ -291,14 +291,15 @@ ucc_status_t ucc_tl_mlx5_populate_send_recv_mkeys(ucc_tl_mlx5_team_t * team, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE; int nbc = req->alltoall.num_of_blocks_columns; int seq_index = req->alltoall.seq_index; - int repeat_count = nbc ? a2a->net.sbgp->group_size - : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_size; int n_mkeys = nbc ? nbc : 1; + int repeat_count; int i; ucc_status_t status; if (ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag & UCC_MLX5_NEED_SEND_MKEY_UPDATE) { + repeat_count = nbc ? a2a->net.sbgp->group_size + : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_width; for (i = 0; i < n_mkeys; i++) { status = populate_strided_mkey(a2a, send_mem_access_flags, node->ops[seq_index].send_mkeys[i], @@ -313,6 +314,9 @@ ucc_status_t ucc_tl_mlx5_populate_send_recv_mkeys(ucc_tl_mlx5_team_t * team, } if (ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag & UCC_MLX5_NEED_RECV_MKEY_UPDATE) { + repeat_count = + nbc ? a2a->net.sbgp->group_size + : UCC_TL_TEAM_SIZE(team) / req->alltoall.block_height; for (i = 0; i < n_mkeys; i++) { status = populate_strided_mkey(a2a, recv_mem_access_flags, node->ops[seq_index].recv_mkeys[i], @@ -332,7 +336,8 @@ static void update_mkey_entry(ucc_tl_mlx5_alltoall_t *a2a, ucc_tl_mlx5_schedule_t *req, int direction_send) { ucc_tl_mlx5_alltoall_node_t *node = &a2a->node; - int block_size = req->alltoall.block_size; + int block_height = req->alltoall.block_height; + int block_width = req->alltoall.block_width; size_t msg_size = req->alltoall.msg_size; int nbc = req->alltoall.num_of_blocks_columns; struct ibv_mr *buff = direction_send @@ -345,26 +350,28 @@ static void update_mkey_entry(ucc_tl_mlx5_alltoall_t *a2a, mkey_entry = (umr_t *)(direction_send ? MY_SEND_UMR_DATA(req, a2a, 0) : MY_RECV_UMR_DATA(req, a2a, 0)); mkey_entry->addr = (uintptr_t)buff->addr; - mkey_entry->bytes_count = block_size * msg_size; + mkey_entry->bytes_count = + (direction_send ? block_width : block_height) * msg_size; mkey_entry->bytes_skip = 0; mkey_entry->lkey = direction_send ? buff->lkey : buff->rkey; } else { for (i = 0; i < nbc; i++) { + ucc_assert(block_height == block_width); mkey_entry = (umr_t *)(direction_send ? MY_SEND_UMR_DATA(req, a2a, i) : MY_RECV_UMR_DATA(req, a2a, i)); mkey_entry->addr = - (uintptr_t)buff->addr + i * (block_size * msg_size); + (uintptr_t)buff->addr + i * (block_height * msg_size); mkey_entry->bytes_count = (i == (nbc - 1)) - ? ((node->sbgp->group_size % block_size) * msg_size) - : (block_size * msg_size); + ? ((node->sbgp->group_size % block_height) * msg_size) + : (block_height * msg_size); mkey_entry->bytes_skip = (i == (nbc - 1)) ? ((node->sbgp->group_size - - (node->sbgp->group_size % block_size)) * + (node->sbgp->group_size % block_height)) * msg_size) - : ((node->sbgp->group_size - block_size) * msg_size); + : ((node->sbgp->group_size - block_height) * msg_size); mkey_entry->lkey = direction_send ? buff->lkey : buff->rkey; } } diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 5cdd6c51a1..d9082fa833 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -28,6 +28,19 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, dm_buf_num), UCC_CONFIG_TYPE_ULUNITS}, + {"FORCE_REGULAR", "y", + "Force the regular case where the block dimensions " + "divide ppn. Requires BLOCK_SIZE=0", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_regular), + UCC_CONFIG_TYPE_BOOL}, + + {"FORCE_LONGER", "y", "Force the blocks to have more height than width", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_longer), + UCC_CONFIG_TYPE_BOOL}, + + {"FORCE_WIDER", "n", "Force the blocks to have more width than height", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_wider), UCC_CONFIG_TYPE_BOOL}, + {"BLOCK_SIZE", "0", "Size of the blocks that are sent using blocked AlltoAll Algorithm", ucc_offsetof(ucc_tl_mlx5_lib_config_t, block_size), UCC_CONFIG_TYPE_UINT}, @@ -104,6 +117,28 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable), UCC_CONFIG_TYPE_BOOL}, + {"FANIN_KN_RADIX", "4", "Radix of the knomial tree fanin algorithm", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, fanin_kn_radix), + UCC_CONFIG_TYPE_UINT}, + + {"SEND_BATCH_SIZE", "1", + "number of blocks that are transposed " + "on the NIC before being sent as a batch to a remote peer", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, block_batch_size), + UCC_CONFIG_TYPE_UINT}, + + {"NBR_SERIALIZED_BATCHES", "1", + "number of block batches " + "(within the set of blocks to be sent to a given remote peer) " + "serialized on the same device memory chunk", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_serialized_batches), + UCC_CONFIG_TYPE_UINT}, + + {"NBR_BATCHES_PER_PASSAGE", "32", + "number of batches of blocks sent to one remote node before enqueing", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_batches_per_passage), + UCC_CONFIG_TYPE_UINT}, + {NULL}}; static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { @@ -125,6 +160,10 @@ static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { {"MCAST_NET_DEVICE", "", "Specifies which network device to use for Mcast", ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.ib_dev_name), UCC_CONFIG_TYPE_STRING}, + {"FANIN_NPOLLS", "1000", + "Number of shared memory polling before returning UCC_INPROGRESS during " + "internode FANIN", + ucc_offsetof(ucc_tl_mlx5_context_config_t, npolls), UCC_CONFIG_TYPE_UINT}, {NULL}}; diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 159ecda8ed..fe44921004 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -60,12 +60,20 @@ typedef struct ucc_tl_mlx5_lib_config { int dm_host; ucc_tl_mlx5_ib_qp_conf_t qp_conf; ucc_tl_mlx5_mcast_coll_comm_init_spec_t mcast_conf; + int fanin_kn_radix; + int nbr_serialized_batches; + int nbr_batches_per_passage; + int block_batch_size; + int force_regular; + int force_longer; + int force_wider; } ucc_tl_mlx5_lib_config_t; typedef struct ucc_tl_mlx5_context_config { ucc_tl_context_config_t super; ucs_config_names_array_t devices; ucc_tl_mlx5_mcast_ctx_params_t mcast_ctx_conf; + int npolls; } ucc_tl_mlx5_context_config_t; typedef struct ucc_tl_mlx5_lib { @@ -93,10 +101,13 @@ UCC_CLASS_DECLARE(ucc_tl_mlx5_context_t, const ucc_base_context_params_t*, typedef struct ucc_tl_mlx5_task ucc_tl_mlx5_task_t; typedef struct ucc_tl_mlx5_schedule ucc_tl_mlx5_schedule_t; -typedef struct ucc_tl_mlx5_dm_chunk { - ptrdiff_t offset; /* 0 based offset from the beginning of - memic_mr (obtained with ibv_reg_dm_mr) */ +typedef struct ucc_tl_mlx5_dm_chunk_t { + uintptr_t addr; // 0 based offset from the beginning of + // memic_mr (obtained with ibv_reg_dm_mr) ucc_tl_mlx5_schedule_t *task; + int posted_sends; + int posted_all; + int completed_sends; } ucc_tl_mlx5_dm_chunk_t; typedef struct ucc_tl_mlx5_alltoall ucc_tl_mlx5_alltoall_t; diff --git a/src/components/tl/mlx5/tl_mlx5_coll.h b/src/components/tl/mlx5/tl_mlx5_coll.h index 8ffe3eaf64..9f05d01d90 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.h +++ b/src/components/tl/mlx5/tl_mlx5_coll.h @@ -27,7 +27,8 @@ typedef struct ucc_tl_mlx5_schedule { int seq_num; int seq_index; int num_of_blocks_columns; - int block_size; + int block_height; + int block_width; int started; int send_blocks_enqueued; int blocks_sent; diff --git a/src/components/tl/mlx5/tl_mlx5_dm.c b/src/components/tl/mlx5/tl_mlx5_dm.c index 2a0c474a39..39e4779752 100644 --- a/src/components/tl/mlx5/tl_mlx5_dm.c +++ b/src/components/tl/mlx5/tl_mlx5_dm.c @@ -78,9 +78,15 @@ static void ucc_tl_mlx5_dm_chunk_init(ucc_mpool_t *mp, //NOLINT ucc_tl_mlx5_team_t *team = ucc_container_of(mp, ucc_tl_mlx5_team_t, dm_pool); - c->offset = (ptrdiff_t)team->dm_offset; - team->dm_offset = PTR_OFFSET(team->dm_offset, - UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size); + c->addr = (uintptr_t)PTR_OFFSET( + (UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_host) ? team->dm_ptr : NULL, + team->dm_offset); + c->posted_sends = 0; + c->posted_all = 0; + c->completed_sends = 0; + team->dm_offset = PTR_OFFSET( + team->dm_offset, UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size * + UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size); } static ucc_mpool_ops_t ucc_tl_mlx5_dm_ops = { @@ -219,13 +225,15 @@ ucc_status_t ucc_tl_mlx5_dm_init(ucc_tl_mlx5_team_t *team) } status = ucc_tl_mlx5_dm_alloc_reg( - ctx->shared_ctx, ctx->shared_pd, cfg->dm_host, cfg->dm_buf_size, - &cfg->dm_buf_num, &team->dm_ptr, &team->dm_mr, UCC_TL_TEAM_LIB(team)); + ctx->shared_ctx, ctx->shared_pd, cfg->dm_host, + cfg->dm_buf_size * cfg->block_batch_size, &cfg->dm_buf_num, + &team->dm_ptr, &team->dm_mr, UCC_TL_TEAM_LIB(team)); if (status != UCC_OK) { goto err_dm_alloc; } - team->dm_offset = NULL; - + team->dm_offset = 0; + // TODO: fix/check the case dm_host=true + ucc_assert(!cfg->dm_host); status = ucc_mpool_init( &team->dm_pool, 0, sizeof(ucc_tl_mlx5_dm_chunk_t), 0, UCC_CACHE_LINE_SIZE, 1, cfg->dm_buf_num, &ucc_tl_mlx5_dm_ops, diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index 1e5f6ddf56..6a65274d1d 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -117,7 +117,7 @@ ucc_status_t ucc_tl_mlx5_team_destroy(ucc_base_team_t *tl_team) return UCC_OK; } -static inline ucc_status_t ucc_tl_mlx5_a2a_team_test(ucc_base_team_t *team) +static inline ucc_status_t ucc_tl_mlx5_alltoall_team_test(ucc_base_team_t *team) { ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); @@ -253,7 +253,7 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) goto initial_sync_post; } - a2a_status = ucc_tl_mlx5_a2a_team_test(team); + a2a_status = ucc_tl_mlx5_alltoall_team_test(team); if (a2a_status < 0) { tl_warn(team->context->lib, "ALLTOALL tl team: %p creation failed %d", team, a2a_status); diff --git a/src/components/tl/mlx5/tl_mlx5_wqe.c b/src/components/tl/mlx5/tl_mlx5_wqe.c index cf4d590658..e783abf791 100644 --- a/src/components/tl/mlx5/tl_mlx5_wqe.c +++ b/src/components/tl/mlx5/tl_mlx5_wqe.c @@ -39,8 +39,9 @@ static inline uint8_t get_umr_mr_flags(uint32_t acc) typedef struct transpose_seg { __be32 element_size; /* 8 bit value */ - __be16 num_rows; /* 7 bit value */ + //From PRM we should have the rows first and then the colls. This is probably a naming error __be16 num_cols; /* 7 bit value */ + __be16 num_rows; /* 7 bit value */ __be64 padding; } transpose_seg_t; diff --git a/test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc b/test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc index 70cc9a4190..3103ba2380 100644 --- a/test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc +++ b/test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc @@ -64,7 +64,7 @@ UCC_TEST_P(test_tl_mlx5_transpose, transposeWqe) ibv_wr_start(qp.qp_ex); post_transpose(qp.qp, src_mr->lkey, dst_mr->rkey, (uintptr_t)src, - (uintptr_t)dst, elem_size, nrows, ncols, IBV_SEND_SIGNALED); + (uintptr_t)dst, elem_size, ncols, nrows, IBV_SEND_SIGNALED); GTEST_ASSERT_EQ(ibv_wr_complete(qp.qp_ex), 0); while (!completions_num) {