Skip to content

Commit

Permalink
Asymmetric memory (#1000)
Browse files Browse the repository at this point in the history
* CORE: Implement weak asymmetric mem with gtests

* CORE: Fix asymmetric bug

---------

Co-authored-by: Nicholas Sarkauskas <[email protected]>
  • Loading branch information
nsarka and Nicholas Sarkauskas authored Sep 5, 2024
1 parent 2ada313 commit 7f85fba
Show file tree
Hide file tree
Showing 8 changed files with 761 additions and 20 deletions.
7 changes: 2 additions & 5 deletions src/coll_score/ucc_coll_score_map.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,12 @@ static ucc_status_t ucc_coll_score_map_lookup(ucc_score_map_t *map,
ucc_list_link_t *list;
ucc_msg_range_t *r;

if (mt == UCC_MEMORY_TYPE_ASYMMETRIC) {
/* TODO */
ucc_debug("asymmetric memory type is not supported");
return UCC_ERR_NOT_SUPPORTED;
} else if (mt == UCC_MEMORY_TYPE_NOT_APPLY) {
if (mt == UCC_MEMORY_TYPE_NOT_APPLY) {
/* Temporary solution: for Barrier, Fanin, Fanout - use
"host" range list */
mt = UCC_MEMORY_TYPE_HOST;
}
ucc_assert(ucc_coll_args_is_mem_symmetric(&bargs->args, map->team_rank));
if (msgsize == UCC_MSG_SIZE_INVALID || msgsize == UCC_MSG_SIZE_ASYMMETRIC) {
/* These algorithms require global communication to get the same msgsize estimation.
Can't use msg ranges. Use msize 0 (assuming the range list should only contain 1
Expand Down
17 changes: 13 additions & 4 deletions src/components/base/ucc_base_iface.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,20 @@ enum {
UCC_BASE_CARGS_MAX_FRAG_COUNT = UCC_BIT(0)
};

typedef struct ucc_buffer_info_asymmetric_memtype {
union {
ucc_coll_buffer_info_t info;
ucc_coll_buffer_info_v_t info_v;
} old_asymmetric_buffer;
ucc_mc_buffer_header_t *scratch;
} ucc_buffer_info_asymmetric_memtype_t;

typedef struct ucc_base_coll_args {
uint64_t mask;
ucc_coll_args_t args;
ucc_team_t *team;
size_t max_frag_count;
uint64_t mask;
ucc_coll_args_t args;
ucc_team_t *team;
size_t max_frag_count;
ucc_buffer_info_asymmetric_memtype_t asymmetric_save_info;
} ucc_base_coll_args_t;

typedef ucc_status_t (*ucc_base_coll_init_fn_t)(ucc_base_coll_args_t *coll_args,
Expand Down
42 changes: 38 additions & 4 deletions src/core/ucc_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -230,19 +230,31 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
UCC_COPY_PARAM_BY_FIELD(&op_args.args, coll_args, UCC_COLL_ARGS_FIELD_FLAGS,
flags);

if (!ucc_coll_args_is_mem_symmetric(&op_args.args, team->rank) &&
ucc_coll_args_is_rooted(op_args.args.coll_type)) {
status = ucc_coll_args_init_asymmetric_buffer(&op_args.args, team,
&op_args.asymmetric_save_info);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("handling asymmetric memory failed");
return status;
}
} else {
op_args.asymmetric_save_info.scratch = NULL;
}

status = ucc_coll_init(team->score_map, &op_args, &task);
if (UCC_ERR_NOT_SUPPORTED == status) {
ucc_debug("failed to init collective: not supported");
return status;
goto free_scratch;
} else if (ucc_unlikely(status < 0)) {
ucc_error("failed to init collective: %s", ucc_status_string(status));
return status;
goto free_scratch;
}

task->flags |= UCC_COLL_TASK_FLAG_TOP_LEVEL;
if (task->flags & UCC_COLL_TASK_FLAG_EXECUTOR) {
task->flags |= UCC_COLL_TASK_FLAG_EXECUTOR_STOP;
coll_mem_type = ucc_coll_args_mem_type(coll_args, team->rank);
coll_mem_type = ucc_coll_args_mem_type(&op_args.args, team->rank);
switch(coll_mem_type) {
case UCC_MEMORY_TYPE_CUDA:
case UCC_MEMORY_TYPE_CUDA_MANAGED:
Expand All @@ -251,7 +263,7 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
case UCC_MEMORY_TYPE_ROCM:
coll_ee_type = UCC_EE_ROCM_STREAM;
break;
case UCC_MEMORY_TYPE_HOST:
case UCC_MEMORY_TYPE_HOST:
coll_ee_type = UCC_EE_CPU_THREAD;
break;
default:
Expand Down Expand Up @@ -299,6 +311,10 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,

coll_finalize:
task->finalize(task);
free_scratch:
if (op_args.asymmetric_save_info.scratch != NULL) {
ucc_mc_free(op_args.asymmetric_save_info.scratch);
}
return status;
}

Expand Down Expand Up @@ -341,6 +357,17 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_post, (request),
}
}

if (task->bargs.asymmetric_save_info.scratch != NULL &&
(task->bargs.args.coll_type == UCC_COLL_TYPE_SCATTER ||
task->bargs.args.coll_type == UCC_COLL_TYPE_SCATTERV)) {
status = ucc_copy_asymmetric_buffer(task);
if (status != UCC_OK) {
ucc_error("failure copying in asymmetric buffer: %s",
ucc_status_string(status));
return status;
}
}

COLL_POST_STATUS_CHECK(task);
if (UCC_COLL_TIMEOUT_REQUIRED(task)) {
task->start_time = ucc_get_time();
Expand Down Expand Up @@ -402,6 +429,13 @@ ucc_status_t ucc_collective_finalize_internal(ucc_coll_task_t *task)
return UCC_ERR_INVALID_PARAM;
}

if (task->bargs.asymmetric_save_info.scratch) {
st = ucc_coll_args_free_asymmetric_buffer(task);
if (ucc_unlikely(st != UCC_OK)) {
ucc_error("error freeing asymmetric buf: %s", ucc_status_string(st));
}
}

if (task->executor) {
st = ucc_ee_executor_finalize(task->executor);
if (ucc_unlikely(st != UCC_OK)) {
Expand Down
11 changes: 11 additions & 0 deletions src/schedule/ucc_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "utils/ucc_coll_utils.h"
#include "components/base/ucc_base_iface.h"
#include "components/ec/ucc_ec.h"
#include "components/mc/ucc_mc.h"

#define MAX_LISTENERS 4

Expand Down Expand Up @@ -185,6 +186,16 @@ static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task)
with schedules are not released during a callback (if set). */

if (ucc_likely(status == UCC_OK)) {
ucc_buffer_info_asymmetric_memtype_t *save = &task->bargs.asymmetric_save_info;
if (save->scratch &&
task->bargs.args.coll_type != UCC_COLL_TYPE_SCATTERV &&
task->bargs.args.coll_type != UCC_COLL_TYPE_SCATTER) {
status = ucc_copy_asymmetric_buffer(task);
if (status != UCC_OK) {
ucc_error("failure copying out asymmetric buffer: %s",
ucc_status_string(status));
}
}
status = ucc_event_manager_notify(task, UCC_EVENT_COMPLETED);
} else {
/* error in task status */
Expand Down
182 changes: 176 additions & 6 deletions src/utils/ucc_coll_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ucc_memory_type_t ucc_mem_type_from_str(const char *str)
return UCC_MEMORY_TYPE_LAST;
}

static inline int
int
ucc_coll_args_is_mem_symmetric(const ucc_coll_args_t *args,
ucc_rank_t rank)
{
Expand Down Expand Up @@ -94,6 +94,180 @@ ucc_coll_args_is_mem_symmetric(const ucc_coll_args_t *args,
return 0;
}


/* If this is the root and the src/dst buffers are asymmetric, one buffer needs
to have a new allocation to make the mem types match. If that buffer was the
dst buffer, copy the result back into the old dst on task completion */
ucc_status_t
ucc_coll_args_init_asymmetric_buffer(ucc_coll_args_t *args,
ucc_team_h team,
ucc_buffer_info_asymmetric_memtype_t *save_info)
{
ucc_status_t status = UCC_OK;

if (UCC_IS_INPLACE(*args)) {
return UCC_ERR_INVALID_PARAM;
}
switch (args->coll_type) {
case UCC_COLL_TYPE_REDUCE:
case UCC_COLL_TYPE_GATHER:
{
ucc_memory_type_t mem_type = args->src.info.mem_type;
if (args->coll_type == UCC_COLL_TYPE_SCATTERV) {
mem_type = args->src.info_v.mem_type;
}
memcpy(&save_info->old_asymmetric_buffer.info,
&args->dst.info, sizeof(ucc_coll_buffer_info_t));
status = ucc_mc_alloc(&save_info->scratch,
ucc_dt_size(args->dst.info.datatype) *
args->dst.info.count,
mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement "
"memory for asymmetric buffer");
return status;
}
args->dst.info.buffer = save_info->scratch->addr;
args->dst.info.mem_type = mem_type;
return UCC_OK;
}
case UCC_COLL_TYPE_GATHERV:
{
memcpy(&save_info->old_asymmetric_buffer.info_v,
&args->dst.info_v, sizeof(ucc_coll_buffer_info_v_t));
status = ucc_mc_alloc(&save_info->scratch,
ucc_coll_args_get_v_buffer_size(args,
args->dst.info_v.counts,
args->dst.info_v.displacements,
team->size),
args->src.info.mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement "
"memory for asymmetric buffer");
return status;
}
args->dst.info_v.buffer = save_info->scratch->addr;
args->dst.info_v.mem_type = args->src.info.mem_type;
return UCC_OK;
}
case UCC_COLL_TYPE_SCATTER:
{
ucc_memory_type_t mem_type = args->dst.info.mem_type;
memcpy(&save_info->old_asymmetric_buffer.info,
&args->src.info, sizeof(ucc_coll_buffer_info_t));
status = ucc_mc_alloc(&save_info->scratch,
ucc_dt_size(args->src.info.datatype) * args->src.info.count,
mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement "
"memory for asymmetric buffer");
return status;
}
args->src.info.buffer = save_info->scratch->addr;
args->src.info.mem_type = mem_type;
return UCC_OK;
}
case UCC_COLL_TYPE_SCATTERV:
{
ucc_memory_type_t mem_type = args->dst.info.mem_type;
memcpy(&save_info->old_asymmetric_buffer.info_v,
&args->src.info_v, sizeof(ucc_coll_buffer_info_v_t));
status = ucc_mc_alloc(&save_info->scratch,
ucc_coll_args_get_v_buffer_size(args,
args->src.info_v.counts,
args->src.info_v.displacements,
team->size),
mem_type);
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to allocate replacement "
"memory for asymmetric buffer");
return status;
}
args->src.info_v.buffer = save_info->scratch->addr;
args->src.info_v.mem_type = mem_type;
return UCC_OK;
}
default:
break;
}
return UCC_ERR_INVALID_PARAM;
}

