Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Oct 16, 2023
1 parent 17db0b3 commit ee3c2a7
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 16 deletions.
Empty file added python/benchmark/__init__.py
Empty file.
11 changes: 5 additions & 6 deletions python/test/allreduce1.cu → python/benchmark/allreduce1.cu
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <mscclpp/sm_channel_device.hpp>
#include <mscclpp/concurrency.hpp>
#include <mscclpp/sm_channel_device.hpp>

__device__ mscclpp::DeviceSyncer deviceSyncer;
__device__ mscclpp::DeviceSyncer allGatherDeviceSyncer;
__device__ mscclpp::DeviceSyncer reduceScatterDeviceSyncer;


__device__ void localReduceScatterSm2(mscclpp::SmChannelDeviceHandle* smChans, int* buff, int* scratch, int rank, int nRanksPerNode, size_t chunkSize,
size_t nelems, int nBlocks) {
__device__ void localReduceScatterSm2(mscclpp::SmChannelDeviceHandle* smChans, int* buff, int* scratch, int rank,
int nRanksPerNode, size_t chunkSize, size_t nelems, int nBlocks) {
if (nRanksPerNode == 1) return;
if (blockIdx.x >= nBlocks) return;
const int nPeer = nRanksPerNode - 1;
Expand Down Expand Up @@ -53,7 +52,8 @@ __device__ void localReduceScatterSm2(mscclpp::SmChannelDeviceHandle* smChans, i
}
}

__device__ void localRingAllGatherSm(mscclpp::SmChannelDeviceHandle* smChans, int rank, int nRanksPerNode, uint64_t size, size_t nBlocks) {
__device__ void localRingAllGatherSm(mscclpp::SmChannelDeviceHandle* smChans, int rank, int nRanksPerNode,
uint64_t size, size_t nBlocks) {
if (nRanksPerNode == 1) return;
if (blockIdx.x >= nBlocks) return;

Expand All @@ -76,7 +76,6 @@ __device__ void localRingAllGatherSm(mscclpp::SmChannelDeviceHandle* smChans, in
}
}


// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce1(mscclpp::SmChannelDeviceHandle* smChans, int* buff, int rank, int nranks, int nelems) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from mscclpp_group import MscclppGroup
import os
from test.mscclpp_group import MscclppGroup
import cupy as cp
from mscclpp_mpi import MpiGroup
from utils import KernelBuilder, pack
from test.mscclpp_mpi import MpiGroup
from test.utils import KernelBuilder, pack
from mscclpp import Transport
from mpi4py import MPI
from prettytable import PrettyTable
Expand Down Expand Up @@ -33,7 +34,8 @@ def benchmark(table: PrettyTable, niter: int, nelem: int):

# create a sm_channel for each remote neighbor
sm_channels = group.make_sm_channels(memory, connections)
kernel = KernelBuilder(file="allreduce1.cu", kernel_name="allreduce1").get_compiled_kernel()
file_dir = os.path.dirname(os.path.abspath(__file__))
kernel = KernelBuilder(file="allreduce1.cu", kernel_name="allreduce1", file_dir=file_dir).get_compiled_kernel()
params = b""
device_handles = []
for rank in range(group.nranks):
Expand Down Expand Up @@ -76,7 +78,7 @@ def benchmark(table: PrettyTable, niter: int, nelem: int):


if __name__ == "__main__":

# Create a table
table = PrettyTable()

Expand All @@ -86,6 +88,6 @@ def benchmark(table: PrettyTable, niter: int, nelem: int):

for i in range(10,28):
benchmark(table, 1000, 2**i)

if MPI.COMM_WORLD.rank == 0:
print(table)
print(table)
2 changes: 1 addition & 1 deletion python/test/mscclpp_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
import numpy as np

from mscclpp_mpi import MpiGroup
from .mscclpp_mpi import MpiGroup

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions python/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ def __del__(self):
class KernelBuilder:
kernel_map: dict = {}

def __init__(self, file: str, kernel_name: str):
def __init__(self, file: str, kernel_name: str, file_dir: str = None):
if kernel_name in self.kernel_map:
self._kernel = self.kernel_map[kernel_name]
return
self._tempdir = tempfile.TemporaryDirectory(suffix=f"{os.getpid()}")
self._current_file_dir = os.path.dirname(os.path.abspath(__file__))
self._current_file_dir = file_dir if file_dir else os.path.dirname(os.path.abspath(__file__))
device_id = cp.cuda.Device().id
ptx = self._compile_cuda(os.path.join(self._current_file_dir, file), f"{kernel_name}.ptx", device_id)
self._kernel = Kernel(ptx, kernel_name, device_id)
Expand Down

0 comments on commit ee3c2a7

Please sign in to comment.