From 81ede99ca44a5b3518932a07ea4a76a719e7416e Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 17 Oct 2024 11:38:15 -0500 Subject: [PATCH] [Core] Deprecating block manager v1 and make block manager v2 default (#8704) Removing the block manager v1. This is the initial piece of prefix-caching-centric design. In order to achieve prefix-caching-centric design, we need to simplify the code path so that we only use v2 block manager (which has much higher performance on prefix caching). --- .buildkite/test-pipeline.yaml | 18 +- benchmarks/benchmark_latency.py | 4 - benchmarks/benchmark_prefix_caching.py | 6 - benchmarks/benchmark_throughput.py | 11 +- benchmarks/overheads/benchmark_hashing.py | 4 - docs/source/models/spec_decode.rst | 3 - examples/offline_inference_mlpspeculator.py | 2 - .../basic_correctness/test_chunked_prefill.py | 11 +- tests/core/block/e2e/test_correctness.py | 78 +- .../e2e/test_correctness_sliding_window.py | 19 +- ...ck_manager_v2.py => test_block_manager.py} | 57 +- tests/core/test_block_manager.py | 637 --------------- tests/core/test_chunked_prefill_scheduler.py | 68 +- tests/core/test_num_computed_tokens_update.py | 1 - tests/core/test_scheduler.py | 150 ++-- tests/metrics/test_metrics.py | 16 +- .../multi_step/test_correctness_async_llm.py | 1 - tests/multi_step/test_correctness_llm.py | 4 - tests/prefix_caching/test_prefix_caching.py | 89 --- tests/spec_decode/e2e/test_compatibility.py | 68 +- .../spec_decode/e2e/test_eagle_correctness.py | 18 - tests/spec_decode/e2e/test_integration.py | 8 - .../e2e/test_integration_dist_tp2.py | 6 - .../e2e/test_integration_dist_tp4.py | 6 - tests/spec_decode/e2e/test_logprobs.py | 14 - .../e2e/test_medusa_correctness.py | 21 - tests/spec_decode/e2e/test_mlp_correctness.py | 27 - .../e2e/test_multistep_correctness.py | 36 - .../spec_decode/e2e/test_ngram_correctness.py | 16 - tests/spec_decode/e2e/test_seed.py | 3 - tests/utils.py | 9 - vllm/attention/backends/flash_attn.py | 8 +- vllm/attention/backends/flashinfer.py | 8 +- vllm/attention/backends/utils.py | 16 +- vllm/commit_id.py | 1 + vllm/config.py | 24 - vllm/core/block/utils.py | 24 +- .../{block_manager_v2.py => block_manager.py} | 2 +- vllm/core/block_manager_v1.py | 743 ------------------ vllm/core/interfaces.py | 10 +- vllm/core/scheduler.py | 4 +- vllm/engine/arg_utils.py | 38 +- vllm/engine/llm_engine.py | 3 +- vllm/envs.py | 6 - vllm/worker/model_runner.py | 17 +- 45 files changed, 206 insertions(+), 2109 deletions(-) rename tests/core/block/{test_block_manager_v2.py => test_block_manager.py} (91%) delete mode 100644 tests/core/test_block_manager.py create mode 100644 vllm/commit_id.py rename vllm/core/{block_manager_v2.py => block_manager.py} (99%) delete mode 100644 vllm/core/block_manager_v1.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 398fdc5f0ae2b..d2324d7cee60f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -77,8 +77,8 @@ steps: - vllm/ - tests/basic_correctness/test_chunked_prefill commands: - - VLLM_ATTENTION_BACKEND=XFORMERS VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s basic_correctness/test_chunked_prefill.py - - VLLM_ATTENTION_BACKEND=FLASH_ATTN VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - label: Core Test # 10min mirror_hardwares: [amd] @@ -88,11 +88,7 @@ steps: - vllm/distributed - tests/core commands: - - VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core/test_scheduler.py - - VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/test_chunked_prefill_scheduler.py - - VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/block/e2e/test_correctness.py - - VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/block/e2e/test_correctness_sliding_window.py - - pytest -v -s core --ignore=core/block/e2e/test_correctness.py --ignore=core/test_scheduler.py --ignore=core/test_chunked_prefill_scheduler.py --ignore=core/block/e2e/test_correctness.py --ignore=core/block/e2e/test_correctness_sliding_window.py + - pytest -v -s core - label: Entrypoints Test # 40min working_dir: "/vllm-workspace/tests" @@ -192,8 +188,7 @@ steps: - vllm/ - tests/prefix_caching commands: - - VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s prefix_caching/test_prefix_caching.py - - pytest -v -s prefix_caching --ignore=prefix_caching/test_prefix_caching.py + - pytest -v -s prefix_caching - label: Samplers Test # 36min source_file_dependencies: @@ -217,8 +212,7 @@ steps: - tests/spec_decode commands: - pytest -v -s spec_decode/e2e/test_multistep_correctness.py - - VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s spec_decode/e2e/test_compatibility.py - - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_compatibility.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py - label: LoRA Test %N # 15min each mirror_hardwares: [amd] @@ -405,7 +399,7 @@ steps: - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - - TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus + - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus # Avoid importing model tests that cause CUDA reinitialization error - pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 79a48b2a1a845..ea1a7788f621d 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -38,7 +38,6 @@ def main(args: argparse.Namespace): quantization_param_path=args.quantization_param_path, device=args.device, ray_workers_use_nsight=args.ray_workers_use_nsight, - use_v2_block_manager=args.use_v2_block_manager, enable_chunked_prefill=args.enable_chunked_prefill, download_dir=args.download_dir, block_size=args.block_size, @@ -221,9 +220,6 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument("--enable-prefix-caching", action='store_true', help="Enable automatic prefix caching") - parser.add_argument('--use-v2-block-manager', - action='store_true', - default=EngineArgs.use_v2_block_manager) parser.add_argument( "--ray-workers-use-nsight", action='store_true', diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index f14092d347343..a354358e43aa3 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -33,7 +33,6 @@ from transformers import PreTrainedTokenizerBase from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import EngineArgs from vllm.utils import FlexibleArgumentParser try: @@ -134,7 +133,6 @@ def main(args): tokenizer_mode='auto', trust_remote_code=True, enforce_eager=True, - use_v2_block_manager=args.use_v2_block_manager, tensor_parallel_size=args.tensor_parallel_size, enable_prefix_caching=args.enable_prefix_caching) @@ -176,10 +174,6 @@ def main(args): parser.add_argument('--enable-prefix-caching', action='store_true', help='enable prefix caching') - parser.add_argument('--use-v2-block-manager', - action='store_true', - default=EngineArgs.use_v2_block_manager, - help='Use BlockSpaceMangerV2') parser.add_argument('--num-prompts', type=int, default=1, diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index b7bc2a6402375..e26706af606b0 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -86,7 +86,6 @@ def run_vllm( distributed_executor_backend: Optional[str], gpu_memory_utilization: float = 0.9, num_scheduler_steps: int = 1, - use_v2_block_manager: bool = False, download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, @@ -113,7 +112,6 @@ def run_vllm( distributed_executor_backend=distributed_executor_backend, load_format=load_format, num_scheduler_steps=num_scheduler_steps, - use_v2_block_manager=use_v2_block_manager, disable_async_output_proc=disable_async_output_proc, ) @@ -176,7 +174,6 @@ async def run_vllm_async( distributed_executor_backend: Optional[str], gpu_memory_utilization: float = 0.9, num_scheduler_steps: int = 1, - use_v2_block_manager: bool = False, download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, @@ -204,7 +201,6 @@ async def run_vllm_async( distributed_executor_backend=distributed_executor_backend, load_format=load_format, num_scheduler_steps=num_scheduler_steps, - use_v2_block_manager=use_v2_block_manager, disable_async_output_proc=disable_async_output_proc, worker_use_ray=False, disable_log_requests=True, @@ -341,8 +337,7 @@ def main(args: argparse.Namespace): args.enable_prefix_caching, args.enable_chunked_prefill, args.max_num_batched_tokens, args.distributed_executor_backend, args.gpu_memory_utilization, args.num_scheduler_steps, - args.use_v2_block_manager, args.download_dir, args.load_format, - args.disable_async_output_proc + args.download_dir, args.load_format, args.disable_async_output_proc ] if args.async_engine: @@ -471,10 +466,6 @@ def main(args: argparse.Namespace): type=int, default=1, help="Maximum number of forward steps per scheduler call.") - parser.add_argument("--use-v2-block-manager", - action='store_true', - default=EngineArgs.use_v2_block_manager, - help="Enable block manager v2.") parser.add_argument( "--enable-prefix-caching", action='store_true', diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py index 203699e9a8d06..d16d6f9fba442 100644 --- a/benchmarks/overheads/benchmark_hashing.py +++ b/benchmarks/overheads/benchmark_hashing.py @@ -16,7 +16,6 @@ def main(args): enforce_eager=True, enable_prefix_caching=True, tensor_parallel_size=args.tensor_parallel_size, - use_v2_block_manager=args.use_v2_block_manager, ) sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) @@ -56,8 +55,5 @@ def main(args): parser.add_argument('--enable-prefix-caching', action='store_true', help='enable prefix caching') - parser.add_argument('--use-v2-block-manager', - action='store_true', - help='Use BlockSpaceMangerV2') args = parser.parse_args() main(args) diff --git a/docs/source/models/spec_decode.rst b/docs/source/models/spec_decode.rst index 0dc9cb383a7fd..b02c80aebec69 100644 --- a/docs/source/models/spec_decode.rst +++ b/docs/source/models/spec_decode.rst @@ -30,7 +30,6 @@ The following code configures vLLM in an offline mode to use speculative decodin tensor_parallel_size=1, speculative_model="facebook/opt-125m", num_speculative_tokens=5, - use_v2_block_manager=True, ) outputs = llm.generate(prompts, sampling_params) @@ -104,7 +103,6 @@ matching n-grams in the prompt. For more information read `this thread. 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) @@ -206,7 +199,6 @@ def test_with_prefix_caching( max_tokens: int, enforce_eager: bool, chunk_size: int, - use_v2_block_manager: bool, tensor_parallel_size: int, ) -> None: """ @@ -234,7 +226,6 @@ def test_with_prefix_caching( enable_chunked_prefill=True, enable_prefix_caching=enable, tensor_parallel_size=tensor_parallel_size, - use_v2_block_manager=use_v2_block_manager, enforce_eager=enforce_eager, max_num_seqs=max_num_seqs, ) as vllm_model: diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index b3f626714d351..86502f613b187 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -2,18 +2,11 @@ import pytest -from tests.utils import check_deprecated_block_manager_usage from vllm import SamplingParams from .conftest import get_token_ids_from_llm_generator -@pytest.fixture(scope="module", autouse=True) -def check_deprecated_block_manager(): - check_deprecated_block_manager_usage( - 'tests/core/block/e2e/test_correctness.py') - - @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -28,32 +21,32 @@ def check_deprecated_block_manager(): "num_gpu_blocks_override": 5 * (64 + 1), }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "use_v2_block_manager": False -}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{ - "use_v2_block_manager": True, "preemption_mode": "swap" }, { - "use_v2_block_manager": True, "preemption_mode": "recompute" }]) @pytest.mark.parametrize("batch_size", [10]) @pytest.mark.parametrize("seed", [1]) -def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, - test_llm_generator, batch_size): - """Verify block manager v2 produces same outputs as block manager v1, even - when there is preemption. +def test_block_manager_with_preemption(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify block manager produces same outputs even when there is preemption. This constructs two LLM, each with limited number of GPU blocks. The limit is decided such that as the sequences in the batch grow, sequences must be preempted and removed from cache. If the output token ids are equivalent, then we have confidence that the KV - cache is not corrupted in the v2 block manager. + cache is not corrupted. NOTE: We want a significant number of generated tokens so that any incorrect KV mapping has time to build up error. + + NOTE(Kuntai): Though we have removed block manager v1, this test is still + useful as it asserts the behavior of block manager v2 (now it is called + SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we + keep this test. """ output_len = 1024 temperature = 0.0 @@ -77,11 +70,9 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, temperature=temperature, ) - print('Getting token ids from block manager v1') baseline_token_ids = get_token_ids_from_llm_generator( baseline_llm_generator, prompts, sampling_params) - print('Getting token ids from block manager v2') test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, prompts, sampling_params) @@ -104,9 +95,6 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, # skip cuda graph creation for fast test. "enforce_eager": True, - - # Lookahead scheduling only supported in v2 block manager. - "use_v2_block_manager": True, }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -218,26 +206,22 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, "max_num_seqs": 10, }]) @pytest.mark.parametrize("baseline_llm_kwargs", [ - { - "use_v2_block_manager": False, - }, + {}, ]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "use_v2_block_manager": True, "num_lookahead_slots": 0, }, { - "use_v2_block_manager": True, "num_lookahead_slots": 5, }, ]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) -def test_chunked_prefill_block_manager_v2(baseline_llm_generator, - test_llm_generator, batch_size): - """Verify that chunked prefill works with BlockManagerV2, with and without - lookahead scheduling. +def test_chunked_prefill_block_manager(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify that chunked prefill works with SelfAttnBlockSpaceManager, + with and without lookahead scheduling. """ output_len = 32 temperature = 0.0 @@ -258,11 +242,11 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator, temperature=temperature, ) - print('Getting token ids with BlockManagerV1') + print('Getting token ids with BlockManager') baseline_token_ids = get_token_ids_from_llm_generator( baseline_llm_generator, prompts, sampling_params) - print('Getting token ids with BlockManagerV2') + print('Getting token ids with BlockManager, with lookahead slots.') test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, prompts, sampling_params) @@ -290,32 +274,32 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator, "enable_prefix_caching": True, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "use_v2_block_manager": False -}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{ - "use_v2_block_manager": True, "preemption_mode": "swap" }, { - "use_v2_block_manager": True, "preemption_mode": "recompute" }]) @pytest.mark.parametrize("batch_size", [10]) @pytest.mark.parametrize("seed", [1]) -def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption( +def test_block_manager_prefix_caching_enabled_with_preemption( baseline_llm_generator, test_llm_generator, batch_size): - """Verify block manager v2 produces same outputs as block manager v1, even - when there is preemption. + """Verify block manager produces same outputs even when there is preemption. This constructs two LLM, each with limited number of GPU blocks. The limit is decided such that as the sequences in the batch grow, sequences must be preempted and removed from cache. If the output token ids are equivalent, then we have confidence that the KV - cache is not corrupted in the v2 block manager. + cache is not corrupted. NOTE: We want a significant number of generated tokens so that any incorrect KV mapping has time to build up error. + + NOTE(Kuntai): Though we have removed block manager v1, this test is still + useful as it asserts the behavior of block manager v2 (now it is called + SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we + keep this test. """ output_len = 1024 temperature = 0.0 @@ -339,11 +323,11 @@ def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption( temperature=temperature, ) - print('Getting token ids from block manager v1') + print('Getting token ids from block manager') baseline_token_ids = get_token_ids_from_llm_generator( baseline_llm_generator, prompts, sampling_params) - print('Getting token ids from block manager v2') + print('Getting token ids from block manager, with preemption') test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, prompts, sampling_params) @@ -366,9 +350,6 @@ def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption( # Allow only 5 sequences of ~1024 tokens in worst case. "block_size": 16, "num_gpu_blocks_override": 5 * (64 + 1), - - # Test APC in v2 block - "use_v2_block_manager": True, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{ @@ -444,9 +425,6 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator, "max_model_len": 48, "block_size": 16, "num_gpu_blocks_override": 3, - - # Test APC in v2 block - "use_v2_block_manager": True, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{ diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index 731131984b0eb..9320a9ef62314 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -3,7 +3,6 @@ import pytest -from tests.utils import check_deprecated_block_manager_usage from vllm import LLM, SamplingParams from .conftest import get_text_from_llm_generator @@ -13,12 +12,6 @@ BLOCK_SIZE = 16 -@pytest.fixture(scope="module", autouse=True) -def check_deprecated_block_manager(): - check_deprecated_block_manager_usage( - 'tests/core/block/e2e/test_correctness_sliding_window.py') - - @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -31,10 +24,8 @@ def check_deprecated_block_manager(): "num_gpu_blocks_override": 100000 // BLOCK_SIZE, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "use_v2_block_manager": False -}]) -@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, @@ -55,7 +46,6 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, prompts, answer, indices = prep_prompts(batch_size) - print('Getting token ids from block manager v1') baseline_texts = get_text_from_llm_generator(baseline_llm_generator, prompts, sampling_params, @@ -91,10 +81,7 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, "num_gpu_blocks_override": 100000 // BLOCK_SIZE, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "use_v2_block_manager": True, - "enable_chunked_prefill": True -}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed): diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager.py similarity index 91% rename from tests/core/block/test_block_manager_v2.py rename to tests/core/block/test_block_manager.py index e67883367879f..cfd749ad58694 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager.py @@ -2,7 +2,7 @@ from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, STR_NOT_IMPL_ENC_DEC_SWA) -from vllm.core.block_manager_v2 import BlockSpaceManagerV2 +from vllm.core.block_manager import SelfAttnBlockSpaceManager from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list @@ -17,7 +17,7 @@ @pytest.mark.parametrize("watermark", [0.0, 0.5]) def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, watermark: float): - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=1024, @@ -63,7 +63,7 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, watermark: float): - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=1024, @@ -117,16 +117,16 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, ''' SWA short for Sliding Window Attention. - At time of writing block manager v2 does not support SWA. + At time of writing block manager does not support SWA. - However even when SWA is implemented for block manager v2, + However even when SWA is implemented for block manager, there will still most likely be a separate workstream required to enable SWA for encoder/decoder models. Therefore this test enforces that one of the following cases hold true: - 1. Block manager v2 does not support SWA at all (true at time of writing) - 2. Block manager v2 fails with NotImplementError when SWA is enabled + 1. Block manager does not support SWA at all (true at time of writing) + 2. Block manager fails with NotImplementError when SWA is enabled AND a SequenceGroup with an encoder sequence (i.e. in support of an encoder/decoder model) is passed into can_allocate() as an argument @@ -135,7 +135,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, ''' with pytest.raises((NotImplementedError, AssertionError)) as exc_info: - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=1024, @@ -158,7 +158,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, block_manager.can_allocate(seq_group) # Assert that either - # 1. Block manager v2 constructor fails with assertion that sliding window + # 1. Block manager constructor fails with assertion that sliding window # is not yet supported (most likely near-term outcome at time of # writing), or # 2. can_allocate() fails with NotImplementedError due to combination of @@ -177,7 +177,7 @@ def test_can_allocate_encoder_decoder_fails_with_prefix_cache( block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, watermark: float): - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=1024, @@ -217,7 +217,7 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append, num_gpu_blocks = 1024 watermark = 0.1 - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0, @@ -269,14 +269,15 @@ def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots, """Verify blocks number on src/desc device is correct after swapping in/out sequence group (not missing or extra blocks). """ - block_manager = BlockSpaceManagerV2(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) + block_manager = SelfAttnBlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=enable_caching) prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1) prompt.status = SequenceStatus.WAITING block_manager.allocate(seq_group) + # Emulate a forward pass by appending a single token. # The block manager then knows how many unprocessed # tokens will be written in the next forward pass. @@ -321,11 +322,11 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, can be swapped in/out. """ num_cpu_blocks = num_gpu_blocks - block_manager = BlockSpaceManagerV2(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) + block_manager = SelfAttnBlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=enable_caching) prompt, seq_group = create_dummy_prompt( "1", prompt_length=(num_gpu_blocks - 1) * block_size - 1) prompt.status = SequenceStatus.WAITING @@ -382,11 +383,11 @@ def test_swap_in_infeasible(num_lookahead_slots, enable_caching): block_size = 8 num_cpu_blocks = 1 num_gpu_blocks = 1 - block_manager = BlockSpaceManagerV2(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) + block_manager = SelfAttnBlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=enable_caching) prompt_length = block_size - 3 assert prompt_length > 0 prompt, seq_group = create_dummy_prompt("1", prompt_length=prompt_length) @@ -434,7 +435,7 @@ def test_sliding_window(block_size, prompt_len, num_slots_to_append, num_gpu_blocks = 1024 watermark = 0.1 - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0, @@ -474,7 +475,7 @@ def num_blocks(num_tokens): seq.data.update_num_computed_tokens(prompt_len) check_used(num_blocks(prompt_len)) - # this is how we compute it in BlockSpaceManagerV2.__init__ + # this is how we compute it in SelfAttnBlockSpaceManager.__init__ sliding_blocks = (sliding_window // block_size) + 2 # plus one block for null block sliding_blocks += 1 diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py deleted file mode 100644 index 2ee9f20824f2f..0000000000000 --- a/tests/core/test_block_manager.py +++ /dev/null @@ -1,637 +0,0 @@ -import time -from collections import defaultdict -from typing import List - -import pytest - -from vllm import SamplingParams -from vllm.block import PhysicalTokenBlock -from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - STR_NOT_IMPL_ENC_DEC_SWA) -from vllm.core.block_manager_v1 import (BlockSpaceManagerV1, - UncachedBlockAllocator) -from vllm.core.interfaces import AllocStatus -from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus -from vllm.utils import Device - -from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder - - -def test_block_allocator_allocate(): - block_size = 4 - num_cpu_blocks = 4 - cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size, - num_cpu_blocks) - - # Allocate all available cpu blocks. - num_free = num_cpu_blocks - assert cpu_allocator.get_num_free_blocks() == num_free - for _ in range(num_cpu_blocks): - block = cpu_allocator.allocate() - num_free -= 1 - - assert block not in cpu_allocator.free_blocks - assert cpu_allocator.get_num_free_blocks() == num_free - - with pytest.raises(ValueError): - cpu_allocator.allocate() - - -def test_block_allocator_free(): - block_size = 4 - num_cpu_blocks = 4 - cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size, - num_cpu_blocks) - - # Allocate all available cpu blocks. - blocks: List[PhysicalTokenBlock] = [] - for _ in range(num_cpu_blocks): - block = cpu_allocator.allocate() - blocks.append(block) - assert block not in cpu_allocator.free_blocks - - # Free all allocated cpu blocks. - num_free = 0 - assert cpu_allocator.get_num_free_blocks() == num_free - for block in blocks: - cpu_allocator.free(block) - num_free += 1 - assert block in cpu_allocator.free_blocks - assert cpu_allocator.get_num_free_blocks() == num_free - - with pytest.raises(ValueError): - cpu_allocator.free(block) - - -def test_allocate(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - # Allocate same sequence group to all available gpu blocks. - for i in range(num_gpu_blocks): - _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) == AllocStatus.OK - block_manager.allocate(seq_group) - assert block_manager.can_allocate(seq_group) != AllocStatus.OK - - # Allocate same sequence group to all available gpu blocks. - # Use watermark to reserve one gpu block. - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=1 / num_gpu_blocks) - for i in range(num_gpu_blocks - 1): - _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) == AllocStatus.OK - block_manager.allocate(seq_group) - assert block_manager.can_allocate(seq_group) != AllocStatus.OK - - -def test_allocate_encoder_decoder(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_req_per_seq_group = 2 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - # Allocate same sequence group to all available gpu blocks. - for i in range(num_gpu_blocks // block_req_per_seq_group): - _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - assert block_manager.can_allocate(seq_group) == AllocStatus.OK - block_manager.allocate(seq_group) - assert block_manager.can_allocate(seq_group) != AllocStatus.OK - - # Allocate same sequence group to all available gpu blocks. - # Use watermark to reserve one gpu block. - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=1 / num_gpu_blocks) - for i in range((num_gpu_blocks - 1) // block_req_per_seq_group): - _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - assert block_manager.can_allocate(seq_group) == AllocStatus.OK - block_manager.allocate(seq_group) - assert block_manager.can_allocate(seq_group) != AllocStatus.OK - - -def test_allocate_encoder_decoder_fails_with_swa(): - # SWA short for sliding window attention - - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - sliding_window=5) # swa - - # Allocate same sequence group to all available gpu blocks. - _, _, seq_group = create_dummy_prompt_encoder_decoder( - "0", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - - # Assert that can_allocate() fails due to SWA - with pytest.raises(NotImplementedError) as exc_info: - block_manager.can_allocate(seq_group) - - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA - - # Assert that allocate() fails due to SWA - with pytest.raises(NotImplementedError) as exc_info: - block_manager.allocate(seq_group) - - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA - - -def test_allocate_encoder_decoder_fails_with_prefix_caching(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=True) # Prefix cache - - # Allocate same sequence group to all available gpu blocks. - _, _, seq_group = create_dummy_prompt_encoder_decoder( - "0", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - - # Assert that can_allocate() fails due to prefix caching - with pytest.raises(NotImplementedError) as exc_info: - block_manager.can_allocate(seq_group) - - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE - - # Assert that allocate() fails due to prefix caching - with pytest.raises(NotImplementedError) as exc_info: - block_manager.allocate(seq_group) - - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE - - -def test_append_slot_single_seq(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - # Allocate single seq to gpu block. - prompt, seq_group = create_dummy_prompt("1", block_size) - block_manager.allocate(seq_group) - - # Nothing to append. Sequence has no new logical blocks. - assert block_manager.can_append_slots(seq_group) - before_blocks = block_manager.get_num_free_gpu_blocks() - assert not block_manager.append_slots(prompt) - after_blocks = block_manager.get_num_free_gpu_blocks() - assert before_blocks == after_blocks - - # Add block_size number of new tokens and append slot. - for i in range(block_size): - token_id = i + 5 - prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) - - assert block_manager.can_append_slots(seq_group) - before_blocks = block_manager.get_num_free_gpu_blocks() - assert not block_manager.append_slots(prompt) - after_blocks = block_manager.get_num_free_gpu_blocks() - assert before_blocks - after_blocks == 1 - - -def test_append_slot_cow(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size=block_size, - num_cpu_blocks=num_cpu_blocks, - num_gpu_blocks=num_gpu_blocks, - watermark=0) - - # Allocate prompt to gpu block. There is one slot left in the block. - prompt = Sequence(seq_id=1, - inputs={ - "prompt": "one two three", - "prompt_token_ids": [1, 2, 3], - }, - block_size=block_size) - - # Fork the sequence, such that a COW will be required when we append a new - # token id. - child = prompt.fork(new_seq_id=2) - - # Allocate space for the sequence group. - seq_group = SequenceGroup(request_id="1", - seqs=[prompt, child], - arrival_time=time.time(), - sampling_params=SamplingParams()) - block_manager.allocate(seq_group) - - # Fork and append a new token id. We expect a COW to be scheduled. - token_id = 4 - child.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.fork(prompt, child) - - assert block_manager.can_append_slots(seq_group) - before_blocks = block_manager.get_num_free_gpu_blocks() - - cows = block_manager.append_slots(child) - assert cows - dict_cows = defaultdict(list) - for src_block, dst_block in cows: - dict_cows[src_block].append(dst_block) - for src_block, dst_blocks in dict_cows.items(): - assert src_block not in dst_blocks - - after_blocks = block_manager.get_num_free_gpu_blocks() - assert before_blocks - after_blocks == 1 - - -def test_fork(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - prompt, seq_group = create_dummy_prompt("1", - block_size - 1, - block_size=block_size) - block_manager.allocate(seq_group) - - # Fork prompt and copy block tables. - child = prompt.fork(2) - block_manager.fork(prompt, child) - assert block_manager.get_block_table( - prompt) == block_manager.get_block_table(child) - token_id = 4 - # Append token to child. Block is shared so copy on write occurs. - child.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.append_slots(child) - assert block_manager.get_block_table( - prompt) != block_manager.get_block_table(child) - - -def test_swap(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1) - prompt.status = SequenceStatus.WAITING - block_manager.allocate(seq_group) - - # Emulate a forward pass by appending a single token. - # The block manager then knows how many unprocessed - # tokens will be written in the next forward pass. - token_id = 0 - prompt.status = SequenceStatus.RUNNING - prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) - - # Swap seq group from GPU -> CPU. - gpu_blocks = block_manager.get_block_table(prompt) - assert block_manager.can_swap_out(seq_group) - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_out(seq_group) - assert [x[0] for x in mapping] == gpu_blocks - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) - assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks - prompt.status = SequenceStatus.SWAPPED - - # Swap seq group from CPU -> GPU. - cpu_blocks = block_manager.get_block_table(prompt) - assert block_manager.can_swap_in(seq_group) == AllocStatus.OK - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_in(seq_group) - assert [x[0] for x in mapping] == cpu_blocks - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks - assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) - - -def test_swap_encoder_decoder(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - decoder_prompt, encoder_prompt, seq_group = \ - create_dummy_prompt_encoder_decoder( - "1", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - decoder_prompt.status = SequenceStatus.WAITING - encoder_prompt.status = SequenceStatus.WAITING - block_manager.allocate(seq_group) - - # Emulate a forward pass by appending a single token. - # The block manager then knows how many unprocessed - # tokens will be written in the next forward pass. - token_id = 0 - decoder_prompt.status = SequenceStatus.RUNNING - decoder_prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) - - # Swap encoder/decoder seq group from GPU -> CPU. - decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt) - cross_gpu_blocks = block_manager.get_cross_block_table(seq_group) - gpu_blocks = decoder_gpu_blocks + cross_gpu_blocks - assert block_manager.can_swap_out(seq_group) - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_out(seq_group) - assert [x[0] for x in mapping] == gpu_blocks - #assert list(mapping.keys()) == gpu_blocks - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) - assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks - decoder_prompt.status = SequenceStatus.SWAPPED - - # Swap encoder/decoder seq group from CPU -> GPU. - decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt) - cross_cpu_blocks = block_manager.get_cross_block_table(seq_group) - cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks - assert block_manager.can_swap_in(seq_group) == AllocStatus.OK - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_in(seq_group) - assert [x[0] for x in mapping] == cpu_blocks - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks - assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) - - -def test_free(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - prompt, seq_group = create_dummy_prompt("1", block_size) - block_manager.allocate(seq_group) - - # Free allocated seq. - prompt_blocks = len(block_manager.get_block_table(prompt)) - before_blocks = block_manager.get_num_free_gpu_blocks() - block_manager.free(prompt) - after_blocks = block_manager.get_num_free_gpu_blocks() - assert after_blocks == before_blocks + prompt_blocks - - # Block table for freed seq is deleted. - with pytest.raises(KeyError): - block_manager.get_block_table(prompt) - - -def test_free_encoder_decoder(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - decoder_prompt, encoder_prompt, seq_group = \ - create_dummy_prompt_encoder_decoder( - "1", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - block_manager.allocate(seq_group) - - # Free allocated seq. - decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt)) - encoder_prompt_blocks = len(block_manager.get_cross_block_table(seq_group)) - prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks - before_blocks = block_manager.get_num_free_gpu_blocks() - block_manager.free(decoder_prompt) - block_manager.free_cross(seq_group) - after_blocks = block_manager.get_num_free_gpu_blocks() - assert after_blocks == before_blocks + prompt_blocks - - # Block table for freed encoder & decoder seq's are deleted. - with pytest.raises(KeyError): - block_manager.get_block_table(decoder_prompt) - - # Block table for freed encoder & decoder seq's are deleted. - with pytest.raises(KeyError): - block_manager.get_block_table(encoder_prompt) - - -def test_reset(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - # Allocate same seq group on all available gpu blocks. - original_blocks = block_manager.get_num_free_gpu_blocks() - for i in range(num_gpu_blocks): - _, seq_group = create_dummy_prompt(str(i), block_size) - block_manager.allocate(seq_group) - assert block_manager.get_num_free_gpu_blocks() == 0 - - # Resetting block manager frees all allocated blocks. - block_manager.reset() - assert block_manager.get_num_free_gpu_blocks() == original_blocks - - -def test_reset_encoder_decoder(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_req_per_seq_group = 2 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - # Allocate same seq group on all available gpu blocks. - original_blocks = block_manager.get_num_free_gpu_blocks() - for i in range(num_gpu_blocks // block_req_per_seq_group): - _, _, seq_group = create_dummy_prompt_encoder_decoder( - f"{i}", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - block_manager.allocate(seq_group) - assert block_manager.get_num_free_gpu_blocks() == 0 - - # Resetting block manager frees all allocated blocks. - block_manager.reset() - assert block_manager.get_num_free_gpu_blocks() == original_blocks - - -def test_sliding_window_multi_seq(): - """ - Tests that memory allocation and deallocation is handled - correctly with multiple sequences that exceed the sliding - window's capacity. - """ - block_size = 1 - num_cpu_blocks = 8 - num_gpu_blocks = 8 - sliding_window = 2 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - sliding_window=sliding_window, - watermark=0) - - assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - - parent = Sequence(seq_id=1, - inputs={ - "prompt": "one two three", - "prompt_token_ids": [0, 1, 2], - }, - block_size=block_size) - seq_group = SequenceGroup(request_id="1", - seqs=[parent], - arrival_time=time.time(), - sampling_params=SamplingParams(), - lora_request=None) - block_manager.allocate(seq_group) - - # assert the number of blocks allocated is correct - # the parent seq has len 3, but since sliding_window is 2, - # we will use at most 2 blocks - assert block_manager.get_num_free_gpu_blocks( - ) == num_gpu_blocks - sliding_window - - # Fork prompt and copy block tables. - child = parent.fork(2) - block_manager.fork(parent, child) - - # assert the number of blocks allocated is correct - # forking does not increase memory consumption - assert block_manager.get_num_free_gpu_blocks( - ) == num_gpu_blocks - sliding_window - - # assert both parent and child share all blocks - assert block_manager.get_block_table( - parent) == block_manager.get_block_table(child) - - token_id = 4 - # Append token to child. Block is shared so copy on write occurs. - child.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.append_slots(child) - - # assert the number of blocks allocated is correct - # we will use now one block more. Each seq will use 2 blocks, - # but only one can be shared - assert block_manager.get_num_free_gpu_blocks( - ) == num_gpu_blocks - sliding_window - 1 - - token_id = 5 - parent.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.append_slots(parent) - - # assert the number of blocks allocated is correct - # no change, because both sequences are still just sharing one block - assert block_manager.get_num_free_gpu_blocks( - ) == num_gpu_blocks - sliding_window - 1 - - block_table_parent = block_manager.get_block_table(parent) - block_table_child = block_manager.get_block_table(child) - - assert block_table_parent != block_table_child - - # assert both blocks are sharing the second-last block - assert block_table_parent[-2] == block_table_child[-2] - - # now let's clean up... - block_manager.free(parent) - - # assert the number of blocks allocated is correct - # We have freed one seq, reducing the ref count of two blocks by one. - # One of the two was only used by the parent seq, so this is now free. - # The child seq still consumes sliding_window blocks - assert block_manager.get_num_free_gpu_blocks( - ) == num_gpu_blocks - sliding_window - - # free all blocks - block_manager.free(child) - - # assert all blocks are free now - assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - - -def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill(): - """When prefix cache and chunked prefill are enabled, the block manager - should only mark a chunk of blocks as computed instead of all blocks. - """ - - block_size = 4 - num_cpu_blocks = 0 - num_gpu_blocks = 16 - block_manager = BlockSpaceManagerV1(block_size, - num_gpu_blocks, - num_cpu_blocks, - watermark=0, - enable_caching=True) - - # Set prompt size to have num_gpu_blocks - 1 full blocks. - prompt_length = block_size * num_gpu_blocks - 1 - - # Allocate (reserve) all blocks. - _, seq_group = create_dummy_prompt("0", - prompt_length, - block_size=block_size) - block_manager.allocate(seq_group) - assert seq_group.seqs[0].n_blocks == num_gpu_blocks - - # 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed. - token_chunk_size = int(block_size * 2.5) - block_manager.mark_blocks_as_computed(seq_group, token_chunk_size) - computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0]) - assert len(computed_blocks) == 2 - - # Actual computed tokens. - seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size) - - # 2nd chunk: Complete 3rd block and additional 4 blocks. - token_chunk_size = int(block_size * 4.5) - block_manager.mark_blocks_as_computed(seq_group, token_chunk_size) - computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0]) - assert len(computed_blocks) == 7 diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index c9495fd50d7c9..f97caa06ff02d 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -8,7 +8,6 @@ from vllm.core.scheduler import Scheduler from vllm.sequence import Logprob, SequenceGroup -from ..utils import check_deprecated_block_manager_usage from .utils import create_dummy_prompt @@ -28,25 +27,16 @@ def schedule_and_update_computed_tokens(scheduler): return metas, out -@pytest.fixture(scope="module", autouse=True) -def check_deprecated_block_manager(): - check_deprecated_block_manager_usage( - 'tests/core/test_chunked_prefill_scheduler.py') - - -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_simple(use_v2_block_manager: bool): +def test_simple(): """Verify basic scheduling works.""" block_size = 4 num_seq_group = 4 max_model_len = 16 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - max_num_batched_tokens, - num_seq_group, - max_model_len, - enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + scheduler_config = SchedulerConfig(max_num_batched_tokens, + num_seq_group, + max_model_len, + enable_chunked_prefill=True) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -81,8 +71,7 @@ def test_simple(use_v2_block_manager: bool): assert len(seq_group_meta) == num_seq_group -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_chunk(use_v2_block_manager: bool): +def test_chunk(): """Verify prefills are chunked properly.""" block_size = 4 max_seqs = 60 @@ -93,7 +82,7 @@ def test_chunk(use_v2_block_manager: bool): max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 32 cache_config.num_gpu_blocks = 32 @@ -131,8 +120,7 @@ def test_chunk(use_v2_block_manager: bool): assert out.num_batched_tokens == 57 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_complex(use_v2_block_manager: bool): +def test_complex(): block_size = 4 max_seqs = 60 max_model_len = 80 @@ -142,7 +130,7 @@ def test_complex(use_v2_block_manager: bool): max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 64 cache_config.num_gpu_blocks = 64 @@ -201,8 +189,7 @@ def test_complex(use_v2_block_manager: bool): assert running[2].is_prefill() -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_maximal_decoding(use_v2_block_manager: bool): +def test_maximal_decoding(): """Verify decoding requests are prioritized.""" block_size = 4 max_seqs = 2 @@ -213,7 +200,7 @@ def test_maximal_decoding(use_v2_block_manager: bool): max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -295,8 +282,7 @@ def test_maximal_decoding(use_v2_block_manager: bool): assert out.num_batched_tokens == 2 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prompt_limit(use_v2_block_manager: bool): +def test_prompt_limit(): """Verify max_num_batched_tokens < max_model_len is possible.""" block_size = 4 max_seqs = 32 @@ -307,7 +293,7 @@ def test_prompt_limit(use_v2_block_manager: bool): max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 cache_config.num_gpu_blocks = 16 @@ -330,8 +316,7 @@ def test_prompt_limit(use_v2_block_manager: bool): assert out.num_batched_tokens == 32 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prompt_limit_exceed(use_v2_block_manager: bool): +def test_prompt_limit_exceed(): block_size = 4 max_seqs = 64 max_model_len = 32 @@ -356,8 +341,7 @@ def test_prompt_limit_exceed(use_v2_block_manager: bool): assert out.ignored_seq_groups[0] == seq_group -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_swap(use_v2_block_manager: bool): +def test_swap(): """Verify swapping works with chunked prefill requests""" block_size = 4 max_seqs = 30 @@ -368,7 +352,7 @@ def test_swap(use_v2_block_manager: bool): max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 cache_config.num_gpu_blocks = 16 @@ -414,8 +398,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert out.blocks_to_swap_out == [] -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_running_prefill_prioritized_over_swap(use_v2_block_manager: bool): +def test_running_prefill_prioritized_over_swap(): block_size = 4 max_seqs = 30 max_model_len = 200 @@ -425,7 +408,7 @@ def test_running_prefill_prioritized_over_swap(use_v2_block_manager: bool): max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 32 cache_config.num_gpu_blocks = 32 @@ -508,8 +491,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert out.blocks_to_swap_out == [] -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_chunked_prefill_preempt(use_v2_block_manager: bool): +def test_chunked_prefill_preempt(): """Verify preempt works with chunked prefill requests""" block_size = 4 max_seqs = 30 @@ -520,7 +502,7 @@ def test_chunked_prefill_preempt(use_v2_block_manager: bool): max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 cache_config.num_gpu_blocks = 16 @@ -575,8 +557,7 @@ def cannot_append_second_group2(seq_group, num_lookahead_slots): assert out.num_batched_tokens == max_num_batched_tokens -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_chunked_prefill_max_seqs(use_v2_block_manager: bool): +def test_chunked_prefill_max_seqs(): block_size = 4 max_seqs = 2 max_model_len = 80 @@ -586,7 +567,7 @@ def test_chunked_prefill_max_seqs(use_v2_block_manager: bool): max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 128 cache_config.num_gpu_blocks = 128 @@ -629,8 +610,7 @@ def test_chunked_prefill_max_seqs(use_v2_block_manager: bool): assert not running[1].is_prefill() -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_perfix_caching(use_v2_block_manager: bool): +def test_perfix_caching(): """Verify allocating full blocks when prefix caching is enabled.""" block_size = 4 max_seqs = 10 @@ -641,7 +621,7 @@ def test_perfix_caching(use_v2_block_manager: bool): max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, diff --git a/tests/core/test_num_computed_tokens_update.py b/tests/core/test_num_computed_tokens_update.py index f3ec24e7bee3e..bd4accab7f37d 100644 --- a/tests/core/test_num_computed_tokens_update.py +++ b/tests/core/test_num_computed_tokens_update.py @@ -31,7 +31,6 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, # Make a vllm engine runner = VllmRunner(model_name=MODEL, gpu_memory_utilization=0.7, - use_v2_block_manager=True, num_scheduler_steps=num_scheduler_steps, enable_chunked_prefill=enable_chunked_prefill, enforce_eager=enforce_eager) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 5cdf743a4509c..defa6c1bdaf78 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -3,7 +3,7 @@ from typing import List, Set, Tuple from unittest.mock import MagicMock -import pytest +import pytest # noqa from torch import Use # noqa from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig @@ -12,23 +12,18 @@ from vllm.lora.request import LoRARequest from vllm.sequence import SequenceGroup, SequenceStatus -from ..utils import check_deprecated_block_manager_usage from .utils import (append_new_token, append_new_token_seq_group, create_dummy_prompt, get_sequence_groups, schedule_and_update_computed_tokens) -@pytest.fixture(scope="module", autouse=True) -def check_deprecated_block_manager(): - check_deprecated_block_manager_usage( - "tests/core/test_chunked_prefill_scheduler.py") - - -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_add_seq_group(use_v2_block_manager: bool): +def test_scheduler_add_seq_group(): block_size = 4 scheduler_config = SchedulerConfig( - 100, 64, 1, use_v2_block_manager=use_v2_block_manager) + 100, + 64, + 1, + ) cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -44,11 +39,13 @@ def test_scheduler_add_seq_group(use_v2_block_manager: bool): assert scheduler.get_num_unfinished_seq_groups() == i + 1 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_abort_seq_group(use_v2_block_manager: bool): +def test_scheduler_abort_seq_group(): block_size = 4 scheduler_config = SchedulerConfig( - 100, 64, 1, use_v2_block_manager=use_v2_block_manager) + 100, + 64, + 1, + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -68,8 +65,7 @@ def test_scheduler_abort_seq_group(use_v2_block_manager: bool): assert scheduler.get_num_unfinished_seq_groups() == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_schedule_simple(use_v2_block_manager: bool): +def test_scheduler_schedule_simple(): block_size = 4 num_seq_group = 4 max_model_len = 16 @@ -77,7 +73,7 @@ def test_scheduler_schedule_simple(use_v2_block_manager: bool): 64, num_seq_group, max_model_len, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -112,8 +108,7 @@ def test_scheduler_schedule_simple(use_v2_block_manager: bool): append_new_token(out, 1) -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_prefill_prioritized(use_v2_block_manager: bool): +def test_scheduler_prefill_prioritized(): """Verify running batched tokens are not applied to prefill requests.""" block_size = 4 max_model_len = 30 @@ -122,7 +117,7 @@ def test_scheduler_prefill_prioritized(use_v2_block_manager: bool): max_batched_num_tokens, 2, max_model_len, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 cache_config.num_gpu_blocks = 16 @@ -146,12 +141,14 @@ def test_scheduler_prefill_prioritized(use_v2_block_manager: bool): assert get_sequence_groups(out) == [seq_group_b] -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_schedule_preempt_abort(use_v2_block_manager: bool): +def test_scheduler_schedule_preempt_abort(): block_size = 4 max_model_len = 16 scheduler_config = SchedulerConfig( - 64, 2, max_model_len, use_v2_block_manager=use_v2_block_manager) + 64, + 2, + max_model_len, + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 @@ -201,8 +198,7 @@ def test_scheduler_schedule_preempt_abort(use_v2_block_manager: bool): assert scheduler.get_num_unfinished_seq_groups() == 1 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_max_seqs(use_v2_block_manager: bool): +def test_scheduler_max_seqs(): block_size = 4 num_seq_group = 4 max_seq_group = 2 @@ -211,7 +207,7 @@ def test_scheduler_max_seqs(use_v2_block_manager: bool): 64, max_seq_group, max_model_len, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -249,15 +245,14 @@ def test_scheduler_max_seqs(use_v2_block_manager: bool): assert set(get_sequence_groups(out)) == set([all_seq_groups[1]]) -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_delay_factor(use_v2_block_manager: bool): +def test_scheduler_delay_factor(): block_size = 4 scheduler_config = SchedulerConfig( 100, 64, 16, delay_factor=0.5, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -294,12 +289,10 @@ def test_scheduler_delay_factor(use_v2_block_manager: bool): append_new_token(out, 1) -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_swapped_out_prioritized(use_v2_block_manager: bool): +def test_swapped_out_prioritized(): block_size = 4 scheduler = initialize_scheduler(max_num_seqs=6, block_size=block_size, - use_v2_block_manager=use_v2_block_manager, num_cpu_blocks=64, num_gpu_blocks=64) # best_of=2 * 3 == 6 sequences. @@ -351,7 +344,6 @@ def initialize_scheduler( max_token_budget=1000, max_model_len=1000, lora_config=None, - use_v2_block_manager=False, block_size=4, num_cpu_blocks=8, num_gpu_blocks=8, @@ -361,7 +353,7 @@ def initialize_scheduler( max_token_budget, max_num_seqs, max_model_len, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = num_cpu_blocks cache_config.num_gpu_blocks = num_gpu_blocks @@ -386,15 +378,12 @@ def add_token_budget(budget: SchedulingBudget, budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs) -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prefill_schedule_max_prompt_len(use_v2_block_manager: bool): +def test_prefill_schedule_max_prompt_len(): """ Test prompt longer than max_prompt_len is aborted. """ block_size = 4 - scheduler = initialize_scheduler(max_model_len=30, - use_v2_block_manager=use_v2_block_manager, - block_size=block_size) + scheduler = initialize_scheduler(max_model_len=30, block_size=block_size) _, seq_group = create_dummy_prompt("0", prompt_length=60, block_size=block_size) @@ -409,14 +398,12 @@ def test_prefill_schedule_max_prompt_len(use_v2_block_manager: bool): assert len(remaining_waiting) == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prefill_schedule_token_budget(use_v2_block_manager: bool): +def test_prefill_schedule_token_budget(): """ Test token budget respected. """ block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=64, num_gpu_blocks=64) budget = create_token_budget(token_budget=0) @@ -446,8 +433,7 @@ def test_prefill_schedule_token_budget(use_v2_block_manager: bool): assert len(remaining_waiting) == 1 # Test when current_batched_tokens respected. - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=16, num_gpu_blocks=16) budget = create_token_budget(token_budget=60) @@ -474,14 +460,12 @@ def test_prefill_schedule_token_budget(use_v2_block_manager: bool): assert len(remaining_waiting) == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prefill_schedule_max_seqs(use_v2_block_manager: bool): +def test_prefill_schedule_max_seqs(): """ Test max seq respected. """ block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=64, num_gpu_blocks=64) budget = create_token_budget(max_num_seqs=2) @@ -515,15 +499,13 @@ def test_prefill_schedule_max_seqs(use_v2_block_manager: bool): assert len(remaining_waiting) == 1 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prefill_schedule_max_lora(use_v2_block_manager: bool): +def test_prefill_schedule_max_lora(): """ Test max lora is respected and prioritized. """ block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) scheduler = initialize_scheduler(lora_config=lora_config, - use_v2_block_manager=use_v2_block_manager, block_size=block_size, num_cpu_blocks=64, num_gpu_blocks=64) @@ -570,14 +552,12 @@ def test_prefill_schedule_max_lora(use_v2_block_manager: bool): assert budget.num_batched_tokens == 60 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prefill_schedule_no_block_manager_capacity(use_v2_block_manager): +def test_prefill_schedule_no_block_manager_capacity(): """ Test sequence cannot be scheduled due to block manager has no capacity. """ block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_gpu_blocks=128, num_cpu_blocks=128) budget = create_token_budget() @@ -614,14 +594,12 @@ def test_prefill_schedule_no_block_manager_capacity(use_v2_block_manager): assert len(remaining_waiting) == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_decode_schedule_preempted(use_v2_block_manager: bool): +def test_decode_schedule_preempted(): """ Test decodes cannot be scheduled and preempted. """ block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=64, num_gpu_blocks=64) curr_loras = None @@ -660,14 +638,12 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert output.blocks_to_copy == [] -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_decode_swap_beam_search(use_v2_block_manager: bool): +def test_decode_swap_beam_search(): """ Test best_of > 1 swap out blocks """ block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_gpu_blocks=64, num_cpu_blocks=64) curr_loras = None @@ -716,14 +692,12 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert output.blocks_to_copy == [] -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_decode_blocks_to_copy_update(use_v2_block_manager: bool): +def test_schedule_decode_blocks_to_copy_update(): """ Verify blocks_to_copy is updated. """ block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=4, + scheduler = initialize_scheduler(block_size=4, num_cpu_blocks=16, num_gpu_blocks=16) _, seq_group = create_dummy_prompt("1", @@ -754,11 +728,9 @@ def test_schedule_decode_blocks_to_copy_update(use_v2_block_manager: bool): assert output.blocks_to_copy == [(2, 3)] -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_swapped_simple(use_v2_block_manager: bool): +def test_schedule_swapped_simple(): block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size) + scheduler = initialize_scheduler(block_size=block_size) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] _, seq_group = create_dummy_prompt("1", @@ -785,11 +757,9 @@ def test_schedule_swapped_simple(use_v2_block_manager: bool): assert blocks_to_swap_out == blocks_to_swap_in_reverse -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_swapped_max_token_budget(use_v2_block_manager: bool): +def test_schedule_swapped_max_token_budget(): block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=32, num_gpu_blocks=32) curr_loras = None @@ -822,11 +792,9 @@ def test_schedule_swapped_max_token_budget(use_v2_block_manager: bool): assert len(output.prefill_seq_groups) == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_swapped_max_seqs(use_v2_block_manager: bool): +def test_schedule_swapped_max_seqs(): block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=64, num_gpu_blocks=64) curr_loras = None @@ -859,12 +827,10 @@ def test_schedule_swapped_max_seqs(use_v2_block_manager: bool): assert len(output.prefill_seq_groups) == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_swapped_max_loras(use_v2_block_manager: bool): +def test_schedule_swapped_max_loras(): block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) scheduler = initialize_scheduler(lora_config=lora_config, - use_v2_block_manager=use_v2_block_manager, block_size=block_size, num_cpu_blocks=32, num_gpu_blocks=32) @@ -894,11 +860,9 @@ def test_schedule_swapped_max_loras(use_v2_block_manager: bool): assert len(curr_loras) == 1 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_swapped_cannot_swap_in(use_v2_block_manager: bool): +def test_schedule_swapped_cannot_swap_in(): block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=32, num_gpu_blocks=32) curr_loras = None @@ -927,11 +891,9 @@ def test_schedule_swapped_cannot_swap_in(use_v2_block_manager: bool): assert len(output.prefill_seq_groups) == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_infeasible_swap(use_v2_block_manager: bool): +def test_infeasible_swap(): block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=32, num_gpu_blocks=32) curr_loras = None @@ -961,11 +923,9 @@ def test_infeasible_swap(use_v2_block_manager: bool): assert len(output.prefill_seq_groups) == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_swapped_blocks_to_copy(use_v2_block_manager: bool): +def test_schedule_swapped_blocks_to_copy(): block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=32, num_gpu_blocks=32) curr_loras = None diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index f1003221ab518..8798ff078843a 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -185,13 +185,14 @@ def test_metric_spec_decode( ) -> None: k = 5 - with vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4, - speculative_model=model, - num_speculative_tokens=k, - use_v2_block_manager=True) as vllm_model: + with vllm_runner( + model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.4, + speculative_model=model, + num_speculative_tokens=k, + ) as vllm_model: # Force log interval to be 0 to catch all metrics. stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] @@ -242,7 +243,6 @@ def test_metric_spec_decode_interval( gpu_memory_utilization=0.4, speculative_model=model, num_speculative_tokens=k, - use_v2_block_manager=True, enforce_eager=True) engine = LLMEngine.from_engine_args(engine_args) diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index 000c923ef3e6e..7203d635c2fa8 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -17,7 +17,6 @@ DEFAULT_SERVER_ARGS: List[str] = [ "--disable-log-requests", - "--use-v2-block-manager", "--worker-use-ray", "--gpu-memory-utilization", "0.85", diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index f45428675bde8..cc1fd19252019 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -76,7 +76,6 @@ def test_multi_step_llm( enforce_eager=enforce_eager, gpu_memory_utilization=0.7, tensor_parallel_size=tp_size, - use_v2_block_manager=True, enable_chunked_prefill=enable_chunked_prefill, num_scheduler_steps=num_scheduler_steps, ) as vllm_model: @@ -169,7 +168,6 @@ def test_multi_step_llm_w_prompt_logprobs( enforce_eager=enforce_eager, gpu_memory_utilization=0.7, tensor_parallel_size=tp_size, - use_v2_block_manager=True, num_scheduler_steps=num_scheduler_steps, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( @@ -305,7 +303,6 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( enforce_eager=enforce_eager, gpu_memory_utilization=0.7, tensor_parallel_size=tp_size, - use_v2_block_manager=True, num_scheduler_steps=num_scheduler_steps, max_model_len=48, max_num_batched_tokens=48, @@ -324,7 +321,6 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( enforce_eager=enforce_eager, gpu_memory_utilization=0.7, tensor_parallel_size=tp_size, - use_v2_block_manager=True, enable_chunked_prefill=True, enable_prefix_caching=True, num_scheduler_steps=num_scheduler_steps, diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 88437425feb31..366b030eaa399 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -2,15 +2,9 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`. """ -from typing import List - import pytest from tests.kernels.utils import override_backend_env_variable -from tests.utils import check_deprecated_block_manager_usage -from vllm.block import PhysicalTokenBlock -from vllm.core.block_manager_v1 import CachedBlockAllocator -from vllm.utils import Device from ..models.utils import check_outputs_equal @@ -19,92 +13,11 @@ ] -@pytest.fixture(scope="module", autouse=True) -def check_deprecated_block_manager(): - check_deprecated_block_manager_usage( - 'tests/prefix_caching/test_prefix_caching.py') - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("num_blocks", [16]) -def test_block_allocator( - block_size: int, - num_blocks: int, -): - block_hash = 1 - block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks) - - # Allocate two PysicalTokenBlocks with the same hash and check - # that they are the same PhysicalTokenBlock - first_block = block_allocator.allocate(block_hash, 0) - second_block = block_allocator.allocate(block_hash, 0) - assert (first_block == second_block) - assert (second_block.ref_count == 2) - - # Check metric: 1 hit of 2 queries - assert block_allocator.get_prefix_cache_hit_rate() == 0.5 - - # Free the first_block and confirm that the ref_count is correctly - # decremented on the second block - block_allocator.free(first_block) - assert (second_block.ref_count == 1) - - # Free the second block - block_allocator.free(second_block) - - # Reallocate the first block and confirm that, even after the block - # had its ref_count go to 0, we still get the same block back - first_block = block_allocator.allocate(block_hash, 0) - assert (first_block == second_block) - assert (first_block.block_hash == block_hash) - - # Allocate one more time to get 3/4 hit rate for easy checking - block_allocator.allocate(block_hash, 0) - assert block_allocator.get_prefix_cache_hit_rate() == 0.75 - - -@pytest.mark.parametrize("num_blocks", [16]) -def test_eviction(num_blocks: int, ): - block_size = 16 - block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks) - blocks: List[PhysicalTokenBlock] = [] - - for i in range(num_blocks): - # use i as the block_hash - blocks.append(block_allocator.allocate(i, 0)) - - #Free all blocks - for block in blocks: - block_allocator.free(block) - - # Allocate a new block and confirm that it's the first block freed. - # I.E The Least Recently Used block - new_block_hash = block_size - new_block = block_allocator.allocate(new_block_hash, 0) - assert (new_block == blocks[0]) - assert (new_block.block_hash == new_block_hash) - - # Reallocate the second in blocks to remove it from the free list - realloc_block_hash = 1 - realloc_block = block_allocator.allocate(realloc_block_hash, 0) - assert (realloc_block == blocks[realloc_block_hash]) - assert (realloc_block.block_hash == realloc_block_hash) - - # Allocate a new block and confirm that it's not the realloc_block, - # since the realloc_block shouldn't be in the free list - new_block_hash = block_size + 1 - new_block = block_allocator.allocate(new_block_hash, 0) - assert (realloc_block != new_block) - assert (new_block.block_hash == new_block_hash) - assert (new_block.block_number == 2) - - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("cached_position", [0, 1]) -@pytest.mark.parametrize("use_v2_block_manager", [False, True]) def test_mixed_requests( hf_runner, vllm_runner, @@ -114,7 +27,6 @@ def test_mixed_requests( dtype: str, max_tokens: int, cached_position: int, - use_v2_block_manager: bool, monkeypatch, ) -> None: """ @@ -132,7 +44,6 @@ def test_mixed_requests( model, dtype=dtype, enable_prefix_caching=True, - use_v2_block_manager=use_v2_block_manager, ) as vllm_model: # Run the first prompt so the cache is populated vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index 69ea81cfffed4..629074188a6c1 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -1,27 +1,15 @@ import pytest -from tests.utils import check_deprecated_block_manager_usage from vllm import SamplingParams from .conftest import get_output_from_llm_generator -@pytest.fixture(scope="module", autouse=True) -def check_deprecated_block_manager(): - check_deprecated_block_manager_usage( - 'tests/spec_decode/e2e/test_compatibility.py') - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model": "JackFram/llama-68m", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - - # Required for spec decode. - "use_v2_block_manager": True - }]) +@pytest.mark.parametrize("common_llm_kwargs", [{ + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, +}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { "enable_chunked_prefill": True, @@ -51,16 +39,11 @@ def test_spec_decode_xfail_chunked_prefill(test_llm_generator): sampling_params) -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model": "meta-llama/Llama-2-7b-chat-hf", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - - # Required for spec decode. - "use_v2_block_manager": True - }]) +@pytest.mark.parametrize("common_llm_kwargs", [{ + "model": "meta-llama/Llama-2-7b-chat-hf", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, +}]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ @@ -101,34 +84,3 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator): with pytest.raises(ValueError, match="cannot be larger than"): get_output_from_llm_generator(test_llm_generator, prompts, sampling_params) - - -@pytest.mark.parametrize("common_llm_kwargs", [{ - "model": "JackFram/llama-68m", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "use_v2_block_manager": False, -}]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_xfail_block_manager_v1(test_llm_generator): - """Verify that speculative decoding with block manager v1 fails. - """ - output_len = 128 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - ] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - with pytest.raises(ValueError, - match="Speculative decoding requires usage of the V2"): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index d7ca8815ec259..5bc70de9dac56 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -43,9 +43,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -86,9 +83,6 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -143,9 +137,6 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, [{ "enforce_eager": False, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -191,9 +182,6 @@ def test_eagle_e2e_greedy_correctness_cuda_graph( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -235,9 +223,6 @@ def test_eagle_e2e_greedy_correctness_with_preemption( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -283,9 +268,6 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py index d04e312689bcc..b89e5849727f4 100644 --- a/tests/spec_decode/e2e/test_integration.py +++ b/tests/spec_decode/e2e/test_integration.py @@ -12,8 +12,6 @@ @pytest.mark.parametrize( "common_llm_kwargs", [{ - # Required for spec decode. - "use_v2_block_manager": True, # Verify equality when cuda graphs allowed. "enforce_eager": False, @@ -57,9 +55,6 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { @@ -111,9 +106,6 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True, "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 3, }]) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index 679a6ded9ee79..b829d1a5be784 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -17,9 +17,6 @@ [[ # Skip cuda graph recording for fast test. "--enforce-eager", - - # Required for spec decode. - "--use-v2-block-manager", "--tensor-parallel-size", "2" ]]) @@ -74,9 +71,6 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs, [[ # Skip cuda graph recording for fast test. "--enforce-eager", - - # Required for spec decode. - "--use_v2_block_manager", "--tensor_parallel_size", "2", diff --git a/tests/spec_decode/e2e/test_integration_dist_tp4.py b/tests/spec_decode/e2e/test_integration_dist_tp4.py index 3f7c5d749e4f9..555aef99218c3 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp4.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp4.py @@ -19,9 +19,6 @@ [[ # Skip cuda graph recording for fast test. "--enforce_eager", - - # Required for spec decode. - "--use-v2-block-manager", "--tensor-parallel-size", "4", ]]) @@ -71,9 +68,6 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs, # Skip cuda graph recording for fast test. "--enforce-eager", - - # Required for spec decode. - "--use-v2-block-manager", "--tensor-parallel-size", "4", ]]) diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index b7d54991e0535..4cfca8b78e79b 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -14,9 +14,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -67,9 +64,6 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -119,9 +113,6 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -173,9 +164,6 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -251,8 +239,6 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs, "model_name": "JackFram/llama-160m", # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 0b36e712a11b2..b8965606b3d0e 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -45,9 +45,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -93,9 +90,6 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -151,9 +145,6 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, [{ "enforce_eager": False, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -204,9 +195,6 @@ def test_medusa_e2e_greedy_correctness_cuda_graph( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -253,9 +241,6 @@ def test_medusa_e2e_greedy_correctness_with_preemption( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -306,9 +291,6 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -356,9 +338,6 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 52b48a33c3097..5ecc0d4e95719 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -47,9 +47,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -94,9 +91,6 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -149,9 +143,6 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -195,9 +186,6 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -258,9 +246,6 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -311,9 +296,6 @@ def test_mlp_e2e_greedy_correctness_with_preemption( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -366,9 +348,6 @@ def patched_pad_vocab_size(vocab_size, pad_to=None): # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -419,9 +398,6 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -469,9 +445,6 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True, "speculative_model": SPEC_MODEL, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index df6f12d57b400..5f240d42d9e09 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -55,9 +55,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True, }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -124,9 +121,6 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, }]) @@ -190,9 +184,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, }]) @@ -246,9 +237,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( [{ # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -303,9 +291,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, }]) @@ -353,9 +338,6 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, }]) @@ -404,9 +386,6 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { @@ -454,9 +433,6 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -514,9 +490,6 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -570,9 +543,6 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -611,9 +581,6 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -660,9 +627,6 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 5862459383167..31bedad480283 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -35,9 +35,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, }]) @@ -82,9 +79,6 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, }]) @@ -145,9 +139,6 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { @@ -195,9 +186,6 @@ def test_ngram_e2e_greedy_correctness_with_preemption( # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -254,9 +242,6 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -303,7 +288,6 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, "enforce_eager": True, # Required for spec decode. - "use_v2_block_manager": True, "speculative_model": "[ngram]", "num_speculative_tokens": 5, "ngram_prompt_lookup_max": 3, diff --git a/tests/spec_decode/e2e/test_seed.py b/tests/spec_decode/e2e/test_seed.py index b17013216ae23..e42cf416b159f 100644 --- a/tests/spec_decode/e2e/test_seed.py +++ b/tests/spec_decode/e2e/test_seed.py @@ -17,9 +17,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # speculative model "speculative_model": "JackFram/llama-160m", diff --git a/tests/utils.py b/tests/utils.py index 924465057468f..115cab80691f0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -678,12 +678,3 @@ def get_client_text_logprob_generations( return [(text_generations, text, (None if x.logprobs is None else x.logprobs.top_logprobs)) for completion in completions for x in completion.choices] - - -def check_deprecated_block_manager_usage(test_name: str): - assert envs.VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1 is True, ( - f"To allow the use of deprecated BlockSpaceManagerV1, set the " - f"environment variable VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1. " - f"You can run the tests with: " - f"`VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest {test_name}`" #noqa - ) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 8457bde066eb7..d54dbdcb19495 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -305,8 +305,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.runner = input_builder.runner self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size - self.use_v2_block_manager = ( - input_builder.scheduler_config.use_v2_block_manager) def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", @@ -355,9 +353,9 @@ def _add_seq_group( # Compute slot mapping. is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx( - is_prompt, query_len, context_len, self.sliding_window, - self.use_v2_block_manager) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ba9b2d043c640..dd9a0fb9d94df 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -475,8 +475,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size - self.use_v2_block_manager = ( - input_builder.scheduler_config.use_v2_block_manager) # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout # for the precise definition of the following fields. @@ -542,9 +540,9 @@ def _add_seq_group( is_profile_run = is_block_tables_empty(block_tables) # Compute slot mapping. - start_idx = compute_slot_mapping_start_idx( - is_prompt, query_len, context_len, self.sliding_window, - self.use_v2_block_manager) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 53e3a53badeae..358a223e7ed0e 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -38,18 +38,12 @@ def is_block_tables_empty(block_tables: Union[None, Dict]): def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, - context_len: int, sliding_window: int, - use_v2_block_manager: bool): + context_len: int, sliding_window: int): """ Compute the start index of slot mapping. """ start_idx = 0 if is_prompt and sliding_window is not None: - assert use_v2_block_manager or context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention in V1 block manager") - # When prefill, we use it to not write slots to kv cache - # to save memory. start_idx = max(0, query_len - sliding_window) return start_idx @@ -138,8 +132,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size - self.use_v2_block_manager = ( - input_builder.scheduler_config.use_v2_block_manager) def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", @@ -180,9 +172,9 @@ def _add_seq_group( # Compute slot mapping. is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx( - is_prompt, query_len, context_len, self.sliding_window, - self.use_v2_block_manager) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) diff --git a/vllm/commit_id.py b/vllm/commit_id.py new file mode 100644 index 0000000000000..d857066f1f51b --- /dev/null +++ b/vllm/commit_id.py @@ -0,0 +1 @@ +__commit__ = "93ec62b8556e279d2c050bdc1c3247831bd39466" diff --git a/vllm/config.py b/vllm/config.py index 2e98923a3cb24..4533fb017188c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -949,7 +949,6 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). - use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. num_lookahead_slots: The number of slots to allocate per sequence per step, beyond the known token ids. This is used in speculative decoding to store KV activations of tokens which may or may not be @@ -976,7 +975,6 @@ def __init__(self, max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, - use_v2_block_manager: bool = True, num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, @@ -1026,7 +1024,6 @@ def __init__(self, self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len - self.use_v2_block_manager = use_v2_block_manager self.num_lookahead_slots = num_lookahead_slots self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill @@ -1067,18 +1064,6 @@ def _verify_args(self) -> None: f"({self.num_scheduler_steps}) must be greater than or " "equal to 1.") - if (not self.use_v2_block_manager \ - and not envs.VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1): - raise ValueError( - "The use of BlockSpaceManagerV1 is deprecated and will " - "be removed in a future release. Please switch to " - "BlockSpaceManagerV2 by setting --use-v2-block-manager to " - "True. If you wish to suppress this error temporarily, " - "you can set the environment variable " - "`VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1. If your use " - "case is not supported in BlockSpaceManagerV2, please " - "file an issue with detailed information.") - @property def is_multi_step(self) -> bool: return self.num_scheduler_steps > 1 @@ -1137,7 +1122,6 @@ def maybe_create_spec_config( speculative_disable_mqa_scorer: Optional[bool], speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, - use_v2_block_manager: bool, disable_log_stats: bool, speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], @@ -1178,9 +1162,6 @@ def maybe_create_spec_config( enable_chunked_prefill (bool): Whether vLLM is configured to use chunked prefill or not. Used for raising an error since its not yet compatible with spec decode. - use_v2_block_manager (bool): Whether vLLM is configured to use the - v2 block manager or not. Used for raising an error since the v2 - block manager is required with spec decode. speculative_disable_by_batch_size (Optional[int]): Disable speculative decoding for new incoming requests when the number of enqueue requests is larger than this value, if provided. @@ -1231,11 +1212,6 @@ def maybe_create_spec_config( "Speculative decoding and chunked prefill are " f"currently mutually exclusive ({enable_chunked_prefill=}).") - if not use_v2_block_manager: - raise ValueError( - "Speculative decoding requires usage of the V2 " - "block manager. Enable it with --use-v2-block-manager.") - # TODO: The user should be able to specify revision/max model len # for the draft model. It is not currently supported. draft_revision = None diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 28839437c33c5..1c6578e4cc6ab 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -4,28 +4,6 @@ STR_NOT_IMPL_ENC_DEC_SWA) -def _get_block_mgr_sliding_window_attr(block_mgr): - ''' - BlockManagerV1 and BlockManagerV2 have slightly different - members related to sliding window attention (SWA). This - function extracts the appropriate member to use for determining - whether SWA is enabled. - - Arguments: - - * block_mgr: BlockManagerV1 or BlockManagerV2 instance - ''' - - if hasattr(block_mgr, 'block_sliding_window'): - return block_mgr.block_sliding_window - if hasattr(block_mgr, 'max_block_sliding_window'): - return block_mgr.max_block_sliding_window - - raise AttributeError("Block manager instance has neither " + \ - "block_sliding_window nor " + \ - "max_block_sliding_window attributes.") - - def check_no_caching_or_swa_for_blockmgr_encdec( block_mgr, seq_group: SequenceGroup) -> None: ''' @@ -41,7 +19,7 @@ def check_no_caching_or_swa_for_blockmgr_encdec( ''' if seq_group.is_encoder_decoder(): - if _get_block_mgr_sliding_window_attr(block_mgr) is not None: + if block_mgr.max_block_sliding_window is not None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) if block_mgr.enable_caching: diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager.py similarity index 99% rename from vllm/core/block_manager_v2.py rename to vllm/core/block_manager.py index cb047c832e6cb..61ed7afba12ed 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager.py @@ -17,7 +17,7 @@ EncoderSeqId = str -class BlockSpaceManagerV2(BlockSpaceManager): +class SelfAttnBlockSpaceManager(BlockSpaceManager): """BlockSpaceManager which manages the allocation of KV cache. It owns responsibility for allocation, swapping, allocating memory for diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py deleted file mode 100644 index 8bc0ce2bc6626..0000000000000 --- a/vllm/core/block_manager_v1.py +++ /dev/null @@ -1,743 +0,0 @@ -"""A block manager that manages token blocks.""" -import math -from abc import ABC, abstractmethod -from itertools import count, takewhile -from os.path import commonprefix -from typing import Dict, List, Optional -from typing import Sequence as GenericSequence -from typing import Set, Tuple - -from vllm.block import BlockTable, PhysicalTokenBlock -from vllm.core.block.common import CacheMetricData -from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec -from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor -from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.logger import init_logger -from vllm.sequence import Sequence, SequenceGroup, SequenceStatus -from vllm.utils import Device - -logger = init_logger(__name__) - - -class BlockAllocatorBase(ABC): - """Manages free physical token blocks for a device. - - The allocator maintains a list of free blocks and allocates a block when - requested. When a block is freed, its reference count is decremented. If - the reference count becomes zero, the block is added back to the free list. - """ - - @abstractmethod - def __init__(self, - device: Device, - block_size: int, - num_blocks: int, - eviction_policy: EvictionPolicy = EvictionPolicy.LRU): - pass - - @abstractmethod - def allocate(self, - block_hash: Optional[int] = None, - num_hashed_tokens: int = 0) -> PhysicalTokenBlock: - pass - - @abstractmethod - def free(self, block: PhysicalTokenBlock) -> None: - pass - - @abstractmethod - def get_num_free_blocks(self) -> int: - pass - - @abstractmethod - def get_num_total_blocks(self) -> int: - pass - - @abstractmethod - def contains_block(self, block_hash: int) -> bool: - pass - - @abstractmethod - def update_hash(self, block_hash: int, block: PhysicalTokenBlock): - pass - - @abstractmethod - def get_prefix_cache_hit_rate(self) -> float: - """Prefix cache hit rate. -1 means not supported or disabled.""" - pass - - -class CachedBlockAllocator(BlockAllocatorBase): - """Manages free physical token blocks for a device. - - The allocator maintains a list of free blocks and allocates a block when - requested. When a block is freed, its reference count is decremented. If - the reference count becomes zero, the block is added back to the free list. - """ - - def __init__(self, - device: Device, - block_size: int, - num_blocks: int, - eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None: - self.device = device - self.block_size = block_size - self.num_blocks = num_blocks - - self.current_num_blocks = 0 - self.cached_blocks: Dict[int, PhysicalTokenBlock] = {} - - self.evictor: Evictor = make_evictor(eviction_policy) - - self.default_hash_ctr = count() - - self.cache_metric_data = CacheMetricData() - - def allocate_block(self, block_hash: int, - num_hashed_tokens: int) -> PhysicalTokenBlock: - if self.current_num_blocks == self.num_blocks: - block = self.evictor.evict() - block.block_hash = block_hash - block.num_hashed_tokens = num_hashed_tokens - return block - block = PhysicalTokenBlock(device=self.device, - block_number=self.current_num_blocks, - block_size=self.block_size, - block_hash=block_hash, - num_hashed_tokens=num_hashed_tokens) - self.current_num_blocks += 1 - return block - - def allocate(self, - block_hash: Optional[int] = None, - num_hashed_tokens: int = 0) -> PhysicalTokenBlock: - if block_hash is None: - block_hash = next(self.default_hash_ctr) - - if block_hash in self.evictor: - assert block_hash not in self.cached_blocks - block = self.evictor.remove(block_hash) - assert block.ref_count == 0 - self.cached_blocks[block_hash] = block - - if block_hash in self.cached_blocks: - self.cache_metric_data.query(hit=True) - else: - self.cache_metric_data.query(hit=False) - self.cached_blocks[block_hash] = self.allocate_block( - block_hash, num_hashed_tokens) - block = self.cached_blocks[block_hash] - assert block.block_hash == block_hash - block.ref_count += 1 - return block - - def free(self, block: PhysicalTokenBlock) -> None: - if block.ref_count == 0: - raise ValueError(f"Double free! {block} is already freed.") - block.ref_count -= 1 - if block.ref_count == 0: - assert block.block_hash not in self.evictor - self.evictor.add(block) - - # Remove the block from the cached_blocks - del self.cached_blocks[block.block_hash] - - def get_num_free_blocks(self) -> int: - return (self.num_blocks - self.current_num_blocks + - self.evictor.num_blocks) - - def get_num_total_blocks(self) -> int: - return self.num_blocks - - def contains_block(self, block_hash: int) -> bool: - return block_hash in self.cached_blocks or block_hash in self.evictor - - def update_hash(self, block_hash: int, block: PhysicalTokenBlock): - # Update the hash of block and the cached_blocks dictionary. - assert not self.contains_block(block_hash) - old_hash = block.block_hash - block.block_hash = block_hash - del self.cached_blocks[old_hash] - self.cached_blocks[block_hash] = block - - def get_prefix_cache_hit_rate(self) -> float: - return self.cache_metric_data.get_hit_rate() - - -class UncachedBlockAllocator(BlockAllocatorBase): - """Manages free physical token blocks for a device. - - The allocator maintains a list of free blocks and allocates a block when - requested. When a block is freed, its reference count is decremented. If - the reference count becomes zero, the block is added back to the free list. - """ - - def __init__( - self, - device: Device, - block_size: int, - num_blocks: int, - ) -> None: - self.device = device - self.block_size = block_size - self.num_blocks = num_blocks - - # Initialize the free blocks. - self.free_blocks: List[PhysicalTokenBlock] = [] - for i in range(num_blocks): - block = PhysicalTokenBlock(device=device, - block_number=i, - block_size=block_size, - block_hash=-1, - num_hashed_tokens=0) - self.free_blocks.append(block) - - def allocate(self, - block_hash: Optional[int] = None, - num_hashed_tokens: int = 0) -> PhysicalTokenBlock: - if not self.free_blocks: - raise ValueError("Out of memory! No free blocks are available.") - block = self.free_blocks.pop() - block.ref_count = 1 - return block - - def free(self, block: PhysicalTokenBlock) -> None: - if block.ref_count == 0: - raise ValueError(f"Double free! {block} is already freed.") - block.ref_count -= 1 - if block.ref_count == 0: - self.free_blocks.append(block) - - def get_num_free_blocks(self) -> int: - return len(self.free_blocks) - - def get_num_total_blocks(self) -> int: - return self.num_blocks - - def contains_block(self, block_hash: int) -> bool: - raise NotImplementedError( - "Invalid codepath for uncached block allocator.") - - def update_hash(self, block_hash: int, block: PhysicalTokenBlock): - raise NotImplementedError( - "Invalid codepath for uncached block allocator.") - - def get_prefix_cache_hit_rate(self) -> float: - return -1 - - -class BlockSpaceManagerV1(BlockSpaceManager): - """Manages the mapping between logical and physical token blocks.""" - - def __init__( - self, - block_size: int, - num_gpu_blocks: int, - num_cpu_blocks: int, - watermark: float = 0.01, - sliding_window: Optional[int] = None, - enable_caching: bool = False, - ) -> None: - self.block_size = block_size - self.num_total_gpu_blocks = num_gpu_blocks - self.num_total_cpu_blocks = num_cpu_blocks - - if enable_caching and sliding_window is not None: - raise NotImplementedError( - "Sliding window is not allowed with prefix caching enabled!") - - self.block_sliding_window = None - if sliding_window is not None: - # Round up to nearest block size to regularize sliding window - # allocation sizes. - self.block_sliding_window = math.ceil(sliding_window / block_size) - - self.watermark = watermark - assert watermark >= 0.0 - - self.enable_caching = enable_caching - - self.watermark_blocks = int(watermark * num_gpu_blocks) - - if self.enable_caching: - logger.info("Automatic prefix caching is enabled.") - self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator( - Device.GPU, block_size, num_gpu_blocks) - self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator( - Device.CPU, block_size, num_cpu_blocks) - else: - self.gpu_allocator = UncachedBlockAllocator( - Device.GPU, block_size, num_gpu_blocks) - self.cpu_allocator = UncachedBlockAllocator( - Device.CPU, block_size, num_cpu_blocks) - # Mapping: seq_id -> BlockTable. - self.block_tables: Dict[int, BlockTable] = {} - - # Mapping: req_id -> BlockTable - # Note that each SequenceGroup has a unique - # request ID - self.cross_block_tables: Dict[str, BlockTable] = {} - - def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int: - return 0 if seq is None else seq.n_blocks - - def can_allocate(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> AllocStatus: - # FIXME(woosuk): Here we assume that all sequences in the group share - # the same prompt. This may not be true for preempted sequences. - - assert (num_lookahead_slots == 0 - ), "lookahead allocation not supported in BlockSpaceManagerV1" - - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - self_num_required_blocks = self._get_seq_num_required_blocks( - seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) - cross_num_required_blocks = self._get_seq_num_required_blocks( - seq_group.get_encoder_seq()) - num_required_blocks = self_num_required_blocks + \ - cross_num_required_blocks - - if self.block_sliding_window is not None: - - num_required_blocks = min(num_required_blocks, - self.block_sliding_window) - num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() - - # Use watermark to avoid frequent cache eviction. - if (self.num_total_gpu_blocks - num_required_blocks < - self.watermark_blocks): - return AllocStatus.NEVER - if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - - def _allocate_sequence(self, \ - seq: Optional[Sequence], \ - ref_count: int, \ - is_encoder_decoder: bool = True) -> BlockTable: - # Allocate new physical token blocks that will store the prompt tokens. - num_prompt_blocks = self._get_seq_num_required_blocks(seq) - - block_table: BlockTable = BlockTable() - assert seq is not None - for logical_idx in range(num_prompt_blocks): - if (self.block_sliding_window is not None - and logical_idx >= self.block_sliding_window): - block = block_table[logical_idx % self.block_sliding_window] - # Set the reference counts of the token blocks. - block.ref_count = ref_count - elif not is_encoder_decoder and self.enable_caching: - block = self.gpu_allocator.allocate( - seq.hash_of_block(logical_idx), - seq.num_hashed_tokens_of_block(logical_idx)) - else: - block = self.gpu_allocator.allocate() - # Set the reference counts of the token blocks. - block.ref_count = ref_count - block_table.append(block) - - return block_table - - def allocate(self, seq_group: SequenceGroup) -> None: - is_encoder_decoder = seq_group.is_encoder_decoder() - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - # Allocate decoder sequences - # - # NOTE: Here we assume that all sequences in the group have the same - # decoder prompt. - wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) - seq = wait_seqs[0] - block_table: BlockTable = \ - self._allocate_sequence(seq, - seq_group.num_seqs(), - is_encoder_decoder) - - # Assign the self-attention block tables for each sequence. - if len(wait_seqs) == 1: - self.block_tables[seq.seq_id] = block_table - else: - for seq in wait_seqs: - self.block_tables[seq.seq_id] = block_table.copy() - - # Allocate encoder sequence - if is_encoder_decoder: - # A SequenceGroup has only a single encoder sequence (at most), - # thus allocate with a ref count of 1 - block_table = self._allocate_sequence(seq_group.get_encoder_seq(), - 1, is_encoder_decoder) - # Assign the cross-attention block table for the SequenceGroup. - self.cross_block_tables[seq_group.request_id] = block_table - - def can_append_slots(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> bool: - assert (num_lookahead_slots == 0 - ), "lookahead allocation not supported in BlockSpaceManagerV1" - - # Simple heuristic: If there is at least one free block - # for each sequence, we can append. - num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() - num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) - return num_seqs <= num_free_gpu_blocks - - def _promote_last_block( - self, - seq: Sequence, - last_block: PhysicalTokenBlock, - ) -> PhysicalTokenBlock: - assert self.enable_caching - - # Compute a new hash for the block so that it can be shared by other - # Sequences - new_hash = seq.hash_of_block(seq.n_blocks - 1) - - # if new_hash is already in the cached table, then free last_block - # and return the cached version - if self.gpu_allocator.contains_block(new_hash): - self.gpu_allocator.free(last_block) - return self.gpu_allocator.allocate(new_hash) - else: - self.gpu_allocator.update_hash(new_hash, last_block) - return last_block - - def _is_last_block_full( - self, - seq: Sequence, - ) -> bool: - token_ids_len = seq.data.get_len() - return token_ids_len > 0 and token_ids_len % seq.block_size == 0 - - def _maybe_promote_last_block( - self, - seq: Sequence, - last_block: PhysicalTokenBlock, - ) -> PhysicalTokenBlock: - if self._is_last_block_full(seq): - return self._promote_last_block(seq, last_block) - else: - return last_block - - def _allocate_last_physical_block( - self, - seq: Sequence, - ) -> PhysicalTokenBlock: - # Called before a new block is appended. - # This is in charge of allocating a new physical block (to be appended). - - # None if the last block is not full. Otherwise, we set it to the - # content hash. - if not self.enable_caching: - return self.gpu_allocator.allocate() - block_hash: Optional[int] = None - n_blocks = seq.n_blocks - if (self._is_last_block_full(seq)): - block_hash = seq.hash_of_block(n_blocks - 1) - num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1) - - # num_hashed_tokens is used to compute future hashes - # (e.g. in the hashing function, it is used to ask the sequence for - # prefix tokens) - new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens) - - # If the block_hash is None, then the block is not full. - # If the block is not full, then we expect it to have a refcount of 1. - if block_hash is None: - assert new_block.ref_count == 1 - return new_block - - def append_slots( - self, - seq: Sequence, - num_lookahead_slots: int = 0, - ) -> List[Tuple[int, int]]: - """Allocate a physical slot for a new token.""" - n_blocks = seq.n_blocks - block_table = self.block_tables[seq.seq_id] - # If we need to allocate a new physical block - if len(block_table) < n_blocks: - # Currently this code only supports adding one physical block - assert len(block_table) == n_blocks - 1 - - if (self.block_sliding_window - and len(block_table) >= self.block_sliding_window): - # reuse a block - block_table.append(block_table[len(block_table) % - self.block_sliding_window]) - else: - # The sequence hash a new logical block. - # Allocate a new physical block. - new_block = self._allocate_last_physical_block(seq) - block_table.append(new_block) - return [] - - # We want to append the token to the last physical block. - last_block = block_table[-1] - assert last_block.device == Device.GPU - if last_block.ref_count == 1: - # Not shared with other sequences. Appendable. - if self.enable_caching: - # If the last block is now complete, we may reuse an old block - # to save memory. - maybe_new_block = self._maybe_promote_last_block( - seq, last_block) - block_table[-1] = maybe_new_block - return [] - else: - # The last block is shared with other sequences. - # Copy on Write: Allocate a new block and copy the tokens. - new_block = self._allocate_last_physical_block(seq) - - block_table[-1] = new_block - self.gpu_allocator.free(last_block) - return [(last_block.block_number, new_block.block_number)] - - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - # NOTE: fork does not allocate a new physical block. - # Thus, it is always safe from OOM. - if parent_seq.seq_id not in self.block_tables: - # Parent sequence has either been freed or never existed. - return - src_block_table = self.block_tables[parent_seq.seq_id] - self.block_tables[child_seq.seq_id] = src_block_table.copy() - - # When using a sliding window, blocks will be eventually reused. - # In this case the block tables will contain repeated blocks. - # When forking, we must make sure that each block's `ref_count` - # is only incremented by one, so we deduplicate them by wrapping - # them in a set. - for block in set(src_block_table): - block.ref_count += 1 - - def _get_physical_blocks( - self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: - - # NOTE: Here, we assume that the physical blocks are only shared by - # the sequences in the same group. - request_id = seq_group.request_id - blocks: Set[PhysicalTokenBlock] = set() - for seq in seq_group.get_seqs(): - if seq.is_finished(): - continue - blocks.update(self.block_tables[seq.seq_id]) - # Cross-attention blocks - if seq_group.is_encoder_decoder(): - blocks.update(self.cross_block_tables[request_id]) - return list(blocks) - - def can_swap_in(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> AllocStatus: - assert (num_lookahead_slots == 0 - ), "BlockSpaceManagerV1 does not support lookahead allocation" - - blocks = self._get_physical_blocks(seq_group) - num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) - if seq_group.is_encoder_decoder(): - num_swapped_seqs += 1 - num_free_blocks = self.gpu_allocator.get_num_free_blocks() - # NOTE: Conservatively, we assume that every sequence will allocate - # at least one free block right after the swap-in. - # NOTE: This should match the logic in can_append_slot(). - num_required_blocks = len(blocks) + num_swapped_seqs - if self.gpu_allocator.get_num_total_blocks() < num_required_blocks: - return AllocStatus.NEVER - elif num_free_blocks - num_required_blocks >= self.watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - - def _swap_block_table( - self, block_table: BlockTable, src_allocator: BlockAllocatorBase, - dest_allocator: BlockAllocatorBase, - mapping: Dict[PhysicalTokenBlock, - PhysicalTokenBlock]) -> BlockTable: - new_block_table: BlockTable = BlockTable() - - for from_block in block_table: - if from_block in mapping: - to_block = mapping[from_block] - to_block.ref_count += 1 - else: - to_block = dest_allocator.allocate( - from_block.block_hash, from_block.num_hashed_tokens) - mapping[from_block] = to_block - new_block_table.append(to_block) - # Free the source block swapped in to destination. - src_allocator.free(from_block) - - return new_block_table - - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - - request_id = seq_group.request_id - - # CPU block -> GPU block. - # dict is efficient in lookup `if cpu_block in mapping` - mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - self.block_tables[seq.seq_id] = \ - self._swap_block_table(self.block_tables[seq.seq_id], - self.cpu_allocator, self.gpu_allocator, - mapping) - - if seq_group.is_encoder_decoder(): - self.cross_block_tables[request_id] = \ - self._swap_block_table(self.cross_block_tables[request_id], - self.cpu_allocator, - self.gpu_allocator, - mapping) - - return [(cpu_block.block_number, gpu_block.block_number) - for cpu_block, gpu_block in mapping.items()] - - def can_swap_out(self, seq_group: SequenceGroup) -> bool: - blocks = self._get_physical_blocks(seq_group) - return len(blocks) <= self.cpu_allocator.get_num_free_blocks() - - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - request_id = seq_group.request_id - - # GPU block -> CPU block. - # dict is efficient in lookup `if gpu_block in mapping` - mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - self.block_tables[seq.seq_id] = \ - self._swap_block_table(self.block_tables[seq.seq_id], - self.gpu_allocator, self.cpu_allocator, - mapping) - - if seq_group.is_encoder_decoder(): - self.cross_block_tables[request_id] = \ - self._swap_block_table(self.cross_block_tables[request_id], - self.gpu_allocator, - self.cpu_allocator, - mapping) - - return [(cpu_block.block_number, gpu_block.block_number) - for cpu_block, gpu_block in mapping.items()] - - def _free_block_table(self, block_table: BlockTable) -> None: - # when using a sliding window, each seq will only use up - # to `self.block_sliding_window` blocks. When freeing - # the block table, we must make sure to not free blocks more - # than once. If no sliding window is used, there is no block - # reuse in the block table, so we must free all blocks. - blocks_to_free = (block_table[-self.block_sliding_window:] - if self.block_sliding_window is not None else - block_table) - for block in set(blocks_to_free): - if block.device == Device.GPU: - self.gpu_allocator.free(block) - else: - self.cpu_allocator.free(block) - - def free(self, seq: Sequence) -> None: - if seq.seq_id not in self.block_tables: - # Already freed or haven't been scheduled yet. - return - block_table = self.block_tables[seq.seq_id] - self._free_block_table(block_table) - del self.block_tables[seq.seq_id] - - def free_cross(self, seq_group: SequenceGroup) -> None: - if seq_group.request_id not in self.cross_block_tables: - # Already freed or hasn't ben scheduled yet. - return - block_table = self.cross_block_tables[seq_group.request_id] - self._free_block_table(block_table) - del self.cross_block_tables[seq_group.request_id] - - def reset(self) -> None: - # Free decoder block tables - for block_table in self.block_tables.values(): - self._free_block_table(block_table) - self.block_tables.clear() - # Free cross-attention block tables - for block_table in self.cross_block_tables.values(): - self._free_block_table(block_table) - self.cross_block_tables.clear() - - def get_block_table(self, seq: Sequence) -> List[int]: - return self.block_tables[seq.seq_id].ids() - - def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: - block_table = self.cross_block_tables[seq_group.request_id] - return [block.block_number for block in block_table] - - def get_num_free_gpu_blocks(self) -> int: - return self.gpu_allocator.get_num_free_blocks() - - def get_num_free_cpu_blocks(self) -> int: - return self.cpu_allocator.get_num_free_blocks() - - def access_all_blocks_in_seq( - self, - seq: Sequence, - access_time: float, - ) -> None: - if self.enable_caching: - # Update the last accessed time of all the blocks accessed - # in this step. - block_table = self.block_tables[seq.seq_id] - for block in block_table: - block.last_accessed = access_time - - def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int): - if seq.seq_id not in self.block_tables: - return - - # When chunked prefill is enabled, the computed full blocks - # should be calculated based on the number of computed tokens. - max_computed_tokens = (seq.data.get_num_computed_tokens() + - token_chunk_size) - computed_full_blocks = max_computed_tokens // self.block_size - - block_table = self.block_tables[seq.seq_id] - if computed_full_blocks == 0: - return - for i in reversed(range(computed_full_blocks)): - if block_table[i].computed: - break - block_table[i].computed = True - - def get_all_computed_blocks(self, seq: Sequence) -> List[int]: - if seq.seq_id not in self.block_tables: - return [] - block_table = self.block_tables[seq.seq_id] - # NOTE We exclude the last block to avoid the case where the entire - # prompt is cached. This would cause erroneous behavior in model - # runner. - return [ - b.block_number - for b in takewhile(lambda b: b.computed, block_table[:-1]) - ] - - def get_common_computed_block_ids( - self, seqs: List[Sequence]) -> GenericSequence[int]: - """Return the block ids that are common for a given sequence group. - - Used in prefill (can skip prefill of some blocks). - """ - # Can return non-empty result only with prefix caching enabled. - if not self.enable_caching: - return [] - - ids_list = [self.get_all_computed_blocks(seq) for seq in seqs] - return commonprefix([ids for ids in ids_list if ids != []]) - - def mark_blocks_as_computed(self, seq_group: SequenceGroup, - token_chunk_size: int): - if self.enable_caching: - for seq in seq_group.get_seqs(): - self.compute_full_blocks_in_seq(seq, token_chunk_size) - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - if device == Device.GPU: - return self.gpu_allocator.get_prefix_cache_hit_rate() - if device == Device.CPU: - return self.cpu_allocator.get_prefix_cache_hit_rate() - raise ValueError(f"Invalid device: {device}") diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 9e1d1b02f6805..9501a516bf020 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -28,13 +28,9 @@ class BlockSpaceManager(ABC): def get_block_space_manager_class(version: str): version = version.lower() - if version == "v1": - from vllm.core.block_manager_v1 import BlockSpaceManagerV1 - return BlockSpaceManagerV1 - - if version == "v2": - from vllm.core.block_manager_v2 import BlockSpaceManagerV2 - return BlockSpaceManagerV2 + if version == "selfattn": + from vllm.core.block_manager import SelfAttnBlockSpaceManager + return SelfAttnBlockSpaceManager if version == "placeholder": from vllm.core.placeholder_block_space_manager import ( diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index e7eaaf12272d6..f0c8e6bab4862 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -312,9 +312,7 @@ def __init__( # LoRAs. This should be improved in the future. self.lora_config = lora_config - version = "v1" - if self.scheduler_config.use_v2_block_manager: - version = "v2" + version = "selfattn" if (self.scheduler_config.embedding_mode or self.cache_config.is_attention_free): version = "placeholder" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1ce9e62007f64..41963dcb16922 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -373,12 +373,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action='store_true', help='Disables sliding window, ' 'capping to sliding window size') - parser.add_argument( - '--use-v2-block-manager', - default=EngineArgs.use_v2_block_manager, - action='store_true', - help='Use BlockSpaceMangerV2. By default this is set to True. ' - 'Set to False to use BlockSpaceManagerV1') + parser.add_argument('--use-v2-block-manager', + action='store_true', + help='[DEPRECATED] block manager v1 has been ' + 'removed and SelfAttnBlockSpaceManager (i.e. ' + 'block manager v2) is now the default. ' + 'Setting this flag to True or False' + ' has no effect on vLLM behavior.') parser.add_argument( '--num-lookahead-slots', type=int, @@ -969,12 +970,6 @@ def create_engine_config(self) -> EngineConfig: "in low performance due to small KV cache space. Consider " "setting --max-model-len to a smaller value.", max_model_len) - if self.num_scheduler_steps > 1 and not self.use_v2_block_manager: - self.use_v2_block_manager = True - logger.warning( - "Enabled BlockSpaceManagerV2 because it is " - "required for multi-step (--num-scheduler-steps > 1)") - speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, target_parallel_config=parallel_config, @@ -990,7 +985,6 @@ def create_engine_config(self) -> EngineConfig: speculative_disable_by_batch_size, speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, - use_v2_block_manager=self.use_v2_block_manager, disable_log_stats=self.disable_log_stats, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, @@ -1021,11 +1015,20 @@ def create_engine_config(self) -> EngineConfig: if speculative_config is None \ else speculative_config.num_lookahead_slots + if not self.use_v2_block_manager: + logger.warning( + "[DEPRECATED] Block manager v1 has been removed, " + "and setting --use-v2-block-manager to True or False has " + "no effect on vLLM behavior. Please remove " + "--use-v2-block-manager in your engine argument. " + "If your use case is not supported by " + "SelfAttnBlockSpaceManager (i.e. block manager v2)," + " please file an issue with detailed information.") + scheduler_config = SchedulerConfig( max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, - use_v2_block_manager=self.use_v2_block_manager, num_lookahead_slots=num_lookahead_slots, delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, @@ -1081,13 +1084,6 @@ def create_engine_config(self) -> EngineConfig: or "all" in detailed_trace_modules, ) - if (model_config.get_sliding_window() is not None - and scheduler_config.chunked_prefill_enabled - and not scheduler_config.use_v2_block_manager): - raise ValueError( - "Chunked prefill is not supported with sliding window. " - "Set --disable-sliding-window to disable sliding window.") - return EngineConfig( model_config=model_config, cache_config=cache_config, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a570d096d4cd0..61c21887e6816 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -247,7 +247,7 @@ def __init__( "enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " + "seed=%d, served_model_name=%s, " "num_scheduler_steps=%d, chunked_prefill_enabled=%s " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " @@ -280,7 +280,6 @@ def __init__( observability_config, model_config.seed, model_config.served_model_name, - scheduler_config.use_v2_block_manager, scheduler_config.num_scheduler_steps, scheduler_config.chunked_prefill_enabled, scheduler_config.multi_step_stream_outputs, diff --git a/vllm/envs.py b/vllm/envs.py index 45a9999610f6a..2d283fae23849 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -64,7 +64,6 @@ VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False - VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 VLLM_DISABLED_KERNELS: List[str] = [] @@ -427,11 +426,6 @@ def get_default_config_root(): "VLLM_SKIP_P2P_CHECK": lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1", - # If set, allowing the use of deprecated block manager V1 - "VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1": - lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0" - ) == "1", - # List of quantization kernels that should be disabled, used for testing # and performance comparisons. Currently only affects MPLinearKernel # selection diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 36753b8580f6f..a82956985af55 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -574,17 +574,12 @@ def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, # paged attn. We can remove it if we make paged attn kernel # to properly handle slinding window attn. curr_sliding_window_block = self.sliding_window_blocks - if self.scheduler_config.use_v2_block_manager: - # number of elements in last block - suff_len = inter_data.seq_lens[seq_idx] % self.block_size - sliding_seq_len = min( - inter_data.seq_lens[seq_idx], - self.block_aligned_sliding_window + suff_len) - if suff_len > 0: - curr_sliding_window_block += 1 - else: - sliding_seq_len = min(inter_data.seq_lens[seq_idx], - self.sliding_window) + # number of elements in last block + suff_len = inter_data.seq_lens[seq_idx] % self.block_size + sliding_seq_len = min(inter_data.seq_lens[seq_idx], + self.block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_block += 1 inter_data.curr_sliding_window_blocks[ seq_idx] = curr_sliding_window_block