Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TOOLS: add persistent colls to perftest #863

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/utils/ucc_coll_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ void ucc_coll_args_str(const ucc_coll_args_t *args, ucc_rank_t trank,
strncat(hdr, tmp, left);
}

if (UCC_IS_PERSISTENT(*args)) {
ucc_snprintf_safe(tmp, sizeof(tmp), " persistent");
left = COLL_ARGS_HEADER_STR_MAX_SIZE - strlen(hdr);
strncat(hdr, tmp, left);
}

if (ucc_coll_args_is_rooted(ct)) {
ucc_snprintf_safe(tmp, sizeof(tmp), " root %u", root);
left = COLL_ARGS_HEADER_STR_MAX_SIZE - strlen(hdr);
Expand Down
63 changes: 44 additions & 19 deletions tools/perf/ucc_pt_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,54 +18,61 @@ ucc_pt_benchmark::ucc_pt_benchmark(ucc_pt_benchmark_config cfg,
{
switch (cfg.op_type) {
case UCC_PT_OP_TYPE_ALLGATHER:
coll = new ucc_pt_coll_allgather(cfg.dt, cfg.mt, cfg.inplace, comm);
coll = new ucc_pt_coll_allgather(cfg.dt, cfg.mt, cfg.inplace,
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_ALLGATHERV:
coll = new ucc_pt_coll_allgatherv(cfg.dt, cfg.mt, cfg.inplace, comm);
coll = new ucc_pt_coll_allgatherv(cfg.dt, cfg.mt, cfg.inplace,
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_ALLREDUCE:
coll = new ucc_pt_coll_allreduce(cfg.dt, cfg.mt, cfg.op, cfg.inplace,
comm);
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_ALLTOALL:
coll = new ucc_pt_coll_alltoall(cfg.dt, cfg.mt, cfg.inplace, comm);
coll = new ucc_pt_coll_alltoall(cfg.dt, cfg.mt, cfg.inplace,
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_ALLTOALLV:
coll = new ucc_pt_coll_alltoallv(cfg.dt, cfg.mt, cfg.inplace, comm);
coll = new ucc_pt_coll_alltoallv(cfg.dt, cfg.mt, cfg.inplace,
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_BARRIER:
coll = new ucc_pt_coll_barrier(comm);
break;
case UCC_PT_OP_TYPE_BCAST:
coll = new ucc_pt_coll_bcast(cfg.dt, cfg.mt, cfg.root_shift, comm);
coll = new ucc_pt_coll_bcast(cfg.dt, cfg.mt, cfg.root_shift,
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_GATHER:
coll = new ucc_pt_coll_gather(cfg.dt, cfg.mt, cfg.inplace,
cfg.root_shift, comm);
cfg.persistent, cfg.root_shift, comm);
break;
case UCC_PT_OP_TYPE_GATHERV:
coll = new ucc_pt_coll_gatherv(cfg.dt, cfg.mt, cfg.inplace,
cfg.root_shift, comm);
cfg.persistent, cfg.root_shift, comm);
break;
case UCC_PT_OP_TYPE_REDUCE:
coll = new ucc_pt_coll_reduce(cfg.dt, cfg.mt, cfg.op, cfg.inplace,
cfg.root_shift, comm);
cfg.persistent, cfg.root_shift, comm);
break;
case UCC_PT_OP_TYPE_REDUCE_SCATTER:
coll = new ucc_pt_coll_reduce_scatter(cfg.dt, cfg.mt, cfg.op,
cfg.inplace, comm);
cfg.inplace,
cfg.persistent, comm);
break;
case UCC_PT_OP_TYPE_REDUCE_SCATTERV:
coll = new ucc_pt_coll_reduce_scatterv(cfg.dt, cfg.mt, cfg.op,
cfg.inplace, comm);
cfg.inplace, cfg.persistent,
comm);
break;
case UCC_PT_OP_TYPE_SCATTER:
coll = new ucc_pt_coll_scatter(cfg.dt, cfg.mt, cfg.inplace,
cfg.root_shift, comm);
cfg.persistent, cfg.root_shift, comm);
break;
case UCC_PT_OP_TYPE_SCATTERV:
coll = new ucc_pt_coll_scatterv(cfg.dt, cfg.mt, cfg.inplace,
cfg.root_shift, comm);
cfg.persistent, cfg.root_shift, comm);
break;
case UCC_PT_OP_TYPE_MEMCPY:
coll = new ucc_pt_op_memcpy(cfg.dt, cfg.mt, cfg.n_bufs, comm);
Expand Down Expand Up @@ -137,10 +144,11 @@ ucc_status_t ucc_pt_benchmark::run_single_coll_test(ucc_coll_args_t args,
double &time)
noexcept
{
const bool triggered = config.triggered;
ucc_team_h team = comm->get_team();
ucc_context_h ctx = comm->get_context();
ucc_status_t st = UCC_OK;
const bool triggered = config.triggered;
const bool persistent = config.persistent;
ucc_team_h team = comm->get_team();
ucc_context_h ctx = comm->get_context();
ucc_status_t st = UCC_OK;
ucc_coll_req_h req;
ucc_ee_h ee;
ucc_ev_t comp_ev, *post_ev;
Expand All @@ -161,10 +169,18 @@ ucc_status_t ucc_pt_benchmark::run_single_coll_test(ucc_coll_args_t args,
comp_ev.ev_context_size = 0;
}

if (persistent) {
UCCCHECK_GOTO(ucc_collective_init(&args, &req, team), exit_err, st);
}

args.root = config.root % comm->get_size();
for (int i = 0; i < nwarmup + niter; i++) {
double s = get_time_us();
UCCCHECK_GOTO(ucc_collective_init(&args, &req, team), exit_err, st);

if (!persistent) {
UCCCHECK_GOTO(ucc_collective_init(&args, &req, team), exit_err, st);
}

if (triggered) {
comp_ev.req = req;
UCCCHECK_GOTO(ucc_collective_triggered_post(ee, &comp_ev),
Expand All @@ -175,12 +191,16 @@ ucc_status_t ucc_pt_benchmark::run_single_coll_test(ucc_coll_args_t args,
} else {
UCCCHECK_GOTO(ucc_collective_post(req), free_req, st);
}

st = ucc_collective_test(req);
while (st > 0) {
UCCCHECK_GOTO(ucc_context_progress(ctx), free_req, st);
st = ucc_collective_test(req);
}
ucc_collective_finalize(req);

if (!persistent) {
ucc_collective_finalize(req);
}
double f = get_time_us();
if (st != UCC_OK) {
goto exit_err;
Expand All @@ -191,6 +211,11 @@ ucc_status_t ucc_pt_benchmark::run_single_coll_test(ucc_coll_args_t args,
args.root = (args.root + config.root_shift) % comm->get_size();
UCCCHECK_GOTO(comm->barrier(), exit_err, st);
}

if (persistent) {
ucc_collective_finalize(req);
}

if (niter != 0) {
time /= niter;
}
Expand Down
32 changes: 18 additions & 14 deletions tools/perf/ucc_pt_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class ucc_pt_coll {
class ucc_pt_coll_allgather: public ucc_pt_coll {
public:
ucc_pt_coll_allgather(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, ucc_pt_comm *communicator);
bool is_inplace, bool is_persistent,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override;
Expand All @@ -67,7 +68,8 @@ class ucc_pt_coll_allgather: public ucc_pt_coll {
class ucc_pt_coll_allgatherv: public ucc_pt_coll {
public:
ucc_pt_coll_allgatherv(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, ucc_pt_comm *communicator);
bool is_inplace, bool is_persistent,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
};
Expand All @@ -76,7 +78,7 @@ class ucc_pt_coll_allreduce: public ucc_pt_coll {
public:
ucc_pt_coll_allreduce(ucc_datatype_t dt, ucc_memory_type mt,
ucc_reduction_op_t op, bool is_inplace,
ucc_pt_comm *communicator);
bool is_persistent, ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override;
Expand All @@ -85,7 +87,8 @@ class ucc_pt_coll_allreduce: public ucc_pt_coll {
class ucc_pt_coll_alltoall: public ucc_pt_coll {
public:
ucc_pt_coll_alltoall(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, ucc_pt_comm *communicator);
bool is_inplace, bool is_persistent,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override;
Expand All @@ -94,7 +97,8 @@ class ucc_pt_coll_alltoall: public ucc_pt_coll {
class ucc_pt_coll_alltoallv: public ucc_pt_coll {
public:
ucc_pt_coll_alltoallv(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, ucc_pt_comm *communicator);
bool is_inplace, bool is_persistent,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
};
Expand All @@ -109,7 +113,7 @@ class ucc_pt_coll_barrier: public ucc_pt_coll {
class ucc_pt_coll_bcast: public ucc_pt_coll {
public:
ucc_pt_coll_bcast(ucc_datatype_t dt, ucc_memory_type mt, int root_shift,
ucc_pt_comm *communicator);
bool is_persistent, ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override;
Expand All @@ -118,7 +122,7 @@ class ucc_pt_coll_bcast: public ucc_pt_coll {
class ucc_pt_coll_gather: public ucc_pt_coll {
public:
ucc_pt_coll_gather(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, int root_shift,
bool is_inplace, bool is_persistent, int root_shift,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
Expand All @@ -128,7 +132,7 @@ class ucc_pt_coll_gather: public ucc_pt_coll {
class ucc_pt_coll_gatherv: public ucc_pt_coll {
public:
ucc_pt_coll_gatherv(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, int root_shift,
bool is_inplace, bool is_persistent, int root_shift,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
Expand All @@ -137,8 +141,8 @@ class ucc_pt_coll_gatherv: public ucc_pt_coll {
class ucc_pt_coll_reduce: public ucc_pt_coll {
public:
ucc_pt_coll_reduce(ucc_datatype_t dt, ucc_memory_type mt,
ucc_reduction_op_t op, bool is_inplace, int root_shift,
ucc_pt_comm *communicator);
ucc_reduction_op_t op, bool is_inplace, bool is_persistent,
int root_shift, ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override;
Expand All @@ -148,7 +152,7 @@ class ucc_pt_coll_reduce_scatter: public ucc_pt_coll {
public:
ucc_pt_coll_reduce_scatter(ucc_datatype_t dt, ucc_memory_type mt,
ucc_reduction_op_t op, bool is_inplace,
ucc_pt_comm *communicator);
bool is_persistent, ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
float get_bw(float time_ms, int grsize, ucc_pt_test_args_t args) override;
Expand All @@ -158,15 +162,15 @@ class ucc_pt_coll_reduce_scatterv: public ucc_pt_coll {
public:
ucc_pt_coll_reduce_scatterv(ucc_datatype_t dt, ucc_memory_type mt,
ucc_reduction_op_t op, bool is_inplace,
ucc_pt_comm *communicator);
bool is_persistent, ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
};

class ucc_pt_coll_scatter: public ucc_pt_coll {
public:
ucc_pt_coll_scatter(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, int root_shift,
bool is_inplace, bool is_persistent, int root_shift,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
Expand All @@ -176,7 +180,7 @@ class ucc_pt_coll_scatter: public ucc_pt_coll {
class ucc_pt_coll_scatterv: public ucc_pt_coll {
public:
ucc_pt_coll_scatterv(ucc_datatype_t dt, ucc_memory_type mt,
bool is_inplace, int root_shift,
bool is_inplace, bool is_persistent, int root_shift,
ucc_pt_comm *communicator);
ucc_status_t init_args(size_t count, ucc_pt_test_args_t &args) override;
void free_args(ucc_pt_test_args_t &args) override;
Expand Down
14 changes: 11 additions & 3 deletions tools/perf/ucc_pt_coll_allgather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

ucc_pt_coll_allgather::ucc_pt_coll_allgather(ucc_datatype_t dt,
ucc_memory_type mt, bool is_inplace,
bool is_persistent,
ucc_pt_comm *communicator) : ucc_pt_coll(communicator)

{
Expand All @@ -21,16 +22,23 @@ ucc_pt_coll_allgather::ucc_pt_coll_allgather(ucc_datatype_t dt,
has_bw_ = true;
root_shift_ = 0;

coll_args.mask = 0;
coll_args.coll_type = UCC_COLL_TYPE_ALLGATHER;
coll_args.mask = 0;
coll_args.flags = 0;
coll_args.coll_type = UCC_COLL_TYPE_ALLGATHER;
coll_args.src.info.datatype = dt;
coll_args.src.info.mem_type = mt;
coll_args.dst.info.datatype = dt;
coll_args.dst.info.mem_type = mt;

if (is_inplace) {
coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}

if (is_persistent) {
coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT;
}
}

ucc_status_t ucc_pt_coll_allgather::init_args(size_t single_rank_count,
Expand Down
18 changes: 13 additions & 5 deletions tools/perf/ucc_pt_coll_allgatherv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

ucc_pt_coll_allgatherv::ucc_pt_coll_allgatherv(ucc_datatype_t dt,
ucc_memory_type mt, bool is_inplace,
bool is_persistent,
ucc_pt_comm *communicator) : ucc_pt_coll(communicator)
{
has_inplace_ = true;
Expand All @@ -20,16 +21,23 @@ ucc_pt_coll_allgatherv::ucc_pt_coll_allgatherv(ucc_datatype_t dt,
has_bw_ = false;
root_shift_ = 0;

coll_args.mask = 0;
coll_args.coll_type = UCC_COLL_TYPE_ALLGATHERV;
coll_args.src.info.datatype = dt;
coll_args.src.info.mem_type = mt;
coll_args.mask = 0;
coll_args.flags = 0;
coll_args.coll_type = UCC_COLL_TYPE_ALLGATHERV;
coll_args.src.info.datatype = dt;
coll_args.src.info.mem_type = mt;
coll_args.dst.info_v.datatype = dt;
coll_args.dst.info_v.mem_type = mt;

if (is_inplace) {
coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}

if (is_persistent) {
coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT;
}
}

ucc_status_t ucc_pt_coll_allgatherv::init_args(size_t count,
Expand Down
22 changes: 15 additions & 7 deletions tools/perf/ucc_pt_coll_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

ucc_pt_coll_allreduce::ucc_pt_coll_allreduce(ucc_datatype_t dt,
ucc_memory_type mt, ucc_reduction_op_t op,
bool is_inplace,
bool is_inplace, bool is_persistent,
ucc_pt_comm *communicator) : ucc_pt_coll(communicator)
{
has_inplace_ = true;
Expand All @@ -21,17 +21,25 @@ ucc_pt_coll_allreduce::ucc_pt_coll_allreduce(ucc_datatype_t dt,
has_bw_ = true;
root_shift_ = 0;

coll_args.coll_type = UCC_COLL_TYPE_ALLREDUCE;
coll_args.mask = 0;
if (is_inplace) {
coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}
coll_args.mask = 0;
coll_args.flags = 0;
coll_args.coll_type = UCC_COLL_TYPE_ALLREDUCE;
coll_args.op = op;
coll_args.src.info.datatype = dt;
coll_args.dst.info.datatype = dt;
coll_args.src.info.mem_type = mt;
coll_args.dst.info.mem_type = mt;

if (is_inplace) {
coll_args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}

if (is_persistent) {
coll_args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll_args.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT;

}
}

ucc_status_t ucc_pt_coll_allreduce::init_args(size_t count,
Expand Down
Loading
Loading