-
Notifications
You must be signed in to change notification settings - Fork 103
UCC Allreduce example
valentin petrov edited this page Mar 11, 2022
·
2 revisions
The code snippet below demonstrates how the UCC API can be used to execute an Allreduce collective operation over a group of process. The code below is an MPI based application. MPI is only used to bootstrap the job: spawn processes and implement OOB (out-of-band) allgather exchange among the processes used for UCC wire-up.
Main steps to execute UCC allreduce:
- Read UCC lib configuration
- Initialize UCC library
- Read UCC context configuration
- Initialize UCC context
- Initialize UCC team
- Fill collective descriptor and initialize coll request
- Post collective
- Test for completion and progress UCC
- Clean up coll request
- Cleanup UCC
If UCC is compiled and installed into ${UCC_PATH} and MPI (mpicc/mpirun) is available in PATH then the cmd line below can be used:
mpicc ucc_allreduce.c -g -o ucc_allreduce -I${UCC_PATH}/include -L${UCC_PATH}/lib -lucc -Wl,-rpath="${UCC_PATH}/lib"
mpirun -x UCC_TLS=ucp -np 4 ./ucc_allreduce
#include <mpi.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <ucc/api/ucc.h>
#define STR(x) #x
#define UCC_CHECK(_call) \
if (UCC_OK != (_call)) { \
fprintf(stderr, "*** UCC TEST FAIL: %s\n", STR(_call)); \
MPI_Abort(MPI_COMM_WORLD, -1); \
}
static ucc_status_t oob_allgather(void *sbuf, void *rbuf, size_t msglen,
void *coll_info, void **req)
{
MPI_Comm comm = (MPI_Comm)coll_info;
MPI_Request request;
MPI_Iallgather(sbuf, msglen, MPI_BYTE, rbuf, msglen, MPI_BYTE, comm,
&request);
*req = (void *)request;
return UCC_OK;
}
static ucc_status_t oob_allgather_test(void *req)
{
MPI_Request request = (MPI_Request)req;
int completed;
MPI_Test(&request, &completed, MPI_STATUS_IGNORE);
return completed ? UCC_OK : UCC_INPROGRESS;
}
static ucc_status_t oob_allgather_free(void *req)
{
return UCC_OK;
}
/* Creates UCC team for a group of processes represented by MPI
communicator. UCC API provides different ways to create a team,
one of them is to use out-of-band (OOB) allgather provided by
the calling runtime. */
static ucc_team_h create_ucc_team(MPI_Comm comm, ucc_context_h ctx)
{
int rank, size;
ucc_team_h team;
ucc_team_params_t team_params;
ucc_status_t status;
MPI_Comm_rank(comm, &rank);
MPI_Comm_size(comm, &size);
team_params.mask = UCC_TEAM_PARAM_FIELD_OOB;
team_params.oob.allgather = oob_allgather;
team_params.oob.req_test = oob_allgather_test;
team_params.oob.req_free = oob_allgather_free;
team_params.oob.coll_info = (void*)comm;
team_params.oob.n_oob_eps = size;
team_params.oob.oob_ep = rank;
UCC_CHECK(ucc_team_create_post(&ctx, 1, &team_params, &team));
while (UCC_INPROGRESS == (status = ucc_team_create_test(team))) {
UCC_CHECK(ucc_context_progress(ctx));
};
if (UCC_OK != status) {
fprintf(stderr, "failed to create ucc team\n");
MPI_Abort(MPI_COMM_WORLD, status);
}
return team;
}
int main (int argc, char **argv) {
ucc_lib_config_h lib_config;
ucc_context_config_h ctx_config;
int rank, size, i;
ucc_team_h team;
ucc_context_h ctx;
ucc_lib_h lib;
size_t msglen;
size_t count;
int *sbuf, *rbuf;
ucc_coll_req_h req;
ucc_coll_args_t args;
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
/* Init ucc library */
ucc_lib_params_t lib_params = {
.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE,
.thread_mode = UCC_THREAD_SINGLE
};
UCC_CHECK(ucc_lib_config_read(NULL, NULL, &lib_config));
UCC_CHECK(ucc_init(&lib_params, lib_config, &lib));
ucc_lib_config_release(lib_config);
/* Init ucc context for a specified UCC_TEST_TLS */
ucc_context_params_t ctx_params = {
.mask = UCC_CONTEXT_PARAM_FIELD_OOB,
.oob.allgather = oob_allgather,
.oob.req_test = oob_allgather_test,
.oob.req_free = oob_allgather_free,
.oob.coll_info = (void*)MPI_COMM_WORLD,
.oob.n_oob_eps = size,
.oob.oob_ep = rank
};
UCC_CHECK(ucc_context_config_read(lib, NULL, &ctx_config));
/* UCC_CHECK(ucc_context_config_modify(ctx_config, "TLS", &lib_config)); */
UCC_CHECK(ucc_context_create(lib, &ctx_params, ctx_config, &ctx));
ucc_context_config_release(ctx_config);
team = create_ucc_team(MPI_COMM_WORLD, ctx);
count = argc > 1 ? atoi(argv[1]) : 1;
msglen = count * sizeof(int);
sbuf = malloc(msglen);
rbuf = malloc(msglen);
for (i = 0; i < count; i++) {
sbuf[i] = rank + 1;
rbuf[i] = 0;
}
args.mask = 0;
args.coll_type = UCC_COLL_TYPE_ALLREDUCE;
args.src.info.buffer = sbuf;
args.src.info.count = count;
args.src.info.datatype = UCC_DT_INT32;
args.src.info.mem_type = UCC_MEMORY_TYPE_HOST;
args.dst.info.buffer = rbuf;
args.dst.info.count = count;
args.dst.info.datatype = UCC_DT_INT32;
args.dst.info.mem_type = UCC_MEMORY_TYPE_HOST;
args.op = UCC_OP_SUM;
UCC_CHECK(ucc_collective_init(&args, &req, team));
UCC_CHECK(ucc_collective_post(req));
while (UCC_INPROGRESS == ucc_collective_test(req)) {
UCC_CHECK(ucc_context_progress(ctx));
}
ucc_collective_finalize(req);
/* Check result */
int sum = ((size + 1) * size) / 2;
for (i = 0; i < count; i++) {
if (rbuf[i] != sum) {
printf("ERROR at rank %d, pos %d, value %d, expected %d\n", rank, i, rbuf[i], sum);
break;
}
}
/* Cleanup UCC */
UCC_CHECK(ucc_team_destroy(team));
UCC_CHECK(ucc_context_destroy(ctx));
UCC_CHECK(ucc_finalize(lib));
MPI_Finalize();
free(sbuf);
free(rbuf);
return 0;
}