From 82cfa5a61e35c9736322ebb5d693966bf0d999c8 Mon Sep 17 00:00:00 2001 From: seungrokj <144636725+seungrokj@users.noreply.github.com> Date: Thu, 17 Oct 2024 00:25:41 +0900 Subject: [PATCH] cuda graph + num-scheduler-steps bug fix (#236) * cuda graph + num-scheduler-steps bug fix * cuda graph + num-scheduler-steps bug fix * linting --- vllm/attention/backends/utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 2b8c373178ab3..e451cd5522d18 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -218,9 +218,18 @@ def build(self, seq_lens: List[int], query_lens: List[int], # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] for i, block_table in enumerate(self.block_tables): if block_table: - input_block_tables[i, :len(block_table)] = block_table + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] block_tables = torch.from_numpy(input_block_tables).to( device, non_blocking=True) else: