Skip to content

Commit

Permalink
TL/MLX5: add device mem mcast bcast (#989)
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB authored Sep 5, 2024
1 parent 7f85fba commit 313f2da
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 31 deletions.
1 change: 1 addition & 0 deletions src/components/mc/ucc_mc.c
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ ucc_status_t ucc_mc_get_attr(ucc_mc_attr_t *attr, ucc_memory_type_t mem_type)
return mc->get_attr(attr);
}

/* TODO: add the flexbility to bypass the mpool if the user asks for it */
UCC_MC_PROFILE_FUNC(ucc_status_t, ucc_mc_alloc, (h_ptr, size, mem_type),
ucc_mc_buffer_header_t **h_ptr, size_t size,
ucc_memory_type_t mem_type)
Expand Down
4 changes: 4 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "components/tl/ucc_tl_log.h"
#include "utils/ucc_rcache.h"
#include "core/ucc_service_coll.h"
#include "components/mc/ucc_mc.h"

#define POLL_PACKED 16
#define REL_DONE ((void*)-1)
Expand Down Expand Up @@ -98,6 +99,7 @@ typedef struct mcast_coll_comm_init_spec {
int scq_moderation;
int wsize;
int max_eager;
int cuda_mem_enabled;
void *oob;
} ucc_tl_mlx5_mcast_coll_comm_init_spec_t;

Expand Down Expand Up @@ -261,6 +263,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
int pending_recv;
struct ibv_mr *pp_mr;
char *pp_buf;
ucc_mc_buffer_header_t *pp_buf_header;
struct pp_packet *pp;
uint32_t psn;
uint32_t last_psn;
Expand Down Expand Up @@ -293,6 +296,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
int n_prep_reliable;
int n_mcast_reliable;
int wsize;
int cuda_mem_enabled;
ucc_tl_mlx5_mcast_join_info_t *group_setup_info;
ucc_service_coll_req_t *group_setup_info_req;
ucc_tl_mlx5_mcast_service_coll_t service_coll;
Expand Down
17 changes: 17 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,23 @@ void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task)
}
}

ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team)
{
ucc_tl_mlx5_team_t *mlx5_team = ucc_derived_of(team, ucc_tl_mlx5_team_t);
ucc_tl_mlx5_mcast_coll_comm_t *comm = mlx5_team->mcast->mcast_comm;
ucc_coll_args_t *args = &coll_args->args;

if ((comm->cuda_mem_enabled &&
args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA) ||
(!comm->cuda_mem_enabled &&
args->src.info.mem_type == UCC_MEMORY_TYPE_HOST)) {
return UCC_OK;
}

return UCC_ERR_NO_RESOURCE;
}

ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task)
{
task->super.post = ucc_tl_mlx5_mcast_bcast_start;
Expand Down
2 changes: 2 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task);

ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req);

ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team);
#endif
10 changes: 8 additions & 2 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,12 @@ ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
return UCC_ERR_NO_RESOURCE;
}

comm->max_inline = qp_init_attr.cap.max_inline_data;
if (comm->cuda_mem_enabled) {
/* max inline send otherwise it segfault during ibv send */
comm->max_inline = 0;
} else {
comm->max_inline = qp_init_attr.cap.max_inline_data;
}

return UCC_OK;
}
Expand Down Expand Up @@ -609,6 +614,7 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
return UCC_ERR_NO_RESOURCE;
}
}

if (comm->grh_buf) {
ucc_free(comm->grh_buf);
}
Expand All @@ -626,7 +632,7 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
}

if (comm->pp_buf) {
ucc_free(comm->pp_buf);
ucc_mc_free(comm->pp_buf_header);
}

if (comm->call_rwr) {
Expand Down
9 changes: 8 additions & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t
int rc;
int length;
ucc_status_t status;
ucc_memory_type_t mem_type = comm->cuda_mem_enabled ? UCC_MEMORY_TYPE_CUDA
: UCC_MEMORY_TYPE_HOST;

for (i = 0; i < num_packets; i++) {
if (comm->params.sx_depth <=
Expand Down Expand Up @@ -75,7 +77,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t
if (zcopy) {
pp->context = (uintptr_t) PTR_OFFSET(req->ptr, offset);
} else {
memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length);
status = ucc_mc_memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length,
mem_type, mem_type);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->lib, "failed to copy cuda buffer");
return status;
}
ssg[0].addr = (uint64_t) pp->buf;
}

Expand Down
15 changes: 14 additions & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com
{
ucc_status_t status = UCC_OK;
void *dest;
ucc_memory_type_t mem_type;
ucc_assert(pp->psn >= req->start_psn &&
pp->psn < req->start_psn + req->num_packets);

Expand All @@ -379,7 +380,19 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com

if (pp->length > 0 ) {
dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm);
memcpy(dest, (void*) pp->buf, pp->length);

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
mem_type = UCC_MEMORY_TYPE_HOST;
}

status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length,
mem_type, mem_type);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->lib, "failed to copy buffer");
return status;
}
}

comm->r_window[pp->psn & (comm->wsize-1)] = pp;
Expand Down
85 changes: 58 additions & 27 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
#include "mcast/tl_mlx5_mcast_helper.h"
#include "mcast/tl_mlx5_mcast_service_coll.h"

static ucc_status_t ucc_tl_mlx5_check_gpudirect_driver()
{
const char *file = "/sys/kernel/mm/memory_peers/nv_mem/version";

if (!access(file, F_OK)) {
return UCC_OK;
}

return UCC_ERR_NO_RESOURCE;
}

ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,
ucc_tl_mlx5_mcast_team_t **mcast_team,
ucc_tl_mlx5_mcast_context_t *ctx,
Expand Down Expand Up @@ -88,23 +99,14 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,

