Skip to content

Commit

Permalink
Merge branch 'main' into binyli/nccl
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 authored Dec 1, 2024
2 parents 2b10c92 + ff18bb8 commit 0faade2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
6 changes: 4 additions & 2 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def bench_correctness(
fill_data_kernel_name = "fill_data_%s" % dtype_str
if "allgather" in execution_plan_name:
coll = "all_gather"
elif "reducescatter" in execution_plan_name:
coll = "reduce_scatter"
else:
coll = "all_reduce"
test_data_kernel_name = "test_data_%s_%s" % (coll, dtype_str)
Expand All @@ -96,7 +98,7 @@ def bench_correctness(
fill_data_kernel.launch_kernel(fill_data_params, nblocks, nthreads, 0, stream)
func(stream)
test_data_params = (
pack(result_buf, test_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(num_ranks, i)
pack(result_buf, test_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(num_ranks, rank, i)
)
test_data_kernel.launch_kernel(test_data_params, nblocks, nthreads, 0, stream)
graph = stream.end_capture()
Expand Down Expand Up @@ -128,7 +130,7 @@ def dtype_to_mscclpp_dtype(dtype):


def allocate_buffer(nelems, dtype):
if is_nvls_supported:
if is_nvls_supported():
buffer_raw = alloc_shared_physical_cuda(nelems * cp.dtype(dtype).itemsize)
buffer_ptr = cp.cuda.MemoryPointer(
cp.cuda.UnownedMemory(buffer_raw.get_ptr(), buffer_raw.size(), buffer_raw), 0
Expand Down
30 changes: 28 additions & 2 deletions python/test/executor_test_verifier.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ FILL_DATA(int32, int)

#define TEST_DATA_ALL_GATHER(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_gather_##FuncNameType( \
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int seq) { \
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int my_rank, int seq) { \
for (int rank = 0; rank < num_ranks; rank++) { \
size_t rank_offset = rank * num_elems; \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
Expand All @@ -54,7 +54,7 @@ TEST_DATA_ALL_GATHER(int32, int)

#define TEST_DATA_ALL_REDUCE(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_reduce_##FuncNameType( \
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int seq) { \
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int my_rank, int seq) { \
for (int rank = 0; rank < num_ranks; rank++) { \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
Expand All @@ -73,3 +73,29 @@ TEST_DATA_ALL_GATHER(int32, int)
TEST_DATA_ALL_REDUCE(float16, __half)
TEST_DATA_ALL_REDUCE(float32, float)
TEST_DATA_ALL_REDUCE(int32, int)

#define TEST_DATA_REDUCE_SCATTER(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_reduce_scatter_##FuncNameType( \
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int my_rank, int seq) { \
int nem_elems_per_rank = num_elems / num_ranks; \
int offset = nem_elems_per_rank * my_rank; \
for (int rank = 0; rank < num_ranks; rank++) { \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
if (rank == 0) { \
test_buf[i] = 0; \
} \
seed = ranqd1(seed); \
test_buf[i] += DataType(seed % blockDim.x) / DataType(blockDim.x); \
} \
} \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
if (i >= offset && i < offset + nem_elems_per_rank) { \
assert(abs(float(result_buf[i]) - float(test_buf[i])) < 1e-3 * num_ranks); \
} \
} \
}

TEST_DATA_REDUCE_SCATTER(float16, __half)
TEST_DATA_REDUCE_SCATTER(float32, float)
TEST_DATA_REDUCE_SCATTER(int32, int)

0 comments on commit 0faade2

Please sign in to comment.