Skip to content

Commit

Permalink
CL/HIER: Add allgatherv
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarka committed Dec 13, 2024
1 parent 73651ea commit 9ee9a17
Show file tree
Hide file tree
Showing 15 changed files with 661 additions and 44 deletions.
59 changes: 33 additions & 26 deletions src/components/cl/hier/Makefile.am
Original file line number Diff line number Diff line change
@@ -1,48 +1,55 @@
#
# Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#

allgatherv = \
allgatherv/unpack.h \
allgatherv/unpack.c \
allgatherv/allgatherv.h \
allgatherv/allgatherv.c

allreduce = \
allreduce/allreduce.h \
allreduce/allreduce.c \
allreduce/allreduce_rab.c \
allreduce/allreduce_split_rail.c

alltoallv = \
alltoallv/alltoallv.h \
alltoallv = \
alltoallv/alltoallv.h \
alltoallv/alltoallv.c

alltoall = \
alltoall/alltoall.h \
alltoall = \
alltoall/alltoall.h \
alltoall/alltoall.c

barrier = \
barrier/barrier.h \
barrier = \
barrier/barrier.h \
barrier/barrier.c

bcast = \
bcast/bcast.h \
bcast/bcast.c \
bcast = \
bcast/bcast.h \
bcast/bcast.c \
bcast/bcast_2step.c

reduce = \
reduce/reduce.h \
reduce/reduce.c \
reduce = \
reduce/reduce.h \
reduce/reduce.c \
reduce/reduce_2step.c

sources = \
cl_hier.h \
cl_hier.c \
cl_hier_lib.c \
cl_hier_context.c \
cl_hier_team.c \
cl_hier_coll.c \
cl_hier_coll.h \
$(allreduce) \
$(alltoallv) \
$(alltoall) \
$(barrier) \
$(bcast) \
sources = \
cl_hier.h \
cl_hier.c \
cl_hier_lib.c \
cl_hier_context.c \
cl_hier_team.c \
cl_hier_coll.c \
cl_hier_coll.h \
$(allgatherv) \
$(allreduce) \
$(alltoallv) \
$(alltoall) \
$(barrier) \
$(bcast) \
$(reduce)

module_LTLIBRARIES = libucc_cl_hier.la
Expand Down
307 changes: 307 additions & 0 deletions src/components/cl/hier/allgatherv/allgatherv.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
/**
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "allgatherv.h"
#include "unpack.h"
#include "../cl_hier_coll.h"
#include "core/ucc_team.h"

#define MAX_ALLGATHERV_TASKS 4

ucc_base_coll_alg_info_t
ucc_cl_hier_allgatherv_algs[UCC_CL_HIER_ALLGATHERV_ALG_LAST + 1] = {
[UCC_CL_HIER_ALLGATHERV_ALG_GAB] =
{.id = UCC_CL_HIER_ALLGATHERV_ALG_GAB,
.name = "gab",
.desc = "gatherv + allgatherv + bcast"},
[UCC_CL_HIER_ALLGATHERV_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

static ucc_status_t ucc_cl_hier_allgatherv_start(ucc_coll_task_t *task)
{
UCC_CL_HIER_PROFILE_REQUEST_EVENT(task, "cl_hier_allgatherv_start", 0);
return ucc_schedule_start(task);
}

static ucc_status_t ucc_cl_hier_allgatherv_finalize(ucc_coll_task_t *task)
{
ucc_schedule_t *schedule = ucc_derived_of(task, ucc_schedule_t);
ucc_cl_hier_schedule_t *cl_schedule = ucc_derived_of(task,
ucc_cl_hier_schedule_t);
ucc_status_t status;

ucc_mc_free(cl_schedule->scratch);

UCC_CL_HIER_PROFILE_REQUEST_EVENT(task, "cl_hier_allgatherv_finalize", 0);
status = ucc_schedule_finalize(task);
ucc_cl_hier_put_schedule(schedule);
return status;
}

/* Return team_rank's node leader in team space */
static inline ucc_rank_t find_leader_rank(ucc_base_team_t *team,
ucc_rank_t team_rank)
{
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t);
ucc_team_t *core_team = team->params.team;
ucc_rank_t team_size = UCC_CL_TEAM_SIZE(cl_team);
ucc_rank_t ldr_sbgp_size = SBGP_SIZE(cl_team, NODE_LEADERS);
ucc_rank_t i, j;

ucc_assert(team_rank >= 0 && team_rank < team_size);
ucc_assert(SBGP_ENABLED(cl_team, NODE_LEADERS));

/* Allocate and populate node_leaders and leader_list */
if (ucc_unlikely(cl_team->node_leaders == NULL)) {
cl_team->node_leaders = ucc_malloc(sizeof(ucc_rank_t) * team_size);
if (!cl_team->node_leaders) {
cl_error(team->context->lib, "Could not allocate node_leaders array");
}
cl_team->leader_list = ucc_malloc(sizeof(ucc_rank_t) * ldr_sbgp_size);
if (!cl_team->node_leaders) {
cl_error(team->context->lib, "Could not allocate leader_list array");
}
for (i = 0; i < team_size; i++) {
for (j = 0; j < ldr_sbgp_size; j++) {
ucc_rank_t ldr_team_rank = ucc_ep_map_eval(
SBGP_MAP(cl_team, NODE_LEADERS), j);
if (ucc_team_ranks_on_same_node(i, ldr_team_rank, core_team)) {
cl_team->node_leaders[i] = ldr_team_rank;
cl_team->leader_list[j] = ldr_team_rank;
break;
}
}
}
}

//NOLINTNEXTLINE
return cl_team->node_leaders[team_rank];
}

UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allgatherv_init,
(coll_args, team, task),
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
ucc_coll_task_t **task)
{
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team,
ucc_cl_hier_team_t);
ucc_coll_task_t *tasks[MAX_ALLGATHERV_TASKS]
= {NULL};
ucc_rank_t rank = UCC_CL_TEAM_RANK(cl_team);
ucc_rank_t node_sbgp_size = SBGP_SIZE(cl_team, NODE);
ucc_rank_t leader_sbgp_size = SBGP_SIZE(cl_team, NODE_LEADERS);
ucc_rank_t team_size = UCC_CL_TEAM_SIZE(cl_team);
ucc_aint_t *node_disps = NULL;
ucc_count_t *node_counts = NULL;
ucc_aint_t *leader_disps = NULL;
ucc_count_t *leader_counts = NULL;
size_t dt_size = ucc_dt_size(coll_args->args.
dst.info_v.datatype);
int in_place = 0;
int is_contig = 1;
ucc_schedule_t *schedule;
ucc_cl_hier_schedule_t *cl_schedule;
ucc_status_t status;
ucc_base_coll_args_t args, args_old;
int n_tasks, i;
size_t scratch_size;
size_t node_counts_size;
size_t node_disps_size;
size_t leader_counts_size;
size_t leader_disps_size;
size_t total_count;
void *node_gathered_data;

schedule = &ucc_cl_hier_get_schedule(cl_team)->super.super;
if (ucc_unlikely(!schedule)) {
return UCC_ERR_NO_MEMORY;
}
cl_schedule = ucc_derived_of(schedule, ucc_cl_hier_schedule_t);

memcpy(&args, coll_args, sizeof(args));
memcpy(&args_old, coll_args, sizeof(args));
in_place = UCC_IS_INPLACE(args.args);
is_contig = UCC_COLL_IS_DST_CONTIG(&args.args);
n_tasks = 0;
UCC_CHECK_GOTO(ucc_schedule_init(schedule, &args, team), free_sched, status);