ucc_status_t
ucc_coll_args_free_asymmetric_buffer(ucc_coll_task_t *task)
{
ucc_status_t status = UCC_OK;
ucc_buffer_info_asymmetric_memtype_t *save = &task->bargs.asymmetric_save_info;

if (UCC_IS_INPLACE(task->bargs.args)) {
return UCC_ERR_INVALID_PARAM;
}

if (save->scratch == NULL) {
ucc_error("failure trying to free NULL asymmetric buffer");
}

status = ucc_mc_free(save->scratch);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("error freeing scratch asymmetric buffer: %s",
ucc_status_string(status));
}
save->scratch = NULL;

return status;
}

ucc_status_t ucc_copy_asymmetric_buffer(ucc_coll_task_t *task)
{
ucc_status_t status = UCC_OK;
ucc_coll_args_t *coll_args = &task->bargs.args;
ucc_buffer_info_asymmetric_memtype_t *save = &task->bargs.asymmetric_save_info;
ucc_rank_t size = task->team->params.size;

if(task->bargs.args.coll_type == UCC_COLL_TYPE_SCATTERV) {
// copy in
status = ucc_mc_memcpy(save->scratch->addr,
save->old_asymmetric_buffer.info_v.buffer,
ucc_coll_args_get_v_buffer_size(coll_args,
save->old_asymmetric_buffer.info_v.counts,
save->old_asymmetric_buffer.info_v.displacements,
size),
save->scratch->mt,
save->old_asymmetric_buffer.info_v.mem_type);
} else if(task->bargs.args.coll_type == UCC_COLL_TYPE_SCATTER) {
// copy in
status = ucc_mc_memcpy(save->scratch->addr,
save->old_asymmetric_buffer.info.buffer,
ucc_dt_size(save->old_asymmetric_buffer.info.datatype) *
save->old_asymmetric_buffer.info.count,
save->scratch->mt,
save->old_asymmetric_buffer.info.mem_type);
} else if(task->bargs.args.coll_type == UCC_COLL_TYPE_GATHERV) {
// copy out
status = ucc_mc_memcpy(save->old_asymmetric_buffer.info_v.buffer,
save->scratch->addr,
ucc_coll_args_get_v_buffer_size(coll_args,
save->old_asymmetric_buffer.info_v.counts,
save->old_asymmetric_buffer.info_v.displacements,
size),
save->old_asymmetric_buffer.info_v.mem_type,
save->scratch->mt);
} else {
// copy out
status = ucc_mc_memcpy(save->old_asymmetric_buffer.info.buffer,
save->scratch->addr,
ucc_dt_size(save->old_asymmetric_buffer.info.datatype) *
save->old_asymmetric_buffer.info.count,
save->old_asymmetric_buffer.info.mem_type,
save->scratch->mt);
}
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("error copying back to old asymmetric buffer: %s",
ucc_status_string(status));
}
return status;
}

