Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/MLX5: various optimizations #1012

Closed
3 changes: 3 additions & 0 deletions src/components/tl/mlx5/alltoall/alltoall.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ typedef struct ucc_tl_mlx5_alltoall_node {
struct mlx5dv_mkey *team_recv_mkey;
void *umr_entries_buf;
struct ibv_mr *umr_entries_mr;
int fanin_index;
int fanin_dist;
int fanin_max_dist;
} ucc_tl_mlx5_alltoall_node_t;

typedef struct alltoall_net_ctrl {
Expand Down
447 changes: 290 additions & 157 deletions src/components/tl/mlx5/alltoall/alltoall_coll.c

Large diffs are not rendered by default.

25 changes: 16 additions & 9 deletions src/components/tl/mlx5/alltoall/alltoall_mkeys.c
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,15 @@ ucc_status_t ucc_tl_mlx5_populate_send_recv_mkeys(ucc_tl_mlx5_team_t * team,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
int nbc = req->alltoall.num_of_blocks_columns;
int seq_index = req->alltoall.seq_index;
int repeat_count = nbc ? a2a->net.sbgp->group_size
: UCC_TL_TEAM_SIZE(team) / req->alltoall.block_size;
int n_mkeys = nbc ? nbc : 1;
int repeat_count;
int i;
ucc_status_t status;

if (ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag &
UCC_MLX5_NEED_SEND_MKEY_UPDATE) {
repeat_count = nbc ? a2a->net.sbgp->group_size
: UCC_TL_TEAM_SIZE(team) / req->alltoall.block_width;
for (i = 0; i < n_mkeys; i++) {
status = populate_strided_mkey(a2a, send_mem_access_flags,
node->ops[seq_index].send_mkeys[i],
Expand All @@ -313,6 +314,9 @@ ucc_status_t ucc_tl_mlx5_populate_send_recv_mkeys(ucc_tl_mlx5_team_t * team,
}
if (ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag &
UCC_MLX5_NEED_RECV_MKEY_UPDATE) {
repeat_count =
nbc ? a2a->net.sbgp->group_size
: UCC_TL_TEAM_SIZE(team) / req->alltoall.block_height;
for (i = 0; i < n_mkeys; i++) {
status = populate_strided_mkey(a2a, recv_mem_access_flags,
node->ops[seq_index].recv_mkeys[i],
Expand All @@ -332,7 +336,8 @@ static void update_mkey_entry(ucc_tl_mlx5_alltoall_t *a2a,
ucc_tl_mlx5_schedule_t *req, int direction_send)
{
ucc_tl_mlx5_alltoall_node_t *node = &a2a->node;
int block_size = req->alltoall.block_size;
int block_height = req->alltoall.block_height;
int block_width = req->alltoall.block_width;
size_t msg_size = req->alltoall.msg_size;
int nbc = req->alltoall.num_of_blocks_columns;
struct ibv_mr *buff = direction_send
Expand All @@ -345,26 +350,28 @@ static void update_mkey_entry(ucc_tl_mlx5_alltoall_t *a2a,
mkey_entry = (umr_t *)(direction_send ? MY_SEND_UMR_DATA(req, a2a, 0)
: MY_RECV_UMR_DATA(req, a2a, 0));
mkey_entry->addr = (uintptr_t)buff->addr;
mkey_entry->bytes_count = block_size * msg_size;
mkey_entry->bytes_count =
(direction_send ? block_width : block_height) * msg_size;
mkey_entry->bytes_skip = 0;
mkey_entry->lkey = direction_send ? buff->lkey : buff->rkey;
} else {
for (i = 0; i < nbc; i++) {
ucc_assert(block_height == block_width);
mkey_entry =
(umr_t *)(direction_send ? MY_SEND_UMR_DATA(req, a2a, i)
: MY_RECV_UMR_DATA(req, a2a, i));
mkey_entry->addr =
(uintptr_t)buff->addr + i * (block_size * msg_size);
(uintptr_t)buff->addr + i * (block_height * msg_size);
mkey_entry->bytes_count =
(i == (nbc - 1))
? ((node->sbgp->group_size % block_size) * msg_size)
: (block_size * msg_size);
? ((node->sbgp->group_size % block_height) * msg_size)
: (block_height * msg_size);
mkey_entry->bytes_skip =
(i == (nbc - 1))
? ((node->sbgp->group_size -
(node->sbgp->group_size % block_size)) *
(node->sbgp->group_size % block_height)) *
msg_size)
: ((node->sbgp->group_size - block_size) * msg_size);
: ((node->sbgp->group_size - block_height) * msg_size);
mkey_entry->lkey = direction_send ? buff->lkey : buff->rkey;
}
}
Expand Down
39 changes: 39 additions & 0 deletions src/components/tl/mlx5/tl_mlx5.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = {
ucc_offsetof(ucc_tl_mlx5_lib_config_t, dm_buf_num),
UCC_CONFIG_TYPE_ULUNITS},

{"FORCE_REGULAR", "y",
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
"Force the regular case where the block dimensions "
"divide ppn. Requires BLOCK_SIZE=0",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_regular),
UCC_CONFIG_TYPE_BOOL},

{"FORCE_LONGER", "y", "Force the blocks to have more height than width",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_longer),
UCC_CONFIG_TYPE_BOOL},

{"FORCE_WIDER", "n", "Force the blocks to have more width than height",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_wider), UCC_CONFIG_TYPE_BOOL},

{"BLOCK_SIZE", "0",
"Size of the blocks that are sent using blocked AlltoAll Algorithm",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, block_size), UCC_CONFIG_TYPE_UINT},
Expand Down Expand Up @@ -104,6 +117,28 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = {
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable),
UCC_CONFIG_TYPE_BOOL},

{"FANIN_KN_RADIX", "4", "Radix of the knomial tree fanin algorithm",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, fanin_kn_radix),
UCC_CONFIG_TYPE_UINT},

{"SEND_BATCH_SIZE", "1",
"number of blocks that are transposed "
"on the NIC before being sent as a batch to a remote peer",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, block_batch_size),
UCC_CONFIG_TYPE_UINT},

{"NBR_SERIALIZED_BATCHES", "1",
"number of block batches "
"(within the set of blocks to be sent to a given remote peer) "
"serialized on the same device memory chunk",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_serialized_batches),
UCC_CONFIG_TYPE_UINT},

