Skip to content

Commit

Permalink
Merge branch 'microsoft:main' into ziyue/fix-double-buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
yzygitzh authored Nov 7, 2024
2 parents 0683623 + 9526d76 commit 32777df
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 66 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

A GPU-driven communication stack for scalable AI applications.

See [Quick Start](docs/quickstart.md) to quickly get started.
See [Quick Start](docs/getting-started/quickstart.md) to quickly get started.

## Overview

Expand Down
195 changes: 130 additions & 65 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,33 @@
npkit,
)
import mscclpp.comm as mscclpp_comm
from mscclpp.utils import KernelBuilder, pack
import os
import struct

import cupy as cp
from mpi4py import MPI


def bench_time(niters: int, ngraphIters: int, func):
# capture cuda graph for niters of the kernel launch
def parse_dtype(dtype_str):
"""Convert a human-readable data type string to a numpy data type."""
dtype_str = dtype_str.strip().lower()
if dtype_str == "float16":
return cp.float16
elif dtype_str == "float32":
return cp.float32
elif dtype_str == "int32":
return cp.int32
else:
raise ValueError(f"Unknown data type: {dtype_str}")


def bench_time(n_iters: int, n_graph_iters: int, func):
# capture cuda graph for n_iters of the kernel launch
stream = cp.cuda.Stream(non_blocking=True)
with stream:
stream.begin_capture()
for i in range(niters):
for i in range(n_iters):
func(stream)
graph = stream.end_capture()

Expand All @@ -33,12 +48,58 @@ def bench_time(niters: int, ngraphIters: int, func):
end = cp.cuda.Event()

start.record(stream)
for _ in range(ngraphIters):
for _ in range(n_graph_iters):
graph.launch(stream)
end.record(stream)
end.synchronize()

return cp.cuda.get_elapsed_time(start, end) / niters * 1000.0 / ngraphIters
return cp.cuda.get_elapsed_time(start, end) / n_iters * 1000.0 / n_graph_iters


def bench_correctness(
execution_plan_name: str,
input_buf: cp.ndarray,
result_buf: cp.ndarray,
test_buf: cp.ndarray,
dtype_str: str,
rank: int,
num_ranks: int,
n_iters: int,
func,
):
type_size = cp.dtype(parse_dtype(dtype_str)).itemsize

fill_data_kernel_name = "fill_data_%s" % dtype_str
if "allgather" in execution_plan_name:
coll = "all_gather"
else:
coll = "all_reduce"
test_data_kernel_name = "test_data_%s_%s" % (coll, dtype_str)

file_dir = os.path.dirname(os.path.abspath(__file__))
fill_data_kernel = KernelBuilder(
file="executor_test_verifier.cu", kernel_name=fill_data_kernel_name, file_dir=file_dir
).get_compiled_kernel()
test_data_kernel = KernelBuilder(
file="executor_test_verifier.cu", kernel_name=test_data_kernel_name, file_dir=file_dir
).get_compiled_kernel()
nblocks = 64
nthreads = 1024

stream = cp.cuda.Stream(non_blocking=True)
with stream:
stream.begin_capture()
for i in range(n_iters):
fill_data_params = pack(input_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(rank, i)
fill_data_kernel.launch_kernel(fill_data_params, nblocks, nthreads, 0, stream)
func(stream)
test_data_params = (
pack(result_buf, test_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(num_ranks, i)
)
test_data_kernel.launch_kernel(test_data_params, nblocks, nthreads, 0, stream)
graph = stream.end_capture()
graph.launch(stream)
stream.synchronize()


def parse_size(size_str):
Expand All @@ -53,19 +114,6 @@ def parse_size(size_str):
return int(size_str)


def parse_dtype(dtype_str):
"""Convert a human-readable data type string to a numpy data type."""
dtype_str = dtype_str.strip().lower()
if dtype_str == "float16":
return cp.float16
elif dtype_str == "float32":
return cp.float32
elif dtype_str == "int32":
return cp.int32
else:
raise ValueError(f"Unknown data type: {dtype_str}")


def dtype_to_mscclpp_dtype(dtype):
if dtype == cp.float16:
return DataType.float16
Expand All @@ -77,23 +125,47 @@ def dtype_to_mscclpp_dtype(dtype):
raise ValueError(f"Unknown data type: {dtype}")


def determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name):
def build_bufs(
execution_plan_name: str,
size: int,
in_place: bool,
dtype: cp.dtype,
rank: int,
num_ranks: int,
):
type_size = cp.dtype(dtype).itemsize
assert (size % type_size) == 0, "size %d not multiple of type size %d" % (size, type_size)
nelems = size // type_size

if "allgather" in execution_plan_name:
return recvbuf
elif in_place:
return sendbuf
assert (nelems % num_ranks) == 0, "nelems %d not multiple of num_ranks %d" % (nelems, num_ranks)
nelems_input = nelems if in_place else nelems // num_ranks
else:
return recvbuf
nelems_input = nelems
nelems_output = nelems

result_buf = cp.zeros(nelems_output, dtype=dtype)
if in_place:
if "allgather" in execution_plan_name:
input_buf = cp.split(result_buf, num_ranks)[rank]
else:
input_buf = result_buf
else:
input_buf = cp.zeros(nelems_input, dtype=dtype)
test_buf = cp.zeros(nelems_output, dtype=dtype)

return input_buf, result_buf, test_buf


def main(
execution_plan_name: str,
execution_plan_path: str,
size: int,
in_place: bool = True,
dtype: cp.dtype = cp.float16,
dtype_str: str = "float16",
packet_type: PacketType = PacketType.LL16,
seed: int = 42,
n_iters: int = 10,
n_graph_iters: int = 10,
):
mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD)
cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use()
Expand All @@ -103,57 +175,49 @@ def main(
npkit.init(mscclpp_group.my_rank)
execution_plan = ExecutionPlan(execution_plan_name, execution_plan_path)

cp.random.seed(seed)
nelems = size // cp.dtype(dtype).itemsize
buffer = cp.random.random(nelems * mscclpp_group.nranks, dtype=cp.float32).astype(dtype)
sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size)
sendbuf = cp.zeros(nelems, dtype=dtype)
for i in range(nelems):
sendbuf[i] = sub_arrays[MPI.COMM_WORLD.rank][i]

if "allgather" in execution_plan_name:
recvbuf = cp.zeros(nelems * mscclpp_group.nranks, dtype=dtype)
if in_place:
for i in range(nelems):
recvbuf[mscclpp_group.my_rank * nelems + i] = sendbuf[i]
expected = buffer
else:
recvbuf = cp.zeros(nelems, dtype=dtype)
expected = cp.zeros_like(sendbuf, dtype=dtype)
for i in range(mscclpp_group.nranks):
expected += sub_arrays[i]
mscclpp_group.barrier()
dtype = parse_dtype(dtype_str)
input_buf, result_buf, test_buf = build_bufs(
execution_plan_name,
size,
in_place,
dtype,
mscclpp_group.my_rank,
mscclpp_group.nranks,
)

executor_func = lambda stream: executor.execute(
MPI.COMM_WORLD.rank,
sendbuf.data.ptr,
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr,
sendbuf.nbytes,
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes,
mscclpp_group.my_rank,
input_buf.data.ptr,
result_buf.data.ptr,
input_buf.nbytes,
result_buf.nbytes,
dtype_to_mscclpp_dtype(dtype),
execution_plan,
stream.ptr,
packet_type,
)
# check correctness
stream = cp.cuda.Stream(non_blocking=True)
executor_func(stream)
stream.synchronize()

assert cp.allclose(
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name),
expected,
atol=1e-2 * mscclpp_group.nranks,
mscclpp_group.barrier()
bench_correctness(
execution_plan_name,
input_buf,
result_buf,
test_buf,
dtype_str,
mscclpp_group.my_rank,
mscclpp_group.nranks,
n_iters,
executor_func,
)

mscclpp_group.barrier()
execution_time = bench_time(10, 10, executor_func)
execution_time = bench_time(n_iters, n_graph_iters, executor_func)
if npkit_dump_dir is not None:
npkit.dump(npkit_dump_dir)
npkit.shutdown()
print(
f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, "
f"data size: {sendbuf.nbytes} bytes data type: {dtype().dtype.name} "
f"Rank: {mscclpp_group.my_rank} Execution time: {execution_time} us, "
f"data size: {result_buf.nbytes} bytes data type: {dtype().dtype.name} "
f"packet type: {packet_type}"
)
executor = None
Expand All @@ -168,21 +232,22 @@ def main(
parser.add_argument("--in_place", action="store_true", help="flag to define an in-place operation")
parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32")
parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--n_iters", type=int, default=10)
parser.add_argument("--n_graph_iters", type=int, default=10)
args = parser.parse_args()

packet_type = PacketType.LL16
if args.packet_type == "LL8":
packet_type = PacketType.LL8

buffer_size = parse_size(args.size)
dtype = parse_dtype(args.dtype)
main(
args.execution_plan_name,
args.execution_plan_path,
buffer_size,
args.in_place,
dtype,
args.dtype,
packet_type,
args.seed,
args.n_iters,
args.n_graph_iters,
)
75 changes: 75 additions & 0 deletions python/test/executor_test_verifier.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <assert.h>

#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_fp16.h>
#else
#include <cuda_fp16.h>
#endif

// Numerical Recipes ranqd1, Chapter 7.1, §An Even Quicker Generator, Eq. 7.1.6
// parameters from Knuth and H. W. Lewis
static __device__ unsigned int ranqd1(unsigned int seed) {
const unsigned int a = 1664525;
const unsigned int c = 1013904223;
return a * seed + c;
}

// fill/test kernel pairs must have the same thread block size to
// match their random number series.

#define FILL_DATA(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) \
fill_data_##FuncNameType(DataType* input_buf, size_t num_elems, int rank, int seq) { \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
seed = ranqd1(seed); \
input_buf[i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \
} \
}

FILL_DATA(float16, __half)
FILL_DATA(float32, float)
FILL_DATA(int32, int)

#define TEST_DATA_ALL_GATHER(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_gather_##FuncNameType( \
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int seq) { \
for (int rank = 0; rank < num_ranks; rank++) { \
size_t rank_offset = rank * num_elems; \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
seed = ranqd1(seed); \
test_buf[rank_offset + i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \
assert(result_buf[i] == test_buf[i]); \
} \
} \
}

TEST_DATA_ALL_GATHER(float16, __half)
TEST_DATA_ALL_GATHER(float32, float)
TEST_DATA_ALL_GATHER(int32, int)

#define TEST_DATA_ALL_REDUCE(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_reduce_##FuncNameType( \
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int seq) { \
for (int rank = 0; rank < num_ranks; rank++) { \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
if (rank == 0) { \
test_buf[i] = 0; \
} \
seed = ranqd1(seed); \
test_buf[i] += DataType(seed % blockDim.x) / DataType(blockDim.x); \
} \
} \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
assert(abs(float(result_buf[i]) - float(test_buf[i])) < 1e-3 * num_ranks); \
} \
}

TEST_DATA_ALL_REDUCE(float16, __half)
TEST_DATA_ALL_REDUCE(float32, float)
TEST_DATA_ALL_REDUCE(int32, int)

0 comments on commit 32777df

Please sign in to comment.