diff --git a/src/coll_patterns/recursive_knomial.h b/src/coll_patterns/recursive_knomial.h index e234b0cafb..5935fbddb6 100644 --- a/src/coll_patterns/recursive_knomial.h +++ b/src/coll_patterns/recursive_knomial.h @@ -85,7 +85,7 @@ static inline ucc_rank_t ucc_kn_pattern_radix_pow_init(ucc_knomial_pattern_t *p, static inline void ucc_knomial_pattern_init_impl(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix, ucc_knomial_pattern_t *p, - int backward, int extra) + int backward, int has_extra) { ucc_rank_t fs = radix; ucc_rank_t n_full_subtrees; @@ -102,7 +102,7 @@ ucc_knomial_pattern_init_impl(ucc_rank_t size, ucc_rank_t rank, p->backward = backward; p->iteration = 0; n_full_subtrees = ucc_kn_pattern_n_full(p); - p->n_extra = extra ? size - n_full_subtrees * p->full_pow_size : 0; + p->n_extra = has_extra ? size - n_full_subtrees * p->full_pow_size : 0; p->n_iters = (p->n_extra && n_full_subtrees == 1) ? p->pow_radix_sup - 1 : p->pow_radix_sup; p->radix_pow = ucc_kn_pattern_radix_pow_init(p, backward); diff --git a/src/components/tl/ucp/gather/gather_knomial.c b/src/components/tl/ucp/gather/gather_knomial.c index 7871a9209e..7052540306 100644 --- a/src/components/tl/ucp/gather/gather_knomial.c +++ b/src/components/tl/ucp/gather/gather_knomial.c @@ -16,19 +16,20 @@ task->gather_kn.phase = _phase; \ } while (0) -static inline uint32_t calc_buffer_size(ucc_rank_t trank, uint32_t radix, +static inline uint32_t calc_buffer_size(ucc_rank_t vrank, uint32_t radix, ucc_rank_t tsize) { uint32_t radix_valuation; - if (trank == 0) { + if (vrank == 0) { return tsize; } - radix_valuation = calc_valuation(trank, radix); - return (uint32_t)ucc_min(pow(radix, radix_valuation), tsize - trank); + radix_valuation = calc_valuation(vrank, radix); + return (uint32_t)ucc_min(pow(radix, radix_valuation), tsize - vrank); } +/* gather knomial is used as regular gather collective and as part of reduce SRG */ void ucc_tl_ucp_gather_knomial_progress(ucc_coll_task_t *coll_task) { ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); @@ -86,8 +87,8 @@ void ucc_tl_ucp_gather_knomial_progress(ucc_coll_task_t *coll_task) task->gather_kn.dist); } UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(scratch_offset, - msg_size, mtype, peer, - team, task), + msg_size, mtype, peer, + team, task), task, out); } else { /* @@ -176,10 +177,11 @@ void ucc_tl_ucp_gather_knomial_progress(ucc_coll_task_t *coll_task) root_at_level, team, task), task, out); } else { + // need to split in this case due to root and tree topology msg_size = data_size * (tsize - rank); UCPCHECK_GOTO(ucc_tl_ucp_send_nb(task->gather_kn.scratch, - msg_size, mtype, - root_at_level, team, task), + msg_size, mtype, + root_at_level, team, task), task, out); msg_size = data_size * (num_blocks - (tsize - rank)); UCPCHECK_GOTO( @@ -226,6 +228,8 @@ ucc_status_t ucc_tl_ucp_gather_knomial_start(ucc_coll_task_t *coll_task) task->gather_kn.radix, args->src.info.count * size, &task->gather_kn.p); } else { + /* reduce srg */ + ucc_assert(args->coll_type == UCC_COLL_TYPE_REDUCE); task->gather_kn.scratch = args->dst.info.buffer; ucc_kn_gx_pattern_init(size, VRANK(trank, root, size), task->gather_kn.radix, args->dst.info.count, @@ -265,7 +269,7 @@ ucc_status_t ucc_tl_ucp_gather_knomial_init_common(ucc_tl_ucp_task_t *task, ucc_datatype_t dt; size_t count, data_size; uint32_t buffer_size; - int isleaf; + int is_leaf; if (UCC_IS_ROOT(*args, trank)) { count = args->dst.info.count; @@ -288,10 +292,11 @@ ucc_status_t ucc_tl_ucp_gather_knomial_init_common(ucc_tl_ucp_task_t *task, if (args->coll_type == UCC_COLL_TYPE_REDUCE) { task->gather_kn.scratch = args->dst.info.buffer; } else { - isleaf = ((vrank % radix != 0) || (vrank == tsize - 1)); + ucc_assert(args->coll_type == UCC_COLL_TYPE_GATHER); + is_leaf = ((vrank % radix != 0) || (vrank == tsize - 1)); if (vrank == 0) { task->gather_kn.scratch = args->dst.info.buffer; - } else if (isleaf) { + } else if (is_leaf) { task->gather_kn.scratch = args->src.info.buffer; } else { buffer_size = calc_buffer_size(vrank, task->gather_kn.radix, tsize);