{"NBR_BATCHES_PER_PASSAGE", "32",
"number of batches of blocks sent to one remote node before enqueing",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_batches_per_passage),
UCC_CONFIG_TYPE_UINT},

{NULL}};

static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = {
Expand All @@ -125,6 +160,10 @@ static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = {
{"MCAST_NET_DEVICE", "", "Specifies which network device to use for Mcast",
ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.ib_dev_name),
UCC_CONFIG_TYPE_STRING},
{"FANIN_NPOLLS", "1000",
"Number of shared memory polling before returning UCC_INPROGRESS during "
"internode FANIN",
ucc_offsetof(ucc_tl_mlx5_context_config_t, npolls), UCC_CONFIG_TYPE_UINT},

{NULL}};

Expand Down
17 changes: 14 additions & 3 deletions src/components/tl/mlx5/tl_mlx5.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,20 @@ typedef struct ucc_tl_mlx5_lib_config {
int dm_host;
ucc_tl_mlx5_ib_qp_conf_t qp_conf;
ucc_tl_mlx5_mcast_coll_comm_init_spec_t mcast_conf;
int fanin_kn_radix;
int nbr_serialized_batches;
int nbr_batches_per_passage;
int block_batch_size;
int force_regular;
int force_longer;
int force_wider;
} ucc_tl_mlx5_lib_config_t;

typedef struct ucc_tl_mlx5_context_config {
ucc_tl_context_config_t super;
ucs_config_names_array_t devices;
ucc_tl_mlx5_mcast_ctx_params_t mcast_ctx_conf;
int npolls;
} ucc_tl_mlx5_context_config_t;

