Skip to content

Commit

Permalink
TL/MLX5: add ctx option to disable a2a
Browse files Browse the repository at this point in the history
  • Loading branch information
samnordmann committed Dec 30, 2024
1 parent 73651ea commit ff27a77
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 15 deletions.
5 changes: 5 additions & 0 deletions src/components/tl/mlx5/alltoall/alltoall_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init,
ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team,
ucc_tl_mlx5_team_t);
ucc_tl_mlx5_alltoall_t *a2a = tl_team->a2a;
ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(tl_team);
int is_asr = (a2a->node.sbgp->group_rank
== a2a->node.asr_rank);
int n_tasks = is_asr ? 5 : 3;
Expand All @@ -793,6 +794,10 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init,
ucc_coll_task_t *tasks[5];
ucc_status_t status;

if (!ctx->cfg.enable_alltoall) {
return UCC_ERR_NOT_SUPPORTED;
}

if (UCC_IS_INPLACE(coll_args->args)) {
return UCC_ERR_NOT_SUPPORTED;
}
Expand Down
7 changes: 6 additions & 1 deletion src/components/tl/mlx5/tl_mlx5.c
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = {
ucc_offsetof(ucc_tl_mlx5_context_config_t, devices),
UCC_CONFIG_TYPE_STRING_ARRAY},

{"MCAST_TIMEOUT", "10000", "Timeout [usec] for the reliability NACK in Mcast",
{"MCAST_TIMEOUT", "10000",
"Timeout [usec] for the reliability NACK in Mcast",
ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.timeout),
UCC_CONFIG_TYPE_INT},

Expand All @@ -126,6 +127,10 @@ static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = {
ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.ib_dev_name),
UCC_CONFIG_TYPE_STRING},

{"ALLTOALL_ENABLE", "1", "Enable Accelerated alltoall",
ucc_offsetof(ucc_tl_mlx5_context_config_t, enable_alltoall),
UCC_CONFIG_TYPE_BOOL},

{NULL}};

UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_mlx5_lib_t, ucc_base_lib_t,
Expand Down
1 change: 1 addition & 0 deletions src/components/tl/mlx5/tl_mlx5.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ typedef struct ucc_tl_mlx5_context_config {
ucc_tl_context_config_t super;
ucs_config_names_array_t devices;
ucc_tl_mlx5_mcast_ctx_params_t mcast_ctx_conf;
int enable_alltoall;
} ucc_tl_mlx5_context_config_t;

typedef struct ucc_tl_mlx5_lib {
Expand Down
18 changes: 14 additions & 4 deletions src/components/tl/mlx5/tl_mlx5_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,16 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_context_t,
return status;
}

status = tl_mlx5_rcache_create(self);
if (UCC_OK != status) {
tl_debug(self->super.super.lib, "failed to create rcache");
goto err_rcache;
if (self->cfg.enable_alltoall) {
status = tl_mlx5_rcache_create(self);
if (UCC_OK != status) {
tl_debug(self->super.super.lib, "failed to create rcache");
goto err_rcache;
}
} else {
tl_debug(self->super.super.lib,
"alltoall is disabled by the env variable "
"`UCC_TL_MLX5_ALLTOALL_ENABLE`");
}

self->mcast.mcast_ctx_ready = 0;
Expand Down Expand Up @@ -180,6 +186,10 @@ ucc_status_t ucc_tl_mlx5_context_ib_ctx_pd_setup(ucc_base_context_t *context)
ucc_coll_task_t *req;
ucc_tl_mlx5_context_create_sbcast_data_t *sbcast_data;

if (!ctx->cfg.enable_alltoall) {
return UCC_OK;
}

if (!core_ctx->service_team) {
tl_debug(context->lib, "failed to init ctx: need service team");
return UCC_ERR_NO_MESSAGE;
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_dm.c
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ static ucc_mpool_ops_t ucc_tl_mlx5_dm_ops = {

void ucc_tl_mlx5_dm_pool_cleanup(ucc_tl_mlx5_team_t *team)
{
if (!team->dm_ptr) {
if (!team->dm_ptr || !team->a2a) {
return;
}

Expand Down
26 changes: 17 additions & 9 deletions src/components/tl/mlx5/tl_mlx5_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_team_t, ucc_base_context_t *tl_context,
self->global_sync_req = NULL;

self->a2a = NULL;
status = ucc_tl_mlx5_team_init_alltoall(self);
if (UCC_OK != status) {
return status;
if (ctx->cfg.enable_alltoall) {
status = ucc_tl_mlx5_team_init_alltoall(self);
if (UCC_OK != status) {
return status;
}
}

self->mcast = NULL;
Expand Down Expand Up @@ -155,6 +157,7 @@ static inline ucc_status_t ucc_tl_mlx5_a2a_team_test(ucc_base_team_t *team)
ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team)
{
ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t);
ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(tl_team);
ucc_team_t *core_team = UCC_TL_CORE_TEAM(tl_team);
ucc_subset_t subset = {.map = UCC_TL_TEAM_MAP(tl_team),
.myrank = UCC_TL_TEAM_RANK(tl_team)};
Expand Down Expand Up @@ -242,7 +245,7 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team)
}

ucc_assert(tl_team->global_sync_req == NULL);

if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_CTX_CHECK &&
tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_CTX_CHECK) {
// check if ctx is ready for a2a and mcast
Expand All @@ -253,10 +256,15 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team)
goto initial_sync_post;
}

a2a_status = ucc_tl_mlx5_a2a_team_test(team);
if (a2a_status < 0) {
tl_warn(team->context->lib, "ALLTOALL tl team: %p creation failed %d",
team, a2a_status);
if (ctx->cfg.enable_alltoall) {
a2a_status = ucc_tl_mlx5_a2a_team_test(team);
if (a2a_status < 0) {
tl_warn(team->context->lib,
"ALLTOALL tl team: %p creation failed %d", team,
a2a_status);
tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE;
}
} else {
tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE;
}

Expand All @@ -269,7 +277,7 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team)
}
}

if (UCC_OK != a2a_status || UCC_OK != mcast_status) {
if (UCC_INPROGRESS == a2a_status || UCC_INPROGRESS == mcast_status) {
return UCC_INPROGRESS;
}

Expand Down

0 comments on commit ff27a77

Please sign in to comment.