diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 663ee636ed..19ea00d547 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -117,6 +117,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm_init_spec { int max_eager; int cuda_mem_enabled; int one_sided_reliability_enable; + int truly_zero_copy_allgather_enabled; + int mcast_prepost_bucket_size; void *oob; } ucc_tl_mlx5_mcast_coll_comm_init_spec_t; @@ -276,6 +278,8 @@ typedef struct ucc_tl_mlx5_mcast_allgather_comm { uint32_t coll_counter; uint32_t max_num_packets; uint32_t max_push_send; + uint8_t truly_zero_copy_allgather_enabled; + uint32_t mcast_prepost_bucket_size; } ucc_tl_mlx5_mcast_allgather_comm_t; typedef struct ucc_tl_mlx5_mcast_bcast_comm { @@ -434,6 +438,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_req { ucc_memory_type_t buf_mem_type; enum ucc_tl_mlx5_mcast_one_sided_reliability_scheme one_sided_reliability_scheme; uint32_t ag_counter; + int concurreny_level; + int mcast_prepost_bucket_size; int state; ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *ag_schedule; int total_steps; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c index 82592238d4..9d06041fd9 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c @@ -270,6 +270,152 @@ void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task) } } +static inline ucc_status_t +ucc_tl_mlx5_mcast_validate_zero_copy_allgather_params(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + if ((req->concurreny_level % 2 == 0 && req->num_packets % req->mcast_prepost_bucket_size != 0) || + (comm->commsize % req->concurreny_level != 0) || + (req->length % comm->max_per_packet != 0)) { + tl_warn(comm->lib, "Pipelined mcast allgather not supported: " + "num_packets %d mcast_prepost_bucket_size %d " + "length %ld max_per_packet %d " + "team size %d concurreny_level %d", + req->num_packets, req->mcast_prepost_bucket_size, req->length, + comm->max_per_packet, comm->commsize, req->concurreny_level); + return UCC_ERR_NOT_SUPPORTED; + } + + if (req->mcast_prepost_bucket_size * req->concurreny_level * 2 > comm->params.rx_depth) { + tl_warn(comm->lib, "Pipelined mcast allgather not supported: " + "either reduce prepost_bucket_size or mcast group " + "count or increase recv queue size " + "mcast_prepost_bucket_size %d concurreny_level %d " + "rx_depth %d", + req->mcast_prepost_bucket_size, req->concurreny_level, + comm->params.rx_depth); + return UCC_ERR_NOT_SUPPORTED; + } + + return UCC_OK; +} + + +/* + * at each stage half of the mcast groups are ready for receiving mcast + * packets while the other half are getting prepared by preposting recv + * buffers + */ +static inline ucc_status_t +ucc_tl_mlx5_mcast_prepare_zero_copy_allgather(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + ucc_tl_mlx5_mcast_reg_t *reg = NULL; + ucc_rank_t root = 0; + int offset = 0; + ucc_status_t status; + ucc_rank_t j, i; + int total_steps; + ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *new_sched; + + ucc_assert(comm->allgather_comm.truly_zero_copy_allgather_enabled); + + req->concurreny_level = comm->mcast_group_count / 2; + req->concurreny_level = ucc_min(req->concurreny_level, ONE_SIDED_MAX_CONCURRENT_LEVEL); + req->concurreny_level = ucc_min(req->concurreny_level, comm->commsize); + + if (req->concurreny_level == 0) { + tl_warn(comm->lib, "not enough concurreny level to enable zcopy pipeline allgather"); + return UCC_ERR_NOT_SUPPORTED; + } + + if (comm->allgather_comm.mcast_prepost_bucket_size > req->num_packets) { + req->mcast_prepost_bucket_size = req->num_packets; + } else { + req->mcast_prepost_bucket_size = comm->allgather_comm.mcast_prepost_bucket_size; + } + + status = ucc_tl_mlx5_mcast_validate_zero_copy_allgather_params(comm, req); + if (status != UCC_OK) { + return status; + } + + /* calculate the schedule and details of what we should + * mcast and prepost to which mcast group at each stage*/ + total_steps = req->num_packets * (comm->commsize / req->concurreny_level) + / req->mcast_prepost_bucket_size + 1; + + new_sched = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_pipelined_ag_schedule_t) * total_steps, "sched"); + if (!new_sched) { + tl_warn(comm->lib, "cannot allocate memory for schedule list"); + return UCC_ERR_NO_MEMORY; + } + + /* generate schedule */ + for (i = 0; i < total_steps; i++) { + ucc_assert(root < comm->commsize); + if (i < total_steps - 1) { + for (j = 0; j < req->concurreny_level; j++) { + new_sched[i].prepost_buf_op[j].group_id = j + req->concurreny_level * (i % 2); + new_sched[i].prepost_buf_op[j].offset = offset * comm->max_per_packet; + new_sched[i].prepost_buf_op[j].root = root + j; + new_sched[i].prepost_buf_op[j].count = req->mcast_prepost_bucket_size; + } + } else { + new_sched[i].prepost_buf_op_done = 1; + } + + if (i > 0) { + for (j = 0; j < req->concurreny_level; j++) { + new_sched[i].multicast_op[j].group_id = new_sched[i - 1].prepost_buf_op[j].group_id; + new_sched[i].multicast_op[j].offset = new_sched[i - 1].prepost_buf_op[j].offset; + new_sched[i].multicast_op[j].offset_left = new_sched[i - 1].prepost_buf_op[j].offset; + new_sched[i].multicast_op[j].root = new_sched[i - 1].prepost_buf_op[j].root; + new_sched[i].multicast_op[j].to_send_left = new_sched[i - 1].prepost_buf_op[j].count; + new_sched[i].multicast_op[j].to_recv = new_sched[i - 1].prepost_buf_op[j].count; + new_sched[i].to_recv += new_sched[i].multicast_op[j].to_recv; + if (new_sched[i].multicast_op[j].root == comm->rank) { + new_sched[i].to_send += new_sched[i].multicast_op[j].to_send_left; + } + } + } + + if (!new_sched[i].to_send || !new_sched[i].to_recv) { + new_sched[i].multicast_op_done = 1; + } + + offset += req->mcast_prepost_bucket_size; + + if (offset == req->num_packets) { + offset = 0; + root = (root + req->concurreny_level) % comm->commsize; + } + } + + tl_trace(comm->lib, + "generated the schedule for pipelined zero copy allgather with total_steps %d", + total_steps); + new_sched->total_steps = total_steps; + req->total_steps = total_steps; + req->ag_schedule = new_sched; + tl_trace(comm->lib, "registering recv buf of size %ld", req->length * comm->commsize); + ucc_assert(req->recv_rreg == NULL); + + status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->rptr, req->length * + comm->commsize, ®); + if (UCC_OK != status) { + tl_warn(comm->lib, "unable to register receive buffer %p of size %ld", + req->rptr, req->length * comm->commsize); + ucc_free(new_sched); + return status; + } + + req->recv_rreg = reg; + req->recv_mr = reg->mr; + + return UCC_OK; +} + ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task) { ucc_coll_task_t *coll_task = &(task->super); @@ -357,6 +503,13 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task) req->to_send = req->num_packets; req->to_recv = comm->commsize * req->num_packets; + if (comm->allgather_comm.truly_zero_copy_allgather_enabled) { + status = ucc_tl_mlx5_mcast_prepare_zero_copy_allgather(comm, req); + if (UCC_OK != status) { + return status; + } + } + comm->allgather_comm.coll_counter++; task->coll_mcast.req_handle = req; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index b70ca6e2f6..46e5df8a24 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -99,6 +99,10 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, memcpy(&comm->params, conf_params, sizeof(*conf_params)); + comm->allgather_comm.mcast_prepost_bucket_size + = conf_params->mcast_prepost_bucket_size; + comm->allgather_comm.truly_zero_copy_allgather_enabled + = conf_params->truly_zero_copy_allgather_enabled; comm->one_sided.reliability_enabled = conf_params->one_sided_reliability_enable; comm->bcast_comm.wsize = conf_params->wsize; comm->allgather_comm.max_push_send = conf_params->max_push_send; diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 5cdd6c51a1..b9d48edb7b 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -104,6 +104,15 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable), UCC_CONFIG_TYPE_BOOL}, + {"MCAST_ZERO_COPY_ALLGATHER_ENABLE", "1", "Enable truly zero copy allgather design for mcast", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.truly_zero_copy_allgather_enabled), + UCC_CONFIG_TYPE_BOOL}, + + {"MCAST_ZERO_COPY_PREPOST_BUCKET_SIZE", "16", "Number of posted recvs during each stage of the pipeline" + " in truly zero copy mcast allgather design", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.mcast_prepost_bucket_size), + UCC_CONFIG_TYPE_INT}, + {NULL}}; static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = {