Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/UCP: use knomial pattern in gather #1044

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions src/coll_patterns/recursive_knomial.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ enum {
KN_PATTERN_ALLGATHER,
KN_PATTERN_ALLGATHERV,
KN_PATTERN_ALLGATHERX,
KN_PATTERN_GATHER,
KN_PATTERN_GATHERX,
};

typedef struct ucc_knomial_pattern {
Expand Down Expand Up @@ -83,7 +85,7 @@ static inline ucc_rank_t ucc_kn_pattern_radix_pow_init(ucc_knomial_pattern_t *p,
static inline void
ucc_knomial_pattern_init_impl(ucc_rank_t size, ucc_rank_t rank,
ucc_kn_radix_t radix, ucc_knomial_pattern_t *p,
int backward)
int backward, int has_extra)
{
ucc_rank_t fs = radix;
ucc_rank_t n_full_subtrees;
Expand All @@ -100,7 +102,7 @@ ucc_knomial_pattern_init_impl(ucc_rank_t size, ucc_rank_t rank,
p->backward = backward;
p->iteration = 0;
n_full_subtrees = ucc_kn_pattern_n_full(p);
p->n_extra = size - n_full_subtrees * p->full_pow_size;
p->n_extra = has_extra ? size - n_full_subtrees * p->full_pow_size : 0;
p->n_iters = (p->n_extra && n_full_subtrees == 1) ?
p->pow_radix_sup - 1 : p->pow_radix_sup;
p->radix_pow = ucc_kn_pattern_radix_pow_init(p, backward);
Expand All @@ -115,14 +117,22 @@ ucc_knomial_pattern_init_backward(ucc_rank_t size, ucc_rank_t rank,
ucc_kn_radix_t radix,
ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_impl(size, rank, radix, p, 1);
ucc_knomial_pattern_init_impl(size, rank, radix, p, 1, 1);
}

static inline void
ucc_knomial_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_impl(size, rank, radix, p, 0);
ucc_knomial_pattern_init_impl(size, rank, radix, p, 0, 1);
}

static inline void
ucc_knomial_pattern_init_no_extra(ucc_rank_t size, ucc_rank_t rank,
ucc_kn_radix_t radix,
ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_impl(size, rank, radix, p, 0, 0);
}

static inline ucc_rank_t
Expand Down Expand Up @@ -186,6 +196,23 @@ ucc_knomial_pattern_get_loop_peer(ucc_knomial_pattern_t *p, ucc_rank_t rank,
ucc_knomial_pattern_loop_rank_inv(p, peer);
}

static inline ucc_rank_t
ucc_knomial_pattern_get_base_rank(ucc_knomial_pattern_t *p, ucc_rank_t rank)
{
ucc_rank_t step_size = p->radix_pow * p->radix;
ucc_rank_t lrank;
ucc_kn_radix_t s;

lrank = ucc_knomial_pattern_loop_rank(p, rank);
s = ucc_div_round_up(step_size - (lrank % step_size), p->radix_pow);

if (s == p->radix) {
return rank;
} else {
return ucc_knomial_pattern_get_loop_peer(p, rank, s);
}
}

static inline void
ucc_knomial_pattern_next_iteration(ucc_knomial_pattern_t *p)
{
Expand Down Expand Up @@ -224,11 +251,13 @@ static inline ucc_rank_t
ucc_knomial_calc_recv_dist(ucc_rank_t team_size, ucc_rank_t rank,
ucc_rank_t radix, ucc_rank_t root)
{
ucc_rank_t root_base = 0;
ucc_rank_t dist = 1;

if (rank == root) {
return 0;
}
ucc_rank_t root_base = 0 ;
ucc_rank_t dist = 1;

while (dist <= team_size) {
if (rank < root_base + radix * dist) {
break;
Expand Down
76 changes: 76 additions & 0 deletions src/coll_patterns/sra_knomial.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,82 @@ ucc_knx_block(ucc_rank_t rank, ucc_rank_t size, ucc_kn_radix_t radix,
*b_offset = offset;
}

static inline void
ucc_kn_g_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
size_t count, ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_no_extra(size, rank, radix, p);
p->type = KN_PATTERN_GATHER;
p->count = count;
p->block_size = p->radix_pow * radix;
p->block_offset = ucc_knomial_pattern_loop_rank(p, rank) / p->block_size *
p->block_size;
}

static inline void
ucc_kn_gx_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
size_t count, ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_backward(size, rank, radix, p);
p->type = KN_PATTERN_GATHERX;
p->count = count;
if (p->node_type != KN_NODE_EXTRA) {
p->block_size = ucc_kn_compute_step_radix(p);
ucc_knx_block(rank, size, radix, count, p->n_iters - 1,
&p->block_size_counts, &p->block_offset);

}

}

static inline void
ucc_kn_g_pattern_peer_seg(ucc_rank_t peer, ucc_knomial_pattern_t *p,
size_t *seg_count, ptrdiff_t *seg_offset)
{
ucc_rank_t step_radix, seg_index;

*seg_count = 0;
*seg_offset = 0;
switch (p->type) {
case KN_PATTERN_GATHER:
*seg_count = ucc_min(p->radix_pow, p->size - peer) * (p->count / p->size);
*seg_offset = peer * (p->count / p->size);
return;
case KN_PATTERN_GATHERX:
step_radix = ucc_kn_compute_step_radix(p);
seg_index = ucc_kn_compute_seg_index(peer, p->radix_pow, p);
*seg_offset = ucc_buffer_block_offset(p->block_size_counts, step_radix,
seg_index) + p->block_offset;
*seg_count = ucc_buffer_block_count(p->block_size_counts, step_radix,
seg_index);
return;
default:
ucc_assert(0);
}
}

static inline void ucc_kn_g_pattern_next_iter(ucc_knomial_pattern_t *p)
{
ucc_rank_t rank;
if (p->type == KN_PATTERN_GATHERX) {
ucc_knomial_pattern_next_iteration_backward(p);

if (!ucc_knomial_pattern_loop_done(p)) {
ucc_knx_block(p->rank, p->size, p->radix, p->count,
p->n_iters - 1 - p->iteration,
&p->block_size_counts, &p->block_offset);
}
} else {
rank = ucc_knomial_pattern_loop_rank(p, p->rank);
ucc_knomial_pattern_next_iteration(p);

if (!ucc_knomial_pattern_loop_done(p)) {
p->block_size *= ucc_kn_compute_step_radix(p);
p->block_offset = rank / p->block_size * p->block_size;
}
}
}

static inline void
ucc_kn_ag_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
size_t count, ucc_knomial_pattern_t *p)
Expand Down
61 changes: 6 additions & 55 deletions src/components/tl/ucp/gather/gather.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -17,62 +17,13 @@ ucc_base_coll_alg_info_t
[UCC_TL_UCP_GATHER_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

static inline uint32_t calc_buffer_size(ucc_rank_t rank, uint32_t radix, ucc_rank_t team_size)
{
uint32_t radix_valuation;

if (rank == 0) {
return team_size;
}
radix_valuation = calc_valuation(rank, radix);
return (uint32_t)ucc_min(pow(radix, radix_valuation), team_size - rank);
}

ucc_status_t ucc_tl_ucp_gather_init(ucc_tl_ucp_task_t *task)
{
ucc_coll_args_t * args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t myrank = UCC_TL_TEAM_RANK(team);
ucc_rank_t team_size = UCC_TL_TEAM_SIZE(team);
ucc_rank_t root = args->root;
ucc_rank_t vrank = (myrank - root + team_size) % team_size;
ucc_status_t status = UCC_OK;
ucc_memory_type_t mtype;
ucc_datatype_t dt;
size_t count, data_size;
uint32_t buffer_size;
int isleaf;

if (root == myrank) {
count = args->dst.info.count;
dt = args->dst.info.datatype;
mtype = args->dst.info.mem_type;
} else {
count = args->src.info.count;
dt = args->src.info.datatype;
mtype = args->src.info.mem_type;
}
data_size = count * ucc_dt_size(dt);
task->super.post = ucc_tl_ucp_gather_knomial_start;
task->super.progress = ucc_tl_ucp_gather_knomial_progress;
task->super.finalize = ucc_tl_ucp_gather_knomial_finalize;
task->gather_kn.radix =
ucc_min(UCC_TL_UCP_TEAM_LIB(team)->cfg.gather_kn_radix, team_size);
CALC_KN_TREE_DIST(team_size, task->gather_kn.radix,
task->gather_kn.max_dist);
isleaf = (vrank % task->gather_kn.radix != 0 || vrank == team_size - 1);
task->gather_kn.scratch_mc_header = NULL;
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t size = UCC_TL_TEAM_SIZE(team);
ucc_kn_radix_t radix;

if (vrank == 0) {
task->gather_kn.scratch = args->dst.info.buffer;
} else if (isleaf) {
task->gather_kn.scratch = args->src.info.buffer;
} else {
buffer_size = calc_buffer_size(vrank, task->gather_kn.radix, team_size);
status = ucc_mc_alloc(&task->gather_kn.scratch_mc_header,
buffer_size * data_size, mtype);
task->gather_kn.scratch = task->gather_kn.scratch_mc_header->addr;
}
radix = ucc_min(UCC_TL_UCP_TEAM_LIB(team)->cfg.gather_kn_radix, size);

return status;
return ucc_tl_ucp_gather_knomial_init_common(task, radix);
}
10 changes: 9 additions & 1 deletion src/components/tl/ucp/gather/gather.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -45,4 +45,12 @@ void ucc_tl_ucp_gather_knomial_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_gather_knomial_finalize(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_gather_knomial_init_common(ucc_tl_ucp_task_t *task,
ucc_kn_radix_t radix);

/* Internal interface with custom radix */
ucc_status_t ucc_tl_ucp_gather_knomial_init_r(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h,
ucc_kn_radix_t radix);
#endif
Loading
Loading