typedef struct ucc_tl_mlx5_lib {
Expand Down Expand Up @@ -93,10 +101,13 @@ UCC_CLASS_DECLARE(ucc_tl_mlx5_context_t, const ucc_base_context_params_t*,

typedef struct ucc_tl_mlx5_task ucc_tl_mlx5_task_t;
typedef struct ucc_tl_mlx5_schedule ucc_tl_mlx5_schedule_t;
typedef struct ucc_tl_mlx5_dm_chunk {
ptrdiff_t offset; /* 0 based offset from the beginning of
memic_mr (obtained with ibv_reg_dm_mr) */
typedef struct ucc_tl_mlx5_dm_chunk_t {
uintptr_t addr; // 0 based offset from the beginning of
// memic_mr (obtained with ibv_reg_dm_mr)
ucc_tl_mlx5_schedule_t *task;
int posted_sends;
int posted_all;
int completed_sends;
} ucc_tl_mlx5_dm_chunk_t;

typedef struct ucc_tl_mlx5_alltoall ucc_tl_mlx5_alltoall_t;
Expand Down
3 changes: 2 additions & 1 deletion src/components/tl/mlx5/tl_mlx5_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ typedef struct ucc_tl_mlx5_schedule {
int seq_num;
int seq_index;
int num_of_blocks_columns;
int block_size;
int block_height;
int block_width;
int started;
int send_blocks_enqueued;
int blocks_sent;
Expand Down
22 changes: 15 additions & 7 deletions src/components/tl/mlx5/tl_mlx5_dm.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,15 @@ static void ucc_tl_mlx5_dm_chunk_init(ucc_mpool_t *mp, //NOLINT
ucc_tl_mlx5_team_t *team =
ucc_container_of(mp, ucc_tl_mlx5_team_t, dm_pool);

c->offset = (ptrdiff_t)team->dm_offset;
team->dm_offset = PTR_OFFSET(team->dm_offset,
UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size);
c->addr = (uintptr_t)PTR_OFFSET(
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
(UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_host) ? team->dm_ptr : NULL,
team->dm_offset);
c->posted_sends = 0;
c->posted_all = 0;
c->completed_sends = 0;
team->dm_offset = PTR_OFFSET(
team->dm_offset, UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size *
UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size);
}

static ucc_mpool_ops_t ucc_tl_mlx5_dm_ops = {
Expand Down Expand Up @@ -219,13 +225,15 @@ ucc_status_t ucc_tl_mlx5_dm_init(ucc_tl_mlx5_team_t *team)
}

status = ucc_tl_mlx5_dm_alloc_reg(
ctx->shared_ctx, ctx->shared_pd, cfg->dm_host, cfg->dm_buf_size,
&cfg->dm_buf_num, &team->dm_ptr, &team->dm_mr, UCC_TL_TEAM_LIB(team));
ctx->shared_ctx, ctx->shared_pd, cfg->dm_host,
cfg->dm_buf_size * cfg->block_batch_size, &cfg->dm_buf_num,
&team->dm_ptr, &team->dm_mr, UCC_TL_TEAM_LIB(team));
if (status != UCC_OK) {
goto err_dm_alloc;
}
team->dm_offset = NULL;

team->dm_offset = 0;
// TODO: fix/check the case dm_host=true
ucc_assert(!cfg->dm_host);
status = ucc_mpool_init(
&team->dm_pool, 0, sizeof(ucc_tl_mlx5_dm_chunk_t), 0,
UCC_CACHE_LINE_SIZE, 1, cfg->dm_buf_num, &ucc_tl_mlx5_dm_ops,
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/mlx5/tl_mlx5_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ ucc_status_t ucc_tl_mlx5_team_destroy(ucc_base_team_t *tl_team)
return UCC_OK;
}

static inline ucc_status_t ucc_tl_mlx5_a2a_team_test(ucc_base_team_t *team)
static inline ucc_status_t ucc_tl_mlx5_alltoall_team_test(ucc_base_team_t *team)
{
ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t);

Expand Down Expand Up @@ -253,7 +253,7 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team)
goto initial_sync_post;
}

a2a_status = ucc_tl_mlx5_a2a_team_test(team);
a2a_status = ucc_tl_mlx5_alltoall_team_test(team);
if (a2a_status < 0) {
tl_warn(team->context->lib, "ALLTOALL tl team: %p creation failed %d",
team, a2a_status);
Expand Down
3 changes: 2 additions & 1 deletion src/components/tl/mlx5/tl_mlx5_wqe.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ static inline uint8_t get_umr_mr_flags(uint32_t acc)

typedef struct transpose_seg {
__be32 element_size; /* 8 bit value */
__be16 num_rows; /* 7 bit value */
//From PRM we should have the rows first and then the colls. This is probably a naming error
__be16 num_cols; /* 7 bit value */
__be16 num_rows; /* 7 bit value */
__be64 padding;
} transpose_seg_t;

Expand Down
2 changes: 1 addition & 1 deletion test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ UCC_TEST_P(test_tl_mlx5_transpose, transposeWqe)

ibv_wr_start(qp.qp_ex);
post_transpose(qp.qp, src_mr->lkey, dst_mr->rkey, (uintptr_t)src,
(uintptr_t)dst, elem_size, nrows, ncols, IBV_SEND_SIGNALED);
(uintptr_t)dst, elem_size, ncols, nrows, IBV_SEND_SIGNALED);
GTEST_ASSERT_EQ(ibv_wr_complete(qp.qp_ex), 0);

while (!completions_num) {
Expand Down
Loading