From 73907f1c643684cb56203db02772443e4e04f591 Mon Sep 17 00:00:00 2001 From: Ziyue Yang Date: Wed, 6 Nov 2024 02:15:02 +0000 Subject: [PATCH 1/2] Add kernel-based verification for executor_test --- python/test/executor_test.py | 206 +++++++++++++++++--------- python/test/executor_test_verifier.cu | 77 ++++++++++ 2 files changed, 212 insertions(+), 71 deletions(-) create mode 100644 python/test/executor_test_verifier.cu diff --git a/python/test/executor_test.py b/python/test/executor_test.py index 5dd41a2c3..a032f0f94 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -10,18 +10,33 @@ npkit, ) import mscclpp.comm as mscclpp_comm +from mscclpp.utils import KernelBuilder, pack import os +import struct import cupy as cp from mpi4py import MPI -def bench_time(niters: int, ngraphIters: int, func): - # capture cuda graph for niters of the kernel launch +def parse_dtype(dtype_str): + """Convert a human-readable data type string to a numpy data type.""" + dtype_str = dtype_str.strip().lower() + if dtype_str == "float16": + return cp.float16 + elif dtype_str == "float32": + return cp.float32 + elif dtype_str == "int32": + return cp.int32 + else: + raise ValueError(f"Unknown data type: {dtype_str}") + + +def bench_time(n_iters: int, n_graph_iters: int, func): + # capture cuda graph for n_iters of the kernel launch stream = cp.cuda.Stream(non_blocking=True) with stream: stream.begin_capture() - for i in range(niters): + for i in range(n_iters): func(stream) graph = stream.end_capture() @@ -33,12 +48,64 @@ def bench_time(niters: int, ngraphIters: int, func): end = cp.cuda.Event() start.record(stream) - for _ in range(ngraphIters): + for _ in range(n_graph_iters): graph.launch(stream) end.record(stream) end.synchronize() - return cp.cuda.get_elapsed_time(start, end) / niters * 1000.0 / ngraphIters + return cp.cuda.get_elapsed_time(start, end) / n_iters * 1000.0 / n_graph_iters + + +def bench_correctness( + execution_plan_name: str, + input_buf: cp.ndarray, + result_buf: cp.ndarray, + test_buf: cp.ndarray, + dtype_str: str, + rank: int, + num_ranks: int, + n_iters: int, + func, +): + type_size = cp.dtype(parse_dtype(dtype_str)).itemsize + + fill_data_kernel_name = "fill_data_%s" % dtype_str + if "allgather" in execution_plan_name: + coll = "all_gather" + else: + coll = "all_reduce" + test_data_kernel_name = "test_data_%s_%s" % (coll, dtype_str) + + file_dir = os.path.dirname(os.path.abspath(__file__)) + fill_data_kernel = KernelBuilder( + file="executor_test_verifier.cu", kernel_name=fill_data_kernel_name, file_dir=file_dir + ).get_compiled_kernel() + test_data_kernel = KernelBuilder( + file="executor_test_verifier.cu", kernel_name=test_data_kernel_name, file_dir=file_dir + ).get_compiled_kernel() + nblocks = 64 + nthreads = 1024 + + stream = cp.cuda.Stream(non_blocking=True) + with stream: + stream.begin_capture() + for i in range(n_iters): + fill_data_params = ( + pack(input_buf) + + struct.pack("Q", input_buf.nbytes // type_size) + + pack(rank, i) + ) + fill_data_kernel.launch_kernel(fill_data_params, nblocks, nthreads, 0, stream) + func(stream) + test_data_params = ( + pack(result_buf, test_buf) + + struct.pack("Q", input_buf.nbytes // type_size) + + pack(num_ranks, i) + ) + test_data_kernel.launch_kernel(test_data_params, nblocks, nthreads, 0, stream) + graph = stream.end_capture() + graph.launch(stream) + stream.synchronize() def parse_size(size_str): @@ -53,19 +120,6 @@ def parse_size(size_str): return int(size_str) -def parse_dtype(dtype_str): - """Convert a human-readable data type string to a numpy data type.""" - dtype_str = dtype_str.strip().lower() - if dtype_str == "float16": - return cp.float16 - elif dtype_str == "float32": - return cp.float32 - elif dtype_str == "int32": - return cp.int32 - else: - raise ValueError(f"Unknown data type: {dtype_str}") - - def dtype_to_mscclpp_dtype(dtype): if dtype == cp.float16: return DataType.float16 @@ -77,20 +131,36 @@ def dtype_to_mscclpp_dtype(dtype): raise ValueError(f"Unknown data type: {dtype}") -def determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name): - if "allgather" in execution_plan_name and in_place: - return recvbuf - else: - return sendbuf - +def build_bufs( + execution_plan_name: str, + size: int, + in_place: bool, + dtype: cp.dtype, + rank: int, + num_ranks: int, +): + type_size = cp.dtype(dtype).itemsize + assert (size % type_size) == 0, "size %d not multiple of type size %d" % (size, type_size) + nelems = size // type_size -def determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name): if "allgather" in execution_plan_name: - return recvbuf - elif in_place: - return sendbuf + assert (nelems % num_ranks) == 0, "nelems %d not multiple of num_ranks %d" % (nelems, num_ranks) + nelems_input = nelems if in_place else nelems // num_ranks + else: + nelems_input = nelems + nelems_output = nelems + + result_buf = cp.zeros(nelems_output, dtype=dtype) + if in_place: + if "allgather" in execution_plan_name: + input_buf = cp.split(result_buf, num_ranks)[rank] + else: + input_buf = result_buf else: - return recvbuf + input_buf = cp.zeros(nelems_input, dtype=dtype) + test_buf = cp.zeros(nelems_output, dtype=dtype) + + return input_buf, result_buf, test_buf def main( @@ -98,9 +168,10 @@ def main( execution_plan_path: str, size: int, in_place: bool = True, - dtype: cp.dtype = cp.float16, + dtype_str: str = "float16", packet_type: PacketType = PacketType.LL16, - seed: int = 42, + n_iters: int = 10, + n_graph_iters: int = 10, ): mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD) cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use() @@ -110,57 +181,49 @@ def main( npkit.init(mscclpp_group.my_rank) execution_plan = ExecutionPlan(execution_plan_name, execution_plan_path) - cp.random.seed(seed) - nelems = size // cp.dtype(dtype).itemsize - buffer = cp.random.random(nelems * mscclpp_group.nranks, dtype=cp.float32).astype(dtype) - sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size) - sendbuf = cp.zeros(nelems, dtype=dtype) - for i in range(nelems): - sendbuf[i] = sub_arrays[MPI.COMM_WORLD.rank][i] - - if "allgather" in execution_plan_name: - recvbuf = cp.zeros(nelems * mscclpp_group.nranks, dtype=dtype) - if in_place: - for i in range(nelems): - recvbuf[mscclpp_group.my_rank * nelems + i] = sendbuf[i] - expected = buffer - else: - recvbuf = cp.zeros(nelems, dtype=dtype) - expected = cp.zeros_like(sendbuf, dtype=dtype) - for i in range(mscclpp_group.nranks): - expected += sub_arrays[i] - mscclpp_group.barrier() + dtype = parse_dtype(dtype_str) + input_buf, result_buf, test_buf = build_bufs( + execution_plan_name, + size, + in_place, + dtype, + mscclpp_group.my_rank, + mscclpp_group.nranks, + ) executor_func = lambda stream: executor.execute( - MPI.COMM_WORLD.rank, - determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr, - determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr, - determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes, - determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes, + mscclpp_group.my_rank, + input_buf.data.ptr, + result_buf.data.ptr, + input_buf.nbytes, + result_buf.nbytes, dtype_to_mscclpp_dtype(dtype), execution_plan, stream.ptr, packet_type, ) - # check correctness - stream = cp.cuda.Stream(non_blocking=True) - executor_func(stream) - stream.synchronize() - assert cp.allclose( - determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name), - expected, - atol=1e-2 * mscclpp_group.nranks, + mscclpp_group.barrier() + bench_correctness( + execution_plan_name, + input_buf, + result_buf, + test_buf, + dtype_str, + mscclpp_group.my_rank, + mscclpp_group.nranks, + n_iters, + executor_func, ) mscclpp_group.barrier() - execution_time = bench_time(10, 10, executor_func) + execution_time = bench_time(n_iters, n_graph_iters, executor_func) if npkit_dump_dir is not None: npkit.dump(npkit_dump_dir) npkit.shutdown() print( - f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, " - f"data size: {sendbuf.nbytes} bytes data type: {dtype().dtype.name} " + f"Rank: {mscclpp_group.my_rank} Execution time: {execution_time} us, " + f"data size: {result_buf.nbytes} bytes data type: {dtype().dtype.name} " f"packet type: {packet_type}" ) executor = None @@ -175,7 +238,8 @@ def main( parser.add_argument("--in_place", action="store_true", help="flag to define an in-place operation") parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32") parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16") - parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n_iters", type=int, default=10) + parser.add_argument("--n_graph_iters", type=int, default=10) args = parser.parse_args() packet_type = PacketType.LL16 @@ -183,13 +247,13 @@ def main( packet_type = PacketType.LL8 buffer_size = parse_size(args.size) - dtype = parse_dtype(args.dtype) main( args.execution_plan_name, args.execution_plan_path, buffer_size, args.in_place, - dtype, + args.dtype, packet_type, - args.seed, + args.n_iters, + args.n_graph_iters, ) diff --git a/python/test/executor_test_verifier.cu b/python/test/executor_test_verifier.cu new file mode 100644 index 000000000..83975ac29 --- /dev/null +++ b/python/test/executor_test_verifier.cu @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#if defined(__HIP_PLATFORM_AMD__) +#include +#else +#include +#endif + +// Numerical Recipes ranqd1, Chapter 7.1, §An Even Quicker Generator, Eq. 7.1.6 +// parameters from Knuth and H. W. Lewis +static __device__ unsigned int ranqd1(unsigned int seed) { + const unsigned int a = 1664525; + const unsigned int c = 1013904223; + return a * seed + c; +} + +// fill/test kernel pairs must have the same thread block size to +// match their random number series. + +#define FILL_DATA(FuncNameType, DataType) \ +extern "C" __global__ void __launch_bounds__(1024, 1) \ + fill_data_##FuncNameType(DataType* input_buf, size_t num_elems, int rank, int seq) { \ + unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + seed = ranqd1(seed); \ + input_buf[i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \ + } \ +} + +FILL_DATA(float16, __half) +FILL_DATA(float32, float) +FILL_DATA(int32, int) + +#define TEST_DATA_ALL_GATHER(FuncNameType, DataType) \ +extern "C" __global__ void __launch_bounds__(1024, 1) \ + test_data_all_gather_##FuncNameType(DataType* result_buf, DataType* test_buf, \ + size_t num_elems, int num_ranks, int seq) { \ + for (int rank = 0; rank < num_ranks; rank++) { \ + size_t rank_offset = rank * num_elems; \ + unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + seed = ranqd1(seed); \ + test_buf[rank_offset + i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \ + assert(result_buf[i] == test_buf[i]); \ + } \ + } \ +} + +TEST_DATA_ALL_GATHER(float16, __half) +TEST_DATA_ALL_GATHER(float32, float) +TEST_DATA_ALL_GATHER(int32, int) + +#define TEST_DATA_ALL_REDUCE(FuncNameType, DataType) \ +extern "C" __global__ void __launch_bounds__(1024, 1) \ + test_data_all_reduce_##FuncNameType(DataType* result_buf, DataType* test_buf, \ + size_t num_elems, int num_ranks, int seq) { \ + for (int rank = 0; rank < num_ranks; rank++) { \ + unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + if (rank == 0) { \ + test_buf[i] = 0; \ + } \ + seed = ranqd1(seed); \ + test_buf[i] += DataType(seed % blockDim.x) / DataType(blockDim.x); \ + } \ + } \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + assert(abs(float(result_buf[i]) - float(test_buf[i])) < 1e-3 * num_ranks); \ + } \ +} + +TEST_DATA_ALL_REDUCE(float16, __half) +TEST_DATA_ALL_REDUCE(float32, float) +TEST_DATA_ALL_REDUCE(int32, int) From bf809a6147d3d4b04df8f22b8d44d089ad9af50f Mon Sep 17 00:00:00 2001 From: Ziyue Yang Date: Wed, 6 Nov 2024 02:58:50 +0000 Subject: [PATCH 2/2] Fix lint --- python/test/executor_test.py | 10 +--- python/test/executor_test_verifier.cu | 86 +++++++++++++-------------- 2 files changed, 44 insertions(+), 52 deletions(-) diff --git a/python/test/executor_test.py b/python/test/executor_test.py index a032f0f94..d0cda18a3 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -90,17 +90,11 @@ def bench_correctness( with stream: stream.begin_capture() for i in range(n_iters): - fill_data_params = ( - pack(input_buf) - + struct.pack("Q", input_buf.nbytes // type_size) - + pack(rank, i) - ) + fill_data_params = pack(input_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(rank, i) fill_data_kernel.launch_kernel(fill_data_params, nblocks, nthreads, 0, stream) func(stream) test_data_params = ( - pack(result_buf, test_buf) - + struct.pack("Q", input_buf.nbytes // type_size) - + pack(num_ranks, i) + pack(result_buf, test_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(num_ranks, i) ) test_data_kernel.launch_kernel(test_data_params, nblocks, nthreads, 0, stream) graph = stream.end_capture() diff --git a/python/test/executor_test_verifier.cu b/python/test/executor_test_verifier.cu index 83975ac29..d99b93611 100644 --- a/python/test/executor_test_verifier.cu +++ b/python/test/executor_test_verifier.cu @@ -12,65 +12,63 @@ // Numerical Recipes ranqd1, Chapter 7.1, §An Even Quicker Generator, Eq. 7.1.6 // parameters from Knuth and H. W. Lewis static __device__ unsigned int ranqd1(unsigned int seed) { - const unsigned int a = 1664525; - const unsigned int c = 1013904223; - return a * seed + c; + const unsigned int a = 1664525; + const unsigned int c = 1013904223; + return a * seed + c; } // fill/test kernel pairs must have the same thread block size to // match their random number series. -#define FILL_DATA(FuncNameType, DataType) \ -extern "C" __global__ void __launch_bounds__(1024, 1) \ - fill_data_##FuncNameType(DataType* input_buf, size_t num_elems, int rank, int seq) { \ - unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ - seed = ranqd1(seed); \ - input_buf[i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \ - } \ -} +#define FILL_DATA(FuncNameType, DataType) \ + extern "C" __global__ void __launch_bounds__(1024, 1) \ + fill_data_##FuncNameType(DataType* input_buf, size_t num_elems, int rank, int seq) { \ + unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + seed = ranqd1(seed); \ + input_buf[i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \ + } \ + } FILL_DATA(float16, __half) FILL_DATA(float32, float) FILL_DATA(int32, int) -#define TEST_DATA_ALL_GATHER(FuncNameType, DataType) \ -extern "C" __global__ void __launch_bounds__(1024, 1) \ - test_data_all_gather_##FuncNameType(DataType* result_buf, DataType* test_buf, \ - size_t num_elems, int num_ranks, int seq) { \ - for (int rank = 0; rank < num_ranks; rank++) { \ - size_t rank_offset = rank * num_elems; \ - unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ - seed = ranqd1(seed); \ - test_buf[rank_offset + i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \ - assert(result_buf[i] == test_buf[i]); \ - } \ - } \ -} +#define TEST_DATA_ALL_GATHER(FuncNameType, DataType) \ + extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_gather_##FuncNameType( \ + DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int seq) { \ + for (int rank = 0; rank < num_ranks; rank++) { \ + size_t rank_offset = rank * num_elems; \ + unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + seed = ranqd1(seed); \ + test_buf[rank_offset + i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \ + assert(result_buf[i] == test_buf[i]); \ + } \ + } \ + } TEST_DATA_ALL_GATHER(float16, __half) TEST_DATA_ALL_GATHER(float32, float) TEST_DATA_ALL_GATHER(int32, int) -#define TEST_DATA_ALL_REDUCE(FuncNameType, DataType) \ -extern "C" __global__ void __launch_bounds__(1024, 1) \ - test_data_all_reduce_##FuncNameType(DataType* result_buf, DataType* test_buf, \ - size_t num_elems, int num_ranks, int seq) { \ - for (int rank = 0; rank < num_ranks; rank++) { \ - unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ - if (rank == 0) { \ - test_buf[i] = 0; \ - } \ - seed = ranqd1(seed); \ - test_buf[i] += DataType(seed % blockDim.x) / DataType(blockDim.x); \ - } \ - } \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ - assert(abs(float(result_buf[i]) - float(test_buf[i])) < 1e-3 * num_ranks); \ - } \ -} +#define TEST_DATA_ALL_REDUCE(FuncNameType, DataType) \ + extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_reduce_##FuncNameType( \ + DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int seq) { \ + for (int rank = 0; rank < num_ranks; rank++) { \ + unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + if (rank == 0) { \ + test_buf[i] = 0; \ + } \ + seed = ranqd1(seed); \ + test_buf[i] += DataType(seed % blockDim.x) / DataType(blockDim.x); \ + } \ + } \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \ + assert(abs(float(result_buf[i]) - float(test_buf[i])) < 1e-3 * num_ranks); \ + } \ + } TEST_DATA_ALL_REDUCE(float16, __half) TEST_DATA_ALL_REDUCE(float32, float)