Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add kernel-based verification for executor_test #378

Merged
merged 2 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 129 additions & 71 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,33 @@
npkit,
)
import mscclpp.comm as mscclpp_comm
from mscclpp.utils import KernelBuilder, pack
import os
import struct

import cupy as cp
from mpi4py import MPI


def bench_time(niters: int, ngraphIters: int, func):
# capture cuda graph for niters of the kernel launch
def parse_dtype(dtype_str):
"""Convert a human-readable data type string to a numpy data type."""
dtype_str = dtype_str.strip().lower()
if dtype_str == "float16":
return cp.float16
elif dtype_str == "float32":
return cp.float32
elif dtype_str == "int32":
return cp.int32
else:
raise ValueError(f"Unknown data type: {dtype_str}")


def bench_time(n_iters: int, n_graph_iters: int, func):
# capture cuda graph for n_iters of the kernel launch
stream = cp.cuda.Stream(non_blocking=True)
with stream:
stream.begin_capture()
for i in range(niters):
for i in range(n_iters):
func(stream)
graph = stream.end_capture()

Expand All @@ -33,12 +48,58 @@ def bench_time(niters: int, ngraphIters: int, func):
end = cp.cuda.Event()

start.record(stream)
for _ in range(ngraphIters):
for _ in range(n_graph_iters):
graph.launch(stream)
end.record(stream)
end.synchronize()

return cp.cuda.get_elapsed_time(start, end) / niters * 1000.0 / ngraphIters
return cp.cuda.get_elapsed_time(start, end) / n_iters * 1000.0 / n_graph_iters


def bench_correctness(
execution_plan_name: str,
input_buf: cp.ndarray,
result_buf: cp.ndarray,
test_buf: cp.ndarray,
dtype_str: str,
rank: int,
num_ranks: int,
n_iters: int,
func,
):
type_size = cp.dtype(parse_dtype(dtype_str)).itemsize

fill_data_kernel_name = "fill_data_%s" % dtype_str
if "allgather" in execution_plan_name:
coll = "all_gather"
else:
coll = "all_reduce"
test_data_kernel_name = "test_data_%s_%s" % (coll, dtype_str)

file_dir = os.path.dirname(os.path.abspath(__file__))
fill_data_kernel = KernelBuilder(
file="executor_test_verifier.cu", kernel_name=fill_data_kernel_name, file_dir=file_dir
).get_compiled_kernel()
test_data_kernel = KernelBuilder(
file="executor_test_verifier.cu", kernel_name=test_data_kernel_name, file_dir=file_dir
).get_compiled_kernel()
nblocks = 64
nthreads = 1024

stream = cp.cuda.Stream(non_blocking=True)
with stream:
stream.begin_capture()
for i in range(n_iters):
fill_data_params = pack(input_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(rank, i)
Binyang2014 marked this conversation as resolved.
Show resolved Hide resolved
fill_data_kernel.launch_kernel(fill_data_params, nblocks, nthreads, 0, stream)
func(stream)
test_data_params = (
pack(result_buf, test_buf) + struct.pack("Q", input_buf.nbytes // type_size) + pack(num_ranks, i)
)
test_data_kernel.launch_kernel(test_data_params, nblocks, nthreads, 0, stream)
graph = stream.end_capture()
graph.launch(stream)
stream.synchronize()


def parse_size(size_str):
Expand All @@ -53,19 +114,6 @@ def parse_size(size_str):
return int(size_str)


def parse_dtype(dtype_str):
"""Convert a human-readable data type string to a numpy data type."""
dtype_str = dtype_str.strip().lower()
if dtype_str == "float16":
return cp.float16
elif dtype_str == "float32":
return cp.float32
elif dtype_str == "int32":
return cp.int32
else:
raise ValueError(f"Unknown data type: {dtype_str}")


def dtype_to_mscclpp_dtype(dtype):
if dtype == cp.float16:
return DataType.float16
Expand All @@ -77,30 +125,47 @@ def dtype_to_mscclpp_dtype(dtype):
raise ValueError(f"Unknown data type: {dtype}")


def determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name):
if "allgather" in execution_plan_name and in_place:
return recvbuf
else:
return sendbuf

def build_bufs(
execution_plan_name: str,
size: int,
in_place: bool,
dtype: cp.dtype,
rank: int,
num_ranks: int,
):
type_size = cp.dtype(dtype).itemsize
assert (size % type_size) == 0, "size %d not multiple of type size %d" % (size, type_size)
nelems = size // type_size

def determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name):
if "allgather" in execution_plan_name:
return recvbuf
elif in_place:
return sendbuf
assert (nelems % num_ranks) == 0, "nelems %d not multiple of num_ranks %d" % (nelems, num_ranks)
nelems_input = nelems if in_place else nelems // num_ranks
else:
return recvbuf
nelems_input = nelems
nelems_output = nelems

result_buf = cp.zeros(nelems_output, dtype=dtype)
if in_place:
if "allgather" in execution_plan_name:
input_buf = cp.split(result_buf, num_ranks)[rank]
else:
input_buf = result_buf
else:
input_buf = cp.zeros(nelems_input, dtype=dtype)
test_buf = cp.zeros(nelems_output, dtype=dtype)

return input_buf, result_buf, test_buf


def main(
execution_plan_name: str,
execution_plan_path: str,
size: int,
in_place: bool = True,
dtype: cp.dtype = cp.float16,
dtype_str: str = "float16",
packet_type: PacketType = PacketType.LL16,
seed: int = 42,
n_iters: int = 10,
n_graph_iters: int = 10,
):
mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD)
cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use()
Expand All @@ -110,57 +175,49 @@ def main(
npkit.init(mscclpp_group.my_rank)
execution_plan = ExecutionPlan(execution_plan_name, execution_plan_path)

cp.random.seed(seed)
nelems = size // cp.dtype(dtype).itemsize
buffer = cp.random.random(nelems * mscclpp_group.nranks, dtype=cp.float32).astype(dtype)
sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size)
sendbuf = cp.zeros(nelems, dtype=dtype)
for i in range(nelems):
sendbuf[i] = sub_arrays[MPI.COMM_WORLD.rank][i]

