Skip to content

Commit

Permalink
TL/CUDA: double buffering
Browse files Browse the repository at this point in the history
  • Loading branch information
ikryukov committed Aug 2, 2024
1 parent 631ebdf commit e43c6b5
Showing 1 changed file with 24 additions and 44 deletions.
68 changes: 24 additions & 44 deletions src/components/tl/cuda/bcast/bcast_linear.c
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,13 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
ucc_tl_cuda_team_t *team = TASK_TEAM(task);
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
// ucc_datatype_t dt = task->bcast_linear.dt;
size_t half_scratch_size = get_raw_scratch_size(team) / 2;
size_t chunk_size =
task->bcast_linear.step < task->bcast_linear.num_steps
? ucc_min(half_scratch_size, task->bcast_linear.size)
: task->bcast_linear.size -
(task->bcast_linear.step - 1) * half_scratch_size;
size_t offset_buff = task->bcast_linear.step * half_scratch_size;

ucc_ee_executor_t * exec;
ucc_ee_executor_task_t *etask;
Expand All @@ -111,21 +117,18 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)

switch (task->bcast_linear.stage) {
case STAGE_SYNC:
// ucc_info("sync");
if (ucc_tl_cuda_get_sync(task) != UCC_OK) {
task->super.status = UCC_INPROGRESS;
return;
}
task->bcast_linear.step = 0;
// ucc_info("setup");
st = ucc_tl_cuda_bcast_linear_setup_start(task);
st = ucc_tl_cuda_bcast_linear_setup_start(task);
if (st != UCC_OK) {
task->super.status = st;
return;
}
task->bcast_linear.stage = STAGE_SETUP;
case STAGE_SETUP:
// ucc_info("test");
st = ucc_tl_cuda_bcast_linear_setup_test(task);
if (st != UCC_OK) {
task->super.status = st;
Expand All @@ -141,22 +144,14 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
break;
}

size_t scratch_size = get_raw_scratch_size(team);
size_t chunk_size = task->bcast_linear.step < task->bcast_linear.num_steps
? ucc_min(scratch_size, task->bcast_linear.size)
: task->bcast_linear.size -
(task->bcast_linear.step - 1) * scratch_size;
size_t offset_buff = task->bcast_linear.step * scratch_size;

// ucc_info("chunk_size: %ld", chunk_size);

if (trank == task->bcast_linear.root) {
// Root scenario
// fall-through between cases is intentional
switch (task->bcast_linear.stage) {
case STAGE_COPY:
// copy from src buffer to scratch
dbuf = TASK_SCRATCH(task, trank);
dbuf = PTR_OFFSET(TASK_SCRATCH(task, trank),
task->bcast_linear.step % 2 * half_scratch_size);
sbuf = PTR_OFFSET(task->bcast_linear.sbuf, offset_buff);
st = ecopy(dbuf, sbuf, chunk_size, exec,
&task->bcast_linear.exec_task);
Expand All @@ -166,7 +161,6 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
return;
}
task->bcast_linear.stage = STAGE_WAIT_COPY;
// break;
case STAGE_WAIT_COPY:
etask = task->bcast_linear.exec_task;
if (etask) {
Expand All @@ -180,23 +174,22 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
task->bcast_linear.step, 0);
task->bcast_linear.stage = STAGE_WAIT_ALL;
} else {
// ucc_info("not ready");
// not ready
return;
}
} else {
ucc_info("etask is nullptr");
return;
}
// break;
case STAGE_WAIT_ALL:
for (int i = 0; i < tsize; ++i) {
if (get_rank_step(task, i, 0) < task->bcast_linear.step) {
// need to wait until all ranks complete step - 1, because of double buffering
if (get_rank_step(task, i, 0) < task->bcast_linear.step - 1) {
// rank is not ready, lets wait
return;
}
}
task->bcast_linear.stage = STAGE_COPY;
// ucc_info("all others ready for next step");
if (task->bcast_linear.step < task->bcast_linear.num_steps) {
// go to next iteration
task->bcast_linear.stage = STAGE_COPY;
Expand All @@ -205,15 +198,15 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
// finish
task->bcast_linear.stage = STAGE_DONE;
}
// break;
case STAGE_DONE:
task->super.status = UCC_OK;
break;
default:
break;
}
} else {
// others
// clients
// fall-through between cases is intentional
switch (task->bcast_linear.stage) {
case STAGE_WAIT_ROOT:
/* code */
Expand All @@ -225,22 +218,19 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
} else {
return;
}
// break;
case STAGE_CLIENT_COPY:
dbuf = PTR_OFFSET(task->bcast_linear.sbuf, offset_buff);
sbuf = TASK_SCRATCH(
task,
task->bcast_linear
.root); // need to copy from root's scratch buffer
st = ecopy(dbuf, sbuf, chunk_size, exec,
// need to copy from root's scratch buffer
sbuf = PTR_OFFSET(TASK_SCRATCH(task, task->bcast_linear.root),
task->bcast_linear.step % 2 * chunk_size);
st = ecopy(dbuf, sbuf, chunk_size, exec,
&task->bcast_linear.exec_task);
if (st != UCC_OK) {
ucc_error("failed to post ecopy task at client");
task->super.status = st;
return;
}
task->bcast_linear.stage = STAGE_CLIENT_COPY_WAIT;
// break;
case STAGE_CLIENT_COPY_WAIT:
etask = task->bcast_linear.exec_task;
if (etask) {
Expand All @@ -263,7 +253,6 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
} else {
return;
}
// break;
case STAGE_DONE:
task->super.status = UCC_OK;
break;
Expand All @@ -278,20 +267,17 @@ ucc_status_t ucc_tl_cuda_bcast_linear_start(ucc_coll_task_t *coll_task)
ucc_tl_cuda_task_t *task = ucc_derived_of(coll_task, ucc_tl_cuda_task_t);
ucc_tl_cuda_team_t *team = TASK_TEAM(task);
ucc_coll_args_t * args = &TASK_ARGS(task);
// ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
ucc_datatype_t dt = task->bcast_linear.dt;
ucc_datatype_t dt = task->bcast_linear.dt;
size_t half_scratch_size = get_raw_scratch_size(team) / 2;

task->bcast_linear.stage = STAGE_SYNC;

ucc_info("bcast start with dt: %s and count: %ld", ucc_datatype_str(dt),
args->src.info.count);

task->bcast_linear.size = ucc_dt_size(dt) * args->src.info.count;
size_t scratch_size = get_raw_scratch_size(team);
task->bcast_linear.num_steps =
ucc_div_round_up(task->bcast_linear.size, scratch_size);
ucc_div_round_up(task->bcast_linear.size, half_scratch_size);

ucc_info("bcast buffer size: %ld, num_steps: %d", task->bcast_linear.size,
ucc_info("bcast dt: %s, buffer size: %ld, num_steps: %d",
ucc_datatype_str(dt), task->bcast_linear.size,
task->bcast_linear.num_steps);

task->bcast_linear.sbuf = args->src.info.buffer;
Expand All @@ -308,8 +294,6 @@ ucc_status_t ucc_tl_cuda_bcast_linear_init(ucc_base_coll_args_t *coll_args,
ucc_tl_cuda_task_t *task;
ucc_status_t status;

ucc_info("bcast init");

if (ucc_unlikely(!ucc_tl_cuda_team_topo_is_fully_conntected(team->topo) ||
UCC_TL_TEAM_SIZE(team) - 1 >
UCC_EE_EXECUTOR_MULTI_OP_NUM_BUFS)) {
Expand All @@ -323,8 +307,6 @@ ucc_status_t ucc_tl_cuda_bcast_linear_init(ucc_base_coll_args_t *coll_args,

task->bcast_linear.root = coll_args->args.root;
task->bcast_linear.dt = coll_args->args.src.info.datatype;
ucc_info("bcast init with dt: %s", ucc_datatype_str(task->bcast_linear.dt));

task->bcast_linear.sbuf = coll_args->args.src.info.buffer;

task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
Expand All @@ -333,8 +315,6 @@ ucc_status_t ucc_tl_cuda_bcast_linear_init(ucc_base_coll_args_t *coll_args,
task->super.finalize = ucc_tl_cuda_bcast_linear_finalize;
task->bar = TASK_BAR(task);

ucc_info("bcast init success");

*task_p = &task->super;
return UCC_OK;
}

0 comments on commit e43c6b5

Please sign in to comment.