Skip to content

Commit

Permalink
UCC/CTX: passing cuda check from tl ucp to others
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Aug 21, 2024
1 parent 777df69 commit ab17b0e
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_context {
ucc_rcache_t *rcache;
ucc_tl_mlx5_mcast_ctx_params_t params;
ucc_base_lib_t *lib;
enum ucc_tl_capabilities tl_caps;
} ucc_tl_mlx5_mcast_coll_context_t;

typedef struct ucc_tl_mlx5_mcast_join_info_t {
Expand Down
2 changes: 2 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,
conf_params->rx_sge = 2;
conf_params->scq_moderation = 64;

mcast_context->tl_caps = base_context->ucc_context->tl_caps;

comm = (ucc_tl_mlx5_mcast_coll_comm_t*)
ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_comm_t) +
sizeof(struct pp_packet*)*(conf_params->wsize-1),
Expand Down
5 changes: 5 additions & 0 deletions src/components/tl/ucp/tl_ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t,
self);

self->ucp_memory_types = context_attr.memory_types;
if (self->ucp_memory_types & UCC_BIT(ucc_memtype_to_ucs[UCC_MEMORY_TYPE_CUDA])) {
/* TL MLX5 needs this information */
self->super.super.ucc_context->tl_caps |= UCC_TL_UCP_CUDA_ENABLED;
}

worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
switch (params->thread_mode) {
case UCC_THREAD_SINGLE:
Expand Down
7 changes: 7 additions & 0 deletions src/core/ucc_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ typedef struct ucc_context_id {
#define UCC_CTX_ID_EQUAL(_id1, _id2) (UCC_PROC_INFO_EQUAL((_id1).pi, (_id2).pi) \
&& (_id1).seq_num == (_id2).seq_num)

enum ucc_tl_capabilities {
/* capabalities that every TL needs to be aware of
* about other TLs */
UCC_TL_UCP_CUDA_ENABLED = UCC_BIT(0)
};

enum {
/* all ranks have identical set of TLs*/
UCC_ADDR_STORAGE_FLAG_TLS_SYMMETRIC = UCC_BIT(0),
Expand Down Expand Up @@ -78,6 +84,7 @@ typedef struct ucc_context {
uint64_t cl_flags;
ucc_tl_team_t *service_team;
int32_t throttle_progress;
enum ucc_tl_capabilities tl_caps;
} ucc_context_t;

typedef struct ucc_context_config {
Expand Down

0 comments on commit ab17b0e

Please sign in to comment.