From 154814f7e6af4112c3756ab89536392970f29e63 Mon Sep 17 00:00:00 2001 From: Faraz Shahsavan Date: Thu, 12 Dec 2024 20:27:33 +0000 Subject: [PATCH] Update benchmarking code and remove empty files --- .../cutlass_benchmarks/sparse_benchmarks.py | 359 +++++++++ .../sparse_fp8_benchmarks.py | 760 ------------------ csrc/cutlass_extensions/common.hpp | 18 +- csrc/sparse/cutlass/sparse_compressor.cu | 19 +- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu | 10 +- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 33 +- .../layers/sparsity/__init__.py | 0 .../layers/sparsity/utils/__init__.py | 0 8 files changed, 392 insertions(+), 807 deletions(-) create mode 100644 benchmarks/cutlass_benchmarks/sparse_benchmarks.py delete mode 100644 benchmarks/cutlass_benchmarks/sparse_fp8_benchmarks.py delete mode 100644 vllm/model_executor/layers/sparsity/__init__.py delete mode 100644 vllm/model_executor/layers/sparsity/utils/__init__.py diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py new file mode 100644 index 0000000000000..f93ab89049360 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -0,0 +1,359 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_sparse_tensors +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + + +# bench +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + timers = [] + # pytorch impl - bfloat16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) + + # pytorch impl - float16 + timers.append( + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + + # cutlass impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass sparse impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, + torch.bfloat16)) + + # cutlass sparse with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, torch.bfloat16, + bias)) + + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, + torch.bfloat16)) + + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, scale_b, torch.float16, + bias.to(dtype=torch.float16))) + + return timers + + +def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']") + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) \ No newline at end of file diff --git a/benchmarks/cutlass_benchmarks/sparse_fp8_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_fp8_benchmarks.py deleted file mode 100644 index 4a76b289e3696..0000000000000 --- a/benchmarks/cutlass_benchmarks/sparse_fp8_benchmarks.py +++ /dev/null @@ -1,760 +0,0 @@ -import argparse -import copy -import dataclasses -import itertools -import multiprocessing as mp -import pickle as pkl -import time -import traceback -from multiprocessing import Process, Queue -from queue import Empty -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple - -import torch -import torch.utils.benchmark as TBenchmark -from torch.utils.benchmark import Measurement as TMeasurement -from utils import make_n_rand_sparse_tensors -from weight_shapes import WEIGHT_SHAPES - -import vllm._custom_ops as ops -from vllm.utils import FlexibleArgumentParser - -DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) -DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] -DEFAULT_TP_SIZES = [1] - - -@dataclasses.dataclass -class CudaGraphBenchParams: - num_ops_in_cuda_graph: int - - -@dataclasses.dataclass -class ArgPool: - ''' - When some argument of the benchmarking function is annotated with this type, - the benchmarking class (BenchMM) will collapse the argument to a pick a - single value from the given list of values, during function invocation. - - For every invocation during a benchmarking run, it will choose a - different value from the list. - ''' - values: Iterable[Any] - - -class BenchMM: - - class ArgsIterator: - - def __init__(self, args_list, kwargs_list): - assert len(args_list) == len(kwargs_list) - self.args_list = args_list - self.kwargs_list = kwargs_list - self.n = len(self.args_list) - self.idx = 0 - - def __next__(self): - while True: - yield (self.args_list[self.idx], self.kwargs_list[self.idx]) - self.idx += 1 - self.idx = self.idx % self.n - - def reset(self): - self.idx = 0 - - @property - def n_args(self): - return self.n - - def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams], - label: str, sub_label: str, description: str, fn: Callable, - *args, **kwargs): - - self.cuda_graph_params = cuda_graph_params - self.use_cuda_graph = self.cuda_graph_params is not None - self.label = label - self.sub_label = sub_label - self.description = description - self.fn = fn - - # Process args - self._args = args - self._kwargs = kwargs - self.args_list, self.kwargs_list = self.collapse_argpool( - *args, **kwargs) - self.args_iterator = self.ArgsIterator(self.args_list, - self.kwargs_list) - - # Cudagraph runner - self.g = None - if self.use_cuda_graph: - self.g = self.get_cuda_graph_runner() - - # benchmark run params - self.min_run_time = 1 - - def collapse_argpool(self, *args, **kwargs): - kwargs = kwargs if kwargs is not None else {} - assert kwargs is None or all([ - not isinstance(v, ArgPool) for k, v in kwargs.items() - ]), 'ArgPools in kwargs are not supported yet' - - arg_pool_indices = [ - i for i, x in enumerate(args) if isinstance(x, ArgPool) - ] - if len(arg_pool_indices) == 0: - return [args], [kwargs] - - # make sure all the Arg pools have the same number of choices - arg_pool_size = len(args[arg_pool_indices[0]].values) - assert all( - [len(args[i].values) == arg_pool_size for i in arg_pool_indices]) - - # create copies of the args - args_list = [] - kwargs_list = [] - for _ in range(arg_pool_size): - args_list.append(args) - kwargs_list.append(kwargs.copy()) - - # collapse the arg pools by simply choosing the ith value - for i in range(arg_pool_size): - assert isinstance(args_list[i], tuple) - # get as list - args_i = list(args_list[i]) - # collapse - make replacements - for arg_pool_idx in arg_pool_indices: - val_from_pool = args_i[arg_pool_idx].values[i] - args_i[arg_pool_idx] = val_from_pool - # store back as tuple - args_list[i] = tuple(args_i) - - return args_list, kwargs_list - - def get_cuda_graph_runner(self): - assert self.use_cuda_graph - assert self.args_iterator is not None - - num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph - - # warmup - args_it = self.args_iterator.__next__() - for _ in range(5): - args, kwargs = next(args_it) - self.fn(*args, **kwargs) - - self.args_iterator.reset() - args_it = self.args_iterator.__next__() - - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - for _ in range(num_graph_ops): - args, kwargs = next(args_it) - self.fn(*args, **kwargs) - return g - - def run_cudagraph(self) -> TMeasurement: - assert self.use_cuda_graph - globals = {'g': self.g} - - return TBenchmark.Timer( - stmt="g.replay()", - globals=globals, - label=self.label, - sub_label=self.sub_label, - description=self.description, - ).blocked_autorange(min_run_time=self.min_run_time) - - def run_eager(self) -> TMeasurement: - setup = None - stmt = None - globals = None - - has_arg_pool = self.args_iterator.n_args > 1 - if has_arg_pool: - setup = ''' - args_iterator.reset() - args_it = args_iterator.__next__() - ''' - stmt = ''' - args, kwargs = next(args_it) - fn(*args, **kwargs) - ''' - globals = {'fn': self.fn, 'args_iterator': self.args_iterator} - else: - # no arg pool. Just use the args and kwargs directly - self.args_iterator.reset() - args_it = self.args_iterator.__next__() - args, kwargs = next(args_it) - - setup = "" - stmt = ''' - fn(*args, **kwargs) - ''' - globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs} - - return TBenchmark.Timer( - stmt=stmt, - setup=setup, - globals=globals, - label=self.label, - sub_label=self.sub_label, - description=self.description, - ).blocked_autorange(min_run_time=self.min_run_time) - - def run(self) -> TMeasurement: - timer = None - if self.use_cuda_graph: # noqa SIM108 - timer = self.run_cudagraph() - else: - timer = self.run_eager() - #assert timer.meets_confidence() - #assert not timer.has_warnings, f"Warnings {timer._warnings}" - if not timer.meets_confidence() or timer.has_warnings: - print("Doesn't meet confidence - re-running bench ...") - return self.run() - return timer - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - if exc_type: - print(f"exc type {exc_type}") - print(f"exc value {exc_value}") - print(f"exc traceback {traceback}") - - -def run_single_benchmark_process(kernel_config: Dict, gpu_id: int, - queue: Queue): - """ - Run a single kernel benchmark in an isolated process. - Puts (success, result, config) tuple in the queue. - """ - try: - torch.cuda.set_device(gpu_id) - - # Initialize CUDA tensors - m, k, n = kernel_config['m'], kernel_config['k'], kernel_config['n'] - dtype = kernel_config['dtype'] - - # Create tensors - BComps, Es, As, Bs = make_n_rand_sparse_tensors( - kernel_config.get('arg_pool_size', 1), dtype, m, n, k) - bf16_As = [x.to(dtype=torch.bfloat16) for x in As] - bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs] - scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) - scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - # Because the transposed output will be computed - # out = torch.zeros((n, m), dtype=torch.bfloat16, device="cuda") - - # Setup benchmark params - cuda_graph_params = None - if cgops := kernel_config.get('cuda_graph_ops'): - cuda_graph_params = CudaGraphBenchParams(cgops) - - label = kernel_config['label'] - sub_label = kernel_config['sub_label'] - - # Initialize benchmark based on kernel type - bench = None - kernel_type = kernel_config['kernel_type'] - - if kernel_type == 'pytorch_mm': - bench = BenchMM(cuda_graph_params, label, sub_label, - "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, ArgPool(bf16_As), ArgPool(bf16_Bs)) - - elif kernel_type == 'pytorch_scaled_mm': - bench = BenchMM(cuda_graph_params, - label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm", - torch._scaled_mm, - ArgPool(As), - ArgPool(Bs), - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16) - - elif kernel_type == 'pytorch_scaled_mm_fast': - bench = BenchMM(cuda_graph_params, - label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", - torch._scaled_mm, - ArgPool(As), - ArgPool(Bs), - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16, - use_fast_accum=True) - - elif kernel_type == 'cutlass_scaled_mm': - bench = BenchMM(cuda_graph_params, label, sub_label, - "cutlass_fp8_fp8_bf16_scaled_mm", - ops.cutlass_scaled_mm, ArgPool(As), ArgPool(Bs), - scale_a, scale_b, torch.bfloat16) - - elif kernel_type == 'cutlass_scaled_sparse_mm': - bench = BenchMM(cuda_graph_params, label, sub_label, - "cutlass_fp8_fp8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, ArgPool(As), - ArgPool(BComps), ArgPool(Es), scale_a, scale_b, - torch.bfloat16) - - # Run the benchmark - result = bench.run() - queue.put((True, result, kernel_config)) - - except Exception as e: - print(f"Error in benchmark process: {str(e)}") - print(traceback.format_exc()) - queue.put((False, None, kernel_config)) - finally: - # Explicit cleanup - torch.cuda.empty_cache() - - -def benchmark_gpu_worker(gpu_id: int, task_queue: Queue, result_queue: Queue): - """ - Worker process that spawns individual benchmark processes for each kernel. - """ - try: - while True: - try: - kernel_config = task_queue.get_nowait() - if kernel_config is None: # Poison pill - break - - # Create a new process queue for this specific benchmark - process_queue = Queue() - - # Create and start a new process for this kernel benchmark - p = Process(target=run_single_benchmark_process, - args=(kernel_config, gpu_id, process_queue)) - p.start() - - # Wait for result with timeout (5 minutes for benchmarking) - try: - success, result, config = process_queue.get(timeout=300) - result_queue.put((success, result, config)) - except Empty: - print(f"Kernel {kernel_config.get('kernel_type')} ", - "benchmark timed out") - result_queue.put((False, None, kernel_config)) - - # Cleanup - p.join(timeout=1) # Give it 1 second to join - if p.is_alive(): - p.terminate() - p.join() - - except Empty: - break - except Exception as e: - print(f"Error in GPU {gpu_id} worker: {str(e)}") - print(traceback.format_exc()) - if 'kernel_config' in locals(): - result_queue.put((False, None, kernel_config)) - - finally: - print(f"GPU {gpu_id} worker finished") - - -def run_kernels_on_gpus( - configs: List[Dict] -) -> List[Tuple[bool, Optional[TMeasurement], Dict]]: - MULTI_GPU_MULTI_PROCESS = False # Set to False for single GPU testing - if MULTI_GPU_MULTI_PROCESS: - gpus_list = [0] - task_queue = Queue() - result_queue = Queue() - - configs = configs[:10] - - # Fill task queue - for config in configs: - task_queue.put(config) - for _ in gpus_list: # Add poison pills - task_queue.put(None) - - # Start GPU workers - workers = [] - for gpu_id in gpus_list: - p = Process(target=benchmark_gpu_worker, - args=(gpu_id, task_queue, result_queue)) - p.start() - workers.append(p) - - # Collect results - results = [] - completed = 0 - total_tasks = len(configs) - - while completed < total_tasks: - success, result, config = result_queue.get() - results.append((success, result, config)) - completed += 1 - - # Print progress - status = "Success" if success else "Failed" - print(f"{status}: {config['kernel_type']}") - - # Cleanup workers - for worker in workers: - worker.join(timeout=1) - if worker.is_alive(): - worker.terminate() - worker.join() - - return results - else: - """Run kernel benchmarks in a single process.""" - results = [] - gpu_id = 0 # Using the same GPU as before - torch.cuda.set_device(gpu_id) - # configs = configs[:10] # Keep the original slice - - for config in configs: - try: - # Initialize CUDA tensors - m, k, n = config['m'], config['k'], config['n'] - dtype = config['dtype'] - - # Create tensors - BComps, Es, As, Bs = make_n_rand_sparse_tensors( - config.get('arg_pool_size', 1), dtype, m, n, k) - bf16_As = [x.to(dtype=torch.bfloat16) for x in As] - bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs] - scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) - scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - # out = torch.zeros((n, m), dtype=torch.bfloat16, device="cuda") - - # Setup benchmark params - cuda_graph_params = None - if cgops := config.get('cuda_graph_ops'): - cuda_graph_params = CudaGraphBenchParams(cgops) - - label = config['label'] - sub_label = config['sub_label'] - - # Initialize benchmark based on kernel type - bench = None - kernel_type = config['kernel_type'] - - if kernel_type == 'pytorch_mm': - bench = BenchMM(cuda_graph_params, label, sub_label, - "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, ArgPool(bf16_As), - ArgPool(bf16_Bs)) - - elif kernel_type == 'pytorch_scaled_mm': - bench = BenchMM(cuda_graph_params, - label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm", - torch._scaled_mm, - ArgPool(As), - ArgPool(Bs), - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16) - - elif kernel_type == 'pytorch_scaled_mm_fast': - bench = BenchMM( - cuda_graph_params, - label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", - torch._scaled_mm, - ArgPool(As), - ArgPool(Bs), - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16, - use_fast_accum=True) - - elif kernel_type == 'cutlass_scaled_mm': - bench = BenchMM(cuda_graph_params, label, sub_label, - "cutlass_fp8_fp8_bf16_scaled_mm", - ops.cutlass_scaled_mm, ArgPool(As), - ArgPool(Bs), scale_a, scale_b, - torch.bfloat16) - - elif kernel_type == 'cutlass_scaled_sparse_mm': - bench = BenchMM(cuda_graph_params, label, sub_label, - "cutlass_fp8_fp8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, ArgPool(As), - ArgPool(BComps), ArgPool(Es), scale_a, - scale_b, torch.bfloat16) - - # Run the benchmark - result = bench.run() - - # Print progress - print(f"Success: {kernel_type}") - - results.append((True, result, config)) - - # Cleanup - torch.cuda.empty_cache() - - except Exception as e: - print(f"Error in benchmark: {str(e)}") - print(traceback.format_exc()) - results.append((False, None, config)) - torch.cuda.empty_cache() - - return results - - -def bench_fp8(dtype: torch.dtype, with_cuda_graph: Optional[int], - with_arg_pool: Optional[int], m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: - - # Check if context is not set - try: # noqa: SIM105 - mp.set_start_method('spawn', force=True) - except RuntimeError: - pass - - timers = [] - gpus_list = [5] # Using the same GPU list as original code - - # Base configuration for all kernels - base_config = { - 'm': m, - 'k': k, - 'n': n, - 'dtype': dtype, - 'cuda_graph_ops': with_cuda_graph, - 'arg_pool_size': with_arg_pool if with_arg_pool else 1, - 'label': label, - 'sub_label': sub_label - } - - # Prepare configs for all kernels - standard_kernels = [{ - 'kernel_type': 'pytorch_mm' - }, { - 'kernel_type': 'pytorch_scaled_mm' - }, { - 'kernel_type': 'pytorch_scaled_mm_fast' - }, { - 'kernel_type': 'cutlass_scaled_mm' - }, { - 'kernel_type': 'cutlass_scaled_sparse_mm' - }] - - # Create configs for standard kernels - all_configs = [{**base_config, **kernel} for kernel in standard_kernels] - - # Run all kernels distributed across GPUs - print( - f"Running {len(all_configs)} benchmarks across {len(gpus_list)} GPUs..." - ) - results = run_kernels_on_gpus(all_configs) - - # Process results - for success, result, _ in results: - if success and result is not None: - timers.append(result) - - return timers - - -def bench(dtype: torch.dtype, with_cuda_graph: Optional[int], - with_arg_pool: Optional[int], m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: - if dtype == torch.float8_e4m3fn: - return bench_fp8(dtype, with_cuda_graph, with_arg_pool, m, k, n, label, - sub_label) - raise ValueError("unsupported type") - - -# runner -def print_timers(timers: Iterable[TMeasurement]): - compare = TBenchmark.Compare(timers) - compare.print() - - -def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: - results = [] - dtype = args.dtype - - for m, k, n in MKNs: - label = f"scaled-sparse-{dtype}-gemm" - label = f"{label}-cugraph_{args.with_cuda_graph}" \ - if args.with_cuda_graph else label - label = f"{label}-argpool_{args.with_arg_pool}" \ - if args.with_arg_pool else label - timers = bench(args.dtype, args.with_cuda_graph, args.with_arg_pool, m, - k, n, label, f"MKN=({m}x{k}x{n})") - - print_timers(timers) - results.extend(timers) - - return results - - -# output makers -def make_output(data: Iterable[TMeasurement], - MKNs: Iterable[Tuple[int, int, int]], - base_description: str, - timestamp=None): - print(f"== All Results {base_description} ====") - print_timers(data) - - # pickle all the results - timestamp = int(time.time()) if timestamp is None else timestamp - with open(f"{base_description}-{timestamp}.pkl", "wb") as f: - pkl.dump(data, f) - - -# argparse runners - - -def run_square_bench(args): - dim_sizes = list( - range(args.dim_start, args.dim_end + 1, args.dim_increment)) - MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) - data = run(args, MKNs) - - make_output(data, MKNs, f"square_bench-{args.dtype}") - - -def run_range_bench(args): - dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) - n = len(dim_sizes) - Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes - Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes - Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes - MKNs = list(zip(Ms, Ks, Ns)) - data = run(args, MKNs) - - make_output(data, MKNs, f"range_bench-{args.dtype}") - - -def run_model_bench(args): - print("Benchmarking models:") - for i, model in enumerate(args.models): - print(f"[{i}] {model}") - - def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: - KNs = [] - for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): - if tp_split_dim is not None: - KN[tp_split_dim] = KN[tp_split_dim] // tp_size - KNs.append(KN) - return KNs - - model_bench_data = [] - models_tps = list(itertools.product(args.models, args.tp_sizes)) - for model, tp_size in models_tps: - Ms = args.batch_sizes - KNs = model_shapes(model, tp_size) - MKNs = [] - for m in Ms: - for k, n in KNs: - MKNs.append((m, k, n)) - - data = run(args, MKNs) - model_bench_data.append(data) - - # Print all results - for data, model_tp in zip(model_bench_data, models_tps): - model, tp_size = model_tp - print(f"== Results {args.dtype} {model}-TP{tp_size} ====") - print_timers(data) - - timestamp = int(time.time()) - - all_data = [] - for d in model_bench_data: - all_data.extend(d) - # pickle all data - with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: - pkl.dump(all_data, f) - - -if __name__ == '__main__': - - def to_torch_dtype(dt): - if dt == "fp8": - return torch.float8_e4m3fn - raise ValueError("unsupported dtype") - - parser = FlexibleArgumentParser( - description=""" -Benchmark Cutlass GEMM. - - To run square GEMMs: - python3 ./benchmarks/cutlass_benchmarks/sp_fp8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 - - To run constant N and K and sweep M: - python3 ./benchmarks/cutlass_benchmarks/sp_fp8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 - - To run dimensions from a model: - python3 ./benchmarks/cutlass_benchmarks/sp_fp8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 - - Output: - - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. - """, # noqa: E501 - formatter_class=argparse.RawTextHelpFormatter) - - parser.add_argument( - "--dtype", - type=to_torch_dtype, - required=True, - help="Available options are ['int8', 'fp8', 'fp16', 'bf16']") - parser.add_argument( - '--with-cuda-graph', - type=int, - default=32, - help="Number of ops/matmuls in a cudagraph execution. When set" - "cuda-graphs is enabled") - parser.add_argument( - '--with-arg-pool', - type=int, - default=None, - help="Number of A and B tensors to use as arg-pool. When not set," - "it defaults to 1") - - subparsers = parser.add_subparsers(dest="cmd") - - square_parser = subparsers.add_parser("square_bench") - square_parser.add_argument("--dim-start", type=int, required=True) - square_parser.add_argument("--dim-end", type=int, required=True) - square_parser.add_argument("--dim-increment", type=int, required=True) - square_parser.set_defaults(func=run_square_bench) - - range_parser = subparsers.add_parser("range_bench") - range_parser.add_argument("--dim-start", type=int, required=True) - range_parser.add_argument("--dim-end", type=int, required=True) - range_parser.add_argument("--dim-increment", type=int, required=True) - range_parser.add_argument("--m-constant", type=int, default=None) - range_parser.add_argument("--n-constant", type=int, default=None) - range_parser.add_argument("--k-constant", type=int, default=None) - range_parser.set_defaults(func=run_range_bench) - - model_parser = subparsers.add_parser("model_bench") - model_parser.add_argument("--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys()) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) - model_parser.set_defaults(func=run_model_bench) - - args = parser.parse_args() - args.func(args) diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index e918c2a77bc85..5c1098971b462 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -19,14 +19,6 @@ inline uint32_t next_pow_2(uint32_t const num) { return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } -inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { - int max_shared_mem_per_block_opt_in = 0; - cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, - cudaDevAttrMaxSharedMemoryPerBlockOptin, - device); - return max_shared_mem_per_block_opt_in; -} - /** * Panic wrapper for unwinding CUDA runtime errors */ @@ -38,4 +30,12 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { << " at line: " << __LINE__ << std::endl; \ exit(EXIT_FAILURE); \ } \ - } \ No newline at end of file + } + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device); + return max_shared_mem_per_block_opt_in; +} diff --git a/csrc/sparse/cutlass/sparse_compressor.cu b/csrc/sparse/cutlass/sparse_compressor.cu index d551a71e51601..ebb1c975121ac 100644 --- a/csrc/sparse/cutlass/sparse_compressor.cu +++ b/csrc/sparse/cutlass/sparse_compressor.cu @@ -67,23 +67,8 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e, using StrideA = cutlass::gemm::TagToStrideA_t; using StrideE = StrideA; - using Gemm = typename std::conditional< - std::is_same_v, - typename sm90_int8_config_default::Cutlass3xGemm, - typename std::conditional< - std::is_same_v, - typename sm90_fp8_config_default::Cutlass3xGemm, - typename std::conditional< - std::is_same_v, - typename sm90_fp16_config_default< - cutlass::half_t, cutlass::half_t, - c3x::ScaledEpilogue>::Cutlass3xGemm, - typename sm90_bf16_config_default< - cutlass::bfloat16_t, cutlass::half_t, - c3x::ScaledEpilogue>::Cutlass3xGemm>::type>::type>::type; + using Gemm = typename sm90_config_default::Cutlass3xGemm; using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu index ea191015b4159..76b6b2e395c04 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -45,8 +45,8 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); using Cutlass3xGemmDefault = - typename sm90_fp8_config_default::Cutlass3xGemm; + typename sm90_config_default::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_fp8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM128 = @@ -151,7 +151,7 @@ void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kFloat16); using Cutlass3xGemmDefault = - typename sm90_fp16_config_default::Cutlass3xGemm; // m in (128, inf) @@ -172,7 +172,7 @@ void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kBFloat16); using Cutlass3xGemmDefault = - typename sm90_bf16_config_default::Cutlass3xGemm; // m in (128, inf) @@ -193,7 +193,7 @@ void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); using Cutlass3xGemmDefault = - typename sm90_int8_config_default::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_int8_config_M128::Cutlass3xGemm; diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh index a5925c715940b..59027d61debae 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh @@ -183,29 +183,32 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a, template typename Epilogue> -struct sm90_fp16_config_default { +struct sm90_config_default {}; + + +template typename Epilogue> +struct sm90_config_default { // M in (128, inf) - static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = - cutlass_sparse_3x_gemm; + cutlass_sparse_3x_gemm; }; -template typename Epilogue> -struct sm90_bf16_config_default { +struct sm90_config_default { // M in (128, inf) - static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = - cutlass_sparse_3x_gemm; }; @@ -323,17 +326,16 @@ struct sm90_fp8_config_8 { }; //////////////////////////////////////////////////////////////////////// -template typename Epilogue> -struct sm90_fp8_config_default { +struct sm90_config_default { // M in (128, inf) - static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _2, _1>; using Cutlass3xGemm = - cutlass_sparse_3x_gemm; }; @@ -415,18 +417,17 @@ struct sm90_fp8_config_M512 { TileSchedule>; }; -template typename Epilogue> -struct sm90_int8_config_default { +struct sm90_config_default { // For M > 128 and any N - static_assert(std::is_same()); using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = - cutlass_sparse_3x_gemm; }; diff --git a/vllm/model_executor/layers/sparsity/__init__.py b/vllm/model_executor/layers/sparsity/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/model_executor/layers/sparsity/utils/__init__.py b/vllm/model_executor/layers/sparsity/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000