Skip to content

Commit

Permalink
TL/UCP: reduce srg
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Oct 30, 2024
1 parent 070eb64 commit d7d5bb9
Show file tree
Hide file tree
Showing 16 changed files with 808 additions and 195 deletions.
40 changes: 34 additions & 6 deletions src/coll_patterns/recursive_knomial.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ enum {
KN_PATTERN_ALLGATHER,
KN_PATTERN_ALLGATHERV,
KN_PATTERN_ALLGATHERX,
KN_PATTERN_GATHER,
KN_PATTERN_GATHERX,
};

typedef struct ucc_knomial_pattern {
Expand Down Expand Up @@ -83,7 +85,7 @@ static inline ucc_rank_t ucc_kn_pattern_radix_pow_init(ucc_knomial_pattern_t *p,
static inline void
ucc_knomial_pattern_init_impl(ucc_rank_t size, ucc_rank_t rank,
ucc_kn_radix_t radix, ucc_knomial_pattern_t *p,
int backward)
int backward, int extra)
{
ucc_rank_t fs = radix;
ucc_rank_t n_full_subtrees;
Expand All @@ -100,7 +102,7 @@ ucc_knomial_pattern_init_impl(ucc_rank_t size, ucc_rank_t rank,
p->backward = backward;
p->iteration = 0;
n_full_subtrees = ucc_kn_pattern_n_full(p);
p->n_extra = size - n_full_subtrees * p->full_pow_size;
p->n_extra = extra ? size - n_full_subtrees * p->full_pow_size : 0;
p->n_iters = (p->n_extra && n_full_subtrees == 1) ?
p->pow_radix_sup - 1 : p->pow_radix_sup;
p->radix_pow = ucc_kn_pattern_radix_pow_init(p, backward);
Expand All @@ -115,14 +117,21 @@ ucc_knomial_pattern_init_backward(ucc_rank_t size, ucc_rank_t rank,
ucc_kn_radix_t radix,
ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_impl(size, rank, radix, p, 1);
ucc_knomial_pattern_init_impl(size, rank, radix, p, 1, 1);
}

static inline void
ucc_knomial_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_impl(size, rank, radix, p, 0);
ucc_knomial_pattern_init_impl(size, rank, radix, p, 0, 1);
}

static inline void
ucc_knomial_pattern_init_no_extra(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_impl(size, rank, radix, p, 0, 0);
}

static inline ucc_rank_t
Expand Down Expand Up @@ -186,6 +195,23 @@ ucc_knomial_pattern_get_loop_peer(ucc_knomial_pattern_t *p, ucc_rank_t rank,
ucc_knomial_pattern_loop_rank_inv(p, peer);
}

static inline ucc_rank_t
ucc_knomial_pattern_get_base_rank(ucc_knomial_pattern_t *p, ucc_rank_t rank)
{
ucc_rank_t step_size = p->radix_pow * p->radix;
ucc_rank_t lrank;
ucc_kn_radix_t s;

lrank = ucc_knomial_pattern_loop_rank(p, rank);
s = ucc_div_round_up(step_size - (lrank % step_size), p->radix_pow);

if (s == p->radix) {
return rank;
} else {
return ucc_knomial_pattern_get_loop_peer(p, rank, s);
}
}

static inline void
ucc_knomial_pattern_next_iteration(ucc_knomial_pattern_t *p)
{
Expand Down Expand Up @@ -224,11 +250,13 @@ static inline ucc_rank_t
ucc_knomial_calc_recv_dist(ucc_rank_t team_size, ucc_rank_t rank,
ucc_rank_t radix, ucc_rank_t root)
{
ucc_rank_t root_base = 0 ;
ucc_rank_t dist = 1;

if (rank == root) {
return 0;
}
ucc_rank_t root_base = 0 ;
ucc_rank_t dist = 1;

while (dist <= team_size) {
if (rank < root_base + radix * dist) {
break;
Expand Down
76 changes: 76 additions & 0 deletions src/coll_patterns/sra_knomial.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,60 @@ ucc_kn_ag_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
p->block_size;
}

static inline void
ucc_kn_g_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
size_t count, ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_no_extra(size, rank, radix, p);
p->type = KN_PATTERN_GATHER;
p->count = count;
p->block_size = p->radix_pow * radix;
p->block_offset = ucc_knomial_pattern_loop_rank(p, rank) / p->block_size *
p->block_size;
}

static inline void
ucc_kn_gx_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
size_t count, ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_backward(size, rank, radix, p);
p->type = KN_PATTERN_GATHERX;
p->count = count;
if (p->node_type != KN_NODE_EXTRA) {
p->block_size = ucc_kn_compute_step_radix(p);
ucc_knx_block(rank, size, radix, count, p->n_iters - 1,
&p->block_size_counts, &p->block_offset);

}

}

static inline void
ucc_kn_g_pattern_peer_seg(ucc_rank_t peer, ucc_knomial_pattern_t *p,
size_t *seg_count, ptrdiff_t *seg_offset)
{
ucc_rank_t step_radix, seg_index;

*seg_count = 0;
*seg_offset = 0;
switch (p->type) {
case KN_PATTERN_GATHER:
*seg_count = ucc_min(p->radix_pow, p->size - peer) * (p->count / p->size);
*seg_offset = peer * (p->count / p->size);
return;
case KN_PATTERN_GATHERX:
step_radix = ucc_kn_compute_step_radix(p);
seg_index = ucc_kn_compute_seg_index(peer, p->radix_pow, p);
*seg_offset = ucc_buffer_block_offset(p->block_size_counts, step_radix,
seg_index) + p->block_offset;
*seg_count = ucc_buffer_block_count(p->block_size_counts, step_radix,
seg_index);
return;
default:
ucc_assert(0);
}
}

static inline void
ucc_kn_agx_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
size_t count, ucc_knomial_pattern_t *p)
Expand Down Expand Up @@ -283,6 +337,28 @@ static inline void ucc_kn_ag_pattern_next_iter(ucc_knomial_pattern_t *p)
}
}

