From 2d1b9baa8f57fc59912c7bcd07fd630fb9d72c9d Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 17 Dec 2024 13:26:32 -0700 Subject: [PATCH 01/23] [Bugfix] Fix request cancellation without polling (#11190) --- tests/entrypoints/openai/test_basic.py | 51 ++++++++++++++++ tests/test_utils.py | 6 +- tests/utils.py | 11 ++-- vllm/engine/async_llm_engine.py | 46 +++++++++------ vllm/entrypoints/api_server.py | 11 ++-- vllm/entrypoints/openai/api_server.py | 8 +++ vllm/entrypoints/openai/serving_chat.py | 5 -- vllm/entrypoints/openai/serving_completion.py | 3 +- vllm/entrypoints/openai/serving_embedding.py | 5 +- vllm/entrypoints/openai/serving_score.py | 5 +- vllm/entrypoints/utils.py | 57 ++++++++++++++++++ vllm/utils.py | 59 ++----------------- 12 files changed, 164 insertions(+), 103 deletions(-) create mode 100644 vllm/entrypoints/utils.py diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 4616f363cc04a..547c1fd020928 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -1,6 +1,8 @@ +import asyncio from http import HTTPStatus from typing import List +import openai import pytest import pytest_asyncio import requests @@ -103,3 +105,52 @@ async def test_check_health(server: RemoteOpenAIServer): response = requests.get(server.url_for("health")) assert response.status_code == HTTPStatus.OK + + +@pytest.mark.parametrize( + "server_args", + [ + pytest.param(["--max-model-len", "10100"], + id="default-frontend-multiprocessing"), + pytest.param( + ["--disable-frontend-multiprocessing", "--max-model-len", "10100"], + id="disable-frontend-multiprocessing") + ], + indirect=True, +) +@pytest.mark.asyncio +async def test_request_cancellation(server: RemoteOpenAIServer): + # clunky test: send an ungodly amount of load in with short timeouts + # then ensure that it still responds quickly afterwards + + chat_input = [{"role": "user", "content": "Write a long story"}] + client = server.get_async_client(timeout=0.5) + tasks = [] + # Request about 2 million tokens + for _ in range(200): + task = asyncio.create_task( + client.chat.completions.create(messages=chat_input, + model=MODEL_NAME, + max_tokens=10000, + extra_body={"min_tokens": 10000})) + tasks.append(task) + + done, pending = await asyncio.wait(tasks, + return_when=asyncio.ALL_COMPLETED) + + # Make sure all requests were sent to the server and timed out + # (We don't want to hide other errors like 400s that would invalidate this + # test) + assert len(pending) == 0 + for d in done: + with pytest.raises(openai.APITimeoutError): + d.result() + + # If the server had not cancelled all the other requests, then it would not + # be able to respond to this one within the timeout + client = server.get_async_client(timeout=5) + response = await client.chat.completions.create(messages=chat_input, + model=MODEL_NAME, + max_tokens=10) + + assert len(response.choices) == 1 diff --git a/tests/test_utils.py b/tests/test_utils.py index 0bc9e5bc32a46..32a6b0aed66aa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,6 @@ import asyncio import os import socket -from functools import partial from typing import AsyncIterator, Tuple import pytest @@ -26,10 +25,7 @@ async def mock_async_iterator(idx: int): print(f"iterator {idx} cancelled") iterators = [mock_async_iterator(i) for i in range(3)] - merged_iterator = merge_async_iterators(*iterators, - is_cancelled=partial(asyncio.sleep, - 0, - result=False)) + merged_iterator = merge_async_iterators(*iterators) async def stream_output(generator: AsyncIterator[Tuple[int, str]]): async for idx, output in generator: diff --git a/tests/utils.py b/tests/utils.py index afeb708f3bcdc..bf3d88194e4ca 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -163,12 +163,11 @@ def get_client(self): api_key=self.DUMMY_API_KEY, ) - def get_async_client(self): - return openai.AsyncOpenAI( - base_url=self.url_for("v1"), - api_key=self.DUMMY_API_KEY, - max_retries=0, - ) + def get_async_client(self, **kwargs): + return openai.AsyncOpenAI(base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs) def _test_completion( diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 32396fd10188d..f50e20cf70323 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1065,16 +1065,20 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in await self.add_request( - request_id, - prompt, - sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, - ): - yield LLMEngine.validate_output(output, RequestOutput) + try: + async for output in await self.add_request( + request_id, + prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ): + yield LLMEngine.validate_output(output, RequestOutput) + except asyncio.CancelledError: + await self.abort(request_id) + raise async def encode( self, @@ -1147,15 +1151,19 @@ async def encode( >>> # Process and return the final output >>> ... """ - async for output in await self.add_request( - request_id, - prompt, - pooling_params, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - ): - yield LLMEngine.validate_output(output, PoolingRequestOutput) + try: + async for output in await self.add_request( + request_id, + prompt, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=priority, + ): + yield LLMEngine.validate_output(output, PoolingRequestOutput) + except asyncio.CancelledError: + await self.abort(request_id) + raise async def abort(self, request_id: str) -> None: """Abort a request. diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index ea3c93f733038..95da1c6e7b9bf 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -17,11 +17,11 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext -from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation, - random_uuid) +from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION logger = init_logger("vllm.entrypoints.api_server") @@ -47,6 +47,11 @@ async def generate(request: Request) -> Response: - other fields: the sampling parameters (See `SamplingParams` for details). """ request_dict = await request.json() + return await _generate(request_dict, raw_request=request) + + +@with_cancellation +async def _generate(request_dict: dict, raw_request: Request) -> Response: prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) @@ -54,8 +59,6 @@ async def generate(request: Request) -> Response: assert engine is not None results_generator = engine.generate(prompt, sampling_params, request_id) - results_generator = iterate_with_cancellation( - results_generator, is_cancelled=request.is_disconnected) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 14e3a34ce141c..00e2d1a56f160 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -59,6 +59,7 @@ from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, @@ -311,6 +312,7 @@ async def health(raw_request: Request) -> Response: @router.post("/tokenize") +@with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -325,6 +327,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): @router.post("/detokenize") +@with_cancellation async def detokenize(request: DetokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -353,6 +356,7 @@ async def show_version(): @router.post("/v1/chat/completions") +@with_cancellation async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): handler = chat(raw_request) @@ -373,6 +377,7 @@ async def create_chat_completion(request: ChatCompletionRequest, @router.post("/v1/completions") +@with_cancellation async def create_completion(request: CompletionRequest, raw_request: Request): handler = completion(raw_request) if handler is None: @@ -390,6 +395,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post("/v1/embeddings") +@with_cancellation async def create_embedding(request: EmbeddingRequest, raw_request: Request): handler = embedding(raw_request) if handler is None: @@ -407,6 +413,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): @router.post("/score") +@with_cancellation async def create_score(request: ScoreRequest, raw_request: Request): handler = score(raw_request) if handler is None: @@ -424,6 +431,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): @router.post("/v1/score") +@with_cancellation async def create_score_v1(request: ScoreRequest, raw_request: Request): logger.warning( "To indicate that Score API is not part of standard OpenAI API, we " diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 527418c635093..81bce0dd370bb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -32,7 +32,6 @@ from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls -from vllm.utils import iterate_with_cancellation logger = init_logger(__name__) @@ -234,10 +233,6 @@ async def create_chat_completion( assert len(generators) == 1 result_generator, = generators - if raw_request: - result_generator = iterate_with_cancellation( - result_generator, raw_request.is_disconnected) - # Streaming response if request.stream: return self.chat_completion_stream_generator( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index bd39a4c42e938..5cf9df92e296e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -159,8 +159,7 @@ async def create_completion( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator = merge_async_iterators( - *generators, is_cancelled=raw_request.is_disconnected) + result_generator = merge_async_iterators(*generators) model_name = self._get_model_name(lora_request) num_prompts = len(engine_prompts) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index fd501ad4f833e..879276646d2ba 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -202,10 +202,7 @@ async def create_embedding( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator = merge_async_iterators( - *generators, - is_cancelled=raw_request.is_disconnected if raw_request else None, - ) + result_generator = merge_async_iterators(*generators) num_prompts = len(engine_prompts) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 6f5cc14ac37cc..101d170bee4d6 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -186,10 +186,7 @@ async def create_score( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator = merge_async_iterators( - *generators, - is_cancelled=raw_request.is_disconnected if raw_request else None, - ) + result_generator = merge_async_iterators(*generators) num_prompts = len(engine_prompts) diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py new file mode 100644 index 0000000000000..e8a78d216d0f0 --- /dev/null +++ b/vllm/entrypoints/utils.py @@ -0,0 +1,57 @@ +import asyncio +import functools + +from fastapi import Request + + +async def listen_for_disconnect(request: Request) -> None: + """Returns if a disconnect message is received""" + while True: + message = await request.receive() + if message["type"] == "http.disconnect": + break + + +def with_cancellation(handler_func): + """Decorator that allows a route handler to be cancelled by client + disconnections. + + This does _not_ use request.is_disconnected, which does not work with + middleware. Instead this follows the pattern from + starlette.StreamingResponse, which simultaneously awaits on two tasks- one + to wait for an http disconnect message, and the other to do the work that we + want done. When the first task finishes, the other is cancelled. + + A core assumption of this method is that the body of the request has already + been read. This is a safe assumption to make for fastapi handlers that have + already parsed the body of the request into a pydantic model for us. + This decorator is unsafe to use elsewhere, as it will consume and throw away + all incoming messages for the request while it looks for a disconnect + message. + + In the case where a `StreamingResponse` is returned by the handler, this + wrapper will stop listening for disconnects and instead the response object + will start listening for disconnects. + """ + + # Functools.wraps is required for this wrapper to appear to fastapi as a + # normal route handler, with the correct request type hinting. + @functools.wraps(handler_func) + async def wrapper(*args, **kwargs): + + # The request is either the second positional arg or `raw_request` + request = args[1] if len(args) > 1 else kwargs["raw_request"] + + handler_task = asyncio.create_task(handler_func(*args, **kwargs)) + cancellation_task = asyncio.create_task(listen_for_disconnect(request)) + + done, pending = await asyncio.wait([handler_task, cancellation_task], + return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + + if handler_task in done: + return handler_task.result() + return None + + return wrapper diff --git a/vllm/utils.py b/vllm/utils.py index 73d2ae25f15ca..38c7dea6d2d3d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -20,7 +20,7 @@ import uuid import warnings import weakref -from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task +from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict from collections.abc import Iterable, Mapping from dataclasses import dataclass, field @@ -370,72 +370,23 @@ def _next_task(iterator: AsyncGenerator[T, None], return loop.create_task(iterator.__anext__()) # type: ignore[arg-type] -async def iterate_with_cancellation( - iterator: AsyncGenerator[T, None], - is_cancelled: Callable[[], Awaitable[bool]], -) -> AsyncGenerator[T, None]: - """Convert async iterator into one that polls the provided function - at least once per second to check for client cancellation. - """ - - loop = asyncio.get_running_loop() - - awaits: List[Future[T]] = [_next_task(iterator, loop)] - next_cancel_check: float = 0 - while True: - done, pending = await asyncio.wait(awaits, timeout=1.5) - - # Check for cancellation at most once per second - time_now = time.time() - if time_now >= next_cancel_check: - if await is_cancelled(): - with contextlib.suppress(BaseException): - awaits[0].cancel() - await iterator.aclose() - raise asyncio.CancelledError("client cancelled") - next_cancel_check = time_now + 1 - - if done: - try: - item = await awaits[0] - awaits[0] = _next_task(iterator, loop) - yield item - except StopAsyncIteration: - # we are done - return - - async def merge_async_iterators( - *iterators: AsyncGenerator[T, None], - is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None, -) -> AsyncGenerator[Tuple[int, T], None]: + *iterators: AsyncGenerator[T, + None], ) -> AsyncGenerator[Tuple[int, T], None]: """Merge multiple asynchronous iterators into a single iterator. This method handle the case where some iterators finish before others. When it yields, it yields a tuple (i, item) where i is the index of the iterator that yields the item. - - It also optionally polls a provided function at least once per second - to check for client cancellation. """ loop = asyncio.get_running_loop() awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)} - timeout = None if is_cancelled is None else 1.5 - next_cancel_check: float = 0 try: while awaits: - done, pending = await asyncio.wait(awaits.keys(), - return_when=FIRST_COMPLETED, - timeout=timeout) - if is_cancelled is not None: - # Check for cancellation at most once per second - time_now = time.time() - if time_now >= next_cancel_check: - if await is_cancelled(): - raise asyncio.CancelledError("client cancelled") - next_cancel_check = time_now + 1 + done, _ = await asyncio.wait(awaits.keys(), + return_when=FIRST_COMPLETED) for d in done: pair = awaits.pop(d) try: From c77eb8a33ceb62858d951ffef87ae626a0d09973 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 17 Dec 2024 19:34:06 -0500 Subject: [PATCH 02/23] [Bugfix] Set temperature=0.7 in test_guided_choice_chat (#11264) --- tests/entrypoints/openai/test_chat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 8d23a2be6f9bb..47c521a9b5eb5 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -482,6 +482,7 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI, model=MODEL_NAME, messages=messages, max_completion_tokens=10, + temperature=0.7, extra_body=dict(guided_choice=sample_guided_choice, guided_decoding_backend=guided_decoding_backend)) choice1 = chat_completion.choices[0].message.content @@ -496,6 +497,7 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI, model=MODEL_NAME, messages=messages, max_completion_tokens=10, + temperature=0.7, extra_body=dict(guided_choice=sample_guided_choice, guided_decoding_backend=guided_decoding_backend)) choice2 = chat_completion.choices[0].message.content From bf8717ebaea8d74279df84fbe127ad22cf62e219 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 17 Dec 2024 16:37:59 -0800 Subject: [PATCH 03/23] [V1] Prefix caching for vision language models (#11187) Signed-off-by: Cody Yu --- tests/v1/core/test_prefix_caching.py | 88 +++++++++++++++++++- tests/v1/engine/test_engine_args.py | 15 ---- vllm/engine/arg_utils.py | 27 ++++--- vllm/inputs/data.py | 20 +++++ vllm/multimodal/inputs.py | 3 + vllm/v1/core/kv_cache_manager.py | 74 +++++++++++------ vllm/v1/core/kv_cache_utils.py | 115 ++++++++++++++++++++++++--- vllm/v1/core/scheduler.py | 2 + vllm/v1/engine/async_llm.py | 10 ++- vllm/v1/engine/core.py | 8 +- vllm/v1/engine/llm_engine.py | 9 ++- vllm/v1/engine/mm_input_mapper.py | 33 ++++---- vllm/v1/engine/processor.py | 12 +-- vllm/v1/request.py | 24 +++++- 14 files changed, 342 insertions(+), 98 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 00f7b0fcfe1dc..ed04f0a373c51 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -2,16 +2,23 @@ import pytest from vllm.inputs import token_inputs +from vllm.multimodal.inputs import PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import cdiv from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens -def make_request(request_id, prompt_token_ids): +def make_request(request_id, + prompt_token_ids, + mm_positions=None, + mm_hashes=None): return Request( request_id=request_id, - inputs=token_inputs(prompt_token_ids=prompt_token_ids), + inputs=token_inputs(prompt_token_ids=prompt_token_ids, + multi_modal_placeholders={"image": mm_positions} + if mm_positions else None, + multi_modal_hashes=mm_hashes), sampling_params=SamplingParams(max_tokens=17), eos_token_id=100, arrival_time=0, @@ -38,6 +45,7 @@ def test_prefill(): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) computed_blocks = manager.get_computed_blocks(req0) + assert len(req0.kv_block_hashes) == 3 assert not computed_blocks blocks = manager.allocate_slots(req0, 55, computed_blocks) assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] @@ -61,6 +69,7 @@ def test_prefill(): unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks = manager.get_computed_blocks(req1) + assert len(req1.kv_block_hashes) == 3 assert [b.block_id for b in computed_blocks] == [0, 1, 2] num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) @@ -90,6 +99,7 @@ def test_prefill(): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) computed_block = manager.get_computed_blocks(req2) + assert len(req2.kv_block_hashes) == 3 assert [b.block_id for b in computed_block] == [0, 1, 2] num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) @@ -416,3 +426,77 @@ def test_cache_blocks(): ) assert len(manager.cached_block_hash_to_block) == 3 assert blocks[0].block_hash is not None + + +def test_mm_prefix_caching(): + """ + This tests that the multi-modal prefix caching is correct. + """ + manager = KVCacheManager( + block_size=16, + num_gpu_blocks=10, + max_model_len=8192, + sliding_window=None, + enable_caching=True, + num_preallocate_tokens=16, + ) + + # Common prompt tokens (T is text tokens and P is image placeholder tokens) + # [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1] + common_token_ids = list(range(10)) + [-1] * 6 + common_token_ids += [-1] * 4 + list(range(10, 20)) + [-1] * 2 + common_token_ids += [-1] * 16 + + common_mm_positions = [ + PlaceholderRange(offset=11, length=10), + PlaceholderRange(offset=30, length=18), + ] + common_mm_hashes = ["aaa", "bbb"] + + # A unique image plus some text tokens. + unique_token_ids = [-1] * 7 + [100] * 4 + all_token_ids = common_token_ids + unique_token_ids + mm_positions = common_mm_positions + [ + PlaceholderRange(offset=48, length=7) + ] + mm_hashes = common_mm_hashes + ["ccc"] + req0 = make_request("0", + all_token_ids, + mm_positions=mm_positions, + mm_hashes=mm_hashes) + computed_blocks = manager.get_computed_blocks(req0) + + # Completed block should have hashes with extra keys. + assert not computed_blocks + assert len(req0.kv_block_hashes) == 3 + assert req0.kv_block_hashes[0].extra_keys == (("aaa", 0), ) + assert req0.kv_block_hashes[1].extra_keys == (("aaa", 5), ("bbb", 0)) + assert req0.kv_block_hashes[2].extra_keys == (("bbb", 2), ) + + blocks = manager.allocate_slots(req0, 59, computed_blocks) + assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + req0.num_computed_tokens = 59 + + # Append slots without allocating a new block. + for _ in range(5): + req0.append_output_token_ids(8) + new_blocks = manager.append_slots(req0, 5) + assert new_blocks is not None and len(new_blocks) == 0 + + # The just completed block should have hashes with extra keys. + assert len(req0.kv_block_hashes) == 4 + assert req0.kv_block_hashes[3].extra_keys == (("ccc", 0), ) + + # Cache hit. + unique_token_ids = [-1] * 7 + [200] * 5 + all_token_ids = common_token_ids + unique_token_ids + mm_positions = common_mm_positions + [ + PlaceholderRange(offset=48, length=7) + ] + mm_hashes = common_mm_hashes + ["ccc"] + req1 = make_request("1", + all_token_ids, + mm_positions=mm_positions, + mm_hashes=mm_hashes) + computed_blocks = manager.get_computed_blocks(req1) + assert len(computed_blocks) == 3 diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index ac5e7dde525a7..ff38a4568ecb1 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -31,14 +31,6 @@ def test_prefix_caching_from_cli(): assert engine_args.enable_prefix_caching -def test_defaults(): - engine_args = EngineArgs(model="facebook/opt-125m") - - # Assert V1 defaults - assert (engine_args.enable_prefix_caching - ), "V1 turns on prefix caching by default" - - def test_defaults_with_usage_context(): engine_args = EngineArgs(model="facebook/opt-125m") vllm_config: VllmConfig = engine_args.create_engine_config( @@ -52,10 +44,3 @@ def test_defaults_with_usage_context(): UsageContext.OPENAI_API_SERVER) assert vllm_config.scheduler_config.max_num_seqs == 1024 assert vllm_config.scheduler_config.max_num_batched_tokens == 2048 - - -def test_prefix_cache_disabled_with_multimodel(): - engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf") - - vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) - assert not vllm_config.cache_config.enable_prefix_caching diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f6d276fe7c0c8..674577f23eba6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -205,6 +205,7 @@ def __post_init__(self): # by user. if self.enable_prefix_caching is None: self.enable_prefix_caching = bool(envs.VLLM_USE_V1) + # Override max_num_seqs if it's not set by user. if self.max_num_seqs is None: self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024 @@ -1026,11 +1027,11 @@ def create_engine_config(self, device_config = DeviceConfig(device=self.device) model_config = self.create_model_config() - if model_config.is_multimodal_model: - if self.enable_prefix_caching: - logger.warning( - "--enable-prefix-caching is currently not " - "supported for multimodal models and has been disabled.") + if (model_config.is_multimodal_model and not envs.VLLM_USE_V1 + and self.enable_prefix_caching): + logger.warning("--enable-prefix-caching is currently not " + "supported for multimodal models in v0 and " + "has been disabled.") self.enable_prefix_caching = False cache_config = CacheConfig( @@ -1249,11 +1250,14 @@ def _override_v1_engine_args(self, usage_context: UsageContext) -> None: # When no user override, set the default values based on the usage # context. # TODO(woosuk): Tune the default values for different hardware. - if self.max_num_batched_tokens is None: - if usage_context == UsageContext.LLM_CLASS: - self.max_num_batched_tokens = 8192 - elif usage_context == UsageContext.OPENAI_API_SERVER: - self.max_num_batched_tokens = 2048 + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 8192, + UsageContext.OPENAI_API_SERVER: 2048, + } + if (self.max_num_batched_tokens is None + and usage_context in default_max_num_batched_tokens): + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context] logger.warning( "Setting max_num_batched_tokens to %d for %s usage context.", self.max_num_batched_tokens, usage_context.value) @@ -1263,9 +1267,6 @@ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None: Override the EngineConfig's configs based on the usage context for V1. """ assert envs.VLLM_USE_V1, "V1 is not enabled" - if engine_config.model_config.is_multimodal_model: - # TODO (ywang96): Enable APC by default when VLM supports it. - assert not engine_config.cache_config.enable_prefix_caching @dataclass diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 85aaaa776907f..d54cbb5c37819 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -162,6 +162,11 @@ class TokenInputs(TypedDict): Placeholder ranges for the multi-modal data. """ + multi_modal_hashes: NotRequired[List[str]] + """ + The hashes of the multi-modal data. + """ + mm_processor_kwargs: NotRequired[Dict[str, Any]] """ Optional multi-modal processor kwargs to be forwarded to the @@ -177,6 +182,7 @@ def token_inputs( prompt: Optional[str] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_inputs: Optional["MultiModalKwargs"] = None, + multi_modal_hashes: Optional[List[str]] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> TokenInputs: @@ -191,6 +197,8 @@ def token_inputs( inputs["multi_modal_data"] = multi_modal_data if multi_modal_inputs is not None: inputs["multi_modal_inputs"] = multi_modal_inputs + if multi_modal_hashes is not None: + inputs["multi_modal_hashes"] = multi_modal_hashes if multi_modal_placeholders is not None: inputs["multi_modal_placeholders"] = multi_modal_placeholders if mm_processor_kwargs is not None: @@ -295,6 +303,18 @@ def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]: assert_never(inputs) + @cached_property + def multi_modal_hashes(self) -> List[str]: + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_hashes", []) + + if inputs["type"] == "multimodal": + return inputs.get("mm_hashes", []) + + assert_never(inputs) + @cached_property def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": inputs = self.inputs diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 229a8fbdf5831..c00943a5f26d9 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict): mm_kwargs: MultiModalKwargs """Keyword arguments to be directly passed to the model after batching.""" + mm_hashes: NotRequired[List[str]] + """The hashes of the multi-modal data.""" + mm_placeholders: MultiModalPlaceholderDict """ For each modality, information about the placeholder tokens in diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index aaa44c930e324..61a3f5fd6d841 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -4,7 +4,9 @@ from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, hash_block_tokens, + KVCacheBlock, + generate_block_hash_extra_keys, + hash_block_tokens, hash_request_tokens) from vllm.v1.request import Request @@ -83,10 +85,12 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: computed_blocks = [] - # TODO(rickyx): potentially we could cache this so we don't have to - # recompute it every time. - block_hashes = hash_request_tokens(self.block_size, - request.all_token_ids) + # The block hashes for the request may already be computed + # if the request was preempted and resumed. + if not request.kv_block_hashes: + request.set_kv_block_hashes( + hash_request_tokens(self.block_size, request)) + block_hashes = request.kv_block_hashes for block_hash in block_hashes: # block_hashes is a chain of block hashes. If a block hash is not @@ -242,14 +246,16 @@ def allocate_slots( num_computed_tokens = len(computed_blocks) * self.block_size num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size - self._cache_full_blocks( - request=request, - blk_start_idx=len(computed_blocks), - # The new full blocks are the full blocks that are not computed. - full_blocks=self.req_to_blocks[request.request_id] - [len(computed_blocks):num_full_blocks], - prev_block=computed_blocks[-1] if computed_blocks else None, - ) + new_full_blocks = self.req_to_blocks[ + request.request_id][len(computed_blocks):num_full_blocks] + if new_full_blocks: + self._cache_full_blocks( + request=request, + blk_start_idx=len(computed_blocks), + # The new full blocks are the full blocks that are not computed. + full_blocks=new_full_blocks, + prev_block=computed_blocks[-1] if computed_blocks else None, + ) return new_blocks @@ -376,6 +382,8 @@ def _cache_full_blocks( full_blocks: The list of blocks to update hash metadata. prev_block: The previous block in the chain. """ + num_cached_block_hashes = len(request.kv_block_hashes) + # Update the new blocks with the block hashes through the chain. prev_block_hash_value = None if prev_block is not None: @@ -387,17 +395,35 @@ def _cache_full_blocks( for i, blk in enumerate(full_blocks): blk_idx = blk_start_idx + i - block_tokens = request.all_token_ids[blk_idx * - self.block_size:(blk_idx + - 1) * - self.block_size] - assert len(block_tokens) == self.block_size, ( - f"Expected {self.block_size} tokens, got {len(block_tokens)} " - f"at {blk_idx}th block for request " - f"{request.request_id}({request})") - - # Compute the hash of the current block. - block_hash = hash_block_tokens(prev_block_hash_value, block_tokens) + if blk_idx < num_cached_block_hashes: + # The block hash may already be computed in + # "get_computed_blocks" if the tokens are not generated by + # this request (either the prompt tokens or the previously + # generated tokens with preemption). In this case we simply + # reuse the block hash. + block_hash = request.kv_block_hashes[blk_idx] + else: + # Otherwise compute the block hash and cache it in the request + # in case it will be preempted in the future. + start_token_idx = blk_idx * self.block_size + end_token_idx = (blk_idx + 1) * self.block_size + block_tokens = request.all_token_ids[ + start_token_idx:end_token_idx] + assert len(block_tokens) == self.block_size, ( + f"Expected {self.block_size} tokens, got " + f"{len(block_tokens)} at {blk_idx}th block for request " + f"{request.request_id}({request})") + + # Generate extra keys for multi-modal inputs. Note that since + # we reach to this branch only when the block is completed with + # generated tokens, we only need to consider the last mm input. + extra_keys, _ = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, -1) + + # Compute the hash of the current block. + block_hash = hash_block_tokens(prev_block_hash_value, + block_tokens, extra_keys) + request.append_kv_block_hashes(block_hash) # Update and added the full block to the cache. blk.block_hash = block_hash diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 0ba338aa5a3d2..d80ea128c7749 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,20 +1,25 @@ """KV-Cache Utilities.""" from collections.abc import Sequence from dataclasses import dataclass -from typing import List, NamedTuple, Optional, Tuple +from typing import Any, List, NamedTuple, Optional, Tuple from vllm.logger import init_logger +from vllm.v1.request import Request logger = init_logger(__name__) class BlockHashType(NamedTuple): - """Hash value of a block and the token IDs in the block. - The reason we keep a tuple of token IDs is to make sure no hash - collision happens when the hash value is the same. + """Hash value of a block (int), the token IDs in the block, and extra keys. + The reason we keep a tuple of token IDs and extra keys is to make sure + no hash collision happens when the hash value is the same. """ + # Hash value of the block in an integer. hash_value: int + # Token IDs in the block. token_ids: Tuple[int, ...] + # Extra keys for the block. + extra_keys: Optional[Any] = None @dataclass @@ -159,8 +164,80 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]: return ret -def hash_block_tokens(parent_block_hash: Optional[int], - curr_block_token_ids: Sequence[int]) -> BlockHashType: +def generate_block_hash_extra_keys( + request: Request, start_token_idx: int, end_token_idx: int, + start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: + """Generate extra keys for the block hash. The extra keys can come from + the multi-modal inputs and request specific metadata (e.g., LoRA ID). + For multi-modal inputs, the extra keys are (mm_hash, start_offset) that + indicate a mm input contained in the block and its starting offset in + the block tokens. + + Args: + request: The request object. + start_token_idx: The start token index of the block. + end_token_idx: The end token index of the block. + start_mm_idx: The start multi-modal index of the block. + + Returns: + A tuple of extra keys and the next multi-modal index. + """ + + mm_positions, mm_hashes = request.mm_positions, request.mm_hashes + if not mm_positions: + return None, start_mm_idx + + if mm_positions and len(mm_positions) != len(mm_hashes): + raise ValueError( + "The number of multi-modal positions and hashes must match. This " + "is likely because you do not enable MM preprocessor hashing. " + "Please set mm_cache_preprocessor=True.") + + # Note that we assume mm_positions is sorted by offset. + # We do not need to check all mm inputs if the start token index is out of + # range. This usually happens in the late prefill phase and decoding phase. + if mm_positions[-1]["offset"] + mm_positions[-1][ + "length"] < start_token_idx: + return None, start_mm_idx + + # Support start_mm_idx == -1 to indicate the last mm input. + if start_mm_idx < 0: + assert -start_mm_idx <= len(mm_positions) + start_mm_idx = len(mm_positions) + start_mm_idx + + extra_keys = [] + curr_mm_idx = start_mm_idx + while mm_positions and curr_mm_idx < len(mm_positions): + assert mm_hashes[curr_mm_idx] is not None + offset = mm_positions[curr_mm_idx]["offset"] + length = mm_positions[curr_mm_idx]["length"] + if end_token_idx > offset: + if start_token_idx > offset + length: + # This block has passed the current mm input. + curr_mm_idx += 1 + continue + + # The block contains the current mm input. + mm_start = max(0, start_token_idx - offset) + extra_keys.append((mm_hashes[curr_mm_idx], mm_start)) + if end_token_idx >= offset + length: + # If this block contains the end of the current mm input, + # move to the next mm input as this block may also contain + # the next mm input. + curr_mm_idx += 1 + else: + # Otherwise this block is done with mm inputs. + break + else: + # This block has not reached the current mm input. + break + return tuple(extra_keys), curr_mm_idx + + +def hash_block_tokens( + parent_block_hash: Optional[int], + curr_block_token_ids: Sequence[int], + extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing @@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int], if this is the first block. curr_block_token_ids: A list of token ids in the current block. The current block is assumed to be full. + extra_keys: Extra keys for the block. Returns: The hash value of the block and the token ids in the block. The entire tuple is used as the hash key of the block. """ return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)), - tuple(curr_block_token_ids)) + tuple(curr_block_token_ids), extra_keys) def hash_request_tokens(block_size: int, - token_ids: Sequence[int]) -> List[BlockHashType]: + request: Request) -> List[BlockHashType]: """Computes hash values of a chain of blocks given a sequence of token IDs. The hash value is used for prefix caching. Args: block_size: The size of each block. - token_ids: A sequence of token ids in the request. + request: The request object. Returns: The list of computed hash values. """ + token_ids = request.all_token_ids + mm_positions, mm_hashes = request.mm_positions, request.mm_hashes + if mm_positions and len(mm_positions) != len(mm_hashes): + raise ValueError( + "The number of multi-modal positions and hashes must match.") + + # TODO: Extend this to support other features such as LoRA. + need_extra_keys = bool(mm_positions) + extra_keys = None + curr_mm_idx = 0 + ret = [] parent_block_hash_value = None for start in range(0, len(token_ids), block_size): @@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int, # Do not hash the block if it is not full. if len(block_token_ids) < block_size: break + + # Add extra keys if the block is a multi-modal block. + if need_extra_keys: + extra_keys, curr_mm_idx = generate_block_hash_extra_keys( + request, start, end, curr_mm_idx) + block_hash = hash_block_tokens(parent_block_hash_value, - block_token_ids) + block_token_ids, extra_keys) ret.append(block_hash) parent_block_hash_value = block_hash.hash_value return ret diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 178532e477dae..08e7c0fd4dc9b 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -516,6 +516,7 @@ class NewRequestData: prompt_token_ids: List[int] prompt: Optional[str] mm_inputs: List["MultiModalKwargs"] + mm_hashes: List[str] mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams block_ids: List[int] @@ -533,6 +534,7 @@ def from_request( prompt_token_ids=request.prompt_token_ids, prompt=request.prompt, mm_inputs=request.mm_inputs, + mm_hashes=request.mm_hashes, mm_positions=request.mm_positions, sampling_params=request.sampling_params, block_ids=block_ids, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index b36de5f66917c..41fb4b25d45bb 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -60,9 +60,13 @@ def __init__( self.client_aborted_requests: List[str] = [] # Processor (converts Inputs --> EngineCoreRequests). - self.processor = Processor(vllm_config.model_config, - vllm_config.lora_config, self.tokenizer, - input_registry) + self.processor = Processor( + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + lora_config=vllm_config.lora_config, + tokenizer=self.tokenizer, + input_registry=input_registry, + ) # Detokenizer (converts EngineCoreOutputs --> RequestOutput). self.detokenizer = Detokenizer( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 56d4dc67e4a0e..497d5db5b4c99 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -65,7 +65,8 @@ def __init__( self._last_logging_time = time.time() - self.mm_input_mapper_server = MMInputMapperServer() + self.mm_input_mapper_server = MMInputMapperServer( + vllm_config.model_config) def _initialize_kv_caches(self, cache_config: CacheConfig) -> Tuple[int, int]: @@ -98,9 +99,8 @@ def add_request(self, request: EngineCoreRequest): # MM mapper, so anything that has a hash must have a HIT cache # entry here as well. assert request.mm_inputs is not None - request.mm_inputs, request.mm_hashes = ( - self.mm_input_mapper_server.process_inputs( - request.mm_inputs, request.mm_hashes)) + request.mm_inputs = self.mm_input_mapper_server.process_inputs( + request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 15dedbd0f9529..bea8c5502f612 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -55,9 +55,12 @@ def __init__( self.tokenizer.ping() # Processor (convert Inputs --> EngineCoreRequests) - self.processor = Processor(vllm_config.model_config, - vllm_config.lora_config, self.tokenizer, - input_registry, mm_registry) + self.processor = Processor(model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + lora_config=vllm_config.lora_config, + tokenizer=self.tokenizer, + input_registry=input_registry, + mm_registry=mm_registry) # Detokenizer (converts EngineCoreOutputs --> RequestOutput) self.detokenizer = Detokenizer( diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index 6cdeba6f3f71e..e53ba092ede04 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import PIL from blake3 import blake3 @@ -42,6 +42,8 @@ def __init__( model_config) self.mm_registry.init_mm_limits_per_prompt(model_config) + # Init cache + self.use_cache = model_config.mm_cache_preprocessor self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) # DEBUG: Set to None to disable @@ -61,7 +63,7 @@ def process_inputs( mm_hashes: Optional[List[str]], mm_processor_kwargs: Optional[Dict[str, Any]], precomputed_mm_inputs: Optional[List[MultiModalKwargs]], - ) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]: + ) -> List[MultiModalKwargs]: if precomputed_mm_inputs is None: image_inputs = mm_data["image"] if not isinstance(image_inputs, list): @@ -70,26 +72,21 @@ def process_inputs( else: num_inputs = len(precomputed_mm_inputs) - # Check if hash is enabled - use_hash = mm_hashes is not None - if use_hash: + # Sanity + if self.use_cache: assert mm_hashes is not None - assert num_inputs == len( - mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format( - num_inputs, len(mm_hashes)) + assert num_inputs == len(mm_hashes) # Process each image input separately, so that later we can schedule # them in a fine-grained manner. # Apply caching (if enabled) and reuse precomputed inputs (if provided) - ret_hashes: Optional[List[str]] = [] if use_hash else None ret_inputs: List[MultiModalKwargs] = [] for input_id in range(num_inputs): if self.mm_debug_cache_hit_ratio_steps is not None: self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps) - mm_hash = None mm_input = None - if use_hash: + if self.use_cache: assert mm_hashes is not None mm_hash = mm_hashes[input_id] mm_input = self.mm_cache.get(mm_hash) @@ -106,7 +103,7 @@ def process_inputs( mm_processor_kwargs=mm_processor_kwargs, ) - if use_hash: + if self.use_cache: # Add to cache assert mm_hash is not None self.mm_cache.put(mm_hash, mm_input) @@ -114,18 +111,15 @@ def process_inputs( self.mm_cache_hits += 1 mm_input = None # Avoids sending mm_input to Server - if use_hash: - assert mm_hash is not None - assert ret_hashes is not None - ret_hashes.append(mm_hash) ret_inputs.append(mm_input) - return ret_inputs, ret_hashes + return ret_inputs class MMInputMapperServer: - def __init__(self, ): + def __init__(self, model_config): + self.use_cache = model_config.mm_cache_preprocessor self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) def process_inputs( @@ -135,6 +129,9 @@ def process_inputs( ) -> List[MultiModalKwargs]: assert len(mm_inputs) == len(mm_hashes) + if not self.use_cache: + return mm_inputs + full_mm_inputs = [] for mm_input, mm_hash in zip(mm_inputs, mm_hashes): assert mm_hash is not None diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 679bf8e25e9ca..732757d6b0ac2 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -1,7 +1,7 @@ import time from typing import Any, Dict, Mapping, Optional, Tuple, Union -from vllm.config import LoRAConfig, ModelConfig +from vllm.config import CacheConfig, LoRAConfig, ModelConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, PromptType, SingletonInputsAdapter) from vllm.inputs.parse import is_encoder_decoder_inputs @@ -23,6 +23,7 @@ class Processor: def __init__( self, model_config: ModelConfig, + cache_config: CacheConfig, lora_config: Optional[LoRAConfig], tokenizer: BaseTokenizerGroup, input_registry: InputRegistry = INPUT_REGISTRY, @@ -45,8 +46,9 @@ def __init__( self.mm_input_mapper_client = MMInputMapperClient(model_config) # Multi-modal hasher (for images) - self.mm_hasher = MMHasher( - ) if model_config.mm_cache_preprocessor else None + self.use_hash = model_config.mm_cache_preprocessor or \ + cache_config.enable_prefix_caching + self.mm_hasher = MMHasher() # TODO: run in an ThreadpoolExecutor or BackgroundProcess. # This ideally should releases the GIL, so we should not block the @@ -77,7 +79,7 @@ def process_inputs( # Compute MM hashes (if enabled) mm_hashes = None - if self.mm_hasher is not None: + if self.use_hash: mm_hashes = self.mm_hasher.hash(prompt) # Process inputs. @@ -118,7 +120,7 @@ def process_inputs( # Apply MM mapper mm_inputs = None if len(decoder_inputs.multi_modal_data) > 0: - mm_inputs, mm_hashes = self.mm_input_mapper_client.process_inputs( + mm_inputs = self.mm_input_mapper_client.process_inputs( decoder_inputs.multi_modal_data, mm_hashes, decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 1737d096e811d..f4783ae366ef0 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,5 +1,5 @@ import enum -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs from vllm.lora.request import LoRARequest @@ -9,6 +9,9 @@ from vllm.v1.engine import EngineCoreRequest from vllm.v1.utils import ConstantList +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_utils import BlockHashType + class Request: @@ -45,6 +48,7 @@ def __init__( self._all_token_ids: List[int] = self.prompt_token_ids.copy() self.num_computed_tokens = 0 + # Multi-modal input metadata. mm_positions = self.inputs.multi_modal_placeholders if mm_positions: # FIXME(woosuk): Support other modalities. @@ -56,6 +60,12 @@ def __init__( if self.inputs.multi_modal_inputs: self.mm_inputs = self.inputs.multi_modal_inputs + self.mm_hashes: List[str] = self.inputs.multi_modal_hashes + + # Cache the computed kv block hashes of the request to avoid + # recomputing. + self._kv_block_hashes: List[BlockHashType] = [] + @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( @@ -65,6 +75,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": prompt=request.prompt, multi_modal_data=None, multi_modal_inputs=request.mm_inputs, + multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, mm_processor_kwargs=None, ), @@ -121,6 +132,17 @@ def get_num_encoder_tokens(self, input_id: int) -> int: num_tokens = self.mm_positions[input_id]["length"] return num_tokens + @property + def kv_block_hashes(self) -> ConstantList["BlockHashType"]: + # Prevent directly appending to the kv_block_hashes. + return ConstantList(self._kv_block_hashes) + + def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None: + self._kv_block_hashes = value + + def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: + self._kv_block_hashes.append(block_hash) + class RequestStatus(enum.IntEnum): """Status of a request.""" From 866fa4550d572f4ff3521ccf503e0df2e76591a1 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 18 Dec 2024 01:39:07 +0100 Subject: [PATCH 04/23] [Bugfix] Restore support for larger block sizes (#11259) Signed-off-by: Konrad Zawora --- vllm/config.py | 4 ++++ vllm/engine/arg_utils.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9ecd3e72afa9f..307cf9c8d5b2a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -917,6 +917,10 @@ def _verify_args(self) -> None: raise ValueError( "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + if (current_platform.is_cuda() and self.block_size is not None + and self.block_size > 32): + raise ValueError("CUDA Paged Attention kernel only supports " + f"block sizes up to 32. Got {self.block_size}.") def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 674577f23eba6..64cc4592c2861 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -424,10 +424,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32], + choices=[8, 16, 32, 64, 128], help='Token block size for contiguous chunks of ' 'tokens. This is ignored on neuron devices and ' - 'set to max-model-len') + 'set to max-model-len. On CUDA devices, ' + 'only block sizes up to 32 are supported. ' + 'On HPU devices, block size defaults to 128.') parser.add_argument( "--enable-prefix-caching", From 8b79f9e107fd4214187bf65485b3ea1bb3191a46 Mon Sep 17 00:00:00 2001 From: Wallas Henrique Date: Wed, 18 Dec 2024 03:34:08 -0300 Subject: [PATCH 05/23] [Bugfix] Fix guided decoding with tokenizer mode mistral (#11046) --- .buildkite/test-pipeline.yaml | 6 +- requirements-common.txt | 3 +- .../model_executor/test_guided_processors.py | 54 ++++++++- .../decoder_only/language/test_mistral.py | 86 ++++++++++++- .../guided_decoding/xgrammar_decoding.py | 113 +++++++++++------- vllm/transformers_utils/tokenizer.py | 2 +- vllm/transformers_utils/tokenizers/mistral.py | 5 +- 7 files changed, 217 insertions(+), 52 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 44f47fac1c1b3..b563c96343f92 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -224,8 +224,12 @@ steps: mirror_hardwares: [amd] source_file_dependencies: - vllm/model_executor/layers + - vllm/model_executor/guided_decoding - tests/test_logits_processor - command: pytest -v -s test_logits_processor.py + - tests/model_executor/test_guided_processors + commands: + - pytest -v -s test_logits_processor.py + - pytest -v -s model_executor/test_guided_processors.py - label: Speculative decoding tests # 30min source_file_dependencies: diff --git a/requirements-common.txt b/requirements-common.txt index bd2b4b7a01668..1c935303c8d79 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -14,12 +14,13 @@ aiohttp openai >= 1.45.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) uvicorn[standard] pydantic >= 2.9 # Required for fastapi >= 0.113.0 -pillow # Required for image processing prometheus_client >= 0.18.0 +pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines == 0.1.11 +lark == 1.2.2 xgrammar >= 0.1.6; platform_machine == "x86_64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 9f4d81b583141..3334c0df149b5 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -1,13 +1,19 @@ +import pickle + import pytest import torch from transformers import AutoTokenizer +from vllm.config import ModelConfig from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) + get_guided_decoding_logits_processor, + get_local_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( JSONLogitsProcessor, RegexLogitsProcessor) from vllm.sampling_params import GuidedDecodingParams +MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta' + def test_guided_logits_processors(sample_regex, sample_json_schema): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" @@ -38,14 +44,29 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): @pytest.mark.asyncio @pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -async def test_guided_logits_processor_black_box(backend: str, sample_regex, +@pytest.mark.parametrize("is_local", [True, False]) +async def test_guided_logits_processor_black_box(backend: str, is_local: bool, + sample_regex, sample_json_schema): - tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + + config = ModelConfig( + MODEL_NAME, + task="generate", + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="bfloat16", + ) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) token_ids = tokenizer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) - regex_lp = await get_guided_decoding_logits_processor( - regex_request, tokenizer) + + regex_lp = get_local_guided_decoding_logits_processor( + regex_request, tokenizer, config) if is_local else \ + await get_guided_decoding_logits_processor( + regex_request, tokenizer, config) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -59,7 +80,7 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = await get_guided_decoding_logits_processor( - json_request, tokenizer) + json_request, tokenizer, config) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -84,3 +105,24 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): with pytest.raises(ValueError, match="You can only use one kind of guided"): GuidedDecodingParams(json=sample_json_schema, grammar="test grammar") + + +def test_pickle_xgrammar_tokenizer_data(): + + # TODO: move to another test file for xgrammar + try: + import xgrammar as xgr + except ImportError: + pytest.skip("Could not import xgrammar to run test") + + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( + TokenizerData) + tokenizer_data = TokenizerData(vocab_type=xgr.VocabType.RAW) + pickled = pickle.dumps(tokenizer_data) + + assert pickled is not None + + depickled: TokenizerData = pickle.loads(pickled) + + assert depickled is not None + assert depickled.vocab_type == xgr.VocabType.RAW diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 99b5d5694f9f7..bdc1571784b5d 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -3,17 +3,20 @@ Run `pytest tests/models/test_mistral.py`. """ import copy +import json +import jsonschema +import jsonschema.exceptions import pytest -from vllm import SamplingParams from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa MistralToolParser) +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...utils import check_logprobs_close MODELS = [ - "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mistral-7B-Instruct-v0.3", ] MISTRAL_FORMAT_MODELS = [ @@ -126,6 +129,45 @@ } ] +SAMPLE_JSON_SCHEMA = { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work_history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "position": { + "type": "string" + } + }, + "required": ["company", "position"] + } + } + }, + "required": ["name", "age", "skills", "work_history"] +} + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @@ -251,3 +293,43 @@ def test_mistral_function_calling( assert parsed_message.tool_calls[ 0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa assert parsed_message.content is None + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("guided_backend", + ["outlines", "lm-format-enforcer", "xgrammar"]) +def test_mistral_guided_decoding( + vllm_runner, + model: str, + guided_backend: str, +) -> None: + with vllm_runner(model, dtype='bfloat16', + tokenizer_mode="mistral") as vllm_model: + + guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA, + backend=guided_backend) + params = SamplingParams(max_tokens=512, + temperature=0.7, + guided_decoding=guided_decoding) + + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {SAMPLE_JSON_SCHEMA}" + }] + outputs = vllm_model.model.chat(messages, sampling_params=params) + + generated_text = outputs[0].outputs[0].text + json_response = json.loads(generated_text) + assert outputs is not None + + try: + jsonschema.validate(instance=json_response, + schema=SAMPLE_JSON_SCHEMA) + except jsonschema.exceptions.ValidationError: + pytest.fail("Generated response is not valid with JSON schema") diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index fc45e37cf6f06..5b97f03257502 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, NamedTuple +from typing import TYPE_CHECKING, Any import torch from transformers import PreTrainedTokenizerFast @@ -16,6 +16,7 @@ from vllm.model_executor.guided_decoding.xgrammar_utils import ( convert_lark_to_gbnf, grammar_is_likely_lark) +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -37,11 +38,21 @@ def get_local_xgrammar_guided_decoding_logits_processor( return XGrammarLogitsProcessor(config) -class TokenizerData(NamedTuple): +@dataclass(frozen=True) +class TokenizerData: """Immutable container for cached tokenizer data.""" - encoded_vocab: list[str] - stop_token_ids: list[int] | None - backend_str: str + encoded_vocab: list[str] = field(default_factory=list) + stop_token_ids: list[int] | None = None + # These fields are mutually exclusive: `backend_str` is used to create a + # TokenizeInfo with `TokenizerInfo.from_huggingface` while `vocab_type` is + # used within the constructor of TokenizeInfo + backend_str: str | None = None + vocab_type: xgr.VocabType | None = None + + def __post_init__(self): + # Check for mutual exclusive + assert not (self.backend_str and self.vocab_type), \ + "backend_str and vocab_type are mutual exclusive" class TokenizerDataCache: @@ -68,18 +79,27 @@ def get_tokenizer_data(cls, "get_vocab method.") from e stop_token_ids = None - backend_str = xgr.VocabType.RAW + backend_str = "" + vocab_type = xgr.VocabType.RAW + + if stop_token_ids is None and hasattr( + tokenizer, + "eos_token_id") and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + if isinstance(tokenizer, PreTrainedTokenizerFast): backend_str = tokenizer.backend_tokenizer.to_str() - if stop_token_ids is None and hasattr( - tokenizer, - "eos_token_id") and tokenizer.eos_token_id is not None: - stop_token_ids = [tokenizer.eos_token_id] + vocab_type = None + + elif isinstance(tokenizer, MistralTokenizer): + # REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 + vocab_type = xgr.VocabType.BYTE_FALLBACK cls._cache[tokenizer_hash] = TokenizerData( encoded_vocab=encoded_vocab, stop_token_ids=stop_token_ids, - backend_str=backend_str) + backend_str=backend_str, + vocab_type=vocab_type) return cls._cache[tokenizer_hash] @@ -98,11 +118,30 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler: cache_key = str(config.tokenizer_hash) if cache_key not in cls._cache: - assert config.encoded_vocab is not None - tokenizer_info = xgr.TokenizerInfo._create_from_handle( - xgr_core.TokenizerInfo.from_huggingface( - config.encoded_vocab, config.backend_str, - config.vocab_size, config.stop_token_ids)) + assert config.tokenizer_data is not None + assert config.tokenizer_data.encoded_vocab is not None + + config_data = config.tokenizer_data + + # In TokenizerDataCache.get_tokenizer_data, a serializable + # tokenizer_data is created and cached. This data is used to build + # a tokenizer_info and create an xgrammar compiler. + # - If tokenizer_data has backend_str set, use + # xgr_core.TokenizerInfo.from_huggingface (a C++ bind). + # - Otherwise, use the default constructor with vocab_type. + # - xgr_core.TokenizerInfo.from_huggingface != + # xgr.TokenizerInfo.from_huggingface. + if config_data.backend_str: + tokenizer_info = xgr.TokenizerInfo._create_from_handle( + xgr_core.TokenizerInfo.from_huggingface( + config_data.encoded_vocab, config_data.backend_str, + config.vocab_size, config_data.stop_token_ids)) + else: + tokenizer_info = xgr.TokenizerInfo( + config_data.encoded_vocab, + config_data.vocab_type, + vocab_size=config.vocab_size, + stop_token_ids=config_data.stop_token_ids) cls._cache[cache_key] = xgr.GrammarCompiler( tokenizer_info, max_threads=config.max_threads) @@ -118,10 +157,7 @@ class GrammarConfig: grammar_str: str | None = None json_object: bool | None = None max_threads: int = 8 - # Only populated if tokenizer_hash not in cache - encoded_vocab: list[str] | None = None - stop_token_ids: list[int] | None = None - backend_str: str | None = None + tokenizer_data: TokenizerData | None = None @classmethod def from_guided_params(cls, @@ -132,9 +168,6 @@ def from_guided_params(cls, tokenizer_hash = hash(tokenizer) tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer) - encoded_vocab = tokenizer_data.encoded_vocab - stop_token_ids = tokenizer_data.stop_token_ids - backend_str = tokenizer_data.backend_str if guided_params.json: if not isinstance(guided_params.json, str): @@ -152,11 +185,9 @@ def from_guided_params(cls, return cls(json_str=json_str, vocab_size=model_config.hf_text_config.vocab_size, - encoded_vocab=encoded_vocab, - stop_token_ids=stop_token_ids, - backend_str=backend_str, tokenizer_hash=tokenizer_hash, - max_threads=max_threads) + max_threads=max_threads, + tokenizer_data=tokenizer_data) elif guided_params.grammar: # XGrammar only supports GBNF grammars, so we must convert Lark if grammar_is_likely_lark(guided_params.grammar): @@ -181,19 +212,17 @@ def from_guided_params(cls, return cls(grammar_str=grammar_str, vocab_size=model_config.hf_text_config.vocab_size, - encoded_vocab=encoded_vocab, - stop_token_ids=stop_token_ids, - backend_str=backend_str, tokenizer_hash=tokenizer_hash, - max_threads=max_threads) + max_threads=max_threads, + tokenizer_data=tokenizer_data) elif guided_params.json_object: - return cls(json_object=True, - vocab_size=model_config.hf_text_config.vocab_size, - encoded_vocab=encoded_vocab, - stop_token_ids=stop_token_ids, - backend_str=backend_str, - tokenizer_hash=tokenizer_hash, - max_threads=max_threads) + return cls( + json_object=True, + vocab_size=model_config.hf_text_config.vocab_size, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data, + ) else: raise ValueError( "Currently only support JSON and EBNF grammar mode for xgrammar" @@ -269,10 +298,14 @@ def __call__(self, input_ids: list[int], # fill_next_token_bitmask so we move it to the device of scores device_type = scores.device.type if device_type != "cuda": - scores = scores.to("cpu") + scores = scores.to("cpu").unsqueeze(0) + + # Note: In this method, if the tensors have different dimensions + # on CPU device fails, but on GPU it runs without error. Hence the + # unsqueeze above for scores, to match the token bitmask shape xgr.apply_token_bitmask_inplace(scores, self.token_bitmask.to(scores.device)) if device_type != "cuda": - scores = scores.to(device_type) + scores = scores.to(device_type).squeeze() return scores diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 54f9f895fe541..e6701f4c4b835 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -132,7 +132,7 @@ def get_tokenizer( if is_from_mistral_org and tokenizer_mode != "mistral": warnings.warn( 'It is strongly recommended to run mistral models with ' - '`--tokenizer_mode "mistral"` to ensure correct ' + '`--tokenizer-mode "mistral"` to ensure correct ' 'encoding and decoding.', FutureWarning, stacklevel=2) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 83b3c37d6f04c..17d722e3d88fe 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -314,12 +314,15 @@ def _token_to_id(t: str): if regular_tokens: decoded_list.append( - self.decode(regular_tokens)) # type: ignore + self.tokenizer.decode(regular_tokens)) # type: ignore decoded = ''.join(decoded_list) return decoded + # WARN: Outlines logits processors can overwrite this method. + # See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer + # for more. def decode(self, ids: Union[List[int], int], skip_special_tokens: bool = True) -> str: From f04e407e6b6b9ce65c16cffda836f05c2ad32682 Mon Sep 17 00:00:00 2001 From: Yan Ma Date: Wed, 18 Dec 2024 14:34:23 +0800 Subject: [PATCH 06/23] [MISC][XPU]update ipex link for CI fix (#11278) --- requirements-xpu.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements-xpu.txt b/requirements-xpu.txt index e41295792283f..42c6c321d040c 100644 --- a/requirements-xpu.txt +++ b/requirements-xpu.txt @@ -9,8 +9,8 @@ setuptools-scm>=8 wheel jinja2 -torch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp310-cp310-linux_x86_64.whl -intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp310-cp310-linux_x86_64.whl -oneccl_bind_pt @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp310-cp310-linux_x86_64.whl +torch @ https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp310-cp310-linux_x86_64.whl +intel-extension-for-pytorch @ https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp310-cp310-linux_x86_64.whl +oneccl_bind_pt @ https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp310-cp310-linux_x86_64.whl triton-xpu == 3.0.0b1 From 60508ffda91c22e4cde3b18f149d222211db8886 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 18 Dec 2024 09:57:16 -0500 Subject: [PATCH 07/23] [Kernel]: Cutlass 2:4 Sparsity + FP8/Int8 Quant Support (#10995) Co-authored-by: Faraz Shahsavan Co-authored-by: ilmarkov Co-authored-by: Rahul Tuli Co-authored-by: rshaw@neuralmagic.com --- CMakeLists.txt | 26 +- .../cutlass_benchmarks/sparse_benchmarks.py | 384 ++++++++++++++ benchmarks/cutlass_benchmarks/utils.py | 96 ++++ .../cutlass_benchmarks/w8a8_benchmarks.py | 28 +- .../cutlass_benchmarks/weight_shapes.py | 2 +- csrc/core/math.hpp | 7 + csrc/cutlass_extensions/common.cpp | 11 + csrc/cutlass_extensions/common.hpp | 35 ++ .../epilogue/scaled_mm_epilogues_c3x.hpp | 4 +- csrc/ops.h | 9 + csrc/quantization/cutlass_w8a8/common.hpp | 27 - .../cutlass_w8a8/scaled_mm_c2x.cuh | 3 +- .../cutlass_w8a8/scaled_mm_c3x.cu | 3 +- .../cutlass_w8a8/scaled_mm_entry.cu | 12 +- csrc/sparse/cutlass/sparse_compressor_c3x.cu | 163 ++++++ .../sparse/cutlass/sparse_compressor_entry.cu | 42 ++ csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu | 303 +++++++++++ csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 496 ++++++++++++++++++ csrc/sparse/cutlass/sparse_scaled_mm_entry.cu | 59 +++ csrc/torch_bindings.cpp | 15 + pyproject.toml | 2 +- tests/kernels/test_semi_structured.py | 131 +++++ tests/quantization/test_compressed_tensors.py | 103 +++- tests/weight_loading/models.txt | 2 + .../run_model_weight_loading_test.sh | 4 + tests/weight_loading/test_weight_loading.py | 7 + vllm/_custom_ops.py | 103 ++++ .../compressed_tensors/compressed_tensors.py | 187 ++++++- .../compressed_tensors/schemes/__init__.py | 15 +- .../schemes/compressed_tensors_24.py | 203 +++++++ 30 files changed, 2365 insertions(+), 117 deletions(-) create mode 100644 benchmarks/cutlass_benchmarks/sparse_benchmarks.py create mode 100644 benchmarks/cutlass_benchmarks/utils.py create mode 100644 csrc/core/math.hpp create mode 100644 csrc/cutlass_extensions/common.cpp create mode 100644 csrc/cutlass_extensions/common.hpp delete mode 100644 csrc/quantization/cutlass_w8a8/common.hpp create mode 100644 csrc/sparse/cutlass/sparse_compressor_c3x.cu create mode 100644 csrc/sparse/cutlass/sparse_compressor_entry.cu create mode 100644 csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu create mode 100644 csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh create mode 100644 csrc/sparse/cutlass/sparse_scaled_mm_entry.cu create mode 100644 tests/kernels/test_semi_structured.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py diff --git a/CMakeLists.txt b/CMakeLists.txt index bf19b3d227171..51b49a18dddf2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -206,7 +206,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. - set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG v3.5.1 + GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE + GIT_SHALLOW FALSE ) endif() FetchContent_MakeAvailable(cutlass) @@ -241,7 +241,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/awq/gemm_kernels.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_compressor_entry.cu" + "csrc/cutlass_extensions/common.cpp") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -271,11 +274,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # - # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels + # For Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" + "csrc/sparse/cutlass/sparse_compressor_c3x.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") @@ -284,12 +290,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") else() if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is " + message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is " "not >= 12.0, we recommend upgrading to CUDA 12.0 or " - "later if you intend on running FP8 quantized models on " + "later if you intend on running FP8 sparse or quantized models on " "Hopper.") else() - message(STATUS "Not building scaled_mm_c3x as no compatible archs found " + message(STATUS "Not building cutlass_c3x as no compatible archs found " "in CUDA target architectures") endif() @@ -404,7 +410,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py new file mode 100644 index 0000000000000..3d1c5e392f9e2 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -0,0 +1,384 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_sparse_tensors +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + + +# bench +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + # pytorch impl - bfloat16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) + + # pytorch impl - float16 + timers.append( + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + + # cutlass impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass sparse impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass sparse with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, + k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16, bias.to(dtype=torch.float16))) + + return timers + + +def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']") + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py new file mode 100644 index 0000000000000..ef06fcd6604dd --- /dev/null +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -0,0 +1,96 @@ +# Cutlass bench utils +from typing import Iterable, Tuple + +import torch + +import vllm._custom_ops as ops + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def to_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.bfloat16) + + +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.float16) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + if dtype == torch.int8: + return to_int8(a), to_int8(b) + if dtype == torch.float8_e4m3fn: + return to_fp8(a), to_fp8(b) + + raise ValueError("unsupported dtype") + + +def prune_to_2_4(tensor): + # Reshape tensor to [N, 4] where N is number of groups of 4 + original_shape = tensor.shape + reshaped = tensor.reshape(-1, 4) + + # Get indices of top 2 absolute values in each group of 4 + _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) + + # Create binary mask + mask = torch.zeros_like(reshaped) + mask.scatter_(dim=1, + index=indices, + src=torch.ones_like(indices, dtype=mask.dtype)) + + # Apply mask and reshape back + pruned = reshaped * mask + + # Turn all -0.0 to 0.0 + pruned[pruned == -0.0] = 0.0 + + return pruned.reshape(original_shape) + + +def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + b = prune_to_2_4(b.t()).t() + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + elif dtype == torch.float16: + a, b = to_fp16(a), to_fp16(b) + elif dtype == torch.bfloat16: + a, b = to_bf16(a), to_bf16(b) + else: + raise ValueError("unsupported dtype") + + b_compressed, e = ops.cutlass_sparse_compress(b.t()) + + # Compressed B, Metadata, Original A, B + return b_compressed, e, a, b + + +def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, + m: int, n: int, k: int) -> \ + Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: + ABs = [] + for _ in range(num_tensors): + b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) + if b_comp is not None: + ABs.append(make_rand_sparse_tensors(dtype, m, n, k)) + BComps, Es, As, Bs = zip(*ABs) + return list(BComps), list(Es), list(As), list(Bs) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 63cf5d50cac75..d0353bc8cb42a 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -8,6 +8,7 @@ import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_tensors from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops @@ -17,31 +18,6 @@ DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] -# helpers - - -def to_fp8(tensor: torch.Tensor) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) - - -def to_int8(tensor: torch.Tensor) -> torch.Tensor: - return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) - - -def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> Tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 - - if dtype == torch.int8: - return to_int8(a), to_int8(b) - if dtype == torch.float8_e4m3fn: - return to_fp8(a), to_fp8(b) - - raise ValueError("unsupported dtype") - # bench def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, @@ -386,4 +362,4 @@ def to_torch_dtype(dt): model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() - args.func(args) + args.func(args) \ No newline at end of file diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py index 25ec9d6028627..d58fb0bf86374 100644 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -40,4 +40,4 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], -} +} \ No newline at end of file diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp new file mode 100644 index 0000000000000..ba9f40a230c8e --- /dev/null +++ b/csrc/core/math.hpp @@ -0,0 +1,7 @@ +#include +#include + +inline uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/cutlass_extensions/common.cpp new file mode 100644 index 0000000000000..3d2093ab94297 --- /dev/null +++ b/csrc/cutlass_extensions/common.cpp @@ -0,0 +1,11 @@ +#include "cutlass_extensions/common.hpp" + +int32_t get_sm_version_num() { + int32_t major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + return version_num; +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp new file mode 100644 index 0000000000000..85e359aa57113 --- /dev/null +++ b/csrc/cutlass_extensions/common.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include +#include "cuda_runtime.h" +#include + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ + } + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device); + return max_shared_mem_per_block_opt_in; +} + +int32_t get_sm_version_num(); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 95764ecddc79f..fcc17c7727f94 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -36,13 +36,13 @@ struct ScaledEpilogueBase { // Don't want to support nullptr by default template using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // Don't want to support nullptr by default template using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // This utility function constructs the arguments for the load descriptors diff --git a/csrc/ops.h b/csrc/ops.h index 816b471d062d2..c145e4eda0845 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -162,6 +162,15 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& azp_adj, c10::optional const& azp, c10::optional const& bias); + +void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& e, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + c10::optional const& bias); + +bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, + torch::Tensor& e, torch::Tensor const& a); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp deleted file mode 100644 index bf04bb400790f..0000000000000 --- a/csrc/quantization/cutlass_w8a8/common.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include "cutlass/cutlass.h" -#include - -/** - * Helper function for checking CUTLASS errors - */ -#define CUTLASS_CHECK(status) \ - { \ - TORCH_CHECK(status == cutlass::Status::kSuccess, \ - cutlassGetStatusString(status)) \ - } - -inline uint32_t next_pow_2(uint32_t const num) { - if (num <= 1) return num; - return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} - -inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { - int max_shared_mem_per_block_opt_in = 0; - cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, - cudaDevAttrMaxSharedMemoryPerBlockOptin, - device); - return max_shared_mem_per_block_opt_in; -} - diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index d03242f44ab1d..75681f7f37820 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -21,7 +21,8 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" -#include "common.hpp" +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" // clang-format on using namespace cute; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 33581a63d4c3d..8190277997161 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -24,7 +24,8 @@ #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include "common.hpp" +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" // clang-format on using namespace cute; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 97a969cf5e3e0..4f7b6588ef3f7 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -3,6 +3,8 @@ #include #include +#include "cutlass_extensions/common.hpp" + void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { return false; } -int32_t get_sm_version_num() { - int32_t major_capability, minor_capability; - cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, - 0); - cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, - 0); - int32_t version_num = major_capability * 10 + minor_capability; - return version_num; -} - void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, diff --git a/csrc/sparse/cutlass/sparse_compressor_c3x.cu b/csrc/sparse/cutlass/sparse_compressor_c3x.cu new file mode 100644 index 0000000000000..218c5317b4de6 --- /dev/null +++ b/csrc/sparse/cutlass/sparse_compressor_c3x.cu @@ -0,0 +1,163 @@ +// clang-format will break include orders +// clang-format off +#include + +#include "sparse_scaled_mm_c3x.cuh" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +// clang-format on + +using namespace cute; +using namespace vllm; + +/// Make A structured sparse by replacing elements with 0 and compress it +template +bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a) { + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn || + a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16); + TORCH_CHECK(a.dim() == 2) + // Check for strides and alignment + TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity + TORCH_CHECK(a.stride(1) == 1) + + int m = a.size(0); + int k = a.size(1); + + // Sparse kernel setup; this kernel is not used for matmul, + // but just for setting up the compressor utility + // A matrix configuration + using ElementA = ElementA_; + using LayoutTagA = cutlass::layout::RowMajor; + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + // B matrix configuration + using ElementB = ElementA; + using LayoutTagB = cutlass::layout::ColumnMajor; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + // C/D matrix configuration + using ElementC = float; + using LayoutTagC = cutlass::layout::ColumnMajor; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + // Core kernel configurations + using ElementAccumulator = ElementAcc_; + using TileShape = Shape<_128, _128, _128>; + using TileShapeRef = Shape<_128, _128, _64>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = typename std::conditional< + std::is_same_v, + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum, + cutlass::gemm::KernelTmaWarpSpecialized>::type; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using ProblemShape = Shape; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC, + AlignmentC, ElementC, LayoutTagC, AlignmentC, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA, + LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideE = StrideA; + + using StrideA = Stride, int64_t>; + + // The n (=1) dimension does not matter for the compressor + typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1}; + + using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA; + using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE; + + using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; + using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig; + + // Offline compressor kernel + using CompressorUtility = + cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, ElementA, LayoutTagA, SparseConfig>; + + using CompressorKernel = + cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, ElementA, LayoutTagA, SparseConfig, + cutlass::arch::Sm90>; + + using Compressor = + cutlass::transform::device::TransformUniversalAdapter; + + auto [M, N, K, L] = prob_shape; + + StrideA stride_A; + stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + + CompressorUtility compressor_utility(prob_shape, stride_A); + + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + auto a_ptr = static_cast(a.data_ptr()); + + auto a_nzs_ptr = static_cast(a_nzs.data_ptr()); + auto a_meta_ptr = static_cast( + a_meta.data_ptr()); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + typename Compressor::Arguments arguments{ + prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}}; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + return true; +} + +bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a) { + if (a.dtype() == torch::kBFloat16) { + return cutlass_sparse_compress(a_nzs, a_meta, + a); + } else if (a.dtype() == torch::kFloat16) { + return cutlass_sparse_compress(a_nzs, a_meta, a); + } else if (a.dtype() == torch::kFloat8_e4m3fn) { + return cutlass_sparse_compress(a_nzs, a_meta, + a); + } else if (a.dtype() == torch::kInt8) { + return cutlass_sparse_compress(a_nzs, a_meta, a); + } + return false; +} \ No newline at end of file diff --git a/csrc/sparse/cutlass/sparse_compressor_entry.cu b/csrc/sparse/cutlass/sparse_compressor_entry.cu new file mode 100644 index 0000000000000..d23d937b6ac28 --- /dev/null +++ b/csrc/sparse/cutlass/sparse_compressor_entry.cu @@ -0,0 +1,42 @@ +#include + +#include +#include + +#include "cutlass_extensions/common.hpp" + +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X +bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a); +#endif + +bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a) { + // Checks for conformality + TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2); + TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) && + a_nzs.size(1) * 2 == a.size(1) && + a_meta.size(1) * 2 * 4 == a.size(1)); + // Considering elemsPerMetaElem = 8b / 2b_per_nz = 4 + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 && + a_meta.stride(1) == 1); // Row-major + TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression + + at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); + int32_t version_num = get_sm_version_num(); + + // Guard against compilation issues for sm90 kernels +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X + if (version_num >= 90) { + return cutlass_sparse_compress_sm90(a_nzs, a_meta, a); + } +#endif + + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_scaled_sparse_mm for a compute capability less than " + "CUDA device capability: ", + version_num); +} diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu new file mode 100644 index 0000000000000..b50e9a3a2c240 --- /dev/null +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -0,0 +1,303 @@ +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 +#include "sparse_scaled_mm_c3x.cuh" +// clang-format on + +using namespace cute; +using namespace vllm; + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_fp8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM256 = + typename sm90_fp8_config_M256::Cutlass3xGemm; + using Cutlass3xGemmM512 = + typename sm90_fp8_config_M512::Cutlass3xGemm; + + using Cutlass3xGemm1 = + typename sm90_fp8_config_1::Cutlass3xGemm; + using Cutlass3xGemm2 = + typename sm90_fp8_config_2::Cutlass3xGemm; + using Cutlass3xGemm3 = + typename sm90_fp8_config_3::Cutlass3xGemm; + using Cutlass3xGemm4 = + typename sm90_fp8_config_4::Cutlass3xGemm; + using Cutlass3xGemm5 = + typename sm90_fp8_config_5::Cutlass3xGemm; + using Cutlass3xGemm6 = + typename sm90_fp8_config_6::Cutlass3xGemm; + using Cutlass3xGemm7 = + typename sm90_fp8_config_7::Cutlass3xGemm; + using Cutlass3xGemm8 = + typename sm90_fp8_config_8::Cutlass3xGemm; + + uint32_t const n = bt_nzs.size(0); + uint32_t const m = a.size(0); // Batch size + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + if (mp2 <= 64) { + if (n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 4096 || n == 6144) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else if (mp2 <= 128) { + if (n == 4096) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 6144) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else if (mp2 <= 256) { + if (n == 4096) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 6144) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else { + if (n == 6144 || n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 4096) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } + + // Otherwise the default heuristic + if (mp2 <= 64) { + // n in [1, 64] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 128) { + // n in (64, 128] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 256) { + // n in (128, 256] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else { + // n in (256, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat16); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kBFloat16); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kInt8); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_int8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_int8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM32NBig = + typename sm90_int8_config_M32_NBig::Cutlass3xGemm; + using Cutlass3xGemmM32NSmall = + typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; + + uint32_t const n = out.size(1); + bool const is_small_n = n < 8192; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(32), next_pow_2(m)); // next power of 2 + + if (mp2 <= 32) { + // m in [1, 32] + if (is_small_n) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else if (mp2 <= 64) { + // m in (32, 64] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else { + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } +} + +template