diff --git a/src/components/tl/sharp/tl_sharp.c b/src/components/tl/sharp/tl_sharp.c index 464ef50478..fe86950bf7 100644 --- a/src/components/tl/sharp/tl_sharp.c +++ b/src/components/tl/sharp/tl_sharp.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -72,6 +72,11 @@ static ucc_config_field_t ucc_tl_sharp_context_config_table[] = { ucc_offsetof(ucc_tl_sharp_context_config_t, team_max_ppn), UCC_CONFIG_TYPE_UINT}, + {"USE_MULTI_CHANNEL", "0", + "Use SHARP Multi-channel feature. Options: 0-disable 1-enable", + ucc_offsetof(ucc_tl_sharp_context_config_t, use_multi_channel), + UCC_CONFIG_TYPE_BOOL}, + {NULL}}; UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_sharp_lib_t, ucc_base_lib_t, diff --git a/src/components/tl/sharp/tl_sharp.h b/src/components/tl/sharp/tl_sharp.h index adfbc86036..875b9c6689 100644 --- a/src/components/tl/sharp/tl_sharp.h +++ b/src/components/tl/sharp/tl_sharp.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -53,6 +53,7 @@ typedef struct ucc_tl_sharp_context_config { int context_per_team; int enable_lazy_group_alloc; int team_max_ppn; + int use_multi_channel; } ucc_tl_sharp_context_config_t; typedef struct ucc_tl_sharp_lib { diff --git a/src/components/tl/sharp/tl_sharp_context.c b/src/components/tl/sharp/tl_sharp_context.c index 42d10f8d87..ed7d50578b 100644 --- a/src/components/tl/sharp/tl_sharp_context.c +++ b/src/components/tl/sharp/tl_sharp_context.c @@ -305,7 +305,11 @@ ucc_status_t ucc_tl_sharp_context_init(ucc_tl_sharp_context_t *sharp_ctx, init_spec.progress_func = NULL; init_spec.world_local_rank = local_rank; - init_spec.group_channel_idx = 0; + if (sharp_ctx->cfg.use_multi_channel) { + init_spec.group_channel_idx = local_rank; + } else { + init_spec.group_channel_idx = 0; + } init_spec.oob_ctx = oob_ctx; init_spec.config = sharp_coll_default_config; init_spec.config.user_progress_num_polls = sharp_ctx->cfg.uprogress_num_polls;