Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/MLX5: generate schedule for zcopy allgather #1059

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm_init_spec {
int max_eager;
int cuda_mem_enabled;
int one_sided_reliability_enable;
int truly_zero_copy_allgather_enabled;
int mcast_prepost_bucket_size;
void *oob;
} ucc_tl_mlx5_mcast_coll_comm_init_spec_t;

Expand Down Expand Up @@ -276,6 +278,8 @@ typedef struct ucc_tl_mlx5_mcast_allgather_comm {
uint32_t coll_counter;
uint32_t max_num_packets;
uint32_t max_push_send;
uint8_t truly_zero_copy_allgather_enabled;
uint32_t mcast_prepost_bucket_size;
} ucc_tl_mlx5_mcast_allgather_comm_t;

typedef struct ucc_tl_mlx5_mcast_bcast_comm {
Expand Down Expand Up @@ -434,6 +438,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_req {
ucc_memory_type_t buf_mem_type;
enum ucc_tl_mlx5_mcast_one_sided_reliability_scheme one_sided_reliability_scheme;
uint32_t ag_counter;
int concurreny_level;
int mcast_prepost_bucket_size;
int state;
ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *ag_schedule;
int total_steps;
Expand Down
153 changes: 153 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,152 @@ void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task)
}
}

static inline ucc_status_t
ucc_tl_mlx5_mcast_validate_zero_copy_allgather_params(ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t *req)
{
if ((req->concurreny_level % 2 == 0 && req->num_packets % req->mcast_prepost_bucket_size != 0) ||
(comm->commsize % req->concurreny_level != 0) ||
(req->length % comm->max_per_packet != 0)) {
tl_warn(comm->lib, "Pipelined mcast allgather not supported: "
"num_packets %d mcast_prepost_bucket_size %d "
"length %ld max_per_packet %d "
"team size %d concurreny_level %d",
req->num_packets, req->mcast_prepost_bucket_size, req->length,
comm->max_per_packet, comm->commsize, req->concurreny_level);
return UCC_ERR_NOT_SUPPORTED;
}

if (req->mcast_prepost_bucket_size * req->concurreny_level * 2 > comm->params.rx_depth) {
tl_warn(comm->lib, "Pipelined mcast allgather not supported: "
"either reduce prepost_bucket_size or mcast group "
"count or increase recv queue size "
"mcast_prepost_bucket_size %d concurreny_level %d "
"rx_depth %d",
req->mcast_prepost_bucket_size, req->concurreny_level,
comm->params.rx_depth);
return UCC_ERR_NOT_SUPPORTED;
}

return UCC_OK;
}


/*
* at each stage half of the mcast groups are ready for receiving mcast
* packets while the other half are getting prepared by preposting recv
* buffers
*/
static inline ucc_status_t
ucc_tl_mlx5_mcast_prepare_zero_copy_allgather(ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t *req)
{
ucc_tl_mlx5_mcast_reg_t *reg = NULL;
ucc_rank_t root = 0;
int offset = 0;
ucc_status_t status;
ucc_rank_t j, i;
int total_steps;
ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *new_sched;

ucc_assert(comm->allgather_comm.truly_zero_copy_allgather_enabled);

req->concurreny_level = comm->mcast_group_count / 2;
req->concurreny_level = ucc_min(req->concurreny_level, ONE_SIDED_MAX_CONCURRENT_LEVEL);
req->concurreny_level = ucc_min(req->concurreny_level, comm->commsize);

if (req->concurreny_level == 0) {
tl_warn(comm->lib, "not enough concurreny level to enable zcopy pipeline allgather");
return UCC_ERR_NOT_SUPPORTED;
}

if (comm->allgather_comm.mcast_prepost_bucket_size > req->num_packets) {
req->mcast_prepost_bucket_size = req->num_packets;
} else {
req->mcast_prepost_bucket_size = comm->allgather_comm.mcast_prepost_bucket_size;
}

status = ucc_tl_mlx5_mcast_validate_zero_copy_allgather_params(comm, req);
if (status != UCC_OK) {
return status;
}

/* calculate the schedule and details of what we should
* mcast and prepost to which mcast group at each stage*/
total_steps = req->num_packets * (comm->commsize / req->concurreny_level)
/ req->mcast_prepost_bucket_size + 1;

new_sched = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_pipelined_ag_schedule_t) * total_steps, "sched");
if (!new_sched) {
tl_warn(comm->lib, "cannot allocate memory for schedule list");
return UCC_ERR_NO_MEMORY;
}

/* generate schedule */
for (i = 0; i < total_steps; i++) {
ucc_assert(root < comm->commsize);
if (i < total_steps - 1) {
for (j = 0; j < req->concurreny_level; j++) {
new_sched[i].prepost_buf_op[j].group_id = j + req->concurreny_level * (i % 2);
new_sched[i].prepost_buf_op[j].offset = offset * comm->max_per_packet;
new_sched[i].prepost_buf_op[j].root = root + j;
new_sched[i].prepost_buf_op[j].count = req->mcast_prepost_bucket_size;
}
} else {
new_sched[i].prepost_buf_op_done = 1;
}

if (i > 0) {
for (j = 0; j < req->concurreny_level; j++) {
new_sched[i].multicast_op[j].group_id = new_sched[i - 1].prepost_buf_op[j].group_id;
new_sched[i].multicast_op[j].offset = new_sched[i - 1].prepost_buf_op[j].offset;
new_sched[i].multicast_op[j].offset_left = new_sched[i - 1].prepost_buf_op[j].offset;
new_sched[i].multicast_op[j].root = new_sched[i - 1].prepost_buf_op[j].root;
new_sched[i].multicast_op[j].to_send_left = new_sched[i - 1].prepost_buf_op[j].count;
new_sched[i].multicast_op[j].to_recv = new_sched[i - 1].prepost_buf_op[j].count;
new_sched[i].to_recv += new_sched[i].multicast_op[j].to_recv;
if (new_sched[i].multicast_op[j].root == comm->rank) {
new_sched[i].to_send += new_sched[i].multicast_op[j].to_send_left;
}
}
}

if (!new_sched[i].to_send || !new_sched[i].to_recv) {
new_sched[i].multicast_op_done = 1;
}

offset += req->mcast_prepost_bucket_size;

if (offset == req->num_packets) {
offset = 0;
root = (root + req->concurreny_level) % comm->commsize;
}
}

tl_trace(comm->lib,
"generated the schedule for pipelined zero copy allgather with total_steps %d",
total_steps);
new_sched->total_steps = total_steps;
req->total_steps = total_steps;
req->ag_schedule = new_sched;
tl_trace(comm->lib, "registering recv buf of size %ld", req->length * comm->commsize);
ucc_assert(req->recv_rreg == NULL);

status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->rptr, req->length *
comm->commsize, &reg);
if (UCC_OK != status) {
tl_warn(comm->lib, "unable to register receive buffer %p of size %ld",
req->rptr, req->length * comm->commsize);
ucc_free(new_sched);
return status;
}

req->recv_rreg = reg;
req->recv_mr = reg->mr;

return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task)
{
ucc_coll_task_t *coll_task = &(task->super);
Expand Down Expand Up @@ -357,6 +503,13 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task)
req->to_send = req->num_packets;
req->to_recv = comm->commsize * req->num_packets;

if (comm->allgather_comm.truly_zero_copy_allgather_enabled) {
status = ucc_tl_mlx5_mcast_prepare_zero_copy_allgather(comm, req);
if (UCC_OK != status) {
return status;
}
}

comm->allgather_comm.coll_counter++;

task->coll_mcast.req_handle = req;
Expand Down
4 changes: 4 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,

memcpy(&comm->params, conf_params, sizeof(*conf_params));

comm->allgather_comm.mcast_prepost_bucket_size
= conf_params->mcast_prepost_bucket_size;
comm->allgather_comm.truly_zero_copy_allgather_enabled
= conf_params->truly_zero_copy_allgather_enabled;
comm->one_sided.reliability_enabled = conf_params->one_sided_reliability_enable;
comm->bcast_comm.wsize = conf_params->wsize;
comm->allgather_comm.max_push_send = conf_params->max_push_send;
Expand Down
9 changes: 9 additions & 0 deletions src/components/tl/mlx5/tl_mlx5.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = {
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable),
UCC_CONFIG_TYPE_BOOL},

{"MCAST_ZERO_COPY_ALLGATHER_ENABLE", "1", "Enable truly zero copy allgather design for mcast",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.truly_zero_copy_allgather_enabled),
UCC_CONFIG_TYPE_BOOL},

{"MCAST_ZERO_COPY_PREPOST_BUCKET_SIZE", "16", "Number of posted recvs during each stage of the pipeline"
" in truly zero copy mcast allgather design",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.mcast_prepost_bucket_size),
UCC_CONFIG_TYPE_INT},

{NULL}};

static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = {
Expand Down
Loading