Skip to content

Commit

Permalink
TL/CUDA: minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ikryukov authored and Ilya Kryukov committed Jul 4, 2024
1 parent ed59138 commit 6caea67
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions src/components/tl/cuda/bcast/bcast_linear.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,8 @@ ucc_status_t ucc_tl_cuda_bcast_linear_setup_test(ucc_tl_cuda_task_t *task)
return ucc_tl_cuda_shm_barrier_test(UCC_TL_TEAM_RANK(team), task->bar);
}

static inline size_t get_scratch_size(ucc_tl_cuda_team_t *team,
ucc_datatype_t dt)
static inline size_t get_raw_scratch_size(ucc_tl_cuda_team_t *team)
{
size_t dt_size = ucc_dt_size(dt);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);

ucc_assert((dt_size > 0) && (tsize > 0));

return UCC_TL_CUDA_TEAM_LIB(team)->cfg.scratch_size;
}

Expand Down Expand Up @@ -101,8 +95,8 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
ucc_tl_cuda_team_t *team = TASK_TEAM(task);
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
ucc_datatype_t dt = task->bcast_linear.dt;
ucc_status_t st;
// ucc_datatype_t dt = task->bcast_linear.dt;
ucc_status_t st;
(void)team;
(void)st;
ucc_ee_executor_t *exec;
Expand Down Expand Up @@ -148,7 +142,7 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
break;
}

size_t scratch_size = get_scratch_size(team, dt);
size_t scratch_size = get_raw_scratch_size(team);
size_t chunk_size = task->bcast_linear.step < task->bcast_linear.num_steps
? ucc_min(scratch_size, task->bcast_linear.size)
: task->bcast_linear.size -
Expand Down Expand Up @@ -239,8 +233,6 @@ void ucc_tl_cuda_bcast_linear_progress(ucc_coll_task_t *coll_task)
task->bcast_linear.exec_task = NULL;
++task->bcast_linear.step;
set_rank_step(task, trank, task->bcast_linear.step, 0);
// task->bcast_linear.stage =
// STAGE_DONE; // TODO: just for debug
if (task->bcast_linear.step <
task->bcast_linear.num_steps) {
task->bcast_linear.stage = STAGE_WAIT_ROOT;
Expand Down Expand Up @@ -276,7 +268,7 @@ ucc_status_t ucc_tl_cuda_bcast_linear_start(ucc_coll_task_t *coll_task)
args->src.info.count);

task->bcast_linear.size = ucc_dt_size(dt) * args->src.info.count;
size_t scratch_size = get_scratch_size(team, dt);
size_t scratch_size = get_raw_scratch_size(team);
task->bcast_linear.num_steps =
ucc_div_round_up(task->bcast_linear.size, scratch_size);

Expand Down

0 comments on commit 6caea67

Please sign in to comment.