node_counts_size = node_sbgp_size * sizeof(ucc_count_t);
node_disps_size = node_sbgp_size * sizeof(ucc_aint_t);
leader_counts_size = leader_sbgp_size * sizeof(ucc_count_t);
leader_disps_size = leader_sbgp_size * sizeof(ucc_aint_t);
total_count = ucc_coll_args_get_total_count(&args.args,
args.args.dst.info_v.counts, team_size);
scratch_size = node_counts_size + node_disps_size + leader_counts_size
+ leader_disps_size + (total_count * dt_size);

UCC_CHECK_GOTO(
ucc_mc_alloc(&cl_schedule->scratch, scratch_size, UCC_MEMORY_TYPE_HOST),
free_sched, status);
memset(cl_schedule->scratch->addr, 0, scratch_size);

node_counts = PTR_OFFSET(cl_schedule->scratch->addr, 0);
node_disps = PTR_OFFSET(node_counts, node_counts_size);
leader_counts = PTR_OFFSET(node_disps, node_disps_size);
leader_disps = PTR_OFFSET(leader_counts, leader_counts_size);
node_gathered_data = PTR_OFFSET(leader_disps, leader_disps_size);

if (SBGP_ENABLED(cl_team, NODE)) {
ucc_assert(n_tasks == 0);
if (cl_team->top_sbgp == UCC_HIER_SBGP_NODE) {
args.args.coll_type = UCC_COLL_TYPE_ALLGATHERV;
} else {
size_t disp_counter = 0;
for (i = 0; i < node_sbgp_size; i++) {
ucc_rank_t team_rank =
ucc_ep_map_eval(SBGP_MAP(cl_team, NODE), i);
ucc_coll_args_set_count(
&args.args, node_counts, i,
ucc_coll_args_get_count(&args.args,
args.args.dst.info_v.counts,
team_rank));
ucc_coll_args_set_displacement(&args.args, node_disps,
i, disp_counter);
disp_counter += ucc_coll_args_get_count(&args.args,
node_counts, i);
}

if (in_place) {
args.args.src.info.buffer =
PTR_OFFSET(args.args.dst.info_v.buffer,
dt_size * ucc_coll_args_get_displacement(
&args.args,
args.args.dst.info_v.displacements,
rank));
args.args.src.info.count =
ucc_coll_args_get_count(&args.args,
args.args.dst.info_v.counts,
rank);
args.args.src.info.datatype = args.args.dst.info_v.datatype;
args.args.src.info.mem_type = args.args.dst.info_v.mem_type;
}

args.args.coll_type = UCC_COLL_TYPE_GATHERV;
args.args.root = 0;
args.args.flags &= ~UCC_COLL_ARGS_FLAG_IN_PLACE;
args.args.dst.info_v.displacements = node_disps;
args.args.dst.info_v.counts = node_counts;
args.args.dst.info_v.buffer = node_gathered_data;
}
UCC_CHECK_GOTO(
ucc_coll_init(SCORE_MAP(cl_team, NODE), &args, &tasks[n_tasks]),
free_scratch, status);
n_tasks++;
}

args = args_old;

