Skip to content

Commit

Permalink
disable IB for now
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Feb 21, 2024
1 parent 4b3e27c commit 5bf684b
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 12 deletions.
2 changes: 1 addition & 1 deletion apps/nccl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ add_library(mscclpp_nccl_obj OBJECT)
target_sources(mscclpp_nccl_obj PRIVATE ${SOURCES})
target_sources(mscclpp_nccl_obj PUBLIC FILE_SET HEADERS FILES ${HEADERS})
target_include_directories(mscclpp_nccl_obj PRIVATE ${GPU_INCLUDE_DIRS} include)
target_link_libraries(mscclpp_nccl_obj PRIVATE ${GPU_LIBRARIES} mscclpp_obj)
target_link_libraries(mscclpp_nccl_obj PRIVATE ${GPU_LIBRARIES} PUBLIC mscclpp_obj)
set_target_properties(mscclpp_nccl_obj PROPERTIES LINKER_LANGUAGE CXX POSITION_INDEPENDENT_CODE 1 VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
if(USE_CUDA)
target_compile_definitions(mscclpp_nccl_obj PRIVATE USE_CUDA)
Expand Down
14 changes: 14 additions & 0 deletions apps/nccl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# NCCL Interfaces of MSCCL++

Compile

```bash
CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_APPS_NCCL=ON -DBUILD_PYTHON_BINDINGS=OFF ..
make -j
```

Run rccl-tests

```bash
mpirun -np 8 --bind-to numa --allow-run-as-root -x LD_PRELOAD="$MSCCLPP_BUILD/libmscclpp.so $MSCCLPP_BUILD/apps/nccl/libmscclpp_nccl.so" -x MSCCLPP_DEBUG=WARN -x MSCCLPP_DEBUG_SUBSYS=ALL -x NCCL_DEBUG=WARN ./build/all_reduce_perf -b 1K -e 256M -f 2 -d half -G 20 -w 10 -n 50
```
53 changes: 53 additions & 0 deletions apps/nccl/rccl_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
from mpi4py import MPI
import torch
from cupy.cuda import nccl

ROOT_RANK = 0
comm = MPI.COMM_WORLD
rank =comm.Get_rank()

is_group_root = rank == ROOT_RANK

world_size = comm.Get_size()

os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)

device_type = "cuda"
torch.cuda.set_device(0)
device_index = 0
device = torch.device(type=device_type, index=device_index)

if is_group_root:
id_ = nccl.get_unique_id()
else:
id_ = None

ranks = range(world_size)
id_, ranks = comm.bcast((id_, ranks), root=0)
group = nccl.NcclCommunicator(len(ranks), id_, rank)
print(f"{rank=}, {device=}, {group=}")

M = 1024
N = 4096
K = 2048
shape_a = (M,K)
shape_b = (K,N)
shape_c = (M,N)

a = torch.ones(shape_a, device="cuda")
b = torch.ones(shape_b, device="cuda")
c = torch.mm(a, b)

print(c)

nccl_op = nccl.NCCL_SUM
group.allReduce(
sendbuf=c.data_ptr(),
recvbuf=c.data_ptr(),
count=c.nelement(),
datatype=nccl.NCCL_FLOAT,
op=nccl_op,
stream=torch.cuda.current_stream().cuda_stream)

print(c)
23 changes: 12 additions & 11 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ static const int nRanksPerNode = 8;
// Only use scratch buffer for message size less then 1MB
static const int scratchSize = 1024 * 1024 * 8;

static const mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2,
mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5,
mscclpp::Transport::IB6, mscclpp::Transport::IB7};
// static const mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2,
// mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5,
// mscclpp::Transport::IB6, mscclpp::Transport::IB7};

__constant__ mscclpp::DeviceHandle<mscclpp::SmChannel> constSmChannels[8];
__constant__ mscclpp::DeviceHandle<mscclpp::SmChannel> constSmOutChannels[8];
Expand Down Expand Up @@ -444,11 +444,12 @@ static size_t ncclTypeSize(ncclDataType_t type) {
}

static mscclpp::Transport getTransport(int rank, int peerRank) {
if (rank / nRanksPerNode == peerRank / nRanksPerNode) {
return mscclpp::Transport::CudaIpc;
} else {
return IBs[rank % nRanksPerNode];
}
// if (rank / nRanksPerNode == peerRank / nRanksPerNode) {
// return mscclpp::Transport::CudaIpc;
// } else {
// return IBs[rank % nRanksPerNode];
// }
return mscclpp::Transport::CudaIpc;
}

static std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_ptr<mscclpp::Communicator> comm, int rank,
Expand Down Expand Up @@ -537,7 +538,7 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
// using scratch buffer for message size less then 1MB
commPtr->scratchBuff = mscclpp::allocExtSharedCuda<char>(scratchSize);
commPtr->remoteScratchRegMemories = setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), scratchSize,
mscclpp::Transport::CudaIpc | IBs[rank % nRanksPerNode]);
mscclpp::Transport::CudaIpc);

*comm = commPtr;
return ncclSuccess;
Expand Down Expand Up @@ -665,12 +666,12 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
if (it == comm->smChannels.end()) {
std::vector<mscclpp::RegisteredMemory> remoteMemories =
setupRemoteMemories(comm->comm, rank, const_cast<void*>(sendbuff), bytes,
mscclpp::Transport::CudaIpc | IBs[rank % nRanksPerNode]);
mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> channels = setupSmChannels(comm, remoteMemories, const_cast<void*>(sendbuff));
it = comm->smChannels.emplace(key, channels).first;
if (sendbuff != recvbuff) {
std::vector<mscclpp::RegisteredMemory> remoteMemories =
setupRemoteMemories(comm->comm, rank, recvbuff, bytes, mscclpp::Transport::CudaIpc | IBs[rank % nRanksPerNode]);
setupRemoteMemories(comm->comm, rank, recvbuff, bytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> outChannels = setupSmChannels(comm, remoteMemories, recvbuff);
outIt = comm->smOutChannels.emplace(key, outChannels).first;
}
Expand Down

0 comments on commit 5bf684b

Please sign in to comment.