int ucc_coll_args_is_predefined_dt(const ucc_coll_args_t *args, ucc_rank_t rank)
{
switch (args->coll_type) {
Expand Down Expand Up @@ -163,9 +337,6 @@ ucc_memory_type_t ucc_coll_args_mem_type(const ucc_coll_args_t *args,
{
ucc_rank_t root = args->root;

if (!ucc_coll_args_is_mem_symmetric(args, rank)) {
return UCC_MEMORY_TYPE_ASYMMETRIC;
}
switch (args->coll_type) {
case UCC_COLL_TYPE_BARRIER:
case UCC_COLL_TYPE_FANIN:
Expand All @@ -180,7 +351,6 @@ ucc_memory_type_t ucc_coll_args_mem_type(const ucc_coll_args_t *args,
return args->dst.info.mem_type;
case UCC_COLL_TYPE_ALLGATHERV:
case UCC_COLL_TYPE_REDUCE_SCATTERV:
return args->dst.info_v.mem_type;
case UCC_COLL_TYPE_ALLTOALLV:
return args->dst.info_v.mem_type;
case UCC_COLL_TYPE_REDUCE:
Expand Down Expand Up @@ -323,7 +493,7 @@ ucc_ep_map_t ucc_ep_map_from_array_64(uint64_t **array, ucc_rank_t size,
need_free, 1);
}

static inline int ucc_coll_args_is_rooted(ucc_coll_type_t ct)
int ucc_coll_args_is_rooted(ucc_coll_type_t ct)
{
if (ct == UCC_COLL_TYPE_REDUCE || ct == UCC_COLL_TYPE_BCAST ||
ct == UCC_COLL_TYPE_GATHER || ct == UCC_COLL_TYPE_SCATTER ||
Expand Down
Loading

0 comments on commit 7f85fba

Please sign in to comment.