if (SBGP_ENABLED(cl_team, NODE_LEADERS)) {
ucc_assert(cl_team->top_sbgp == UCC_HIER_SBGP_NODE_LEADERS);
size_t disp_counter = 0;

/* Sum up the counts on each node to get the count for each node leader */
for (i = 0; i < team_size; i++) {
ucc_rank_t leader_team_rank = find_leader_rank(team, i);
ucc_rank_t leader_sbgp_rank = ucc_ep_map_local_rank(
SBGP_MAP(cl_team, NODE_LEADERS),
leader_team_rank);
size_t leader_old_count = ucc_coll_args_get_count(
&args.args, leader_counts,
leader_sbgp_rank);
size_t add_count = ucc_coll_args_get_count(
&args.args,
args.args.dst.info_v.counts, i);
size_t new_count = add_count + leader_old_count;
ucc_coll_args_set_count(&args.args, leader_counts,
leader_sbgp_rank, new_count);
}

for (i = 0; i < leader_sbgp_size; i++) {
//NOLINTNEXTLINE
ucc_rank_t leader_sgbp_rank = ucc_ep_map_local_rank(
SBGP_MAP(cl_team, NODE_LEADERS),
cl_team->leader_list[i]); //NOLINT
ucc_coll_args_set_displacement(&args.args, leader_disps,
leader_sgbp_rank, disp_counter);
disp_counter += ucc_coll_args_get_count(&args.args,
leader_counts,
leader_sgbp_rank);
}
args.args.coll_type = UCC_COLL_TYPE_ALLGATHERV;
args.args.flags &= ~UCC_COLL_ARGS_FLAG_IN_PLACE;
args.args.flags |= UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER;
args.args.src.info.buffer = node_gathered_data;
args.args.src.info.count = ucc_coll_args_get_total_count(
&args.args,
node_counts,
node_sbgp_size);
args.args.src.info.datatype = args.args.dst.info_v.datatype;
args.args.src.info.mem_type = UCC_MEMORY_TYPE_HOST;
args.args.dst.info_v.displacements = leader_disps;
args.args.dst.info_v.counts = leader_counts;
args.args.dst.info_v.buffer = args_old.args.dst.info_v.buffer;
UCC_CHECK_GOTO(ucc_coll_init(SCORE_MAP(cl_team, NODE_LEADERS), &args,
&tasks[n_tasks]),
free_scratch, status);
n_tasks++;
}

if (SBGP_ENABLED(cl_team, NODE) &&
cl_team->top_sbgp != UCC_HIER_SBGP_NODE) {
args = args_old;
args.args.coll_type = UCC_COLL_TYPE_BCAST;
args.args.root = 0;
args.args.src.info.buffer = args_old.args.dst.info_v.buffer;
args.args.src.info.count = total_count;
args.args.src.info.datatype = args_old.args.dst.info_v.datatype;
args.args.src.info.mem_type = args_old.args.dst.info_v.mem_type;

/* If using tl_shm and the shmem segment size is less than total_count,
this node-level bcast will cause the allgatherv to fail and fall back */
UCC_CHECK_GOTO(
ucc_coll_init(SCORE_MAP(cl_team, NODE), &args, &tasks[n_tasks]),
free_scratch, status);
n_tasks++;

if (!is_contig) {
args = args_old;
UCC_CHECK_GOTO(
ucc_cl_hier_allgatherv_unpack_init(&args, team, &tasks[n_tasks]),
free_scratch, status);
n_tasks++;
}
}

UCC_CHECK_GOTO(ucc_event_manager_subscribe(
&schedule->super, UCC_EVENT_SCHEDULE_STARTED, tasks[0],
ucc_task_start_handler),
free_scratch, status);
UCC_CHECK_GOTO(
ucc_schedule_add_task(schedule, tasks[0]), free_scratch, status);
for (i = 1; i < n_tasks; i++) {
UCC_CHECK_GOTO(
ucc_event_manager_subscribe(tasks[i - 1], UCC_EVENT_COMPLETED,
tasks[i], ucc_task_start_handler),
free_scratch, status);
UCC_CHECK_GOTO(
ucc_schedule_add_task(schedule, tasks[i]), free_scratch, status);
}

schedule->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
schedule->super.post = ucc_cl_hier_allgatherv_start;
schedule->super.finalize = ucc_cl_hier_allgatherv_finalize;
*task = &schedule->super;
return UCC_OK;

free_scratch:
ucc_mc_free(cl_schedule->scratch);
free_sched:
for (i = 0; i < n_tasks; i++) {
tasks[i]->finalize(tasks[i]);
}
ucc_cl_hier_put_schedule(schedule);
return status;
}
Loading

0 comments on commit 9ee9a17

Please sign in to comment.