if "allgather" in execution_plan_name:
recvbuf = cp.zeros(nelems * mscclpp_group.nranks, dtype=dtype)
if in_place:
for i in range(nelems):
recvbuf[mscclpp_group.my_rank * nelems + i] = sendbuf[i]
expected = buffer
else:
recvbuf = cp.zeros(nelems, dtype=dtype)
expected = cp.zeros_like(sendbuf, dtype=dtype)
for i in range(mscclpp_group.nranks):
expected += sub_arrays[i]
mscclpp_group.barrier()
dtype = parse_dtype(dtype_str)
input_buf, result_buf, test_buf = build_bufs(
execution_plan_name,
size,
in_place,
dtype,
mscclpp_group.my_rank,
mscclpp_group.nranks,
)

executor_func = lambda stream: executor.execute(
MPI.COMM_WORLD.rank,
determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr,
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr,
determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes,
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes,
mscclpp_group.my_rank,
input_buf.data.ptr,
result_buf.data.ptr,
input_buf.nbytes,
result_buf.nbytes,
dtype_to_mscclpp_dtype(dtype),
execution_plan,
stream.ptr,
packet_type,
)
# check correctness
stream = cp.cuda.Stream(non_blocking=True)
executor_func(stream)
stream.synchronize()

assert cp.allclose(
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name),
expected,
atol=1e-2 * mscclpp_group.nranks,
mscclpp_group.barrier()
bench_correctness(
execution_plan_name,
input_buf,
result_buf,
test_buf,
dtype_str,
mscclpp_group.my_rank,
mscclpp_group.nranks,
n_iters,
executor_func,
)

mscclpp_group.barrier()
execution_time = bench_time(10, 10, executor_func)
execution_time = bench_time(n_iters, n_graph_iters, executor_func)
if npkit_dump_dir is not None:
npkit.dump(npkit_dump_dir)
npkit.shutdown()
print(
f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, "
f"data size: {sendbuf.nbytes} bytes data type: {dtype().dtype.name} "
f"Rank: {mscclpp_group.my_rank} Execution time: {execution_time} us, "
f"data size: {result_buf.nbytes} bytes data type: {dtype().dtype.name} "
f"packet type: {packet_type}"
)
executor = None
Expand All @@ -175,21 +232,22 @@ def main(
parser.add_argument("--in_place", action="store_true", help="flag to define an in-place operation")
parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32")
parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--n_iters", type=int, default=10)
parser.add_argument("--n_graph_iters", type=int, default=10)
args = parser.parse_args()

packet_type = PacketType.LL16
if args.packet_type == "LL8":
packet_type = PacketType.LL8

buffer_size = parse_size(args.size)
dtype = parse_dtype(args.dtype)
main(
args.execution_plan_name,
args.execution_plan_path,
buffer_size,
args.in_place,
dtype,
args.dtype,
packet_type,
args.seed,
args.n_iters,
args.n_graph_iters,
)
75 changes: 75 additions & 0 deletions python/test/executor_test_verifier.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <assert.h>

#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_fp16.h>
#else
#include <cuda_fp16.h>
#endif

// Numerical Recipes ranqd1, Chapter 7.1, §An Even Quicker Generator, Eq. 7.1.6
// parameters from Knuth and H. W. Lewis
static __device__ unsigned int ranqd1(unsigned int seed) {
const unsigned int a = 1664525;
const unsigned int c = 1013904223;
return a * seed + c;
}

// fill/test kernel pairs must have the same thread block size to
// match their random number series.

#define FILL_DATA(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) \
fill_data_##FuncNameType(DataType* input_buf, size_t num_elems, int rank, int seq) { \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
seed = ranqd1(seed); \
input_buf[i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \
} \
}

FILL_DATA(float16, __half)
FILL_DATA(float32, float)
FILL_DATA(int32, int)

#define TEST_DATA_ALL_GATHER(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_gather_##FuncNameType( \
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int seq) { \
for (int rank = 0; rank < num_ranks; rank++) { \
size_t rank_offset = rank * num_elems; \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
seed = ranqd1(seed); \
test_buf[rank_offset + i] = DataType(seed % blockDim.x) / DataType(blockDim.x); \
assert(result_buf[i] == test_buf[i]); \
} \
} \
}

TEST_DATA_ALL_GATHER(float16, __half)
TEST_DATA_ALL_GATHER(float32, float)
TEST_DATA_ALL_GATHER(int32, int)

#define TEST_DATA_ALL_REDUCE(FuncNameType, DataType) \
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_reduce_##FuncNameType( \
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int seq) { \
for (int rank = 0; rank < num_ranks; rank++) { \
unsigned int seed = (unsigned int)(blockIdx.x * blockDim.x + threadIdx.x + rank + seq); \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
if (rank == 0) { \
test_buf[i] = 0; \
} \
seed = ranqd1(seed); \
test_buf[i] += DataType(seed % blockDim.x) / DataType(blockDim.x); \
} \
} \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
assert(abs(float(result_buf[i]) - float(test_buf[i])) < 1e-3 * num_ranks); \
} \
}

TEST_DATA_ALL_REDUCE(float16, __half)
TEST_DATA_ALL_REDUCE(float32, float)
TEST_DATA_ALL_REDUCE(int32, int)
Loading