memcpy(&comm->params, conf_params, sizeof(*conf_params));

comm->wsize = conf_params->wsize;
comm->max_eager = conf_params->max_eager;
comm->comm_id = team_params->id;
comm->ctx = mcast_context;
comm->grh_buf = (char *)ucc_malloc(GRH_LENGTH * sizeof(char), "grh_buf");
if (!comm->grh_buf) {
status = UCC_ERR_NO_MEMORY;
goto cleanup;
}
comm->wsize = conf_params->wsize;
comm->max_eager = conf_params->max_eager;
comm->cuda_mem_enabled = conf_params->cuda_mem_enabled;
comm->comm_id = team_params->id;
comm->ctx = mcast_context;

memset(comm->grh_buf, 0, GRH_LENGTH);

comm->grh_mr = ibv_reg_mr(mcast_context->pd, comm->grh_buf, GRH_LENGTH,
IBV_ACCESS_REMOTE_WRITE |
IBV_ACCESS_LOCAL_WRITE);
if (!comm->grh_mr) {
tl_error(mcast_context->lib, "could not register memory for GRH, errno %d", errno);
if (comm->cuda_mem_enabled && (UCC_OK != ucc_tl_mlx5_check_gpudirect_driver())) {
tl_warn(mcast_context->lib, "cuda-aware mcast not available as gpu direct is not ready");
status = UCC_ERR_NO_RESOURCE;
goto cleanup;
}
Expand Down Expand Up @@ -162,9 +164,10 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,

ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_comm_t *comm)
{
ucc_status_t status;
size_t page_size;
int buf_size, i, ret;
ucc_status_t status;
size_t page_size;
int buf_size, i, ret;
ucc_memory_type_t supported_mem_type;

status = ucc_tl_mlx5_mcast_init_qps(comm->ctx, comm);
if (UCC_OK != status) {
Expand Down Expand Up @@ -197,19 +200,47 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_
comm->pending_recv = 0;
comm->buf_n = comm->params.rx_depth * 2;

ret = posix_memalign((void**) &comm->pp_buf, page_size, buf_size * comm->buf_n);
if (ret) {
tl_error(comm->ctx->lib, "posix_memalign failed");
return UCC_ERR_NO_MEMORY;
supported_mem_type = comm->cuda_mem_enabled ? UCC_MEMORY_TYPE_CUDA
: UCC_MEMORY_TYPE_HOST;

comm->grh_buf = ucc_malloc(GRH_LENGTH * sizeof(char), "grh");
if (ucc_unlikely(!comm->grh_buf)) {
tl_error(comm->ctx->lib, "failed to allocate grh memory");
return status;
}

status = ucc_mc_memset(comm->grh_buf, 0, GRH_LENGTH, UCC_MEMORY_TYPE_HOST);
if (status != UCC_OK) {
tl_error(comm->ctx->lib, "could not cuda memset");
goto error;
}

comm->grh_mr = ibv_reg_mr(comm->ctx->pd, comm->grh_buf, GRH_LENGTH,
IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE);
if (!comm->grh_mr) {
tl_error(comm->ctx->lib, "could not register device memory for GRH, errno %d", errno);
status = UCC_ERR_NO_RESOURCE;
goto error;
}

status = ucc_mc_alloc(&comm->pp_buf_header, buf_size * comm->buf_n, supported_mem_type);
comm->pp_buf = comm->pp_buf_header->addr;
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->ctx->lib, "failed to allocate cuda memory");
goto error;
}

status = ucc_mc_memset(comm->pp_buf, 0, buf_size * comm->buf_n, supported_mem_type);
if (status != UCC_OK) {
tl_error(comm->ctx->lib, "could not memset");
goto error;
}

memset(comm->pp_buf, 0, buf_size * comm->buf_n);

comm->pp_mr = ibv_reg_mr(comm->ctx->pd, comm->pp_buf, buf_size * comm->buf_n,
IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE);
if (!comm->pp_mr) {
tl_error(comm->ctx->lib, "could not register pp_buf mr, errno %d", errno);
status = UCC_ERR_NO_MEMORY;
tl_error(comm->ctx->lib, "could not register pp_buf device mr, errno %d", errno);
status = UCC_ERR_NO_RESOURCE;
goto error;
}

Expand Down
4 changes: 4 additions & 0 deletions src/components/tl/mlx5/tl_mlx5.c
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = {
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.max_eager),
UCC_CONFIG_TYPE_MEMUNITS},

{"MCAST_CUDA_MEM_ENABLE", "0", "Enable GPU CUDA memory support for Mcast. GPUDirect RDMA must be enabled",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.cuda_mem_enabled),
UCC_CONFIG_TYPE_BOOL},

{NULL}};

static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = {
Expand Down
5 changes: 5 additions & 0 deletions src/components/tl/mlx5/tl_mlx5_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args,
tl_trace(team->context->lib, "mcast bcast not supported for active sets");
return UCC_ERR_NOT_SUPPORTED;
}

if (UCC_OK != ucc_tl_mlx5_mcast_check_memory_type_cap(coll_args, team)) {
tl_trace(team->context->lib, "mcast bcast not compatible with this memory type");
return UCC_ERR_NOT_SUPPORTED;
}

task = ucc_tl_mlx5_get_task(coll_args, team);

Expand Down

0 comments on commit 313f2da

Please sign in to comment.