diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 70439263fb..fb4ffd03e6 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -782,6 +782,7 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); ucc_tl_mlx5_alltoall_t *a2a = tl_team->a2a; + ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(tl_team); int is_asr = (a2a->node.sbgp->group_rank == a2a->node.asr_rank); int n_tasks = is_asr ? 5 : 3; @@ -793,6 +794,10 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init, ucc_coll_task_t *tasks[5]; ucc_status_t status; + if (!ctx->cfg.enable_alltoall) { + return UCC_ERR_NOT_SUPPORTED; + } + if (UCC_IS_INPLACE(coll_args->args)) { return UCC_ERR_NOT_SUPPORTED; } diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 5cdd6c51a1..9e2b9ce548 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -114,7 +114,8 @@ static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { ucc_offsetof(ucc_tl_mlx5_context_config_t, devices), UCC_CONFIG_TYPE_STRING_ARRAY}, - {"MCAST_TIMEOUT", "10000", "Timeout [usec] for the reliability NACK in Mcast", + {"MCAST_TIMEOUT", "10000", + "Timeout [usec] for the reliability NACK in Mcast", ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.timeout), UCC_CONFIG_TYPE_INT}, @@ -126,6 +127,10 @@ static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.ib_dev_name), UCC_CONFIG_TYPE_STRING}, + {"ALLTOALL_ENABLE", "1", "Enable Accelerated alltoall", + ucc_offsetof(ucc_tl_mlx5_context_config_t, enable_alltoall), + UCC_CONFIG_TYPE_BOOL}, + {NULL}}; UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_mlx5_lib_t, ucc_base_lib_t, diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 159ecda8ed..8576b5730d 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -66,6 +66,7 @@ typedef struct ucc_tl_mlx5_context_config { ucc_tl_context_config_t super; ucs_config_names_array_t devices; ucc_tl_mlx5_mcast_ctx_params_t mcast_ctx_conf; + int enable_alltoall; } ucc_tl_mlx5_context_config_t; typedef struct ucc_tl_mlx5_lib { diff --git a/src/components/tl/mlx5/tl_mlx5_context.c b/src/components/tl/mlx5/tl_mlx5_context.c index 1011fe72c0..81bcf81013 100644 --- a/src/components/tl/mlx5/tl_mlx5_context.c +++ b/src/components/tl/mlx5/tl_mlx5_context.c @@ -43,10 +43,16 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_context_t, return status; } - status = tl_mlx5_rcache_create(self); - if (UCC_OK != status) { - tl_debug(self->super.super.lib, "failed to create rcache"); - goto err_rcache; + if (self->cfg.enable_alltoall) { + status = tl_mlx5_rcache_create(self); + if (UCC_OK != status) { + tl_debug(self->super.super.lib, "failed to create rcache"); + goto err_rcache; + } + } else { + tl_debug(self->super.super.lib, + "alltoall is disabled by the env variable " + "`UCC_TL_MLX5_ALLTOALL_ENABLE`"); } self->mcast.mcast_ctx_ready = 0; @@ -180,6 +186,10 @@ ucc_status_t ucc_tl_mlx5_context_ib_ctx_pd_setup(ucc_base_context_t *context) ucc_coll_task_t *req; ucc_tl_mlx5_context_create_sbcast_data_t *sbcast_data; + if (!ctx->cfg.enable_alltoall) { + return UCC_OK; + } + if (!core_ctx->service_team) { tl_debug(context->lib, "failed to init ctx: need service team"); return UCC_ERR_NO_MESSAGE; diff --git a/src/components/tl/mlx5/tl_mlx5_dm.c b/src/components/tl/mlx5/tl_mlx5_dm.c index 2a0c474a39..0881698ee3 100644 --- a/src/components/tl/mlx5/tl_mlx5_dm.c +++ b/src/components/tl/mlx5/tl_mlx5_dm.c @@ -91,7 +91,7 @@ static ucc_mpool_ops_t ucc_tl_mlx5_dm_ops = { void ucc_tl_mlx5_dm_pool_cleanup(ucc_tl_mlx5_team_t *team) { - if (!team->dm_ptr) { + if (!team->dm_ptr || !team->a2a) { return; } diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index 1e5f6ddf56..de0820db5b 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -69,9 +69,11 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_team_t, ucc_base_context_t *tl_context, self->global_sync_req = NULL; self->a2a = NULL; - status = ucc_tl_mlx5_team_init_alltoall(self); - if (UCC_OK != status) { - return status; + if (ctx->cfg.enable_alltoall) { + status = ucc_tl_mlx5_team_init_alltoall(self); + if (UCC_OK != status) { + return status; + } } self->mcast = NULL; @@ -155,6 +157,7 @@ static inline ucc_status_t ucc_tl_mlx5_a2a_team_test(ucc_base_team_t *team) ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) { ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(tl_team); ucc_team_t *core_team = UCC_TL_CORE_TEAM(tl_team); ucc_subset_t subset = {.map = UCC_TL_TEAM_MAP(tl_team), .myrank = UCC_TL_TEAM_RANK(tl_team)}; @@ -242,7 +245,7 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) } ucc_assert(tl_team->global_sync_req == NULL); - + if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_CTX_CHECK && tl_team->a2a_state == TL_MLX5_TEAM_STATE_ALLTOALL_CTX_CHECK) { // check if ctx is ready for a2a and mcast @@ -253,10 +256,15 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) goto initial_sync_post; } - a2a_status = ucc_tl_mlx5_a2a_team_test(team); - if (a2a_status < 0) { - tl_warn(team->context->lib, "ALLTOALL tl team: %p creation failed %d", - team, a2a_status); + if (ctx->cfg.enable_alltoall) { + a2a_status = ucc_tl_mlx5_a2a_team_test(team); + if (a2a_status < 0) { + tl_warn(team->context->lib, + "ALLTOALL tl team: %p creation failed %d", team, + a2a_status); + tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE; + } + } else { tl_team->a2a_state = TL_MLX5_TEAM_STATE_ALLTOALL_NOT_AVAILABLE; } @@ -269,7 +277,7 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) } } - if (UCC_OK != a2a_status || UCC_OK != mcast_status) { + if (UCC_INPROGRESS == a2a_status || UCC_INPROGRESS == mcast_status) { return UCC_INPROGRESS; }