diff --git a/Makefile b/Makefile index cbbd9d0..d38d1fb 100644 --- a/Makefile +++ b/Makefile @@ -6,8 +6,8 @@ USE_ILUVATAR_COREX ?= 0 USE_CAMBRICON ?= 0 # set to empty if not provided -DEVICE_HOME ?= -CCL_HOME ?= +DEVICE_HOME ?= +CCL_HOME ?= ifeq ($(strip $(DEVICE_HOME)),) ifeq ($(USE_NVIDIA), 1) @@ -15,7 +15,7 @@ ifeq ($(strip $(DEVICE_HOME)),) else ifeq ($(USE_ILUVATAR_COREX), 1) DEVICE_HOME = /usr/local/corex else ifeq ($(USE_CAMBRICON), 1) - DEVICE_HOME = /torch/neuware_home + DEVICE_HOME = $(NEUWARE_HOME) else DEVICE_HOME = /usr/local/cuda endif @@ -27,7 +27,7 @@ ifeq ($(strip $(CCL_HOME)),) else ifeq ($(USE_ILUVATAR_COREX), 1) CCL_HOME = /usr/local/corex else ifeq ($(USE_CAMBRICON), 1) - CCL_HOME = /torch/neuware_home + CCL_HOME = $(NEUWARE_HOME) else CCL_HOME = /usr/local/nccl/build endif diff --git a/flagcx/adaptor/cncl_adaptor.cc b/flagcx/adaptor/cncl_adaptor.cc index 03a4942..24603ec 100644 --- a/flagcx/adaptor/cncl_adaptor.cc +++ b/flagcx/adaptor/cncl_adaptor.cc @@ -58,8 +58,11 @@ flagcxResult_t cnclAdaptorCommInitRank(flagcxHomoComm_t *comm, int nranks, flagc if (*comm == NULL) { flagcxCalloc(comm, 1); } - return (flagcxResult_t)cnclInitComms(&(*comm)->base, 1/*num_comm*/, &rank/*dev_list*/, - &rank/*rank_list*/, nranks, (cnclCliqueId *)commId); + unsigned int device_count = 0; + DEVCHECK(cnrtGetDeviceCount(&device_count)); + int dev_id = rank % device_count; + return (flagcxResult_t)c2f_ret_map[cnclInitComms(&(*comm)->base, 1/*num_comm*/, &dev_id/*dev_list*/, + &rank/*rank_list*/, nranks, (cnclCliqueId *)commId)]; } //TODO: unsupported diff --git a/flagcx/adaptor/mlu_adaptor.cc b/flagcx/adaptor/mlu_adaptor.cc index 5cef4f4..e7b6cda 100644 --- a/flagcx/adaptor/mlu_adaptor.cc +++ b/flagcx/adaptor/mlu_adaptor.cc @@ -17,7 +17,7 @@ flagcxResult_t mluAdaptorDeviceMemcpy(void *dst, void *src, size_t size, flagcxM if (stream == NULL) { DEVCHECK(cnrtMemcpy(dst, src, size, memcpy_type_map[type])); } else { - DEVCHECK(cnrtMemcpyAsync_V3(dst, src, size, stream->base, memcpy_type_map[type])); + DEVCHECK(cnrtMemcpyAsync_V2(dst, src, size, stream->base, memcpy_type_map[type])); } return flagcxSuccess; } @@ -73,12 +73,18 @@ flagcxResult_t mluAdaptorGetVendor(char *vendor) { } flagcxResult_t mluAdaptorGdrMemAlloc(void **ptr, size_t size, void *memHandle) { - // TODO: Implement GDR memory allocation + if (ptr == NULL) { + return flagcxInvalidArgument; + } + DEVCHECK(cnrtMalloc(ptr, size)); return flagcxSuccess; } flagcxResult_t mluAdaptorGdrMemFree(void *ptr, void *memHandle) { - // TODO: Implement GDR memory free + if (ptr == NULL) { + return flagcxSuccess; + } + DEVCHECK(cnrtFree(ptr)); return flagcxSuccess; }