static inline void ucc_kn_g_pattern_next_iter(ucc_knomial_pattern_t *p)
{
ucc_rank_t rank;
if (p->type == KN_PATTERN_GATHERX) {
ucc_knomial_pattern_next_iteration_backward(p);

if (!ucc_knomial_pattern_loop_done(p)) {
ucc_knx_block(p->rank, p->size, p->radix, p->count,
p->n_iters - 1 - p->iteration,
&p->block_size_counts, &p->block_offset);
}
} else {
rank = ucc_knomial_pattern_loop_rank(p, p->rank);
ucc_knomial_pattern_next_iteration(p);

if (!ucc_knomial_pattern_loop_done(p)) {
p->block_size *= ucc_kn_compute_step_radix(p);
p->block_offset = rank / p->block_size * p->block_size;
}
}
}

static inline void ucc_kn_rs_pattern_init(ucc_rank_t size, ucc_rank_t rank,
ucc_kn_radix_t radix, size_t count,
ucc_knomial_pattern_t *p)
Expand Down
1 change: 1 addition & 0 deletions src/components/cl/hier/bcast/bcast_2step.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ find_root_node_rank(ucc_rank_t root, ucc_cl_hier_team_t *cl_team)
return UCC_RANK_INVALID;
}


static ucc_status_t
ucc_cl_hier_bcast_2step_init_schedule(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
Expand Down
11 changes: 6 additions & 5 deletions src/components/tl/ucp/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,12 @@ gatherv = \
gatherv/gatherv.c \
gatherv/gatherv_linear.c

reduce = \
reduce/reduce.h \
reduce/reduce.c \
reduce/reduce_knomial.c \
reduce/reduce_dbt.c
reduce = \
reduce/reduce.h \
reduce/reduce.c \
reduce/reduce_knomial.c \
reduce/reduce_dbt.c \
reduce/reduce_srg_knomial.c

reduce_scatter = \
reduce_scatter/reduce_scatter.h \
Expand Down
8 changes: 4 additions & 4 deletions src/components/tl/ucp/allreduce/allreduce_sliding_window.c
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ ucc_tl_ucp_allreduce_sliding_window_rdma_task_post(
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);

ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
Expand Down Expand Up @@ -325,9 +325,9 @@ void ucc_tl_ucp_allreduce_sliding_window_rdma_progress(ucc_coll_task_t *coll_tas
ucc_tl_ucp_allreduce_sw_buf_t *accbuf = &pipe->accbuf;
ucp_request_param_t req_param = {0};
int i = 0;
ucc_coll_task_t *allgather_task =
ucc_coll_task_t *allgather_task =
task->allreduce_sliding_window.allgather_task;
ucc_ee_executor_task_t **reduce_task =
ucc_ee_executor_task_t **reduce_task =
&task->allreduce_sliding_window.reduce_task;
ucc_rank_t put_window_size =
UCC_TL_UCP_TEAM_LIB(tl_team)->
Expand Down Expand Up @@ -490,7 +490,7 @@ void ucc_tl_ucp_allreduce_sliding_window_rdma_progress(ucc_coll_task_t *coll_tas

ucp_worker_fence(tl_ctx->worker.ucp_worker);
ucc_tl_ucp_get_ep(tl_team, dst_rank, &ep);
task->allreduce_sliding_window.put_requests[put_idx] =
task->allreduce_sliding_window.put_requests[put_idx] =
ucp_put_nbx(
ep, src_addr,
data_size, (uint64_t)dst_addr,
Expand Down
59 changes: 5 additions & 54 deletions src/components/tl/ucp/gather/gather.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,62 +17,13 @@ ucc_base_coll_alg_info_t
[UCC_TL_UCP_GATHER_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

static inline uint32_t calc_buffer_size(ucc_rank_t rank, uint32_t radix, ucc_rank_t team_size)
{
uint32_t radix_valuation;

if (rank == 0) {
return team_size;
}
radix_valuation = calc_valuation(rank, radix);
return (uint32_t)ucc_min(pow(radix, radix_valuation), team_size - rank);
}

ucc_status_t ucc_tl_ucp_gather_init(ucc_tl_ucp_task_t *task)
{
ucc_coll_args_t * args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t myrank = UCC_TL_TEAM_RANK(team);
ucc_rank_t team_size = UCC_TL_TEAM_SIZE(team);
ucc_rank_t root = args->root;
ucc_rank_t vrank = (myrank - root + team_size) % team_size;
ucc_status_t status = UCC_OK;
ucc_memory_type_t mtype;
ucc_datatype_t dt;
size_t count, data_size;
uint32_t buffer_size;
int isleaf;

if (root == myrank) {
count = args->dst.info.count;
dt = args->dst.info.datatype;
mtype = args->dst.info.mem_type;
} else {
count = args->src.info.count;
dt = args->src.info.datatype;
mtype = args->src.info.mem_type;
}
data_size = count * ucc_dt_size(dt);
task->super.post = ucc_tl_ucp_gather_knomial_start;
task->super.progress = ucc_tl_ucp_gather_knomial_progress;
task->super.finalize = ucc_tl_ucp_gather_knomial_finalize;
task->gather_kn.radix =
ucc_min(UCC_TL_UCP_TEAM_LIB(team)->cfg.gather_kn_radix, team_size);
CALC_KN_TREE_DIST(team_size, task->gather_kn.radix,
task->gather_kn.max_dist);
isleaf = (vrank % task->gather_kn.radix != 0 || vrank == team_size - 1);
task->gather_kn.scratch_mc_header = NULL;
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t size = UCC_TL_TEAM_SIZE(team);
ucc_kn_radix_t radix;

if (vrank == 0) {
task->gather_kn.scratch = args->dst.info.buffer;
} else if (isleaf) {
task->gather_kn.scratch = args->src.info.buffer;
} else {
buffer_size = calc_buffer_size(vrank, task->gather_kn.radix, team_size);
status = ucc_mc_alloc(&task->gather_kn.scratch_mc_header,
buffer_size * data_size, mtype);
task->gather_kn.scratch = task->gather_kn.scratch_mc_header->addr;
}
radix = ucc_min(UCC_TL_UCP_TEAM_LIB(team)->cfg.gather_kn_radix, size);

return status;
return ucc_tl_ucp_gather_knomial_init_common(task, radix);
}
8 changes: 8 additions & 0 deletions src/components/tl/ucp/gather/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,12 @@ void ucc_tl_ucp_gather_knomial_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_gather_knomial_finalize(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_gather_knomial_init_common(ucc_tl_ucp_task_t *task,
ucc_kn_radix_t radix);

/* Internal interface with custom radix */
ucc_status_t ucc_tl_ucp_gather_knomial_init_r(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h,
ucc_kn_radix_t radix);
#endif
Loading

0 comments on commit d7d5bb9

Please sign in to comment.