diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 71f4e520135d4..d1f6105a47166 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,2 +1,2 @@ github: [vllm-project] -open_collective: [vllm] +open_collective: vllm diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 4f54eea564ecb..683b70cd89989 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -15,6 +15,8 @@ updates: allow: - dependency-type: "all" ignore: + - dependency-name: "*" + update-types: ["version-update:semver-patch"] - dependency-name: "torch" - dependency-name: "torchvision" - dependency-name: "xformers" @@ -24,9 +26,6 @@ updates: - dependency-name: "ray[adag]" - dependency-name: "lm-eval" groups: - patch-update: - applies-to: version-updates - update-types: ["patch"] minor-update: applies-to: version-updates update-types: ["minor"] diff --git a/.github/workflows/png-lint.yml b/.github/workflows/png-lint.yml new file mode 100644 index 0000000000000..4932af943a07b --- /dev/null +++ b/.github/workflows/png-lint.yml @@ -0,0 +1,37 @@ +name: Lint PNG exports from excalidraw +on: + push: + branches: + - "main" + paths: + - '*.excalidraw.png' + - '.github/workflows/png-lint.yml' + pull_request: + branches: + - "main" + paths: + - '*.excalidraw.png' + - '.github/workflows/png-lint.yml' + +env: + LC_ALL: en_US.UTF-8 + +defaults: + run: + shell: bash + +permissions: + contents: read + +jobs: + actionlint: + runs-on: ubuntu-latest + steps: + - name: "Checkout" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + + - name: "Run png-lint.sh to check excalidraw exported images" + run: | + tools/png-lint.sh diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 6d33096ca1d11..5e9381f712e10 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -54,13 +54,30 @@ def test_prefix(llm=None, sampling_params=None, prompts=None): print(f"cost time {end_time - start_time}") -def sample_requests( +@dataclasses.dataclass +class Request: + prompt: str + prompt_len: int + output_len: int + + +def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str: + vocab = tokenizer.get_vocab() + # Remove the special tokens. + vocab = { + k: v + for k, v in vocab.items() if k not in tokenizer.all_special_ids + } + return random.choices(list(vocab.values()), k=length) + + +def sample_requests_from_dataset( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, input_length_range: Tuple[int, int], fixed_output_len: Optional[int], -) -> List[Tuple[str, int, int]]: +) -> List[Request]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -77,31 +94,55 @@ def sample_requests( random.shuffle(dataset) min_len, max_len = input_length_range + assert min_len >= 0 and max_len >= min_len, "input_length_range too small" # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] + filtered_requests: List[Request] = [] + for i in range(len(dataset)): - if len(filtered_dataset) == num_requests: + if len(filtered_requests) == num_requests: break # Tokenize the prompts and completions. - prompt = dataset[i][0] - prompt_token_ids = tokenizer(prompt).input_ids + prompt_token_ids = tokenizer(dataset[i][0]).input_ids + prompt = tokenizer.decode(prompt_token_ids) completion = dataset[i][1] completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) - output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue + output_len = (len(completion_token_ids) + if fixed_output_len is None else fixed_output_len) if min_len <= prompt_len <= max_len: - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_requests.append(Request(prompt, prompt_len, output_len)) + + return filtered_requests + + +def sample_requests_from_random( + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + input_length_range: Tuple[int, int], + fixed_output_len: Optional[int], + prefix_len: int, +) -> List[Request]: - return filtered_dataset + requests = [] + prefix_token_ids = sample_tokens(tokenizer, prefix_len) + min_len, max_len = input_length_range + + for i in range(num_requests): + unique_part_token_ids = sample_tokens( + tokenizer, + random.randint(min_len - prefix_len, max_len - prefix_len)) + prompt_token_ids = prefix_token_ids + unique_part_token_ids + prompt = tokenizer.decode(prompt_token_ids) + prompt_len = len(prompt_token_ids) + assert (min_len <= prompt_len <= max_len + ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + requests.append(Request(prompt, prompt_len, fixed_output_len)) + return requests -def repeat_and_sort_requests(requests: List[Tuple[str, int, int]], +def repeat_and_sort_requests(requests: List[Request], repeat_count: int, sort: bool = False) -> List[str]: repeated_requests = requests * repeat_count @@ -109,7 +150,7 @@ def repeat_and_sort_requests(requests: List[Tuple[str, int, int]], repeated_requests.sort(key=lambda x: x[1]) else: random.shuffle(repeated_requests) - return [req[0] for req in repeated_requests] + return [req.prompt for req in repeated_requests] def main(args): @@ -117,9 +158,12 @@ def main(args): input_length_range = tuple(map(int, args.input_length_range.split(':'))) random.seed(args.seed) if args.dataset_path is not None: - print(f"Start to sample {args.num_prompts} prompts" + if args.prefix_len > 0: + raise ValueError("prefix-len is not supported when " + "dataset-path is provided.") + print(f"Start to sample {args.num_prompts} prompts " f"from {args.dataset_path}") - filtered_datasets = sample_requests( + filtered_requests = sample_requests_from_dataset( dataset_path=args.dataset_path, num_requests=args.num_prompts, tokenizer=tokenizer, @@ -127,9 +171,22 @@ def main(args): fixed_output_len=args.output_len, ) else: - prompt_len = len(tokenizer(PROMPT).input_ids) - filtered_datasets = [(PROMPT, prompt_len, args.output_len) - ] * args.num_prompts + print(f"Start to sample {args.num_prompts} prompts from random") + filtered_requests = sample_requests_from_random( + num_requests=args.num_prompts, + tokenizer=tokenizer, + input_length_range=input_length_range, + fixed_output_len=args.output_len, + prefix_len=args.prefix_len, + ) + + # Print some helpful stats of the requests. + print(f"Sampled {len(filtered_requests)} requests.") + prompt_lens = [req.prompt_len for req in filtered_requests] + print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}") + print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}") + print(f"Min Prompt Length: {min(prompt_lens)}") + print(f"Max Prompt Length: {max(prompt_lens)}") engine_args = EngineArgs.from_cli_args(args) @@ -137,8 +194,8 @@ def main(args): sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) - print("Testing filtered datasets") - prompts = repeat_and_sort_requests(filtered_datasets, + print("Testing filtered requests") + prompts = repeat_and_sort_requests(filtered_requests, repeat_count=args.repeat_count, sort=args.sort) @@ -161,20 +218,29 @@ def main(args): parser.add_argument('--output-len', type=int, default=10) parser.add_argument('--num-prompts', type=int, - default=1, + required=True, help="Number of the prompts sampled from dataset") parser.add_argument('--repeat-count', type=int, - default=100, + default=1, help='Number of times to repeat each prompt') parser.add_argument('--sort', action='store_true', help='Sort prompts by input length') parser.add_argument('--input-length-range', type=str, - default='128:256', + required=True, help='Range of input lengths for sampling prompts,' 'specified as "min:max" (e.g., "128:256").') + parser.add_argument( + "--prefix-len", + type=int, + default=0, + help="Specifies the length of a common prefix to be " + "added to the input prompt. The input-length-range will " + "subtract this length when filtering prompts. Only used " + "when dataset-path is not provided.", + ) parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 665b50bf18cf0..a0342d08f1db8 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -2,8 +2,10 @@ import copy import itertools import math +import os import pickle as pkl import time +from dataclasses import dataclass from itertools import product from typing import Callable, Iterable, List, Optional, Tuple @@ -15,11 +17,12 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales) + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales, + marlin_zero_points) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, pack_rows, quantize_weights) + pack_rows, quantize_weights) from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import FlexibleArgumentParser @@ -27,149 +30,349 @@ DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] DEFAULT_TP_SIZES = [1] +NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False) + +if NVTX_PROFILE: + import nvtx + + +def terse_type_name(dt): + return { + torch.bfloat16: "bf16", + torch.float16: "fp16", + torch.int8: "int8", + torch.float8_e4m3fn: "fp8", + torch.bfloat16: "bf16", + torch.float: "float", + torch.int: "int", + }[dt] + + +@dataclass +class BenchmarkTensors: + w_ref: torch.Tensor + a: torch.Tensor + + w_q: torch.Tensor + group_size: Optional[int] + wtype: ScalarType + w_g_s: torch.Tensor + w_g_zp: Optional[torch.Tensor] + w_ch_s: Optional[torch.Tensor] + w_tok_s: Optional[torch.Tensor] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + group_zero_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +def rand_data(shape, dtype=torch.float16, scale=1): + if dtype.is_floating_point: + return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype) + else: + return torch.randint(-15, 15, shape, dtype=dtype, device="cuda") + + +def quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, + wtype, + group_size=group_size, + zero_points=zero_points, + # to match how the kernel applies zps + ref_zero_points_after_scales=True) -def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor: w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) - w_q = w_q.t().contiguous().t() # make col major - return ops.machete_prepack_B(w_q, wtype) + return w_ref, w_q, w_s, w_zp -def make_bench_tensors( - atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int, - k: int -) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor, - torch.tensor]]]: - assert wtype.is_integer(), "TODO: support floating point weights" +def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig, + group_size: Optional[int]) -> List[BenchmarkTensors]: + m, n, k = shape # we want to make sure that weights don't fit into L2 cache between runs so # we construct enough weights to exceed L2 cache, which is 50mb on a H100 # so we target total weight size > 2*50mb - num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits)) - - a = torch.randn((m, k), device="cuda", dtype=atype) * 5 - weights = [ - torch.randn((k, n), device="cuda", dtype=atype) - for _ in range(num_weights) - ] - quanitized_weights = [ - quantize_weights(w, wtype, group_size) for w in weights - ] - - return a, quanitized_weights + num_weights = math.ceil(2 * 50 * 1024**2 * 8 / + (k * n * types.weight_type.size_bits)) + + a = rand_data((m, k), types.act_type, scale=5) + + benchmark_tensors: List[BenchmarkTensors] = [] + for _ in range(num_weights): + w = rand_data((k, n), types.act_type, scale=5) + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, w_zp = quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + types.group_zero_type is not None) + + if not a.dtype.is_floating_point: + aiinfo = torch.iinfo(a.dtype) + w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max) + + w_ref = w_ref.to(torch.float32) + + w_ch_s = None if types.channel_scale_type is None else\ + rand_data((n,), types.channel_scale_type) + w_tok_s = None if types.token_scale_type is None else\ + rand_data((m,), types.token_scale_type) + + benchmark_tensors.append( + BenchmarkTensors(w_ref=w_ref, + a=a, + w_q=w_q_packed, + wtype=types.weight_type, + w_g_s=w_s, + w_g_zp=w_zp, + group_size=group_size, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s)) + + return benchmark_tensors + + +def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable: + a = bt.a + w = bt.w_ref.to(bt.a.dtype) # use float reference tensor + if a.dtype not in [torch.float16, torch.bfloat16]: + a = a.to(torch.float16) + w = w.to(torch.float16) + return lambda: torch.matmul(a, w) + + +def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable: + if bt.w_ch_s is not None and bt.w_tok_s is not None: + scale_a = bt.w_tok_s.to(torch.float32) + scale_b = bt.w_ch_s.to(torch.float32) + else: + scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) + scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) + w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t() + return lambda: ops.cutlass_scaled_mm( + bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16) + + +def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: + device = bt.a.device + + workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + if bt.w_g_zp is None: + w_zp = torch.empty(0, dtype=torch.int, device=device) + else: + w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.wtype.size_bits) + + if bt.group_size is None: + w_s = torch.tensor([], device="cuda", dtype=torch.half) + else: + w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.group_size) + + sort_indices = torch.empty(0, dtype=torch.int, device=device) + g_idx = torch.empty(0, dtype=torch.int, device=device) + w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.wtype.size_bits) + + if bt.a.dtype.is_floating_point: + assert bt.w_ch_s is None + assert bt.w_tok_s is None + assert bt.group_size is not None + + fn = lambda: ops.gptq_marlin_gemm(a=bt.a, + b_q_weight=w_q, + b_scales=w_s, + b_zeros=w_zp, + g_idx=g_idx, + perm=sort_indices, + workspace=workspace.scratch, + b_q_type=bt.wtype, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0], + is_k_full=True) + else: + assert bt.a.dtype == torch.int8 + assert bt.wtype == scalar_types.uint4b8 + + if bt.w_ch_s is not None: + s_ch = bt.w_ch_s.to(torch.float32) + else: + s_ch = torch.ones(bt.w_ref.shape[1], + dtype=torch.float32, + device=device) + + if bt.w_tok_s is not None: + s_tok = bt.w_tok_s.to(torch.float32) + else: + s_tok = torch.ones(bt.a.shape[0], + dtype=torch.float32, + device=device) + + fn = lambda: ops.marlin_qqq_gemm(a=bt.a, + b_q_weight=w_q, + s_group=w_s, + s_tok=s_tok, + s_ch=s_ch, + workspace=workspace.scratch, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0]) + + return fn + + +def machete_create_bench_fn(bt: BenchmarkTensors, + out_type=torch.dtype, + schedule=None) -> Callable: + w_q = bt.w_q.t().contiguous().t() # make col major + w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype, + None if bt.w_g_s is None else bt.w_g_s.dtype) + + w_g_zp = bt.w_g_zp + if w_g_zp is not None: + w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype)) + + return lambda: ops.machete_mm( + a=bt.a, + b_q=bt.w_q, + b_type=bt.wtype, + b_group_scales=bt.w_g_s, + b_group_zeros=w_g_zp, + b_group_size=bt.group_size, + b_channel_scales=bt.w_ch_s, + a_token_scales=bt.w_tok_s, + out_type=out_type, + schedule=schedule, + ) # impl - # bench -def bench_fn(label: str, sub_label: str, description: str, - fn: Callable) -> TMeasurement: - min_run_time = 1 - return TBenchmark.Timer( - stmt="fn()", + +def bench_fns(label: str, sub_label: str, description: str, + fns: List[Callable]): + + min_run_time = 1 if not NVTX_PROFILE else 0.1 + res = TBenchmark.Timer( + stmt=""" + for fn in fns: + fn() + """, globals={ - "fn": fn + "fns": fns }, label=label, sub_label=sub_label, description=description, ).blocked_autorange(min_run_time=min_run_time) + if NVTX_PROFILE: + with nvtx.annotate("mm-bench"), nvtx.annotate( + f"{label}|{sub_label}|{description}"): + fns[0]() -def loop_over_weights( - a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor, - torch.tensor, torch.tensor]], - fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor], - None]): - for w_ref, w_q, w_s, _ in weights: - fn(a, w_ref, w_q, w_s) + return res _SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None _SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None -def bench(atype: torch.dtype, - wtype: ScalarType, +def bench(types: TypeConfig, group_size: int, m: int, k: int, n: int, label: str, sub_label: str, - benchmark_marlinv1: bool = True, - sweep_schedules: bool = True) -> Iterable[TMeasurement]: - global _SWEEP_SCHEDULES_RESULTS - - a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k) - sub_label += f", L={len(weights)}" - - weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp) - for w_ref, w_q, w_s, w_zp in weights] + sweep_schedules: bool = True) -> List[TMeasurement]: + benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) + sub_label += f", L={len(benchmark_tensors)}" + + name_type_string = f"W{types.weight_type}"+\ + f"-A{terse_type_name(types.act_type)}" + if types.group_scale_type is not None: + name_type_string += f"-GS{terse_type_name(types.group_scale_type)}" + if types.group_zero_type is not None: + name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}" + if group_size is not None: + name_type_string += f"-G{group_size}" + if types.channel_scale_type is not None: + name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}" + if types.token_scale_type is not None: + name_type_string += f"-TS{terse_type_name(types.token_scale_type)}" timers = [] # pytorch impl timers.append( - bench_fn( - label, sub_label, "torch.matmul", lambda: loop_over_weights( - a, - weights, - lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref), - ))) + bench_fns( + label, sub_label, "torch.matmul (fp16)", + [torch_matmul_f16_create_bench_fn(bt) + for bt in benchmark_tensors])) - if benchmark_marlinv1: - w_ref = weights[0][0] - - w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device) - sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device) - g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device) - - def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor: - w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape) - return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape, - wtype.size_bits) - - def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: - return marlin_permute_scales(w_s, *w_ref.shape, group_size) - - weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q), - marlinv1_permute_scales(w_s), w_zp) - for w_ref, w_q, w_s, w_zp in weights] - - workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) - - # marlinv1 + if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn: + timers.append( + bench_fns( + label, sub_label, + f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [ + cutlass_scaled_mm_create_bench_fn(bt) + for bt in benchmark_tensors + ])) + + if types.act_type != torch.float8_e4m3fn: timers.append( - bench_fn( - label, sub_label, "marlin_orig", lambda: loop_over_weights( - a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops. - gptq_marlin_gemm(a, - w_q, - w_s, - w_zp_empty, - g_idx, - sort_indices, - workspace.scratch, - wtype, - size_m=a.shape[0], - size_n=w_ref.shape[1], - size_k=w_ref.shape[0], - is_k_full=True)))) + bench_fns(label, sub_label, f"marlin ({name_type_string})", + [marlin_create_bench_fn(bt) + for bt in benchmark_tensors])) # machete timers.append( - bench_fn( - label, sub_label, "machete_heuristic", lambda: loop_over_weights( - a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm( - a, w_q, wtype, b_scales=w_s, b_group_size=group_size)))) + bench_fns(label, sub_label, f"machete ({name_type_string})", [ + machete_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ])) if sweep_schedules: + global _SWEEP_SCHEDULES_RESULTS + print("Finding best schedule for machete") best = None best_schedule = None - schedules = ops.machete_supported_schedules(wtype) + schedules = ops.machete_supported_schedules( + a_type=types.act_type, + b_type=types.weight_type, + group_scales_type=types.group_scale_type, + group_zeros_type=types.group_zero_type, + token_scales_type=types.token_scale_type, + channel_scales_type=types.channel_scale_type, + out_type=types.output_type) + + if schedules is None or len(schedules) == 0: + raise ValueError("No schedules found to sweep") + for schedule in reversed(schedules): schedule_M = int(schedule.split("_")[0].split("x")[1]) @@ -177,16 +380,11 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: continue - def run(a, _, w_q, w_s, schedule=schedule): - ops.machete_gemm(a, - w_q, - wtype, - w_s, - b_group_size=group_size, - schedule=schedule) - - res = bench_fn(label, sub_label, "machete_best", - lambda: loop_over_weights(a, weights_machete, run)) + res = bench_fns(label, sub_label, "machete_best", [ + machete_create_bench_fn( + bt, out_type=types.output_type, schedule=schedule) + for bt in benchmark_tensors + ]) results_row = { "M": m, @@ -213,25 +411,33 @@ def run(a, _, w_q, w_s, schedule=schedule): # runner -def print_timers(timers: Iterable[TMeasurement]): +def print_timers(timers: List[TMeasurement]): compare = TBenchmark.Compare(timers) compare.print() -def run(dtype: torch.dtype, sweep_schedules: bool, - MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: +def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + types = TypeConfig( + act_type=args.act_type, + weight_type=scalar_types.uint4b8 if args.group_zero_type is None \ + else scalar_types.uint4, + output_type=args.out_type, + group_scale_type=args.group_scale_type, + group_zero_type=args.group_zero_type, + channel_scale_type=args.channel_scale_type, + token_scale_type=args.token_scale_type, + ) - results = [] + results: List[TMeasurement] = [] for m, k, n in MKNs: - timers = bench(dtype, - scalar_types.uint4b8, - 128, + timers = bench(types, + args.group_size, m, k, n, - f"{dtype}-gemm", + f"{args.act_type}-gemm", f"MKN=({m}x{k}x{n})", - sweep_schedules=sweep_schedules) + sweep_schedules=args.sweep_schedules) print_timers(timers) results.extend(timers) @@ -240,7 +446,7 @@ def run(dtype: torch.dtype, sweep_schedules: bool, # output makers def make_output( - data: Iterable[TMeasurement], + data: List[TMeasurement], MKNs: Iterable[Tuple[int, int, int]], base_description: str, timestamp=None, @@ -262,7 +468,6 @@ def run_square_bench(args): dim_sizes = list( range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) - data = run(args.dtype, args.sweep_schedules, MKNs) make_output(data, MKNs, f"square_bench-{args.dtype}") @@ -306,33 +511,49 @@ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: for k, n in KNs: MKNs.append((m, k, n)) - data = run(args.dtype, args.sweep_schedules, MKNs) + data = run(args, MKNs) model_bench_data.append(data) + type_string = f"{args.act_type}" + # Print all results for data, model_tp in zip(model_bench_data, models_tps): model, tp_size = model_tp - print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print(f"== Results {type_string} {model}-TP{tp_size} ====") print_timers(data) - timestamp = int(time.time()) + timestr = time.strftime("%Y%m%d-%H%M%S") - all_data = [] + all_results = [] for d in model_bench_data: - all_data.extend(d) + all_results.extend(d) + # pickle all data - with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: - pkl.dump(all_data, f) + with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f: + args_dict = vars(args) + args_dict.pop("func") + pkl.dump({ + "args": args_dict, + "results": all_results, + }, f) if __name__ == "__main__": def to_torch_dtype(dt): - if dt == "bfloat16": - return torch.bfloat16 - if dt == "float16": - return torch.float16 - raise ValueError("unsupported dtype") + return { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "int8": torch.int8, + "float8_e4m3fn": torch.float8_e4m3fn, + "int": torch.int, + "float": torch.float, + }[dt] + + class ToTorchDtype(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, to_torch_dtype(values)) parser = FlexibleArgumentParser( description=""" @@ -352,12 +573,42 @@ def to_torch_dtype(dt): """, # noqa: E501 formatter_class=argparse.RawTextHelpFormatter, ) - parser.add_argument( - "--dtype", - type=to_torch_dtype, + "--act-type", + action=ToTorchDtype, required=True, - help="Available options are ['bfloat16', 'float16']", + choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'], + ) + parser.add_argument( + "--group-scale-type", + action=ToTorchDtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--group-zero-type", + type=to_torch_dtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--channel-scale-type", + action=ToTorchDtype, + choices=['float'], + ) + parser.add_argument( + "--token-scale-type", + action=ToTorchDtype, + choices=['float'], + ) + parser.add_argument( + "--out-type", + action=ToTorchDtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--group-size", + type=int, + help="Available options are ['None', '-1', '128'], default=128", + default=128, ) parser.add_argument( "--sweep-schedules", diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py index de608fd05af70..7d0bd84150a27 100644 --- a/benchmarks/kernels/graph_machete_bench.py +++ b/benchmarks/kernels/graph_machete_bench.py @@ -20,10 +20,11 @@ args = parser.parse_args() with open(args.filename, 'rb') as f: - data: List[TMeasurement] = pickle.load(f) + data = pickle.load(f) + raw_results: List[TMeasurement] = data["results"] results = defaultdict(lambda: list()) - for v in data: + for v in raw_results: result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label) if result is not None: KN = result.group(1) diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index 25ec9d6028627..51f24f3ba1774 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -40,4 +40,10 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], + "meta-llama/Llama-3.1-405b-hf": [ + ([16384, 18432], 1), + ([16384, 16384], 0), + ([16384, 106496], 1), + ([53248, 16384], 0), + ], } diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh index 1842fab8b2cac..f61fe3ceb978a 100644 --- a/csrc/cutlass_extensions/cute_utils.cuh +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { // is the layout f(x) = x template CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) { return true; - else { + } else { constexpr auto coalesced_layout = coalesce(Layout{}); if constexpr (rank(coalesced_layout) == 1 && stride<0>(coalesced_layout) == 1) { diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp similarity index 99% rename from csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp rename to csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp index d407d66ab2aa6..7aa87feb4cce2 100644 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp @@ -52,6 +52,7 @@ // clang-format off #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cute/tensor.hpp" namespace cutlass::epilogue::threadblock { diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp similarity index 100% rename from csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp rename to csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp new file mode 100644 index 0000000000000..c69e87999ae71 --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -0,0 +1,317 @@ +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs. + + Epilogues must contain a public type named EVTCompute of type Sm80EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c2x { + +using namespace cute; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + template + using ColOrScalarLoad = + cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = + cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using RowOrZeroLoad = + cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + // it would technically work but no use case as data_ptr is never nullptr + static_assert(!std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(c10::optional const& tensor) { + static_assert(std::is_same_v>); + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch._scaled_mm. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : protected ScaledEpilogueBase { + protected: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +}; // namespace vllm::c2x \ No newline at end of file diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp new file mode 100644 index 0000000000000..95764ecddc79f --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -0,0 +1,315 @@ +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c3x { + +using namespace cute; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + template + using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>>; + + // Don't want to support nullptr by default + template + using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // Don't want to support nullptr by default + template + using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + static_assert(!std::is_same_v> && + !std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(c10::optional const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch.scaled_mm_. + + A and B may be both either int8 or fp8_e4m3. A can be + quantized per-tensor or per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +}; // namespace vllm::c3x \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index 4fcfcd311aa91..a5beea1a35e49 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -35,6 +35,35 @@ class MixedInputKernelScheduleType(enum.Enum): } } +VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = { + **DataTypeSize, # type: ignore + **{ + VLLMDataType.u4b8: 4, + VLLMDataType.u8b128: 8, + } +} + +VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = { + VLLMDataType.u4b8: "vllm::kU4B8", + VLLMDataType.u8b128: "vllm::kU8B128", + DataType.u4: "vllm::kU4", + DataType.u8: "vllm::kU8", + DataType.s4: "vllm::kS4", + DataType.s8: "vllm::kS8", + DataType.f16: "vllm::kFloat16", + DataType.bf16: "vllm::kBfloat16", +} + +VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { + DataType.u8: "at::ScalarType::Byte", + DataType.s8: "at::ScalarType::Char", + DataType.e4m3: "at::ScalarType::Float8_e4m3fn", + DataType.s32: "at::ScalarType::Int", + DataType.f16: "at::ScalarType::Half", + DataType.bf16: "at::ScalarType::BFloat16", + DataType.f32: "at::ScalarType::Float", +} + VLLMKernelScheduleTag: Dict[Union[ MixedInputKernelScheduleType, KernelScheduleType], str] = { **KernelScheduleTag, # type: ignore diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh index 2ad914f8e9868..90f226cf64c0a 100644 --- a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -3,6 +3,7 @@ #include "cutlass/numeric_conversion.h" #include "cutlass_extensions/vllm_custom_types.cuh" #include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/vllm_type_utils.cuh" // this file extends: // https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h @@ -28,8 +29,19 @@ struct InterleavedNumericArrayConverter { CUTLASS_DEVICE static result_type convert(source_type const& source) { - CUTE_INVALID_CONTROL_PATH( - "InterleavedNumericArrayConverter not implemented\n"); + if (cute::elect_one_sync()) { + if constexpr (std::is_same_v) { + printf( + "Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n", + nameof_v, nameof_v, N); + } else { + printf( + "Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not " + "implemented\n", + nameof_v, nameof_v, N, size(IlvBlkLayout{})); + } + __brkpt(); + } return {}; } @@ -56,11 +68,6 @@ struct InterleavedNumericArrayConverter< result_type operator()(source_type const& s) const { return convert(s); } }; -// TODO (LucasWilkinson): Implement -// for Array <= Array - -// .... - template struct ArrayConverterPacked32Bit { using result_type = Array; @@ -86,14 +93,16 @@ struct ArrayConverterPacked32Bit { using ScalarConverter = NumericConverter; template - CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) { + CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) { if constexpr (sizeof(PackedSrc) == 1) { - return static_cast(reinterpret_cast(source)); + return Array{reinterpret_cast(src)}; } else if constexpr (sizeof(PackedSrc) == 2) { - return static_cast(reinterpret_cast(source)); + return Array{reinterpret_cast(src)}; + } else if constexpr (sizeof(PackedSrc) == 4) { + return Array{reinterpret_cast(src)}; } else { - static_assert(sizeof(PackedSrc) == 4); - return reinterpret_cast(source); + static_assert(sizeof(PackedSrc) == 8); + return reinterpret_cast const&>(src); } } @@ -110,7 +119,7 @@ struct ArrayConverterPacked32Bit { static_assert(std::is_same_v); static_assert(std::is_same_v); - return RegConvert32bit::template convert(to_reg(source)); + return RegConvert32bit::template convert(to_regs(source)); } friend class detail::VectorizedConverter; @@ -140,6 +149,131 @@ struct ArrayConverterPacked32Bit { } }; +// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed +// into 2 32bit register. +template +CUTLASS_DEVICE cutlass::AlignedArray lut_4bit_to_8bit_convert( + uint32_t src) { + cutlass::AlignedArray r; + // Determines if the value is in the top half of the LUT if set or + // (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move + // into bit position 0x4 of each nibble so when or'd with final_prmt_base it + // selects the correct candidate. When elements in final_prmt_base + // are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements + // are < 0x4, the low candidate is selected (i.e. LUT[0:7]) + uint32_t high_bit = (src & 0x88888888) >> 1; + + // `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT + // (selects correct high or low candidate) + const uint32_t final_prmt_base = 0x32103210; + + // Ignore the high bit when indexing into LUT, for each 4bit value + // we index into both the high and low candidates then use + // high_bit | final_prmt_base to select the correct candidate + uint32_t lut_idx = (src & 0x77777777); + + auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) { + return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) | + (uint32_t(d) << 24); + }; + + static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3); + static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7); + static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11); + static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) { + uint32_t final_prmt_idx = final_prmt_base | high_bit; + + // This uses a look up table to convert packed int4s to packed int8s, + // using the int4 value as the index to prmt. It first select both the + // high and low candidates, then uses the high bit (i.e. `high_bit`) to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 low, high;\n" + " prmt.b32 low, %1, %2, %5;\n" + " prmt.b32 high, %3, %4, %5;\n" + " prmt.b32 %0, low, high, %6;\n" + "}\n" + : "=r"(r[ii]) + : "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx), + "r"(final_prmt_idx)); + } + + return r; +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s + auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, // + 0xFC, 0xFD, 0xFE, 0xFF, // + 0x00, 0x01, 0x02, 0x03, // + 0x04, 0x05, 0x06, 0x07>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s + auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, // + 0xC8, 0xC4, 0xC0, 0xB8, // + 0x00, 0x38, 0x40, 0x44, // + 0x48, 0x4A, 0x4C, 0x4E>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + // for Array <= Array template struct NumericArrayConverter { @@ -148,7 +282,8 @@ struct NumericArrayConverter { struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -249,7 +384,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -338,7 +474,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -417,7 +554,8 @@ struct NumericArrayConverter { struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; // Hold output FP16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray { private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; PackedResultType r; // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of @@ -513,7 +652,8 @@ struct NumericArrayConverter { private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src_reg = src_[0]; // Hold output BF16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -671,7 +812,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -788,6 +930,61 @@ struct NumericArrayConverter { #endif +// for Array <= Array +// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + // FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 + template + CUTLASS_DEVICE static PackedResultType convert( + Array src) { + // Hold output int8s in reg. We need 1 reg for every 4 elements + using RegArray = cutlass::AlignedArray< + uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>; + RegArray r; + + static constexpr uint32_t MAGIC_BIAS_ = 0x64806480; + auto MAGIC_BIAS = *reinterpret_cast(&MAGIC_BIAS_); + + *reinterpret_cast(&src[0]) = + __hadd2(*reinterpret_cast(&src[0]), MAGIC_BIAS); + + if constexpr (src_regs > 1) { + *reinterpret_cast(&src[1]) = + __hadd2(*reinterpret_cast(&src[1]), MAGIC_BIAS); + } + + static_assert(PackedResultType::kElements <= 4); + uint32_t uint8s; + static constexpr uint32_t MASK_0246 = 0x6420; + static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(uint8s) + : "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]), + "n"(MASK_0246)); + + uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK); + + return reinterpret_cast(int8s); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/csrc/cutlass_extensions/vllm_type_utils.cuh b/csrc/cutlass_extensions/vllm_type_utils.cuh new file mode 100644 index 0000000000000..500ed508c8303 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_type_utils.cuh @@ -0,0 +1,42 @@ +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" +#include "cuda_bf16.h" + +#include "cutlass_extensions/vllm_custom_types.cuh" + +namespace cutlass { + +template +struct nameof { + static constexpr char const* value = "unknown"; +}; + +template +inline constexpr auto nameof_v = nameof::value; + +#define NAMEOF_TYPE(T) \ + template <> \ + struct nameof { \ + static constexpr char const* value = #T; \ + }; + +NAMEOF_TYPE(float_e4m3_t) +NAMEOF_TYPE(float_e5m2_t) +NAMEOF_TYPE(half_t) +NAMEOF_TYPE(nv_bfloat16) +NAMEOF_TYPE(bfloat16_t) +NAMEOF_TYPE(float) + +NAMEOF_TYPE(int4b_t) +NAMEOF_TYPE(int8_t) +NAMEOF_TYPE(int32_t) +NAMEOF_TYPE(int64_t) + +NAMEOF_TYPE(vllm_uint4b8_t) +NAMEOF_TYPE(uint4b_t) +NAMEOF_TYPE(uint8_t) +NAMEOF_TYPE(vllm_uint8b128_t) +NAMEOF_TYPE(uint32_t) +NAMEOF_TYPE(uint64_t) + +}; // namespace cutlass \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index ee801e16573d4..dbb72e8bbd3f5 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -8,6 +8,10 @@ #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" #include "scaled_mm_c2x_sm89_int8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp" + +using namespace vllm; + /* This file defines quantized GEMM operations using the CUTLASS 2.x API, for NVIDIA GPUs with SM versions prior to sm90 (Hopper). @@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm75_dispatch( + return cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm75_dispatch( + return cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales); } } @@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } @@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm80_dispatch( + return cutlass_gemm_sm80_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm80_dispatch( + return cutlass_gemm_sm80_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales); } } @@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } @@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm89_int8_dispatch( + return cutlass_gemm_sm89_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } else { assert(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm89_int8_dispatch( + return cutlass_gemm_sm89_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } } else { @@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm89_fp8_dispatch< - cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>( + return cutlass_gemm_sm89_fp8_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm89_fp8_dispatch( + return cutlass_gemm_sm89_fp8_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales); } } @@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index 6329ff63623e2..d03242f44ab1d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -21,7 +21,6 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" -#include "broadcast_load_epilogue_c2x.hpp" #include "common.hpp" // clang-format on @@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel { #endif } }; - -/* - * This class provides the common load descriptors for the - * ScaledEpilogue[...] classes - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - - template - using ColOrScalarLoad = - cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; - - template - using RowOrScalarLoad = - cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - template - using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; - - template - using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - template - using RowOrZeroLoad = - cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - // This utility function constructs the arguments for the load descriptors - // from a tensor. It can handle both row and column, as well as row/column or - // scalar cases. - template - static auto args_from_tensor(torch::Tensor const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = static_cast(tensor.data_ptr()); - if constexpr (std::is_same_v> || - std::is_same_v>) { - return Arguments{data_ptr, tensor.numel() != 1}; - } else { - // it would technically work but no use case as data_ptr is never nullptr - static_assert(!std::is_same_v>); - return Arguments{data_ptr}; - } - } - - // This overload handles the case where there might not be a tensor, in which - // case a nullptr is passed and a constant (0) is used. - template - static auto args_from_tensor(c10::optional const& tensor) { - static_assert(std::is_same_v>); - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; - return Arguments{data_ptr}; - } -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch._scaled_mm. - - A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; - } -}; - -/* - * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. - * This bias can also be used in the per-tensor azp case, where the activation - * zero point (azp) is used to compute an azp correction term, - * which is folded into the bias. - * - * The bias tensor must be per-output channel. - * ScaleA and ScaleB can be per-tensor or per-token/per-channel. - */ -template -struct ScaledEpilogueBias - : protected ScaledEpilogueBase { - protected: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; - } -}; - -/* - * This epilogue directly supports per-tensor azp in int32 form. - * As opposed to the per-token epilogue below, this epilogue only has an azp_adj - * term, which should already be multiplied with the scalar azp. - * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzp - : protected ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowOrZeroLoad; - - // This is the full AZP term, azp * J @ B, shape (1,n) - using AzpWithAdj = typename SUPER::template RowLoad; - - // Compute float(accum - azp_adj), both operands are int32_t - using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - -/* - * This epilogue supports per-token azp by computing and applying - * the correction term using a rank-1 update. If the term were materialized, - * it would require O(m*n) space, and this way it only requires O(m+n) space. - * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero - * point for each row of A. - * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzpToken - : protected ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowOrZeroLoad; - - // Per-token azp term, shape (m,1) - using Azp = typename SUPER::template ColLoad; - - // This is the AZP adjustment term, J @ B, shape (1,n) - using AzpAdj = typename SUPER::template RowLoad; - - // Compute azp * azp_adj - using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, int32_t, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::threadblock::Sm80EVT; - - // Compute float(accum - azp*azp_adj), all operands are int32_t - using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAcc = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_args = SUPER::template args_from_tensor(azp); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - template typename ArchGuard, typename ElementAB_, typename ElementD_, template typename Epilogue_, typename TileShape, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 292c9e4b34e1c..33581a63d4c3d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -23,11 +23,12 @@ #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "broadcast_load_epilogue_c3x.hpp" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "common.hpp" // clang-format on using namespace cute; +using namespace vllm; /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for @@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel { #endif } }; - -/* - * This class provides the common load descriptors for the - * ScaledEpilogue[...] classes - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - - template - using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>>; - - template - using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>>; - - // Don't want to support nullptr by default - template - using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; - - // Don't want to support nullptr by default - template - using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; - - // This utility function constructs the arguments for the load descriptors - // from a tensor. It can handle both row and column, as well as row/column or - // scalar cases. - template - static auto args_from_tensor(torch::Tensor const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = static_cast(tensor.data_ptr()); - if constexpr (std::is_same_v> || - std::is_same_v>) { - return Arguments{data_ptr, tensor.numel() != 1}; - } else { - static_assert(!std::is_same_v> && - !std::is_same_v>); - return Arguments{data_ptr}; - } - } - - // This overload handles the case where there might not be a tensor, in which - // case a nullptr is passed and a constant (0) is used. - template - static auto args_from_tensor(c10::optional const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; - static_assert(std::is_same_v> || - std::is_same_v>); - return Arguments{data_ptr}; - } -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch.scaled_mm_. - - A and B may be both either int8 or fp8_e4m3. A can be - quantized per-tensor or per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; - } -}; - -/* - * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. - * This bias can also be used in the per-tensor azp case, where the activation - * zero point (azp) is used to compute an azp correction term, - * which is folded into the bias. - * - * The bias tensor must be per-output channel. - * ScaleA and ScaleB can be per-tensor or per-token/per-channel. - */ -template -struct ScaledEpilogueBias - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - - using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; - } -}; - -/* - * This epilogue directly supports per-tensor azp in int32 form. - * As opposed to the per-token epilogue below, this epilogue only has an azp_adj - * term, which should already be multiplied with the scalar azp. - * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzp - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - // This is the full AZP term, azp * J @ B, shape (1,n) - using AzpWithAdj = typename SUPER::template RowLoad; - - // Compute float(accum - azp_adj), both operands are int32_t - using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - -/* - * This epilogue supports per-token azp by computing and applying - * the correction term using a rank-1 update. If the term were materialized, - * it would require O(m*n) space, and this way it only requires O(m+n) space. - * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero - * point for each row of A. - * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzpToken - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - // Per-token azp term, shape (m,1) - using Azp = typename SUPER::template ColLoad; - - // This is the AZP adjustment term, J @ B, shape (1,n) - using AzpAdj = typename SUPER::template RowLoad; - - // Compute azp * azp_adj - using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, int32_t, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::fusion::Sm90EVT; - - // Compute float(accum - azp*azp_adj), all operands are int32_t - using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAcc = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_args = SUPER::template args_from_tensor(azp); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, @@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == c.dtype(), "currently bias dtype must match output dtype ", c.dtype()); - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( c, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm90_epilogue(c, a, b, a_scales, - b_scales); + return cutlass_scaled_mm_sm90_epilogue( + c, a, b, a_scales, b_scales); } } @@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index d126af1849024..ac63afe79a255 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -3,8 +3,10 @@ import os import shutil from collections.abc import Iterable -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from copy import deepcopy +from dataclasses import dataclass, fields +from functools import reduce +from typing import Dict, List, Optional, Tuple, Union import jinja2 # yapf conflicts with isort for this block @@ -14,7 +16,10 @@ MixedInputKernelScheduleType, TileSchedulerTag, TileSchedulerType, VLLMDataType, - VLLMDataTypeNames, VLLMDataTypeTag, + VLLMDataTypeNames, + VLLMDataTypeSize, VLLMDataTypeTag, + VLLMDataTypeTorchDataTypeTag, + VLLMDataTypeVLLMScalarTypeTag, VLLMKernelScheduleTag) # yapf: enable @@ -27,49 +32,125 @@ #include "../machete_mm_launcher.cuh" namespace machete { -using GemmDispatcher_ = GemmDispatcher< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - {{DataTypeTag[type_config.element_b_scale]}}, // Scales - {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints - -{% for s in schedules %}extern torch::Tensor -impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args); -{% endfor %} -template <> -torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) { + +{% for impl_config in impl_configs %} +{% set type_sig = gen_type_sig(impl_config.types) -%} +{% for s in impl_config.schedules %} +extern torch::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs); +{%- endfor %} + +torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) { [[maybe_unused]] auto M = args.A.size(0); [[maybe_unused]] auto N = args.B.size(1); [[maybe_unused]] auto K = args.A.size(1); - if (!args.schedule) { - {%- for cond, s in heuristic %} + if (!args.maybe_schedule) { + {%- for cond, s in impl_config.heuristic %} {%if cond is not none%}if ({{cond}}) {%- else %}else {%- endif %} - return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %} + return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %} } - {% for s in schedules %} - if (*args.schedule == "{{ gen_sch_name(s) }}") { - return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args); - } - {% endfor %} + {%- for s in impl_config.schedules %} + if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}") + return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args); + {%- endfor %} TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for " - "schedule = ", *args.schedule); + "schedule = ", *args.maybe_schedule); } +{%- endfor %} + -template <> -std::vector GemmDispatcher_::supported_schedules() { - return { - {% for s in schedules -%} - "{{ gen_sch_name(s) }}"{{ ", - " if not loop.last }}{%- endfor %} - }; +static inline std::optional maybe_scalartype( + c10::optional const& t) { + if (!t) { + return std::nullopt; + } else { + return t->scalar_type(); + }; +} + +torch::Tensor mm_dispatch(MMArgs args) { + auto out_type = args.maybe_out_type.value_or(args.A.scalar_type()); + auto a_type = args.A.scalar_type(); + auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales); + auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros); + auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales); + auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales); + + {% for impl_config in impl_configs %} + {% set t = impl_config.types -%} + {% set type_sig = gen_type_sig(t) -%} + if (args.b_type == {{VLLMScalarTypeTag[t.b]}} + && a_type == {{TorchTypeTag[t.a]}} + && out_type == {{TorchTypeTag[t.out]}} + && {%if t.b_group_scale != void -%} + maybe_g_scales_type == {{TorchTypeTag[t.b_group_scale]}} + {%- else %}!maybe_g_scales_type{%endif%} + && {%if t.b_group_zeropoint != void -%} + maybe_g_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}} + {%- else %}!maybe_g_zeros_type{%endif%} + && {%if t.b_channel_scale != void -%} + maybe_ch_scales_type == {{TorchTypeTag[t.b_channel_scale]}} + {%- else %}!maybe_ch_scales_type{%endif%} + && {%if t.a_token_scale != void -%} + maybe_tok_scales_type == {{TorchTypeTag[t.a_token_scale]}} + {%- else %}!maybe_tok_scales_type{%endif%} + ) { + return mm_dispatch_{{type_sig}}(args); + } + {%- endfor %} + + TORCH_CHECK_NOT_IMPLEMENTED( + false, "machete_mm(..) is not implemented for " + "a_type=", args.A.scalar_type(), + ", b_type=", args.b_type.str(), + ", out_type=", out_type, + ", with_group_scale_type=", maybe_g_scales_type + ? toString(*maybe_g_scales_type) : "None", + ", with_group_zeropoint_type=", maybe_g_zeros_type + ? toString(*maybe_g_zeros_type) : "None", + ", with_channel_scale_type=", maybe_ch_scales_type + ? toString(*maybe_ch_scales_type) : "None", + ", with_token_scale_type=", maybe_tok_scales_type + ? toString(*maybe_tok_scales_type) : "None", + "; implemented types are: \\n", + {%- for impl_config in impl_configs %} + {% set t = impl_config.types -%} + "\\t{{gen_type_option_name(t)}}\\n", + {%- endfor %} + ""); } +std::vector supported_schedules_dispatch( + SupportedSchedulesArgs args) { + auto out_type = args.maybe_out_type.value_or(args.a_type); + + {% for impl_config in impl_configs %} + {% set t = impl_config.types -%} + {% set schs = impl_config.schedules -%} + if (args.b_type == {{VLLMScalarTypeTag[t.b]}} + && args.a_type == {{TorchTypeTag[t.a]}} + && out_type == {{TorchTypeTag[t.out]}} + && {%if t.b_group_scale != void -%} + args.maybe_group_scales_type == {{TorchTypeTag[t.b_group_scale]}} + {%- else %}!args.maybe_group_scales_type{%endif%} + && {%if t.b_group_zeropoint != void-%} + args.maybe_group_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}} + {%- else %}!args.maybe_group_zeros_type{%endif%} + ) { + return { + {%- for s in impl_config.schedules %} + "{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %} + {%- endfor %} + }; + } + {%- endfor %} + + return {}; +}; + }; // namespace machete """ @@ -77,20 +158,10 @@ #include "../machete_mm_launcher.cuh" namespace machete { -template -using Kernel = MacheteKernelTemplate< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - {{DataTypeTag[type_config.element_b_scale]}}, // Scales - {{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints - cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, - Config, with_C, with_scales, with_zeropoints>; - -{% for sch in schedules %} -{% set schedule_name = gen_sch_name(sch) -%} -struct sch_{{schedule_name}} { + +{% for sch in unique_schedules(impl_configs) %} +{% set sch_sig = gen_sch_sig(sch) -%} +struct sch_{{sch_sig}} { using TileShapeNM = Shape<{{ to_cute_constant(sch.tile_shape_mn)|join(', ')}}>; using ClusterShape = Shape<{{ @@ -101,27 +172,34 @@ using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}}; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; }; - +{% endfor %} + +{% for impl_config in impl_configs %} +{% set t = impl_config.types -%} +{% set schs = impl_config.schedules -%} +{% set type_sig = gen_type_sig(t) -%} + +template +using Kernel_{{type_sig}} = MacheteKernelTemplate< + {{DataTypeTag[t.a]}}, // ElementA + {{DataTypeTag[t.b]}}, // ElementB + {{DataTypeTag[t.out]}}, // ElementD + {{DataTypeTag[t.accumulator]}}, // Accumulator + {{DataTypeTag[t.b_group_scale]}}, // GroupScaleT + {{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT + {{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT + {{DataTypeTag[t.a_token_scale]}}, // TokenScaleT + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, + Sch>; + +{% for sch in schs %} +{% set sch_sig = gen_sch_sig(sch) -%} torch::Tensor -impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) { - bool with_C = args.C.has_value(), with_scales = args.scales.has_value(), - with_zeropoints = args.zeros.has_value(); - - {% for s in specializations %} - if (with_C == {{s.with_C|lower}} - && with_zeropoints == {{s.with_zeropoints|lower}} - && with_scales == {{s.with_scales|lower}}) { - return run_impl>(args); - }{% endfor %} - - TORCH_CHECK_NOT_IMPLEMENTED( - false, "for the sake of compile times and binary size machete_mm(..) is " - " not implemented for with_C=", with_C, ", with_scales=", with_scales, - ", with_zeropoints=", with_zeropoints, - " (for {{type_name}}_sch_{{schedule_name}})"); +impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) { + return run_impl>(args); } -{% endfor %} +{%- endfor %} +{%- endfor %} }; // namespace machete """ @@ -130,26 +208,34 @@ #include "../machete_prepack_launcher.cuh" namespace machete { -using PrepackBDispatcher_ = PrepackBDispatcher< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - {{DataTypeTag[type_config.element_b_scale]}}, // Scales - {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints - -using PrepackedLayoutB = PrepackedLayoutBTemplate< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - cutlass::layout::ColumnMajor, - cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>; - -template <> -torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) { - return prepack_impl(B); + +torch::Tensor prepack_B_dispatch(PrepackBArgs args) { + auto convert_type = args.maybe_group_scales_type.value_or(args.a_type); + {%- for t in types %} + {% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %} + if (args.a_type == {{TorchTypeTag[t.a]}} + && args.b_type.size_bits() == {{t.b_num_bits}} + && convert_type == {{TorchTypeTag[t.convert]}}) { + return prepack_impl< + PrepackedLayoutBTemplate< + {{DataTypeTag[t.a]}}, // ElementA + {{DataTypeTag[b_type]}}, // ElementB + {{DataTypeTag[t.convert]}}, // ElementConvert + {{DataTypeTag[t.accumulator]}}, // Accumulator + cutlass::layout::ColumnMajor, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput> + >(args.B); + } + {%- endfor %} + + TORCH_CHECK_NOT_IMPLEMENTED(false, + "prepack_B_dispatch(..) is not implemented for " + "atype = ", args.a_type, + ", b_type = ", args.b_type.str(), + ", with_group_scales_type= ", args.maybe_group_scales_type ? + toString(*args.maybe_group_scales_type) : "None"); } + }; // namespace machete """ @@ -166,32 +252,34 @@ class ScheduleConfig: tile_scheduler: TileSchedulerType -@dataclass +@dataclass(frozen=True) class TypeConfig: - element_a: DataType - element_b: Union[DataType, VLLMDataType] - element_b_scale: DataType - element_b_zeropoint: DataType - element_d: DataType + a: DataType + b: Union[DataType, VLLMDataType] + b_group_scale: DataType + b_group_zeropoint: DataType + b_channel_scale: DataType + a_token_scale: DataType + out: DataType accumulator: DataType -@dataclass -class Specialization: - with_C: bool - with_zeropoints: bool - with_scales: bool +@dataclass(frozen=True) +class PrepackTypeConfig: + a: DataType + b_num_bits: int + convert: DataType + accumulator: DataType @dataclass class ImplConfig: - type_config: TypeConfig - schedule_configs: List[ScheduleConfig] - specializations: List[Specialization] + types: TypeConfig + schedules: List[ScheduleConfig] heuristic: List[Tuple[Optional[str], ScheduleConfig]] -def generate_schedule_name(schedule_config: ScheduleConfig) -> str: +def generate_sch_sig(schedule_config: ScheduleConfig) -> str: tile_shape = ( f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" ) @@ -209,40 +297,34 @@ def generate_schedule_name(schedule_config: ScheduleConfig) -> str: f"_{epilogue_schedule}_{tile_scheduler}") -# mostly unique shorter schedule_name -def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str: +# mostly unique shorter sch_sig +def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: kernel_terse_names_replace = { "KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_", "TmaWarpSpecializedCooperative_": "TmaCoop_", "StreamKScheduler": "streamK", } - schedule_name = generate_schedule_name(schedule_config) + sch_sig = generate_sch_sig(schedule_config) for orig, terse in kernel_terse_names_replace.items(): - schedule_name = schedule_name.replace(orig, terse) - return schedule_name + sch_sig = sch_sig.replace(orig, terse) + return sch_sig # unique type_name -def generate_type_signature(kernel_type_config: TypeConfig): - element_a = VLLMDataTypeNames[kernel_type_config.element_a] - element_b = VLLMDataTypeNames[kernel_type_config.element_b] - element_d = VLLMDataTypeNames[kernel_type_config.element_d] - accumulator = VLLMDataTypeNames[kernel_type_config.accumulator] - element_scale = VLLMDataTypeNames[kernel_type_config.element_b_scale] - element_zeropoint = VLLMDataTypeNames[ - kernel_type_config.element_b_zeropoint] - - return (f"{element_a}{element_b}{element_d}" - f"{accumulator}{element_scale}{element_zeropoint}") - +def generate_type_signature(kernel_types: TypeConfig): + return str("".join([ + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ])) -# non-unique shorter type_name -def generate_terse_type_signature(kernel_type_config: TypeConfig): - element_a = VLLMDataTypeNames[kernel_type_config.element_a] - element_b = VLLMDataTypeNames[kernel_type_config.element_b] - return f"{element_a}{element_b}" +def generate_type_option_name(kernel_types: TypeConfig): + return ", ".join([ + f"{field.name.replace('b_', 'with_')+'_type'}=" + + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ]) def is_power_of_two(n): @@ -263,13 +345,36 @@ def _to_cute_constant(value: int): return _to_cute_constant(value) +def unique_schedules(impl_configs: List[ImplConfig]): + return list( + set(sch for impl_config in impl_configs + for sch in impl_config.schedules)) + + +def unsigned_type_with_bitwidth(num_bits): + return { + 4: DataType.u4, + 8: DataType.u8, + 16: DataType.u16, + 32: DataType.u32, + 64: DataType.u64, + }[num_bits] + + template_globals = { + "void": DataType.void, "DataTypeTag": VLLMDataTypeTag, + "VLLMScalarTypeTag": VLLMDataTypeVLLMScalarTypeTag, + "TorchTypeTag": VLLMDataTypeTorchDataTypeTag, "KernelScheduleTag": VLLMKernelScheduleTag, "EpilogueScheduleTag": EpilogueScheduleTag, "TileSchedulerTag": TileSchedulerTag, "to_cute_constant": to_cute_constant, - "gen_sch_name": generate_terse_schedule_name, + "gen_sch_sig": generate_terse_sch_sig, + "gen_type_sig": generate_type_signature, + "unique_schedules": unique_schedules, + "unsigned_type_with_bitwidth": unsigned_type_with_bitwidth, + "gen_type_option_name": generate_type_option_name } @@ -284,42 +389,82 @@ def create_template(template_str): prepack_dispatch_template = create_template(PREPACK_TEMPLATE) -def create_sources(impl_config: ImplConfig, num_impl_files=1): +def create_sources(impl_configs: List[ImplConfig], num_impl_files=8): sources = [] - type_name = generate_type_signature(impl_config.type_config) - terse_type_name = generate_terse_type_signature(impl_config.type_config) - sources.append(( - f"machete_mm_{terse_type_name}", - mm_dispatch_template.render(type_name=type_name, - type_config=impl_config.type_config, - schedules=impl_config.schedule_configs, - heuristic=impl_config.heuristic), + "machete_mm_dispatch", + mm_dispatch_template.render(impl_configs=impl_configs), )) + prepack_types = [] + for impl_config in impl_configs: + convert_type = impl_config.types.a \ + if impl_config.types.b_group_scale == DataType.void \ + else impl_config.types.b_group_scale + prepack_types.append( + PrepackTypeConfig( + a=impl_config.types.a, + b_num_bits=VLLMDataTypeSize[impl_config.types.b], + convert=convert_type, + accumulator=impl_config.types.accumulator, + )) + + def prepacked_type_key(prepack_type: PrepackTypeConfig): + # For now we we can just use the first accumulator type seen since + # the tensor core shapes/layouts don't vary based on accumulator + # type so we can generate less code this way + return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert) + + unique_prepack_types = [] + prepack_types_seen = set() + for prepack_type in prepack_types: + key = prepacked_type_key(prepack_type) + if key not in prepack_types_seen: + unique_prepack_types.append(prepack_type) + prepack_types_seen.add(key) + sources.append(( - f"machete_prepack_{terse_type_name}", - prepack_dispatch_template.render( - type_name=type_name, - type_config=impl_config.type_config, - ), + "machete_prepack", + prepack_dispatch_template.render(types=unique_prepack_types, ), )) - num_schedules = len(impl_config.schedule_configs) - schedules_per_file = math.ceil(num_schedules / num_impl_files) - for part, i in enumerate(range(0, num_schedules, schedules_per_file)): - file_schedules = impl_config.schedule_configs[i:i + schedules_per_file] + # Split up impls across files + num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0) + num_impls_per_file = math.ceil(num_impls / num_impl_files) + + files_impls: List[List[ImplConfig]] = [[]] + + curr_num_impls_assigned = 0 + curr_impl_in_file = 0 + curr_impl_configs = deepcopy(list(reversed(impl_configs))) + + while curr_num_impls_assigned < num_impls: + room_left_in_file = num_impls_per_file - curr_impl_in_file + if room_left_in_file == 0: + files_impls.append([]) + room_left_in_file = num_impls_per_file + curr_impl_in_file = 0 + + curr_ic = curr_impl_configs[-1] + if len(curr_ic.schedules) >= room_left_in_file: + # Break apart the current impl config + tmp_ic = deepcopy(curr_ic) + tmp_ic.schedules = curr_ic.schedules[:room_left_in_file] + curr_ic.schedules = curr_ic.schedules[room_left_in_file:] + files_impls[-1].append(tmp_ic) + else: + files_impls[-1].append(curr_ic) + curr_impl_configs.pop() + curr_num_impls_assigned += len(files_impls[-1][-1].schedules) + curr_impl_in_file += len(files_impls[-1][-1].schedules) + for part, file_impls in enumerate(files_impls): sources.append(( - f"machete_mm_{terse_type_name}_impl_part{part}", - mm_impl_template.render( - type_name=type_name, - type_config=impl_config.type_config, - schedules=file_schedules, - specializations=impl_config.specializations, - ), + f"machete_mm_impl_part{part+1}", + mm_impl_template.render(impl_configs=file_impls), )) + return sources @@ -328,187 +473,169 @@ def generate(): # about how this works SCRIPT_DIR = os.path.dirname(__file__) - schedule_common_params = dict( + sch_common_params = dict( kernel_schedule=TmaMI, epilogue_schedule=TmaCoop, tile_scheduler=TileSchedulerType.StreamK, ) - # For now we use the same heuristic for all types - # Heuristic is currently tuned for H100s - default_heuristic = [ + # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + default_tile_heuristic_config = { #### M = 257+ - ( - "M > 256 && K <= 16384 && N <= 4096", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 256", - ScheduleConfig( - tile_shape_mn=(128, 256), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + "M > 256": ((128, 256), (2, 1, 1)), #### M = 129-256 - ( - "M > 128 && K <= 4096 && N <= 4096", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 128 && K <= 8192 && N <= 8192", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 128", - ScheduleConfig( - tile_shape_mn=(128, 256), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + "M > 128": ((128, 256), (2, 1, 1)), #### M = 65-128 - ( - "M > 64 && K <= 4069 && N <= 4069", - ScheduleConfig( - tile_shape_mn=(128, 32), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 64 && K <= 4069 && N <= 8192", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 64 && K >= 8192 && N >= 12288", - ScheduleConfig( - tile_shape_mn=(256, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 64", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + "M > 64": ((128, 128), (2, 1, 1)), #### M = 33-64 - ( - "M > 32 && K <= 6144 && N <= 6144", - ScheduleConfig( - tile_shape_mn=(128, 16), - cluster_shape_mnk=(1, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 32 && K >= 16384 && N >= 12288", - ScheduleConfig( - tile_shape_mn=(256, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 32", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + "M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + "M > 32": ((128, 64), (2, 1, 1)), #### M = 17-32 - ( - "M > 16 && K <= 12288 && N <= 8192", - ScheduleConfig( - tile_shape_mn=(128, 32), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 16", - ScheduleConfig( - tile_shape_mn=(256, 32), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + "M > 16": ((256, 32), (2, 1, 1)), #### M = 1-16 - ( - "N >= 26624", - ScheduleConfig( - tile_shape_mn=(256, 16), - cluster_shape_mnk=(1, 1, 1), - **schedule_common_params # type: ignore - )), - ( - None, - ScheduleConfig( - tile_shape_mn=(128, 16), - cluster_shape_mnk=(1, 1, 1), - **schedule_common_params # type: ignore - )), + "N >= 26624": ((256, 16), (1, 1, 1)), + None: ((128, 16), (1, 1, 1)), + } + + # For now we use the same heuristic for all types + # Heuristic is currently tuned for H100s + default_heuristic = [ + (cond, ScheduleConfig(*tile_config, + **sch_common_params)) # type: ignore + for cond, tile_config in default_tile_heuristic_config.items() ] - # Do not use schedules = list(set(...)) because we need to make sure - # the output list is deterministic; otherwise the generated kernel file - # will be non-deterministic and causes ccache miss. - schedules = [] - for _, schedule_config in default_heuristic: - if schedule_config not in schedules: - schedules.append(schedule_config) + def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]): + # Do not use schedules = list(set(...)) because we need to make sure + # the output list is deterministic; otherwise the generated kernel file + # will be non-deterministic and causes ccache miss. + schedules = [] + for _, schedule_config in heuristic: + if schedule_config not in schedules: + schedules.append(schedule_config) + return schedules impl_configs = [] GPTQ_kernel_type_configs = list( TypeConfig( - element_a=element_a, - element_b=element_b, - element_b_scale=element_a, - element_b_zeropoint=element_a, - element_d=element_a, + a=a, + b=b, + b_group_scale=a, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.void, + a_token_scale=DataType.void, + out=a, accumulator=DataType.f32, - ) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128) - for element_a in (DataType.f16, DataType.bf16)) - - GPTQ_kernel_specializations = [ - Specialization(with_C=False, with_zeropoints=False, with_scales=True) - ] + ) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) + for a in (DataType.f16, DataType.bf16)) impl_configs += [ - ImplConfig(x[0], x[1], x[2], x[3]) - for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules), - itertools.repeat(GPTQ_kernel_specializations), + ImplConfig(x[0], x[1], x[2]) + for x in zip(GPTQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), itertools.repeat(default_heuristic)) ] AWQ_kernel_type_configs = list( TypeConfig( - element_a=element_a, - element_b=element_b, - element_b_scale=element_a, - element_b_zeropoint=element_a, - element_d=element_a, + a=a, + b=b, + b_group_scale=a, + b_group_zeropoint=a, + b_channel_scale=DataType.void, + a_token_scale=DataType.void, + out=a, accumulator=DataType.f32, - ) for element_b in (DataType.u4, DataType.u8) - for element_a in (DataType.f16, DataType.bf16)) + ) for b in (DataType.u4, DataType.u8) + for a in (DataType.f16, DataType.bf16)) + + impl_configs += [ + ImplConfig(x[0], x[1], x[2]) + for x in zip(AWQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic)) + ] - AWQ_kernel_specializations = [ - Specialization(with_C=False, with_zeropoints=True, with_scales=True) + # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + # TODO (LucasWilkinson): Further tuning required + qqq_tile_heuristic_config = { + #### M = 257+ + # ((128, 256), (2, 1, 1)) Broken for QQQ types + # TODO (LucasWilkinson): Investigate further + # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + # "M > 256": ((128, 256), (2, 1, 1)), + "M > 256": ((128, 128), (2, 1, 1)), + #### M = 129-256 + "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + # ((128, 256), (2, 1, 1)) Broken for QQQ types + # TODO (LucasWilkinson): Investigate further + # "M > 128": ((128, 256), (2, 1, 1)), + "M > 128": ((128, 128), (2, 1, 1)), + #### M = 65-128 + "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + "M > 64": ((128, 128), (2, 1, 1)), + #### M = 33-64 + "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + # Broken for QQQ types + # TODO (LucasWilkinson): Investigate further + #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + "M > 32": ((128, 64), (2, 1, 1)), + #### M = 17-32 + "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + "M > 16": ((256, 32), (2, 1, 1)), + #### M = 1-16 + "N >= 26624": ((256, 16), (1, 1, 1)), + None: ((128, 16), (1, 1, 1)), + } + + # For now we use the same heuristic for all types + # Heuristic is currently tuned for H100s + qqq_heuristic = [ + (cond, ScheduleConfig(*tile_config, + **sch_common_params)) # type: ignore + for cond, tile_config in qqq_tile_heuristic_config.items() + ] + + QQQ_kernel_types = [ + *(TypeConfig( + a=DataType.s8, + b=VLLMDataType.u4b8, + b_group_scale=b_group_scale, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.f32, + a_token_scale=DataType.f32, + out=DataType.f16, + accumulator=DataType.s32, + ) for b_group_scale in (DataType.f16, DataType.void)), + *(TypeConfig( + a=DataType.e4m3, + b=VLLMDataType.u4b8, + b_group_scale=b_group_scale, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.f32, + a_token_scale=DataType.f32, + out=DataType.f16, + accumulator=DataType.f32, + ) for b_group_scale in (DataType.f16, DataType.void)), ] impl_configs += [ - ImplConfig(x[0], x[1], x[2], x[3]) - for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules), - itertools.repeat(AWQ_kernel_specializations), - itertools.repeat(default_heuristic)) + ImplConfig(x[0], x[1], x[2]) + for x in zip(QQQ_kernel_types, + itertools.repeat(get_unique_schedules(qqq_heuristic)), + itertools.repeat(qqq_heuristic)) ] output_dir = os.path.join(SCRIPT_DIR, "generated") @@ -521,12 +648,11 @@ def generate(): os.makedirs(output_dir) # Render each group of configurations into separate files - for impl_config in impl_configs: - for filename, code in create_sources(impl_config): - filepath = os.path.join(output_dir, f"{filename}.cu") - with open(filepath, "w") as output_file: - output_file.write(code) - print(f"Rendered template to {filepath}") + for filename, code in create_sources(impl_configs): + filepath = os.path.join(output_dir, f"{filename}.cu") + with open(filepath, "w") as output_file: + output_file.write(code) + print(f"Rendered template to {filepath}") if __name__ == "__main__": diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index e8e7b14de0da1..816f33a1078e5 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -171,6 +171,10 @@ struct MacheteCollectiveMma { make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy( + make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), + Int{}))); + using SmemLayoutAtomARowMajor = decltype(rs_smem_selector(TileShape_MNK{})), @@ -288,14 +292,7 @@ struct MacheteCollectiveMma { static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape."); - // Tile along modes in a way that maximizes the TMA box size. - using SmemLayoutACopy = decltype(tile_to_shape( - SmemLayoutAtomARowMajor{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), - Int{}), - conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), - Step<_2, _1, _3>, Step<_1, _2, _3>>{})); - + // Tile along modes in a way that maximizes the TMA box size using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), @@ -428,12 +425,12 @@ struct MacheteCollectiveMma { // clang-format on // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) - using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset( + using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset_copy( make_shape(int32_t(0), int32_t(0), int32_t(0))))); using ATensor = decltype(make_tensor( get_logical_ptr(static_cast(nullptr)), - shape(GmemLayoutA::TVbNbKL_to_offset( + shape(GmemLayoutA::TVbNbKL_to_offset_copy( make_shape(int32_t(0), int32_t(0), int32_t(0)))), PrepackedStrideA{})); @@ -450,8 +447,8 @@ struct MacheteCollectiveMma { static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) { return make_tma_copy( - GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), - shape(SmemLayoutA{}(_, _, cute::Int<0>{})), + GmemTiledCopyA{}, tensor_a, SmemLayoutACopy{}(_, _, cute::Int<0>{}), + shape(SmemLayoutACopy{}(_, _, cute::Int<0>{})), size<1>(ClusterShape{})); // mcast along N mode for this M load, if any } @@ -584,7 +581,7 @@ struct MacheteCollectiveMma { typename Params::TMA_Scale tma_load_scale; typename Params::TMA_Zero tma_load_zero; - auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L)); tma_load_a = make_tma_copy_A( make_logical_tensor(ptr_A, shape(layout), stride(layout))); @@ -722,7 +719,7 @@ struct MacheteCollectiveMma { // (TILE_V,TILE_B,m,k,l) auto make_gA_mkl = [&]() { // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) - auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L)); Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout)); return local_tile(mA_mkl, make_shape(size<0>(layout), PPBlocksPerTile_MK{}), diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index 4d41b8d291484..d4d19ae5deec7 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -21,6 +21,8 @@ #include "cutlass_extensions/cute_utils.cuh" #include "cutlass_extensions/vllm_numeric_conversion.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "cutlass_extensions/torch_utils.hpp" #include "machete_collective_builder.cuh" #include "machete_prepacked_layout.cuh" #include "machete_interleaving_utils.cuh" @@ -37,27 +39,42 @@ using namespace cute; // W is quantized, in this situation or right-hand operand is quantized so // we compute the transpose to move it to the left-hand side. template + typename AccumulatorT, typename GroupScaleT, typename GroupZeroT, + typename ChannelScaleT, typename TokenScaleT, class KernelSchedule, + typename ScheduleConfig> struct MacheteKernelTemplate { + static constexpr bool with_C = false; // not ever used + static constexpr bool with_group_scales = !std::is_same_v; + static constexpr bool with_group_zeropoints = + !std::is_same_v; + static constexpr bool with_channel_scales = + !std::is_same_v; + static constexpr bool with_token_scales = !std::is_same_v; + using MmaType = ElementA_; using ElementA = ElementA_; using ElementB = ElementB_; using ElementD = ElementD_; using ElementC = cute::conditional_t; - using ElementZ = ZeroT; - using ElementS = ScaleT; - - using ElementAccumulator = - AccumulatorT; // Element type for internal accumulation + using ElementAccumulator = AccumulatorT; using ElementCompute = AccumulatorT; // For Epilogue + // Use dummy values when we don't have scales or zeropoints + using ElementZGroup = + cute::conditional_t; + using ElementSGroup = + cute::conditional_t; + using ElementConvertGroup = + cute::conditional_t; + using ElementSChannel = + cute::conditional_t; + using ElementSToken = + cute::conditional_t; using BTypeTuple = cute::conditional_t< - with_scales, - cute::conditional_t, - cute::tuple>, + with_group_scales, + cute::conditional_t, + cute::tuple>, ElementB>; using LayoutA = cutlass::layout::RowMajor; @@ -71,8 +88,8 @@ struct MacheteKernelTemplate { using StrideA = cutlass::detail::TagToStrideA_t; using StrideC = cutlass::detail::TagToStrideA_t; using StrideD = cutlass::detail::TagToStrideA_t; - using StrideS = cutlass::detail::TagToStrideA_t; - using StrideZ = StrideS; + using StrideSGroup = cutlass::detail::TagToStrideA_t; + using StrideZGroup = StrideSGroup; using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; @@ -85,8 +102,8 @@ struct MacheteKernelTemplate { using OperatorClass = cutlass::arch::OpClassTensorOp; using PrepackedLayoutB = - PrepackedLayoutBTemplate; + PrepackedLayoutBTemplate; static int constexpr TileShapeK = 128 * 8 / cutlass::sizeof_bits::value; @@ -103,12 +120,42 @@ struct MacheteKernelTemplate { using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; using TileScheduler = typename ScheduleConfig::TileScheduler; + static_assert( + (!with_channel_scales && !with_token_scales) || + ((with_channel_scales && with_token_scales) && + std::is_same_v), + "Currently token and channel scales (if present) must be the same type"); + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + // Currently only supports float scales + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogue; + static_assert((with_channel_scales || with_token_scales) || + (std::is_same_v && + std::is_same_v), + "Currently token and channel scales (if present) must be float " + "(and if one is present the other must be too)"); + + using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90AccFetch>; + + using EVTCompute = + std::conditional_t; + + // EVTCompute using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, - ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose, - AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, - EpilogueSchedule>::CollectiveOp; + ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose, + AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule, + EVTCompute>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::VLLMCollectiveBuilder< @@ -131,26 +178,44 @@ struct MacheteKernelTemplate { using MainloopArguments = typename GemmKernel::MainloopArguments; using EpilogueArguments = typename GemmKernel::EpilogueArguments; - template static Arguments create_arguments( cudaStream_t stream, - ElementA const* A_ptr, // A is an MxK matrix - Layout const& layout_A, - ElementB const* B_ptr, // B is an KxN prepacked matrix - ElementD* D_ptr, // D is an MxN matrix - Layout const& layout_D, - ElementC const* C_ptr, // C is an MxN matrix - std::optional> const& layout_C, - ElementS const* S_ptr, // S is an scale_KxN matrix - std::optional> const& layout_S, - ElementZ const* Z_ptr, // Z is an scale_KxN matrix - std::optional> const& layout_Z, - ElementCompute alpha, ElementCompute beta, - std::optional maybe_group_size) { - static_assert(!with_zeropoints || with_scales); - - int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); + torch::Tensor const& A, // MxK matrix + torch::Tensor const& B, // KxN prepacked matrix + torch::Tensor& D, // MxN matrix + c10::optional const& maybe_g_scales, // scale_KxN matrix + c10::optional const& maybe_g_zeros, // scale_KxN matrix + c10::optional maybe_group_size, + c10::optional const& maybe_ch_scales, // len N vector + c10::optional const& maybe_tok_scales) // len M vector + { + static_assert(!with_group_zeropoints || with_group_scales); + + int M = A.size(0), N = B.size(1), K = A.size(1); + TORCH_CHECK(D.size(0) == M && D.size(1) == N); + + auto layout_A = make_cute_layout(A, "A"); + auto layout_D = make_cute_layout(D, "D"); + auto layout_S_group = + maybe_make_cute_layout(maybe_g_scales, "group_scales"); + auto layout_Z_group = + maybe_make_cute_layout(maybe_g_zeros, "group_zeros"); + int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0; + int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0; + + auto unwrap = [](auto const& t) { + return t ? t->const_data_ptr() : nullptr; + }; + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.mutable_data_ptr()); + auto S_group_ptr = + static_cast(unwrap(maybe_g_scales)); + auto Z_group_ptr = static_cast(unwrap(maybe_g_zeros)); + auto S_channel_ptr = + static_cast(unwrap(maybe_ch_scales)); + auto S_token_ptr = + static_cast(unwrap(maybe_tok_scales)); int const group_size = maybe_group_size == -1 ? K : maybe_group_size.value_or(K); @@ -159,26 +224,28 @@ struct MacheteKernelTemplate { TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N); - if constexpr (with_C) { - TORCH_CHECK(C_ptr && layout_C); + if constexpr (with_group_scales) { + TORCH_CHECK(S_group_ptr && layout_S_group); + TORCH_CHECK((size<0>(*layout_S_group) == scale_k && + size<1>(*layout_S_group) == N)); } else { - TORCH_CHECK(!C_ptr, "C not supported"); + TORCH_CHECK(!S_group_ptr, "Scales not supported"); } - if constexpr (with_scales) { - TORCH_CHECK(S_ptr && layout_S); - TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N)); + if constexpr (with_group_zeropoints) { + TORCH_CHECK(Z_group_ptr && layout_Z_group); + TORCH_CHECK((size<0>(*layout_Z_group) == scale_k && + size<1>(*layout_Z_group) == N)); + TORCH_CHECK(layout_S_group && *layout_Z_group == *layout_S_group, + "Scales and zeros must have the same layout"); } else { - TORCH_CHECK(!S_ptr, "Scales not supported"); + TORCH_CHECK(!Z_group_ptr, "Zeropoints not supported"); } - if constexpr (with_zeropoints) { - TORCH_CHECK(Z_ptr && layout_Z); - TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N)); - TORCH_CHECK(layout_S && *layout_Z == *layout_S, - "Scales and zeros must have the same layout"); - } else { - TORCH_CHECK(!Z_ptr, "Zeropoints not supported"); + if constexpr (with_channel_scales || with_token_scales) { + TORCH_CHECK( + (maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) && + (maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1)); } // Transpose A and D @@ -186,24 +253,33 @@ struct MacheteKernelTemplate { // for B (which is At) auto stride_At = layout_A.stride(); auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride(); - auto stride_Ct = stride_Dt; - if (layout_C) { - stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride(); - } MainloopArguments mainloop_arguments{}; - EpilogueArguments epilogue_arguments{ - {alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt}; + // {Accum, C, C_layout, D, D} + EpilogueArguments epilogue_arguments{}; + + if constexpr (with_channel_scales || with_token_scales) { + epilogue_arguments = + EpilogueArguments{ChTokScalesEpilogue::prepare_args( + *maybe_ch_scales, *maybe_tok_scales), + nullptr, + {}, + D_ptr, + stride_Dt}; + } else { + epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt}; + } - if constexpr (with_scales && with_zeropoints) { - auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); - mainloop_arguments = - MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, - S_ptr, stride_S, group_size, Z_ptr}; - } else if constexpr (with_scales) { - auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); + if constexpr (with_group_scales && with_group_zeropoints) { + auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride(); mainloop_arguments = MainloopArguments{ - B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size}; + B_ptr, _StrideB{}, A_ptr, stride_At, + S_group_ptr, stride_S_group, group_size, Z_group_ptr}; + } else if constexpr (with_group_scales) { + auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride(); + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, + S_group_ptr, stride_S_group, group_size}; } else { mainloop_arguments = MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At}; diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh index 60a4ed60535b7..4b0da5b303e0c 100644 --- a/csrc/quantization/machete/machete_mm_launcher.cuh +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -5,73 +5,61 @@ #include "machete_mm_kernel.cuh" #include "cutlass_extensions/torch_utils.hpp" +#include "core/scalar_type.hpp" namespace machete { -struct PyTorchArguments { +struct MMArgs { torch::Tensor const& A; torch::Tensor const& B; - c10::optional const& scales; - c10::optional const& zeros; - c10::optional group_size; - c10::optional const& C; - c10::optional alpha; - c10::optional beta; - c10::optional schedule; + vllm::ScalarType const& b_type; + c10::optional const& maybe_out_type; + c10::optional const& maybe_group_scales; + c10::optional const& maybe_group_zeros; + c10::optional maybe_group_size; + c10::optional const& maybe_channel_scales; + c10::optional const& maybe_token_scales; + c10::optional maybe_schedule; }; +struct SupportedSchedulesArgs { + at::ScalarType a_type; + vllm::ScalarType b_type; + c10::optional maybe_group_scales_type; + c10::optional maybe_group_zeros_type; + c10::optional maybe_channel_scales_type; + c10::optional maybe_token_scales_type; + c10::optional maybe_out_type; +}; + +torch::Tensor mm_dispatch(MMArgs args); + +std::vector supported_schedules_dispatch( + SupportedSchedulesArgs args); + template -torch::Tensor run_impl(PyTorchArguments args) { +torch::Tensor run_impl(MMArgs args) { const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A)); auto device = args.A.device(); auto stream = at::cuda::getCurrentCUDAStream(device.index()); - using EleA = typename MacheteKernel::ElementA; - using EleB = typename MacheteKernel::ElementB; - using EleC = typename MacheteKernel::ElementC; - using EleD = typename MacheteKernel::ElementD; - using EleScale = typename MacheteKernel::ElementS; - using EleZero = typename MacheteKernel::ElementZ; - - using StrideA = typename MacheteKernel::StrideA; - using StrideC = typename MacheteKernel::StrideC; - using StrideD = typename MacheteKernel::StrideD; - using StrideS = typename MacheteKernel::StrideS; - using StrideZ = typename MacheteKernel::StrideZ; - int M = args.A.size(0); int N = args.B.size(1); int K = args.A.size(1); // Allocate output - torch::Tensor D = - torch::empty({M, N}, torch::TensorOptions() - .dtype(equivalent_scalar_type_v) - .device(device)); - - auto const &A = args.A, &B = args.B; - auto const &C = args.C, &scales = args.scales, &zeros = args.zeros; - - auto layout_A = make_cute_layout(A, "A"); - auto layout_D = make_cute_layout(D, "D"); - auto layout_C = maybe_make_cute_layout(C, "C"); - auto layout_S = maybe_make_cute_layout(scales, "scales"); - auto layout_Z = maybe_make_cute_layout(zeros, "zeros"); - - auto A_ptr = static_cast(A.const_data_ptr()); - auto B_ptr = static_cast(B.const_data_ptr()); - auto D_ptr = static_cast(D.mutable_data_ptr()); - auto C_ptr = static_cast(C ? C->const_data_ptr() : nullptr); - auto S_ptr = - static_cast(scales ? scales->const_data_ptr() : nullptr); - auto Z_ptr = - static_cast(zeros ? zeros->const_data_ptr() : nullptr); + torch::Tensor D = torch::empty( + {M, N}, + torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); auto arguments = MacheteKernel::create_arguments( - stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr, - layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0), - args.group_size); + stream, // + args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros, + args.maybe_group_size, args.maybe_channel_scales, + args.maybe_token_scales); TORCH_CHECK(MacheteKernel::can_implement(arguments), "Machete kernel cannot be run with these arguments"); @@ -84,12 +72,4 @@ torch::Tensor run_impl(PyTorchArguments args) { return D; }; -template -struct GemmDispatcher { - static torch::Tensor dispatch(PyTorchArguments args); - static std::vector supported_schedules(); -}; - }; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepack_kernel.cuh b/csrc/quantization/machete/machete_prepack_kernel.cuh index f23483f928b47..d002355ca49d6 100644 --- a/csrc/quantization/machete/machete_prepack_kernel.cuh +++ b/csrc/quantization/machete/machete_prepack_kernel.cuh @@ -6,31 +6,49 @@ namespace machete { -template -static __global__ void prepack_B_kernel(BInTensor B_in, - BTiledOutTensor B_tiled_out) { - auto tB_in = local_tile(B_in, TileShapeNKL{}, - make_coord(blockIdx.x, blockIdx.y, blockIdx.z)); - auto tB_out = B_tiled_out(make_coord(_, _), - make_coord(blockIdx.x, blockIdx.y), blockIdx.z); +template +static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) { + auto constexpr block_size = + Int{}; + auto constexpr eles_per_thread = Int{}; + static_assert(block_size % threads == 0, + "block_size must be divisible by the number of threads"); - auto tiled_copy = make_tiled_copy(Copy_Atom{}, - Layout, Stride<_32, _1>>{}, - Layout>{}); + // Which pre-packed are we responsible for + auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z); + auto tB_in = local_tile( + B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}), + blk_coord); - auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + // Find the start offset in the output for this pre-packed block + auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in)); - Tensor thr_tile_S = thr_copy.partition_S(tB_in); - Tensor thr_tile_D = thr_copy.partition_D(tB_out); + // Tensor representing a 1:1 mapping to the output space in 1D + auto tB_out_linear = + make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord), + make_layout(make_shape(block_size))); + // Mapping from output space (1D) to input space + auto tB_in_linear = make_tensor( + tB_in.data(), + tB_in.layout() + .compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset())) + .with_shape(make_shape(block_size))); + + // Tile for this specific thread (could have used a TiledCopy but these work + // best with 2d layouts, this is a simple 1d layout so local_tile is enough, + // we are also not that concerned with performance for this kernel) + auto thr_tB_in_linear = + local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x); + auto thr_tB_out_linear = + local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x); // Construct a register-backed Tensor with the same shape as each thread's // partition - auto fragment = make_tensor(shape(thr_tile_D)); + auto fragment = make_tensor(shape(thr_tB_in_linear)); - // Copy from GMEM to RMEM and from RMEM to GMEM - copy(tiled_copy, thr_tile_S, fragment); - copy(Copy_Atom{}, fragment, thr_tile_D); + copy(thr_tB_in_linear, fragment); + copy(Copy_Atom{}, fragment, thr_tB_out_linear); } template @@ -44,18 +62,15 @@ static void prepack_B_template( TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0); TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0); - TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0); auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{}); auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{}); - auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{}); + auto L_tiles = size<2>(B_layout); auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout); - auto B_tiled_out = - make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset); - prepack_B_kernel - <<>>(B_in, B_tiled_out); + prepack_B_kernel<128, PrepackedLayoutB> + <<>>(B_in, B_out_ptr); } }; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh index a33d8f9484cfe..3486d28be2126 100644 --- a/csrc/quantization/machete/machete_prepack_launcher.cuh +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -2,9 +2,17 @@ #include "machete_prepack_kernel.cuh" #include "cutlass_extensions/torch_utils.hpp" +#include "core/scalar_type.hpp" namespace machete { +struct PrepackBArgs { + torch::Tensor const& B; + at::ScalarType a_type; + vllm::ScalarType b_type; + c10::optional maybe_group_scales_type; +}; + template torch::Tensor prepack_impl(torch::Tensor const B) { const at::cuda::OptionalCUDAGuard device_guard(device_of(B)); @@ -61,11 +69,6 @@ torch::Tensor prepack_impl(torch::Tensor const B) { return D; }; -template -struct PrepackBDispatcher { - static torch::Tensor dispatch(torch::Tensor B); -}; +torch::Tensor prepack_B_dispatch(PrepackBArgs args); }; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepacked_layout.cuh b/csrc/quantization/machete/machete_prepacked_layout.cuh index 78e2cc5eec7d8..680a858a893c1 100644 --- a/csrc/quantization/machete/machete_prepacked_layout.cuh +++ b/csrc/quantization/machete/machete_prepacked_layout.cuh @@ -41,7 +41,7 @@ struct IlvBlkLayoutAuto {}; // The contract here is that the `TiledMma` determined below matches the one // ultimately used in the kernel. (this is also why the other element types are // required along with the kernel schedule) -template // clang-format on @@ -49,20 +49,27 @@ struct PrepackedLayoutBTemplate { using MmaType = ElementA_; using ElementA = ElementA_; using ElementB = ElementB_; - using ElementD = ElementD_; - using ElementAccumulator = - AccumulatorT; // Element type for internal accumulation + using ElementAccumulator = AccumulatorT; using ElementMma = MmaType; - // Only use interleaved layouts for subbyte weights, prmt instructions makes - // non-interleaved layouts for 8bit+ weights efficient enough we don't need - // iterleaved layouts + // Interleave for 4bit bit types when we are not upconverting to fp8 or int8, + // in those cases case we use a LUT using prmt instructions to upconvert and + // is more efficient if the data is not interleaved For 8bit+ prmt + // instructions makes non-interleaved layouts efficient enough we don't need + // iterleaved layouts (and can reuse more of the existing cutlass converts) + static constexpr bool should_interleave = + sizeof_bits_v <= 4 && + !std::is_same_v && + !std::is_same_v; + + // Only use interleaved layouts for subbyte weights, using IlvdBlkLayout = std::conditional_t< std::is_same_v, - std::conditional_t <= 4, - decltype(get_interleaved_blk_layout< - ElementB, sizeof_bits_v, 32>()), - void>, + std::conditional_t< + should_interleave, + decltype(get_interleaved_blk_layout< + ElementB, sizeof_bits_v, 32>()), + void>, IlvBlkLayout_>; // TODO (LucasWilkinson): compare the performance for other sizes @@ -135,7 +142,8 @@ struct PrepackedLayoutBTemplate { // then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H} auto frgV = get<1, 0>(layout_no_interleave); auto ilvdBlk = IlvdBlkLayout{}; - static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4"); + static_assert(size(frgV) % size(ilvdBlk) == 0, + "FrgV must be divisible by size(ilvdBlk)"); auto ilvd_FrgV = make_layout( make_shape(shape(ilvdBlk), Int{}), make_stride(stride(ilvdBlk), size(ilvdBlk))); @@ -175,6 +183,15 @@ struct PrepackedLayoutBTemplate { return group<1, 3>(result(_, repeat(result)>(_))); } + // ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L) + template + CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy( + Shape_NKL shape_mkl) { + auto layout = TVbNbKL_to_offset(shape_mkl); + return make_layout(coalesce(get<0>(layout)), get<1>(layout), + get<2>(layout)); + } + // ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx) template CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset( @@ -197,6 +214,19 @@ struct PrepackedLayoutBTemplate { return group<1, 3>(result(_, repeat(result)>(_))); } + // (BlocksN, BlocksK, L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) { + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + auto stride = size(PPBlockShape_NK{}); + + // (BlocksN, BlocksK, L) -> (storage_idx) + return make_layout(blocks_shape, compact_col_major(blocks_shape, stride)); + } + // ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L) template CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) { diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu index 9f9073ded6191..da2c2fb0d3e77 100644 --- a/csrc/quantization/machete/machete_pytorch.cu +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -8,89 +8,61 @@ namespace machete { using namespace vllm; -// -// Utils (type dispatching) -// - -template -static auto scalar_type_dispatch(ScalarType const& type, Fn fn) { - if (type == vllm::kU4) { - return fn(cutlass::uint4b_t{}); - } else if (type == vllm::kU8) { - return fn(cutlass::uint8_t{}); - } else if (type == vllm::kU4B8) { - return fn(cutlass::vllm_uint4b8_t{}); - } else if (type == vllm::kU8B128) { - return fn(cutlass::vllm_uint8b128_t{}); - } else { - TORCH_CHECK(false, "Unsupported type ", type.str()); - } -} - -#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \ - AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__) - -#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, \ - AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__)) - -// -// Interface -// - -std::vector supported_schedules(ScalarTypeId const btype_id) { -#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12 - vllm::ScalarType b_type = ScalarType::from_id(btype_id); - return scalar_type_dispatch(b_type, [&](auto BType) { - return GemmDispatcher::supported_schedules(); +std::vector supported_schedules( + at::ScalarType a_type, int64_t b_type_id, + c10::optional maybe_group_scales_type, + c10::optional maybe_group_zeros_type, + c10::optional maybe_channel_scales_type, + c10::optional maybe_token_scales_type, + c10::optional maybe_out_type) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return supported_schedules_dispatch({ + .a_type = a_type, + .b_type = b_type, + .maybe_group_scales_type = maybe_group_scales_type, + .maybe_group_zeros_type = maybe_group_zeros_type, + .maybe_channel_scales_type = maybe_channel_scales_type, + .maybe_token_scales_type = maybe_token_scales_type, + .maybe_out_type = maybe_out_type, }); -#else - TORCH_CHECK(false, "Machete requires CUDA 12.0 or later"); -#endif } -torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, - ScalarTypeId const btype_id, - c10::optional const& scales, - c10::optional const& zeros, - c10::optional group_size, - c10::optional const& C, - c10::optional alpha, c10::optional beta, - c10::optional schedule) { -#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12 - ScalarType const btype = ScalarType::from_id(btype_id); - auto args = PyTorchArguments{.A = A, - .B = B, - .scales = scales, - .zeros = zeros, - .group_size = group_size, - .C = C, - .alpha = alpha, - .beta = beta, - .schedule = schedule}; - - return scalar_type_dispatch(btype, [&](auto BType) { - return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES( - A.scalar_type(), "machete_gemm", [&] { - using ComputeType = equivalent_cutlass_type_t; - return GemmDispatcher::dispatch(args); - }); - }); -#else - TORCH_CHECK(false, "Machete requires CUDA 12.0 or later"); -#endif +torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, + int64_t b_type_id, + c10::optional const& maybe_out_type, + c10::optional const& maybe_group_scales, + c10::optional const& maybe_group_zeros, + c10::optional maybe_group_size, + c10::optional const& maybe_channel_scales, + c10::optional const& maybe_token_scales, + c10::optional maybe_schedule) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return mm_dispatch({.A = A, + .B = B, + .b_type = b_type, + .maybe_out_type = maybe_out_type, + .maybe_group_scales = maybe_group_scales, + .maybe_group_zeros = maybe_group_zeros, + .maybe_group_size = maybe_group_size, + .maybe_channel_scales = maybe_channel_scales, + .maybe_token_scales = maybe_token_scales, + .maybe_schedule = maybe_schedule}); } -torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) { - ScalarType const btype = ScalarType::from_id(btype_id); - return scalar_type_dispatch(btype, [&](auto BType) { - return PrepackBDispatcher::dispatch(B); - }); +torch::Tensor prepack_B( + torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id, + c10::optional const& maybe_group_scales_type) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return prepack_B_dispatch( + {.B = B, + .a_type = a_type, + .b_type = b_type, + .maybe_group_scales_type = maybe_group_scales_type}); } TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("machete_prepack_B", &prepack_B); - m.impl("machete_gemm", &gemm); + m.impl("machete_mm", &mm); } // use CatchAll since supported_schedules has no tensor arguments diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 229fd554d3eee..e4cc7ec951848 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -203,13 +203,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // conditionally compiled so impl in source file // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. - ops.def("machete_supported_schedules(int btype) -> str[]"); ops.def( - "machete_gemm(Tensor A, Tensor B, int btype, " - " Tensor? scales, Tensor? zeros, int? group_size, " - " Tensor? C, float? alpha, float? beta, str? schedule)" - "-> Tensor"); - ops.def("machete_prepack_B(Tensor B, int btype) -> Tensor"); + "machete_supported_schedules(" + " ScalarType a_type," + " int b_type," + " ScalarType? maybe_group_scales_type," + " ScalarType? maybe_group_zeros_type," + " ScalarType? maybe_channel_scales_type," + " ScalarType? maybe_token_scales_type," + " ScalarType? maybe_out_type" + ") -> str[]"); + ops.def( + "machete_mm(" + " Tensor A," + " Tensor B," + " int b_type," + " ScalarType? out_type," + " Tensor? group_scales," + " Tensor? group_zeros," + " int? group_size," + " Tensor? channel_scales," + " Tensor? token_scales," + " str? schedule" + ") -> Tensor"); + ops.def( + "machete_prepack_B(" + " Tensor B," + " ScalarType a_type," + " int b_type," + " ScalarType? group_scales_type" + ") -> Tensor"); // conditionally compiled so impl registration is in source file ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); diff --git a/docs/source/assets/design/arch_overview/entrypoints.excalidraw.png b/docs/source/assets/design/arch_overview/entrypoints.excalidraw.png new file mode 100644 index 0000000000000..bbf46286cfe5d Binary files /dev/null and b/docs/source/assets/design/arch_overview/entrypoints.excalidraw.png differ diff --git a/docs/source/assets/design/arch_overview/llm_engine.excalidraw.png b/docs/source/assets/design/arch_overview/llm_engine.excalidraw.png new file mode 100644 index 0000000000000..ade1d602a9187 Binary files /dev/null and b/docs/source/assets/design/arch_overview/llm_engine.excalidraw.png differ diff --git a/docs/source/design/arch_overview.rst b/docs/source/design/arch_overview.rst new file mode 100644 index 0000000000000..a9e7b4bd69bc7 --- /dev/null +++ b/docs/source/design/arch_overview.rst @@ -0,0 +1,274 @@ +.. _arch_overview: + +Architecture Overview +====================== + +This document provides an overview of the vLLM architecture. + +.. contents:: Table of Contents + :local: + :depth: 2 + +Entrypoints +----------- + +vLLM provides a number of entrypoints for interacting with the system. The +following diagram shows the relationship between them. + +.. image:: /assets/design/arch_overview/entrypoints.excalidraw.png + :alt: Entrypoints Diagram + +LLM Class +^^^^^^^^^ + +The LLM class provides the primary Python interface for doing offline inference, +which is interacting with a model without using a separate model inference +server. + +Here is a sample of `LLM` class usage: + +.. code-block:: python + + from vllm import LLM, SamplingParams + + # Define a list of input prompts + prompts = [ + "Hello, my name is", + "The capital of France is", + "The largest ocean is", + ] + + # Define sampling parameters + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + # Initialize the LLM engine with the OPT-125M model + llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct") + + # Generate outputs for the input prompts + outputs = llm.generate(prompts, sampling_params) + + # Print the generated outputs + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +More API details can be found in the :doc:`Offline Inference +` section of the API docs. + +The code for the `LLM` class can be found in `vllm/entrypoints/llm.py +`_. + +OpenAI-compatible API server +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The second primary interface to vLLM is via its OpenAI-compatible API server. +This server can be started using the `vllm serve` command. + +.. code-block:: bash + + vllm serve + +The code for the `vllm` CLI can be found in `vllm/scripts.py +`_. + +Sometimes you may see the API server entrypoint used directly instead of via the +`vllm` CLI command. For example: + +.. code-block:: bash + + python -m vllm.entrypoints.openai.api_server --model + +That code can be found in `vllm/entrypoints/openai/api_server.py +`_. + +More details on the API server can be found in the :doc:`OpenAI Compatible +Server ` document. + +LLM Engine +---------- + +The `LLMEngine` and `AsyncLLMEngine` classes are central to the functioning of +the vLLM system, handling model inference and asynchronous request processing. + +.. image:: /assets/design/arch_overview/llm_engine.excalidraw.png + :alt: LLMEngine Diagram + +LLMEngine +^^^^^^^^^ + +The `LLMEngine` class is the core component of the vLLM engine. It is +responsible for receiving requests from clients and generating outputs from the +model. The `LLMEngine` includes input processing, model execution (possibly +distributed across multiple hosts and/or GPUs), scheduling, and output +processing. + +- **Input Processing**: Handles tokenization of input text using the specified + tokenizer. + +- **Scheduling**: Chooses which requests are processed in each step. + +- **Model Execution**: Manages the execution of the language model, including + distributed execution across multiple GPUs. + +- **Output Processing**: Processes the outputs generated by the model, decoding the + token IDs from a language model into human-readable text. + +The code for `LLMEngine` can be found in `vllm/engine/llm_engine.py`_. + +.. _vllm/engine/llm_engine.py: https://github.com/vllm-project/vllm/tree/main/vllm/engine/llm_engine.py + +AsyncLLMEngine +^^^^^^^^^^^^^^ + +The `AsyncLLMEngine` class is an asynchronous wrapper for the `LLMEngine` class. +It uses `asyncio` to create a background loop that continuously processes +incoming requests. The `AsyncLLMEngine` is designed for online serving, where it +can handle multiple concurrent requests and stream outputs to clients. + +The OpenAI-compatible API server uses the `AsyncLLMEngine`. There is also a demo +API server that serves as a simpler example in +`vllm/entrypoints/api_server.py`_. + +.. _vllm/entrypoints/api_server.py: https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/api_server.py + +The code for `AsyncLLMEngine` can be found in `vllm/engine/async_llm_engine.py`_. + +.. _vllm/engine/async_llm_engine.py: https://github.com/vllm-project/vllm/tree/main/vllm/engine/async_llm_engine.py + +Worker +------ + +A worker is a process that runs the model inference. vLLM follows the common +practice of using one process to control one accelerator device, such as GPUs. +For example, if we use tensor parallelism of size 2 and pipeline parallelism of +size 2, we will have 4 workers in total. Workers are identified by their +``rank`` and ``local_rank``. ``rank`` is used for global orchestration, while +``local_rank`` is mainly used for assigning the accelerator device and accessing +local resources such as the file system and shared memory. + +Model Runner +------------ + +Every worker has one model runner object, responsible for loading and running +the model. Much of the model execution logic resides here, such as preparing +input tensors and capturing cudagraphs. + +Model +----- + +Every model runner object has one model object, which is the actual +``torch.nn.Module`` instance. See :ref:`huggingface_integration` for how various +configurations affect the class we ultimately get. + +Class Hierarchy +--------------- + +The following figure shows the class hierarchy of vLLM: + + .. figure:: /assets/design/hierarchy.png + :alt: query + :width: 100% + :align: center + +There are several important design choices behind this class hierarchy: + +1. **Extensibility**: All classes in the hierarchy accept a configuration object +containing all the necessary information. The `VllmConfig +`__ +class is the main configuration object that is passed around. The class +hierarchy is quite deep, and every class needs to read the configuration it is +interested in. By encapsulating all configurations in one object, we can easily +pass the configuration object around and access the configuration we need. +Suppose we want to add a new feature (this is often the case given how fast the +field of LLM inference is evolving) that only touches the model runner. We will +have to add a new configuration option in the `VllmConfig` class. Since we pass +the whole config object around, we only need to add the configuration option to +the `VllmConfig` class, and the model runner can access it directly. We don't +need to change the constructor of the engine, worker, or model class to pass the +new configuration option. + +2. **Uniformity**: The model runner needs a unified interface to create and +initialize the model. vLLM supports more than 50 types of popular open-source +models. Each model has its own initialization logic. If the constructor +signature varies with models, the model runner does not know how to call the +constructor accordingly, without complicated and error-prone inspection logic. +By making the constructor of the model class uniform, the model runner can +easily create and initialize the model without knowing the specific model type. +This is also useful for composing models. Vision-language models often consist +of a vision model and a language model. By making the constructor uniform, we +can easily create a vision model and a language model and compose them into a +vision-language model. + +.. note:: + + To support this change, all vLLM models' signatures have been updated to: + + .. code-block:: python + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + + To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one: + + .. code-block:: python + + class MyOldModel(nn.Module): + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + ... + + from vllm.config import VllmConfig + class MyNewModel(MyOldModel): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + super().__init__(config, cache_config, quant_config, lora_config, prefix) + + if __version__ >= "0.6.4": + MyModel = MyNewModel + else: + MyModel = MyOldModel + + This way, the model can work with both old and new versions of vLLM. + +3. **Sharding and Quantization at Initialization**: Certain features require +changing the model weights. For example, tensor parallelism needs to shard the +model weights, and quantization needs to quantize the model weights. There are +two possible ways to implement this feature. One way is to change the model +weights after the model is initialized. The other way is to change the model +weights during the model initialization. vLLM chooses the latter. The first +approach is not scalable to large models. Suppose we want to run a 405B model +(with roughly 810GB weights) with 16 H100 80GB GPUs. Ideally, every GPU should +only load 50GB weights. If we change the model weights after the model is +initialized, we need to load the full 810GB weights to every GPU and then shard +the weights, leading to a huge memory overhead. Instead, if we shard the weights +during the model initialization, every layer will only create a shard of the +weights it needs, leading to a much smaller memory overhead. The same idea +applies to quantization. Note that we also add an additional argument ``prefix`` +to the model's constructor so that the model can initialize itself differently +based on the prefix. This is useful for non-uniform quantization, where +different parts of the model are quantized differently. The ``prefix`` is +usually an empty string for the top-level model and a string like ``"vision"`` +or ``"language"`` for the sub-models. In general, it matches the name of the +module's state dict in the checkpoint file. + +One disadvantage of this design is that it is hard to write unit tests for +individual components in vLLM because every component needs to be initialized by +a complete config object. We solve this problem by providing a default +initialization function that creates a default config object with all fields set +to ``None``. If the component we want to test only cares about a few fields in +the config object, we can create a default config object and set the fields we +care about. This way, we can test the component in isolation. Note that many +tests in vLLM are end-to-end tests that test the whole system, so this is not a +big problem. + +In summary, the complete config object ``VllmConfig`` can be treated as an +engine-level global state that is shared among all vLLM classes. diff --git a/docs/source/design/class_hierarchy.rst b/docs/source/design/class_hierarchy.rst deleted file mode 100644 index 58a888b17ba53..0000000000000 --- a/docs/source/design/class_hierarchy.rst +++ /dev/null @@ -1,74 +0,0 @@ -.. _class_hierarchy: - -vLLM's Class Hierarchy -======================= - -This document describes the class hierarchy of vLLM. We will explain the relationships between the core classes, their responsibilities, and the design choices behind them to make vLLM more modular and extensible. - -1. **Entrypoints**: vLLM has two entrypoints: `command line usage `__ with ``vllm serve`` for launching an OpenAI-API compatible server, and `library-style usage `__ with the ``vllm.LLM`` class for running inference in a Python script. These are user-facing entrypoints that end-users interact with. Under the hood, both create an engine object to handle model inference. - -2. **Engine**: Each vLLM instance contains one engine object, orchestrating and serving as the control plane for model inference. Depending on the configuration, the engine can create multiple workers to handle the inference workload. - -3. **Worker**: A worker is a process that runs the model inference. vLLM follows the common practice of using one process to control one accelerator device, such as GPUs. For example, if we use tensor parallelism of size 2 and pipeline parallelism of size 2, we will have 4 workers in total. Workers are identified by their ``rank`` and ``local_rank``. ``rank`` is used for global orchestration, while ``local_rank`` is mainly used for assigning the accelerator device and accessing local resources such as the file system and shared memory. - -4. **Model Runner**: Every worker has one model runner object, responsible for loading and running the model. Much of the model execution logic resides here, such as preparing input tensors and capturing cudagraphs. - -5. **Model**: Every model runner object has one model object, which is the actual ``torch.nn.Module`` instance. See :ref:`huggingface_integration` for how various configurations affect the class we ultimately get. - -The following figure shows the class hierarchy of vLLM: - - .. figure:: ../assets/design/hierarchy.png - :alt: query - :width: 100% - :align: center - -There are several important design choices behind this class hierarchy: - -1. **Extensibility**: All classes in the hierarchy accept a configuration object containing all the necessary information. The `VllmConfig `__ class is the main configuration object that is passed around. The class hierarchy is quite deep, and every class needs to read the configuration it is interested in. By encapsulating all configurations in one object, we can easily pass the configuration object around and access the configuration we need. Suppose we want to add a new feature (this is often the case given how fast the field of LLM inference is evolving) that only touches the model runner. We will have to add a new configuration option in the `VllmConfig` class. Since we pass the whole config object around, we only need to add the configuration option to the `VllmConfig` class, and the model runner can access it directly. We don't need to change the constructor of the engine, worker, or model class to pass the new configuration option. - -2. **Uniformity**: The model runner needs a unified interface to create and initialize the model. vLLM supports more than 50 types of popular open-source models. Each model has its own initialization logic. If the constructor signature varies with models, the model runner does not know how to call the constructor accordingly, without complicated and error-prone inspection logic. By making the constructor of the model class uniform, the model runner can easily create and initialize the model without knowing the specific model type. This is also useful for composing models. Vision-language models often consist of a vision model and a language model. By making the constructor uniform, we can easily create a vision model and a language model and compose them into a vision-language model. - -.. note:: - - To support this change, all vLLM models' signatures have been updated to: - - .. code-block:: python - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - - To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one: - - .. code-block:: python - - class MyOldModel(nn.Module): - def __init__( - self, - config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: - ... - - from vllm.config import VllmConfig - class MyNewModel(MyOldModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - super().__init__(config, cache_config, quant_config, lora_config, prefix) - - if __version__ >= "0.6.4": - MyModel = MyNewModel - else: - MyModel = MyOldModel - - This way, the model can work with both old and new versions of vLLM. - -3. **Sharding and Quantization at Initialization**: Certain features require changing the model weights. For example, tensor parallelism needs to shard the model weights, and quantization needs to quantize the model weights. There are two possible ways to implement this feature. One way is to change the model weights after the model is initialized. The other way is to change the model weights during the model initialization. vLLM chooses the latter. The first approach is not scalable to large models. Suppose we want to run a 405B model (with roughly 810GB weights) with 16 H100 80GB GPUs. Ideally, every GPU should only load 50GB weights. If we change the model weights after the model is initialized, we need to load the full 810GB weights to every GPU and then shard the weights, leading to a huge memory overhead. Instead, if we shard the weights during the model initialization, every layer will only create a shard of the weights it needs, leading to a much smaller memory overhead. The same idea applies to quantization. Note that we also add an additional argument ``prefix`` to the model's constructor so that the model can initialize itself differently based on the prefix. This is useful for non-uniform quantization, where different parts of the model are quantized differently. The ``prefix`` is usually an empty string for the top-level model and a string like ``"vision"`` or ``"language"`` for the sub-models. In general, it matches the name of the module's state dict in the checkpoint file. - -One disadvantage of this design is that it is hard to write unit tests for individual components in vLLM because every component needs to be initialized by a complete config object. We solve this problem by providing a default initialization function that creates a default config object with all fields set to ``None``. If the component we want to test only cares about a few fields in the config object, we can create a default config object and set the fields we care about. This way, we can test the component in isolation. Note that many tests in vLLM are end-to-end tests that test the whole system, so this is not a big problem. - -In summary, the complete config object ``VllmConfig`` can be treated as an engine-level global state that is shared among all vLLM classes. diff --git a/docs/source/design/plugin_system.rst b/docs/source/design/plugin_system.rst index bfca702b9267a..5a96cc8b3a464 100644 --- a/docs/source/design/plugin_system.rst +++ b/docs/source/design/plugin_system.rst @@ -8,7 +8,7 @@ The community frequently requests the ability to extend vLLM with custom feature How Plugins Work in vLLM ------------------------ -Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see :ref:`class_hierarchy`), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the `load_general_plugins `__ function in the ``vllm.plugins`` module. This function is called for every process created by vLLM before it starts any work. +Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see :ref:`arch_overview`), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the `load_general_plugins `__ function in the ``vllm.plugins`` module. This function is called for every process created by vLLM before it starts any work. How vLLM Discovers Plugins -------------------------- @@ -59,4 +59,4 @@ Guidelines for Writing Plugins Compatibility Guarantee ----------------------- -vLLM guarantees the interface of documented plugins, such as ``ModelRegistry.register_model``, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, ``"vllm_add_dummy_model.my_llava:MyLlava"`` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development. \ No newline at end of file +vLLM guarantees the interface of documented plugins, such as ``ModelRegistry.register_model``, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, ``"vllm_add_dummy_model.my_llava:MyLlava"`` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development. diff --git a/docs/source/index.rst b/docs/source/index.rst index 3b2698a8845ed..c2afd806c50f9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -101,6 +101,7 @@ Documentation models/engine_args models/lora models/vlm + models/structured_outputs models/spec_decode models/performance @@ -156,7 +157,7 @@ Documentation :maxdepth: 2 :caption: Design - design/class_hierarchy + design/arch_overview design/huggingface_integration design/plugin_system design/input_processing/model_inputs_index diff --git a/docs/source/models/structured_outputs.rst b/docs/source/models/structured_outputs.rst new file mode 100644 index 0000000000000..484e1f17d191e --- /dev/null +++ b/docs/source/models/structured_outputs.rst @@ -0,0 +1,267 @@ +.. _structured_outputs: + +Structured Outputs +================== + +vLLM supports the generation of structured outputs using `outlines `_ or `lm-format-enforcer `_ as backends for the guided decoding. +This document shows you some examples of the different options that are available to generate structured outputs. + + +Online Inference (OpenAI API) +----------------------------- + +You can generate structured outputs using the OpenAI's `Completions `_ and `Chat `_ API. + +The following parameters are supported, which must be added as extra parameters: + +- ``guided_choice``: the output will be exactly one of the choices. +- ``guided_regex``: the output will follow the regex pattern. +- ``guided_json``: the output will follow the JSON schema. +- ``guided_grammar``: the output will follow the context free grammar. +- ``guided_whitespace_pattern``: used to override the default whitespace pattern for guided json decoding. +- ``guided_decoding_backend``: used to select the guided decoding backend to use. + +You can see the complete list of supported parameters on the `OpenAI Compatible Server `_ page. + +Now let´s see an example for each of the cases, starting with the ``guided_choice``, as it´s the easiest one: + +.. code-block:: python + + from openai import OpenAI + client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="-", + ) + + completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[ + {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} + ], + extra_body={"guided_choice": ["positive", "negative"]}, + ) + print(completion.choices[0].message.content) + + +The next example shows how to use the ``guided_regex``. The idea is to generate an email address, given a simple regex template: + +.. code-block:: python + + completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[ + { + "role": "user", + "content": "Generate an example email address for Alan Turing, who works in Enigma. End in .com and new line. Example result: alan.turing@enigma.com\n", + } + ], + extra_body={"guided_regex": "\w+@\w+\.com\n", "stop": ["\n"]}, + ) + print(completion.choices[0].message.content) + +One of the most relevant features in structured text generation is the option to generate a valid JSON with pre-defined fields and formats. +For this we can use the ``guided_json`` parameter in two different ways: + +- Using directly a `JSON Schema `_ +- Defining a `Pydantic model `_ and then extracting the JSON Schema from it (which is normally an easier option). + +The next example shows how to use the ``guided_json`` parameter with a Pydantic model: + +.. code-block:: python + + from pydantic import BaseModel + from enum import Enum + + class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + + class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + + json_schema = CarDescription.model_json_schema() + + completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[ + { + "role": "user", + "content": "Generate a JSON with the brand, model and car_type of the most iconic car from the 90's", + } + ], + extra_body={"guided_json": json_schema}, + ) + print(completion.choices[0].message.content) + +.. tip:: + While not strictly necessary, normally it´s better to indicate in the prompt that a JSON needs to be generated and which fields and how should the LLM fill them. + This can improve the results notably in most cases. + + +Finally we have the ``guided_grammar``, which probably is the most difficult one to use but it´s really powerful, as it allows us to define complete languages like SQL queries. +It works by using a context free EBNF grammar, which for example we can use to define a specific format of simplified SQL queries, like in the example below: + +.. code-block:: python + + simplified_sql_grammar = """ + ?start: select_statement + + ?select_statement: "SELECT " column_list " FROM " table_name + + ?column_list: column_name ("," column_name)* + + ?table_name: identifier + + ?column_name: identifier + + ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ + """ + + completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[ + { + "role": "user", + "content": "Generate an SQL query to show the 'username' and 'email' from the 'users' table.", + } + ], + extra_body={"guided_grammar": simplified_sql_grammar}, + ) + print(completion.choices[0].message.content) + +The complete code of the examples can be found on `examples/openai_chat_completion_structured_outputs.py `_. + +Experimental Automatic Parsing (OpenAI API) +-------------------------------------------- + +This section covers the OpenAI beta wrapper over the ``client.chat.completions.create()`` method that provides richer integrations with Python specific types. + +At the time of writing (``openai==1.54.4``), this is a "beta" feature in the OpenAI client library. Code reference can be found `here `_. + +For the following examples, vLLM was setup using ``vllm serve meta-llama/Llama-3.1-8B-Instruct`` + +Here is a simple example demonstrating how to get structured output using Pydantic models: + +.. code-block:: python + + from pydantic import BaseModel + from openai import OpenAI + + + class Info(BaseModel): + name: str + age: int + + + client = OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="dummy") + completion = client.beta.chat.completions.parse( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "My name is Cameron, I'm 28. What's my name and age?"}, + ], + response_format=Info, + extra_body=dict(guided_decoding_backend="outlines"), + ) + + message = completion.choices[0].message + print(message) + assert message.parsed + print("Name:", message.parsed.name) + print("Age:", message.parsed.age) + +Output: + +.. code-block:: console + + ParsedChatCompletionMessage[Testing](content='{"name": "Cameron", "age": 28}', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[], parsed=Testing(name='Cameron', age=28)) + Name: Cameron + Age: 28 + + +Here is a more complex example using nested Pydantic models to handle a step-by-step math solution: + +.. code-block:: python + + from typing import List + from pydantic import BaseModel + from openai import OpenAI + + + class Step(BaseModel): + explanation: str + output: str + + + class MathResponse(BaseModel): + steps: List[Step] + final_answer: str + + + client = OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="dummy") + completion = client.beta.chat.completions.parse( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=[ + {"role": "system", "content": "You are a helpful expert math tutor."}, + {"role": "user", "content": "Solve 8x + 31 = 2."}, + ], + response_format=MathResponse, + extra_body=dict(guided_decoding_backend="outlines"), + ) + + message = completion.choices[0].message + print(message) + assert message.parsed + for i, step in enumerate(message.parsed.steps): + print(f"Step #{i}:", step) + print("Answer:", message.parsed.final_answer) + +Output: + +.. code-block:: console + + ParsedChatCompletionMessage[MathResponse](content='{ "steps": [{ "explanation": "First, let\'s isolate the term with the variable \'x\'. To do this, we\'ll subtract 31 from both sides of the equation.", "output": "8x + 31 - 31 = 2 - 31"}, { "explanation": "By subtracting 31 from both sides, we simplify the equation to 8x = -29.", "output": "8x = -29"}, { "explanation": "Next, let\'s isolate \'x\' by dividing both sides of the equation by 8.", "output": "8x / 8 = -29 / 8"}], "final_answer": "x = -29/8" }', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[], parsed=MathResponse(steps=[Step(explanation="First, let's isolate the term with the variable 'x'. To do this, we'll subtract 31 from both sides of the equation.", output='8x + 31 - 31 = 2 - 31'), Step(explanation='By subtracting 31 from both sides, we simplify the equation to 8x = -29.', output='8x = -29'), Step(explanation="Next, let's isolate 'x' by dividing both sides of the equation by 8.", output='8x / 8 = -29 / 8')], final_answer='x = -29/8')) + Step #0: explanation="First, let's isolate the term with the variable 'x'. To do this, we'll subtract 31 from both sides of the equation." output='8x + 31 - 31 = 2 - 31' + Step #1: explanation='By subtracting 31 from both sides, we simplify the equation to 8x = -29.' output='8x = -29' + Step #2: explanation="Next, let's isolate 'x' by dividing both sides of the equation by 8." output='8x / 8 = -29 / 8' + Answer: x = -29/8 + +Offline Inference +----------------- + +Offline inference allows for the same types of guided decoding. +To use it, we´ll need to configure the guided decoding using the class ``GuidedDecodingParams`` inside ``SamplingParams``. +The main available options inside ``GuidedDecodingParams`` are: + +- ``json`` +- ``regex`` +- ``choice`` +- ``grammar`` +- ``backend`` +- ``whitespace_pattern`` + +These parameters can be used in the same way as the parameters from the Online Inference examples above. +One example for the usage of the ``choices`` parameter is shown below: + +.. code-block:: python + + from vllm import LLM, SamplingParams + from vllm.sampling_params import GuidedDecodingParams + + llm = LLM(model="HuggingFaceTB/SmolLM2-1.7B-Instruct") + + guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"]) + sampling_params = SamplingParams(guided_decoding=guided_decoding_params) + outputs = llm.generate( + prompts="Classify this sentiment: vLLM is wonderful!", + sampling_params=sampling_params, + ) + print(outputs[0].outputs[0].text) + +A complete example with all options can be found in `examples/offline_inference_structured_outputs.py `_. diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 96a513d42753b..e902d393f2f70 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -446,7 +446,7 @@ Text Generation - GLM-4V - T + I - :code:`THUDM/glm-4v-9b` etc. - - + - ✅︎ - ✅︎ * - :code:`H2OVLChatModel` - H2OVL diff --git a/docs/source/quantization/supported_hardware.rst b/docs/source/quantization/supported_hardware.rst index 9bf0cdb80376d..09f8e7112cf0c 100644 --- a/docs/source/quantization/supported_hardware.rst +++ b/docs/source/quantization/supported_hardware.rst @@ -27,7 +27,7 @@ The table below shows the compatibility of various quantization implementations - ✅︎ - ✅︎ - ✗ - - ✗ + - ✅︎ - ✅︎ - ✗ - ✗ @@ -38,8 +38,8 @@ The table below shows the compatibility of various quantization implementations - ✅︎ - ✅︎ - ✗ - - ✗ - - ✗ + - ✅︎ + - ✅︎ - ✗ - ✗ * - Marlin (GPTQ/AWQ/FP8) @@ -129,4 +129,4 @@ Notes: Please note that this compatibility chart may be subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. -For the most up-to-date information on hardware support and quantization methods, please check the `quantization directory `_ or consult with the vLLM development team. \ No newline at end of file +For the most up-to-date information on hardware support and quantization methods, please check the `quantization directory `_ or consult with the vLLM development team. diff --git a/examples/offline_inference_structured_outputs.py b/examples/offline_inference_structured_outputs.py new file mode 100644 index 0000000000000..00d864606eeff --- /dev/null +++ b/examples/offline_inference_structured_outputs.py @@ -0,0 +1,78 @@ +from enum import Enum + +from pydantic import BaseModel + +from vllm import LLM, SamplingParams +from vllm.sampling_params import GuidedDecodingParams + +llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100) + +# Guided decoding by Choice (list of possible options) +guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"]) +sampling_params = SamplingParams(guided_decoding=guided_decoding_params) +outputs = llm.generate( + prompts="Classify this sentiment: vLLM is wonderful!", + sampling_params=sampling_params, +) +print(outputs[0].outputs[0].text) + +# Guided decoding by Regex +guided_decoding_params = GuidedDecodingParams(regex="\w+@\w+\.com\n") +sampling_params = SamplingParams(guided_decoding=guided_decoding_params, + stop=["\n"]) +prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") +outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) +print(outputs[0].outputs[0].text) + + +# Guided decoding by JSON using Pydantic schema +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + +json_schema = CarDescription.model_json_schema() + +guided_decoding_params = GuidedDecodingParams(json=json_schema) +sampling_params = SamplingParams(guided_decoding=guided_decoding_params) +prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's") +outputs = llm.generate( + prompts=prompt, + sampling_params=sampling_params, +) +print(outputs[0].outputs[0].text) + +# Guided decoding by Grammar +simplified_sql_grammar = """ + ?start: select_statement + + ?select_statement: "SELECT " column_list " FROM " table_name + + ?column_list: column_name ("," column_name)* + + ?table_name: identifier + + ?column_name: identifier + + ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ +""" +guided_decoding_params = GuidedDecodingParams(grammar=simplified_sql_grammar) +sampling_params = SamplingParams(guided_decoding=guided_decoding_params) +prompt = ("Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table.") +outputs = llm.generate( + prompts=prompt, + sampling_params=sampling_params, +) +print(outputs[0].outputs[0].text) diff --git a/examples/openai_chat_completion_structured_outputs.py b/examples/openai_chat_completion_structured_outputs.py new file mode 100644 index 0000000000000..8c059c7ca07ce --- /dev/null +++ b/examples/openai_chat_completion_structured_outputs.py @@ -0,0 +1,94 @@ +from enum import Enum + +from openai import OpenAI +from pydantic import BaseModel + +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="-", +) + +# Guided decoding by Choice (list of possible options) +completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[{ + "role": "user", + "content": "Classify this sentiment: vLLM is wonderful!" + }], + extra_body={"guided_choice": ["positive", "negative"]}, +) +print(completion.choices[0].message.content) + +# Guided decoding by Regex +prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") + +completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={ + "guided_regex": "\w+@\w+\.com\n", + "stop": ["\n"] + }, +) +print(completion.choices[0].message.content) + + +# Guided decoding by JSON using Pydantic schema +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + +json_schema = CarDescription.model_json_schema() + +prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's") +completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_json": json_schema}, +) +print(completion.choices[0].message.content) + +# Guided decoding by Grammar +simplified_sql_grammar = """ + ?start: select_statement + + ?select_statement: "SELECT " column_list " FROM " table_name + + ?column_list: column_name ("," column_name)* + + ?table_name: identifier + + ?column_name: identifier + + ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ +""" + +prompt = ("Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table.") +completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_grammar": simplified_sql_grammar}, +) +print(completion.choices[0].message.content) diff --git a/format.sh b/format.sh index a57882d2ac3f9..b3dcdc15bf948 100755 --- a/format.sh +++ b/format.sh @@ -299,6 +299,10 @@ echo 'vLLM shellcheck:' tools/shellcheck.sh echo 'vLLM shellcheck: Done' +echo 'excalidraw png check:' +tools/png-lint.sh +echo 'excalidraw png check: Done' + if ! git diff --quiet &>/dev/null; then echo echo "🔍🔍There are files changed by the format checker or by you that are not added and committed:" diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 169ce040d370c..d37f95d48d5b2 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -5,6 +5,7 @@ from tests.kernels.utils import override_backend_env_variable from vllm.attention.selector import which_attn_to_use +from vllm.platforms import cpu, cuda, openvino, rocm from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL @@ -19,26 +20,28 @@ def test_env(name: str, device: str, monkeypatch): override_backend_env_variable(monkeypatch, name) if device == "cpu": - with patch("vllm.attention.selector.current_platform.is_cpu", - return_value=True): + with patch("vllm.attention.selector.current_platform", + cpu.CpuPlatform()): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "TORCH_SDPA" elif device == "hip": - with patch("vllm.attention.selector.current_platform.is_rocm", - return_value=True): + with patch("vllm.attention.selector.current_platform", + rocm.RocmPlatform()): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "ROCM_FLASH" elif device == "openvino": - with patch("vllm.attention.selector.current_platform.is_openvino", - return_value=True): + with patch("vllm.attention.selector.current_platform", + openvino.OpenVinoPlatform()): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "OPENVINO" else: - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, - False) + with patch("vllm.attention.selector.current_platform", + cuda.CudaPlatform()): + backend = which_attn_to_use(16, torch.float16, torch.float16, 16, + False) assert backend.name == name diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py deleted file mode 100644 index 59c0a24753c3b..0000000000000 --- a/tests/kernels/test_machete_gemm.py +++ /dev/null @@ -1,284 +0,0 @@ -"""Tests for the machete kernel. - -Run `pytest tests/kernels/test_machete_gemm.py`. -""" - -import math -from typing import Optional, Tuple - -import pytest -import torch - -from tests.kernels.utils import opcheck -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_rows, quantize_weights) -from vllm.platforms import current_platform -from vllm.scalar_type import ScalarType, scalar_types - -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - -MNK_SHAPES = [ - (1, 128, 128), - (1, 512, 1024), - (1, 4096, 4096), - (1, 8192, 28672), - (13, 8192, 4096), - (26, 4096, 8192), - (64, 4096, 4096), - (64, 8192, 28672), - (257, 128, 4096), - (257, 4224, 4160), - (257, 4096, 4096), - (1024, 4096, 8192), - (1024, 8192, 4096), -] - -ACT_TYPES = [torch.float16, torch.bfloat16] -WTYPE_ZEROPOINTS = [ - # GPTQ style - (scalar_types.uint4b8, False), - (scalar_types.uint8b128, False), - # AWQ style - (scalar_types.uint4, True), - (scalar_types.uint8, True), -] - -# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel -# unit tests to a common utility function. Currently the use of -# `is_quant_method_supported` conflates kernels with quantization methods -# an assumption which is breaking down as quantizations methods can have -# have kernels and some kernels support multiple quantization methods. -IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) - - -def rand_data(shape, dtype=torch.float16): - return 10 * (torch.rand(shape, dtype=dtype, device="cuda") - 0.3) - - -def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): - return zps if zps is None else -1 * s * (zps.to(s.dtype)) - - -def machete_quantize_and_pack(w: torch.Tensor, - wtype: ScalarType, - group_size: int, - zero_points: bool = False): - assert wtype.is_integer(), "TODO: support floating point weights" - - w_ref, w_q, w_s, w_zp = quantize_weights( - w, - wtype, - group_size, - zero_points=zero_points, - # to match how the kernel applies zps - ref_zero_points_after_scales=True) - - w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) - w_q = w_q.t().contiguous().t() # convert to col major - w_q_machete = ops.machete_prepack_B(w_q, wtype) - - opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id)) - - return w_ref, w_q_machete, w_s, w_zp - - -def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor, - wtype: ScalarType, group_size: int, - zero_points: bool): - w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - b, wtype, group_size, zero_points) - - output_ref = torch.matmul(a, w_ref) - - output = ops.machete_gemm( - a=a, - b_q=w_q_packed, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - ) - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) -@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) -@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) -@pytest.mark.parametrize("group_size", [128, None]) -def test_machete_all_schedules(shape, atype: torch.dtype, - wtype_zeropoints: Tuple[ScalarType, bool], - group_size: Optional[int]): - m, n, k = shape - wtype, zero_points = wtype_zeropoints - - if group_size is not None and k % group_size != 0: - return - - print(f"MNK = {m} {n} {k}") - - # Normalize group_size - if group_size is None: - group_size = k - assert group_size <= k - - a = rand_data((m, k), atype) - w = rand_data((k, n), atype) - - w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack( - w, wtype, group_size, zero_points) - - output_ref = torch.matmul(a, w_ref) - - for schedule in ops.machete_supported_schedules(wtype): - print(f"Testing schedule {schedule}") - output = ops.machete_gemm( - a, - b_q=w_q_machete, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - schedule=schedule, - ) - - opcheck( - torch.ops._C.machete_gemm, - (a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints( - w_zp, w_s), group_size, None, None, None, schedule)) - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\ - f"Schedule failed {schedule}" - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) -@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) -@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) -@pytest.mark.parametrize("group_size", [128, None]) -def test_machete_heuristic(shape, atype: torch.dtype, - wtype_zeropoints: Tuple[ScalarType, bool], - group_size: Optional[int]): - m, n, k = shape - wtype, zero_points = wtype_zeropoints - - if group_size is not None and k % group_size != 0: - return - - # Normalize group_size - if group_size is None: - group_size = k - assert group_size <= k - - a = rand_data((m, k), atype) - b = rand_data((k, n), atype) - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test working on other devices -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_machete_devices(device: str): - m, n, k = 512, 4096, 4096 - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - print(f"MNK = {m} {n} {k}, device = {device}") - - a = rand_data((m, k), torch.float16).to(device) - b = rand_data((k, n), torch.float16).to(device) - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test working with a subset of A and B -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -def test_machete_subset(): - big_m, big_n, big_k = 1024, 1024, 1024 - m, n, k = 512, 512, 512 - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - whole_a = rand_data((big_m, big_k), torch.float16) - whole_b = rand_data((big_k, big_n), torch.float16) - - a = whole_a[0:m, 0:k] - b = whole_b[0:k, 0:n] - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test to make sure cuda graphs work -class MacheteLayer(torch.nn.Module): - - def __init__(self, **kwargs): - super().__init__() - self.kwargs = kwargs - - def forward(self, a): - return ops.machete_gemm(**self.kwargs) - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -def test_machete_cuda_graph(): - m, n, k = 512, 4096, 4096 - - a = rand_data((m, k), torch.float16) - b = rand_data((k, n), torch.float16) - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - b, wtype, group_size, zero_points) - - # Construct a trivial model with a single layer that calls a machete kernel - model = MacheteLayer( - a=a, - b_q=w_q_packed, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - ) - - output_ref = torch.matmul(a, w_ref) - - # Run the model with a cuda graph - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - output = model(a) - output.zero_() - g.replay() - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) diff --git a/tests/kernels/test_machete_mm.py b/tests/kernels/test_machete_mm.py new file mode 100644 index 0000000000000..1c6eb2dd9a228 --- /dev/null +++ b/tests/kernels/test_machete_mm.py @@ -0,0 +1,406 @@ +"""Tests for the machete kernel. + +Run `pytest tests/kernels/test_machete_mm.py`. +""" + +import math +from dataclasses import dataclass, fields +from typing import List, Optional, Tuple + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, quantize_weights) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + +MNK_SHAPES = [ + (1, 128, 128), + (1, 512, 1024), + (1, 4096, 4096), + (1, 8192, 28672), + (13, 8192, 4096), + (26, 4096, 8192), + (64, 4096, 4096), + (64, 8192, 28672), + (257, 128, 4096), + (257, 4224, 4160), + (257, 4096, 4096), + (1024, 4096, 8192), + (1024, 8192, 4096), +] + +GROUP_SIZES_TO_TEST: List[Optional[int]] = [128, -1] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + group_zero_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +@dataclass +class Tensors: + w_ref: torch.Tensor + a_ref: torch.Tensor + a: torch.Tensor + w_q: torch.Tensor + w_g_s: Optional[torch.Tensor] + w_g_zp: Optional[torch.Tensor] + w_ch_s: Optional[torch.Tensor] + w_tok_s: Optional[torch.Tensor] + + +# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints, +# Ch Scales Type, Tok Scales Type) +# NOTE: None "Scale Type" means the act type is floating point +# None "Output Type" means the output type is the same as the act type +TestTypeTuple = Tuple[List[torch.dtype], ScalarType, Optional[torch.dtype], + Optional[torch.dtype], bool] +TEST_TYPES = [ + # GPTQ style + *(TypeConfig(act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None) + for w_type in [scalar_types.uint4b8, scalar_types.uint8b128] + for a_type in [torch.float16, torch.bfloat16]), + # AWQ style + *(TypeConfig(act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=a_type, + channel_scale_type=None, + token_scale_type=None) + for w_type in [scalar_types.uint4, scalar_types.uint8] + for a_type in [torch.float16, torch.bfloat16]), + # QQQ style + *(TypeConfig(act_type=torch.int8, + weight_type=scalar_types.uint4b8, + output_type=torch.float16, + group_scale_type=group_scale_type, + group_zero_type=None, + channel_scale_type=torch.float, + token_scale_type=torch.float) + for group_scale_type in [None, torch.float16]), + *(TypeConfig(act_type=torch.float8_e4m3fn, + weight_type=scalar_types.uint4b8, + output_type=torch.float16, + group_scale_type=group_scale_type, + group_zero_type=None, + channel_scale_type=torch.float, + token_scale_type=torch.float) + for group_scale_type in [None, torch.float16]), +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) + + +def rand_data(shape, dtype=torch.float16, scale=1, offset=0): + if dtype.is_floating_point: + return (scale * torch.rand(shape, device="cuda") - offset).to(dtype) + else: + return torch.randint(-8, 7, shape, dtype=dtype, device="cuda") + + +def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): + return zps if zps is None else -1 * s * (zps.to(s.dtype)) + + +def group_size_valid(shape: Tuple[int, int, int], + group_size: Optional[int]) -> bool: + return group_size is None or group_size == -1 or group_size % shape[2] == 0 + + +def machete_quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, + wtype, + group_size=group_size, + zero_points=zero_points, + # to match how the kernel applies zps + ref_zero_points_after_scales=True) + + w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # convert to col major + + w_q_machete = ops.machete_prepack_B(w_q, atype, wtype, stype) + opcheck(torch.ops._C.machete_prepack_B, (w_q, atype, wtype.id, stype)) + + return w_ref, w_q_machete, w_s, w_zp + + +def create_test_tensors(shape: Tuple[int, int, int], + types: TypeConfig, + group_size: Optional[int], + subset_stride_factor: Optional[int] = None) -> Tensors: + m, n, k = shape + factor = subset_stride_factor or 1 + + print("create_test_tensors, shape:", shape, "types:", types, "group_size:", + group_size) + + a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2) + w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1) + + if factor > 1: + a = a[0:m, 0:k] + w = w[0:k, 0:n] + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + types.group_zero_type is not None) + + if not a.dtype.is_floating_point: + aiinfo = torch.iinfo(a.dtype) + w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max) + + a_ref = a.to(torch.float32) + w_ref = w_ref.to(torch.float32) + + w_ch_s = None if types.channel_scale_type is None else\ + rand_data((n,), types.channel_scale_type) + w_tok_s = None if types.token_scale_type is None else\ + rand_data((m,), types.token_scale_type) + + return Tensors(w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_g_zp=maybe_convert_zeropoints(w_zp, w_s), + w_ch_s=w_ch_s, + w_tok_s=w_tok_s) + + +# None stype means scales use the same dtype as a +def machete_mm_test_helper(types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None): + output_ref = torch.matmul(tensors.a_ref, tensors.w_ref) + output_ref_type = output_ref.dtype + + if tensors.w_ch_s is not None: + output_ref = (output_ref.to(tensors.w_ch_s.dtype) * + tensors.w_ch_s.unsqueeze(0)).to(output_ref_type) + if tensors.w_tok_s is not None: + output_ref = (output_ref.to(tensors.w_tok_s.dtype) * + tensors.w_tok_s.unsqueeze(1)).to(output_ref_type) + + output = ops.machete_mm( + a=tensors.a, + b_q=tensors.w_q, + b_type=types.weight_type, + b_group_scales=tensors.w_g_s, + b_group_zeros=tensors.w_g_zp, + b_group_size=group_size, + b_channel_scales=tensors.w_ch_s, + a_token_scales=tensors.w_tok_s, + out_type=types.output_type, + schedule=schedule, + ) + + print(output) + print(output_ref) + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if tensors.w_g_zp is not None\ + else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1) + rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1 + torch.testing.assert_close(output, + output_ref.to(output.dtype), + rtol=rtol, + atol=atol) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +def test_machete_all_schedules(shape, types: TypeConfig): + + group_sizes: List[Optional[int]] = [] + if types.group_scale_type is None: + group_sizes = [None] + else: + group_sizes = GROUP_SIZES_TO_TEST + + for group_size in group_sizes: + if not group_size_valid(shape, group_size): + continue + + tensors = create_test_tensors(shape, types, group_size) + print(f"MNK = {shape}") + for schedule in ops.machete_supported_schedules( + types.act_type, + types.weight_type, + group_scales_type=types.group_scale_type, + group_zeros_type=types.group_scale_type, + out_type=types.output_type): + print(f"Testing schedule {schedule}") + machete_mm_test_helper(types, tensors, group_size, schedule) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +def test_machete_heuristic(shape, types: TypeConfig): + group_sizes: List[Optional[int]] = [] + if types.group_scale_type is None: + group_sizes = [None] + else: + group_sizes = GROUP_SIZES_TO_TEST + + for group_size in group_sizes: + if not group_size_valid(shape, group_size): + continue + + tensors = create_test_tensors(shape, types, group_size) + machete_mm_test_helper(types, tensors, group_size) + + +# Test working on other devices +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_machete_devices(device: str): + group_size = 128 + + type_config = TypeConfig(act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None) + + tensors = create_test_tensors((512, 4096, 4096), type_config, group_size) + + for field in fields(Tensors): + tensor = getattr(tensors, field.name) + if isinstance(tensor, torch.Tensor): + setattr(tensors, field.name, tensor.to(device)) + + machete_mm_test_helper(type_config, tensors, group_size) + + +# Test working with a subset of A and B +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_subset(): + group_size = 128 + + type_config = TypeConfig(act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None) + + tensors = create_test_tensors((512, 4096, 4096), + type_config, + group_size, + subset_stride_factor=2) + machete_mm_test_helper(type_config, tensors, group_size) + + +# Test to make sure cuda graphs work +class MacheteLayer(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.machete_mm(a=a, **self.kwargs) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = rand_data((m, k), torch.float16) + b = rand_data((k, n), torch.float16) + wtype = scalar_types.uint4b8 + stype = torch.float16 + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + a.dtype, b, wtype, stype, group_size, zero_points) + + # Construct a trivial model with a single layer that calls a machete kernel + model = MacheteLayer( + b_q=w_q_packed, + b_type=wtype, + b_group_scales=w_s, + b_group_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + ) + + output_ref = torch.matmul(a, w_ref) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + output.zero_() + g.replay() + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) diff --git a/tests/quantization/test_ipex_quant.py b/tests/quantization/test_ipex_quant.py index d541efcefcac3..68a73f0f8ab48 100644 --- a/tests/quantization/test_ipex_quant.py +++ b/tests/quantization/test_ipex_quant.py @@ -1,5 +1,5 @@ """Test model set-up and inference for quantized HF models supported - on the CPU backend using IPEX (including AWQ). + on the CPU/GPU backend using IPEX (including AWQ/GPTQ). Validating the configuration and printing results for manual checking. @@ -11,13 +11,15 @@ from vllm.platforms import current_platform MODELS = [ - "casperhansen/llama-3-8b-instruct-awq", + "AMead10/Llama-3.2-1B-Instruct-AWQ", + "shuyuej/Llama-3.2-1B-Instruct-GPTQ", # with g_idx ] DTYPE = ["bfloat16"] -@pytest.mark.skipif(not current_platform.is_cpu(), - reason="only supports the CPU backend.") +@pytest.mark.skipif(not current_platform.is_cpu() + and not current_platform.is_xpu(), + reason="only supports Intel CPU/XPU backend.") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", DTYPE) def test_ipex_quant(vllm_runner, model, dtype): diff --git a/tools/png-lint.sh b/tools/png-lint.sh new file mode 100755 index 0000000000000..a80fe9837342f --- /dev/null +++ b/tools/png-lint.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# Ensure that *.excalidraw.png files have the excalidraw metadata +# embedded in them. This ensures they can be loaded back into +# the tool and edited in the future. + +find . -iname '*.excalidraw.png' | while read -r file; do + if git check-ignore -q "$file"; then + continue + fi + if ! grep -q "excalidraw+json" "$file"; then + echo "$file was not exported from excalidraw with 'Embed Scene' enabled." + exit 1 + fi +done diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b276b8fc25473..aa89010ca8ecd 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -444,18 +444,18 @@ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake("_C::machete_gemm") - def machete_gemm_fake( + @register_fake("_C::machete_mm") + def machete_mm_fake( a: torch.Tensor, - # Should be the tensor returned by machete_prepack_B + # b_q Should be the tensor returned by machete_prepack_B b_q: torch.Tensor, b_type: ScalarType, - b_scales: Optional[torch.Tensor] = None, - b_zeros: Optional[torch.Tensor] = None, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, b_group_size: Optional[int] = None, - c: Optional[torch.Tensor] = None, - alpha: Optional[float] = None, - beta: Optional[float] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, schedule: Optional[str] = None, ) -> torch.Tensor: m = a.size(0) @@ -463,8 +463,9 @@ def machete_gemm_fake( return torch.empty((m, n), device=a.device, dtype=a.dtype) @register_fake("_C::machete_prepack_B") - def machete_prepack_B_fake(b_q_weight: torch.Tensor, - b_type: ScalarType) -> torch.Tensor: + def machete_prepack_B_fake( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) @@ -617,29 +618,41 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # machete -def machete_supported_schedules(b_type: ScalarType) -> List[str]: - return torch.ops._C.machete_supported_schedules(b_type.id) - - -def machete_gemm( - a: torch.Tensor, - b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B - b_type: ScalarType, - b_scales: Optional[torch.Tensor] = None, - b_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - c: Optional[torch.Tensor] = None, - alpha: Optional[float] = None, - beta: Optional[float] = None, - schedule: Optional[str] = None, -) -> torch.Tensor: - return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros, - b_group_size, c, alpha, beta, schedule) +def machete_supported_schedules( + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], + group_zeros_type: Optional[torch.dtype] = None, + channel_scales_type: Optional[torch.dtype] = None, + token_scales_type: Optional[torch.dtype] = None, + out_type: Optional[torch.dtype] = None) -> List[str]: + return torch.ops._C.machete_supported_schedules( + a_type, b_type.id, group_scales_type, group_zeros_type, + channel_scales_type, token_scales_type, out_type) -def machete_prepack_B(b_q_weight: torch.Tensor, - b_type: ScalarType) -> torch.Tensor: - return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id) +def machete_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + schedule: Optional[str] = None) -> torch.Tensor: + return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales, + b_group_zeros, b_group_size, + b_channel_scales, a_token_scales, schedule) + + +def machete_prepack_B( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: + return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id, + group_scales_type) if hasattr(torch.ops._C, "permute_cols"): diff --git a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py index ec1c37c5bcb0e..727a470ba6d0e 100644 --- a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py +++ b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py @@ -157,19 +157,22 @@ def _fwd_kernel_inner( k = tl.load( k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < k_seqlen, + other=0.0, ) else: k = tl.load( k_ptrs + start_n * stride_kt, mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD), + other=0.0, ) else: if EVEN_D: k = tl.load(k_ptrs + start_n * stride_kt) else: k = tl.load(k_ptrs + start_n * stride_kt, - mask=offs_d[:, None] < D_HEAD) + mask=offs_d[:, None] < D_HEAD, + other=0.0) qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -200,19 +203,22 @@ def _fwd_kernel_inner( v = tl.load( v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < k_seqlen, + other=0.0, ) else: v = tl.load( v_ptrs + start_n * stride_vt, mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD), + other=0.0, ) else: if EVEN_D: v = tl.load(v_ptrs + start_n * stride_vt) else: v = tl.load(v_ptrs + start_n * stride_vt, - mask=offs_d[None, :] < D_HEAD) + mask=offs_d[None, :] < D_HEAD, + other=0.0) acc += tl.dot(p, v) @@ -318,12 +324,13 @@ def _fwd_kernel_batch_inference( q = tl.load( Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, mask=offs_m[:, None] < q_seqlen, + other=0.0, ) else: q = tl.load( Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), - other=0, + other=0.0, ) sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 664707e9dc65d..d263839705690 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,4 +1,3 @@ -import enum import os from contextlib import contextmanager from functools import lru_cache @@ -9,26 +8,12 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger -from vllm.platforms import current_platform +from vllm.platforms import _Backend, current_platform from vllm.utils import STR_BACKEND_ENV_VAR logger = init_logger(__name__) -class _Backend(enum.Enum): - FLASH_ATTN = enum.auto() - FLASH_ATTN_VLLM_V1 = enum.auto() - XFORMERS = enum.auto() - ROCM_FLASH = enum.auto() - TORCH_SDPA = enum.auto() - OPENVINO = enum.auto() - FLASHINFER = enum.auto() - HPU_ATTN = enum.auto() - PALLAS = enum.auto() - IPEX = enum.auto() - NO_ATTENTION = enum.auto() - - def backend_name_to_enum(backend_name: str) -> _Backend: assert backend_name is not None @@ -216,40 +201,11 @@ def which_attn_to_use(head_size: int, if backend_by_env_var is not None: selected_backend = backend_name_to_enum(backend_by_env_var) - if current_platform.is_cpu(): - if selected_backend != _Backend.TORCH_SDPA: - logger.info("Cannot use %s backend on CPU.", selected_backend) - return _Backend.TORCH_SDPA - - if current_platform.is_openvino(): - if selected_backend != _Backend.OPENVINO: - logger.info("Cannot use %s backend on OpenVINO.", selected_backend) - return _Backend.OPENVINO - - if current_platform.is_xpu(): - if selected_backend != _Backend.IPEX: - logger.info("Cannot use %s backend on XPU.", selected_backend) - return _Backend.IPEX - - if current_platform.is_tpu(): - if selected_backend != _Backend.PALLAS: - logger.info("Cannot use %s backend on TPU.", selected_backend) - return _Backend.PALLAS - - if current_platform.is_rocm(): - # AMD GPUs. - selected_backend = (_Backend.ROCM_FLASH if selected_backend - == _Backend.FLASH_ATTN else selected_backend) - if selected_backend == _Backend.ROCM_FLASH: - if not current_platform.has_device_capability(90): - # not Instinct series GPUs. - logger.info("flash_attn is not supported on NAVI GPUs.") - else: - logger.info("%s is not supported in AMD GPUs.", selected_backend) - return _Backend.ROCM_FLASH - - if current_platform.is_hpu(): - return _Backend.HPU_ATTN + # get device-specific default attn_backend + default_backend = current_platform.get_default_attn_backend( + selected_backend) + if default_backend is not None: + return default_backend if use_v1: return _Backend.FLASH_ATTN_VLLM_V1 diff --git a/vllm/config.py b/vllm/config.py index 14017bbdb3cf2..ea9ec43cc5a15 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4,8 +4,9 @@ import warnings from dataclasses import dataclass, field, replace from pathlib import Path -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List, - Literal, Mapping, Optional, Set, Tuple, Type, Union) +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict, + Final, List, Literal, Mapping, Optional, Set, Tuple, Type, + Union) import torch from pydantic import BaseModel, Field, PrivateAttr @@ -2169,6 +2170,10 @@ class CompilationConfig(BaseModel): compile_sizes: List[int] = PrivateAttr capture_sizes: List[int] = PrivateAttr + # keep track of enabled and disabled custom ops + enabled_custom_ops: Counter[str] = PrivateAttr + disabled_custom_ops: Counter[str] = PrivateAttr + def model_post_init(self, __context: Any) -> None: self.level = envs.VLLM_TORCH_COMPILE_LEVEL @@ -2190,6 +2195,9 @@ def model_post_init(self, __context: Any) -> None: func = __import__(module).__dict__[func_name] self.inductor_compile_config[k] = func + self.enabled_custom_ops = Counter() + self.disabled_custom_ops = Counter() + def init_backend(self) -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: raise ValueError("No compilation level is set.") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 92fa87c7fa45b..ee4b6addfd466 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -793,7 +793,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, default=[], help="The pattern(s) to ignore when loading the model." - "Default to 'original/**/*' to avoid repeated loading of llama's " + "Default to `original/**/*` to avoid repeated loading of llama's " "checkpoints.") parser.add_argument( '--preemption-mode', diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9a2d73a020c8f..e72dc81f35b67 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1716,7 +1716,7 @@ def _get_stats(self, # not counted (to avoid double counting) actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore - num_generation_tokens_from_prefill_groups = 0. + num_generation_tokens_from_prefill_groups = 0 # NOTE: if scheduler_outputs.num_prefill_groups > 0 and # the len of scheduler_outputs.scheduled_seq_groups is != # scheduler_outputs.num_prefill_groups, this means that diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index e896bcdded2d1..47472c274ccb6 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -512,6 +512,11 @@ def _log_gauge(self, gauge, data: Union[int, float]) -> None: def _log_counter(self, counter, data: Union[int, float]) -> None: # Convenience function for logging to counter. + # Prevent ValueError from negative increment + if data < 0: + logger.warning("Skipping negative increment of %g to %s", data, + counter) + return counter.labels(**self.labels).inc(data) def _log_counter_labels(self, counter, data: CollectionsCounter, diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py index 6a32387a6f36c..f176259fddc78 100644 --- a/vllm/lora/ops/bgmv_expand.py +++ b/vllm/lora/ops/bgmv_expand.py @@ -75,7 +75,9 @@ def _bgmv_expand_kernel( other=0.0, ) # [BLOCK_N,BLOCK_K] if ADD_INPUTS: - tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + tiled_out = tl.load(c_ptr + current_n * cn_stride, + mask=c_mask, + other=0.0) accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out else: accumulator = tl.sum(tiled_a * tiled_b, 1) diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index 73628fd20d327..2c6ed96c253f0 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -78,7 +78,13 @@ def _bgmv_expand_slice_kernel( ) # [BLOCK_N,BLOCK_K] if ADD_INPUTS: - tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + # explicitly pass in other=None to tell triton that masked values + # can be uninitialized. This is OK because the later tl.store + # operation uses the same mask, eliminating the risk of garbage + # values propagating + tiled_out = tl.load(c_ptr + current_n * cn_stride, + mask=c_mask, + other=None) accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out else: accumulator = tl.sum(tiled_a * tiled_b, 1) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 4910cb4061298..ee2cd2e05e2ee 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -88,7 +88,10 @@ def _sgmv_expand_kernel( c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < N) if ADD_INPUTS: - tiled_out = tl.load(c_ptr, mask=c_mask) + # explicitly pass in other=None to tell triton that masked values + # can be uninitialized. This is OK because the later tl.store operation + # uses the same mask, eliminating the risk of garbage values propagating + tiled_out = tl.load(c_ptr, mask=c_mask, other=None) tiled_c += tiled_out tl.store(c_ptr, tiled_c, mask=c_mask) diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index 844f5cec39e93..5244fa14913a4 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -94,7 +94,10 @@ def _sgmv_expand_slice_kernel( c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < (slice_offset + N)) if ADD_INPUTS: - tiled_out = tl.load(c_ptr, mask=c_mask) + # explicitly pass in other=None to tell triton that masked values + # can be uninitialized. This is OK because the later tl.store operation + # uses the same mask, eliminating the risk of garbage values propagating + tiled_out = tl.load(c_ptr, mask=c_mask, other=None) tiled_c += tiled_out tl.store(c_ptr, tiled_c, mask=c_mask) diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 6ae7d7cf6964f..b07966f2ab7d0 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -61,10 +61,13 @@ def forward_hpu(self, *args, **kwargs): def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - + compilation_config = get_current_vllm_config().compilation_config enabled = self.enabled() - logger.debug("custom op %s %s", self.__class__.name, - "enabled" if enabled else "disabled") + if enabled: + compilation_config.enabled_custom_ops.update([self.__class__.name]) + else: + compilation_config.disabled_custom_ops.update( + [self.__class__.name]) if not enabled: return self.forward_native diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 94f30412e43b3..9da38d4857d6d 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -27,7 +27,7 @@ "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod", - "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod" + "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod" ] @@ -470,7 +470,8 @@ def weight_loader(self, needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) if loaded_shard_id is None: - # Loaded weight is already fused on disk (qkv/mlp). + # Loaded weight is already fused on disk (mlp). + # (e.g., Phi-3's gate_up_proj). if output_dim is None: if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( @@ -480,6 +481,8 @@ def weight_loader(self, param_data.copy_(loaded_weight) return current_shard_offset = 0 + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) shard_offsets: List[Tuple[int, int, int]] = [] for i, output_size in enumerate(self.output_sizes): shard_offsets.append((i, current_shard_offset, output_size)) @@ -495,7 +498,9 @@ def weight_loader(self, # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - + if use_bitsandbytes_4bit: + shard_size = loaded_weight.shape[output_dim] // 2 + shard_offset = shard_size * shard_id loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) @@ -808,7 +813,8 @@ def weight_loader(self, needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) if loaded_shard_id is None: - # Loaded weight is already fused on disk (qkv/mlp). + # Loaded weight is already fused on disk (qkv). + # (e.g., Phi-3's qkv_proj). if output_dim is None: if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py index bbb7fc8ad5087..ace8f4a348812 100644 --- a/vllm/model_executor/layers/quantization/awq_triton.py +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -42,7 +42,7 @@ def awq_dequantize_kernel( result_masks = result_masks_y[:, None] & result_masks_x[None, :] # Load the weights. - iweights = tl.load(qweight_ptr + offsets, masks) + iweights = tl.load(qweight_ptr + offsets, masks, 0.0) iweights = tl.interleave(iweights, iweights) iweights = tl.interleave(iweights, iweights) iweights = tl.interleave(iweights, iweights) @@ -71,7 +71,7 @@ def awq_dequantize_kernel( zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] # Load the zeros. - zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) @@ -91,7 +91,7 @@ def awq_dequantize_kernel( scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] # Load the scales. - scales = tl.load(scales_ptr + scale_offsets, scale_masks) + scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0) scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) # Dequantize. @@ -165,10 +165,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): masks_k = offsets_k < K masks_a = masks_am[:, None] & masks_k[None, :] - a = tl.load(a_ptrs, mask=masks_a) + a = tl.load(a_ptrs, mask=masks_a, other=0.0) masks_b = masks_k[:, None] & masks_bn[None, :] - b = tl.load(b_ptrs, mask=masks_b) + b = tl.load(b_ptrs, mask=masks_b, other=0.0) b = tl.interleave(b, b) b = tl.interleave(b, b) b = tl.interleave(b, b) @@ -181,7 +181,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, masks_zk = offsets_szk < K // group_size masks_z = masks_zk[:, None] & masks_zn[None, :] zeros_ptrs = zeros_ptr + offsets_z - zeros = tl.load(zeros_ptrs, mask=masks_z) + zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) @@ -191,7 +191,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, masks_sk = offsets_szk < K // group_size masks_s = masks_sk[:, None] & masks_sn[None, :] scales_ptrs = scales_ptr + offsets_s - scales = tl.load(scales_ptrs, mask=masks_s) + scales = tl.load(scales_ptrs, mask=masks_s, other=0.0) scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) b = (b >> shifts) & 0xF diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 0aa605e62454e..abafad0f1047e 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -210,7 +210,6 @@ def create_weights( def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # for torch.compile - layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 1f72e3afbbce5..a3e58bf1b2a4c 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -23,6 +23,7 @@ PackedColumnParameter, PackedvLLMParameter, RowvLLMParameter) +from vllm.platforms import current_platform from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -134,6 +135,9 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): sym = quant_config.get("sym") desc_act = quant_config.get("desc_act") + if not current_platform.is_cuda(): + return False + if quant_method != "gptq": return False diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 330c2ad195d78..c16a962134d06 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -2,21 +2,26 @@ import torch -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization.awq import AWQLinearMethod +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod, + is_layer_skipped_awq) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.platforms import current_platform +MIN_IPEX_VERSION = "2.5.0" + class IPEXConfig(QuantizationConfig): - """INT8 quantization config class using IPEX for the CPU backend, - including AWQ. + """INT8 quantization config class using IPEX for the CPU/XPU backend, + including AWQ, GPTQ. """ IPEX_QUANT_METHOD_MAP = { "awq": 1, - "gptq": 2, + "gptq": 0, } def __init__( @@ -24,29 +29,30 @@ def __init__( method: str, weight_bits: int, group_size: int, + modules_to_not_convert: Optional[List[str]] = None, + desc_act: Optional[bool] = None, + lm_head_quantized: Optional[bool] = None, ) -> None: self.method = method self.weight_bits = weight_bits self.group_size = group_size + self.modules_to_not_convert = modules_to_not_convert or [] + self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized self.pack_factor = 32 // self.weight_bits if self.weight_bits not in [4]: raise ValueError(f"IPEX quantization supports weight bits [4], " f"but got {self.weight_bits}.") - if self.method == "awq": - self.quant_method = IPEXAWQLinearMethod - else: - raise ValueError(f"IPEX quantization supports [awq], " + if self.method not in ["awq", "gptq"]: + raise ValueError(f"IPEX quantization supports [awq, gptq], " f"but got {self.method}.") def __repr__(self) -> str: - return (f"IPEXConfig(method={self.method}" + return (f"IPEXConfig(method={self.method}," f"weight_bits={self.weight_bits}, " - f"group_size={self.group_size}") - - def get_ipex_quant_method_id(self) -> int: - return IPEXConfig.IPEX_QUANT_METHOD_MAP[self.method] + f"group_size={self.group_size})") @classmethod def get_name(cls) -> str: @@ -70,19 +76,32 @@ def get_config_filenames() -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig": method = cls.get_from_keys(config, ["quant_method"]).lower() - weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) - group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) - return cls(method, weight_bits, group_size) + if method == "awq": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, + ["q_group_size", "group_size"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) + return cls(method, weight_bits, group_size, modules_to_not_convert, + False, False) + # otherwise for gptq + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False) + return cls(method, weight_bits, group_size, [], desc_act, + lm_head_quantized) @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - if not current_platform.is_cpu(): + if not current_platform.is_cpu() and not current_platform.is_xpu(): return None quant_method = hf_quant_cfg.get("quant_method", "").lower() - if quant_method in ["awq"]: + if quant_method in ["awq", "gptq"]: return cls.get_name() return None @@ -90,12 +109,81 @@ def override_quantization_method(cls, hf_quant_cfg, def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["LinearMethodBase"]: if isinstance(layer, LinearBase): - return self.quant_method(self) + if self.method == "awq": + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + return IPEXAWQLinearMethod(self) + if self.method == "gptq": + return IPEXGPTQLinearMethod(self) return None +class IPEXGPTQLinearMethod(GPTQLinearMethod): + """GPTQ linear method using IPEX for the CPU/XPU backend. + """ + + def __init__(self, quant_config: IPEXConfig): + self.quant_config = quant_config # type: ignore + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + bias = layer.bias if not layer.skip_bias_add else None + + try: + import intel_extension_for_pytorch as ipex + if ipex.__version__ < MIN_IPEX_VERSION: + raise ImportError( + "intel_extension_for_pytorch version is " + "wrong. Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.") + except ImportError as err: + raise ImportError( + "Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " + f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" + " to use IPEX-AWQ linear method.") from err + # Using the compute dtype (lowp_mode) as INT8 to leverage instructions + # with better performance. + lowp_mode = ipex.quantization.WoqLowpMode.INT8 + # The weight will be de-packed from INT4 to INT8. + weight_dtype = ipex.quantization.WoqWeightDtype.INT4 + # The float activation will be quantized (dynamic, per-token) to INT8. + act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK + + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode, + group_size=self.quant_config.group_size, + ) + layer.ipex_output_size = layer.qweight.shape[-1] + g_idx = layer.g_idx if self.quant_config.desc_act else None + layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ + IPEXWeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + g_idx=g_idx, + bias=bias, + group_size=self.quant_config.group_size, + quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"] + ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + out = layer.ipex_qlinear(reshaped_x) + if bias is not None: + out.add_(bias) + return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) + + class IPEXAWQLinearMethod(AWQLinearMethod): - """AWQ linear method using IPEX for the CPU backend. + """AWQ linear method using IPEX for the CPU/XPU backend. """ def __init__(self, quant_config: IPEXConfig): @@ -108,15 +196,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: try: import intel_extension_for_pytorch as ipex - if ipex.__version__ < "2.4.0": - raise ImportError("intel_extension_for_pytorch version is " - "wrong. Please install " - "intel_extension_for_pytorch>=2.4.0.") + if ipex.__version__ < MIN_IPEX_VERSION: + raise ImportError( + "intel_extension_for_pytorch version is " + "wrong. Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.") except ImportError as err: raise ImportError( "Please install " - "intel_extension_for_pytorch>=2.4.0 via " - "`pip install intel_extension_for_pytorch>=2.4.0`" + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " + f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" " to use IPEX-AWQ linear method.") from err # Using the compute dtype (lowp_mode) as INT8 to leverage instructions @@ -136,19 +225,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.ipex_output_size = layer.qweight.size( 1) * self.quant_config.pack_factor - layer.ipex_qlinear = ipex.nn.modules.weight_only_quantization.\ - WeightOnlyQuantizedLinear.from_weight( - layer.qweight, - layer.scales, - layer.qzeros, - layer.qweight.size(0), - layer.ipex_output_size, - qconfig=qconfig, - bias=bias, - group_size=self.quant_config.group_size, - quant_method= - self.quant_config.get_ipex_quant_method_id() # type: ignore - ) + layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ + IPEXWeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + bias=bias, + group_size=self.quant_config.group_size, + quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore + ) def apply(self, layer: torch.nn.Module, @@ -156,5 +244,4 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) out = layer.ipex_qlinear(reshaped_x) - return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/machete.py index e5696d08f30f5..15df0200f30b5 100644 --- a/vllm/model_executor/layers/quantization/kernels/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/machete.py @@ -79,7 +79,9 @@ def transform_w_q(x): c.weight_type, packed_dim=0) x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), - self.config.weight_type) + a_type=c.act_type, + b_type=c.weight_type, + group_scales_type=c.act_type) return x def transform_w_s(x): @@ -105,12 +107,12 @@ def apply_weights(self, if c.has_g_idx: x_2d = self.act_perm(x_2d) - output = ops.machete_gemm(a=x_2d, - b_q=w_q, - b_type=c.weight_type, - b_zeros=None, - b_scales=w_s, - b_group_size=c.group_size) + output = ops.machete_mm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_group_zeros=None, + b_group_scales=w_s, + b_group_size=c.group_size) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index c217f5ca620a1..83055d6000d83 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -126,11 +126,14 @@ def permute_rows(q_w: torch.Tensor, def quantize_weights(w: torch.Tensor, quant_type: ScalarType, - group_size: int, + group_size: Optional[int], zero_points: bool = False, ref_zero_points_after_scales: bool = False): assert quant_type.is_integer(), \ "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, \ + "to have group zero points, group_size must be provided "\ + "(-1 group_size is channelwise)" orig_device = w.device orig_type = w.dtype @@ -140,10 +143,9 @@ def quantize_weights(w: torch.Tensor, if group_size == -1: group_size = size_k - assert group_size <= size_k # Reshape to [groupsize, -1] - if group_size < size_k: + if group_size is not None and group_size < size_k: w = w.reshape((-1, group_size, size_n)) w = w.permute(1, 0, 2) w = w.reshape((group_size, -1)) @@ -155,18 +157,20 @@ def quantize_weights(w: torch.Tensor, max_q_val = quant_type.max() min_q_val = quant_type.min() - if zero_points: - assert not quant_type.is_signed() and quant_type.max() > 0 - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ - .clamp(min_q_val, max_q_val).int() - else: - # If the bias is such that there are no possible negative/positive - # values, set the max value to inf to avoid divide by 0 - w_s = torch.max( - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) - maybe_w_zp = None + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ + .clamp(min_q_val, max_q_val).int() + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) # Quantize w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) @@ -176,7 +180,7 @@ def quantize_weights(w: torch.Tensor, # For some kernels (namely Machete) the zero-points are applied after the # scales are applied, for this case computing the reference in similar way # allows us to use tighter error tolerances in our unit tests. - if ref_zero_points_after_scales and zero_points: + if ref_zero_points_after_scales and maybe_w_zp is not None: w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s else: w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s @@ -185,7 +189,7 @@ def quantize_weights(w: torch.Tensor, w_q += quant_type.bias # Restore original shapes - if group_size < size_k: + if group_size is not None and group_size < size_k: def reshape_w(w): w = w.reshape((group_size, -1, size_n)) @@ -195,17 +199,16 @@ def reshape_w(w): w_q = reshape_w(w_q) w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() - w_s = w_s.reshape((-1, size_n)).contiguous() - - if zero_points: + if maybe_w_zp is not None: maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() maybe_w_zp = maybe_w_zp.to(device=orig_device) return ( w_ref.to(device=orig_device), w_q.to(device=orig_device), - w_s.to(device=orig_device), + w_s if group_size is not None else None, maybe_w_zp, ) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 2e9a0e170693b..3ab0ba9e9f5c2 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -368,7 +368,7 @@ def _smallest_positive_value(self) -> float: # Note that we always sample with replacement. # probs will be modified in place, but this is fine, as we pass # in a copy already. -@torch.jit.script +@torch.compile(dynamic=True) def _multinomial( probs: torch.Tensor, num_samples: int, diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 52771f50a7a23..30548e656c557 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -133,13 +133,13 @@ def __post_init__(self): assert self.num_added_elements <= self.num_added_elements_padded -@torch.jit.script +@torch.compile(dynamic=True) def get_masked_input_and_mask( input_: torch.Tensor, org_vocab_start_index: int, org_vocab_end_index: int, num_org_vocab_padding: int, added_vocab_start_index: int, added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: - # torch.jit.script will fuse all of the pointwise ops below + # torch.compile will fuse all of the pointwise ops below # into a single kernel, making it very fast org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index d9ce85949e4ee..b41c23704b7ff 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -29,6 +29,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ReplicatedLinear, RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) from vllm.model_executor.model_loader.tensorizer import ( TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, serialize_vllm_model, tensorizer_weights_iterator) @@ -348,7 +350,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) - if quant_method is not None: + if isinstance(quant_method, QuantizeMethodBase): # When quant methods need to process weights after loading # (for repacking, quantizing, etc), they expect parameters # to be on the global target device. This scope is for the diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 6db6462e97f3f..6af59697160a0 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -4,10 +4,11 @@ import torch import torch.nn as nn +import torch.nn.functional as F from PIL import Image from transformers import Blip2VisionConfig, BlipVisionConfig -from transformers.models.blip.modeling_blip import BlipAttention +from vllm.attention.selector import _Backend from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import DecoderOnlyInputs, token_inputs @@ -21,11 +22,7 @@ repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData -try: - from xformers import ops as xops - USE_XFORMERS_OPS = True -except ImportError: - USE_XFORMERS_OPS = False +from .utils import get_vit_attn_backend def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -168,7 +165,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings -class BlipParallelAttention(nn.Module): +class BlipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( @@ -208,6 +205,12 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + # Detect attention implementation. + self.attn_backend = get_vit_attn_backend(support_fa=False) + if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: + raise RuntimeError( + f"BLIP does not support {self.attn_backend} backend now.") + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -231,11 +234,26 @@ def forward( self.num_heads_per_partition, self.head_dim) - out = xops.memory_efficient_attention_forward(query_states, - key_states, - value_states, - p=self.dropout, - scale=self.scale) + if self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + elif self.attn_backend == _Backend.TORCH_SDPA: + query_states, key_states, value_states = (x.transpose(1, 2) + for x in (query_states, + key_states, + value_states)) + out = F.scaled_dot_product_attention(query_states, + key_states, + value_states, + dropout_p=self.dropout, + scale=self.scale) + out = out.transpose(1, 2) + out = out.view(bsz, tgt_len, -1) attn_output, _ = self.projection(out) @@ -285,18 +303,11 @@ def __init__( super().__init__() # fallback to sdpa attention if tp unavailable - num_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - if USE_XFORMERS_OPS and num_heads % tp_size == 0: - self.self_attn = BlipParallelAttention( - config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - else: - # Blip doesn't have SDPA attention implemented in transformers - # use eager attention instead for cpu backend - self.self_attn = BlipAttention(config) + self.self_attn = BlipAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = BlipMLP(config, @@ -374,11 +385,6 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - - tp_size = get_tensor_model_parallel_world_size() - num_heads = config.num_attention_heads - self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 - self.config = config self.embeddings = BlipVisionEmbeddings(config) @@ -422,7 +428,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - ] if self.shard_weight else [] + ] params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() layer_count = len(self.encoder.layers) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 81e56381eabd8..2ea592aaba9f9 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -30,6 +30,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs @@ -574,25 +575,7 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) -@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv) -class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, - SupportsMultiModal): - packed_modules_mapping = { - "query_key_value": ["query_key_value"], - "dense_h_to_4h": ["dense_h_to_4h"] - } - # LoRA specific attributes - supported_lora_modules = [ - "query_key_value", - "dense", - "dense_h_to_4h", - "dense_4h_to_h", - ] - embedding_modules = {} - embedding_padding_modules = [] +class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -692,3 +675,79 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, combined_weight) loaded_params.add(combined_name) return loaded_params + + +class ChatGLM(ChatGLMBaseModel): + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"] + } + # LoRA specific attributes + supported_lora_modules = [ + "query_key_value", + "dense", + "dense_h_to_4h", + "dense_4h_to_h", + ] + + embedding_modules = {} + embedding_padding_modules = [] + + +class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal): + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"], + "merged_proj": ["gate_proj", "dense_h_to_4h"] + } + # LoRA specific attributes + supported_lora_modules = [ + "query_key_value", + "dense", + "dense_h_to_4h", + "dense_4h_to_h", + # vision + "fc1", + "fc2", + "merged_proj", + "linear_proj" + ] + + embedding_modules = {} + embedding_padding_modules = [] + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="transformer.encoder", + connector="transformer.vision.linear_proj", + tower_model="transformer.vision.transformer") + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) +@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv) +class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, + SupportsMultiModal): + # Ensure that the LoRA support check passes when the class is not + # initialized, but set all these attributes to empty. + packed_modules_mapping = {} + supported_lora_modules = [] + embedding_modules = {} + embedding_padding_modules = [] + + def __new__( + cls, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + config = vllm_config.model_config.hf_config + # Initialize VL + if hasattr(config, "visual"): + return ChatGLM(vllm_config=vllm_config, prefix=prefix) + # Initialize LLM + else: + return ChatGLMV(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 184758f4a8a45..7f638506f9fb2 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -5,10 +5,11 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from PIL import Image from transformers import CLIPVisionConfig -from transformers.models.clip.modeling_clip import CLIPSdpaAttention +from vllm.attention.selector import _Backend from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import DecoderOnlyInputs, token_inputs @@ -23,11 +24,7 @@ repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData -try: - from xformers import ops as xops - USE_XFORMERS_OPS = True -except ImportError: - USE_XFORMERS_OPS = False +from .utils import get_vit_attn_backend def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -197,7 +194,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings -class CLIPParallelAttention(nn.Module): +class CLIPAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( @@ -237,6 +234,12 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + # Detect attention implementation. + self.attn_backend = get_vit_attn_backend(support_fa=False) + if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: + raise RuntimeError( + f"CLIP does not support {self.attn_backend} backend now.") + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -261,11 +264,26 @@ def forward( self.num_heads_per_partition, self.head_dim) - out = xops.memory_efficient_attention_forward(query_states, - key_states, - value_states, - p=self.dropout, - scale=self.scale) + if self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + elif self.attn_backend == _Backend.TORCH_SDPA: + query_states, key_states, value_states = (x.transpose(1, 2) + for x in (query_states, + key_states, + value_states)) + out = F.scaled_dot_product_attention(query_states, + key_states, + value_states, + dropout_p=self.dropout, + scale=self.scale) + out = out.transpose(1, 2) + out = out.view(bsz, tgt_len, -1) attn_output, _ = self.out_proj(out) @@ -311,17 +329,11 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - - num_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - if USE_XFORMERS_OPS and num_heads % tp_size == 0: - self.self_attn = CLIPParallelAttention( - config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - else: - self.self_attn = CLIPSdpaAttention(config) + self.self_attn = CLIPAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = CLIPMLP(config, @@ -461,11 +473,6 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - - tp_size = get_tensor_model_parallel_world_size() - num_heads = config.num_attention_heads - self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 - self.vision_model = CLIPVisionTransformer( config=config, quant_config=quant_config, @@ -490,7 +497,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - ] if self.shard_weight else [] + ] params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() layer_count = len(self.vision_model.encoder.layers) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index bd91a0806ae5c..c4346fcb3bd2a 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from transformers import PretrainedConfig +from vllm.attention.selector import _Backend from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -24,11 +25,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -try: - from xformers import ops as xops - USE_XFORMERS_OPS = True -except ImportError: - USE_XFORMERS_OPS = False +from .utils import get_vit_attn_backend NORM2FN = { 'rms_norm': RMSNorm, @@ -186,6 +183,11 @@ def __init__( prefix=f"{prefix}.proj", ) + self.attn_backend = get_vit_attn_backend(support_fa=False) + if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: + raise RuntimeError( + f"InternViT does not support {self.attn_backend} backend now.") + def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) @@ -211,11 +213,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: k = k.view(B, N, self.num_heads_per_partition, self.head_dim) v = v.view(B, N, self.num_heads_per_partition, self.head_dim) - x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale) - x = x.view(B, N, -1) + if self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops - x, _ = self.proj(x) - return x + out = xops.memory_efficient_attention_forward(q, + k, + v, + scale=self.scale) + elif self.attn_backend == _Backend.TORCH_SDPA: + q, k, v = (x.transpose(1, 2) for x in (q, k, v)) + out = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + out = out.transpose(1, 2) + + out = out.view(B, N, -1) + out, _ = self.proj(out) + return out class InternSdpaAttention(nn.Module): @@ -362,7 +374,7 @@ def _init_attn( tp_size = get_tensor_model_parallel_world_size() num_heads = config.num_attention_heads - if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0: + if (num_heads + num_dummy_heads) % tp_size == 0: return InternParallelAttention(config, quant_config=quant_config, num_dummy_heads=num_dummy_heads, diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 035a1e2ab7b02..2528f741864b3 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -13,7 +13,6 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.attention.selector import _Backend from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -38,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer +from vllm.platforms import _Backend from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) from vllm.transformers_utils.processor import get_processor @@ -187,7 +187,7 @@ def __init__( ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend() + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS }: diff --git a/vllm/model_executor/models/phi3.py b/vllm/model_executor/models/phi3.py index 34141511ea791..54158bc141235 100644 --- a/vllm/model_executor/models/phi3.py +++ b/vllm/model_executor/models/phi3.py @@ -14,3 +14,13 @@ class Phi3ForCausalLM(LlamaForCausalLM): "gate_up_proj", ], } + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_up_proj.", + ".down_proj.", + ".qkv_proj.", + ".o_proj.", + ] + # Initialize an empty dict when there is no stacked parameter mapping. + bitsandbytes_stacked_params_mapping = {} diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index a78e4d355a314..f71cbd1264c45 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -54,12 +54,12 @@ def weight_loader(self, param: torch.nn.Parameter, return load_column_parallel_weight(param, loaded_weight) -@torch.jit.script +@torch.compile(dynamic=True) def quick_gelu(x): return x * torch.sigmoid(1.702 * x) -@torch.jit.script +@torch.compile(dynamic=True) def gegelu(input, limit: Optional[float] = None): a_gelu, a_linear = input[..., ::2], input[..., 1::2] if limit is not None: diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index d44a538d56b8c..f7f46770057e2 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -17,6 +17,7 @@ from vllm.attention import AttentionMetadata from vllm.config import ModelConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_and_mul_fn @@ -843,17 +844,20 @@ def __init__( self.config = config assert not config.hidden_size % config.num_attention_heads - self.n_heads = config.num_attention_heads + self.total_num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + self.n_heads = divide(config.num_attention_heads, tp_size) self.head_dim = config.hidden_size // config.num_attention_heads self.qkv_proj = QKVParallelLinear( hidden_size=config.hidden_size, head_size=self.head_dim, - total_num_heads=self.n_heads, + total_num_heads=self.total_num_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) + assert self.total_num_heads * self.head_dim == config.hidden_size self.o_proj = RowParallelLinear( input_size=config.hidden_size, output_size=config.hidden_size, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 3978c176a2144..44ce6eda42943 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -870,7 +870,7 @@ def dummy_data_for_qwen( return DummyData(seq_data, mm_data) -class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): +class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -1024,7 +1024,7 @@ class QWenLLM(QWenBaseModel): embedding_padding_modules = [] -class QWenVL(QWenBaseModel): +class QWenVL(QWenBaseModel, SupportsMultiModal): packed_modules_mapping = { "c_attn": ["c_attn"], "gate_up_proj": [ @@ -1062,7 +1062,7 @@ def get_mm_mapping(self) -> MultiModelKeys: @MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) -class QWenLMHeadModel(QWenBaseModel, SupportsLoRA): +class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA): """ QWenLMHeadModel is not only applicable to LLM but also to VL, which is not conducive to the current integration logic of LoRA in vLLM. Therefore, it @@ -1083,7 +1083,7 @@ def __new__( config = vllm_config.model_config.hf_config # Initialize VL if hasattr(config, "visual"): - return QWenVL(vllm_config=vllm_config) + return QWenVL(vllm_config=vllm_config, prefix=prefix) # Initialize LLM else: - return QWenLLM(vllm_config=vllm_config) + return QWenLLM(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index ef6b52db6e17d..0ac81387b1bd8 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -39,7 +39,6 @@ make_batched_images, make_batched_videos, smart_resize) from vllm.attention import AttentionMetadata -from vllm.attention.selector import _Backend from vllm.config import VllmConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils @@ -65,6 +64,7 @@ from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs) from vllm.multimodal.utils import cached_get_tokenizer +from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import cached_get_processor @@ -260,7 +260,7 @@ def __init__( prefix=f"{prefix}.proj") # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend() + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS }: diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index c9e09b879843a..c58ad99692900 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -6,11 +6,12 @@ import numpy as np import torch +import torch.nn.functional as F from PIL import Image from torch import nn from transformers import SiglipVisionConfig -from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention +from vllm.attention.selector import _Backend from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import DecoderOnlyInputs, token_inputs @@ -27,11 +28,7 @@ repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData -try: - from xformers import ops as xops - USE_XFORMERS_OPS = True -except ImportError: - USE_XFORMERS_OPS = False +from .utils import get_vit_attn_backend def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -254,7 +251,7 @@ def forward(self, return embeddings -class SiglipParallelAttention(nn.Module): +class SiglipAttention(nn.Module): def __init__( self, @@ -293,6 +290,11 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + self.attn_backend = get_vit_attn_backend(support_fa=False) + if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: + raise RuntimeError( + f"SIGLIP does not support {self.attn_backend} backend now.") + def forward( self, hidden_states: torch.Tensor, @@ -313,11 +315,26 @@ def forward( self.num_heads_per_partition, self.head_dim) - out = xops.memory_efficient_attention_forward(query_states, - key_states, - value_states, - p=self.dropout, - scale=self.scale) + if self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + elif self.attn_backend == _Backend.TORCH_SDPA: + query_states, key_states, value_states = (x.transpose(1, 2) + for x in (query_states, + key_states, + value_states)) + out = F.scaled_dot_product_attention(query_states, + key_states, + value_states, + dropout_p=self.dropout, + scale=self.scale) + out = out.transpose(1, 2) + out = out.view(batch_size, q_len, -1) attn_output, _ = self.out_proj(out) @@ -372,17 +389,11 @@ def __init__( self.embed_dim = config.hidden_size - num_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - if USE_XFORMERS_OPS and num_heads % tp_size == 0: - self.self_attn = SiglipParallelAttention( - config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - else: - self.self_attn = SiglipSdpaAttention(config) - + self.self_attn = SiglipAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -569,10 +580,6 @@ def __init__( ) -> None: super().__init__() - num_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 - self.vision_model = SiglipVisionTransformer( config, quant_config, @@ -601,7 +608,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - ] if self.shard_weight else [] + ] params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() layer_count = len(self.vision_model.encoder.layers) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 7a4fcce95603d..2ab9b19e22068 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -9,13 +9,13 @@ from transformers import PretrainedConfig import vllm.envs as envs -from vllm.attention.selector import (_Backend, backend_name_to_enum, +from vllm.attention.selector import (backend_name_to_enum, get_global_forced_attn_backend) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors -from vllm.platforms import current_platform +from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available @@ -587,7 +587,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return llm(*args, **kwargs) -def get_vit_attn_backend() -> _Backend: +def get_vit_attn_backend(support_fa: bool = False) -> _Backend: + """ + Get the available attention backend for Vision Transformer. + """ + # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn. selected_backend: Optional[_Backend] = get_global_forced_attn_backend() if selected_backend is None: backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND @@ -596,7 +600,7 @@ def get_vit_attn_backend() -> _Backend: if selected_backend is None: # For Volta and Turing GPUs, use xformers instead. device_available = current_platform.has_device_capability(80) - if device_available: + if device_available and support_fa: from transformers.utils import is_flash_attn_2_available if is_flash_attn_2_available(): selected_backend = _Backend.FLASH_ATTN @@ -606,7 +610,8 @@ def get_vit_attn_backend() -> _Backend: "so we use xformers backend instead. You can run " "`pip install flash-attn` to use flash-attention backend.") selected_backend = _Backend.XFORMERS - elif current_platform.is_cpu(): + elif current_platform.is_cpu() or current_platform.is_rocm(): + # ROCM doesn't support xformers selected_backend = _Backend.TORCH_SDPA else: selected_backend = _Backend.XFORMERS diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 9e740837381f8..1f68fc2e25df3 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,3 +1,4 @@ +from .interface import _Backend # noqa: F401 from .interface import Platform, PlatformEnum, UnspecifiedPlatform current_platform: Platform diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 42bee31dfb0e9..f9a34a47959ec 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -5,7 +5,9 @@ from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import Platform, PlatformEnum, _Backend + +logger = init_logger(__name__) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -22,6 +24,12 @@ class CpuPlatform(Platform): def get_device_name(cls, device_id: int = 0) -> str: return "cpu" + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + if selected_backend != _Backend.TORCH_SDPA: + logger.info("Cannot use %s backend on CPU.", selected_backend) + return _Backend.TORCH_SDPA + @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: return psutil.virtual_memory().total diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 170cfff94f90d..1e0888a30ba96 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -1,11 +1,15 @@ import torch -from .interface import Platform, PlatformEnum +from .interface import Platform, PlatformEnum, _Backend class HpuPlatform(Platform): _enum = PlatformEnum.HPU + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + return _Backend.HPU_ATTN + @staticmethod def inference_mode(): return torch.no_grad() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 970c0d1be617e..f4849fa2ccfb0 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -11,6 +11,20 @@ VllmConfig = None +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + FLASH_ATTN_VLLM_V1 = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + TORCH_SDPA = enum.auto() + OPENVINO = enum.auto() + FLASHINFER = enum.auto() + HPU_ATTN = enum.auto() + PALLAS = enum.auto() + IPEX = enum.auto() + NO_ATTENTION = enum.auto() + + class PlatformEnum(enum.Enum): CUDA = enum.auto() ROCM = enum.auto() @@ -71,6 +85,11 @@ def is_cuda_alike(self) -> bool: """Stateless version of :func:`torch.cuda.is_available`.""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend): + """Get the default attention backend of a device.""" + return None + @classmethod def get_device_capability( cls, diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 31fe3f1fcbfe4..ad69ced5417b3 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -3,7 +3,7 @@ import vllm.envs as envs from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import Platform, PlatformEnum, _Backend logger = init_logger(__name__) @@ -11,6 +11,12 @@ class OpenVinoPlatform(Platform): _enum = PlatformEnum.OPENVINO + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + if selected_backend != _Backend.OPENVINO: + logger.info("Cannot use %s backend on OpenVINO.", selected_backend) + return _Backend.OPENVINO + @classmethod def get_device_name(self, device_id: int = 0) -> str: return "openvino" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index fd8afc92b0f28..022256996f97b 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -5,7 +5,7 @@ from vllm.logger import init_logger -from .interface import DeviceCapability, Platform, PlatformEnum +from .interface import DeviceCapability, Platform, PlatformEnum, _Backend logger = init_logger(__name__) @@ -19,6 +19,18 @@ class RocmPlatform(Platform): _enum = PlatformEnum.ROCM + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + selected_backend = (_Backend.ROCM_FLASH if selected_backend + == _Backend.FLASH_ATTN else selected_backend) + if selected_backend == _Backend.ROCM_FLASH: + if not cls.has_device_capability(90): + # not Instinct series GPUs. + logger.info("flash_attn is not supported on NAVI GPUs.") + else: + logger.info("%s is not supported in AMD GPUs.", selected_backend) + return _Backend.ROCM_FLASH + @classmethod @lru_cache(maxsize=8) def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 643db835c85ff..9057afb6514e4 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -3,17 +3,27 @@ import torch -from .interface import Platform, PlatformEnum +from vllm.logger import init_logger + +from .interface import Platform, PlatformEnum, _Backend if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None +logger = init_logger(__name__) + class TpuPlatform(Platform): _enum = PlatformEnum.TPU + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + if selected_backend != _Backend.PALLAS: + logger.info("Cannot use %s backend on TPU.", selected_backend) + return _Backend.PALLAS + @classmethod def get_device_name(cls, device_id: int = 0) -> str: raise NotImplementedError diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 106e8eddf458f..d0b3dca9a4195 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,11 +1,21 @@ import torch -from .interface import DeviceCapability, Platform, PlatformEnum +from vllm.logger import init_logger + +from .interface import DeviceCapability, Platform, PlatformEnum, _Backend + +logger = init_logger(__name__) class XPUPlatform(Platform): _enum = PlatformEnum.XPU + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + if selected_backend != _Backend.IPEX: + logger.info("Cannot use %s backend on XPU.", selected_backend) + return _Backend.IPEX + @staticmethod def get_device_capability(device_id: int = 0) -> DeviceCapability: major, minor, *_ = torch.xpu.get_device_capability( diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index a0c73a752b5e8..05a9739d99e71 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -6,18 +6,22 @@ if TYPE_CHECKING: from vllm.config import CompilationConfig, VllmConfig -else: - CompilationConfig = None - VllmConfig = None logger = logging.getLogger(__name__) +# make sure one process only loads plugins once +plugins_loaded = False + def load_general_plugins(): """WARNING: plugins can be loaded for multiple times in different processes. They should be designed in a way that they can be loaded multiple times without causing issues. """ + global plugins_loaded + if plugins_loaded: + return + plugins_loaded = True import sys if sys.version_info < (3, 10): from importlib_metadata import entry_points @@ -50,23 +54,23 @@ def load_general_plugins(): logger.exception("Failed to load plugin %s", plugin.name) -_compilation_config: Optional[CompilationConfig] = None +_compilation_config: Optional["CompilationConfig"] = None -def set_compilation_config(config: Optional[CompilationConfig]): +def set_compilation_config(config: Optional["CompilationConfig"]): global _compilation_config _compilation_config = config -def get_compilation_config() -> Optional[CompilationConfig]: +def get_compilation_config() -> Optional["CompilationConfig"]: return _compilation_config -_current_vllm_config: Optional[VllmConfig] = None +_current_vllm_config: Optional["VllmConfig"] = None @contextmanager -def set_current_vllm_config(vllm_config: VllmConfig): +def set_current_vllm_config(vllm_config: "VllmConfig"): """ Temporarily set the current VLLM config. Used during model initialization. @@ -80,9 +84,19 @@ def set_current_vllm_config(vllm_config: VllmConfig): _current_vllm_config = vllm_config yield finally: + logger.debug("enabled custom ops: %s", + vllm_config.compilation_config.enabled_custom_ops) + logger.debug("disabled custom ops: %s", + vllm_config.compilation_config.disabled_custom_ops) _current_vllm_config = old_vllm_config -def get_current_vllm_config() -> VllmConfig: - assert _current_vllm_config is not None, "Current VLLM config is not set." +def get_current_vllm_config() -> "VllmConfig": + if _current_vllm_config is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the vllm config. In that case, we set a default + # config. + logger.warning("Current VLLM config is not set.") + from vllm.config import VllmConfig + return VllmConfig() return _current_vllm_config diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 82824faa6629a..687d2cc79360f 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -8,7 +8,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata) from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, +from vllm.attention.selector import (get_env_variable_attn_backend, get_global_forced_attn_backend) from vllm.config import VllmConfig from vllm.forward_context import set_forward_context @@ -18,6 +18,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, MultiModalRegistry) +from vllm.platforms import _Backend from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fb5813651680b..ed0360fb7f727 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1769,7 +1769,7 @@ def capture( # Run the model a few times without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - # Note one iteration is not enough for torch.jit.script + # Note one iteration is not enough for torch.compile for _ in range(_NUM_WARMUP_ITERS): self.model( input_ids=input_ids,