diff --git a/.buildkite/run-gh200-test.sh b/.buildkite/run-gh200-test.sh new file mode 100644 index 0000000000000..d06604f96f2b8 --- /dev/null +++ b/.buildkite/run-gh200-test.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# This script build the GH200 docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +DOCKER_BUILDKIT=1 docker build . \ + --target vllm-openai \ + --platform "linux/arm64" \ + -t gh200-test \ + --build-arg max_jobs=66 \ + --build-arg nvcc_threads=2 \ + --build-arg torch_cuda_arch_list="9.0+PTX" \ + --build-arg vllm_fa_cmake_gpu_arches="90-real" + +# Setup cleanup +remove_docker_container() { docker rm -f gh200-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image and test offline inference +docker run --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c ' + python3 examples/offline_inference.py +' diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index df4fa7a6ee9ba..b563c96343f92 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -181,14 +181,14 @@ steps: commands: - VLLM_USE_V1=1 pytest -v -s v1 -- label: Examples Test # 15min +- label: Examples Test # 25min working_dir: "/vllm-workspace/examples" #mirror_hardwares: [amd] source_file_dependencies: - vllm/entrypoints - examples/ commands: - - pip install awscli tensorizer # for llava example and tensorizer test + - pip install tensorizer # for tensorizer test - python3 offline_inference.py - python3 cpu_offload.py - python3 offline_inference_chat.py @@ -198,7 +198,10 @@ steps: - python3 offline_inference_vision_language_multi_image.py - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py - - python3 offline_profile.py --model facebook/opt-125m + - python3 offline_inference_classification.py + - python3 offline_inference_embedding.py + - python3 offline_inference_scoring.py + - python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Prefix Caching Test # 9min mirror_hardwares: [amd] @@ -221,8 +224,12 @@ steps: mirror_hardwares: [amd] source_file_dependencies: - vllm/model_executor/layers + - vllm/model_executor/guided_decoding - tests/test_logits_processor - command: pytest -v -s test_logits_processor.py + - tests/model_executor/test_guided_processors + commands: + - pytest -v -s test_logits_processor.py + - pytest -v -s model_executor/test_guided_processors.py - label: Speculative decoding tests # 30min source_file_dependencies: @@ -321,7 +328,7 @@ steps: ##### models test ##### -- label: Basic Models Test # 30min +- label: Basic Models Test # 24min source_file_dependencies: - vllm/ - tests/models @@ -331,7 +338,7 @@ steps: - pytest -v -s models/test_registry.py - pytest -v -s models/test_initialization.py -- label: Language Models Test (Standard) # 42min +- label: Language Models Test (Standard) # 32min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -342,7 +349,7 @@ steps: - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pytest -v -s models/embedding/language -m core_model -- label: Language Models Test (Extended) # 50min +- label: Language Models Test (Extended) # 1h10min optional: true source_file_dependencies: - vllm/ @@ -353,7 +360,7 @@ steps: - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' -- label: Multi-Modal Models Test (Standard) # 26min +- label: Multi-Modal Models Test (Standard) # 28min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -369,7 +376,7 @@ steps: - pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model -- label: Multi-Modal Models Test (Extended) # 1h15m +- label: Multi-Modal Models Test (Extended) 1 # 1h16m optional: true source_file_dependencies: - vllm/ @@ -380,14 +387,24 @@ steps: commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' + - pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model' # HACK - run phi3v tests separately to sidestep this transformers bug # https://github.com/huggingface/transformers/issues/34307 - pytest -v -s models/decoder_only/vision_language/test_phi3v.py - - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' + - pytest -v -s --ignore models/decoder_only/vision_language/test_models.py --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/vision_language -m 'not core_model' - pytest -v -s models/encoder_decoder/language -m 'not core_model' - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' +- label: Multi-Modal Models Test (Extended) 2 # 38m + optional: true + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/vision_language + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=1) and not core_model and not quant_model' + # This test is used only in PR development phase to test individual models and should never run on main - label: Custom Models Test optional: true @@ -422,11 +439,11 @@ steps: - tests/distributed/ commands: - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' - label: Distributed Tests (2 GPUs) # 40min #mirror_hardwares: [amd] @@ -445,12 +462,12 @@ steps: commands: - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' + - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' # Avoid importing model tests that cause CUDA reinitialization error - - pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus - - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus - - pytest models/decoder_only/vision_language/test_models.py -v -s -m distributed_2_gpus + - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py @@ -540,7 +557,7 @@ steps: # see https://github.com/vllm-project/vllm/pull/5689 for details - pytest -v -s distributed/test_custom_all_reduce.py - torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py - - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m distributed_2_gpus + - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' - pytest -v -s -x lora/test_mixtral.py - label: LM Eval Large Models # optional diff --git a/CMakeLists.txt b/CMakeLists.txt index c78cdc77a7e42..51b49a18dddf2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -196,6 +196,7 @@ set(VLLM_EXT_SRC "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" + "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" @@ -205,7 +206,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. - set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -222,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG v3.5.1 + GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE + GIT_SHALLOW FALSE ) endif() FetchContent_MakeAvailable(cutlass) @@ -240,7 +241,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/awq/gemm_kernels.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_compressor_entry.cu" + "csrc/cutlass_extensions/common.cpp") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -270,11 +274,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # - # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels + # For Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" + "csrc/sparse/cutlass/sparse_compressor_c3x.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") @@ -283,12 +290,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") else() if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is " + message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is " "not >= 12.0, we recommend upgrading to CUDA 12.0 or " - "later if you intend on running FP8 quantized models on " + "later if you intend on running FP8 sparse or quantized models on " "Hopper.") else() - message(STATUS "Not building scaled_mm_c3x as no compatible archs found " + message(STATUS "Not building cutlass_c3x as no compatible archs found " "in CUDA target architectures") endif() @@ -300,7 +307,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) # kernels for the remaining archs that are not already built for 3x. - cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS + cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) @@ -403,7 +410,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/Dockerfile b/Dockerfile index b38113f524a17..391ec2182a589 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,6 +11,7 @@ ARG CUDA_VERSION=12.4.1 FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base ARG CUDA_VERSION=12.4.1 ARG PYTHON_VERSION=3.12 +ARG TARGETPLATFORM ENV DEBIAN_FRONTEND=noninteractive # Install Python and other dependencies @@ -46,9 +47,14 @@ WORKDIR /workspace # install build and runtime dependencies COPY requirements-common.txt requirements-common.txt COPY requirements-cuda.txt requirements-cuda.txt +COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + python3 -m pip install -r requirements-cuda-arm64.txt; \ + fi # cuda arch list used by torch # can be useful for both `dev` and `test` @@ -63,6 +69,7 @@ ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches} #################### WHEEL BUILD IMAGE #################### FROM base AS build +ARG TARGETPLATFORM # install build dependencies COPY requirements-build.txt requirements-build.txt @@ -70,6 +77,11 @@ COPY requirements-build.txt requirements-build.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-build.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + python3 -m pip install -r requirements-cuda-arm64.txt; \ + fi + COPY . . ARG GIT_REPO_CHECK=0 RUN --mount=type=bind,source=.git,target=.git \ @@ -134,8 +146,8 @@ COPY requirements-test.txt requirements-test.txt COPY requirements-dev.txt requirements-dev.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt - #################### DEV IMAGE #################### + #################### vLLM installation IMAGE #################### # image with vLLM installed FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base @@ -143,6 +155,9 @@ ARG CUDA_VERSION=12.4.1 ARG PYTHON_VERSION=3.12 WORKDIR /vllm-workspace ENV DEBIAN_FRONTEND=noninteractive +ARG TARGETPLATFORM + +COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \ echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment @@ -168,18 +183,25 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ # or future versions of triton. RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ -# install vllm wheel first, so that torch etc will be installed +# Install vllm wheel first, so that torch etc will be installed. RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install dist/*.whl --verbose RUN --mount=type=cache,target=/root/.cache/pip \ - . /etc/environment && \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl + if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + pip uninstall -y torch && \ + python3 -m pip install -r requirements-cuda-arm64.txt; \ + fi + +RUN --mount=type=cache,target=/root/.cache/pip \ +. /etc/environment && \ +if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ + python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \ +fi COPY examples examples #################### vLLM installation IMAGE #################### - #################### TEST IMAGE #################### # image to run unit testing suite # note that this uses vllm installed by `pip` @@ -209,7 +231,6 @@ COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1 RUN mkdir test_docs RUN mv docs test_docs/ RUN mv vllm test_docs/ - #################### TEST IMAGE #################### #################### OPENAI API SERVER #################### @@ -218,8 +239,11 @@ FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.44.0' timm==0.9.10 - + if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10'; \ + else \ + pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10'; \ + fi ENV VLLM_USAGE_SOURCE production-docker-image ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/README.md b/README.md index ed5161ccffb45..93b71ddaccc61 100644 --- a/README.md +++ b/README.md @@ -134,3 +134,7 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs * For coordinating contributions and development, please use Slack. * For security disclosures, please use Github's security advisory feature. * For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu. + +## Media Kit + +* If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit). diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 3256692142c5e..4eb0e1f8ac903 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -781,6 +781,7 @@ def main(args: argparse.Namespace): backend = args.backend model_id = args.model tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + tokenizer_mode = args.tokenizer_mode if args.base_url is not None: api_url = f"{args.base_url}{args.endpoint}" @@ -790,6 +791,7 @@ def main(args: argparse.Namespace): base_url = f"http://{args.host}:{args.port}" tokenizer = get_tokenizer(tokenizer_id, + tokenizer_mode=tokenizer_mode, trust_remote_code=args.trust_remote_code) if args.dataset is not None: @@ -1210,5 +1212,15 @@ def main(args: argparse.Namespace): "from the sampled HF dataset.", ) + parser.add_argument( + '--tokenizer-mode', + type=str, + default="auto", + choices=['auto', 'slow', 'mistral'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer.') + args = parser.parse_args() main(args) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py new file mode 100644 index 0000000000000..3d1c5e392f9e2 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -0,0 +1,384 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_sparse_tensors +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + + +# bench +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + # pytorch impl - bfloat16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) + + # pytorch impl - float16 + timers.append( + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + + # cutlass impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass sparse impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass sparse with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, + k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16, bias.to(dtype=torch.float16))) + + return timers + + +def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']") + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py new file mode 100644 index 0000000000000..ef06fcd6604dd --- /dev/null +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -0,0 +1,96 @@ +# Cutlass bench utils +from typing import Iterable, Tuple + +import torch + +import vllm._custom_ops as ops + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def to_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.bfloat16) + + +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.float16) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + if dtype == torch.int8: + return to_int8(a), to_int8(b) + if dtype == torch.float8_e4m3fn: + return to_fp8(a), to_fp8(b) + + raise ValueError("unsupported dtype") + + +def prune_to_2_4(tensor): + # Reshape tensor to [N, 4] where N is number of groups of 4 + original_shape = tensor.shape + reshaped = tensor.reshape(-1, 4) + + # Get indices of top 2 absolute values in each group of 4 + _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) + + # Create binary mask + mask = torch.zeros_like(reshaped) + mask.scatter_(dim=1, + index=indices, + src=torch.ones_like(indices, dtype=mask.dtype)) + + # Apply mask and reshape back + pruned = reshaped * mask + + # Turn all -0.0 to 0.0 + pruned[pruned == -0.0] = 0.0 + + return pruned.reshape(original_shape) + + +def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + b = prune_to_2_4(b.t()).t() + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + elif dtype == torch.float16: + a, b = to_fp16(a), to_fp16(b) + elif dtype == torch.bfloat16: + a, b = to_bf16(a), to_bf16(b) + else: + raise ValueError("unsupported dtype") + + b_compressed, e = ops.cutlass_sparse_compress(b.t()) + + # Compressed B, Metadata, Original A, B + return b_compressed, e, a, b + + +def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, + m: int, n: int, k: int) -> \ + Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: + ABs = [] + for _ in range(num_tensors): + b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) + if b_comp is not None: + ABs.append(make_rand_sparse_tensors(dtype, m, n, k)) + BComps, Es, As, Bs = zip(*ABs) + return list(BComps), list(Es), list(As), list(Bs) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 63cf5d50cac75..d0353bc8cb42a 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -8,6 +8,7 @@ import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_tensors from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops @@ -17,31 +18,6 @@ DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] -# helpers - - -def to_fp8(tensor: torch.Tensor) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) - - -def to_int8(tensor: torch.Tensor) -> torch.Tensor: - return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) - - -def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> Tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 - - if dtype == torch.int8: - return to_int8(a), to_int8(b) - if dtype == torch.float8_e4m3fn: - return to_fp8(a), to_fp8(b) - - raise ValueError("unsupported dtype") - # bench def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, @@ -386,4 +362,4 @@ def to_torch_dtype(dt): model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() - args.func(args) + args.func(args) \ No newline at end of file diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py index 25ec9d6028627..d58fb0bf86374 100644 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -40,4 +40,4 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], -} +} \ No newline at end of file diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py new file mode 100644 index 0000000000000..ef91f9f8eb529 --- /dev/null +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -0,0 +1,173 @@ +import pickle as pkl +import time +from dataclasses import dataclass +from itertools import product +from typing import Callable, Iterable, List, Optional + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from tqdm import tqdm + +import vllm._custom_ops as ops +from vllm.model_executor.layers.layernorm import RMSNorm + + +@dataclass +class bench_params_t: + num_tokens: int + hidden_size: int + add_residual: bool + dtype: torch.dtype + + def description(self): + return (f'N {self.num_tokens} ' + f'x D {self.hidden_size} ' + f'x R {self.add_residual} ' + f'x DT {self.dtype}') + + +def get_bench_params() -> List[bench_params_t]: + ## Test Fixtures + NUM_TOKENS = [2**x for x in range(11)] + HIDDEN_SIZES = list(range(1024, 8129, 1024)) + ADD_RESIDUAL = [True, False] + DTYPES = [torch.bfloat16, torch.float] + + combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) + bench_params = list(map(lambda x: \ + bench_params_t(x[0], x[1], x[2], x[3]), combinations)) + return bench_params + + +# Reference impls +def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _, _ = ops.scaled_int8_quant(torch_out) + + +def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _ = ops.scaled_fp8_quant(torch_out) + + +def fused_impl( + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + out, _ = ops.rms_norm_dynamic_per_token_quant(x, + rms_norm_layer.weight, + 1e-6, + quant_dtype, + residual=residual) + + +# Bench functions +def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, + quant_dtype: torch.dtype, label: str, sub_label: str, + fn: Callable, description: str) -> TMeasurement: + + min_run_time = 1 + + globals = { + "rms_norm_layer": rms_norm_layer, + "x": x, + "residual": residual, + "quant_dtype": quant_dtype, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(rms_norm_layer, x, residual, quant_dtype)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + +def bench(params: bench_params_t, label: str, sub_label: str) \ + -> Iterable[TMeasurement]: + + # Make inputs + layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype) + # Make weights + layer.weight.data.normal_(mean=1.0, std=0.1) + # Make inputs + scale = 1 / params.hidden_size + x = torch.randn(params.num_tokens, + params.hidden_size, + dtype=params.dtype, + device='cuda') * scale + residual = (torch.randn_like(x) * scale).to(device='cuda') \ + if params.add_residual else None + + timers = [] + + # unfused int8 impl. + timers.append( + bench_fn(layer, x, residual, torch.int8, label, sub_label, + unfused_int8_impl, "unfused_int8_impl")) + + # unfused fp8 impl. + timers.append( + bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, + unfused_fp8_impl, "unfused_fp8_impl")) + + # fused int8 impl. + timers.append( + bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl, + "fused_int8_impl")) + + # fused fp8 impl. + timers.append( + bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, + fused_impl, "fused_fp8_impl")) + + print_timers(timers) + + return timers + + +# launch bench +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def main(): + torch.set_default_device('cuda') + bench_params = get_bench_params() + + timers = [] + for bp in tqdm(bench_params): + timers.extend( + bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) + print_timers(timers) + + # pickle all the results + timestamp = int(time.time()) + with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f: + pkl.dump(timers, f) + + +if __name__ == '__main__': + main() diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py new file mode 100644 index 0000000000000..baa5de0fff1bd --- /dev/null +++ b/benchmarks/kernels/benchmark_rmsnorm.py @@ -0,0 +1,262 @@ +import itertools +from typing import Optional, Tuple, Union + +import torch +import triton +from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from torch import nn + +from vllm import _custom_ops as vllm_ops + + +class HuggingFaceRMSNorm(nn.Module): + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual + + +def rmsnorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) + naive_norm.weight = nn.Parameter(weight) + naive_norm = naive_norm.to(x.device) + + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + output = naive_norm(x, residual) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_flashinfer( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + fused_add_rmsnorm(x, residual, weight, eps) + output = (x, residual) + else: + output = rmsnorm(x, weight, eps) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + vllm_ops.fused_add_rms_norm(x, residual, weight, eps) + output = (x, residual) + else: + out = torch.empty_like(x) + vllm_ops.rms_norm(out, x, weight, eps) + output = out + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): + dtype = torch.bfloat16 + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + output_naive = rmsnorm_naive( + x.clone(), weight, + residual.clone() if residual is not None else None) + output_flashinfer = rmsnorm_flashinfer( + x.clone(), weight, + residual.clone() if residual is not None else None) + output_vllm = rmsnorm_vllm( + x.clone(), weight, + residual.clone() if residual is not None else None) + + if use_residual: + output_naive = output_naive[0] + output_flashinfer = output_flashinfer[0] + output_vllm = output_vllm[0] + + print(f"Naive output={output_naive}") + print(f"FlashInfer output={output_flashinfer}") + print(f"VLLM output={output_vllm}") + + if torch.allclose(output_naive, output_flashinfer, atol=1e-2, + rtol=1e-2) and torch.allclose( + output_naive, output_vllm, atol=1e-2, rtol=1e-2): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [2**i for i in range(0, 7, 2)] +seq_length_range = [2**i for i in range(6, 11, 1)] +head_num_range = [32, 48] +configs = list( + itertools.product(head_num_range, batch_size_range, seq_length_range)) + + +def get_benchmark(use_residual): + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["head_num", "batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["huggingface", "flashinfer", "vllm"], + line_names=["HuggingFace", "FlashInfer", "vLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name= + f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", + args={}, + )) + def benchmark(head_num, batch_size, seq_len, provider): + dtype = torch.bfloat16 + hidden_size = head_num * 128 # assuming head_dim = 128 + + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + quantiles = [0.5, 0.2, 0.8] + + if provider == "huggingface": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_naive( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_flashinfer( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_vllm( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size", + ) + parser.add_argument( + "--seq-len", + type=int, + default=128, + help="Sequence length", + ) + parser.add_argument( + "--hidden-size", + type=int, + default=4096, + help="Hidden size (2nd dimension) of the sequence", + ) + parser.add_argument("--use-residual", + action="store_true", + help="Whether to use residual connection") + parser.add_argument( + "--save-path", + type=str, + default="./configs/rmsnorm/", + help="Path to save rmsnorm benchmark results", + ) + + args = parser.parse_args() + + # Run correctness test + calculate_diff(batch_size=args.batch_size, + seq_len=args.seq_len, + hidden_size=args.hidden_size, + use_residual=args.use_residual) + + # Get the benchmark function with proper use_residual setting + benchmark = get_benchmark(args.use_residual) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp new file mode 100644 index 0000000000000..ba9f40a230c8e --- /dev/null +++ b/csrc/core/math.hpp @@ -0,0 +1,7 @@ +#include +#include + +inline uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/cutlass_extensions/common.cpp new file mode 100644 index 0000000000000..3d2093ab94297 --- /dev/null +++ b/csrc/cutlass_extensions/common.cpp @@ -0,0 +1,11 @@ +#include "cutlass_extensions/common.hpp" + +int32_t get_sm_version_num() { + int32_t major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + return version_num; +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp new file mode 100644 index 0000000000000..85e359aa57113 --- /dev/null +++ b/csrc/cutlass_extensions/common.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include +#include "cuda_runtime.h" +#include + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ + } + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device); + return max_shared_mem_per_block_opt_in; +} + +int32_t get_sm_version_num(); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 95764ecddc79f..fcc17c7727f94 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -36,13 +36,13 @@ struct ScaledEpilogueBase { // Don't want to support nullptr by default template using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // Don't want to support nullptr by default template using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // This utility function constructs the arguments for the load descriptors diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index a634e1c3d4886..03414b7e1ae93 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -14,6 +14,20 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +// TODO(luka/varun): use FP8_TYPE macro after refactoring +#ifndef USE_ROCM + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#else + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#endif + +#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) + #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ diff --git a/csrc/ops.h b/csrc/ops.h index ea001190bc202..c145e4eda0845 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -66,6 +66,14 @@ void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& weight, torch::Tensor& scale, double epsilon); +void rms_norm_dynamic_per_token_quant(torch::Tensor& out, + torch::Tensor const& input, + torch::Tensor const& weight, + torch::Tensor& scales, + double const epsilon, + std::optional scale_ub, + std::optional residual); + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); @@ -154,6 +162,15 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& azp_adj, c10::optional const& azp, c10::optional const& bias); + +void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& e, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + c10::optional const& bias); + +bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, + torch::Tensor& e, torch::Tensor const& a); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp deleted file mode 100644 index bf04bb400790f..0000000000000 --- a/csrc/quantization/cutlass_w8a8/common.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include "cutlass/cutlass.h" -#include - -/** - * Helper function for checking CUTLASS errors - */ -#define CUTLASS_CHECK(status) \ - { \ - TORCH_CHECK(status == cutlass::Status::kSuccess, \ - cutlassGetStatusString(status)) \ - } - -inline uint32_t next_pow_2(uint32_t const num) { - if (num <= 1) return num; - return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} - -inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { - int max_shared_mem_per_block_opt_in = 0; - cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, - cudaDevAttrMaxSharedMemoryPerBlockOptin, - device); - return max_shared_mem_per_block_opt_in; -} - diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index d03242f44ab1d..75681f7f37820 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -21,7 +21,8 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" -#include "common.hpp" +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" // clang-format on using namespace cute; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 33581a63d4c3d..8190277997161 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -24,7 +24,8 @@ #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include "common.hpp" +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" // clang-format on using namespace cute; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 97a969cf5e3e0..4f7b6588ef3f7 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -3,6 +3,8 @@ #include #include +#include "cutlass_extensions/common.hpp" + void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { return false; } -int32_t get_sm_version_num() { - int32_t major_capability, minor_capability; - cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, - 0); - cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, - 0); - int32_t version_num = major_capability * 10 + minor_capability; - return version_num; -} - void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index d7c0297d5333f..15bd5b6ed1564 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -1,6 +1,9 @@ #pragma once +#include "quantization/vectorization.cuh" + #include +#include #ifndef USE_ROCM #include @@ -15,6 +18,7 @@ using FP8_TYPE = c10::Float8_e4m3fnuz; // issue when running dynamic quantization. Here use 224.0f for rocm. constexpr auto FP8_E4M3_MAX = 224.0f; #endif +constexpr static auto kFp8Type = c10::CppTypeToScalarType::value; namespace vllm { @@ -89,22 +93,6 @@ __global__ void segmented_max_reduction(float* __restrict__ scale, } } -template -struct __align__(8) vec4_t { - scalar_t x; - scalar_t y; - scalar_t z; - scalar_t w; -}; - -typedef struct __align__(4) { - FP8_TYPE x; - FP8_TYPE y; - FP8_TYPE z; - FP8_TYPE w; -} -float8x4_t; - template __device__ float thread_max_vec(scalar_t const* __restrict__ input, int64_t const num_elems, int const tid, @@ -139,10 +127,10 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, float const scale, int64_t const num_elems, int const tid, int const step) { + using float8x4_t = q8x4_t; // Vectorized input/output to better utilize memory bandwidth. - vec4_t const* vectorized_in = - reinterpret_cast const*>(input); - float8x4_t* vectorized_out = reinterpret_cast(out); + auto const* vectorized_in = reinterpret_cast const*>(input); + auto* vectorized_out = reinterpret_cast(out); int64_t const num_vec_elems = num_elems >> 2; diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu new file mode 100644 index 0000000000000..3c4f183bf4b59 --- /dev/null +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -0,0 +1,160 @@ + +#include +#include + +#include "../../dispatch_utils.h" +#include "layernorm_utils.cuh" +#include "quant_conversions.cuh" + +namespace vllm { + +template +__device__ void rms_norm_dynamic_per_token_quant_vec( + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + float rms = 0.0f; + float token_scale = 0.0f; + + // Compute rms + vllm::vectorized::compute_rms( + &rms, input, hidden_size, var_epsilon, residual); + + // Compute scale + vllm::vectorized::compute_dynamic_per_token_scales( + &token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor, + hidden_size, residual); + + // RMS Norm + Quant + if constexpr (std::is_same_v) { + vllm::vectorized::norm_and_quant( + out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + } else { + // FP8 - Do not invert token_scale for exact match with FBGemm + vllm::vectorized::norm_and_quant( + out, input, weight, rms, token_scale, hidden_size, residual); + } +} + +// RMS norm + quant kernel +template +__global__ void rms_norm_dynamic_per_token_quant_kernel( + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + // For vectorization, token_input and token_output pointers need to be + // aligned at 8-byte and 4-byte addresses respectively. + bool const can_vectorize = hidden_size % 4 == 0; + + if (can_vectorize) { + return rms_norm_dynamic_per_token_quant_vec( + out, scales, input, weight, scale_ub, var_epsilon, min_scaling_factor, + hidden_size, residual); + } + + float rms = 0.0f; + float token_scale = 0.0f; + + // Compute RMS + vllm::compute_rms(&rms, input, hidden_size, + var_epsilon, residual); + // Compute Scale + vllm::compute_dynamic_per_token_scales( + &token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor, + hidden_size, residual); + + // RMS Norm + Quant + if constexpr (std::is_same_v) { + vllm::norm_and_quant( + out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + } else { + // FP8 - Do not invert s_token_scale for exact match with FBGemm + vllm::norm_and_quant( + out, input, weight, rms, token_scale, hidden_size, residual); + } +} +} // namespace vllm + +// Residual add + RMS norm + dynamic per token +template +void rms_norm_dynamic_per_token_quant_dispatch( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& weight, // [hidden_size] + torch::Tensor& scales, // [num_tokens] + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional const& scale_ub, + std::optional& residual) { + int32_t hidden_size = input.size(-1); + int32_t num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const float min_scaling_factor = + out.dtype() == torch::kInt8 + ? std::numeric_limits::epsilon() + : 1.0f / (std::numeric_limits::max() * 512.f); + + if (residual.has_value()) { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { + vllm::rms_norm_dynamic_per_token_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, min_scaling_factor, hidden_size, + residual->data_ptr()); + }); + + } else { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { + vllm::rms_norm_dynamic_per_token_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, min_scaling_factor, hidden_size, nullptr); + }); + } +} + +void rms_norm_dynamic_per_token_quant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& weight, // [hidden_size] + torch::Tensor& scales, // [num_tokens] + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional scale_ub, std::optional residual) { + TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); + TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + + if (scale_ub.has_value()) { + TORCH_CHECK(out.dtype() == kFp8Type); + } + TORCH_CHECK(scales.dtype() == torch::kFloat32); + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] { + rms_norm_dynamic_per_token_quant_dispatch( + out, input, weight, scales, var_epsilon, scale_ub, residual); + }); +} diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh new file mode 100644 index 0000000000000..cec6b54edb569 --- /dev/null +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -0,0 +1,327 @@ +#pragma once + +/** + * __device__ layernorm utilities. + */ + +#include "quantization/vectorization.cuh" +#include "quant_conversions.cuh" + +#ifndef USE_ROCM + #include +#else + #include +#endif + +namespace vllm { + +// has_residual must be true, if residual is not a nullptr +template +__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, + int32_t const hidden_size, float const epsilon, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + // sum of squares + float ss = 0.0f; + + for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } + + ss += x * x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + + __shared__ float s_rms; + if (threadIdx.x == 0) { + s_rms = rsqrtf(ss / hidden_size + epsilon); + } + __syncthreads(); + + *rms = s_rms; +} + +template +__device__ void compute_dynamic_per_token_scales( + float* __restrict__ token_scale, float* __restrict__ all_token_scales, + scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, + float const rms, float const* __restrict__ scale_ub, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; + constexpr scalar_out_t qmax{std::numeric_limits::max()}; + + float block_absmax_val_maybe = 0.0f; + for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } + + x = static_cast(static_cast(x * rms) * weight[i]); + block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor); + s_token_scale = scale; // Shared memory store + all_token_scales[blockIdx.x] = scale; // Global output store + } + __syncthreads(); + + *token_scale = s_token_scale; +} + +template +__device__ void norm_and_quant(scalar_out_t* __restrict__ output, + scalar_t const* __restrict__ input, + scalar_t const* __restrict__ weight, + float const rms, float const scale, + int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; + + for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + residual[token_offset + i] = static_cast(x); + } + // Norm + x = static_cast(static_cast(x * rms) * weight[i]); + // Quant + output[token_offset + i] = + ScaledQuant::quant_fn(x, scale); + } +} + +namespace vectorized { + +// Compute 1.0/rms(input) +// hidden_size must be a multiple of 4 +template +__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, + int32_t const hidden_size, float const epsilon, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + + // Vectorized input/output to better utilize memory bandwidth. + vec4_t const* vec_input = + reinterpret_cast const*>(&input[token_offset]); + vec4_t const* vec_residual = nullptr; + if constexpr (has_residual) { + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + + // sum of squares + float ss = 0.0f; + + int32_t const num_vec_elems = hidden_size >> 2; + +#pragma unroll 4 + for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t in = vec_input[i]; + + vec4_t x; + x.x = static_cast(in.x); + x.y = static_cast(in.y); + x.z = static_cast(in.z); + x.w = static_cast(in.w); + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; + x.x += static_cast(r.x); + x.y += static_cast(r.y); + x.z += static_cast(r.z); + x.w += static_cast(r.w); + } + + ss += x.x * x.x; + ss += x.y * x.y; + ss += x.z * x.z; + ss += x.w * x.w; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + + __shared__ float s_rms; + if (threadIdx.x == 0) { + s_rms = rsqrtf(ss / hidden_size + epsilon); + } + __syncthreads(); + + *rms = s_rms; +} + +// Vectorized version of vllm::compute_dynamic_per_token_scales +// hidden_size must be a multiple of 4 +template +__device__ void compute_dynamic_per_token_scales( + float* __restrict__ token_scale, float* __restrict__ all_token_scales, + scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, + float const rms, float const* __restrict__ scale_ub, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; + + // Vectorized input/weight/residual to better utilize memory bandwidth. + vec4_t const* vec_input = + reinterpret_cast const*>(&input[token_offset]); + vec4_t const* vec_weight = + reinterpret_cast const*>(weight); + vec4_t const* vec_residual = nullptr; + if constexpr (has_residual) { + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + + constexpr scalar_out_t qmax{std::numeric_limits::max()}; + + int32_t const num_vec_elems = hidden_size >> 2; + float block_absmax_val_maybe = 0.0f; + +#pragma unroll 4 + for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; + x.x = static_cast(in.x); + x.y = static_cast(in.y); + x.z = static_cast(in.z); + x.w = static_cast(in.w); + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; + x.x += static_cast(r.x); + x.y += static_cast(r.y); + x.z += static_cast(r.z); + x.w += static_cast(r.w); + } + + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.x * rms) * w.x)); + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.y * rms) * w.y)); + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.z * rms) * w.z)); + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.w * rms) * w.w)); + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor); + s_token_scale = scale; // shared memory store + all_token_scales[blockIdx.x] = scale; // global output store + } + __syncthreads(); + + *token_scale = s_token_scale; +} + +// hidden_size must be a multiple of 4 +template +__device__ void norm_and_quant(scalar_out_t* __restrict__ output, + scalar_t const* __restrict__ input, + scalar_t const* __restrict__ weight, + float const rms, float const scale, + int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; + + // Vectorized input/output/weight/residual to better utilize memory bandwidth. + vec4_t const* vec_input = + reinterpret_cast const*>(&input[token_offset]); + vec4_t const* vec_weight = + reinterpret_cast const*>(weight); + q8x4_t* vec_output = + reinterpret_cast*>(&output[token_offset]); + vec4_t* vec_residual = nullptr; + if constexpr (has_residual) { + vec_residual = reinterpret_cast*>(&residual[token_offset]); + } + + int32_t const num_vec_elems = hidden_size >> 2; + +// TODO(luka/varun) extract into type-agnostic vectorized quant function to +// replace scaled_fp8_conversion_vec +#pragma unroll 4 + for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t const in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; + x.x = static_cast(in.x); + x.y = static_cast(in.y); + x.z = static_cast(in.z); + x.w = static_cast(in.w); + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; + x.x += static_cast(r.x); + x.y += static_cast(r.y); + x.z += static_cast(r.z); + x.w += static_cast(r.w); + // Update residual + r.x = static_cast(x.x); + r.y = static_cast(x.y); + r.z = static_cast(x.z); + r.w = static_cast(x.w); + vec_residual[i] = r; + } + + q8x4_t out; + out.x = ScaledQuant::quant_fn( + static_cast(x.x * rms) * w.x, scale); + out.y = ScaledQuant::quant_fn( + static_cast(x.y * rms) * w.y, scale); + out.z = ScaledQuant::quant_fn( + static_cast(x.z * rms) * w.z, scale); + out.w = ScaledQuant::quant_fn( + static_cast(x.w * rms) * w.w, scale); + vec_output[i] = out; + } +} + +} // namespace vectorized + +} // namespace vllm diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh new file mode 100644 index 0000000000000..f8a9872226a3a --- /dev/null +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -0,0 +1,81 @@ +#pragma once + +/** + * __device__ helper functions to deal with float -> quant datatype conversion + */ + +#include "quantization/vectorization.cuh" +// TODO(luka/varun):refactor common.cuh to use this file instead +#include "quantization/fp8/common.cuh" + +namespace vllm { + +// TODO(luka/varun): combine into common utilities for int8 +// (with int8_quant_kernels.cu) +static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) { +#ifdef USE_ROCM + static const float i8_min = + static_cast(std::numeric_limits::min()); + static const float i8_max = + static_cast(std::numeric_limits::max()); + // round + float dst = std::nearbyint(x); + // saturate + dst = std::clamp(dst, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) { + float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); + return static_cast(r); +} + +template +struct ScaledQuant; + +template +struct ScaledQuant< + quant_type_t, is_scale_inverted, + typename std::enable_if_t>> { + static __device__ __forceinline__ quant_type_t quant_fn(float const x, + float const scale) { + if constexpr (is_scale_inverted) { + return float_to_int8_rn(x * scale); + } else { + return float_to_int8_rn(x / scale); + } + } +}; + +template +struct ScaledQuant< + quant_type_t, is_scale_inverted, + typename std::enable_if_t>> { + static __device__ __forceinline__ quant_type_t quant_fn(float const x, + float const scale) { + if constexpr (is_scale_inverted) { + return float_to_fp8(x * scale); + } else { + return float_to_fp8(x / scale); + } + } +}; + +template +__device__ void scaled_quant_conversion(quant_type_t* __restrict__ output, + scalar_t const* __restrict__ input, + float const scale, int const tid, + int const num_elements, + int const step) { + for (int i = tid; i < num_elements; i += step) { + output[i] = ScaledQuant(input[i], scale); + } +} + +} // namespace vllm diff --git a/csrc/quantization/vectorization.cuh b/csrc/quantization/vectorization.cuh new file mode 100644 index 0000000000000..44c999130f756 --- /dev/null +++ b/csrc/quantization/vectorization.cuh @@ -0,0 +1,33 @@ +#pragma once +/** + * __device__ datatypes vectorized by 4 + */ + +// Include both AMD and NVIDIA fp8 types to avoid circular import +// TODO(luka/varun) use FP8_TYPE instead after refactoring +#include +#include + +namespace vllm { + +// Vectorization containers +template +struct __align__(8) vec4_t { + scalar_t x; + scalar_t y; + scalar_t z; + scalar_t w; +}; + +template +struct __align__(4) q8x4_t { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v); + quant_type_t x; + quant_type_t y; + quant_type_t z; + quant_type_t w; +}; + +} // namespace vllm diff --git a/csrc/sparse/cutlass/sparse_compressor_c3x.cu b/csrc/sparse/cutlass/sparse_compressor_c3x.cu new file mode 100644 index 0000000000000..218c5317b4de6 --- /dev/null +++ b/csrc/sparse/cutlass/sparse_compressor_c3x.cu @@ -0,0 +1,163 @@ +// clang-format will break include orders +// clang-format off +#include + +#include "sparse_scaled_mm_c3x.cuh" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +// clang-format on + +using namespace cute; +using namespace vllm; + +/// Make A structured sparse by replacing elements with 0 and compress it +template +bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a) { + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn || + a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16); + TORCH_CHECK(a.dim() == 2) + // Check for strides and alignment + TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity + TORCH_CHECK(a.stride(1) == 1) + + int m = a.size(0); + int k = a.size(1); + + // Sparse kernel setup; this kernel is not used for matmul, + // but just for setting up the compressor utility + // A matrix configuration + using ElementA = ElementA_; + using LayoutTagA = cutlass::layout::RowMajor; + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + // B matrix configuration + using ElementB = ElementA; + using LayoutTagB = cutlass::layout::ColumnMajor; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + // C/D matrix configuration + using ElementC = float; + using LayoutTagC = cutlass::layout::ColumnMajor; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + // Core kernel configurations + using ElementAccumulator = ElementAcc_; + using TileShape = Shape<_128, _128, _128>; + using TileShapeRef = Shape<_128, _128, _64>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = typename std::conditional< + std::is_same_v, + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum, + cutlass::gemm::KernelTmaWarpSpecialized>::type; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using ProblemShape = Shape; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC, + AlignmentC, ElementC, LayoutTagC, AlignmentC, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA, + LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideE = StrideA; + + using StrideA = Stride, int64_t>; + + // The n (=1) dimension does not matter for the compressor + typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1}; + + using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA; + using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE; + + using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; + using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig; + + // Offline compressor kernel + using CompressorUtility = + cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, ElementA, LayoutTagA, SparseConfig>; + + using CompressorKernel = + cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, ElementA, LayoutTagA, SparseConfig, + cutlass::arch::Sm90>; + + using Compressor = + cutlass::transform::device::TransformUniversalAdapter; + + auto [M, N, K, L] = prob_shape; + + StrideA stride_A; + stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + + CompressorUtility compressor_utility(prob_shape, stride_A); + + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + auto a_ptr = static_cast(a.data_ptr()); + + auto a_nzs_ptr = static_cast(a_nzs.data_ptr()); + auto a_meta_ptr = static_cast( + a_meta.data_ptr()); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + typename Compressor::Arguments arguments{ + prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}}; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + return true; +} + +bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a) { + if (a.dtype() == torch::kBFloat16) { + return cutlass_sparse_compress(a_nzs, a_meta, + a); + } else if (a.dtype() == torch::kFloat16) { + return cutlass_sparse_compress(a_nzs, a_meta, a); + } else if (a.dtype() == torch::kFloat8_e4m3fn) { + return cutlass_sparse_compress(a_nzs, a_meta, + a); + } else if (a.dtype() == torch::kInt8) { + return cutlass_sparse_compress(a_nzs, a_meta, a); + } + return false; +} \ No newline at end of file diff --git a/csrc/sparse/cutlass/sparse_compressor_entry.cu b/csrc/sparse/cutlass/sparse_compressor_entry.cu new file mode 100644 index 0000000000000..d23d937b6ac28 --- /dev/null +++ b/csrc/sparse/cutlass/sparse_compressor_entry.cu @@ -0,0 +1,42 @@ +#include + +#include +#include + +#include "cutlass_extensions/common.hpp" + +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X +bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a); +#endif + +bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a) { + // Checks for conformality + TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2); + TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) && + a_nzs.size(1) * 2 == a.size(1) && + a_meta.size(1) * 2 * 4 == a.size(1)); + // Considering elemsPerMetaElem = 8b / 2b_per_nz = 4 + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 && + a_meta.stride(1) == 1); // Row-major + TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression + + at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); + int32_t version_num = get_sm_version_num(); + + // Guard against compilation issues for sm90 kernels +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X + if (version_num >= 90) { + return cutlass_sparse_compress_sm90(a_nzs, a_meta, a); + } +#endif + + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_scaled_sparse_mm for a compute capability less than " + "CUDA device capability: ", + version_num); +} diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu new file mode 100644 index 0000000000000..b50e9a3a2c240 --- /dev/null +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -0,0 +1,303 @@ +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 +#include "sparse_scaled_mm_c3x.cuh" +// clang-format on + +using namespace cute; +using namespace vllm; + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_fp8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM256 = + typename sm90_fp8_config_M256::Cutlass3xGemm; + using Cutlass3xGemmM512 = + typename sm90_fp8_config_M512::Cutlass3xGemm; + + using Cutlass3xGemm1 = + typename sm90_fp8_config_1::Cutlass3xGemm; + using Cutlass3xGemm2 = + typename sm90_fp8_config_2::Cutlass3xGemm; + using Cutlass3xGemm3 = + typename sm90_fp8_config_3::Cutlass3xGemm; + using Cutlass3xGemm4 = + typename sm90_fp8_config_4::Cutlass3xGemm; + using Cutlass3xGemm5 = + typename sm90_fp8_config_5::Cutlass3xGemm; + using Cutlass3xGemm6 = + typename sm90_fp8_config_6::Cutlass3xGemm; + using Cutlass3xGemm7 = + typename sm90_fp8_config_7::Cutlass3xGemm; + using Cutlass3xGemm8 = + typename sm90_fp8_config_8::Cutlass3xGemm; + + uint32_t const n = bt_nzs.size(0); + uint32_t const m = a.size(0); // Batch size + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + if (mp2 <= 64) { + if (n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 4096 || n == 6144) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else if (mp2 <= 128) { + if (n == 4096) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 6144) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else if (mp2 <= 256) { + if (n == 4096) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 6144) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else { + if (n == 6144 || n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 4096) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } + + // Otherwise the default heuristic + if (mp2 <= 64) { + // n in [1, 64] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 128) { + // n in (64, 128] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 256) { + // n in (128, 256] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else { + // n in (256, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat16); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kBFloat16); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kInt8); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_int8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_int8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM32NBig = + typename sm90_int8_config_M32_NBig::Cutlass3xGemm; + using Cutlass3xGemmM32NSmall = + typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; + + uint32_t const n = out.size(1); + bool const is_small_n = n < 8192; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(32), next_pow_2(m)); // next power of 2 + + if (mp2 <= 32) { + // m in [1, 32] + if (is_small_n) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else if (mp2 <= 64) { + // m in (32, 64] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else { + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } +} + +template