From 7eb0e0d7a42b3ac64a7912faf1f2822601da5f2a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 09:44:51 -0400 Subject: [PATCH 001/443] added block manager tests --- tests/core/test_block_manager.py | 132 ++++++++++++++++++++++++++++++- 1 file changed, 131 insertions(+), 1 deletion(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 22a9f0cf47d32..6b2fa21f2ef46 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -12,7 +12,7 @@ from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -from .utils import create_dummy_prompt +from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder def test_block_allocator_allocate(): @@ -89,6 +89,34 @@ def test_allocate(): block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK +def test_allocate_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_req_per_seq_group = 2 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same sequence group to all available gpu blocks. + for i in range(num_gpu_blocks//block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder(str(i), block_size, block_size) + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK + + # Allocate same sequence group to all available gpu blocks. + # Use watermark to reserve one gpu block. + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=1 / num_gpu_blocks) + for i in range((num_gpu_blocks - 1)//block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder(str(i), block_size//2, block_size//2) + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK def test_append_slot_single_seq(): block_size = 4 @@ -240,6 +268,58 @@ def test_swap(): assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) +def test_swap_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + decoder_prompt, encoder_prompt, seq_group = create_dummy_prompt_encoder_decoder("1", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + decoder_prompt.status = SequenceStatus.WAITING + encoder_prompt.status = SequenceStatus.WAITING + block_manager.allocate(seq_group) + + # Emulate a forward pass by appending a single token. + # The block manager then knows how many unprocessed + # tokens will be written in the next forward pass. + token_id = 0 + decoder_prompt.status = SequenceStatus.RUNNING + decoder_prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) + + # Swap encoder/decoder seq group from GPU -> CPU. + decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt) + encoder_gpu_blocks = block_manager.get_encoder_block_table(seq_group) + gpu_blocks = decoder_gpu_blocks + encoder_gpu_blocks + assert block_manager.can_swap_out(seq_group) + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_out(seq_group) + assert [x[0] for x in mapping] == gpu_blocks + #assert list(mapping.keys()) == gpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) + assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks + decoder_prompt.status = SequenceStatus.SWAPPED + + # Swap decoder seq group from CPU -> GPU. + decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt) + encoder_cpu_blocks = block_manager.get_encoder_block_table(seq_group) + cpu_blocks = decoder_cpu_blocks + encoder_cpu_blocks + assert block_manager.can_swap_in(seq_group) == AllocStatus.OK + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_in(seq_group) + assert [x[0] for x in mapping] == cpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks + assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) def test_free(): block_size = 4 @@ -264,6 +344,34 @@ def test_free(): with pytest.raises(KeyError): block_manager.get_block_table(prompt) +def test_free_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + decoder_prompt, encoder_prompt, seq_group = create_dummy_prompt_encoder_decoder("1", + decoder_prompt_length=block_size//2, + encoder_prompt_length=block_size//2) + block_manager.allocate(seq_group) + + # Free allocated seq. + decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt)) + encoder_prompt_blocks = len(block_manager.get_encoder_block_table(seq_group)) + prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks + before_blocks = block_manager.get_num_free_gpu_blocks() + block_manager.free(decoder_prompt) + block_manager.free_encoder(seq_group) + after_blocks = block_manager.get_num_free_gpu_blocks() + assert after_blocks == before_blocks + prompt_blocks + + # Block table for freed encoder & decoder seq's are deleted. + with pytest.raises(KeyError): + block_manager.get_block_table(decoder_prompt) + block_manager.get_block_table(encoder_prompt) def test_reset(): block_size = 4 @@ -285,6 +393,28 @@ def test_reset(): block_manager.reset() assert block_manager.get_num_free_gpu_blocks() == original_blocks +def test_reset_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_req_per_seq_group = 2 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same seq group on all available gpu blocks. + original_blocks = block_manager.get_num_free_gpu_blocks() + for i in range(num_gpu_blocks//block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder(f"{i}", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + block_manager.allocate(seq_group) + assert block_manager.get_num_free_gpu_blocks() == 0 + + # Resetting block manager frees all allocated blocks. + block_manager.reset() + assert block_manager.get_num_free_gpu_blocks() == original_blocks def test_sliding_window_multi_seq(): """ From 6e41c39b24e8bdcff76ebbab0b95e16c0603e0b3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 09:52:03 -0400 Subject: [PATCH 002/443] passing block manager encoder/decoder test --- tests/core/utils.py | 29 ++++++++ vllm/core/block_manager_v1.py | 130 ++++++++++++++++++++++++++++++++-- vllm/sequence.py | 12 ++++ 3 files changed, 166 insertions(+), 5 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index 8fb13177a2d6c..170bf9fff3dd2 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -32,6 +32,35 @@ def create_dummy_prompt( return prompt, seq_group +def create_dummy_prompt_encoder_decoder( + request_id: str, + decoder_prompt_length: int, + encoder_prompt_length: int, + block_size: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + use_beam_search: bool = False, + best_of: int = 1, +) -> Tuple[Sequence, SequenceGroup]: + if not block_size: + block_size = decoder_prompt_length + + # Create dummy prompt sequence with tokens 0...block_size-1 + # and prompt "0 ... block_size". + decoder_prompt_tokens = list(range(decoder_prompt_length)) + decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) + decoder_prompt = Sequence(int(request_id), decoder_prompt_str, decoder_prompt_tokens, block_size) + encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) + encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) + encoder_prompt = Sequence(int(request_id), encoder_prompt_str, encoder_prompt_tokens, block_size) + seq_group = SequenceGroup( + request_id=request_id, + seqs=[decoder_prompt], + sampling_params=SamplingParams(use_beam_search=use_beam_search, best_of=best_of), + arrival_time=time.time(), + lora_request=lora_request, + encoder_seq=encoder_prompt) + + return decoder_prompt, encoder_prompt, seq_group def create_seq_group( seq_prompt_len: int = 1024, diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 52a170d79e4e7..bd2ccbbb86572 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -255,12 +255,23 @@ def __init__( Device.CPU, block_size, num_cpu_blocks) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} + # Mapping: req_id -> BlockTable + # Note that each SequenceGroup has a unique + # request ID + self.encoder_block_tables: Dict[str, BlockTable] = {} + + def get_seq_num_required_blocks(self, seq: Sequence) -> int: + if seq is None: + return 0 + return len(seq.logical_token_blocks) def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = len(seq.logical_token_blocks) + + decoder_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) + encoder_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_encoder_seq()) + num_required_blocks = decoder_num_required_blocks+encoder_num_required_blocks if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, @@ -276,9 +287,9 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate(self, seq_group: SequenceGroup) -> None: + def allocate_decoder(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same - # prompt. + # decoder prompt. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] # Allocate new physical token blocks that will store the prompt tokens. @@ -301,10 +312,46 @@ def allocate(self, seq_group: SequenceGroup) -> None: block.ref_count = seq_group.num_seqs() block_table.append(block) - # Assign the block table for each sequence. + # Assign the decoder block table for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() + def allocate_encoder(self, seq_group: SequenceGroup) -> None: + # NOTE: Here we assume that all sequences in the group have the same + # encoder prompt. + seq = seq_group.get_encoder_seq() + + # Allocate new physical token blocks that will store the prompt tokens. + block_table: BlockTable = [] + if seq is None: + # Assign empty encoder block table for the SequenceGroup + self.encoder_block_tables[seq_group.request_id] = block_table + else: + num_prompt_blocks = len(seq.logical_token_blocks) + for logical_idx in range(num_prompt_blocks): + if (self.block_sliding_window is not None + and logical_idx >= self.block_sliding_window): + block = block_table[logical_idx % self.block_sliding_window] + # Set the reference counts of the token blocks. + block.ref_count = seq_group.num_seqs() + elif self.enable_caching: + block = self.gpu_allocator.allocate( + seq.hash_of_block(logical_idx), + seq.num_hashed_tokens_of_block(logical_idx)) + else: + block = self.gpu_allocator.allocate() + # Set the reference counts of the token blocks. + # TODO: feature not supported with encoder/decoder + block.ref_count = seq_group.num_seqs() + block_table.append(block) + + # Assign the encoder block table for the SequenceGroup. + self.encoder_block_tables[seq_group.request_id] = block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + self.allocate_decoder(seq_group) + self.allocate_encoder(seq_group) + def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> bool: @@ -445,11 +492,15 @@ def _get_physical_blocks( self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: # NOTE: Here, we assume that the physical blocks are only shared by # the sequences in the same group. + request_id = seq_group.request_id blocks: Set[PhysicalTokenBlock] = set() for seq in seq_group.get_seqs(): if seq.is_finished(): continue blocks.update(self.block_tables[seq.seq_id]) + # Encoder blocks + if seq_group.encoder_seq is not None: + blocks.update(self.encoder_block_tables[request_id]) return list(blocks) def can_swap_in(self, @@ -459,6 +510,8 @@ def can_swap_in(self, ), "BlockSpaceManagerV1 does not support lookahead allocation" blocks = self._get_physical_blocks(seq_group) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) + if seq_group.encoder_seq is not None: + num_swapped_seqs += 1 num_free_blocks = self.gpu_allocator.get_num_free_blocks() # NOTE: Conservatively, we assume that every sequence will allocate # at least one free block right after the swap-in. @@ -477,6 +530,8 @@ def swap_in(self, assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" + request_id = seq_group.request_id + # CPU block -> GPU block. # dict is efficient in lookup `if cpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} @@ -497,6 +552,23 @@ def swap_in(self, self.cpu_allocator.free(cpu_block) self.block_tables[seq.seq_id] = new_block_table + if seq_group.encoder_seq is not None: + new_block_table: BlockTable = [] + block_table = self.encoder_block_tables[request_id] + + for cpu_block in block_table: + if cpu_block in mapping: + gpu_block = mapping[cpu_block] + gpu_block.ref_count += 1 + else: + gpu_block = self.gpu_allocator.allocate( + cpu_block.block_hash, cpu_block.num_hashed_tokens) + mapping[cpu_block] = gpu_block + new_block_table.append(gpu_block) + # Free the CPU block swapped in to GPU. + self.cpu_allocator.free(cpu_block) + self.encoder_block_tables[request_id] = new_block_table + block_number_mapping = { cpu_block.block_number: gpu_block.block_number for cpu_block, gpu_block in mapping.items() @@ -509,6 +581,8 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: return len(blocks) <= self.cpu_allocator.get_num_free_blocks() def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + request_id = seq_group.request_id + # GPU block -> CPU block. # dict is efficient in lookup `if gpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} @@ -529,6 +603,23 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self.gpu_allocator.free(gpu_block) self.block_tables[seq.seq_id] = new_block_table + if seq_group.encoder_seq is not None: + new_block_table: BlockTable = [] + block_table = self.encoder_block_tables[request_id] + + for gpu_block in block_table: + if gpu_block in mapping: + cpu_block = mapping[gpu_block] + cpu_block.ref_count += 1 + else: + cpu_block = self.cpu_allocator.allocate( + gpu_block.block_hash, gpu_block.num_hashed_tokens) + mapping[gpu_block] = cpu_block + new_block_table.append(cpu_block) + # Free the GPU block swapped out to CPU. + self.gpu_allocator.free(gpu_block) + self.encoder_block_tables[request_id] = new_block_table + block_number_mapping = { gpu_block.block_number: cpu_block.block_number for gpu_block, cpu_block in mapping.items() @@ -559,15 +650,32 @@ def free(self, seq: Sequence) -> None: self._free_block_table(block_table) del self.block_tables[seq.seq_id] + def free_encoder(self, seq_group: SequenceGroup) -> None: + if seq_group.request_id not in self.encoder_block_tables: + # Already freed or hasn't ben scheduled yet. + return + block_table = self.encoder_block_tables[seq_group.request_id] + self._free_block_table(block_table) + del self.encoder_block_tables[seq_group.request_id] + def reset(self) -> None: + # Free decoder block tables for block_table in self.block_tables.values(): self._free_block_table(block_table) self.block_tables.clear() + # Free encoder block tables + for block_table in self.encoder_block_tables.values(): + self._free_block_table(block_table) + self.encoder_block_tables.clear() def get_block_table(self, seq: Sequence) -> List[int]: block_table = self.block_tables[seq.seq_id] return [block.block_number for block in block_table] + def get_encoder_block_table(self, seq_group: SequenceGroup) -> List[int]: + block_table = self.encoder_block_tables[seq_group.request_id] + return [block.block_number for block in block_table] + def get_num_free_gpu_blocks(self) -> int: return self.gpu_allocator.get_num_free_blocks() @@ -586,6 +694,18 @@ def access_all_blocks_in_seq( for block in block_table: block.last_accessed = access_time + def access_all_encoder_blocks_in_seq_group( + self, + seq_group: SequenceGroup, + access_time: float, + ) -> None: + if self.enable_caching: + # Update the last accessed time of all the blocks accessed + # in this step. + block_table = self.encoder_block_tables[seq_group.request_id] + for block in block_table: + block.last_accessed = access_time + def compute_full_blocks_in_seq(self, seq: Sequence): if seq.seq_id not in self.block_tables: return diff --git a/vllm/sequence.py b/vllm/sequence.py index aa759448d82b1..ca2de3ef0d774 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -420,6 +420,7 @@ class SequenceGroup: for an embedding model. pooling_params: The pooling parameters used to generate the pooling for an embedding model. + encoder_seq: Optional, the single encoder sequence. """ def __init__( @@ -432,6 +433,7 @@ def __init__( multi_modal_data: Optional[MultiModalData] = None, embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, + encoder_seq: Optional[Sequence] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -447,6 +449,7 @@ def __init__( self.multi_modal_data = multi_modal_data self.embeddings = embeddings self.pooling_params = pooling_params + self.encoder_seq = encoder_seq @property def prompt(self) -> str: @@ -524,6 +527,9 @@ def get_seqs( seq for seq in self.seqs_dict.values() if seq.status == status ] + def get_encoder_seq(self) -> Sequence: + return self.encoder_seq + def get_unfinished_seqs(self) -> List[Sequence]: return [ seq for seq in self.seqs_dict.values() if not seq.is_finished() @@ -607,6 +613,8 @@ class SequenceGroupMetadata: used in prefix caching. state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. + encoder_seq_data: Optional, the sequence data for the single encoder prompt. + encoder_block_table: Optional, the block table for the single encoder prompt. """ def __init__( @@ -623,6 +631,8 @@ def __init__( computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, multi_modal_data: Optional[MultiModalData] = None, + encoder_seq_data: Optional[SequenceData] = None, + encoder_block_table: Optional[Dict[int, List[int]]] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt @@ -634,6 +644,8 @@ def __init__( self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state + self.encoder_seq_data = encoder_seq_data + self.encoder_block_table = encoder_block_table self._token_chunk_size = token_chunk_size self.do_sample = do_sample From f04ee73114eb50dbf03cb1d2a9ecd238705db035 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 14:22:04 -0400 Subject: [PATCH 003/443] block manager v2 changes to pass test_can_allocate_seq_group_encoder_decoder --- tests/core/block/test_block_manager_v2.py | 49 ++++++++++++++++++++++- tests/core/utils.py | 49 +++++++++++++++++++++++ vllm/core/block_manager_v2.py | 6 +++ 3 files changed, 103 insertions(+), 1 deletion(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 1e8e4ccdfb151..6cb2f3708199f 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -5,7 +5,7 @@ from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list -from ..utils import create_seq_group +from ..utils import create_seq_group, create_seq_group_encoder_decoder @pytest.mark.parametrize("block_size", [16]) @@ -52,6 +52,53 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, assert can_allocate_result == AllocStatus.LATER +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16, 80, 160]) +@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_group: int, + num_gpu_blocks: int, watermark: float): + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + ) + num_watermark_blocks = int(watermark * num_gpu_blocks) + + num_output_blocks_per_seq = 1 + + # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but + # the current implementation assumes all seqs are new prompts / don't have + # different output lens. + num_output_blocks = num_output_blocks_per_seq + + for bdx,num_prompt_blocks in enumerate(range(1, num_gpu_blocks - num_output_blocks)): + num_encoder_blocks_per_seq = num_prompt_blocks + + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id=str(bdx) + ) + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + + can_allocate_result = block_manager.can_allocate(seq_group) + + num_required_blocks = num_prompt_blocks + num_output_blocks + num_encoder_blocks_per_seq + + if num_gpu_blocks - num_required_blocks < num_watermark_blocks: + assert can_allocate_result == AllocStatus.NEVER + elif num_gpu_blocks >= num_required_blocks: + assert can_allocate_result == AllocStatus.OK + else: + assert can_allocate_result == AllocStatus.LATER + + @pytest.mark.parametrize("block_size", [1, 8]) @pytest.mark.parametrize("prompt_len", [1, 7, 8]) @pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) diff --git a/tests/core/utils.py b/tests/core/utils.py index 170bf9fff3dd2..91930457bd25b 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -102,5 +102,54 @@ def create_seq_group( return seq_group +def create_seq_group_encoder_decoder( + seq_prompt_len: int = 1024, + seq_output_lens: Iterable[int] = (128, ), + request_id: str = '0', + seq_id_start: int = 0, + sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: + + assert len(seq_output_lens) > 0 + + if sampling_params is None: + sampling_params = SamplingParams() + + prompt_token_ids = [0] * seq_prompt_len + + seqs = [] + for seq_id_offset, output_len in enumerate(seq_output_lens): + seq = Sequence( + seq_id=seq_id_start + seq_id_offset, + prompt="", + prompt_token_ids=prompt_token_ids, + block_size=16, + ) + + for i in range(output_len): + seq.append_token_id( + token_id=i, + logprobs={i: Logprob(0.0)}, + ) + seqs.append(seq) + + # Encoder sequence + encoder_seq = Sequence( + seq_id=seq_id_start + len(seq_output_lens), + prompt="", + prompt_token_ids=prompt_token_ids, + block_size=16, + ) + + seq_group = SequenceGroup( + request_id=request_id, + seqs=seqs, + sampling_params=sampling_params, + arrival_time=time.time(), + encoder_seq=encoder_seq + ) + + return seq_group + + def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index f0bc96564050a..06bfbba78dce6 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -96,6 +96,12 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: block_size=self.block_size, ) + if seq_group.encoder_seq is not None: + num_required_blocks += BlockTable.get_num_required_blocks( + seq_group.encoder_seq.get_token_ids(), + block_size=self.block_size, + ) + assert self.block_sliding_window is None if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, From 07bbd8ac4c44f50f42137350ee928483842d02ee Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 14:47:47 -0400 Subject: [PATCH 004/443] block manager v2 support for encoder/decoder --- vllm/core/block_manager_v1.py | 9 ++---- vllm/core/block_manager_v2.py | 59 ++++++++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index bd2ccbbb86572..812d1ee3197a5 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -319,14 +319,11 @@ def allocate_decoder(self, seq_group: SequenceGroup) -> None: def allocate_encoder(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # encoder prompt. - seq = seq_group.get_encoder_seq() # Allocate new physical token blocks that will store the prompt tokens. - block_table: BlockTable = [] - if seq is None: - # Assign empty encoder block table for the SequenceGroup - self.encoder_block_tables[seq_group.request_id] = block_table - else: + seq = seq_group.get_encoder_seq() + if seq is not None: + block_table: BlockTable = [] num_prompt_blocks = len(seq.logical_token_blocks) for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 06bfbba78dce6..2f7a11bacc1a1 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -10,6 +10,7 @@ from vllm.utils import Device SeqId = int +EncoderSeqId = str class BlockSpaceManagerV2(BlockSpaceManager): @@ -85,6 +86,7 @@ def __init__( ) self.block_tables: Dict[SeqId, BlockTable] = {} + self.encoder_block_tables: Dict[EncoderSeqId, BlockTable] = {} def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share @@ -119,7 +121,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate(self, seq_group: SequenceGroup) -> None: + def allocate_decoder(self, seq_group: SequenceGroup) -> None: waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert not (set(seq.seq_id for seq in waiting_seqs) & self.block_tables.keys()), "block table already exists" @@ -140,6 +142,28 @@ def allocate(self, seq_group: SequenceGroup) -> None: for seq in waiting_seqs[1:]: self.block_tables[seq.seq_id] = block_table.fork() + def allocate_encoder(self, seq_group: SequenceGroup) -> None: + # NOTE: Here we assume that all sequences in the group have the same + # prompt. + request_id = seq_group.request_id + seq = seq_group.encoder_seq + + assert not (request_id in self.encoder_block_tables), "block table already exists" + + seq = seq_group.get_encoder_seq() + if seq is not None: + block_table = BlockTable( + block_size=self.block_size, + block_allocator=self.block_allocator, + ) + assert self.block_sliding_window is None + block_table.allocate(seq.get_token_ids()) + self.encoder_block_tables[request_id] = block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + self.allocate_decoder(seq_group) + self.allocate_encoder(seq_group) + def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> bool: """Determine if there is enough space in the GPU KV cache to continue @@ -193,12 +217,29 @@ def free(self, seq: Sequence) -> None: self.block_tables[seq.seq_id].free() del self.block_tables[seq.seq_id] + def free_encoder(self, seq_group: SequenceGroup) -> None: + request_id = seq_group.request_id + if request_id not in self.encoder_block_tables: + # Already freed or hasn't ben scheduled yet. + return + self.encoder_block_tables[request_id].free() + del self.encoder_block_tables[request_id] + + del self.encoder_block_tables[seq_group.request_id] + def get_block_table(self, seq: Sequence) -> List[int]: assert seq.seq_id in self.block_tables block_ids = self.block_tables[seq.seq_id].physical_block_ids assert all(b is not None for b in block_ids) return block_ids # type: ignore + def get_encoder_block_table(self, seq_group: SequenceGroup) -> List[int]: + request_id = seq_group.request_id + assert request_id in self.encoder_block_tables + block_ids = self.block_tables[request_id].physical_block_ids + assert all(b is not None for b in block_ids) + return block_ids + def access_all_blocks_in_seq(self, seq: Sequence, now: float): # Update the last accessed time of all the blocks accessed # in this step. @@ -215,6 +256,22 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float): block_ids, # type: ignore now) + def access_all_encoder_blocks_in_seq_group( + self, + seq_group: SequenceGroup, + now: float, + ) -> None: + if self.enable_caching: + # Update the last accessed time of all the blocks accessed + # in this step. + block_table = self.encoder_block_tables[seq_group.request_id] + block_ids = [] + for block_id in block_table.physical_block_ids: + block_ids.append(block_id) + self.block_allocator.mark_blocks_as_accessed( + block_ids, # type: ignore + now) + def mark_blocks_as_computed(self, seq_group: SequenceGroup): # The only need for mark block as computed is for prefix caching, # while currently we could determine whether one block is computed From 3e95602f9c408f82628e881f30540ac82b3cb5f7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 15:11:35 -0400 Subject: [PATCH 005/443] renamed encoder to cross in block manager v2, regarding block tables --- vllm/core/block_manager_v2.py | 32 ++++++++++++++++---------------- vllm/sequence.py | 6 +++--- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 2f7a11bacc1a1..426612f615508 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -86,7 +86,7 @@ def __init__( ) self.block_tables: Dict[SeqId, BlockTable] = {} - self.encoder_block_tables: Dict[EncoderSeqId, BlockTable] = {} + self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share @@ -121,7 +121,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate_decoder(self, seq_group: SequenceGroup) -> None: + def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert not (set(seq.seq_id for seq in waiting_seqs) & self.block_tables.keys()), "block table already exists" @@ -142,13 +142,13 @@ def allocate_decoder(self, seq_group: SequenceGroup) -> None: for seq in waiting_seqs[1:]: self.block_tables[seq.seq_id] = block_table.fork() - def allocate_encoder(self, seq_group: SequenceGroup) -> None: + def allocate_cross_block_table(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # prompt. request_id = seq_group.request_id seq = seq_group.encoder_seq - assert not (request_id in self.encoder_block_tables), "block table already exists" + assert not (request_id in self.cross_block_tables), "block table already exists" seq = seq_group.get_encoder_seq() if seq is not None: @@ -158,11 +158,11 @@ def allocate_encoder(self, seq_group: SequenceGroup) -> None: ) assert self.block_sliding_window is None block_table.allocate(seq.get_token_ids()) - self.encoder_block_tables[request_id] = block_table + self.cross_block_tables[request_id] = block_table def allocate(self, seq_group: SequenceGroup) -> None: - self.allocate_decoder(seq_group) - self.allocate_encoder(seq_group) + self.allocate_self_block_tables(seq_group) + self.allocate_cross_block_table(seq_group) def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> bool: @@ -217,15 +217,15 @@ def free(self, seq: Sequence) -> None: self.block_tables[seq.seq_id].free() del self.block_tables[seq.seq_id] - def free_encoder(self, seq_group: SequenceGroup) -> None: + def free_cross(self, seq_group: SequenceGroup) -> None: request_id = seq_group.request_id - if request_id not in self.encoder_block_tables: + if request_id not in self.cross_block_tables: # Already freed or hasn't ben scheduled yet. return - self.encoder_block_tables[request_id].free() - del self.encoder_block_tables[request_id] + self.cross_block_tables[request_id].free() + del self.cross_block_tables[request_id] - del self.encoder_block_tables[seq_group.request_id] + del self.cross_block_tables[seq_group.request_id] def get_block_table(self, seq: Sequence) -> List[int]: assert seq.seq_id in self.block_tables @@ -233,9 +233,9 @@ def get_block_table(self, seq: Sequence) -> List[int]: assert all(b is not None for b in block_ids) return block_ids # type: ignore - def get_encoder_block_table(self, seq_group: SequenceGroup) -> List[int]: + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: request_id = seq_group.request_id - assert request_id in self.encoder_block_tables + assert request_id in self.cross_block_tables block_ids = self.block_tables[request_id].physical_block_ids assert all(b is not None for b in block_ids) return block_ids @@ -256,7 +256,7 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float): block_ids, # type: ignore now) - def access_all_encoder_blocks_in_seq_group( + def access_all_cross_blocks_in_seq_group( self, seq_group: SequenceGroup, now: float, @@ -264,7 +264,7 @@ def access_all_encoder_blocks_in_seq_group( if self.enable_caching: # Update the last accessed time of all the blocks accessed # in this step. - block_table = self.encoder_block_tables[seq_group.request_id] + block_table = self.cross_block_tables[seq_group.request_id] block_ids = [] for block_id in block_table.physical_block_ids: block_ids.append(block_id) diff --git a/vllm/sequence.py b/vllm/sequence.py index ca2de3ef0d774..a73e70c1ae69d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -614,7 +614,7 @@ class SequenceGroupMetadata: state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. encoder_seq_data: Optional, the sequence data for the single encoder prompt. - encoder_block_table: Optional, the block table for the single encoder prompt. + cross_block_table: Optional, the cross-attention block table associated with the single encoder prompt. """ def __init__( @@ -632,7 +632,7 @@ def __init__( state: Optional[SequenceGroupState] = None, multi_modal_data: Optional[MultiModalData] = None, encoder_seq_data: Optional[SequenceData] = None, - encoder_block_table: Optional[Dict[int, List[int]]] = None, + cross_block_table: Optional[Dict[int, List[int]]] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt @@ -645,7 +645,7 @@ def __init__( self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state self.encoder_seq_data = encoder_seq_data - self.encoder_block_table = encoder_block_table + self.cross_block_table = cross_block_table self._token_chunk_size = token_chunk_size self.do_sample = do_sample From 04f38a819445c0141246feeb6969cc4b1e67891f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 15:22:53 -0400 Subject: [PATCH 006/443] renamed encoder to cross where appropriate --- tests/core/block/test_block_manager_v2.py | 4 +- tests/core/test_block_manager.py | 12 ++--- vllm/core/block_manager_v1.py | 54 +++++++++++------------ 3 files changed, 35 insertions(+), 35 deletions(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 6cb2f3708199f..9b1c6cd68a15a 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -74,7 +74,7 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_gr num_output_blocks = num_output_blocks_per_seq for bdx,num_prompt_blocks in enumerate(range(1, num_gpu_blocks - num_output_blocks)): - num_encoder_blocks_per_seq = num_prompt_blocks + num_cross_blocks_per_seq = num_prompt_blocks seq_group = create_seq_group_encoder_decoder( seq_prompt_len=block_size * num_prompt_blocks, @@ -89,7 +89,7 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_gr can_allocate_result = block_manager.can_allocate(seq_group) - num_required_blocks = num_prompt_blocks + num_output_blocks + num_encoder_blocks_per_seq + num_required_blocks = num_prompt_blocks + num_output_blocks + num_cross_blocks_per_seq if num_gpu_blocks - num_required_blocks < num_watermark_blocks: assert can_allocate_result == AllocStatus.NEVER diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 6b2fa21f2ef46..62b7132e40462 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -293,8 +293,8 @@ def test_swap_encoder_decoder(): # Swap encoder/decoder seq group from GPU -> CPU. decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt) - encoder_gpu_blocks = block_manager.get_encoder_block_table(seq_group) - gpu_blocks = decoder_gpu_blocks + encoder_gpu_blocks + cross_gpu_blocks = block_manager.get_cross_block_table(seq_group) + gpu_blocks = decoder_gpu_blocks + cross_gpu_blocks assert block_manager.can_swap_out(seq_group) before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() @@ -309,8 +309,8 @@ def test_swap_encoder_decoder(): # Swap decoder seq group from CPU -> GPU. decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt) - encoder_cpu_blocks = block_manager.get_encoder_block_table(seq_group) - cpu_blocks = decoder_cpu_blocks + encoder_cpu_blocks + cross_cpu_blocks = block_manager.get_cross_block_table(seq_group) + cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks assert block_manager.can_swap_in(seq_group) == AllocStatus.OK before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() @@ -360,11 +360,11 @@ def test_free_encoder_decoder(): # Free allocated seq. decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt)) - encoder_prompt_blocks = len(block_manager.get_encoder_block_table(seq_group)) + encoder_prompt_blocks = len(block_manager.get_cross_block_table(seq_group)) prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks before_blocks = block_manager.get_num_free_gpu_blocks() block_manager.free(decoder_prompt) - block_manager.free_encoder(seq_group) + block_manager.free_cross(seq_group) after_blocks = block_manager.get_num_free_gpu_blocks() assert after_blocks == before_blocks + prompt_blocks diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 812d1ee3197a5..11a52b3618b44 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -258,7 +258,7 @@ def __init__( # Mapping: req_id -> BlockTable # Note that each SequenceGroup has a unique # request ID - self.encoder_block_tables: Dict[str, BlockTable] = {} + self.cross_block_tables: Dict[str, BlockTable] = {} def get_seq_num_required_blocks(self, seq: Sequence) -> int: if seq is None: @@ -269,9 +269,9 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - decoder_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) - encoder_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_encoder_seq()) - num_required_blocks = decoder_num_required_blocks+encoder_num_required_blocks + self_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) + cross_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_encoder_seq()) + num_required_blocks = self_num_required_blocks+cross_num_required_blocks if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, @@ -287,7 +287,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate_decoder(self, seq_group: SequenceGroup) -> None: + def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # decoder prompt. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] @@ -316,7 +316,7 @@ def allocate_decoder(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() - def allocate_encoder(self, seq_group: SequenceGroup) -> None: + def allocate_cross_block_table(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # encoder prompt. @@ -342,12 +342,12 @@ def allocate_encoder(self, seq_group: SequenceGroup) -> None: block.ref_count = seq_group.num_seqs() block_table.append(block) - # Assign the encoder block table for the SequenceGroup. - self.encoder_block_tables[seq_group.request_id] = block_table + # Assign the cross-attention block table for the SequenceGroup. + self.cross_block_tables[seq_group.request_id] = block_table def allocate(self, seq_group: SequenceGroup) -> None: - self.allocate_decoder(seq_group) - self.allocate_encoder(seq_group) + self.allocate_self_block_tables(seq_group) + self.allocate_cross_block_table(seq_group) def can_append_slots(self, seq_group: SequenceGroup, @@ -495,9 +495,9 @@ def _get_physical_blocks( if seq.is_finished(): continue blocks.update(self.block_tables[seq.seq_id]) - # Encoder blocks + # Cross-attention blocks if seq_group.encoder_seq is not None: - blocks.update(self.encoder_block_tables[request_id]) + blocks.update(self.cross_block_tables[request_id]) return list(blocks) def can_swap_in(self, @@ -551,7 +551,7 @@ def swap_in(self, if seq_group.encoder_seq is not None: new_block_table: BlockTable = [] - block_table = self.encoder_block_tables[request_id] + block_table = self.cross_block_tables[request_id] for cpu_block in block_table: if cpu_block in mapping: @@ -564,7 +564,7 @@ def swap_in(self, new_block_table.append(gpu_block) # Free the CPU block swapped in to GPU. self.cpu_allocator.free(cpu_block) - self.encoder_block_tables[request_id] = new_block_table + self.cross_block_tables[request_id] = new_block_table block_number_mapping = { cpu_block.block_number: gpu_block.block_number @@ -602,7 +602,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: if seq_group.encoder_seq is not None: new_block_table: BlockTable = [] - block_table = self.encoder_block_tables[request_id] + block_table = self.cross_block_tables[request_id] for gpu_block in block_table: if gpu_block in mapping: @@ -615,7 +615,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: new_block_table.append(cpu_block) # Free the GPU block swapped out to CPU. self.gpu_allocator.free(gpu_block) - self.encoder_block_tables[request_id] = new_block_table + self.cross_block_tables[request_id] = new_block_table block_number_mapping = { gpu_block.block_number: cpu_block.block_number @@ -647,30 +647,30 @@ def free(self, seq: Sequence) -> None: self._free_block_table(block_table) del self.block_tables[seq.seq_id] - def free_encoder(self, seq_group: SequenceGroup) -> None: - if seq_group.request_id not in self.encoder_block_tables: + def free_cross(self, seq_group: SequenceGroup) -> None: + if seq_group.request_id not in self.cross_block_tables: # Already freed or hasn't ben scheduled yet. return - block_table = self.encoder_block_tables[seq_group.request_id] + block_table = self.cross_block_tables[seq_group.request_id] self._free_block_table(block_table) - del self.encoder_block_tables[seq_group.request_id] + del self.cross_block_tables[seq_group.request_id] def reset(self) -> None: # Free decoder block tables for block_table in self.block_tables.values(): self._free_block_table(block_table) self.block_tables.clear() - # Free encoder block tables - for block_table in self.encoder_block_tables.values(): + # Free cross-attention block tables + for block_table in self.cross_block_tables.values(): self._free_block_table(block_table) - self.encoder_block_tables.clear() + self.cross_block_tables.clear() def get_block_table(self, seq: Sequence) -> List[int]: block_table = self.block_tables[seq.seq_id] return [block.block_number for block in block_table] - def get_encoder_block_table(self, seq_group: SequenceGroup) -> List[int]: - block_table = self.encoder_block_tables[seq_group.request_id] + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: + block_table = self.cross_block_tables[seq_group.request_id] return [block.block_number for block in block_table] def get_num_free_gpu_blocks(self) -> int: @@ -691,7 +691,7 @@ def access_all_blocks_in_seq( for block in block_table: block.last_accessed = access_time - def access_all_encoder_blocks_in_seq_group( + def access_all_cross_blocks_in_seq_group( self, seq_group: SequenceGroup, access_time: float, @@ -699,7 +699,7 @@ def access_all_encoder_blocks_in_seq_group( if self.enable_caching: # Update the last accessed time of all the blocks accessed # in this step. - block_table = self.encoder_block_tables[seq_group.request_id] + block_table = self.cross_block_tables[seq_group.request_id] for block in block_table: block.last_accessed = access_time From 2dcd663d40bdcc1cf2aca19b9cec64395ac6d528 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 15:45:12 -0400 Subject: [PATCH 007/443] formatting --- tests/core/block/test_block_manager_v2.py | 16 +++++--- tests/core/test_block_manager.py | 43 +++++++++++++++------- tests/core/utils.py | 45 ++++++++++++----------- vllm/core/block_manager_v1.py | 18 +++++---- vllm/core/block_manager_v2.py | 8 ++-- vllm/sequence.py | 9 +++-- 6 files changed, 85 insertions(+), 54 deletions(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 9b1c6cd68a15a..06c3389cfa0f0 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -56,8 +56,10 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, @pytest.mark.parametrize("num_gpu_blocks", [16, 80, 160]) @pytest.mark.parametrize("num_seqs_per_group", [1, 4]) @pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_group: int, - num_gpu_blocks: int, watermark: float): +def test_can_allocate_seq_group_encoder_decoder(block_size: int, + num_seqs_per_group: int, + num_gpu_blocks: int, + watermark: float): block_manager = BlockSpaceManagerV2( block_size=block_size, num_gpu_blocks=num_gpu_blocks, @@ -73,7 +75,8 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_gr # different output lens. num_output_blocks = num_output_blocks_per_seq - for bdx,num_prompt_blocks in enumerate(range(1, num_gpu_blocks - num_output_blocks)): + for bdx, num_prompt_blocks in enumerate( + range(1, num_gpu_blocks - num_output_blocks)): num_cross_blocks_per_seq = num_prompt_blocks seq_group = create_seq_group_encoder_decoder( @@ -82,14 +85,15 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_gr block_size * num_output_blocks_per_seq for _ in range(num_seqs_per_group) ], - request_id=str(bdx) - ) + request_id=str(bdx)) assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks can_allocate_result = block_manager.can_allocate(seq_group) - num_required_blocks = num_prompt_blocks + num_output_blocks + num_cross_blocks_per_seq + num_required_blocks = num_prompt_blocks + \ + num_output_blocks + \ + num_cross_blocks_per_seq if num_gpu_blocks - num_required_blocks < num_watermark_blocks: assert can_allocate_result == AllocStatus.NEVER diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 62b7132e40462..d6ab246699903 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -89,6 +89,7 @@ def test_allocate(): block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK + def test_allocate_encoder_decoder(): block_size = 4 num_cpu_blocks = 4 @@ -100,8 +101,9 @@ def test_allocate_encoder_decoder(): watermark=0) # Allocate same sequence group to all available gpu blocks. - for i in range(num_gpu_blocks//block_req_per_seq_group): - _, _, seq_group = create_dummy_prompt_encoder_decoder(str(i), block_size, block_size) + for i in range(num_gpu_blocks // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + str(i), block_size, block_size) assert block_manager.can_allocate(seq_group) block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -112,12 +114,14 @@ def test_allocate_encoder_decoder(): num_cpu_blocks, num_gpu_blocks, watermark=1 / num_gpu_blocks) - for i in range((num_gpu_blocks - 1)//block_req_per_seq_group): - _, _, seq_group = create_dummy_prompt_encoder_decoder(str(i), block_size//2, block_size//2) + for i in range((num_gpu_blocks - 1) // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + str(i), block_size // 2, block_size // 2) assert block_manager.can_allocate(seq_group) block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK + def test_append_slot_single_seq(): block_size = 4 num_cpu_blocks = 4 @@ -268,6 +272,7 @@ def test_swap(): assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) + def test_swap_encoder_decoder(): block_size = 4 num_cpu_blocks = 4 @@ -277,9 +282,11 @@ def test_swap_encoder_decoder(): num_gpu_blocks, watermark=0) - decoder_prompt, encoder_prompt, seq_group = create_dummy_prompt_encoder_decoder("1", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) + decoder_prompt, encoder_prompt, seq_group = \ + create_dummy_prompt_encoder_decoder( + "1", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) decoder_prompt.status = SequenceStatus.WAITING encoder_prompt.status = SequenceStatus.WAITING block_manager.allocate(seq_group) @@ -321,6 +328,7 @@ def test_swap_encoder_decoder(): assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) + def test_free(): block_size = 4 num_cpu_blocks = 4 @@ -344,6 +352,7 @@ def test_free(): with pytest.raises(KeyError): block_manager.get_block_table(prompt) + def test_free_encoder_decoder(): block_size = 4 num_cpu_blocks = 4 @@ -353,9 +362,11 @@ def test_free_encoder_decoder(): num_gpu_blocks, watermark=0) - decoder_prompt, encoder_prompt, seq_group = create_dummy_prompt_encoder_decoder("1", - decoder_prompt_length=block_size//2, - encoder_prompt_length=block_size//2) + decoder_prompt, encoder_prompt, seq_group = \ + create_dummy_prompt_encoder_decoder( + "1", + decoder_prompt_length=block_size // 2, + encoder_prompt_length=block_size // 2) block_manager.allocate(seq_group) # Free allocated seq. @@ -373,6 +384,7 @@ def test_free_encoder_decoder(): block_manager.get_block_table(decoder_prompt) block_manager.get_block_table(encoder_prompt) + def test_reset(): block_size = 4 num_cpu_blocks = 4 @@ -393,6 +405,7 @@ def test_reset(): block_manager.reset() assert block_manager.get_num_free_gpu_blocks() == original_blocks + def test_reset_encoder_decoder(): block_size = 4 num_cpu_blocks = 4 @@ -405,10 +418,11 @@ def test_reset_encoder_decoder(): # Allocate same seq group on all available gpu blocks. original_blocks = block_manager.get_num_free_gpu_blocks() - for i in range(num_gpu_blocks//block_req_per_seq_group): - _, _, seq_group = create_dummy_prompt_encoder_decoder(f"{i}", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) + for i in range(num_gpu_blocks // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + f"{i}", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) block_manager.allocate(seq_group) assert block_manager.get_num_free_gpu_blocks() == 0 @@ -416,6 +430,7 @@ def test_reset_encoder_decoder(): block_manager.reset() assert block_manager.get_num_free_gpu_blocks() == original_blocks + def test_sliding_window_multi_seq(): """ Tests that memory allocation and deallocation is handled diff --git a/tests/core/utils.py b/tests/core/utils.py index 91930457bd25b..376af0f0eac4f 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -32,6 +32,7 @@ def create_dummy_prompt( return prompt, seq_group + def create_dummy_prompt_encoder_decoder( request_id: str, decoder_prompt_length: int, @@ -48,20 +49,24 @@ def create_dummy_prompt_encoder_decoder( # and prompt "0 ... block_size". decoder_prompt_tokens = list(range(decoder_prompt_length)) decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) - decoder_prompt = Sequence(int(request_id), decoder_prompt_str, decoder_prompt_tokens, block_size) + decoder_prompt = Sequence(int(request_id), decoder_prompt_str, + decoder_prompt_tokens, block_size) encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) - encoder_prompt = Sequence(int(request_id), encoder_prompt_str, encoder_prompt_tokens, block_size) - seq_group = SequenceGroup( - request_id=request_id, - seqs=[decoder_prompt], - sampling_params=SamplingParams(use_beam_search=use_beam_search, best_of=best_of), - arrival_time=time.time(), - lora_request=lora_request, - encoder_seq=encoder_prompt) + encoder_prompt = Sequence(int(request_id), encoder_prompt_str, + encoder_prompt_tokens, block_size) + seq_group = SequenceGroup(request_id=request_id, + seqs=[decoder_prompt], + sampling_params=SamplingParams( + use_beam_search=use_beam_search, + best_of=best_of), + arrival_time=time.time(), + lora_request=lora_request, + encoder_seq=encoder_prompt) return decoder_prompt, encoder_prompt, seq_group + def create_seq_group( seq_prompt_len: int = 1024, seq_output_lens: Iterable[int] = (128, ), @@ -134,20 +139,18 @@ def create_seq_group_encoder_decoder( # Encoder sequence encoder_seq = Sequence( - seq_id=seq_id_start + len(seq_output_lens), - prompt="", - prompt_token_ids=prompt_token_ids, - block_size=16, - ) - - seq_group = SequenceGroup( - request_id=request_id, - seqs=seqs, - sampling_params=sampling_params, - arrival_time=time.time(), - encoder_seq=encoder_seq + seq_id=seq_id_start + len(seq_output_lens), + prompt="", + prompt_token_ids=prompt_token_ids, + block_size=16, ) + seq_group = SequenceGroup(request_id=request_id, + seqs=seqs, + sampling_params=sampling_params, + arrival_time=time.time(), + encoder_seq=encoder_seq) + return seq_group diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 11a52b3618b44..03eba2e80c78d 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -263,15 +263,18 @@ def __init__( def get_seq_num_required_blocks(self, seq: Sequence) -> int: if seq is None: return 0 - return len(seq.logical_token_blocks) + return len(seq.logical_token_blocks) def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - self_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) - cross_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_encoder_seq()) - num_required_blocks = self_num_required_blocks+cross_num_required_blocks + self_num_required_blocks = self.get_seq_num_required_blocks( + seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) + cross_num_required_blocks = self.get_seq_num_required_blocks( + seq_group.get_encoder_seq()) + num_required_blocks = self_num_required_blocks + \ + cross_num_required_blocks if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, @@ -328,7 +331,8 @@ def allocate_cross_block_table(self, seq_group: SequenceGroup) -> None: for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None and logical_idx >= self.block_sliding_window): - block = block_table[logical_idx % self.block_sliding_window] + block = block_table[logical_idx % + self.block_sliding_window] # Set the reference counts of the token blocks. block.ref_count = seq_group.num_seqs() elif self.enable_caching: @@ -550,7 +554,7 @@ def swap_in(self, self.block_tables[seq.seq_id] = new_block_table if seq_group.encoder_seq is not None: - new_block_table: BlockTable = [] + new_block_table = [] block_table = self.cross_block_tables[request_id] for cpu_block in block_table: @@ -601,7 +605,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self.block_tables[seq.seq_id] = new_block_table if seq_group.encoder_seq is not None: - new_block_table: BlockTable = [] + new_block_table = [] block_table = self.cross_block_tables[request_id] for gpu_block in block_table: diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 426612f615508..4ae3361e7b234 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -148,7 +148,9 @@ def allocate_cross_block_table(self, seq_group: SequenceGroup) -> None: request_id = seq_group.request_id seq = seq_group.encoder_seq - assert not (request_id in self.cross_block_tables), "block table already exists" + assert (request_id + not in self.cross_block_tables), \ + "block table already exists" seq = seq_group.get_encoder_seq() if seq is not None: @@ -236,9 +238,9 @@ def get_block_table(self, seq: Sequence) -> List[int]: def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: request_id = seq_group.request_id assert request_id in self.cross_block_tables - block_ids = self.block_tables[request_id].physical_block_ids + block_ids = self.cross_block_tables[request_id].physical_block_ids assert all(b is not None for b in block_ids) - return block_ids + return block_ids # type: ignore def access_all_blocks_in_seq(self, seq: Sequence, now: float): # Update the last accessed time of all the blocks accessed diff --git a/vllm/sequence.py b/vllm/sequence.py index a73e70c1ae69d..a11c411876ea8 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -528,7 +528,7 @@ def get_seqs( ] def get_encoder_seq(self) -> Sequence: - return self.encoder_seq + return self.encoder_seq # type: ignore def get_unfinished_seqs(self) -> List[Sequence]: return [ @@ -613,8 +613,11 @@ class SequenceGroupMetadata: used in prefix caching. state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. - encoder_seq_data: Optional, the sequence data for the single encoder prompt. - cross_block_table: Optional, the cross-attention block table associated with the single encoder prompt. + encoder_seq_data: Optional, the sequence data + for the single encoder prompt. + cross_block_table: Optional, the cross-attention + block table associated with + the single encoder prompt. """ def __init__( From 8ff1ddf224d742cf3aea6b7fa9f55409b5815b7b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 19:40:45 -0400 Subject: [PATCH 008/443] attention test & xformers backend changes --- tests/layer/test_self_and_cross_attn.py | 466 ++++++++++++++++++++++++ vllm/attention/backends/xformers.py | 29 +- 2 files changed, 489 insertions(+), 6 deletions(-) create mode 100644 tests/layer/test_self_and_cross_attn.py diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py new file mode 100644 index 0000000000000..c29f7224ab57f --- /dev/null +++ b/tests/layer/test_self_and_cross_attn.py @@ -0,0 +1,466 @@ +import random +from typing import List, Optional +import itertools + +import pytest +import torch +import copy +from vllm.attention import Attention, AttentionMetadata, AttentionMetadataPerStage + +from vllm.attention.backends.xformers import XFormersBackend +from vllm.attention.backends.abstract import AttentionBackend + +from vllm.attention.ops.paged_attn import PagedAttention + +from vllm.utils import get_max_shared_memory_bytes +from vllm.utils import is_hip +from vllm.utils import make_tensor_with_pad + +from vllm.attention.layer import Attention + +import random + +# FlashAttention forward only supports head dimension at most 128 +# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 +HEAD_SIZES = [64] + +# [64, 80, 96, 112, 128, 256 +# ] if not is_hip() else [64, 80, 96, 112, 128] + +NUM_HEADS = [1] + +BATCH_SIZES = [16] +BLOCK_SIZES = [16] +#KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] +BACKEND_NAMES = ["xformers"] +#CUDA_DEVICES = [ +# f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +#] + +PROMPT_LENS = [32] + +def build_causal_mask(q_max_prompt_len, k_max_prompt_len): + # Create a matrix where entry (i, j) is True if i >= j + mask = torch.triu(torch.ones(q_max_prompt_len, k_max_prompt_len), diagonal=1) #.transpose(0, 1) + # Replace True with float('-inf') and False with 0 + mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, 0.0) + return mask + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + #query=query.unsqueeze(-2) + #key=key.unsqueeze(-2) + #value=value.unsqueeze(-2) + #assert False,f"{query.shape} ; {key.shape}" + attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() + #assert False,f"{query.shape} ; {key.shape} ; {attn_weights.shape}" + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) + #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" + return out + +def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=True): + if force_max_len: + q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] + kv_prompt_lens = None + if not is_cross_attn: + # K,V prompt lens match Q for self-attention + kv_prompt_lens = q_prompt_lens + else: + # K,V prompt lens come from K,V operands + kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] + else: + q_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] + kv_prompt_lens = None + if not is_cross_attn: + # K,V prompt lens match Q for self-attention + kv_prompt_lens = q_prompt_lens + else: + # K,V prompt lens come from K,V operands + kv_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] + + query=torch.rand((batch_size,max_q_prompt_len,head_size)) + key=torch.rand((batch_size,max_kv_prompt_len,head_size)) + value=torch.rand((batch_size,max_kv_prompt_len,head_size)) + + for bdx,(q_prompt_len,kv_prompt_len) in enumerate(zip(q_prompt_lens,kv_prompt_lens)): + query[bdx,q_prompt_len:] = 0 + key[bdx,kv_prompt_len:] = 0 + value[bdx,kv_prompt_len:] = 0 + + query=query.unsqueeze(-2) + key=key.unsqueeze(-2) + value=value.unsqueeze(-2) + + return query,key,value,q_prompt_lens,kv_prompt_lens + +def pack_tensor(unpacked_tensor,prompt_lens, device='cuda:0'): + num_tok = sum(prompt_lens) + num_heads = unpacked_tensor.shape[-2] + head_size = unpacked_tensor.shape[-1] + start_loc_list = [0]+list(itertools.accumulate(prompt_lens)) + packed_tensor = torch.zeros((num_tok,num_heads,head_size), + device=device) + + #assert False, f"{start_loc_list}" + + #assert False, f"{packed_tensor.shape} ; {unpacked_tensor.shape}" + + for bdx,(prompt_len,start_loc) in enumerate(zip(prompt_lens,start_loc_list)): + try: + packed_tensor[start_loc:(start_loc+prompt_len),:,:] = unpacked_tensor[bdx,:prompt_len,:,:] + except: + assert False, f"{start_loc} ; {prompt_len} ; {packed_tensor.shape} ; {unpacked_tensor.shape}" + + return packed_tensor,start_loc_list + +def pack_qkv(query,key,value,q_prompt_lens,kv_prompt_lens): + packed_query,q_start_loc_list = pack_tensor(query,q_prompt_lens) + packed_key,kv_start_loc_list = pack_tensor(key,kv_prompt_lens) + packed_value,_ = pack_tensor(value,kv_prompt_lens) + packed_query=packed_query.view(-1,packed_query.shape[-1]*packed_query.shape[-2]) + packed_key=packed_key.view(-1,packed_key.shape[-1]*packed_key.shape[-2]) + packed_value=packed_value.view(-1,packed_value.shape[-1]*packed_value.shape[-2]) + return packed_query,packed_key,packed_value,q_start_loc_list,kv_start_loc_list + +def make_backend(backend_name: str) -> AttentionBackend: + if backend_name == "xformers": + return XFormersBackend() + assert False, f"Unrecognized backend_name {backend_name} for unit test" + +def make_stage_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, device='cuda:0', cross_prompt_lens:Optional[List[int]] = None) -> AttentionMetadataPerStage: + ''' + Assumptions: + * No chunked prefill + * No (automatic) prefix caching + * Packed variable-length sequences + ''' + prompt_lens_tensor=torch.tensor(prompt_lens, + dtype=torch.int, + device=device) + context_lens_tensor=None if context_lens is None else torch.tensor(context_lens, + dtype=torch.int, + device=device) + max_query_len=None if prompt_lens is None else max(prompt_lens) + max_context_len=None if context_lens is None else max(context_lens) + max_prompt_len=max_query_len + + seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + + torch.cumsum(prompt_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + query_start_loc = copy.deepcopy(seq_start_loc) + + return attn_backend.make_metadata( + is_prompt=is_prompt, + is_cross_attn=is_cross_attn, + seq_lens=prompt_lens, + seq_lens_tensor=prompt_lens_tensor, + cross_seq_lens=cross_prompt_lens, + max_query_len=max_query_len, + #max_context_len=max_context_len, + max_seq_len=max_prompt_len, + subquery_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + ) + +def make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0', default_val=0.0): + #key_cache = torch.rand((num_blocks, num_heads, head_size//key_read_width, block_size, key_read_width),device=device) + #val_cache = torch.rand((num_blocks, num_heads, head_size, block_size),device=device) + kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to(device) + if default_val is not None: + kv_cache[:,:,:] = default_val + return kv_cache + +def num_tokens_to_min_blocks(num_tokens,block_size): + return (num_tokens+block_size)//block_size + +def make_flat_block_tables_slot_mapping(block_size,prompt_lens): + ''' + Naive block table: + * For each batch element... + * Block table has + ''' + num_tokens = sum(prompt_lens) + num_blocks = num_tokens_to_min_blocks(num_tokens,block_size) + block_tables = list(range(num_blocks*100)) + slot_mapping = [(idx % block_size) + block_tables[idx//block_size]*block_size for idx in range(num_tokens)] + prefill_block_tables_tensor = torch.tensor( + [], + device='cuda:0' + ) + block_tables_tensor = torch.tensor( + block_tables, + device='cuda:0' + ) + slot_mapping_tensor = torch.tensor( + slot_mapping, + dtype=torch.long, + device='cuda:0' + ) + + return block_tables_tensor, slot_mapping_tensor, prefill_block_tables_tensor + +def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): + ''' + Naive block table: + * For each batch element... + * Block table has + ''' + num_prompts = len(prompt_lens) + total_num_tokens = sum(prompt_lens) + # Over-provision block table blocks by 1 + num_blocks_list = [num_tokens_to_min_blocks(num_tokens,block_size)+1 for num_tokens in prompt_lens] + max_block_table_len = max(num_blocks_list) + #block_tables = [list(range(num_blocks*10)) for num_blocks in num_blocks_list] + block_table_pad_tokens = 10 + + block_tables = [] + slot_mapping = [] + block_base_idx = sum(num_blocks_list)*2-1 # Support more blocks than needed + #seq_base_idx = 0 + for sdx,num_tokens in enumerate(prompt_lens): + #num_blocks = num_tokens_to_min_blocks(num_tokens,block_size) + num_blocks = num_blocks_list[sdx] + block_table = list(range(block_base_idx,block_base_idx-num_blocks,-1)) + for idx in range(num_tokens): + slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) + + #seq_base_idx += num_tokens + block_base_idx -= num_blocks + block_tables.append(block_table) + + prefill_block_tables_tensor = torch.tensor( + [], + device='cuda:0' + ) + block_tables_tensor = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len+block_table_pad_tokens, + pad=0, + dtype=torch.int, + device=device, + ) + slot_mapping_tensor = torch.tensor( + slot_mapping, + dtype=torch.long, + device=device + ) + + return block_tables_tensor, slot_mapping_tensor, prefill_block_tables_tensor + + +def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, slot_mapping, device='cuda:0', kv_cache_dtype='auto', cross_prompt_lens:Optional[List[int]] = None): + ''' + Assumptions: + * No chunked prefill -> a batch is 100% prefill or 100% decode, never both + ''' + + if is_prompt: + num_prefills = len(prompt_lens) + num_prefill_tokens = sum(prompt_lens) + num_decode_tokens = 0 + + # make_stage_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, device='cuda:0', cross_prompt_lens:Optional[List[int]] = None) + stage_metadata:AttentionMetadataPerStage = make_stage_metadata(attn_backend, is_prompt, is_cross_attn, prompt_lens, context_lens, block_tables, device=device, cross_prompt_lens=cross_prompt_lens) + + return AttentionMetadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + prefill_metadata=stage_metadata, + decode_metadata=None, + kv_cache_dtype=kv_cache_dtype, + ) + + else: # not is_prompt + + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = sum(context_lens) + + stage_metadata:AttentionMetadataPerStage = make_stage_metadata(attn_backend, is_prompt, is_cross_attn, prompt_lens, context_lens, block_tables, device=device, cross_prompt_lens=cross_prompt_lens) + + return AttentionMetadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + prefill_metadata=None, + decode_metadata=stage_metadata, + kv_cache_dtype=kv_cache_dtype, + ) + +def make_attention(num_heads: int, head_size: int, scale: float): + # Attention operator instance + return Attention(num_heads, + head_size, + scale=scale,) + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size",BLOCK_SIZES) +@pytest.mark.parametrize("max_prompt_len",PROMPT_LENS) +def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_prompt_len: int) -> None: + # Attention operator instance + is_cross_attn=False + device='cuda:0' + kv_cache_dtype='auto' + is_prompt = True + max_q_prompt_len = max_prompt_len + max_kv_prompt_len = max_q_prompt_len + context_lens = None + key_read_width = 4 + num_blocks = 4096 + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, num_heads, head_size) + #(key_cache, value_cache) = kv_cache + scale = float(1.0 / (head_size**0.5)) + attn = make_attention(num_heads, head_size, scale) + attn_backend = make_backend(backend_name) + query,key,value,q_prompt_lens,kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) + #block_tables, slot_mapping = make_block_tables_slot_mapping(block_size,q_prompt_lens) + #prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) + ideal_output = ref_masked_attention( + query, + key, + value, + scale=scale, + attn_mask=causal_mask + ) + + prefill_query = query[:,:-1] + prefill_key = key[:,:-1] + prefill_value = value[:,:-1] + decode_query = query[:,-1:] + decode_key = key[:,-1:] + decode_value = value[:,-1:] + prefill_q_prompt_lens = [plen-1 for plen in q_prompt_lens] + prefill_kv_prompt_lens = [plen-1 for plen in kv_prompt_lens] + decode_q_prompt_lens = [1 for _ in q_prompt_lens] + decode_kv_prompt_lens = [1 for _ in kv_prompt_lens] + prefill_ideal_output = ideal_output[:,:-1] + prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) + decode_ideal_output = ideal_output[:,-1:] + decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) + + block_tables, slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,prefill_q_prompt_lens) + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + + prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) + + prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + + # eval correctness of prefill output + assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) + + # Put KVs in KV cache + # Deprecated - handled automatically inside attention + # PagedAttention.write_to_paged_cache(key, value, key_cache, + # value_cache, + # prefill_attn_metadata.slot_mapping, + # prefill_attn_metadata.kv_cache_dtype, + # scale) + + is_prompt = False + context_lens = [1 for _ in range(batch_size)] + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) + + decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) + + decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) + + # eval correctness of decode output + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size",BLOCK_SIZES) +@pytest.mark.parametrize("max_q_prompt_len",PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len",PROMPT_LENS) +def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: + # Attention operator instance + is_cross_attn=True + device='cuda:0' + kv_cache_dtype='auto' + is_prompt = True + #max_q_prompt_len = max_prompt_len + #max_kv_prompt_len = max_prompt_len + context_lens = None + key_read_width = 4 + num_blocks = 4096 + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') + # key_cache, value_cache = PagedAttention.split_kv_cache( + # kv_cache, num_heads, head_size) + #(key_cache, value_cache) = kv_cache + scale = float(1.0 / (head_size**0.5)) + attn = make_attention(num_heads, head_size, scale) + attn_backend = make_backend(backend_name) + query,key,value,q_prompt_lens,kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=True) + #block_tables, slot_mapping = make_block_tables_slot_mapping(block_size,q_prompt_lens) + #prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + #causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) + ideal_output = ref_masked_attention( + query, + key, + value, + scale=scale, + #attn_mask=causal_mask + ) + + prefill_query = query[:,:-1] + prefill_key = key #key[:,:-1] + prefill_value = value #value[:,:-1] + decode_query = query[:,-1:] + decode_key = key #key[:,-1:] + decode_value = value #value[:,-1:] + prefill_q_prompt_lens = [plen-1 for plen in q_prompt_lens] + prefill_kv_prompt_lens = kv_prompt_lens + decode_q_prompt_lens = [1 for _ in q_prompt_lens] + decode_kv_prompt_lens = kv_prompt_lens + prefill_ideal_output = ideal_output[:,:-1] + prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) + decode_ideal_output = ideal_output[:,-1:] + decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) + + block_tables, slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,prefill_kv_prompt_lens) + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=prefill_kv_prompt_lens) + + prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) + + prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + + # eval correctness of prefill output + assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) + + is_prompt = False + context_lens = [1 for _ in range(batch_size)] + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) + + decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) + + decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) + + # eval correctness of decode output + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) \ No newline at end of file diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index fc46af054de4f..0a35a41a69a93 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -5,6 +5,7 @@ import torch from xformers import ops as xops from xformers.ops.fmha.attn_bias import (AttentionBias, + BlockDiagonalMask, BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias) @@ -108,6 +109,14 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): _cached_prefill_metadata: Optional["XFormersMetadata"] = None _cached_decode_metadata: Optional["XFormersMetadata"] = None + # Need to make KV cache read-only for cross-attention + is_cross_attn: bool = False + + # (batch_size,). The "cross-sequence-length" per sequence,i.e. the key/value + # sequence length (usually encoder sequence length) in the cross-attention + # computation. None if this is self-attention + cross_seq_lens: Optional[List[int]] = None + def __post_init__(self): # Set during the execution of the first attention op. # It is a list because it is needed to set per prompt @@ -270,16 +279,20 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + is_cross_attn = (attn_metadata.prefill_metadata is not None and attn_metadata.prefill_metadata.is_cross_attn) or (attn_metadata.decode_metadata is not None and attn_metadata.decode_metadata.is_cross_attn) + assert is_cross_attn or (key.shape[0] == num_prefill_tokens + num_decode_tokens) + assert is_cross_attn or (value.shape[0] == num_prefill_tokens + num_decode_tokens) output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] + + if not is_cross_attn: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens @@ -374,8 +387,12 @@ def _run_memory_efficient_xformers_forward( # FIXME(woosuk): This is a hack. if attn_metadata.attn_bias is None: if self.alibi_slopes is None: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seq_lens) + if attn_metadata.is_cross_attn: + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens,attn_metadata.cross_seq_lens) + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) From 685afc07eaa6560821676e8bb212d88ed09a206b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 16 May 2024 12:52:49 -0400 Subject: [PATCH 009/443] wip attn tests --- tests/layer/__init__.py | 0 tests/layer/test_self_and_cross_attn.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 tests/layer/__init__.py diff --git a/tests/layer/__init__.py b/tests/layer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index c29f7224ab57f..2da312b03afaf 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -5,7 +5,7 @@ import pytest import torch import copy -from vllm.attention import Attention, AttentionMetadata, AttentionMetadataPerStage +from vllm.attention import Attention, AttentionMetadata #, AttentionMetadataPerStage from vllm.attention.backends.xformers import XFormersBackend from vllm.attention.backends.abstract import AttentionBackend From 5278f13426fb17574b6060fb62afcc56e1b20077 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 17 May 2024 10:58:48 -0400 Subject: [PATCH 010/443] wip self attention test --- tests/layer/test_self_and_cross_attn.py | 128 +++++++++++++++++------- 1 file changed, 92 insertions(+), 36 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 2da312b03afaf..96419a54b7ad6 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -135,7 +135,7 @@ def make_backend(backend_name: str) -> AttentionBackend: return XFormersBackend() assert False, f"Unrecognized backend_name {backend_name} for unit test" -def make_stage_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, device='cuda:0', cross_prompt_lens:Optional[List[int]] = None) -> AttentionMetadataPerStage: +def make_metadata_tensors(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, device='cuda:0', cross_prompt_lens:Optional[List[int]] = None) -> tuple: ''' Assumptions: * No chunked prefill @@ -156,27 +156,35 @@ def make_stage_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_ dtype=torch.int32, device=device) - torch.cumsum(prompt_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) + # torch.cumsum(prompt_lens_tensor, + # dim=0, + # dtype=seq_start_loc.dtype, + # out=seq_start_loc[1:]) query_start_loc = copy.deepcopy(seq_start_loc) - return attn_backend.make_metadata( - is_prompt=is_prompt, - is_cross_attn=is_cross_attn, - seq_lens=prompt_lens, - seq_lens_tensor=prompt_lens_tensor, - cross_seq_lens=cross_prompt_lens, - max_query_len=max_query_len, - #max_context_len=max_context_len, - max_seq_len=max_prompt_len, - subquery_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - ) + return prompt_lens_tensor, \ + context_lens_tensor, \ + max_query_len, \ + max_context_len, \ + max_prompt_len, \ + seq_start_loc, \ + query_start_loc + + # return attn_backend.make_metadata( + # is_prompt=is_prompt, + # is_cross_attn=is_cross_attn, + # seq_lens=prompt_lens, + # seq_lens_tensor=prompt_lens_tensor, + # cross_seq_lens=cross_prompt_lens, + # max_query_len=max_query_len, + # #max_context_len=max_context_len, + # max_seq_len=max_prompt_len, + # subquery_start_loc=query_start_loc, + # seq_start_loc=seq_start_loc, + # context_lens_tensor=context_lens_tensor, + # block_tables=block_tables, + # use_cuda_graph=False, + # ) def make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0', default_val=0.0): #key_cache = torch.rand((num_blocks, num_heads, head_size//key_read_width, block_size, key_read_width),device=device) @@ -275,17 +283,40 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b num_prefill_tokens = sum(prompt_lens) num_decode_tokens = 0 - # make_stage_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, device='cuda:0', cross_prompt_lens:Optional[List[int]] = None) - stage_metadata:AttentionMetadataPerStage = make_stage_metadata(attn_backend, is_prompt, is_cross_attn, prompt_lens, context_lens, block_tables, device=device, cross_prompt_lens=cross_prompt_lens) - - return AttentionMetadata( + prompt_lens_tensor, \ + context_lens_tensor, \ + max_query_len, \ + max_context_len, \ + max_prompt_len, \ + seq_start_loc, \ + query_start_loc = make_metadata_tensors(attn_backend, + is_prompt, + is_cross_attn, + prompt_lens, + context_lens, + block_tables, + device=device, + cross_prompt_lens=cross_prompt_lens) + + slot_mapping_tensor=torch.tensor(slot_mapping, + dtype=torch.long, + device=device) + + return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=slot_mapping, + slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, - prefill_metadata=stage_metadata, - decode_metadata=None, - kv_cache_dtype=kv_cache_dtype, + seq_lens=prompt_lens, + seq_lens_tensor=prompt_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max(prompt_lens), + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, ) else: # not is_prompt @@ -294,16 +325,40 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b num_prefill_tokens = 0 num_decode_tokens = sum(context_lens) - stage_metadata:AttentionMetadataPerStage = make_stage_metadata(attn_backend, is_prompt, is_cross_attn, prompt_lens, context_lens, block_tables, device=device, cross_prompt_lens=cross_prompt_lens) - - return AttentionMetadata( + prompt_lens_tensor, \ + context_lens_tensor, \ + max_query_len, \ + max_context_len, \ + max_prompt_len, \ + seq_start_loc, \ + query_start_loc = make_metadata_tensors(attn_backend, + is_prompt, + is_cross_attn, + prompt_lens, + context_lens, + block_tables, + device=device, + cross_prompt_lens=cross_prompt_lens) + + slot_mapping_tensor=torch.tensor(slot_mapping, + dtype=torch.long, + device=device) + + return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=slot_mapping, + slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, - prefill_metadata=None, - decode_metadata=stage_metadata, - kv_cache_dtype=kv_cache_dtype, + seq_lens=prompt_lens, + seq_lens_tensor=prompt_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max(prompt_lens), + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, ) def make_attention(num_heads: int, head_size: int, scale: float): @@ -326,7 +381,7 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n is_prompt = True max_q_prompt_len = max_prompt_len max_kv_prompt_len = max_q_prompt_len - context_lens = None + context_lens = [0 for _ in range(batch_size)] key_read_width = 4 num_blocks = 4096 kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') @@ -392,6 +447,7 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n # eval correctness of decode output assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) +@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) From 1824a9523064f6410f766cc22a032db54b7a429a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 17 May 2024 13:45:59 -0400 Subject: [PATCH 011/443] tests run but do not pass --- tests/layer/test_self_and_cross_attn.py | 31 ++++++++++++++++--------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 96419a54b7ad6..61ab4075e0597 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -148,19 +148,28 @@ def make_metadata_tensors(attn_backend:AttentionBackend, is_prompt:bool, is_cros context_lens_tensor=None if context_lens is None else torch.tensor(context_lens, dtype=torch.int, device=device) - max_query_len=None if prompt_lens is None else max(prompt_lens) max_context_len=None if context_lens is None else max(context_lens) - max_prompt_len=max_query_len + max_prompt_len=None if prompt_lens is None else max(prompt_lens) seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) - # torch.cumsum(prompt_lens_tensor, - # dim=0, - # dtype=seq_start_loc.dtype, - # out=seq_start_loc[1:]) - query_start_loc = copy.deepcopy(seq_start_loc) + torch.cumsum(prompt_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + if is_prompt: + # Prefill: query_start_loc matches seq_start_loc + query_start_loc = copy.deepcopy(seq_start_loc) + max_query_len=max_prompt_len + else: + # Decode: one new query input token per batch + # element, thus query_start_loc is the cumsum + # of [1,1,1,...] + query_start_loc = list(range(len(seq_start_loc))) + max_query_len = 1 return prompt_lens_tensor, \ context_lens_tensor, \ @@ -323,7 +332,7 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b num_prefills = 0 num_prefill_tokens = 0 - num_decode_tokens = sum(context_lens) + num_decode_tokens = len(prompt_lens) prompt_lens_tensor, \ context_lens_tensor, \ @@ -352,8 +361,8 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b seq_lens=prompt_lens, seq_lens_tensor=prompt_lens_tensor, max_query_len=max_query_len, - max_prefill_seq_len=max(prompt_lens), - max_decode_seq_len=0, + max_prefill_seq_len=0, + max_decode_seq_len=max(prompt_lens), query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, @@ -437,7 +446,7 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n # scale) is_prompt = False - context_lens = [1 for _ in range(batch_size)] + context_lens = copy.deepcopy(prefill_kv_prompt_lens) decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) From 64b7b6154c077c78d84cc26f67011e3a04a3d9b0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 17 May 2024 15:06:38 -0400 Subject: [PATCH 012/443] passing self-attention --- tests/layer/test_self_and_cross_attn.py | 38 ++++++++++++------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 61ab4075e0597..fd0682996c8da 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -12,8 +12,6 @@ from vllm.attention.ops.paged_attn import PagedAttention -from vllm.utils import get_max_shared_memory_bytes -from vllm.utils import is_hip from vllm.utils import make_tensor_with_pad from vllm.attention.layer import Attention @@ -247,15 +245,18 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): block_table_pad_tokens = 10 block_tables = [] - slot_mapping = [] + prefill_slot_mapping = [] + decode_slot_mapping = [] block_base_idx = sum(num_blocks_list)*2-1 # Support more blocks than needed #seq_base_idx = 0 for sdx,num_tokens in enumerate(prompt_lens): #num_blocks = num_tokens_to_min_blocks(num_tokens,block_size) num_blocks = num_blocks_list[sdx] block_table = list(range(block_base_idx,block_base_idx-num_blocks,-1)) - for idx in range(num_tokens): - slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) + for idx in range(num_tokens-1): + prefill_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) + idx = num_tokens-1 + decode_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) #seq_base_idx += num_tokens block_base_idx -= num_blocks @@ -265,20 +266,25 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): [], device='cuda:0' ) - block_tables_tensor = make_tensor_with_pad( + decode_block_tables_tensor = make_tensor_with_pad( block_tables, max_len=max_block_table_len+block_table_pad_tokens, pad=0, dtype=torch.int, device=device, ) - slot_mapping_tensor = torch.tensor( - slot_mapping, + prefill_slot_mapping_tensor = torch.tensor( + prefill_slot_mapping, + dtype=torch.long, + device=device + ) + decode_slot_mapping_tensor = torch.tensor( + decode_slot_mapping, dtype=torch.long, device=device ) - return block_tables_tensor, slot_mapping_tensor, prefill_block_tables_tensor + return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, slot_mapping, device='cuda:0', kv_cache_dtype='auto', cross_prompt_lens:Optional[List[int]] = None): @@ -427,8 +433,8 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_ideal_output = ideal_output[:,-1:] decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) - block_tables, slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,prefill_q_prompt_lens) - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,q_prompt_lens) + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) @@ -437,17 +443,9 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n # eval correctness of prefill output assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) - # Put KVs in KV cache - # Deprecated - handled automatically inside attention - # PagedAttention.write_to_paged_cache(key, value, key_cache, - # value_cache, - # prefill_attn_metadata.slot_mapping, - # prefill_attn_metadata.kv_cache_dtype, - # scale) - is_prompt = False context_lens = copy.deepcopy(prefill_kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) From c99aa0dcc579a22fcaf698c30169c5d7bfb69898 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 17 May 2024 16:01:50 -0400 Subject: [PATCH 013/443] passing self-attention test with variable lengths! --- tests/layer/test_self_and_cross_attn.py | 124 ++++++++++++++++-------- 1 file changed, 85 insertions(+), 39 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index fd0682996c8da..5cabef7e5f8b6 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -64,40 +64,79 @@ def ref_masked_attention( #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" return out -def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=True): - if force_max_len: - q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] - kv_prompt_lens = None - if not is_cross_attn: - # K,V prompt lens match Q for self-attention - kv_prompt_lens = q_prompt_lens - else: - # K,V prompt lens come from K,V operands - kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] +def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True): + q_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] + kv_prompt_lens = None + if not is_cross_attn: + # K,V prompt lens match Q for self-attention + kv_prompt_lens = q_prompt_lens else: - q_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] - kv_prompt_lens = None - if not is_cross_attn: - # K,V prompt lens match Q for self-attention - kv_prompt_lens = q_prompt_lens - else: - # K,V prompt lens come from K,V operands - kv_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] - + # K,V prompt lens come from K,V operands + kv_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] + + actual_max_q_prompt_len = max(q_prompt_lens) + actual_max_kv_prompt_len = max(kv_prompt_lens) + query=torch.rand((batch_size,max_q_prompt_len,head_size)) key=torch.rand((batch_size,max_kv_prompt_len,head_size)) value=torch.rand((batch_size,max_kv_prompt_len,head_size)) + prefill_query=torch.zeros((batch_size,max_q_prompt_len-1,head_size)) + prefill_key=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)) + prefill_value=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)) + + decode_query=torch.zeros((batch_size,1,head_size)) + decode_key=torch.zeros((batch_size,1,head_size)) + decode_value=torch.zeros((batch_size,1,head_size)) + for bdx,(q_prompt_len,kv_prompt_len) in enumerate(zip(q_prompt_lens,kv_prompt_lens)): - query[bdx,q_prompt_len:] = 0 - key[bdx,kv_prompt_len:] = 0 - value[bdx,kv_prompt_len:] = 0 + query[bdx,q_prompt_len:,:] = 0 + key[bdx,kv_prompt_len:,:] = 0 + value[bdx,kv_prompt_len:,:] = 0 + + prefill_query[bdx,0:(q_prompt_len-1),:] = query[bdx,0:(q_prompt_len-1),:] + prefill_key[bdx,0:(kv_prompt_len-1),:] = key[bdx,0:(kv_prompt_len-1),:] + prefill_value[bdx,0:(kv_prompt_len-1),:] = value[bdx,0:(kv_prompt_len-1),:] + + decode_query[bdx,:,:] = query[bdx,(q_prompt_len-1):q_prompt_len,:] + decode_key[bdx,:,:] = key[bdx,(kv_prompt_len-1):kv_prompt_len,:] + decode_value[bdx,:,:] = value[bdx,(kv_prompt_len-1):kv_prompt_len,:] + + prefill_q_prompt_lens = [plen - 1 for plen in q_prompt_lens] + prefill_kv_prompt_lens = [plen - 1 for plen in kv_prompt_lens] + + decode_q_prompt_lens = [1 for _ in q_prompt_lens] + decode_kv_prompt_lens = [1 for _ in kv_prompt_lens] query=query.unsqueeze(-2) key=key.unsqueeze(-2) value=value.unsqueeze(-2) - return query,key,value,q_prompt_lens,kv_prompt_lens + prefill_query=prefill_query.unsqueeze(-2) + prefill_key=prefill_key.unsqueeze(-2) + prefill_value=prefill_value.unsqueeze(-2) + + decode_query=decode_query.unsqueeze(-2) + decode_key=decode_key.unsqueeze(-2) + decode_value=decode_value.unsqueeze(-2) + + return query, \ + key, \ + value, \ + prefill_query, \ + prefill_key, \ + prefill_value, \ + decode_query, \ + decode_key, \ + decode_value, \ + q_prompt_lens, \ + kv_prompt_lens, \ + actual_max_q_prompt_len, \ + actual_max_kv_prompt_len, \ + prefill_q_prompt_lens, \ + prefill_kv_prompt_lens, \ + decode_q_prompt_lens, \ + decode_kv_prompt_lens def pack_tensor(unpacked_tensor,prompt_lens, device='cuda:0'): num_tok = sum(prompt_lens) @@ -400,13 +439,26 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n key_read_width = 4 num_blocks = 4096 kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, num_heads, head_size) - #(key_cache, value_cache) = kv_cache scale = float(1.0 / (head_size**0.5)) attn = make_attention(num_heads, head_size, scale) attn_backend = make_backend(backend_name) - query,key,value,q_prompt_lens,kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) + query, \ + key, \ + value, \ + prefill_query, \ + prefill_key, \ + prefill_value, \ + decode_query, \ + decode_key, \ + decode_value, \ + q_prompt_lens, \ + kv_prompt_lens, \ + actual_max_q_prompt_len, \ + actual_max_kv_prompt_len, \ + prefill_q_prompt_lens, \ + prefill_kv_prompt_lens, \ + decode_q_prompt_lens, \ + decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) #block_tables, slot_mapping = make_block_tables_slot_mapping(block_size,q_prompt_lens) #prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) @@ -418,19 +470,13 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n attn_mask=causal_mask ) - prefill_query = query[:,:-1] - prefill_key = key[:,:-1] - prefill_value = value[:,:-1] - decode_query = query[:,-1:] - decode_key = key[:,-1:] - decode_value = value[:,-1:] - prefill_q_prompt_lens = [plen-1 for plen in q_prompt_lens] - prefill_kv_prompt_lens = [plen-1 for plen in kv_prompt_lens] - decode_q_prompt_lens = [1 for _ in q_prompt_lens] - decode_kv_prompt_lens = [1 for _ in kv_prompt_lens] - prefill_ideal_output = ideal_output[:,:-1] + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) + for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] + decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] + prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) - decode_ideal_output = ideal_output[:,-1:] decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,q_prompt_lens) From 270d95e78af170692d19bc9d37d6badec5a91869 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 17 May 2024 17:16:44 -0400 Subject: [PATCH 014/443] wip cross-attention; is_cross_atn and cross_seq_lens is transferred from parent metadata struct to child metadata structs; cross-attn test runs without functional errors but fails all_close --- tests/layer/test_self_and_cross_attn.py | 147 ++++++++++-------------- vllm/attention/backends/xformers.py | 6 +- 2 files changed, 68 insertions(+), 85 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 5cabef7e5f8b6..d5ca2c5ccbc5b 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -146,10 +146,6 @@ def pack_tensor(unpacked_tensor,prompt_lens, device='cuda:0'): packed_tensor = torch.zeros((num_tok,num_heads,head_size), device=device) - #assert False, f"{start_loc_list}" - - #assert False, f"{packed_tensor.shape} ; {unpacked_tensor.shape}" - for bdx,(prompt_len,start_loc) in enumerate(zip(prompt_lens,start_loc_list)): try: packed_tensor[start_loc:(start_loc+prompt_len),:,:] = unpacked_tensor[bdx,:prompt_len,:,:] @@ -216,22 +212,6 @@ def make_metadata_tensors(attn_backend:AttentionBackend, is_prompt:bool, is_cros seq_start_loc, \ query_start_loc - # return attn_backend.make_metadata( - # is_prompt=is_prompt, - # is_cross_attn=is_cross_attn, - # seq_lens=prompt_lens, - # seq_lens_tensor=prompt_lens_tensor, - # cross_seq_lens=cross_prompt_lens, - # max_query_len=max_query_len, - # #max_context_len=max_context_len, - # max_seq_len=max_prompt_len, - # subquery_start_loc=query_start_loc, - # seq_start_loc=seq_start_loc, - # context_lens_tensor=context_lens_tensor, - # block_tables=block_tables, - # use_cuda_graph=False, - # ) - def make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0', default_val=0.0): #key_cache = torch.rand((num_blocks, num_heads, head_size//key_read_width, block_size, key_read_width),device=device) #val_cache = torch.rand((num_blocks, num_heads, head_size, block_size),device=device) @@ -243,32 +223,6 @@ def make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, def num_tokens_to_min_blocks(num_tokens,block_size): return (num_tokens+block_size)//block_size -def make_flat_block_tables_slot_mapping(block_size,prompt_lens): - ''' - Naive block table: - * For each batch element... - * Block table has - ''' - num_tokens = sum(prompt_lens) - num_blocks = num_tokens_to_min_blocks(num_tokens,block_size) - block_tables = list(range(num_blocks*100)) - slot_mapping = [(idx % block_size) + block_tables[idx//block_size]*block_size for idx in range(num_tokens)] - prefill_block_tables_tensor = torch.tensor( - [], - device='cuda:0' - ) - block_tables_tensor = torch.tensor( - block_tables, - device='cuda:0' - ) - slot_mapping_tensor = torch.tensor( - slot_mapping, - dtype=torch.long, - device='cuda:0' - ) - - return block_tables_tensor, slot_mapping_tensor, prefill_block_tables_tensor - def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): ''' Naive block table: @@ -286,6 +240,7 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): block_tables = [] prefill_slot_mapping = [] decode_slot_mapping = [] + slot_mapping = [] block_base_idx = sum(num_blocks_list)*2-1 # Support more blocks than needed #seq_base_idx = 0 for sdx,num_tokens in enumerate(prompt_lens): @@ -294,8 +249,10 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): block_table = list(range(block_base_idx,block_base_idx-num_blocks,-1)) for idx in range(num_tokens-1): prefill_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) + slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) idx = num_tokens-1 decode_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) + slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) #seq_base_idx += num_tokens block_base_idx -= num_blocks @@ -322,8 +279,18 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): dtype=torch.long, device=device ) + slot_mapping_tensor = torch.tensor( + slot_mapping, + dtype=torch.long, + device=device + ) + empty_slot_mapping_tensor = torch.tensor( + [], + dtype=torch.long, + device=device + ) - return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor + return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, slot_mapping, device='cuda:0', kv_cache_dtype='auto', cross_prompt_lens:Optional[List[int]] = None): @@ -371,6 +338,8 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, + is_cross_attn=is_cross_attn, + cross_seq_lens=cross_prompt_lens ) else: # not is_prompt @@ -413,6 +382,8 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, + is_cross_attn=is_cross_attn, + cross_seq_lens=cross_prompt_lens ) def make_attention(num_heads: int, head_size: int, scale: float): @@ -421,6 +392,7 @@ def make_attention(num_heads: int, head_size: int, scale: float): head_size, scale=scale,) +@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -452,15 +424,14 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_key, \ decode_value, \ q_prompt_lens, \ - kv_prompt_lens, \ - actual_max_q_prompt_len, \ - actual_max_kv_prompt_len, \ + _, \ + _, \ + _, \ prefill_q_prompt_lens, \ prefill_kv_prompt_lens, \ decode_q_prompt_lens, \ decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) - #block_tables, slot_mapping = make_block_tables_slot_mapping(block_size,q_prompt_lens) - #prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) ideal_output = ref_masked_attention( query, @@ -479,10 +450,10 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) - decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,q_prompt_lens) + decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping(block_size,q_prompt_lens) prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) - prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) + prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) @@ -493,14 +464,13 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n context_lens = copy.deepcopy(prefill_kv_prompt_lens) decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) - decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) + decode_packed_query,decode_packed_key,decode_packed_value,_,_ = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) # eval correctness of decode output assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) -@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -514,21 +484,32 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ device='cuda:0' kv_cache_dtype='auto' is_prompt = True - #max_q_prompt_len = max_prompt_len - #max_kv_prompt_len = max_prompt_len - context_lens = None + context_lens = [0 for _ in range(batch_size)] key_read_width = 4 num_blocks = 4096 kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') - # key_cache, value_cache = PagedAttention.split_kv_cache( - # kv_cache, num_heads, head_size) - #(key_cache, value_cache) = kv_cache scale = float(1.0 / (head_size**0.5)) attn = make_attention(num_heads, head_size, scale) attn_backend = make_backend(backend_name) - query,key,value,q_prompt_lens,kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=True) - #block_tables, slot_mapping = make_block_tables_slot_mapping(block_size,q_prompt_lens) - #prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + + query, \ + key, \ + value, \ + prefill_query, \ + _, \ + _, \ + decode_query, \ + _, \ + _, \ + q_prompt_lens, \ + kv_prompt_lens, \ + actual_max_q_prompt_len, \ + actual_max_kv_prompt_len, \ + prefill_q_prompt_lens, \ + _, \ + decode_q_prompt_lens, \ + _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=is_cross_attn) + #causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) ideal_output = ref_masked_attention( query, @@ -538,25 +519,23 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ #attn_mask=causal_mask ) - prefill_query = query[:,:-1] - prefill_key = key #key[:,:-1] - prefill_value = value #value[:,:-1] - decode_query = query[:,-1:] - decode_key = key #key[:,-1:] - decode_value = value #value[:,-1:] - prefill_q_prompt_lens = [plen-1 for plen in q_prompt_lens] - prefill_kv_prompt_lens = kv_prompt_lens - decode_q_prompt_lens = [1 for _ in q_prompt_lens] - decode_kv_prompt_lens = kv_prompt_lens - prefill_ideal_output = ideal_output[:,:-1] + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) + for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] + decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] + prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) - decode_ideal_output = ideal_output[:,-1:] decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) - block_tables, slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,prefill_kv_prompt_lens) - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=prefill_kv_prompt_lens) + # Unlike self-attention: + # - Prefill slot-mapping includes all key slots + # - Decode slot-mapping is empty + decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping(block_size,kv_prompt_lens) + + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) - prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) + prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) @@ -564,12 +543,12 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) is_prompt = False - context_lens = [1 for _ in range(batch_size)] - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) + context_lens = copy.deepcopy(kv_prompt_lens) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) - decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) + decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,key,value,decode_q_prompt_lens,kv_prompt_lens) - decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) + decode_packed_actual_output=attn.forward(decode_packed_query,None,None,kv_cache,decode_attn_metadata,scale) # eval correctness of decode output assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) \ No newline at end of file diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3f5752ab4d445..0ef80ca410355 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -154,6 +154,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, + is_cross_attn=self.is_cross_attn, + cross_seq_lens=self.cross_seq_lens ) return self._cached_prefill_metadata @@ -182,6 +184,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, + is_cross_attn=self.is_cross_attn, + cross_seq_lens=self.cross_seq_lens ) return self._cached_decode_metadata @@ -280,7 +284,7 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - is_cross_attn = (attn_metadata.prefill_metadata is not None and attn_metadata.prefill_metadata.is_cross_attn) or (attn_metadata.decode_metadata is not None and attn_metadata.decode_metadata.is_cross_attn) + is_cross_attn = attn_metadata.is_cross_attn assert is_cross_attn or (key.shape[0] == num_prefill_tokens + num_decode_tokens) assert is_cross_attn or (value.shape[0] == num_prefill_tokens + num_decode_tokens) From c41079153e5f2d4b1f9a21dc92d4e6e7f99b2182 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 18 May 2024 15:04:41 -0400 Subject: [PATCH 015/443] moved ideal to cuda --- tests/layer/test_self_and_cross_attn.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index d5ca2c5ccbc5b..feab7aad97d8a 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -27,7 +27,7 @@ NUM_HEADS = [1] -BATCH_SIZES = [16] +BATCH_SIZES = [1] BLOCK_SIZES = [16] #KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] BACKEND_NAMES = ["xformers"] @@ -77,17 +77,17 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) - query=torch.rand((batch_size,max_q_prompt_len,head_size)) - key=torch.rand((batch_size,max_kv_prompt_len,head_size)) - value=torch.rand((batch_size,max_kv_prompt_len,head_size)) + query=torch.rand((batch_size,max_q_prompt_len,head_size)).cuda() + key=torch.rand((batch_size,max_kv_prompt_len,head_size)).cuda() + value=torch.rand((batch_size,max_kv_prompt_len,head_size)).cuda() - prefill_query=torch.zeros((batch_size,max_q_prompt_len-1,head_size)) - prefill_key=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)) - prefill_value=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)) + prefill_query=torch.zeros((batch_size,max_q_prompt_len-1,head_size)).cuda() + prefill_key=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)).cuda() + prefill_value=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)).cuda() - decode_query=torch.zeros((batch_size,1,head_size)) - decode_key=torch.zeros((batch_size,1,head_size)) - decode_value=torch.zeros((batch_size,1,head_size)) + decode_query=torch.zeros((batch_size,1,head_size)).cuda() + decode_key=torch.zeros((batch_size,1,head_size)).cuda() + decode_value=torch.zeros((batch_size,1,head_size)).cuda() for bdx,(q_prompt_len,kv_prompt_len) in enumerate(zip(q_prompt_lens,kv_prompt_lens)): query[bdx,q_prompt_len:,:] = 0 @@ -432,7 +432,7 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_q_prompt_lens, \ decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) - causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) + causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).cuda() ideal_output = ref_masked_attention( query, key, From 3719f5c2dff4532720703119bb141ac5dd3c9053 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 18 May 2024 16:18:36 -0400 Subject: [PATCH 016/443] wip cross-attention; appears to be power-of-two key-length requirement? --- tests/layer/test_self_and_cross_attn.py | 22 ++++++++++++++++------ vllm/attention/backends/xformers.py | 4 ++-- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index feab7aad97d8a..1621ee99d4a50 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -35,7 +35,11 @@ # f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) #] -PROMPT_LENS = [32] +PROMPT_LENS = [8] + +Q_PROMPT_LENS = [7] + +K_PROMPT_LENS = [32] def build_causal_mask(q_max_prompt_len, k_max_prompt_len): # Create a matrix where entry (i, j) is True if i >= j @@ -64,15 +68,21 @@ def ref_masked_attention( #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" return out -def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True): - q_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] +def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=False): + if force_max_len: + q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] + else: + q_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] kv_prompt_lens = None if not is_cross_attn: # K,V prompt lens match Q for self-attention kv_prompt_lens = q_prompt_lens else: # K,V prompt lens come from K,V operands - kv_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] + if force_max_len: + kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] + else: + kv_prompt_lens = [random.randint(1, max_kv_prompt_len) for _ in range(batch_size)] actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) @@ -476,8 +486,8 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size",BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len",PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len",PROMPT_LENS) +@pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: # Attention operator instance is_cross_attn=True diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0ef80ca410355..46862d72cd7f9 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -392,8 +392,8 @@ def _run_memory_efficient_xformers_forward( if attn_metadata.attn_bias is None: if self.alibi_slopes is None: if attn_metadata.is_cross_attn: - attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens,attn_metadata.cross_seq_lens) + attn_bias = None #BlockDiagonalMask.from_seqlens( + # attn_metadata.seq_lens,attn_metadata.cross_seq_lens) else: attn_bias = BlockDiagonalCausalMask.from_seqlens( attn_metadata.seq_lens) From 96082e155e501951890c25fbefa387805a817452 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 18 May 2024 17:35:01 -0400 Subject: [PATCH 017/443] trying to debug cross-attention issue --- tests/layer/test_self_and_cross_attn.py | 13 ++++++++----- vllm/attention/backends/xformers.py | 3 ++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 1621ee99d4a50..8c3b3b0c0e359 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -37,9 +37,9 @@ PROMPT_LENS = [8] -Q_PROMPT_LENS = [7] +Q_PROMPT_LENS = [128] -K_PROMPT_LENS = [32] +K_PROMPT_LENS = [128] def build_causal_mask(q_max_prompt_len, k_max_prompt_len): # Create a matrix where entry (i, j) is True if i >= j @@ -68,7 +68,9 @@ def ref_masked_attention( #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" return out -def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=False): +def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=True): + assert max_kv_prompt_len >= max_q_prompt_len + if force_max_len: q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] else: @@ -82,7 +84,8 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a if force_max_len: kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] else: - kv_prompt_lens = [random.randint(1, max_kv_prompt_len) for _ in range(batch_size)] + kv_prompt_lens = [16*((q_prompt_len + random.randint(0, max_kv_prompt_len-q_prompt_len))//16) + for q_prompt_len,_ in zip(q_prompt_lens,range(batch_size))] actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) @@ -547,7 +550,7 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) - prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + prefill_packed_actual_output=attn.forward(prefill_packed_query.contiguous(),prefill_packed_key.contiguous(),prefill_packed_value.contiguous(),kv_cache,prefill_attn_metadata,scale) # eval correctness of prefill output assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 46862d72cd7f9..21d828edefc78 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -420,7 +420,8 @@ def _run_memory_efficient_xformers_forward( value, attn_bias=attn_metadata.attn_bias[0], p=0.0, - scale=self.scale) + scale=self.scale, + op=xops.MemoryEfficientAttentionOp()) return out.view_as(original_query) # Attention with alibi slopes. From d99c5d94164dc7695b5933465cce15c74d39609a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 11:07:04 -0400 Subject: [PATCH 018/443] wip --- tests/layer/test_self_and_cross_attn.py | 251 +++++++++++++++++++++++- vllm/attention/backends/xformers.py | 3 +- 2 files changed, 243 insertions(+), 11 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 8c3b3b0c0e359..ddf67c04bfef5 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -18,6 +18,13 @@ import random +from xformers import ops as xops + +from xformers.ops.fmha.attn_bias import (AttentionBias, + BlockDiagonalMask, + BlockDiagonalCausalMask, + LowerTriangularMaskWithTensorBias) + # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 HEAD_SIZES = [64] @@ -27,7 +34,7 @@ NUM_HEADS = [1] -BATCH_SIZES = [1] +BATCH_SIZES = [2] BLOCK_SIZES = [16] #KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] BACKEND_NAMES = ["xformers"] @@ -35,9 +42,9 @@ # f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) #] -PROMPT_LENS = [8] +PROMPT_LENS = [128] -Q_PROMPT_LENS = [128] +Q_PROMPT_LENS = [129] K_PROMPT_LENS = [128] @@ -68,13 +75,13 @@ def ref_masked_attention( #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" return out -def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=True): - assert max_kv_prompt_len >= max_q_prompt_len +def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=False): + #assert max_kv_prompt_len >= max_q_prompt_len if force_max_len: q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] else: - q_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] + q_prompt_lens = [random.randint(3, max_q_prompt_len) for _ in range(batch_size)] kv_prompt_lens = None if not is_cross_attn: # K,V prompt lens match Q for self-attention @@ -84,8 +91,8 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a if force_max_len: kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] else: - kv_prompt_lens = [16*((q_prompt_len + random.randint(0, max_kv_prompt_len-q_prompt_len))//16) - for q_prompt_len,_ in zip(q_prompt_lens,range(batch_size))] + kv_prompt_lens = [min(q_prompt_len-1,max_kv_prompt_len) + for q_prompt_len in q_prompt_lens] actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) @@ -405,7 +412,115 @@ def make_attention(num_heads: int, head_size: int, scale: float): head_size, scale=scale,) -@pytest.mark.skip +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size",BLOCK_SIZES) +@pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) +def test_xops_memory_efficient_attention_forward_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int): + # Attention operator instance + is_cross_attn=True + device='cuda:0' + kv_cache_dtype='auto' + is_prompt = True + context_lens = [0 for _ in range(batch_size)] + key_read_width = 4 + num_blocks = 4096 + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') + scale = float(1.0 / (head_size**0.5)) + attn = make_attention(num_heads, head_size, scale) + attn_backend = make_backend(backend_name) + + query, \ + key, \ + value, \ + prefill_query, \ + _, \ + _, \ + decode_query, \ + _, \ + _, \ + q_prompt_lens, \ + kv_prompt_lens, \ + actual_max_q_prompt_len, \ + actual_max_kv_prompt_len, \ + prefill_q_prompt_lens, \ + _, \ + decode_q_prompt_lens, \ + _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=is_cross_attn) + + #causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) + ideal_output = ref_masked_attention( + query, + key, + value, + scale=scale, + #attn_mask=causal_mask + ) + + original_query = query + #query = query.unsqueeze(0) + #key = key.unsqueeze(0) + #value = value.unsqueeze(0) + xops_out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=None, + p=0.0, + scale=scale) + + assert torch.allclose(ideal_output,xops_out) + + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) + for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] + decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] + + prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) + decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) + + # Unlike self-attention: + # - Prefill slot-mapping includes all key slots + # - Decode slot-mapping is empty + decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping(block_size,kv_prompt_lens) + + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) + + prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) + + prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + + xops_out = xops.memory_efficient_attention_forward( + prefill_packed_query.unsqueeze(0), + prefill_packed_key.unsqueeze(0), + prefill_packed_value.unsqueeze(0), + attn_bias=None, + p=0.0, + scale=scale) + xops_out=xops_out.view_as(prefill_packed_query) + + # eval correctness of prefill output + assert torch.allclose(prefill_packed_actual_output,xops_out) + + # eval correctness of prefill output + assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) + + is_prompt = False + context_lens = copy.deepcopy(kv_prompt_lens) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) + + decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,key,value,decode_q_prompt_lens,kv_prompt_lens) + + decode_packed_actual_output=attn.forward(decode_packed_query,None,None,kv_cache,decode_attn_metadata,scale) + + # eval correctness of decode output + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) + +#@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -470,6 +585,18 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + # attn_bias = BlockDiagonalCausalMask.from_seqlens(prefill_q_prompt_lens) + # xops_out = xops.memory_efficient_attention_forward( + # prefill_packed_query.unsqueeze(0), + # prefill_packed_key.unsqueeze(0), + # prefill_packed_value.unsqueeze(0), + # attn_bias=attn_bias, + # p=0.0, + # scale=scale) + # xops_out=xops_out.view_as(prefill_packed_query) + + # assert torch.allclose(xops_out,prefill_packed_ideal_output[:,0,:]) + # eval correctness of prefill output assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) @@ -484,6 +611,7 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n # eval correctness of decode output assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -492,6 +620,111 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n @pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) @pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: + # Attention operator instance + is_cross_attn=False + device='cuda:0' + kv_cache_dtype='auto' + is_prompt = True + #max_q_prompt_len = max_prompt_len + max_kv_prompt_len = max_q_prompt_len + context_lens = [0 for _ in range(batch_size)] + key_read_width = 4 + num_blocks = 4096 + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') + scale = float(1.0 / (head_size**0.5)) + attn = make_attention(num_heads, head_size, scale) + attn_backend = make_backend(backend_name) + query, \ + key, \ + value, \ + prefill_query, \ + prefill_key, \ + prefill_value, \ + decode_query, \ + decode_key, \ + decode_value, \ + q_prompt_lens, \ + _, \ + _, \ + _, \ + prefill_q_prompt_lens, \ + prefill_kv_prompt_lens, \ + decode_q_prompt_lens, \ + decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) + + causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).cuda() + ideal_output = ref_masked_attention( + query, + key, + value, + scale=scale, + attn_mask=causal_mask + ) + + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) + for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] + decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] + + prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) + decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) + + decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping(block_size,q_prompt_lens) + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + + prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) + + prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + + + + shorten_amt = 17 + new_ideal_output = ref_masked_attention( + query[:,:(prefill_q_prompt_lens[0]-shorten_amt),:,:], + key[:,:prefill_kv_prompt_lens[0],:,:], + value[:,:prefill_kv_prompt_lens[0],:,:], + scale=scale, + attn_mask=None #causal_mask[:(prefill_q_prompt_lens[0]-shorten_amt),:prefill_kv_prompt_lens[0]] + ) + + attn_bias = None #BlockDiagonalCausalMask.from_seqlens([(prefill_q_prompt_lens[0]-shorten_amt)],[prefill_kv_prompt_lens[0]]) + + xops_out = xops.memory_efficient_attention_forward( + prefill_packed_query.view(-1, num_heads, head_size)[:-shorten_amt,:,:].unsqueeze(0), + prefill_packed_key.view(-1, num_heads, head_size).unsqueeze(0), + prefill_packed_value.view(-1, num_heads, head_size).unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale) + #xops_out=xops_out.view_as(prefill_packed_query) + + assert torch.allclose(xops_out,new_ideal_output.view_as(xops_out)) + + # eval correctness of prefill output + assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) + + is_prompt = False + context_lens = copy.deepcopy(prefill_kv_prompt_lens) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) + + decode_packed_query,decode_packed_key,decode_packed_value,_,_ = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) + + decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) + + # eval correctness of decode output + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) + + +@pytest.mark.skip +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size",BLOCK_SIZES) +@pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) +def test_prefill_decode_cross_attention_old(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: # Attention operator instance is_cross_attn=True device='cuda:0' diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 21d828edefc78..46862d72cd7f9 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -420,8 +420,7 @@ def _run_memory_efficient_xformers_forward( value, attn_bias=attn_metadata.attn_bias[0], p=0.0, - scale=self.scale, - op=xops.MemoryEfficientAttentionOp()) + scale=self.scale) return out.view_as(original_query) # Attention with alibi slopes. From 7880b0ea9ecf60d49802d9944a20466b404a7e1b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 16:31:39 -0400 Subject: [PATCH 019/443] cross-attention prefill works! --- tests/layer/test_self_and_cross_attn.py | 255 +++--------------------- vllm/attention/backends/xformers.py | 4 +- 2 files changed, 29 insertions(+), 230 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index ddf67c04bfef5..954e56d8bcc39 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -10,21 +10,12 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.ops.paged_attn import PagedAttention - from vllm.utils import make_tensor_with_pad from vllm.attention.layer import Attention import random -from xformers import ops as xops - -from xformers.ops.fmha.attn_bias import (AttentionBias, - BlockDiagonalMask, - BlockDiagonalCausalMask, - LowerTriangularMaskWithTensorBias) - # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 HEAD_SIZES = [64] @@ -60,7 +51,9 @@ def ref_masked_attention( key: torch.Tensor, value: torch.Tensor, scale: float, - attn_mask: Optional[torch.Tensor] = None, + custom_mask: Optional[torch.Tensor] = None, + q_prompt_lens: Optional[List] = None, + kv_prompt_lens: Optional[List] = None ) -> torch.Tensor: #query=query.unsqueeze(-2) #key=key.unsqueeze(-2) @@ -68,8 +61,23 @@ def ref_masked_attention( #assert False,f"{query.shape} ; {key.shape}" attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() #assert False,f"{query.shape} ; {key.shape} ; {attn_weights.shape}" - if attn_mask is not None: + + # Lowest-level attention mask, derived from prompt lens + if (q_prompt_lens is not None) or (kv_prompt_lens is not None): + attn_mask = torch.zeros_like(attn_weights) + if q_prompt_lens is not None: + for bdx,plen in enumerate(q_prompt_lens): + attn_mask[bdx,:,plen:,:] = -torch.inf + if kv_prompt_lens is not None: + for bdx,plen in enumerate(kv_prompt_lens): + attn_mask[bdx,:,:,plen:] = -torch.inf + attn_weights = attn_weights + attn_mask.float() + + # Custom attention mask + if custom_mask is not None: + attn_weights = attn_weights + custom_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" @@ -412,114 +420,6 @@ def make_attention(num_heads: int, head_size: int, scale: float): head_size, scale=scale,) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size",BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) -def test_xops_memory_efficient_attention_forward_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int): - # Attention operator instance - is_cross_attn=True - device='cuda:0' - kv_cache_dtype='auto' - is_prompt = True - context_lens = [0 for _ in range(batch_size)] - key_read_width = 4 - num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') - scale = float(1.0 / (head_size**0.5)) - attn = make_attention(num_heads, head_size, scale) - attn_backend = make_backend(backend_name) - - query, \ - key, \ - value, \ - prefill_query, \ - _, \ - _, \ - decode_query, \ - _, \ - _, \ - q_prompt_lens, \ - kv_prompt_lens, \ - actual_max_q_prompt_len, \ - actual_max_kv_prompt_len, \ - prefill_q_prompt_lens, \ - _, \ - decode_q_prompt_lens, \ - _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=is_cross_attn) - - #causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) - ideal_output = ref_masked_attention( - query, - key, - value, - scale=scale, - #attn_mask=causal_mask - ) - - original_query = query - #query = query.unsqueeze(0) - #key = key.unsqueeze(0) - #value = value.unsqueeze(0) - xops_out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=None, - p=0.0, - scale=scale) - - assert torch.allclose(ideal_output,xops_out) - - prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) - for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] - decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] - - prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) - decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) - - # Unlike self-attention: - # - Prefill slot-mapping includes all key slots - # - Decode slot-mapping is empty - decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping(block_size,kv_prompt_lens) - - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) - - prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) - - prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) - - xops_out = xops.memory_efficient_attention_forward( - prefill_packed_query.unsqueeze(0), - prefill_packed_key.unsqueeze(0), - prefill_packed_value.unsqueeze(0), - attn_bias=None, - p=0.0, - scale=scale) - xops_out=xops_out.view_as(prefill_packed_query) - - # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,xops_out) - - # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) - - is_prompt = False - context_lens = copy.deepcopy(kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) - - decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,key,value,decode_q_prompt_lens,kv_prompt_lens) - - decode_packed_actual_output=attn.forward(decode_packed_query,None,None,kv_cache,decode_attn_metadata,scale) - - # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) - #@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -552,7 +452,7 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_key, \ decode_value, \ q_prompt_lens, \ - _, \ + kv_prompt_lens, \ _, \ _, \ prefill_q_prompt_lens, \ @@ -566,7 +466,9 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n key, value, scale=scale, - attn_mask=causal_mask + custom_mask=causal_mask, + q_prompt_lens=q_prompt_lens, + kv_prompt_lens=kv_prompt_lens ) prefill_ideal_output = torch.zeros_like(ideal_output) @@ -609,9 +511,10 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) +#@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -620,111 +523,6 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n @pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) @pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: - # Attention operator instance - is_cross_attn=False - device='cuda:0' - kv_cache_dtype='auto' - is_prompt = True - #max_q_prompt_len = max_prompt_len - max_kv_prompt_len = max_q_prompt_len - context_lens = [0 for _ in range(batch_size)] - key_read_width = 4 - num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') - scale = float(1.0 / (head_size**0.5)) - attn = make_attention(num_heads, head_size, scale) - attn_backend = make_backend(backend_name) - query, \ - key, \ - value, \ - prefill_query, \ - prefill_key, \ - prefill_value, \ - decode_query, \ - decode_key, \ - decode_value, \ - q_prompt_lens, \ - _, \ - _, \ - _, \ - prefill_q_prompt_lens, \ - prefill_kv_prompt_lens, \ - decode_q_prompt_lens, \ - decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) - - causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).cuda() - ideal_output = ref_masked_attention( - query, - key, - value, - scale=scale, - attn_mask=causal_mask - ) - - prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) - for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] - decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] - - prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) - decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) - - decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping(block_size,q_prompt_lens) - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) - - prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) - - prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) - - - - shorten_amt = 17 - new_ideal_output = ref_masked_attention( - query[:,:(prefill_q_prompt_lens[0]-shorten_amt),:,:], - key[:,:prefill_kv_prompt_lens[0],:,:], - value[:,:prefill_kv_prompt_lens[0],:,:], - scale=scale, - attn_mask=None #causal_mask[:(prefill_q_prompt_lens[0]-shorten_amt),:prefill_kv_prompt_lens[0]] - ) - - attn_bias = None #BlockDiagonalCausalMask.from_seqlens([(prefill_q_prompt_lens[0]-shorten_amt)],[prefill_kv_prompt_lens[0]]) - - xops_out = xops.memory_efficient_attention_forward( - prefill_packed_query.view(-1, num_heads, head_size)[:-shorten_amt,:,:].unsqueeze(0), - prefill_packed_key.view(-1, num_heads, head_size).unsqueeze(0), - prefill_packed_value.view(-1, num_heads, head_size).unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale) - #xops_out=xops_out.view_as(prefill_packed_query) - - assert torch.allclose(xops_out,new_ideal_output.view_as(xops_out)) - - # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) - - is_prompt = False - context_lens = copy.deepcopy(prefill_kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) - - decode_packed_query,decode_packed_key,decode_packed_value,_,_ = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) - - decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) - - # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) - - -@pytest.mark.skip -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size",BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) -def test_prefill_decode_cross_attention_old(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: # Attention operator instance is_cross_attn=True device='cuda:0' @@ -762,7 +560,8 @@ def test_prefill_decode_cross_attention_old(num_heads: int, head_size: int, back key, value, scale=scale, - #attn_mask=causal_mask + q_prompt_lens=q_prompt_lens, + kv_prompt_lens=kv_prompt_lens ) prefill_ideal_output = torch.zeros_like(ideal_output) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 46862d72cd7f9..0ef80ca410355 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -392,8 +392,8 @@ def _run_memory_efficient_xformers_forward( if attn_metadata.attn_bias is None: if self.alibi_slopes is None: if attn_metadata.is_cross_attn: - attn_bias = None #BlockDiagonalMask.from_seqlens( - # attn_metadata.seq_lens,attn_metadata.cross_seq_lens) + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens,attn_metadata.cross_seq_lens) else: attn_bias = BlockDiagonalCausalMask.from_seqlens( attn_metadata.seq_lens) From 2ad68f1932d50a4cbd71068efa1b06df8f18eb54 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 16:35:17 -0400 Subject: [PATCH 020/443] reintroduced completely random Q/K sequence lengths --- tests/layer/test_self_and_cross_attn.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 954e56d8bcc39..42dec244b5c9c 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -25,7 +25,7 @@ NUM_HEADS = [1] -BATCH_SIZES = [2] +BATCH_SIZES = [16] BLOCK_SIZES = [16] #KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] BACKEND_NAMES = ["xformers"] @@ -89,7 +89,7 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a if force_max_len: q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] else: - q_prompt_lens = [random.randint(3, max_q_prompt_len) for _ in range(batch_size)] + q_prompt_lens = [random.randint(2, max_q_prompt_len) for _ in range(batch_size)] kv_prompt_lens = None if not is_cross_attn: # K,V prompt lens match Q for self-attention @@ -99,8 +99,7 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a if force_max_len: kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] else: - kv_prompt_lens = [min(q_prompt_len-1,max_kv_prompt_len) - for q_prompt_len in q_prompt_lens] + kv_prompt_lens = [random.randint(2, max_kv_prompt_len) for _ in range(batch_size)] actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) @@ -420,7 +419,6 @@ def make_attention(num_heads: int, head_size: int, scale: float): head_size, scale=scale,) -#@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -487,18 +485,6 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) - # attn_bias = BlockDiagonalCausalMask.from_seqlens(prefill_q_prompt_lens) - # xops_out = xops.memory_efficient_attention_forward( - # prefill_packed_query.unsqueeze(0), - # prefill_packed_key.unsqueeze(0), - # prefill_packed_value.unsqueeze(0), - # attn_bias=attn_bias, - # p=0.0, - # scale=scale) - # xops_out=xops_out.view_as(prefill_packed_query) - - # assert torch.allclose(xops_out,prefill_packed_ideal_output[:,0,:]) - # eval correctness of prefill output assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) From 93e96d493c13484038f7874c7ee7e8fe4119751c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 17:21:29 -0400 Subject: [PATCH 021/443] cross-attention works for both prefill and decode! --- tests/layer/test_self_and_cross_attn.py | 4 +-- vllm/attention/backends/xformers.py | 35 +++++++++++++++---------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 42dec244b5c9c..d024b28cbf13b 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -499,8 +499,6 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n # eval correctness of decode output assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) - -#@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -568,7 +566,7 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) - prefill_packed_actual_output=attn.forward(prefill_packed_query.contiguous(),prefill_packed_key.contiguous(),prefill_packed_value.contiguous(),kv_cache,prefill_attn_metadata,scale) + prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) # eval correctness of prefill output assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0ef80ca410355..57316f9e0c967 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -266,20 +266,27 @@ def forward( shape = [num_tokens, num_heads * head_size] """ query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if kv_cache is not None: + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) + + if (kv_cache is not None): + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, kv_scale) + if (key is not None) and (value is not None): + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -294,7 +301,7 @@ def forward( # QKV for prefill. query = query[:num_prefill_tokens] - if not is_cross_attn: + if not is_cross_attn and key is not None and value is not None: key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] @@ -339,8 +346,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_decode_seq_len, + decode_meta.seq_lens_tensor if not is_cross_attn else torch.tensor(decode_meta.cross_seq_lens,dtype=decode_meta.seq_lens_tensor.dtype,device=decode_meta.seq_lens_tensor.device), + decode_meta.max_decode_seq_len if not is_cross_attn else max(decode_meta.cross_seq_lens), self.kv_cache_dtype, self.num_kv_heads, self.scale, From 5d91c94c990d531a059ce3e42b131b88d8c0f0bd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 19:01:31 -0400 Subject: [PATCH 022/443] test refactoring: adding function comments, removing unnecessary comments & arguments --- tests/layer/test_self_and_cross_attn.py | 222 ++++++++++++++---------- 1 file changed, 131 insertions(+), 91 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index d024b28cbf13b..751f6b3d15f5d 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -5,7 +5,7 @@ import pytest import torch import copy -from vllm.attention import Attention, AttentionMetadata #, AttentionMetadataPerStage +from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.xformers import XFormersBackend from vllm.attention.backends.abstract import AttentionBackend @@ -23,25 +23,33 @@ # [64, 80, 96, 112, 128, 256 # ] if not is_hip() else [64, 80, 96, 112, 128] -NUM_HEADS = [1] +NUM_HEADS = [16] BATCH_SIZES = [16] BLOCK_SIZES = [16] -#KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] BACKEND_NAMES = ["xformers"] -#CUDA_DEVICES = [ -# f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -#] +CUDA_DEVICE="cuda:0" PROMPT_LENS = [128] -Q_PROMPT_LENS = [129] +Q_PROMPT_LENS = [128] K_PROMPT_LENS = [128] -def build_causal_mask(q_max_prompt_len, k_max_prompt_len): +def build_causal_mask(q_max_prompt_len, kv_max_prompt_len): + ''' + Create a q_max_prompt_len x kv_max_prompt_len causal mask + + Arguments: + * q_max_prompt_len: query max prompt len + * kv_max_prompt_len: key/value max prompt len + + Returns: + * 2D tensor, q_max_prompt_len x kv_max_prompt_len + ''' + # Create a matrix where entry (i, j) is True if i >= j - mask = torch.triu(torch.ones(q_max_prompt_len, k_max_prompt_len), diagonal=1) #.transpose(0, 1) + mask = torch.triu(torch.ones(q_max_prompt_len, kv_max_prompt_len), diagonal=1) # Replace True with float('-inf') and False with 0 mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, 0.0) return mask @@ -55,14 +63,31 @@ def ref_masked_attention( q_prompt_lens: Optional[List] = None, kv_prompt_lens: Optional[List] = None ) -> torch.Tensor: - #query=query.unsqueeze(-2) - #key=key.unsqueeze(-2) - #value=value.unsqueeze(-2) - #assert False,f"{query.shape} ; {key.shape}" + ''' + "Golden" masked attention reference. Supports two types of masking: + * Basic attention mask, utilizing {q,kv}_prompt_lens args to mask out padding elements + * Custom attention mask, which can force an arbitrary mask tensor, i.e. causal + + Arguments: + * query: batch_size x q_padded_seq_len x num_heads x head_size + * key: batch_size x kv_padded_seq_len x num_heads x head_size + * value: batch_size x kv_padded_seq_len x num_heads x head_size + * scale: Attention scale factor + * Custom mask: custom attention mask; good place to inject a causal attention mask + * q_prompt_lens: list of unpadded query seq_lens for each batch index + * kv_prompt_lens: list of unpadded key/value seq_lens for each batch index + + Returns: + * Attention result, batch_size x q_padded_seq_len x num_heads x head_size + ''' + + batch_size = query.shape[0] + assert(len(q_prompt_lens) == batch_size) + assert(len(kv_prompt_lens) == batch_size) + attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() - #assert False,f"{query.shape} ; {key.shape} ; {attn_weights.shape}" - # Lowest-level attention mask, derived from prompt lens + # Basic attention mask, derived from prompt lens if (q_prompt_lens is not None) or (kv_prompt_lens is not None): attn_mask = torch.zeros_like(attn_weights) if q_prompt_lens is not None: @@ -80,11 +105,48 @@ def ref_masked_attention( attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) - #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" return out -def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=False): - #assert max_kv_prompt_len >= max_q_prompt_len +def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size, is_cross_attn=True, force_max_len=False, device=CUDA_DEVICE): + ''' + Construct QKV test tensors for self- and cross-attention. + + Generates three query/key/value triplets: + * "Baseline" query/key/value (for input to reference attention function) + * "Prefill" query/key/value (last sequence offset zero'd out, for use as input to prefill kernel) + * "Decode" query/key/value (only the last sequence offset from baseline, for use as input to decode kernel) + + Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v seqlens + + Arguments: + * batch_size + * max_q_prompt_len: max query prompt len + * max_kv_prompt_len: max key/value prompt len + * num_heads + * head_size + * is_cross_attn: if True, query seqlen may differ from key/value seqlen (as is often the case for cross-attention); o/w, query/key/value seqlens match at each batch index (max_kv_prompt_len is unused) + * force_max_len: if True, all query seqlens are max_q_prompt_len; o/w query seqlens are random in [2,max_q_prompt_lens]. Same for key/value seqlens and max_kv_prompt_len, unless forced by is_cross_attn=False + * device: CPU or CUDA device + + Returns: + * query: "baseline" query; batch_size x max_q_prompt_len x num_heads x head_size + * key: "baseline" key; batch_size x max_kv_prompt_len x num_heads x head_size + * value: "baseline" value; batch_size x max_kv_prompt_len x num_heads x head_size + * prefill_query: batch_size x (max_q_prompt_len-1) x num_heads x head_size + * prefill_key: batch_size x (max_kv_prompt_len-1) x num_heads x head_size + * prefill_value: batch_size x (max_kv_prompt_len-1) x num_heads x head_size + * decode_query: batch_size x 1 x num_heads x head_size + * decode_key: batch_size x 1 x num_heads x head_size + * decode_value: batch_size x 1 x num_heads x head_size + * q_prompt_lens: "baseline" query seqlen list + * kv_prompt_lens: "baseline" key/value seqlen list + * actual_max_q_prompt_len: actual "baseline" query max prompt len (may be <= max_q_prompt_len due to randomness) + * actual_max_kv_prompt_len: actual "baseline" key/value max prompt len (may be <= max_kv_prompt_len due to randomness) + * prefill_q_prompt_lens: "prefill" query seqlen list + * prefill_kv_prompt_lens: "prefill" key/value seqlen list + * decode_q_prompt_lens: "decode" query seqlen list (all ones) + * decode_kv_prompt_lens: "decode" key/value seqlen list + ''' if force_max_len: q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] @@ -95,7 +157,7 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a # K,V prompt lens match Q for self-attention kv_prompt_lens = q_prompt_lens else: - # K,V prompt lens come from K,V operands + # K,V prompt lens are distinct from Q prompt lens & random if force_max_len: kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] else: @@ -104,17 +166,17 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) - query=torch.rand((batch_size,max_q_prompt_len,head_size)).cuda() - key=torch.rand((batch_size,max_kv_prompt_len,head_size)).cuda() - value=torch.rand((batch_size,max_kv_prompt_len,head_size)).cuda() + query=torch.rand((batch_size,max_q_prompt_len,num_heads*head_size)).to(device) + key=torch.rand((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) + value=torch.rand((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) - prefill_query=torch.zeros((batch_size,max_q_prompt_len-1,head_size)).cuda() - prefill_key=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)).cuda() - prefill_value=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)).cuda() + prefill_query=torch.zeros((batch_size,max_q_prompt_len,num_heads*head_size)).to(device) + prefill_key=torch.zeros((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) + prefill_value=torch.zeros((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) - decode_query=torch.zeros((batch_size,1,head_size)).cuda() - decode_key=torch.zeros((batch_size,1,head_size)).cuda() - decode_value=torch.zeros((batch_size,1,head_size)).cuda() + decode_query=torch.zeros((batch_size,1,num_heads*head_size)).to(device) + decode_key=torch.zeros((batch_size,1,num_heads*head_size)).to(device) + decode_value=torch.zeros((batch_size,1,num_heads*head_size)).to(device) for bdx,(q_prompt_len,kv_prompt_len) in enumerate(zip(q_prompt_lens,kv_prompt_lens)): query[bdx,q_prompt_len:,:] = 0 @@ -135,17 +197,17 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a decode_q_prompt_lens = [1 for _ in q_prompt_lens] decode_kv_prompt_lens = [1 for _ in kv_prompt_lens] - query=query.unsqueeze(-2) - key=key.unsqueeze(-2) - value=value.unsqueeze(-2) + query=query.view(batch_size,query.shape[1],num_heads,head_size) + key=key.view(batch_size,key.shape[1],num_heads,head_size) + value=value.view(batch_size,value.shape[1],num_heads,head_size) - prefill_query=prefill_query.unsqueeze(-2) - prefill_key=prefill_key.unsqueeze(-2) - prefill_value=prefill_value.unsqueeze(-2) + prefill_query=prefill_query.view(batch_size,prefill_query.shape[1],num_heads,head_size) + prefill_key=prefill_key.view(batch_size,prefill_key.shape[1],num_heads,head_size) + prefill_value=prefill_value.view(batch_size,prefill_value.shape[1],num_heads,head_size) - decode_query=decode_query.unsqueeze(-2) - decode_key=decode_key.unsqueeze(-2) - decode_value=decode_value.unsqueeze(-2) + decode_query=decode_query.view(batch_size,decode_query.shape[1],num_heads,head_size) + decode_key=decode_key.view(batch_size,decode_key.shape[1],num_heads,head_size) + decode_value=decode_value.view(batch_size,decode_value.shape[1],num_heads,head_size) return query, \ key, \ @@ -165,7 +227,7 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a decode_q_prompt_lens, \ decode_kv_prompt_lens -def pack_tensor(unpacked_tensor,prompt_lens, device='cuda:0'): +def pack_tensor(unpacked_tensor,prompt_lens, device=CUDA_DEVICE): num_tok = sum(prompt_lens) num_heads = unpacked_tensor.shape[-2] head_size = unpacked_tensor.shape[-1] @@ -195,7 +257,7 @@ def make_backend(backend_name: str) -> AttentionBackend: return XFormersBackend() assert False, f"Unrecognized backend_name {backend_name} for unit test" -def make_metadata_tensors(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, device='cuda:0', cross_prompt_lens:Optional[List[int]] = None) -> tuple: +def make_metadata_tensors(is_prompt:bool, prompt_lens:List[int], context_lens:List[int], device=CUDA_DEVICE) -> tuple: ''' Assumptions: * No chunked prefill @@ -239,9 +301,7 @@ def make_metadata_tensors(attn_backend:AttentionBackend, is_prompt:bool, is_cros seq_start_loc, \ query_start_loc -def make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0', default_val=0.0): - #key_cache = torch.rand((num_blocks, num_heads, head_size//key_read_width, block_size, key_read_width),device=device) - #val_cache = torch.rand((num_blocks, num_heads, head_size, block_size),device=device) +def make_kv_cache(num_blocks, num_heads, head_size, block_size, device=CUDA_DEVICE, default_val=0.0): kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to(device) if default_val is not None: kv_cache[:,:,:] = default_val @@ -250,18 +310,16 @@ def make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, def num_tokens_to_min_blocks(num_tokens,block_size): return (num_tokens+block_size)//block_size -def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): +def make_block_tables_slot_mapping(block_size,prompt_lens,device=CUDA_DEVICE): ''' Naive block table: * For each batch element... * Block table has ''' - num_prompts = len(prompt_lens) - total_num_tokens = sum(prompt_lens) + # Over-provision block table blocks by 1 num_blocks_list = [num_tokens_to_min_blocks(num_tokens,block_size)+1 for num_tokens in prompt_lens] max_block_table_len = max(num_blocks_list) - #block_tables = [list(range(num_blocks*10)) for num_blocks in num_blocks_list] block_table_pad_tokens = 10 block_tables = [] @@ -269,9 +327,7 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): decode_slot_mapping = [] slot_mapping = [] block_base_idx = sum(num_blocks_list)*2-1 # Support more blocks than needed - #seq_base_idx = 0 for sdx,num_tokens in enumerate(prompt_lens): - #num_blocks = num_tokens_to_min_blocks(num_tokens,block_size) num_blocks = num_blocks_list[sdx] block_table = list(range(block_base_idx,block_base_idx-num_blocks,-1)) for idx in range(num_tokens-1): @@ -281,13 +337,12 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): decode_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) - #seq_base_idx += num_tokens block_base_idx -= num_blocks block_tables.append(block_table) prefill_block_tables_tensor = torch.tensor( [], - device='cuda:0' + device=CUDA_DEVICE ) decode_block_tables_tensor = make_tensor_with_pad( block_tables, @@ -320,7 +375,7 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor -def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, slot_mapping, device='cuda:0', kv_cache_dtype='auto', cross_prompt_lens:Optional[List[int]] = None): +def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, slot_mapping, device=CUDA_DEVICE, cross_prompt_lens:Optional[List[int]] = None): ''' Assumptions: * No chunked prefill -> a batch is 100% prefill or 100% decode, never both @@ -334,17 +389,13 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b prompt_lens_tensor, \ context_lens_tensor, \ max_query_len, \ - max_context_len, \ - max_prompt_len, \ + _, \ + _, \ seq_start_loc, \ - query_start_loc = make_metadata_tensors(attn_backend, - is_prompt, - is_cross_attn, + query_start_loc = make_metadata_tensors(is_prompt, prompt_lens, context_lens, - block_tables, - device=device, - cross_prompt_lens=cross_prompt_lens) + device=device) slot_mapping_tensor=torch.tensor(slot_mapping, dtype=torch.long, @@ -378,17 +429,13 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b prompt_lens_tensor, \ context_lens_tensor, \ max_query_len, \ - max_context_len, \ - max_prompt_len, \ + _, \ + _, \ seq_start_loc, \ - query_start_loc = make_metadata_tensors(attn_backend, - is_prompt, - is_cross_attn, + query_start_loc = make_metadata_tensors(is_prompt, prompt_lens, context_lens, - block_tables, - device=device, - cross_prompt_lens=cross_prompt_lens) + device=device) slot_mapping_tensor=torch.tensor(slot_mapping, dtype=torch.long, @@ -428,15 +475,12 @@ def make_attention(num_heads: int, head_size: int, scale: float): def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_prompt_len: int) -> None: # Attention operator instance is_cross_attn=False - device='cuda:0' - kv_cache_dtype='auto' is_prompt = True max_q_prompt_len = max_prompt_len max_kv_prompt_len = max_q_prompt_len context_lens = [0 for _ in range(batch_size)] - key_read_width = 4 num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) scale = float(1.0 / (head_size**0.5)) attn = make_attention(num_heads, head_size, scale) attn_backend = make_backend(backend_name) @@ -456,9 +500,9 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n prefill_q_prompt_lens, \ prefill_kv_prompt_lens, \ decode_q_prompt_lens, \ - decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) + decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=False) - causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).cuda() + causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).to(CUDA_DEVICE) ideal_output = ref_masked_attention( query, key, @@ -479,25 +523,25 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping(block_size,q_prompt_lens) - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, cross_prompt_lens=None) prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) + assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) is_prompt = False context_lens = copy.deepcopy(prefill_kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping) decode_packed_query,decode_packed_key,decode_packed_value,_,_ = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output.view_as(decode_packed_actual_output)) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -509,13 +553,10 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: # Attention operator instance is_cross_attn=True - device='cuda:0' - kv_cache_dtype='auto' is_prompt = True context_lens = [0 for _ in range(batch_size)] - key_read_width = 4 num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) scale = float(1.0 / (head_size**0.5)) attn = make_attention(num_heads, head_size, scale) attn_backend = make_backend(backend_name) @@ -531,14 +572,13 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ _, \ q_prompt_lens, \ kv_prompt_lens, \ - actual_max_q_prompt_len, \ - actual_max_kv_prompt_len, \ + _, \ + _, \ prefill_q_prompt_lens, \ _, \ decode_q_prompt_lens, \ - _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=is_cross_attn) + _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=is_cross_attn) - #causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) ideal_output = ref_masked_attention( query, key, @@ -562,22 +602,22 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ # - Decode slot-mapping is empty decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping(block_size,kv_prompt_lens) - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, cross_prompt_lens=kv_prompt_lens) - prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) + prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) + assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) is_prompt = False context_lens = copy.deepcopy(kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, cross_prompt_lens=kv_prompt_lens) - decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,key,value,decode_q_prompt_lens,kv_prompt_lens) + decode_packed_query,_,_,_,_ = pack_qkv(decode_query,key,value,decode_q_prompt_lens,kv_prompt_lens) decode_packed_actual_output=attn.forward(decode_packed_query,None,None,kv_cache,decode_attn_metadata,scale) # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) \ No newline at end of file + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output.view_as(decode_packed_actual_output)) \ No newline at end of file From 7cc88f780598d86a51563b20a1acf552eddc1afd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 19:08:02 -0400 Subject: [PATCH 023/443] formatting --- tests/layer/test_self_and_cross_attn.py | 556 +++++++++++++++--------- vllm/attention/backends/xformers.py | 28 +- 2 files changed, 361 insertions(+), 223 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 751f6b3d15f5d..18afff46d8371 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -18,7 +18,7 @@ # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64] +HEAD_SIZES = [64] # [64, 80, 96, 112, 128, 256 # ] if not is_hip() else [64, 80, 96, 112, 128] @@ -28,7 +28,7 @@ BATCH_SIZES = [16] BLOCK_SIZES = [16] BACKEND_NAMES = ["xformers"] -CUDA_DEVICE="cuda:0" +CUDA_DEVICE = "cuda:0" PROMPT_LENS = [128] @@ -36,6 +36,7 @@ K_PROMPT_LENS = [128] + def build_causal_mask(q_max_prompt_len, kv_max_prompt_len): ''' Create a q_max_prompt_len x kv_max_prompt_len causal mask @@ -49,20 +50,22 @@ def build_causal_mask(q_max_prompt_len, kv_max_prompt_len): ''' # Create a matrix where entry (i, j) is True if i >= j - mask = torch.triu(torch.ones(q_max_prompt_len, kv_max_prompt_len), diagonal=1) + mask = torch.triu(torch.ones(q_max_prompt_len, kv_max_prompt_len), + diagonal=1) # Replace True with float('-inf') and False with 0 - mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, 0.0) + mask = mask.masked_fill(mask == 1, + float('-inf')).masked_fill(mask == 0, 0.0) return mask + def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - custom_mask: Optional[torch.Tensor] = None, - q_prompt_lens: Optional[List] = None, - kv_prompt_lens: Optional[List] = None -) -> torch.Tensor: + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + custom_mask: Optional[torch.Tensor] = None, + q_prompt_lens: Optional[List] = None, + kv_prompt_lens: Optional[List] = None) -> torch.Tensor: ''' "Golden" masked attention reference. Supports two types of masking: * Basic attention mask, utilizing {q,kv}_prompt_lens args to mask out padding elements @@ -82,8 +85,8 @@ def ref_masked_attention( ''' batch_size = query.shape[0] - assert(len(q_prompt_lens) == batch_size) - assert(len(kv_prompt_lens) == batch_size) + assert (len(q_prompt_lens) == batch_size) + assert (len(kv_prompt_lens) == batch_size) attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() @@ -91,11 +94,11 @@ def ref_masked_attention( if (q_prompt_lens is not None) or (kv_prompt_lens is not None): attn_mask = torch.zeros_like(attn_weights) if q_prompt_lens is not None: - for bdx,plen in enumerate(q_prompt_lens): - attn_mask[bdx,:,plen:,:] = -torch.inf + for bdx, plen in enumerate(q_prompt_lens): + attn_mask[bdx, :, plen:, :] = -torch.inf if kv_prompt_lens is not None: - for bdx,plen in enumerate(kv_prompt_lens): - attn_mask[bdx,:,:,plen:] = -torch.inf + for bdx, plen in enumerate(kv_prompt_lens): + attn_mask[bdx, :, :, plen:] = -torch.inf attn_weights = attn_weights + attn_mask.float() @@ -107,7 +110,15 @@ def ref_masked_attention( out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) return out -def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size, is_cross_attn=True, force_max_len=False, device=CUDA_DEVICE): + +def make_qkv(batch_size, + max_q_prompt_len, + max_kv_prompt_len, + num_heads, + head_size, + is_cross_attn=True, + force_max_len=False, + device=CUDA_DEVICE): ''' Construct QKV test tensors for self- and cross-attention. @@ -151,7 +162,9 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size, if force_max_len: q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] else: - q_prompt_lens = [random.randint(2, max_q_prompt_len) for _ in range(batch_size)] + q_prompt_lens = [ + random.randint(2, max_q_prompt_len) for _ in range(batch_size) + ] kv_prompt_lens = None if not is_cross_attn: # K,V prompt lens match Q for self-attention @@ -161,35 +174,53 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size, if force_max_len: kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] else: - kv_prompt_lens = [random.randint(2, max_kv_prompt_len) for _ in range(batch_size)] + kv_prompt_lens = [ + random.randint(2, max_kv_prompt_len) for _ in range(batch_size) + ] actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) - query=torch.rand((batch_size,max_q_prompt_len,num_heads*head_size)).to(device) - key=torch.rand((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) - value=torch.rand((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) - - prefill_query=torch.zeros((batch_size,max_q_prompt_len,num_heads*head_size)).to(device) - prefill_key=torch.zeros((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) - prefill_value=torch.zeros((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) - - decode_query=torch.zeros((batch_size,1,num_heads*head_size)).to(device) - decode_key=torch.zeros((batch_size,1,num_heads*head_size)).to(device) - decode_value=torch.zeros((batch_size,1,num_heads*head_size)).to(device) - - for bdx,(q_prompt_len,kv_prompt_len) in enumerate(zip(q_prompt_lens,kv_prompt_lens)): - query[bdx,q_prompt_len:,:] = 0 - key[bdx,kv_prompt_len:,:] = 0 - value[bdx,kv_prompt_len:,:] = 0 - - prefill_query[bdx,0:(q_prompt_len-1),:] = query[bdx,0:(q_prompt_len-1),:] - prefill_key[bdx,0:(kv_prompt_len-1),:] = key[bdx,0:(kv_prompt_len-1),:] - prefill_value[bdx,0:(kv_prompt_len-1),:] = value[bdx,0:(kv_prompt_len-1),:] - - decode_query[bdx,:,:] = query[bdx,(q_prompt_len-1):q_prompt_len,:] - decode_key[bdx,:,:] = key[bdx,(kv_prompt_len-1):kv_prompt_len,:] - decode_value[bdx,:,:] = value[bdx,(kv_prompt_len-1):kv_prompt_len,:] + query = torch.rand( + (batch_size, max_q_prompt_len, num_heads * head_size)).to(device) + key = torch.rand( + (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + value = torch.rand( + (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + + prefill_query = torch.zeros( + (batch_size, max_q_prompt_len, num_heads * head_size)).to(device) + prefill_key = torch.zeros( + (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + prefill_value = torch.zeros( + (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + + decode_query = torch.zeros( + (batch_size, 1, num_heads * head_size)).to(device) + decode_key = torch.zeros((batch_size, 1, num_heads * head_size)).to(device) + decode_value = torch.zeros( + (batch_size, 1, num_heads * head_size)).to(device) + + for bdx, (q_prompt_len, + kv_prompt_len) in enumerate(zip(q_prompt_lens, kv_prompt_lens)): + query[bdx, q_prompt_len:, :] = 0 + key[bdx, kv_prompt_len:, :] = 0 + value[bdx, kv_prompt_len:, :] = 0 + + prefill_query[bdx, + 0:(q_prompt_len - 1), :] = query[bdx, + 0:(q_prompt_len - 1), :] + prefill_key[bdx, + 0:(kv_prompt_len - 1), :] = key[bdx, + 0:(kv_prompt_len - 1), :] + prefill_value[bdx, 0:(kv_prompt_len - + 1), :] = value[bdx, 0:(kv_prompt_len - 1), :] + + decode_query[bdx, :, :] = query[bdx, + (q_prompt_len - 1):q_prompt_len, :] + decode_key[bdx, :, :] = key[bdx, (kv_prompt_len - 1):kv_prompt_len, :] + decode_value[bdx, :, :] = value[bdx, + (kv_prompt_len - 1):kv_prompt_len, :] prefill_q_prompt_lens = [plen - 1 for plen in q_prompt_lens] prefill_kv_prompt_lens = [plen - 1 for plen in kv_prompt_lens] @@ -197,17 +228,23 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size, decode_q_prompt_lens = [1 for _ in q_prompt_lens] decode_kv_prompt_lens = [1 for _ in kv_prompt_lens] - query=query.view(batch_size,query.shape[1],num_heads,head_size) - key=key.view(batch_size,key.shape[1],num_heads,head_size) - value=value.view(batch_size,value.shape[1],num_heads,head_size) + query = query.view(batch_size, query.shape[1], num_heads, head_size) + key = key.view(batch_size, key.shape[1], num_heads, head_size) + value = value.view(batch_size, value.shape[1], num_heads, head_size) - prefill_query=prefill_query.view(batch_size,prefill_query.shape[1],num_heads,head_size) - prefill_key=prefill_key.view(batch_size,prefill_key.shape[1],num_heads,head_size) - prefill_value=prefill_value.view(batch_size,prefill_value.shape[1],num_heads,head_size) + prefill_query = prefill_query.view(batch_size, prefill_query.shape[1], + num_heads, head_size) + prefill_key = prefill_key.view(batch_size, prefill_key.shape[1], num_heads, + head_size) + prefill_value = prefill_value.view(batch_size, prefill_value.shape[1], + num_heads, head_size) - decode_query=decode_query.view(batch_size,decode_query.shape[1],num_heads,head_size) - decode_key=decode_key.view(batch_size,decode_key.shape[1],num_heads,head_size) - decode_value=decode_value.view(batch_size,decode_value.shape[1],num_heads,head_size) + decode_query = decode_query.view(batch_size, decode_query.shape[1], + num_heads, head_size) + decode_key = decode_key.view(batch_size, decode_key.shape[1], num_heads, + head_size) + decode_value = decode_value.view(batch_size, decode_value.shape[1], + num_heads, head_size) return query, \ key, \ @@ -227,51 +264,62 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size, decode_q_prompt_lens, \ decode_kv_prompt_lens -def pack_tensor(unpacked_tensor,prompt_lens, device=CUDA_DEVICE): + +def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): num_tok = sum(prompt_lens) num_heads = unpacked_tensor.shape[-2] head_size = unpacked_tensor.shape[-1] - start_loc_list = [0]+list(itertools.accumulate(prompt_lens)) - packed_tensor = torch.zeros((num_tok,num_heads,head_size), - device=device) + start_loc_list = [0] + list(itertools.accumulate(prompt_lens)) + packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) - for bdx,(prompt_len,start_loc) in enumerate(zip(prompt_lens,start_loc_list)): + for bdx, (prompt_len, + start_loc) in enumerate(zip(prompt_lens, start_loc_list)): try: - packed_tensor[start_loc:(start_loc+prompt_len),:,:] = unpacked_tensor[bdx,:prompt_len,:,:] + packed_tensor[start_loc:( + start_loc + + prompt_len), :, :] = unpacked_tensor[bdx, :prompt_len, :, :] except: assert False, f"{start_loc} ; {prompt_len} ; {packed_tensor.shape} ; {unpacked_tensor.shape}" - return packed_tensor,start_loc_list - -def pack_qkv(query,key,value,q_prompt_lens,kv_prompt_lens): - packed_query,q_start_loc_list = pack_tensor(query,q_prompt_lens) - packed_key,kv_start_loc_list = pack_tensor(key,kv_prompt_lens) - packed_value,_ = pack_tensor(value,kv_prompt_lens) - packed_query=packed_query.view(-1,packed_query.shape[-1]*packed_query.shape[-2]) - packed_key=packed_key.view(-1,packed_key.shape[-1]*packed_key.shape[-2]) - packed_value=packed_value.view(-1,packed_value.shape[-1]*packed_value.shape[-2]) - return packed_query,packed_key,packed_value,q_start_loc_list,kv_start_loc_list + return packed_tensor, start_loc_list + + +def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): + packed_query, q_start_loc_list = pack_tensor(query, q_prompt_lens) + packed_key, kv_start_loc_list = pack_tensor(key, kv_prompt_lens) + packed_value, _ = pack_tensor(value, kv_prompt_lens) + packed_query = packed_query.view( + -1, packed_query.shape[-1] * packed_query.shape[-2]) + packed_key = packed_key.view(-1, + packed_key.shape[-1] * packed_key.shape[-2]) + packed_value = packed_value.view( + -1, packed_value.shape[-1] * packed_value.shape[-2]) + return packed_query, packed_key, packed_value, q_start_loc_list, kv_start_loc_list + def make_backend(backend_name: str) -> AttentionBackend: if backend_name == "xformers": return XFormersBackend() assert False, f"Unrecognized backend_name {backend_name} for unit test" -def make_metadata_tensors(is_prompt:bool, prompt_lens:List[int], context_lens:List[int], device=CUDA_DEVICE) -> tuple: + +def make_metadata_tensors(is_prompt: bool, + prompt_lens: List[int], + context_lens: List[int], + device=CUDA_DEVICE) -> tuple: ''' Assumptions: * No chunked prefill * No (automatic) prefix caching * Packed variable-length sequences ''' - prompt_lens_tensor=torch.tensor(prompt_lens, - dtype=torch.int, - device=device) - context_lens_tensor=None if context_lens is None else torch.tensor(context_lens, - dtype=torch.int, - device=device) - max_context_len=None if context_lens is None else max(context_lens) - max_prompt_len=None if prompt_lens is None else max(prompt_lens) + prompt_lens_tensor = torch.tensor(prompt_lens, + dtype=torch.int, + device=device) + context_lens_tensor = None if context_lens is None else torch.tensor( + context_lens, dtype=torch.int, device=device) + max_context_len = None if context_lens is None else max(context_lens) + max_prompt_len = None if prompt_lens is None else max(prompt_lens) seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, dtype=torch.int32, @@ -285,7 +333,7 @@ def make_metadata_tensors(is_prompt:bool, prompt_lens:List[int], context_lens:Li if is_prompt: # Prefill: query_start_loc matches seq_start_loc query_start_loc = copy.deepcopy(seq_start_loc) - max_query_len=max_prompt_len + max_query_len = max_prompt_len else: # Decode: one new query input token per batch # element, thus query_start_loc is the cumsum @@ -301,16 +349,27 @@ def make_metadata_tensors(is_prompt:bool, prompt_lens:List[int], context_lens:Li seq_start_loc, \ query_start_loc -def make_kv_cache(num_blocks, num_heads, head_size, block_size, device=CUDA_DEVICE, default_val=0.0): - kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to(device) + +def make_kv_cache(num_blocks, + num_heads, + head_size, + block_size, + device=CUDA_DEVICE, + default_val=0.0): + kv_cache = torch.rand( + (2, num_blocks, block_size * num_heads * head_size)).to(device) if default_val is not None: - kv_cache[:,:,:] = default_val + kv_cache[:, :, :] = default_val return kv_cache -def num_tokens_to_min_blocks(num_tokens,block_size): - return (num_tokens+block_size)//block_size -def make_block_tables_slot_mapping(block_size,prompt_lens,device=CUDA_DEVICE): +def num_tokens_to_min_blocks(num_tokens, block_size): + return (num_tokens + block_size) // block_size + + +def make_block_tables_slot_mapping(block_size, + prompt_lens, + device=CUDA_DEVICE): ''' Naive block table: * For each batch element... @@ -318,7 +377,10 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device=CUDA_DEVICE): ''' # Over-provision block table blocks by 1 - num_blocks_list = [num_tokens_to_min_blocks(num_tokens,block_size)+1 for num_tokens in prompt_lens] + num_blocks_list = [ + num_tokens_to_min_blocks(num_tokens, block_size) + 1 + for num_tokens in prompt_lens + ] max_block_table_len = max(num_blocks_list) block_table_pad_tokens = 10 @@ -326,56 +388,60 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device=CUDA_DEVICE): prefill_slot_mapping = [] decode_slot_mapping = [] slot_mapping = [] - block_base_idx = sum(num_blocks_list)*2-1 # Support more blocks than needed - for sdx,num_tokens in enumerate(prompt_lens): + block_base_idx = sum( + num_blocks_list) * 2 - 1 # Support more blocks than needed + for sdx, num_tokens in enumerate(prompt_lens): num_blocks = num_blocks_list[sdx] - block_table = list(range(block_base_idx,block_base_idx-num_blocks,-1)) - for idx in range(num_tokens-1): - prefill_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) - slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) - idx = num_tokens-1 - decode_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) - slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) + block_table = list( + range(block_base_idx, block_base_idx - num_blocks, -1)) + for idx in range(num_tokens - 1): + prefill_slot_mapping.append((idx % block_size) + + block_table[idx // block_size] * + block_size) + slot_mapping.append((idx % block_size) + + block_table[idx // block_size] * block_size) + idx = num_tokens - 1 + decode_slot_mapping.append((idx % block_size) + + block_table[idx // block_size] * block_size) + slot_mapping.append((idx % block_size) + + block_table[idx // block_size] * block_size) block_base_idx -= num_blocks block_tables.append(block_table) - - prefill_block_tables_tensor = torch.tensor( - [], - device=CUDA_DEVICE - ) + + prefill_block_tables_tensor = torch.tensor([], device=CUDA_DEVICE) decode_block_tables_tensor = make_tensor_with_pad( block_tables, - max_len=max_block_table_len+block_table_pad_tokens, + max_len=max_block_table_len + block_table_pad_tokens, pad=0, dtype=torch.int, device=device, ) - prefill_slot_mapping_tensor = torch.tensor( - prefill_slot_mapping, - dtype=torch.long, - device=device - ) - decode_slot_mapping_tensor = torch.tensor( - decode_slot_mapping, - dtype=torch.long, - device=device - ) - slot_mapping_tensor = torch.tensor( - slot_mapping, - dtype=torch.long, - device=device - ) - empty_slot_mapping_tensor = torch.tensor( - [], - dtype=torch.long, - device=device - ) + prefill_slot_mapping_tensor = torch.tensor(prefill_slot_mapping, + dtype=torch.long, + device=device) + decode_slot_mapping_tensor = torch.tensor(decode_slot_mapping, + dtype=torch.long, + device=device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=device) + empty_slot_mapping_tensor = torch.tensor([], + dtype=torch.long, + device=device) return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor - - -def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, slot_mapping, device=CUDA_DEVICE, cross_prompt_lens:Optional[List[int]] = None): + + +def make_metadata(attn_backend: AttentionBackend, + is_prompt: bool, + is_cross_attn: bool, + prompt_lens: List[int], + context_lens: List[int], + block_tables, + slot_mapping, + device=CUDA_DEVICE, + cross_prompt_lens: Optional[List[int]] = None): ''' Assumptions: * No chunked prefill -> a batch is 100% prefill or 100% decode, never both @@ -392,14 +458,14 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b _, \ _, \ seq_start_loc, \ - query_start_loc = make_metadata_tensors(is_prompt, - prompt_lens, - context_lens, + query_start_loc = make_metadata_tensors(is_prompt, + prompt_lens, + context_lens, device=device) - slot_mapping_tensor=torch.tensor(slot_mapping, - dtype=torch.long, - device=device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -417,10 +483,9 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b block_tables=block_tables, use_cuda_graph=False, is_cross_attn=is_cross_attn, - cross_seq_lens=cross_prompt_lens - ) + cross_seq_lens=cross_prompt_lens) - else: # not is_prompt + else: # not is_prompt num_prefills = 0 num_prefill_tokens = 0 @@ -432,14 +497,14 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b _, \ _, \ seq_start_loc, \ - query_start_loc = make_metadata_tensors(is_prompt, - prompt_lens, - context_lens, + query_start_loc = make_metadata_tensors(is_prompt, + prompt_lens, + context_lens, device=device) - slot_mapping_tensor=torch.tensor(slot_mapping, - dtype=torch.long, - device=device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -457,30 +522,36 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b block_tables=block_tables, use_cuda_graph=False, is_cross_attn=is_cross_attn, - cross_seq_lens=cross_prompt_lens - ) + cross_seq_lens=cross_prompt_lens) + def make_attention(num_heads: int, head_size: int, scale: float): # Attention operator instance - return Attention(num_heads, - head_size, - scale=scale,) + return Attention( + num_heads, + head_size, + scale=scale, + ) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size",BLOCK_SIZES) -@pytest.mark.parametrize("max_prompt_len",PROMPT_LENS) -def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_prompt_len: int) -> None: +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_prompt_len", PROMPT_LENS) +def test_prefill_decode_self_attention(num_heads: int, head_size: int, + backend_name: str, batch_size: int, + block_size: int, + max_prompt_len: int) -> None: # Attention operator instance - is_cross_attn=False + is_cross_attn = False is_prompt = True max_q_prompt_len = max_prompt_len max_kv_prompt_len = max_q_prompt_len context_lens = [0 for _ in range(batch_size)] num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) scale = float(1.0 / (head_size**0.5)) attn = make_attention(num_heads, head_size, scale) attn_backend = make_backend(backend_name) @@ -502,61 +573,94 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_q_prompt_lens, \ decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=False) - causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).to(CUDA_DEVICE) - ideal_output = ref_masked_attention( - query, - key, - value, - scale=scale, - custom_mask=causal_mask, - q_prompt_lens=q_prompt_lens, - kv_prompt_lens=kv_prompt_lens - ) + causal_mask = build_causal_mask(max_q_prompt_len, + max_kv_prompt_len).to(CUDA_DEVICE) + ideal_output = ref_masked_attention(query, + key, + value, + scale=scale, + custom_mask=causal_mask, + q_prompt_lens=q_prompt_lens, + kv_prompt_lens=kv_prompt_lens) prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) - for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] - decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] - - prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) - decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) - - decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping(block_size,q_prompt_lens) - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, cross_prompt_lens=None) + decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) + for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ + bdx, :prefill_q_prompt_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( + prefill_q_prompt_len + 1)] + + prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, + prefill_q_prompt_lens) + decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, + [1 for _ in range(batch_size)]) + + decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping( + block_size, q_prompt_lens) + prefill_attn_metadata: AttentionMetadata = make_metadata( + attn_backend, + is_prompt, + is_cross_attn, + prefill_q_prompt_lens, + context_lens, + prefill_block_tables, + prefill_slot_mapping, + cross_prompt_lens=None) - prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) + prefill_packed_query, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( + prefill_query, prefill_key, prefill_value, prefill_q_prompt_lens, + prefill_kv_prompt_lens) - prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + prefill_packed_actual_output = attn.forward(prefill_packed_query, + prefill_packed_key, + prefill_packed_value, kv_cache, + prefill_attn_metadata, scale) # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) + assert torch.allclose( + prefill_packed_actual_output, + prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) is_prompt = False context_lens = copy.deepcopy(prefill_kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping) - - decode_packed_query,decode_packed_key,decode_packed_value,_,_ = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, + is_cross_attn, q_prompt_lens, + context_lens, decode_block_tables, + decode_slot_mapping) + + decode_packed_query, decode_packed_key, decode_packed_value, _, _ = pack_qkv( + decode_query, decode_key, decode_value, decode_q_prompt_lens, + decode_kv_prompt_lens) - decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) + decode_packed_actual_output = attn.forward(decode_packed_query, + decode_packed_key, + decode_packed_value, kv_cache, + decode_attn_metadata, scale) # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output.view_as(decode_packed_actual_output)) + assert torch.allclose( + decode_packed_actual_output, + decode_packed_ideal_output.view_as(decode_packed_actual_output)) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size",BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) -def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_q_prompt_len", Q_PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len", K_PROMPT_LENS) +def test_prefill_decode_cross_attention(num_heads: int, head_size: int, + backend_name: str, batch_size: int, + block_size: int, max_q_prompt_len: int, + max_kv_prompt_len: int) -> None: # Attention operator instance - is_cross_attn=True + is_cross_attn = True is_prompt = True context_lens = [0 for _ in range(batch_size)] num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) scale = float(1.0 / (head_size**0.5)) attn = make_attention(num_heads, head_size, scale) attn_backend = make_backend(backend_name) @@ -579,45 +683,75 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ decode_q_prompt_lens, \ _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=is_cross_attn) - ideal_output = ref_masked_attention( - query, - key, - value, - scale=scale, - q_prompt_lens=q_prompt_lens, - kv_prompt_lens=kv_prompt_lens - ) + ideal_output = ref_masked_attention(query, + key, + value, + scale=scale, + q_prompt_lens=q_prompt_lens, + kv_prompt_lens=kv_prompt_lens) prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) - for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] - decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] - - prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) - decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) + decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) + for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ + bdx, :prefill_q_prompt_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( + prefill_q_prompt_len + 1)] + + prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, + prefill_q_prompt_lens) + decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, + [1 for _ in range(batch_size)]) # Unlike self-attention: # - Prefill slot-mapping includes all key slots # - Decode slot-mapping is empty - decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping(block_size,kv_prompt_lens) - - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, cross_prompt_lens=kv_prompt_lens) + decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping( + block_size, kv_prompt_lens) + + prefill_attn_metadata: AttentionMetadata = make_metadata( + attn_backend, + is_prompt, + is_cross_attn, + prefill_q_prompt_lens, + context_lens, + prefill_block_tables, + prefill_slot_mapping, + cross_prompt_lens=kv_prompt_lens) - prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) + prefill_packed_query, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( + prefill_query, key, value, prefill_q_prompt_lens, kv_prompt_lens) - prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + prefill_packed_actual_output = attn.forward(prefill_packed_query, + prefill_packed_key, + prefill_packed_value, kv_cache, + prefill_attn_metadata, scale) # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) + assert torch.allclose( + prefill_packed_actual_output, + prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) is_prompt = False context_lens = copy.deepcopy(kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, cross_prompt_lens=kv_prompt_lens) - - decode_packed_query,_,_,_,_ = pack_qkv(decode_query,key,value,decode_q_prompt_lens,kv_prompt_lens) - - decode_packed_actual_output=attn.forward(decode_packed_query,None,None,kv_cache,decode_attn_metadata,scale) + decode_attn_metadata = make_metadata(attn_backend, + is_prompt, + is_cross_attn, + q_prompt_lens, + context_lens, + decode_block_tables, + decode_slot_mapping, + cross_prompt_lens=kv_prompt_lens) + + decode_packed_query, _, _, _, _ = pack_qkv(decode_query, key, value, + decode_q_prompt_lens, + kv_prompt_lens) + + decode_packed_actual_output = attn.forward(decode_packed_query, None, None, + kv_cache, decode_attn_metadata, + scale) # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output.view_as(decode_packed_actual_output)) \ No newline at end of file + assert torch.allclose( + decode_packed_actual_output, + decode_packed_ideal_output.view_as(decode_packed_actual_output)) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 57316f9e0c967..3dfe363cbe7f2 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -4,8 +4,7 @@ import torch from xformers import ops as xops -from xformers.ops.fmha.attn_bias import (AttentionBias, - BlockDiagonalMask, +from xformers.ops.fmha.attn_bias import (AttentionBias, BlockDiagonalMask, BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias) @@ -155,8 +154,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, is_cross_attn=self.is_cross_attn, - cross_seq_lens=self.cross_seq_lens - ) + cross_seq_lens=self.cross_seq_lens) return self._cached_prefill_metadata @property @@ -185,8 +183,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, is_cross_attn=self.is_cross_attn, - cross_seq_lens=self.cross_seq_lens - ) + cross_seq_lens=self.cross_seq_lens) return self._cached_decode_metadata @@ -286,14 +283,17 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - self.kv_cache_dtype, kv_scale) + self.kv_cache_dtype, + kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens is_cross_attn = attn_metadata.is_cross_attn - assert is_cross_attn or (key.shape[0] == num_prefill_tokens + num_decode_tokens) - assert is_cross_attn or (value.shape[0] == num_prefill_tokens + num_decode_tokens) + assert is_cross_attn or (key.shape[0] + == num_prefill_tokens + num_decode_tokens) + assert is_cross_attn or (value.shape[0] + == num_prefill_tokens + num_decode_tokens) output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. @@ -346,8 +346,12 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.seq_lens_tensor if not is_cross_attn else torch.tensor(decode_meta.cross_seq_lens,dtype=decode_meta.seq_lens_tensor.dtype,device=decode_meta.seq_lens_tensor.device), - decode_meta.max_decode_seq_len if not is_cross_attn else max(decode_meta.cross_seq_lens), + decode_meta.seq_lens_tensor if not is_cross_attn else + torch.tensor(decode_meta.cross_seq_lens, + dtype=decode_meta.seq_lens_tensor.dtype, + device=decode_meta.seq_lens_tensor.device), + decode_meta.max_decode_seq_len + if not is_cross_attn else max(decode_meta.cross_seq_lens), self.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -400,7 +404,7 @@ def _run_memory_efficient_xformers_forward( if self.alibi_slopes is None: if attn_metadata.is_cross_attn: attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens,attn_metadata.cross_seq_lens) + attn_metadata.seq_lens, attn_metadata.cross_seq_lens) else: attn_bias = BlockDiagonalCausalMask.from_seqlens( attn_metadata.seq_lens) From 86591214f2116d04b38688cc691e93b1ecd33c71 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 22:44:15 -0400 Subject: [PATCH 024/443] Self & cross attention tests pass with new cross-compatible attention metadata structure! --- tests/layer/test_self_and_cross_attn.py | 430 +++++++++++++++++++++++- vllm/attention/backends/xformers.py | 288 ++++++++++++---- 2 files changed, 647 insertions(+), 71 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 18afff46d8371..5f1e35ceb21d7 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -285,11 +285,16 @@ def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): - packed_query, q_start_loc_list = pack_tensor(query, q_prompt_lens) + if query is None: + packed_query = None + q_start_loc_list = None + else: + packed_query, q_start_loc_list = pack_tensor(query, q_prompt_lens) packed_key, kv_start_loc_list = pack_tensor(key, kv_prompt_lens) packed_value, _ = pack_tensor(value, kv_prompt_lens) - packed_query = packed_query.view( - -1, packed_query.shape[-1] * packed_query.shape[-2]) + if packed_query is not None: + packed_query = packed_query.view( + -1, packed_query.shape[-1] * packed_query.shape[-2]) packed_key = packed_key.view(-1, packed_key.shape[-1] * packed_key.shape[-2]) packed_value = packed_value.view( @@ -369,6 +374,7 @@ def num_tokens_to_min_blocks(num_tokens, block_size): def make_block_tables_slot_mapping(block_size, prompt_lens, + block_base_addr=0, device=CUDA_DEVICE): ''' Naive block table: @@ -388,8 +394,8 @@ def make_block_tables_slot_mapping(block_size, prefill_slot_mapping = [] decode_slot_mapping = [] slot_mapping = [] - block_base_idx = sum( - num_blocks_list) * 2 - 1 # Support more blocks than needed + block_base_idx = block_base_addr + sum(num_blocks_list) * 2 - 1 # Support more blocks than needed + max_block_idx = block_base_idx for sdx, num_tokens in enumerate(prompt_lens): num_blocks = num_blocks_list[sdx] block_table = list( @@ -430,7 +436,7 @@ def make_block_tables_slot_mapping(block_size, dtype=torch.long, device=device) - return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor + return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor, max_block_idx def make_metadata(attn_backend: AttentionBackend, @@ -525,6 +531,102 @@ def make_metadata(attn_backend: AttentionBackend, cross_seq_lens=cross_prompt_lens) +def make_metadata_self_cross(attn_backend: AttentionBackend, + is_prompt: bool, + prompt_lens: List[int], + context_lens: List[int], + block_tables, + slot_mapping, + device=CUDA_DEVICE, + cross_seq_lens: Optional[List[int]] = None, + cross_block_tables: Optional[torch.Tensor] = None, + cross_slot_mapping: Optional[List[int]] = None,): + ''' + Assumptions: + * No chunked prefill -> a batch is 100% prefill or 100% decode, never both + ''' + + if is_prompt: + num_prefills = len(prompt_lens) + num_prefill_tokens = sum(prompt_lens) + num_decode_tokens = 0 + + prompt_lens_tensor, \ + context_lens_tensor, \ + max_query_len, \ + _, \ + _, \ + seq_start_loc, \ + query_start_loc = make_metadata_tensors(is_prompt, + prompt_lens, + context_lens, + device=device) + + slot_mapping_tensor = slot_mapping + + cross_slot_mapping_tensor = cross_slot_mapping + + return attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=prompt_lens, + seq_lens_tensor=prompt_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max(prompt_lens), + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + is_cross_attn=False, + cross_seq_lens=cross_seq_lens, + cross_slot_mapping=cross_slot_mapping_tensor, + cross_block_tables=cross_block_tables) + + else: # not is_prompt + + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = len(prompt_lens) + + prompt_lens_tensor, \ + context_lens_tensor, \ + max_query_len, \ + _, \ + _, \ + seq_start_loc, \ + query_start_loc = make_metadata_tensors(is_prompt, + prompt_lens, + context_lens, + device=device) + + slot_mapping_tensor = slot_mapping + + cross_slot_mapping_tensor = cross_slot_mapping + + return attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=prompt_lens, + seq_lens_tensor=prompt_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=max(prompt_lens), + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + is_cross_attn=False, + cross_seq_lens=cross_seq_lens, + cross_slot_mapping=cross_slot_mapping_tensor, + cross_block_tables=cross_block_tables) + def make_attention(num_heads: int, head_size: int, scale: float): # Attention operator instance return Attention( @@ -534,6 +636,322 @@ def make_attention(num_heads: int, head_size: int, scale: float): ) +def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): + scale = float(1.0 / (head_size**0.5)) + attn_backend = make_backend(backend_name) + attn = make_attention(num_heads, head_size, scale) + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) + return scale, attn_backend, attn, kv_cache + +def self_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_prompt_len, block_base_addr=0): + + max_kv_prompt_len = max_q_prompt_len + + query, \ + key, \ + value, \ + prefill_query, \ + prefill_key, \ + prefill_value, \ + decode_query, \ + decode_key, \ + decode_value, \ + q_prompt_lens, \ + kv_prompt_lens, \ + _, \ + _, \ + prefill_q_prompt_lens, \ + prefill_kv_prompt_lens, \ + decode_q_prompt_lens, \ + decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=False) + + causal_mask = build_causal_mask(max_q_prompt_len, + max_kv_prompt_len).to(CUDA_DEVICE) + + ideal_output = ref_masked_attention(query, + key, + value, + scale=scale, + custom_mask=causal_mask, + q_prompt_lens=q_prompt_lens, + kv_prompt_lens=kv_prompt_lens) + + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) + for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ + bdx, :prefill_q_prompt_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( + prefill_q_prompt_len + 1)] + + prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, + prefill_q_prompt_lens) + decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, + [1 for _ in range(batch_size)]) + + decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _, max_block_idx = make_block_tables_slot_mapping( + block_size, q_prompt_lens, block_base_addr=block_base_addr) + + prefill_packed_query, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( + prefill_query, prefill_key, prefill_value, prefill_q_prompt_lens, + prefill_kv_prompt_lens) + + decode_packed_query, decode_packed_key, decode_packed_value, _, _ = pack_qkv( + decode_query, decode_key, decode_value, decode_q_prompt_lens, + decode_kv_prompt_lens) + + return query, \ + prefill_packed_query, \ + prefill_packed_key, \ + prefill_packed_value, \ + prefill_packed_ideal_output, \ + prefill_q_prompt_lens, \ + prefill_kv_prompt_lens, \ + decode_packed_query, \ + decode_packed_key, \ + decode_packed_value, \ + decode_packed_ideal_output, \ + decode_q_prompt_lens, \ + decode_kv_prompt_lens, \ + q_prompt_lens, \ + kv_prompt_lens, \ + decode_block_tables, \ + decode_slot_mapping, \ + prefill_slot_mapping, \ + prefill_block_tables, \ + max_block_idx + + +def cross_attn_setup_reuses_query(query, q_prompt_lens, prefill_q_prompt_lens, batch_size, num_heads, head_size, block_size, scale, max_q_prompt_len, max_kv_prompt_len, block_base_addr=0): + + _, \ + key, \ + value, \ + _, \ + _, \ + _, \ + _, \ + _, \ + _, \ + _, \ + kv_prompt_lens, \ + _, \ + _, \ + _, \ + _, \ + _, \ + _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=True) + + ideal_output = ref_masked_attention(query, + key, + value, + scale=scale, + q_prompt_lens=q_prompt_lens, + kv_prompt_lens=kv_prompt_lens) + + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) + for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ + bdx, :prefill_q_prompt_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( + prefill_q_prompt_len + 1)] + + prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, + prefill_q_prompt_lens) + decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, + [1 for _ in range(batch_size)]) + + # Unlike self-attention: + # - Prefill slot-mapping includes all key slots + # - Decode slot-mapping is empty + decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping, max_block_idx = make_block_tables_slot_mapping( + block_size, kv_prompt_lens, block_base_addr=block_base_addr) + + _, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( + None, key, value, prefill_q_prompt_lens, kv_prompt_lens) + + return prefill_packed_key, \ + prefill_packed_value, \ + prefill_packed_ideal_output, \ + decode_packed_ideal_output, \ + kv_prompt_lens, \ + decode_block_tables, \ + decode_slot_mapping, \ + prefill_slot_mapping, \ + prefill_block_tables, \ + max_block_idx + +def run_self_attention_test(attn,packed_query,packed_key,packed_value,kv_cache,attn_metadata:AttentionMetadata,scale): + attn_metadata.do_cross_attn = False + return attn.forward(packed_query, + packed_key, + packed_value, + kv_cache, + attn_metadata, + scale) + +def run_cross_attention_test(attn,packed_query,packed_key,packed_value,kv_cache,attn_metadata:AttentionMetadata,scale): + attn_metadata.do_cross_attn = True + return attn.forward(packed_query, + packed_key, + packed_value, + kv_cache, + attn_metadata, + scale) + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_q_prompt_len", Q_PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len", K_PROMPT_LENS) +def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, + backend_name: str, batch_size: int, + block_size: int, max_q_prompt_len: int, + max_kv_prompt_len: int) -> None: + + # Num KV cache blocks + num_blocks = 4096 + + # Attention scale factor, + # attention backend instance, + # attention wrapper instance, + # KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, + backend_name) + + # Self-attention setup + + self_block_base_addr=0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_prompt_lens, \ + self_prefill_kv_prompt_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + decode_q_prompt_lens, \ + self_decode_kv_prompt_lens, \ + q_prompt_lens, \ + self_kv_prompt_lens, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = self_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_prompt_len, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + cross_kv_prompt_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + final_max_block_idx = cross_attn_setup_reuses_query(query, + q_prompt_lens, + prefill_q_prompt_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_prompt_len, + max_kv_prompt_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross(attn_backend, + True, + prefill_q_prompt_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + cross_seq_lens = cross_kv_prompt_lens, + cross_block_tables = cross_prefill_block_tables, + cross_slot_mapping = cross_prefill_slot_mapping,) + + self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test(attn, + prefill_packed_query, + self_prefill_packed_key, + self_prefill_packed_value, + kv_cache, + prefill_attn_metadata, + scale) + + # - Prefill self-attention correct? + assert torch.allclose(self_prefill_packed_ideal_output,self_prefill_packed_actual_output.view_as(self_prefill_packed_ideal_output)) + + cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test(attn, + prefill_packed_query, + cross_prefill_packed_key, + cross_prefill_packed_value, + kv_cache, + prefill_attn_metadata, + scale) + + # - Prefill cross-attention correct? + assert torch.allclose(cross_prefill_packed_ideal_output,cross_prefill_packed_actual_output.view_as(cross_prefill_packed_ideal_output)) + + context_lens = copy.deepcopy(self_prefill_kv_prompt_lens) + + # DECODE: self- and cross-attention tests + + decode_attn_metadata: AttentionMetadata = make_metadata_self_cross(attn_backend, + False, + q_prompt_lens, + context_lens, + self_decode_block_tables, + self_decode_slot_mapping, + cross_seq_lens = cross_kv_prompt_lens, + cross_block_tables = cross_decode_block_tables, + cross_slot_mapping = cross_decode_slot_mapping,) + + self_decode_packed_actual_output: torch.Tensor = run_self_attention_test(attn, + decode_packed_query, + self_decode_packed_key, + self_decode_packed_value, + kv_cache, + decode_attn_metadata, + scale) + + assert torch.allclose(self_decode_packed_ideal_output,self_decode_packed_actual_output.view_as(self_decode_packed_ideal_output)) + + cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test(attn, + decode_packed_query, + None, + None, + kv_cache, + decode_attn_metadata, + scale) + + assert torch.allclose(cross_decode_packed_ideal_output,cross_decode_packed_actual_output.view_as(cross_decode_packed_ideal_output)) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3dfe363cbe7f2..5b6d2ac0e144f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -105,16 +105,36 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - _cached_prefill_metadata: Optional["XFormersMetadata"] = None - _cached_decode_metadata: Optional["XFormersMetadata"] = None - # Need to make KV cache read-only for cross-attention + # Self-attention prefill/decode metadata cache + _self_cached_prefill_metadata: Optional["XFormersMetadata"] = None + _self_cached_decode_metadata: Optional["XFormersMetadata"] = None + # Cross-attention prefill/decode metadata cache + _cross_cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cross_cached_decode_metadata: Optional["XFormersMetadata"] = None + + # Begin cross-attention fields... + + # If True, prefill_metadata() and decode_metadata() will return + # seqlen & memory-mapping data structures for cross-attention; + # otherwise, self-attention data structures will be returned. is_cross_attn: bool = False # (batch_size,). The "cross-sequence-length" per sequence,i.e. the key/value # sequence length (usually encoder sequence length) in the cross-attention # computation. None if this is self-attention cross_seq_lens: Optional[List[int]] = None + cross_seq_lens_tensor: Optional[torch.Tensor] = None + + # The maximum cross-sequence-length, if cross_seq_lens is specified. + # Note that for cross-attention there is no difference in key/value + # sequence length between prefill and decode + max_cross_seq_len: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None def __post_init__(self): # Set during the execution of the first attention op. @@ -124,67 +144,183 @@ def __post_init__(self): # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[AttentionBias]] = None + @property + def has_valid_cross_attn_metadata(self): + # No cross-attention metadata is present whatsoever + no_md = (self.cross_seq_lens is None) and (self.cross_slot_mapping is None) and (self.cross_block_tables is None) + # If any cross-attention metadata is present, it is invalid + invalid_md_if_not_no_md = (self.cross_seq_lens is None) or (self.cross_slot_mapping is None) or (self.cross_block_tables is None) + + if no_md: + return False + + assert (not invalid_md_if_not_no_md), "Invalid cross-attention metadata" + + return True + + @property + def do_cross_attn(self): + return self.is_cross_attn + + @do_cross_attn.setter + def do_cross_attn(self,state:bool): + + if state: + assert self.has_valid_cross_attn_metadata, "Must have self.cross_seq_lens not None in order to enable cross-attention" + + # Infer implicit cross-attention fields from user-provided fields, if needed + if self.cross_seq_lens_tensor is None: + self.cross_seq_lens_tensor = torch.tensor(self.cross_seq_lens, + dtype=self.seq_lens_tensor.dtype, + device=self.seq_lens_tensor.device) + if self.max_cross_seq_len is None: + self.max_cross_seq_len = max(self.cross_seq_lens) + + self.is_cross_attn = True + else: + self.is_cross_attn = False + @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self.num_prefills == 0: return None - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - - self._cached_prefill_metadata = XFormersMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - is_cross_attn=self.is_cross_attn, - cross_seq_lens=self.cross_seq_lens) - return self._cached_prefill_metadata + if not self.do_cross_attn: + # Self-attention prefill + + if self._self_cached_prefill_metadata is not None: + return self._self_cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + + self._self_cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None, + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + is_cross_attn=False, # Begin cross-attention fields below... + cross_seq_lens=None, + cross_seq_lens_tensor=None, + max_cross_seq_len=None, + cross_block_tables=None, + cross_slot_mapping=None) + return self._self_cached_prefill_metadata + + else: + # Cross-attention prefill + + if self._cross_cached_prefill_metadata is not None: + return self._cross_cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + + self._cross_cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None, + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + is_cross_attn=True, # Begin cross-attention fields below... + cross_seq_lens=self.cross_seq_lens, + cross_seq_lens_tensor=self.cross_seq_lens_tensor, + max_cross_seq_len=self.max_cross_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cross_cached_prefill_metadata @property def decode_metadata(self) -> Optional["XFormersMetadata"]: if self.num_decode_tokens == 0: return None - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._cached_decode_metadata = XFormersMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - is_cross_attn=self.is_cross_attn, - cross_seq_lens=self.cross_seq_lens) - return self._cached_decode_metadata + if not self.do_cross_attn: + # Self-attention decode + + if self._self_cached_decode_metadata is not None: + return self._self_cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._self_cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + is_cross_attn=False, # Begin cross-attention fields below... + cross_seq_lens=None, + cross_seq_lens_tensor=None, + max_cross_seq_len=None, + cross_block_tables=None, + cross_slot_mapping=None) + return self._self_cached_decode_metadata + + else: + # Cross-attention decode + + if self._cross_cached_decode_metadata is not None: + return self._cross_cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cross_cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + is_cross_attn=True, # Begin cross-attention fields below... + cross_seq_lens=self.cross_seq_lens, + cross_seq_lens_tensor=self.cross_seq_lens_tensor, + max_cross_seq_len=self.max_cross_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cross_cached_decode_metadata class XFormersImpl(AttentionImpl[XFormersMetadata]): @@ -268,6 +404,11 @@ def forward( if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) + # Self-attention vs. cross-attention will impact + # which KV cache memory-mapping & which + # seqlen datastructures we utilize + do_cross_attn = attn_metadata.do_cross_attn + if (kv_cache is not None): # Even if there are no new key/value pairs to cache, # we still need to break out key_cache and value_cache @@ -277,22 +418,30 @@ def forward( if (key is not None) and (value is not None): + if do_cross_attn: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, - attn_metadata.slot_mapping, + updated_slot_mapping, self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - is_cross_attn = attn_metadata.is_cross_attn - assert is_cross_attn or (key.shape[0] + assert do_cross_attn or (key.shape[0] == num_prefill_tokens + num_decode_tokens) - assert is_cross_attn or (value.shape[0] + assert do_cross_attn or (value.shape[0] == num_prefill_tokens + num_decode_tokens) output = torch.empty_like(query) @@ -301,7 +450,7 @@ def forward( # QKV for prefill. query = query[:num_prefill_tokens] - if not is_cross_attn and key is not None and value is not None: + if not do_cross_attn and key is not None and value is not None: key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] @@ -323,6 +472,8 @@ def forward( # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. + # + # TODO(afeldman-nm): support cross-attention out = PagedAttention.forward_prefix( query, key, @@ -341,17 +492,24 @@ def forward( output[:num_prefill_tokens] = out if decode_meta := attn_metadata.decode_metadata: + if do_cross_attn: + # Paged attention against cross-attention KV cache + seq_lens_arg = decode_meta.cross_seq_lens_tensor + max_seq_len_arg = decode_meta.max_cross_seq_len + block_tables_arg = decode_meta.cross_block_tables + else: + # Paged attention against self-attention KV cache + seq_lens_arg = decode_meta.seq_lens_tensor + max_seq_len_arg = decode_meta.max_decode_seq_len + block_tables_arg = decode_meta.block_tables + output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, key_cache, value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor if not is_cross_attn else - torch.tensor(decode_meta.cross_seq_lens, - dtype=decode_meta.seq_lens_tensor.dtype, - device=decode_meta.seq_lens_tensor.device), - decode_meta.max_decode_seq_len - if not is_cross_attn else max(decode_meta.cross_seq_lens), + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, self.kv_cache_dtype, self.num_kv_heads, self.scale, From 78ce588834c17c7e9b73a5544b1c691492ab8c5c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 00:15:13 -0400 Subject: [PATCH 025/443] refactoring --- tests/layer/test_self_and_cross_attn.py | 667 +++++++++++------------- 1 file changed, 311 insertions(+), 356 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 5f1e35ceb21d7..5cab054b61069 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -16,25 +16,22 @@ import random +# If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] +# +# TODO: # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64] +HEAD_SIZES = [64,256] -# [64, 80, 96, 112, 128, 256 -# ] if not is_hip() else [64, 80, 96, 112, 128] +NUM_HEADS = [1,16] -NUM_HEADS = [16] - -BATCH_SIZES = [16] +BATCH_SIZES = [1,16] BLOCK_SIZES = [16] BACKEND_NAMES = ["xformers"] CUDA_DEVICE = "cuda:0" -PROMPT_LENS = [128] - -Q_PROMPT_LENS = [128] - -K_PROMPT_LENS = [128] +MAX_Q_PROMPT_LENS = [128] +MAX_K_PROMPT_LENS = [128] def build_causal_mask(q_max_prompt_len, kv_max_prompt_len): @@ -266,6 +263,22 @@ def make_qkv(batch_size, def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): + ''' + Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an + unpadded number_of_tokens x num_heads x head_size tensor, where + number_of_tokens = sum(prompt_lens) + + Arguments: + * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size + * prompt_lens: list of token counts for each prompt + * device: CPU or CUDA device + + Returns + * packed_tensor: number_of_tokens x num_heads x head_size + * start_loc_list: start idx of each batch elt in packed_tensor; + [0] + list(itertools.accumulate(prompt_lens)) + ''' + num_tok = sum(prompt_lens) num_heads = unpacked_tensor.shape[-2] head_size = unpacked_tensor.shape[-1] @@ -285,6 +298,30 @@ def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): + ''' + Individually pack each of Q, K and V, each with dimensions + batch_size x padded_seq_len x num_heads x head_size, into + respective number_of_tokens x num_heads x head_size tensors. + + For Q, number_of_tokens = sum(q_prompt_lens). + + For K and V, number_of_tokens = sum(kv_prompt_lens) + + Arguments: + * query: batch_size x padded_seq_len x num_heads x head_size + * key: batch_size x padded_seq_len x num_heads x head_size + * value: batch_size x padded_seq_len x num_heads x head_size + * q_prompt_lens: list of token counts for each query + * kv_prompt_lens: list of token counts for each key/value + + Returns + * packed_query: number_of_tokens x num_heads x head_size + * packed_key: number_of_tokens x num_heads x head_size + * packed_value: number_of_tokens x num_heads x head_size + * q_start_loc_list: start idx of each query in packed_query + * kv_start_loc_list: start idx of each {key,value} in packed_{key,value} + ''' + if query is None: packed_query = None q_start_loc_list = None @@ -303,6 +340,16 @@ def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): def make_backend(backend_name: str) -> AttentionBackend: + ''' + Construct the backend instance determined by the backend_name string argument. + + "xformers" -> construct xformers backend + + TODO: flash attention backend + + Returns: + * Backend instance + ''' if backend_name == "xformers": return XFormersBackend() assert False, f"Unrecognized backend_name {backend_name} for unit test" @@ -313,10 +360,22 @@ def make_metadata_tensors(is_prompt: bool, context_lens: List[int], device=CUDA_DEVICE) -> tuple: ''' - Assumptions: - * No chunked prefill - * No (automatic) prefix caching - * Packed variable-length sequences + Build scalar & tensor values required to build attention metadata structure. + + Arguments: + * is_prompt: True -> Prefill, False -> Decode + * prompt_lens: list of token-counts for each prompt + * context_lens: list of context length values for each prompt + * device: CPU or CUDA device + + Returns: + * prompt_lens_tensor: prompt_lens list, as tensor + * context_lens_tensor: context_lens list, as tensor + * max_query_len: max(prompt_lens) if is_prompt, o/w 1 + * max_context_len: max(context_lens) + * max_prompt_len: max(prompt_lens) + * seq_start_loc: start idx of each sequence + * query_start_loc: start idx of each query ''' prompt_lens_tensor = torch.tensor(prompt_lens, dtype=torch.int, @@ -361,6 +420,21 @@ def make_kv_cache(num_blocks, block_size, device=CUDA_DEVICE, default_val=0.0): + ''' + Create a fake KV cache. + + Arguments: + * num_blocks: number of blocks in the KV cache + * num_heads: number of attention heads + * head_size: head dimension + * block_size: number of offsets within a block + * device: CPU or CUDA device + * default_val: initialization value for KV cache elements + + Returns: + * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) + ''' + kv_cache = torch.rand( (2, num_blocks, block_size * num_heads * head_size)).to(device) if default_val is not None: @@ -369,6 +443,10 @@ def make_kv_cache(num_blocks, def num_tokens_to_min_blocks(num_tokens, block_size): + ''' + Compute the minimum number of blocks required + to hold num_tokens tokens, given block_size + ''' return (num_tokens + block_size) // block_size @@ -377,9 +455,29 @@ def make_block_tables_slot_mapping(block_size, block_base_addr=0, device=CUDA_DEVICE): ''' - Naive block table: - * For each batch element... - * Block table has + Construct fake block tables & slot mappings. + + The first block is at + + block_base_addr + sum(num_blocks_list) * 2 - 1 + + and subsequent blocks count downward toward + block_base_addr + + Arguments: + * block_size: number of offsets per block + * prompt_lens: list of token-counts for each sequence + * block_base_addr: the block table base address + * device: CPU or CUDA device + + Return: + * decode_block_tables_tensor: fake the state of the block tables during decode + * decode_slot_mapping_tensor: fake the state of the slot mapping during decode + * prefill_slot_mapping_tensor: fake the state of the slot mapping during prefill + * prefill_block_tables_tensor: fake the state of the block tables during prefill + * slot_mapping_tensor: union of prefill and decode slot mappings + * empty_slot_mapping_tensor: empty slot mapping (useful for decode phase cross attention) + * max_block_idx: the highest block address within this block table ''' # Over-provision block table blocks by 1 @@ -438,99 +536,6 @@ def make_block_tables_slot_mapping(block_size, return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor, max_block_idx - -def make_metadata(attn_backend: AttentionBackend, - is_prompt: bool, - is_cross_attn: bool, - prompt_lens: List[int], - context_lens: List[int], - block_tables, - slot_mapping, - device=CUDA_DEVICE, - cross_prompt_lens: Optional[List[int]] = None): - ''' - Assumptions: - * No chunked prefill -> a batch is 100% prefill or 100% decode, never both - ''' - - if is_prompt: - num_prefills = len(prompt_lens) - num_prefill_tokens = sum(prompt_lens) - num_decode_tokens = 0 - - prompt_lens_tensor, \ - context_lens_tensor, \ - max_query_len, \ - _, \ - _, \ - seq_start_loc, \ - query_start_loc = make_metadata_tensors(is_prompt, - prompt_lens, - context_lens, - device=device) - - slot_mapping_tensor = torch.tensor(slot_mapping, - dtype=torch.long, - device=device) - - return attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=prompt_lens, - seq_lens_tensor=prompt_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=max(prompt_lens), - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - is_cross_attn=is_cross_attn, - cross_seq_lens=cross_prompt_lens) - - else: # not is_prompt - - num_prefills = 0 - num_prefill_tokens = 0 - num_decode_tokens = len(prompt_lens) - - prompt_lens_tensor, \ - context_lens_tensor, \ - max_query_len, \ - _, \ - _, \ - seq_start_loc, \ - query_start_loc = make_metadata_tensors(is_prompt, - prompt_lens, - context_lens, - device=device) - - slot_mapping_tensor = torch.tensor(slot_mapping, - dtype=torch.long, - device=device) - - return attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=prompt_lens, - seq_lens_tensor=prompt_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=max(prompt_lens), - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - is_cross_attn=is_cross_attn, - cross_seq_lens=cross_prompt_lens) - - def make_metadata_self_cross(attn_backend: AttentionBackend, is_prompt: bool, prompt_lens: List[int], @@ -540,10 +545,29 @@ def make_metadata_self_cross(attn_backend: AttentionBackend, device=CUDA_DEVICE, cross_seq_lens: Optional[List[int]] = None, cross_block_tables: Optional[torch.Tensor] = None, - cross_slot_mapping: Optional[List[int]] = None,): + cross_slot_mapping: Optional[List[int]] = None,) -> AttentionMetadata: ''' + Construct fake attention metadata for a combined + self-/cross-attention scenario i.e. an encoder/decoder + model. + Assumptions: * No chunked prefill -> a batch is 100% prefill or 100% decode, never both + + Arguments: + * attn_backend: Backend for sourcing attention kernels + * is_prompt: prefill if True, o/w decode + * prompt_lens: list of token counts for each sequence + * context_lens: list of context lengths for each sequence + * block_tables: self-attention block tables + * slot_mapping: self-attention slot_mapping + * device: CPU or CUDA device + * cross_seq_lens: list of token counts for each encoder sequence, if any exist + * cross_block_tables: cross-attention block tables, if required + * cross_slot_mapping: cross-attention slot mapping, if required + + Return: + * AttentionMetadata structure supporting self- and cross-attention ''' if is_prompt: @@ -628,7 +652,13 @@ def make_metadata_self_cross(attn_backend: AttentionBackend, cross_block_tables=cross_block_tables) def make_attention(num_heads: int, head_size: int, scale: float): - # Attention operator instance + ''' + Construct an instance of the Attention wrapper, suited to + the number of attention heads and head dimension + (num_heads and head_size respectively) as well as the + attention scale factor (scale) + ''' + return Attention( num_heads, head_size, @@ -637,6 +667,23 @@ def make_attention(num_heads: int, head_size: int, scale: float): def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): + ''' + Compute & build entities required for the self-/cross-attention test. + + Arguments: + * num_heads: Number of attention heads + * head_size: Head dimension + * num_blocks: Number of KV cache blocks + * block_size: Number of offsets within a KV cache block + * backend_name: selection of backend + + Returns: + * scale: 1/sqrt(head_size) + * attn_backend: backend instance + * attn: Attention wrapper instance + * kv_cache: fake KV cache, 2 x num_blocks x (block_size * num_heads * head_size) + ''' + scale = float(1.0 / (head_size**0.5)) attn_backend = make_backend(backend_name) attn = make_attention(num_heads, head_size, scale) @@ -644,6 +691,64 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): return scale, attn_backend, attn, kv_cache def self_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_prompt_len, block_base_addr=0): + ''' + Set up test vectors & data structures for self-attention test. + + A triplet of synthetic query/key/value tensors are constructed ("baseline" query/key/value). + Given this is a self-attention test, the key & value sequences will have the same length + as the corresponding queries. + + "Prefill" query/key/value tensors are derived by masking out the last value in each + baseline query/key/value. These tensors are used to test prefill & populate KV cache + for a subsequent decode test. + + "Decode" query/key/value tensors are derived by extracting *only* the last value from + each baseline query/key/value (i.e. complement of the prefill tensors.) These tensors + are used to test decode, conditional on the kv cache being populated during the + prefill test. + + The baseline query/key/value tensors are passed to an ideal reference self-attention implementation + to generate a "Baseline" ideal output tensor. This tensor is split into the "Prefill" + ideal output tensor (all but the last element of each output sequence) and the "Decode" + ideal output tensor (*only* the last element of each output sequence); the "Prefill" and + "Decode" ideal output tensors can be used to validate the prefill and decode test + results, respectively. + + This function also constructs the self-attention KV cache memory mapping + (slot mapping and block table), ensuring that the block table starts + at block_base_addr + + Arguments: + * batch_size + * num_heads: Number of attention heads + * head_size: Head dimension + * block_size: Number of offsets per KV cache block + * scale: attention scale parameter + * max_q_prompt_len: upper limit on query length for synthetic test vectors + * block_base_addr: self-attention block table base address + + Returns: + * query: "baseline" query; batch_size x padded_seq_len x num_heads x head_size + * prefill_packed_query: "prefill" query; number_of_tokens x num_heads x head_size + * prefill_packed_key: self-attn "prefill" key; number_of_tokens x num_heads x head_size + * prefill_packed_value: self-attn "prefill" value; number_of_tokens x num_heads x head_size + * prefill_packed_ideal_output: self-attn "prefill" ideal output; number_of_tokens x num_heads x head_size + * prefill_q_prompt_lens: list of token counts for each *prefill query* (one less than baseline query) + * prefill_kv_prompt_lens: list of token counts for each self-attn *prefill key/value* (should match prefill_q_prompt_lens) + * decode_packed_query: "decode" query; number_of_tokens x num_heads x head_size + * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x head_size + * decode_packed_value: self-attn "decode" key; number_of_tokens x num_heads x head_size + * decode_packed_ideal_output: self-attn "decode" ideal output; number_of_tokens x num_heads x head_size + * decode_q_prompt_lens: list of token counts for each *decode query* (should be 1) + * decode_kv_prompt_lens: list of token counts for each self-attn *decode key/value* (should match decode_q_prompt_lens) + * q_prompt_lens: "baseline" query seq lens; number_of_tokens x num_heads x head_size + * kv_prompt_lens: self-attn "baseline" key/value seq lens; number_of_tokens x num_heads x head_size + * decode_block_tables: fake self-attn decode-phase block table + * decode_slot_mapping: fake self-attn decode-phase slot mapping + * prefill_slot_mapping: fake self-attn prefill-phase slot mapping + * prefill_block_tables: fake self-attn prefill-phase block table + * max_block_idx: highest block address in the self-attention block-table + ''' max_kv_prompt_len = max_q_prompt_len @@ -723,6 +828,55 @@ def self_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_p def cross_attn_setup_reuses_query(query, q_prompt_lens, prefill_q_prompt_lens, batch_size, num_heads, head_size, block_size, scale, max_q_prompt_len, max_kv_prompt_len, block_base_addr=0): + ''' + Set up test vectors & data structures for cross-attention test. + + A triplet of synthetic cross-attention key/value tensors are constructed ("baseline" key/value). + Given this is a cross-attention test, we assume query tensors were already synthesized for a + prior self-attention test and will be reused for cross-attention. The key & value sequences + generated here will may have a different length than the corresponding queries (as is often + the case for cross-attention between decoder and encoder sequences.) + + Cross attention key & value tensors do not grow during autoregressive inference; thus + this function obtains a single key/value pair suitable for both prefill and decode. + + The "baseline" query tensor is received as an argument. The "baseline" query/key/value tensors + are passed to an ideal reference cross-attention implementation + to generate a "baseline" ideal output tensor. This tensor is split into the "Prefill" + ideal output tensor (all but the last element of each output sequence) and the "Decode" + ideal output tensor (*only* the last element of each output sequence); the "Prefill" and + "Decode" ideal output tensors can be used to validate the prefill and decode test + results, respectively. + + This function also constructs the cross-attention KV cache memory mapping + (slot mapping and block table), ensuring that the block table starts + at block_base_addr. + + Arguments: + * query: pre-existing "baseline" query; batch_size x padded_seq_len x num_heads x head_size + * q_prompt_lens: list of token-counts for each "baseline" query sequence + * prefill_q_prompt_lens: list of token-counts for each "prefill" query sequence + * batch_size + * num_heads: Number of attention heads + * head_size: Head dimension + * block_size: Number of offsets per KV cache block + * scale: attention scale parameter + * max_q_prompt_len: upper limit on query length for synthetic test vectors + * max_kv_prompt_len: upper limit on key/value length for synthetic test vectors + * block_base_addr: cross-attention block table base address + + Returns: + * packed_key: cross-attention key; number_of_tokens x num_heads x head_size + * packed_value: cross-attention value; number_of_tokens x num_heads x head_size + * prefill_packed_ideal_output: "prefill" ideal output; number_of_tokens x num_heads x head_size + * decode_packed_ideal_output: "decode" ideal output; number_of_tokens x num_heads x head_size + * kv_prompt_lens: list of token-counts for each key/value + * decode_block_tables: fake decode-phase block tables + * decode_slot_mapping: fake decode-phase slot mapping + * prefill_slot_mapping: fake prefill-phase slot mapping + * prefill_block_tables: fake prefill-phase block tables + * max_block_idx: highest block address in the cross-attention block-table + ''' _, \ key, \ @@ -768,11 +922,12 @@ def cross_attn_setup_reuses_query(query, q_prompt_lens, prefill_q_prompt_lens, b decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping, max_block_idx = make_block_tables_slot_mapping( block_size, kv_prompt_lens, block_base_addr=block_base_addr) - _, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( - None, key, value, prefill_q_prompt_lens, kv_prompt_lens) + # Packed key/value (query is already provided) + _, packed_key, packed_value, _, _ = pack_qkv( + None, key, value, None, kv_prompt_lens) - return prefill_packed_key, \ - prefill_packed_value, \ + return packed_key, \ + packed_value, \ prefill_packed_ideal_output, \ decode_packed_ideal_output, \ kv_prompt_lens, \ @@ -805,12 +960,32 @@ def run_cross_attention_test(attn,packed_query,packed_key,packed_value,kv_cache, @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len", Q_PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len", K_PROMPT_LENS) +@pytest.mark.parametrize("max_q_prompt_len", MAX_Q_PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len", MAX_K_PROMPT_LENS) def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: + ''' + Test: + * Construct fake test vectors for self- and cross-attention + * Construct attention metadata structure with self- and cross-attention attributes + * Test self- and cross-attention in the following order + * Prefill self-attention + * Prefill cross-attention + * Decode self-attention + * Decode cross-attention + * This order would exacerbate any accidental overlap in the self-/cross-attention block tables, + which we attempt to avoid + * Validate output correctness against ideal reference attention implementation + + Block tables are constructed such that cross-attention KV cache is in a higher, non-intersecting + address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V tensors. Self-attention + K/Vs must have the same seq len as Q while cross-attention K/Vs are allowed to differ in seq + len, as is often the case for cross-attention. + ''' # Num KV cache blocks num_blocks = 4096 @@ -843,10 +1018,10 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, self_decode_packed_key, \ self_decode_packed_value, \ self_decode_packed_ideal_output, \ - decode_q_prompt_lens, \ - self_decode_kv_prompt_lens, \ + _, \ + _, \ q_prompt_lens, \ - self_kv_prompt_lens, \ + _, \ self_decode_block_tables, \ self_decode_slot_mapping, \ self_prefill_slot_mapping, \ @@ -870,17 +1045,17 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - final_max_block_idx = cross_attn_setup_reuses_query(query, - q_prompt_lens, - prefill_q_prompt_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_prompt_len, - max_kv_prompt_len, - block_base_addr=cross_block_base_addr) + _ = cross_attn_setup_reuses_query(query, + q_prompt_lens, + prefill_q_prompt_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_prompt_len, + max_kv_prompt_len, + block_base_addr=cross_block_base_addr) # PREFILL: self- and cross-attention tests @@ -940,6 +1115,7 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, decode_attn_metadata, scale) + # - Decode self-attention correct? assert torch.allclose(self_decode_packed_ideal_output,self_decode_packed_actual_output.view_as(self_decode_packed_ideal_output)) cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test(attn, @@ -950,226 +1126,5 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, decode_attn_metadata, scale) - assert torch.allclose(cross_decode_packed_ideal_output,cross_decode_packed_actual_output.view_as(cross_decode_packed_ideal_output)) - -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_prompt_len", PROMPT_LENS) -def test_prefill_decode_self_attention(num_heads: int, head_size: int, - backend_name: str, batch_size: int, - block_size: int, - max_prompt_len: int) -> None: - # Attention operator instance - is_cross_attn = False - is_prompt = True - max_q_prompt_len = max_prompt_len - max_kv_prompt_len = max_q_prompt_len - context_lens = [0 for _ in range(batch_size)] - num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) - scale = float(1.0 / (head_size**0.5)) - attn = make_attention(num_heads, head_size, scale) - attn_backend = make_backend(backend_name) - query, \ - key, \ - value, \ - prefill_query, \ - prefill_key, \ - prefill_value, \ - decode_query, \ - decode_key, \ - decode_value, \ - q_prompt_lens, \ - kv_prompt_lens, \ - _, \ - _, \ - prefill_q_prompt_lens, \ - prefill_kv_prompt_lens, \ - decode_q_prompt_lens, \ - decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=False) - - causal_mask = build_causal_mask(max_q_prompt_len, - max_kv_prompt_len).to(CUDA_DEVICE) - ideal_output = ref_masked_attention(query, - key, - value, - scale=scale, - custom_mask=causal_mask, - q_prompt_lens=q_prompt_lens, - kv_prompt_lens=kv_prompt_lens) - - prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ - bdx, :prefill_q_prompt_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( - prefill_q_prompt_len + 1)] - - prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_prompt_lens) - decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)]) - - decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping( - block_size, q_prompt_lens) - prefill_attn_metadata: AttentionMetadata = make_metadata( - attn_backend, - is_prompt, - is_cross_attn, - prefill_q_prompt_lens, - context_lens, - prefill_block_tables, - prefill_slot_mapping, - cross_prompt_lens=None) - - prefill_packed_query, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( - prefill_query, prefill_key, prefill_value, prefill_q_prompt_lens, - prefill_kv_prompt_lens) - - prefill_packed_actual_output = attn.forward(prefill_packed_query, - prefill_packed_key, - prefill_packed_value, kv_cache, - prefill_attn_metadata, scale) - - # eval correctness of prefill output - assert torch.allclose( - prefill_packed_actual_output, - prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) - - is_prompt = False - context_lens = copy.deepcopy(prefill_kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, - is_cross_attn, q_prompt_lens, - context_lens, decode_block_tables, - decode_slot_mapping) - - decode_packed_query, decode_packed_key, decode_packed_value, _, _ = pack_qkv( - decode_query, decode_key, decode_value, decode_q_prompt_lens, - decode_kv_prompt_lens) - - decode_packed_actual_output = attn.forward(decode_packed_query, - decode_packed_key, - decode_packed_value, kv_cache, - decode_attn_metadata, scale) - - # eval correctness of decode output - assert torch.allclose( - decode_packed_actual_output, - decode_packed_ideal_output.view_as(decode_packed_actual_output)) - - -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len", Q_PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len", K_PROMPT_LENS) -def test_prefill_decode_cross_attention(num_heads: int, head_size: int, - backend_name: str, batch_size: int, - block_size: int, max_q_prompt_len: int, - max_kv_prompt_len: int) -> None: - # Attention operator instance - is_cross_attn = True - is_prompt = True - context_lens = [0 for _ in range(batch_size)] - num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) - scale = float(1.0 / (head_size**0.5)) - attn = make_attention(num_heads, head_size, scale) - attn_backend = make_backend(backend_name) - - query, \ - key, \ - value, \ - prefill_query, \ - _, \ - _, \ - decode_query, \ - _, \ - _, \ - q_prompt_lens, \ - kv_prompt_lens, \ - _, \ - _, \ - prefill_q_prompt_lens, \ - _, \ - decode_q_prompt_lens, \ - _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=is_cross_attn) - - ideal_output = ref_masked_attention(query, - key, - value, - scale=scale, - q_prompt_lens=q_prompt_lens, - kv_prompt_lens=kv_prompt_lens) - - prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ - bdx, :prefill_q_prompt_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( - prefill_q_prompt_len + 1)] - - prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_prompt_lens) - decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)]) - - # Unlike self-attention: - # - Prefill slot-mapping includes all key slots - # - Decode slot-mapping is empty - decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping( - block_size, kv_prompt_lens) - - prefill_attn_metadata: AttentionMetadata = make_metadata( - attn_backend, - is_prompt, - is_cross_attn, - prefill_q_prompt_lens, - context_lens, - prefill_block_tables, - prefill_slot_mapping, - cross_prompt_lens=kv_prompt_lens) - - prefill_packed_query, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( - prefill_query, key, value, prefill_q_prompt_lens, kv_prompt_lens) - - prefill_packed_actual_output = attn.forward(prefill_packed_query, - prefill_packed_key, - prefill_packed_value, kv_cache, - prefill_attn_metadata, scale) - - # eval correctness of prefill output - assert torch.allclose( - prefill_packed_actual_output, - prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) - - is_prompt = False - context_lens = copy.deepcopy(kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, - is_prompt, - is_cross_attn, - q_prompt_lens, - context_lens, - decode_block_tables, - decode_slot_mapping, - cross_prompt_lens=kv_prompt_lens) - - decode_packed_query, _, _, _, _ = pack_qkv(decode_query, key, value, - decode_q_prompt_lens, - kv_prompt_lens) - - decode_packed_actual_output = attn.forward(decode_packed_query, None, None, - kv_cache, decode_attn_metadata, - scale) - - # eval correctness of decode output - assert torch.allclose( - decode_packed_actual_output, - decode_packed_ideal_output.view_as(decode_packed_actual_output)) + # - Decode cross-attention correct? + assert torch.allclose(cross_decode_packed_ideal_output,cross_decode_packed_actual_output.view_as(cross_decode_packed_ideal_output)) \ No newline at end of file From 92701a3bd21cd89c0e004c0bdee2b783a12ce846 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 00:22:55 -0400 Subject: [PATCH 026/443] some format fixes --- tests/layer/test_self_and_cross_attn.py | 244 +++++++++++++----------- vllm/attention/backends/xformers.py | 40 ++-- 2 files changed, 157 insertions(+), 127 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 5cab054b61069..2fe5add8ae453 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -21,11 +21,11 @@ # TODO: # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64,256] +HEAD_SIZES = [64, 256] -NUM_HEADS = [1,16] +NUM_HEADS = [1, 16] -BATCH_SIZES = [1,16] +BATCH_SIZES = [1, 16] BLOCK_SIZES = [16] BACKEND_NAMES = ["xformers"] CUDA_DEVICE = "cuda:0" @@ -492,7 +492,8 @@ def make_block_tables_slot_mapping(block_size, prefill_slot_mapping = [] decode_slot_mapping = [] slot_mapping = [] - block_base_idx = block_base_addr + sum(num_blocks_list) * 2 - 1 # Support more blocks than needed + block_base_idx = block_base_addr + sum( + num_blocks_list) * 2 - 1 # Support more blocks than needed max_block_idx = block_base_idx for sdx, num_tokens in enumerate(prompt_lens): num_blocks = num_blocks_list[sdx] @@ -536,16 +537,19 @@ def make_block_tables_slot_mapping(block_size, return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor, max_block_idx -def make_metadata_self_cross(attn_backend: AttentionBackend, - is_prompt: bool, - prompt_lens: List[int], - context_lens: List[int], - block_tables, - slot_mapping, - device=CUDA_DEVICE, - cross_seq_lens: Optional[List[int]] = None, - cross_block_tables: Optional[torch.Tensor] = None, - cross_slot_mapping: Optional[List[int]] = None,) -> AttentionMetadata: + +def make_metadata_self_cross( + attn_backend: AttentionBackend, + is_prompt: bool, + prompt_lens: List[int], + context_lens: List[int], + block_tables, + slot_mapping, + device=CUDA_DEVICE, + cross_seq_lens: Optional[List[int]] = None, + cross_block_tables: Optional[torch.Tensor] = None, + cross_slot_mapping: Optional[List[int]] = None, +) -> AttentionMetadata: ''' Construct fake attention metadata for a combined self-/cross-attention scenario i.e. an encoder/decoder @@ -651,6 +655,7 @@ def make_metadata_self_cross(attn_backend: AttentionBackend, cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) + def make_attention(num_heads: int, head_size: int, scale: float): ''' Construct an instance of the Attention wrapper, suited to @@ -690,7 +695,14 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache -def self_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_prompt_len, block_base_addr=0): + +def self_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_prompt_len, + block_base_addr=0): ''' Set up test vectors & data structures for self-attention test. @@ -772,7 +784,7 @@ def self_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_p causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).to(CUDA_DEVICE) - + ideal_output = ref_masked_attention(query, key, value, @@ -827,7 +839,17 @@ def self_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_p max_block_idx -def cross_attn_setup_reuses_query(query, q_prompt_lens, prefill_q_prompt_lens, batch_size, num_heads, head_size, block_size, scale, max_q_prompt_len, max_kv_prompt_len, block_base_addr=0): +def cross_attn_setup_reuses_query(query, + q_prompt_lens, + prefill_q_prompt_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_prompt_len, + max_kv_prompt_len, + block_base_addr=0): ''' Set up test vectors & data structures for cross-attention test. @@ -919,12 +941,12 @@ def cross_attn_setup_reuses_query(query, q_prompt_lens, prefill_q_prompt_lens, b # Unlike self-attention: # - Prefill slot-mapping includes all key slots # - Decode slot-mapping is empty - decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping, max_block_idx = make_block_tables_slot_mapping( + decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping, max_block_idx = make_block_tables_slot_mapping( block_size, kv_prompt_lens, block_base_addr=block_base_addr) - + # Packed key/value (query is already provided) - _, packed_key, packed_value, _, _ = pack_qkv( - None, key, value, None, kv_prompt_lens) + _, packed_key, packed_value, _, _ = pack_qkv(None, key, value, None, + kv_prompt_lens) return packed_key, \ packed_value, \ @@ -937,23 +959,21 @@ def cross_attn_setup_reuses_query(query, q_prompt_lens, prefill_q_prompt_lens, b prefill_block_tables, \ max_block_idx -def run_self_attention_test(attn,packed_query,packed_key,packed_value,kv_cache,attn_metadata:AttentionMetadata,scale): + +def run_self_attention_test(attn, packed_query, packed_key, packed_value, + kv_cache, attn_metadata: AttentionMetadata, scale): attn_metadata.do_cross_attn = False - return attn.forward(packed_query, - packed_key, - packed_value, - kv_cache, - attn_metadata, - scale) - -def run_cross_attention_test(attn,packed_query,packed_key,packed_value,kv_cache,attn_metadata:AttentionMetadata,scale): + return attn.forward(packed_query, packed_key, packed_value, kv_cache, + attn_metadata, scale) + + +def run_cross_attention_test(attn, packed_query, packed_key, packed_value, + kv_cache, attn_metadata: AttentionMetadata, + scale): attn_metadata.do_cross_attn = True - return attn.forward(packed_query, - packed_key, - packed_value, - kv_cache, - attn_metadata, - scale) + return attn.forward(packed_query, packed_key, packed_value, kv_cache, + attn_metadata, scale) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -962,10 +982,10 @@ def run_cross_attention_test(attn,packed_query,packed_key,packed_value,kv_cache, @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_q_prompt_len", MAX_Q_PROMPT_LENS) @pytest.mark.parametrize("max_kv_prompt_len", MAX_K_PROMPT_LENS) -def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, - backend_name: str, batch_size: int, - block_size: int, max_q_prompt_len: int, - max_kv_prompt_len: int) -> None: +def test_prefill_decode_self_and_cross_attention( + num_heads: int, head_size: int, backend_name: str, batch_size: int, + block_size: int, max_q_prompt_len: int, + max_kv_prompt_len: int) -> None: ''' Test: * Construct fake test vectors for self- and cross-attention @@ -997,15 +1017,15 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, scale, \ attn_backend, \ attn, \ - kv_cache = basic_setup(num_heads, - head_size, - num_blocks, - block_size, + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, backend_name) # Self-attention setup - self_block_base_addr=0 + self_block_base_addr = 0 query, \ prefill_packed_query, \ @@ -1026,11 +1046,11 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, self_decode_slot_mapping, \ self_prefill_slot_mapping, \ self_prefill_block_tables, \ - cross_block_base_addr = self_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, + cross_block_base_addr = self_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, max_q_prompt_len, block_base_addr=self_block_base_addr) @@ -1045,86 +1065,86 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = cross_attn_setup_reuses_query(query, - q_prompt_lens, - prefill_q_prompt_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_prompt_len, - max_kv_prompt_len, + _ = cross_attn_setup_reuses_query(query, + q_prompt_lens, + prefill_q_prompt_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_prompt_len, + max_kv_prompt_len, block_base_addr=cross_block_base_addr) # PREFILL: self- and cross-attention tests context_lens = [0 for _ in range(batch_size)] - prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross(attn_backend, - True, - prefill_q_prompt_lens, - context_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - cross_seq_lens = cross_kv_prompt_lens, - cross_block_tables = cross_prefill_block_tables, - cross_slot_mapping = cross_prefill_slot_mapping,) - - self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test(attn, - prefill_packed_query, - self_prefill_packed_key, - self_prefill_packed_value, - kv_cache, - prefill_attn_metadata, - scale) + prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross( + attn_backend, + True, + prefill_q_prompt_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + cross_seq_lens=cross_kv_prompt_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( + attn, prefill_packed_query, self_prefill_packed_key, + self_prefill_packed_value, kv_cache, prefill_attn_metadata, scale) # - Prefill self-attention correct? - assert torch.allclose(self_prefill_packed_ideal_output,self_prefill_packed_actual_output.view_as(self_prefill_packed_ideal_output)) + assert torch.allclose( + self_prefill_packed_ideal_output, + self_prefill_packed_actual_output.view_as( + self_prefill_packed_ideal_output)) - cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test(attn, - prefill_packed_query, - cross_prefill_packed_key, - cross_prefill_packed_value, - kv_cache, - prefill_attn_metadata, - scale) + cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test( + attn, prefill_packed_query, cross_prefill_packed_key, + cross_prefill_packed_value, kv_cache, prefill_attn_metadata, scale) # - Prefill cross-attention correct? - assert torch.allclose(cross_prefill_packed_ideal_output,cross_prefill_packed_actual_output.view_as(cross_prefill_packed_ideal_output)) + assert torch.allclose( + cross_prefill_packed_ideal_output, + cross_prefill_packed_actual_output.view_as( + cross_prefill_packed_ideal_output)) context_lens = copy.deepcopy(self_prefill_kv_prompt_lens) # DECODE: self- and cross-attention tests - decode_attn_metadata: AttentionMetadata = make_metadata_self_cross(attn_backend, - False, - q_prompt_lens, - context_lens, - self_decode_block_tables, - self_decode_slot_mapping, - cross_seq_lens = cross_kv_prompt_lens, - cross_block_tables = cross_decode_block_tables, - cross_slot_mapping = cross_decode_slot_mapping,) - - self_decode_packed_actual_output: torch.Tensor = run_self_attention_test(attn, - decode_packed_query, - self_decode_packed_key, - self_decode_packed_value, - kv_cache, - decode_attn_metadata, - scale) + decode_attn_metadata: AttentionMetadata = make_metadata_self_cross( + attn_backend, + False, + q_prompt_lens, + context_lens, + self_decode_block_tables, + self_decode_slot_mapping, + cross_seq_lens=cross_kv_prompt_lens, + cross_block_tables=cross_decode_block_tables, + cross_slot_mapping=cross_decode_slot_mapping, + ) + + self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( + attn, decode_packed_query, self_decode_packed_key, + self_decode_packed_value, kv_cache, decode_attn_metadata, scale) # - Decode self-attention correct? - assert torch.allclose(self_decode_packed_ideal_output,self_decode_packed_actual_output.view_as(self_decode_packed_ideal_output)) + assert torch.allclose( + self_decode_packed_ideal_output, + self_decode_packed_actual_output.view_as( + self_decode_packed_ideal_output)) - cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test(attn, - decode_packed_query, - None, - None, - kv_cache, - decode_attn_metadata, - scale) + cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test( + attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata, + scale) # - Decode cross-attention correct? - assert torch.allclose(cross_decode_packed_ideal_output,cross_decode_packed_actual_output.view_as(cross_decode_packed_ideal_output)) \ No newline at end of file + assert torch.allclose( + cross_decode_packed_ideal_output, + cross_decode_packed_actual_output.view_as( + cross_decode_packed_ideal_output)) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 5b6d2ac0e144f..0c8db7e47a50d 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -115,7 +115,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Begin cross-attention fields... - # If True, prefill_metadata() and decode_metadata() will return + # If True, prefill_metadata() and decode_metadata() will return # seqlen & memory-mapping data structures for cross-attention; # otherwise, self-attention data structures will be returned. is_cross_attn: bool = False @@ -147,14 +147,19 @@ def __post_init__(self): @property def has_valid_cross_attn_metadata(self): # No cross-attention metadata is present whatsoever - no_md = (self.cross_seq_lens is None) and (self.cross_slot_mapping is None) and (self.cross_block_tables is None) + no_md = (self.cross_seq_lens is + None) and (self.cross_slot_mapping is + None) and (self.cross_block_tables is None) # If any cross-attention metadata is present, it is invalid - invalid_md_if_not_no_md = (self.cross_seq_lens is None) or (self.cross_slot_mapping is None) or (self.cross_block_tables is None) + invalid_md_if_not_no_md = (self.cross_seq_lens is None) or ( + self.cross_slot_mapping is None) or (self.cross_block_tables is + None) if no_md: return False - - assert (not invalid_md_if_not_no_md), "Invalid cross-attention metadata" + + assert ( + not invalid_md_if_not_no_md), "Invalid cross-attention metadata" return True @@ -163,17 +168,20 @@ def do_cross_attn(self): return self.is_cross_attn @do_cross_attn.setter - def do_cross_attn(self,state:bool): + def do_cross_attn(self, state: bool): if state: assert self.has_valid_cross_attn_metadata, "Must have self.cross_seq_lens not None in order to enable cross-attention" # Infer implicit cross-attention fields from user-provided fields, if needed if self.cross_seq_lens_tensor is None: - self.cross_seq_lens_tensor = torch.tensor(self.cross_seq_lens, - dtype=self.seq_lens_tensor.dtype, - device=self.seq_lens_tensor.device) + assert self.seq_lens_tensor is not None + self.cross_seq_lens_tensor = torch.tensor( + self.cross_seq_lens, + dtype=self.seq_lens_tensor.dtype, + device=self.seq_lens_tensor.device) if self.max_cross_seq_len is None: + assert self.cross_seq_lens is not None self.max_cross_seq_len = max(self.cross_seq_lens) self.is_cross_attn = True @@ -209,10 +217,11 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: max_decode_seq_len=0, query_start_loc=self.query_start_loc[:self.num_prefills + 1], seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + context_lens_tensor=self.context_lens_tensor[:self. + num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_cross_attn=False, # Begin cross-attention fields below... + is_cross_attn=False, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -244,10 +253,11 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: max_decode_seq_len=0, query_start_loc=self.query_start_loc[:self.num_prefills + 1], seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + context_lens_tensor=self.context_lens_tensor[:self. + num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_cross_attn=True, # Begin cross-attention fields below... + is_cross_attn=True, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, @@ -283,7 +293,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_cross_attn=False, # Begin cross-attention fields below... + is_cross_attn=False, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -314,7 +324,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_cross_attn=True, # Begin cross-attention fields below... + is_cross_attn=True, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, From 39788150b2efd4177fc2a0da6615ced032e8a3b0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 00:56:59 -0400 Subject: [PATCH 027/443] refactored long lines in self/cross attn test --- tests/layer/test_self_and_cross_attn.py | 304 +++++++++++++++--------- 1 file changed, 187 insertions(+), 117 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 2fe5add8ae453..4fe9862de99a2 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -12,14 +12,9 @@ from vllm.utils import make_tensor_with_pad -from vllm.attention.layer import Attention - -import random - # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] # -# TODO: -# FlashAttention forward only supports head dimension at most 128 +# TODO: FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 HEAD_SIZES = [64, 256] @@ -39,10 +34,12 @@ def build_causal_mask(q_max_prompt_len, kv_max_prompt_len): Create a q_max_prompt_len x kv_max_prompt_len causal mask Arguments: + * q_max_prompt_len: query max prompt len * kv_max_prompt_len: key/value max prompt len Returns: + * 2D tensor, q_max_prompt_len x kv_max_prompt_len ''' @@ -65,19 +62,25 @@ def ref_masked_attention( kv_prompt_lens: Optional[List] = None) -> torch.Tensor: ''' "Golden" masked attention reference. Supports two types of masking: - * Basic attention mask, utilizing {q,kv}_prompt_lens args to mask out padding elements - * Custom attention mask, which can force an arbitrary mask tensor, i.e. causal + + * Basic attention mask, utilizing {q,kv}_prompt_lens args to mask out + padding elements + * Custom attention mask, which can force an arbitrary mask tensor, i.e. + causal Arguments: + * query: batch_size x q_padded_seq_len x num_heads x head_size * key: batch_size x kv_padded_seq_len x num_heads x head_size * value: batch_size x kv_padded_seq_len x num_heads x head_size * scale: Attention scale factor - * Custom mask: custom attention mask; good place to inject a causal attention mask + * Custom mask: custom attention mask; good place to inject a causal + attention mask * q_prompt_lens: list of unpadded query seq_lens for each batch index * kv_prompt_lens: list of unpadded key/value seq_lens for each batch index Returns: + * Attention result, batch_size x q_padded_seq_len x num_heads x head_size ''' @@ -120,26 +123,39 @@ def make_qkv(batch_size, Construct QKV test tensors for self- and cross-attention. Generates three query/key/value triplets: + * "Baseline" query/key/value (for input to reference attention function) - * "Prefill" query/key/value (last sequence offset zero'd out, for use as input to prefill kernel) - * "Decode" query/key/value (only the last sequence offset from baseline, for use as input to decode kernel) + * "Prefill" query/key/value (last sequence offset zero'd out, for use as + input to prefill kernel) + * "Decode" query/key/value (only the last sequence offset from baseline, + for use as input to decode kernel) - Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v seqlens + Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v + seqlens Arguments: + * batch_size * max_q_prompt_len: max query prompt len * max_kv_prompt_len: max key/value prompt len * num_heads * head_size - * is_cross_attn: if True, query seqlen may differ from key/value seqlen (as is often the case for cross-attention); o/w, query/key/value seqlens match at each batch index (max_kv_prompt_len is unused) - * force_max_len: if True, all query seqlens are max_q_prompt_len; o/w query seqlens are random in [2,max_q_prompt_lens]. Same for key/value seqlens and max_kv_prompt_len, unless forced by is_cross_attn=False + * is_cross_attn: if True, query seqlen may differ from key/value seqlen (as + is often the case for cross-attention); o/w, query/key/value seqlens match + at each batch index (max_kv_prompt_len is unused) + * force_max_len: if True, all query seqlens are max_q_prompt_len; o/w query + seqlens are random in [2,max_q_prompt_lens]. Same for key/value seqlens + and max_kv_prompt_len, unless forced by is_cross_attn=False * device: CPU or CUDA device Returns: - * query: "baseline" query; batch_size x max_q_prompt_len x num_heads x head_size - * key: "baseline" key; batch_size x max_kv_prompt_len x num_heads x head_size - * value: "baseline" value; batch_size x max_kv_prompt_len x num_heads x head_size + + * query: "baseline" query; batch_size x max_q_prompt_len x num_heads x + head_size + * key: "baseline" key; batch_size x max_kv_prompt_len x num_heads x + head_size + * value: "baseline" value; batch_size x max_kv_prompt_len x num_heads x + head_size * prefill_query: batch_size x (max_q_prompt_len-1) x num_heads x head_size * prefill_key: batch_size x (max_kv_prompt_len-1) x num_heads x head_size * prefill_value: batch_size x (max_kv_prompt_len-1) x num_heads x head_size @@ -148,8 +164,10 @@ def make_qkv(batch_size, * decode_value: batch_size x 1 x num_heads x head_size * q_prompt_lens: "baseline" query seqlen list * kv_prompt_lens: "baseline" key/value seqlen list - * actual_max_q_prompt_len: actual "baseline" query max prompt len (may be <= max_q_prompt_len due to randomness) - * actual_max_kv_prompt_len: actual "baseline" key/value max prompt len (may be <= max_kv_prompt_len due to randomness) + * actual_max_q_prompt_len: actual "baseline" query max prompt len (may be <= + max_q_prompt_len due to randomness) + * actual_max_kv_prompt_len: actual "baseline" key/value max prompt len (may + be <= max_kv_prompt_len due to randomness) * prefill_q_prompt_lens: "prefill" query seqlen list * prefill_kv_prompt_lens: "prefill" key/value seqlen list * decode_q_prompt_lens: "decode" query seqlen list (all ones) @@ -264,19 +282,21 @@ def make_qkv(batch_size, def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): ''' - Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an - unpadded number_of_tokens x num_heads x head_size tensor, where + Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an + unpadded number_of_tokens x num_heads x head_size tensor, where number_of_tokens = sum(prompt_lens) Arguments: + * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size * prompt_lens: list of token counts for each prompt * device: CPU or CUDA device Returns + * packed_tensor: number_of_tokens x num_heads x head_size - * start_loc_list: start idx of each batch elt in packed_tensor; - [0] + list(itertools.accumulate(prompt_lens)) + * start_loc_list: start idx of each batch elt in packed_tensor; [0] + + list(itertools.accumulate(prompt_lens)) ''' num_tok = sum(prompt_lens) @@ -299,15 +319,16 @@ def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): ''' - Individually pack each of Q, K and V, each with dimensions - batch_size x padded_seq_len x num_heads x head_size, into - respective number_of_tokens x num_heads x head_size tensors. + Individually pack each of Q, K and V, each with dimensions batch_size x + padded_seq_len x num_heads x head_size, into respective number_of_tokens x + num_heads x head_size tensors. For Q, number_of_tokens = sum(q_prompt_lens). For K and V, number_of_tokens = sum(kv_prompt_lens) Arguments: + * query: batch_size x padded_seq_len x num_heads x head_size * key: batch_size x padded_seq_len x num_heads x head_size * value: batch_size x padded_seq_len x num_heads x head_size @@ -315,6 +336,7 @@ def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): * kv_prompt_lens: list of token counts for each key/value Returns + * packed_query: number_of_tokens x num_heads x head_size * packed_key: number_of_tokens x num_heads x head_size * packed_value: number_of_tokens x num_heads x head_size @@ -341,13 +363,15 @@ def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): def make_backend(backend_name: str) -> AttentionBackend: ''' - Construct the backend instance determined by the backend_name string argument. + Construct the backend instance determined by the backend_name string + argument. "xformers" -> construct xformers backend TODO: flash attention backend Returns: + * Backend instance ''' if backend_name == "xformers": @@ -363,12 +387,14 @@ def make_metadata_tensors(is_prompt: bool, Build scalar & tensor values required to build attention metadata structure. Arguments: + * is_prompt: True -> Prefill, False -> Decode * prompt_lens: list of token-counts for each prompt * context_lens: list of context length values for each prompt * device: CPU or CUDA device Returns: + * prompt_lens_tensor: prompt_lens list, as tensor * context_lens_tensor: context_lens list, as tensor * max_query_len: max(prompt_lens) if is_prompt, o/w 1 @@ -399,9 +425,8 @@ def make_metadata_tensors(is_prompt: bool, query_start_loc = copy.deepcopy(seq_start_loc) max_query_len = max_prompt_len else: - # Decode: one new query input token per batch - # element, thus query_start_loc is the cumsum - # of [1,1,1,...] + # Decode: one new query input token per batch element, thus + # query_start_loc is the cumsum of [1,1,1,...] query_start_loc = list(range(len(seq_start_loc))) max_query_len = 1 @@ -424,6 +449,7 @@ def make_kv_cache(num_blocks, Create a fake KV cache. Arguments: + * num_blocks: number of blocks in the KV cache * num_heads: number of attention heads * head_size: head dimension @@ -432,6 +458,7 @@ def make_kv_cache(num_blocks, * default_val: initialization value for KV cache elements Returns: + * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) ''' @@ -444,8 +471,8 @@ def make_kv_cache(num_blocks, def num_tokens_to_min_blocks(num_tokens, block_size): ''' - Compute the minimum number of blocks required - to hold num_tokens tokens, given block_size + Compute the minimum number of blocks required to hold num_tokens tokens, + given block_size ''' return (num_tokens + block_size) // block_size @@ -461,22 +488,28 @@ def make_block_tables_slot_mapping(block_size, block_base_addr + sum(num_blocks_list) * 2 - 1 - and subsequent blocks count downward toward - block_base_addr + and subsequent blocks count downward toward block_base_addr Arguments: + * block_size: number of offsets per block * prompt_lens: list of token-counts for each sequence * block_base_addr: the block table base address * device: CPU or CUDA device Return: - * decode_block_tables_tensor: fake the state of the block tables during decode - * decode_slot_mapping_tensor: fake the state of the slot mapping during decode - * prefill_slot_mapping_tensor: fake the state of the slot mapping during prefill - * prefill_block_tables_tensor: fake the state of the block tables during prefill + + * decode_block_tables_tensor: fake the state of the block tables during + decode + * decode_slot_mapping_tensor: fake the state of the slot mapping during + decode + * prefill_slot_mapping_tensor: fake the state of the slot mapping during + prefill + * prefill_block_tables_tensor: fake the state of the block tables during + prefill * slot_mapping_tensor: union of prefill and decode slot mappings - * empty_slot_mapping_tensor: empty slot mapping (useful for decode phase cross attention) + * empty_slot_mapping_tensor: empty slot mapping (useful for decode phase + cross attention) * max_block_idx: the highest block address within this block table ''' @@ -551,14 +584,15 @@ def make_metadata_self_cross( cross_slot_mapping: Optional[List[int]] = None, ) -> AttentionMetadata: ''' - Construct fake attention metadata for a combined - self-/cross-attention scenario i.e. an encoder/decoder - model. + Construct fake attention metadata for a combined self-/cross-attention + scenario i.e. an encoder/decoder model. Assumptions: + * No chunked prefill -> a batch is 100% prefill or 100% decode, never both Arguments: + * attn_backend: Backend for sourcing attention kernels * is_prompt: prefill if True, o/w decode * prompt_lens: list of token counts for each sequence @@ -566,11 +600,13 @@ def make_metadata_self_cross( * block_tables: self-attention block tables * slot_mapping: self-attention slot_mapping * device: CPU or CUDA device - * cross_seq_lens: list of token counts for each encoder sequence, if any exist + * cross_seq_lens: list of token counts for each encoder sequence, if any + exist * cross_block_tables: cross-attention block tables, if required * cross_slot_mapping: cross-attention slot mapping, if required Return: + * AttentionMetadata structure supporting self- and cross-attention ''' @@ -658,10 +694,9 @@ def make_metadata_self_cross( def make_attention(num_heads: int, head_size: int, scale: float): ''' - Construct an instance of the Attention wrapper, suited to - the number of attention heads and head dimension - (num_heads and head_size respectively) as well as the - attention scale factor (scale) + Construct an instance of the Attention wrapper, suited to the number of + attention heads and head dimension (num_heads and head_size respectively) as + well as the attention scale factor (scale) ''' return Attention( @@ -676,6 +711,7 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): Compute & build entities required for the self-/cross-attention test. Arguments: + * num_heads: Number of attention heads * head_size: Head dimension * num_blocks: Number of KV cache blocks @@ -683,10 +719,12 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): * backend_name: selection of backend Returns: + * scale: 1/sqrt(head_size) * attn_backend: backend instance * attn: Attention wrapper instance - * kv_cache: fake KV cache, 2 x num_blocks x (block_size * num_heads * head_size) + * kv_cache: fake KV cache, 2 x num_blocks x (block_size * num_heads * + head_size) ''' scale = float(1.0 / (head_size**0.5)) @@ -706,31 +744,33 @@ def self_attn_setup(batch_size, ''' Set up test vectors & data structures for self-attention test. - A triplet of synthetic query/key/value tensors are constructed ("baseline" query/key/value). - Given this is a self-attention test, the key & value sequences will have the same length - as the corresponding queries. + A triplet of synthetic query/key/value tensors are constructed ("baseline" + query/key/value). Given this is a self-attention test, the key & value + sequences will have the same length as the corresponding queries. - "Prefill" query/key/value tensors are derived by masking out the last value in each - baseline query/key/value. These tensors are used to test prefill & populate KV cache - for a subsequent decode test. + "Prefill" query/key/value tensors are derived by masking out the last value + in each baseline query/key/value. These tensors are used to test prefill & + populate KV cache for a subsequent decode test. - "Decode" query/key/value tensors are derived by extracting *only* the last value from - each baseline query/key/value (i.e. complement of the prefill tensors.) These tensors - are used to test decode, conditional on the kv cache being populated during the - prefill test. + "Decode" query/key/value tensors are derived by extracting *only* the last + value from each baseline query/key/value (i.e. complement of the prefill + tensors.) These tensors are used to test decode, conditional on the kv cache + being populated during the prefill test. - The baseline query/key/value tensors are passed to an ideal reference self-attention implementation - to generate a "Baseline" ideal output tensor. This tensor is split into the "Prefill" - ideal output tensor (all but the last element of each output sequence) and the "Decode" - ideal output tensor (*only* the last element of each output sequence); the "Prefill" and - "Decode" ideal output tensors can be used to validate the prefill and decode test - results, respectively. + The baseline query/key/value tensors are passed to an ideal reference + self-attention implementation to generate a "Baseline" ideal output tensor. + This tensor is split into the "Prefill" ideal output tensor (all but the + last element of each output sequence) and the "Decode" ideal output tensor + (*only* the last element of each output sequence); the "Prefill" and + "Decode" ideal output tensors can be used to validate the prefill and decode + test results, respectively. This function also constructs the self-attention KV cache memory mapping - (slot mapping and block table), ensuring that the block table starts - at block_base_addr + (slot mapping and block table), ensuring that the block table starts at + block_base_addr Arguments: + * batch_size * num_heads: Number of attention heads * head_size: Head dimension @@ -740,21 +780,37 @@ def self_attn_setup(batch_size, * block_base_addr: self-attention block table base address Returns: - * query: "baseline" query; batch_size x padded_seq_len x num_heads x head_size - * prefill_packed_query: "prefill" query; number_of_tokens x num_heads x head_size - * prefill_packed_key: self-attn "prefill" key; number_of_tokens x num_heads x head_size - * prefill_packed_value: self-attn "prefill" value; number_of_tokens x num_heads x head_size - * prefill_packed_ideal_output: self-attn "prefill" ideal output; number_of_tokens x num_heads x head_size - * prefill_q_prompt_lens: list of token counts for each *prefill query* (one less than baseline query) - * prefill_kv_prompt_lens: list of token counts for each self-attn *prefill key/value* (should match prefill_q_prompt_lens) - * decode_packed_query: "decode" query; number_of_tokens x num_heads x head_size - * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x head_size - * decode_packed_value: self-attn "decode" key; number_of_tokens x num_heads x head_size - * decode_packed_ideal_output: self-attn "decode" ideal output; number_of_tokens x num_heads x head_size - * decode_q_prompt_lens: list of token counts for each *decode query* (should be 1) - * decode_kv_prompt_lens: list of token counts for each self-attn *decode key/value* (should match decode_q_prompt_lens) - * q_prompt_lens: "baseline" query seq lens; number_of_tokens x num_heads x head_size - * kv_prompt_lens: self-attn "baseline" key/value seq lens; number_of_tokens x num_heads x head_size + + * query: "baseline" query; batch_size x padded_seq_len x num_heads x + head_size + * prefill_packed_query: "prefill" query; number_of_tokens x num_heads x + head_size + * prefill_packed_key: self-attn "prefill" key; number_of_tokens x num_heads + x head_size + * prefill_packed_value: self-attn "prefill" value; number_of_tokens x + num_heads x head_size + * prefill_packed_ideal_output: self-attn "prefill" ideal output; + number_of_tokens x num_heads x head_size + * prefill_q_prompt_lens: list of token counts for each *prefill query* (one + less than baseline query) + * prefill_kv_prompt_lens: list of token counts for each self-attn *prefill + key/value* (should match prefill_q_prompt_lens) + * decode_packed_query: "decode" query; number_of_tokens x num_heads x + head_size + * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x + head_size + * decode_packed_value: self-attn "decode" key; number_of_tokens x num_heads + x head_size + * decode_packed_ideal_output: self-attn "decode" ideal output; + number_of_tokens x num_heads x head_size + * decode_q_prompt_lens: list of token counts for each *decode query* (should + be 1) + * decode_kv_prompt_lens: list of token counts for each self-attn *decode + key/value* (should match decode_q_prompt_lens) + * q_prompt_lens: "baseline" query seq lens; number_of_tokens x num_heads x + head_size + * kv_prompt_lens: self-attn "baseline" key/value seq lens; number_of_tokens + x num_heads x head_size * decode_block_tables: fake self-attn decode-phase block table * decode_slot_mapping: fake self-attn decode-phase slot mapping * prefill_slot_mapping: fake self-attn prefill-phase slot mapping @@ -853,45 +909,56 @@ def cross_attn_setup_reuses_query(query, ''' Set up test vectors & data structures for cross-attention test. - A triplet of synthetic cross-attention key/value tensors are constructed ("baseline" key/value). - Given this is a cross-attention test, we assume query tensors were already synthesized for a - prior self-attention test and will be reused for cross-attention. The key & value sequences - generated here will may have a different length than the corresponding queries (as is often + A triplet of synthetic cross-attention key/value tensors are constructed + ("baseline" key/value). Given this is a cross-attention test, we assume + query tensors were already synthesized for a prior self-attention test and + will be reused for cross-attention. The key & value sequences generated here + will may have a different length than the corresponding queries (as is often the case for cross-attention between decoder and encoder sequences.) - Cross attention key & value tensors do not grow during autoregressive inference; thus - this function obtains a single key/value pair suitable for both prefill and decode. + Cross attention key & value tensors do not grow during autoregressive + inference; thus this function obtains a single key/value pair suitable for + both prefill and decode. - The "baseline" query tensor is received as an argument. The "baseline" query/key/value tensors - are passed to an ideal reference cross-attention implementation - to generate a "baseline" ideal output tensor. This tensor is split into the "Prefill" - ideal output tensor (all but the last element of each output sequence) and the "Decode" - ideal output tensor (*only* the last element of each output sequence); the "Prefill" and - "Decode" ideal output tensors can be used to validate the prefill and decode test - results, respectively. + The "baseline" query tensor is received as an argument. The "baseline" + query/key/value tensors are passed to an ideal reference cross-attention + implementation to generate a "baseline" ideal output tensor. This tensor is + split into the "Prefill" ideal output tensor (all but the last element of + each output sequence) and the "Decode" ideal output tensor (*only* the last + element of each output sequence); the "Prefill" and "Decode" ideal output + tensors can be used to validate the prefill and decode test results, + respectively. This function also constructs the cross-attention KV cache memory mapping - (slot mapping and block table), ensuring that the block table starts - at block_base_addr. + (slot mapping and block table), ensuring that the block table starts at + block_base_addr. Arguments: - * query: pre-existing "baseline" query; batch_size x padded_seq_len x num_heads x head_size + + * query: pre-existing "baseline" query; batch_size x padded_seq_len x + num_heads x head_size * q_prompt_lens: list of token-counts for each "baseline" query sequence - * prefill_q_prompt_lens: list of token-counts for each "prefill" query sequence + * prefill_q_prompt_lens: list of token-counts for each "prefill" query + sequence * batch_size * num_heads: Number of attention heads * head_size: Head dimension * block_size: Number of offsets per KV cache block * scale: attention scale parameter * max_q_prompt_len: upper limit on query length for synthetic test vectors - * max_kv_prompt_len: upper limit on key/value length for synthetic test vectors + * max_kv_prompt_len: upper limit on key/value length for synthetic test + vectors * block_base_addr: cross-attention block table base address Returns: + * packed_key: cross-attention key; number_of_tokens x num_heads x head_size - * packed_value: cross-attention value; number_of_tokens x num_heads x head_size - * prefill_packed_ideal_output: "prefill" ideal output; number_of_tokens x num_heads x head_size - * decode_packed_ideal_output: "decode" ideal output; number_of_tokens x num_heads x head_size + * packed_value: cross-attention value; number_of_tokens x num_heads x + head_size + * prefill_packed_ideal_output: "prefill" ideal output; number_of_tokens x + num_heads x head_size + * decode_packed_ideal_output: "decode" ideal output; number_of_tokens x + num_heads x head_size * kv_prompt_lens: list of token-counts for each key/value * decode_block_tables: fake decode-phase block tables * decode_slot_mapping: fake decode-phase slot mapping @@ -988,32 +1055,35 @@ def test_prefill_decode_self_and_cross_attention( max_kv_prompt_len: int) -> None: ''' Test: + * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention attributes + * Construct attention metadata structure with self- and cross-attention + attributes * Test self- and cross-attention in the following order + * Prefill self-attention * Prefill cross-attention * Decode self-attention * Decode cross-attention - * This order would exacerbate any accidental overlap in the self-/cross-attention block tables, - which we attempt to avoid - * Validate output correctness against ideal reference attention implementation - - Block tables are constructed such that cross-attention KV cache is in a higher, non-intersecting - address-space than self-attention KV cache. - - Self- and cross-attention share the same query tensor but not the K/V tensors. Self-attention - K/Vs must have the same seq len as Q while cross-attention K/Vs are allowed to differ in seq - len, as is often the case for cross-attention. + * This order would exacerbate any accidental overlap in the + self-/cross-attention block tables, which we attempt to avoid + * Validate output correctness against ideal reference attention + implementation + + Block tables are constructed such that cross-attention KV cache is in a + higher, non-intersecting address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V + tensors. Self-attention K/Vs must have the same seq len as Q while + cross-attention K/Vs are allowed to differ in seq len, as is often the case + for cross-attention. ''' # Num KV cache blocks num_blocks = 4096 - # Attention scale factor, - # attention backend instance, - # attention wrapper instance, - # KV cache init + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init scale, \ attn_backend, \ attn, \ From 014a751da86ca922049a4229e753b63d0c5ad75e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 01:04:25 -0400 Subject: [PATCH 028/443] formatting fixes --- tests/layer/test_self_and_cross_attn.py | 67 +++++++++++++++++++------ vllm/attention/backends/xformers.py | 10 ++-- 2 files changed, 59 insertions(+), 18 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 4fe9862de99a2..d1d0d0def15e9 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -307,12 +307,10 @@ def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): for bdx, (prompt_len, start_loc) in enumerate(zip(prompt_lens, start_loc_list)): - try: - packed_tensor[start_loc:( - start_loc + - prompt_len), :, :] = unpacked_tensor[bdx, :prompt_len, :, :] - except: - assert False, f"{start_loc} ; {prompt_len} ; {packed_tensor.shape} ; {unpacked_tensor.shape}" + + packed_tensor[start_loc:( + start_loc + + prompt_len), :, :] = unpacked_tensor[bdx, :prompt_len, :, :] return packed_tensor, start_loc_list @@ -358,7 +356,11 @@ def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): packed_key.shape[-1] * packed_key.shape[-2]) packed_value = packed_value.view( -1, packed_value.shape[-1] * packed_value.shape[-2]) - return packed_query, packed_key, packed_value, q_start_loc_list, kv_start_loc_list + return packed_query, \ + packed_key, \ + packed_value, \ + q_start_loc_list, \ + kv_start_loc_list def make_backend(backend_name: str) -> AttentionBackend: @@ -376,7 +378,8 @@ def make_backend(backend_name: str) -> AttentionBackend: ''' if backend_name == "xformers": return XFormersBackend() - assert False, f"Unrecognized backend_name {backend_name} for unit test" + raise AssertionError( + f"Unrecognized backend_name {backend_name} for unit test") def make_metadata_tensors(is_prompt: bool, @@ -568,7 +571,13 @@ def make_block_tables_slot_mapping(block_size, dtype=torch.long, device=device) - return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor, max_block_idx + return decode_block_tables_tensor, \ + decode_slot_mapping_tensor, \ + prefill_slot_mapping_tensor, \ + prefill_block_tables_tensor, \ + slot_mapping_tensor, \ + empty_slot_mapping_tensor, \ + max_block_idx def make_metadata_self_cross( @@ -836,7 +845,12 @@ def self_attn_setup(batch_size, prefill_q_prompt_lens, \ prefill_kv_prompt_lens, \ decode_q_prompt_lens, \ - decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=False) + decode_kv_prompt_lens = make_qkv(batch_size, + max_q_prompt_len, + max_kv_prompt_len, + num_heads, + head_size, + is_cross_attn=False) causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).to(CUDA_DEVICE) @@ -862,14 +876,26 @@ def self_attn_setup(batch_size, decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, [1 for _ in range(batch_size)]) - decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _, max_block_idx = make_block_tables_slot_mapping( + decode_block_tables, \ + decode_slot_mapping, \ + prefill_slot_mapping, \ + prefill_block_tables, \ + _, \ + _, \ + max_block_idx = make_block_tables_slot_mapping( block_size, q_prompt_lens, block_base_addr=block_base_addr) - prefill_packed_query, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( + prefill_packed_query, \ + prefill_packed_key, \ + prefill_packed_value, _, _ = pack_qkv( prefill_query, prefill_key, prefill_value, prefill_q_prompt_lens, prefill_kv_prompt_lens) - decode_packed_query, decode_packed_key, decode_packed_value, _, _ = pack_qkv( + decode_packed_query, \ + decode_packed_key, \ + decode_packed_value, \ + _, \ + _ = pack_qkv( decode_query, decode_key, decode_value, decode_q_prompt_lens, decode_kv_prompt_lens) @@ -983,7 +1009,12 @@ def cross_attn_setup_reuses_query(query, _, \ _, \ _, \ - _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=True) + _ = make_qkv(batch_size, + max_q_prompt_len, + max_kv_prompt_len, + num_heads, + head_size, + is_cross_attn=True) ideal_output = ref_masked_attention(query, key, @@ -1008,7 +1039,13 @@ def cross_attn_setup_reuses_query(query, # Unlike self-attention: # - Prefill slot-mapping includes all key slots # - Decode slot-mapping is empty - decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping, max_block_idx = make_block_tables_slot_mapping( + decode_block_tables, \ + _, \ + _, \ + prefill_block_tables, \ + prefill_slot_mapping, \ + decode_slot_mapping, \ + max_block_idx = make_block_tables_slot_mapping( block_size, kv_prompt_lens, block_base_addr=block_base_addr) # Packed key/value (query is already provided) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0c8db7e47a50d..b540f05c94d7a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -171,9 +171,12 @@ def do_cross_attn(self): def do_cross_attn(self, state: bool): if state: - assert self.has_valid_cross_attn_metadata, "Must have self.cross_seq_lens not None in order to enable cross-attention" + assert self.has_valid_cross_attn_metadata, \ + "Must have self.cross_seq_lens not None " + \ + "in order to enable cross-attention" - # Infer implicit cross-attention fields from user-provided fields, if needed + # Infer implicit cross-attention fields + # from user-provided fields, if needed if self.cross_seq_lens_tensor is None: assert self.seq_lens_tensor is not None self.cross_seq_lens_tensor = torch.tensor( @@ -439,7 +442,8 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. + # not cached. This happens during the initial memory + # profiling run. PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, updated_slot_mapping, From 8dc501af797a8b618e841756f3dfdd0a6037c25e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 01:06:28 -0400 Subject: [PATCH 029/443] isort --- tests/layer/test_self_and_cross_attn.py | 9 ++++----- vllm/attention/backends/xformers.py | 3 ++- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index d1d0d0def15e9..6945d15be39c8 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -1,15 +1,14 @@ +import copy +import itertools import random from typing import List, Optional -import itertools import pytest import torch -import copy -from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.xformers import XFormersBackend +from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend - +from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import make_tensor_with_pad # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b540f05c94d7a..36f1343e995df 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -4,8 +4,9 @@ import torch from xformers import ops as xops -from xformers.ops.fmha.attn_bias import (AttentionBias, BlockDiagonalMask, +from xformers.ops.fmha.attn_bias import (AttentionBias, BlockDiagonalCausalMask, + BlockDiagonalMask, LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, From 3ea10ea4d6a06df3a48def0dfbeac57174a08c78 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 01:22:03 -0400 Subject: [PATCH 030/443] refactor: prompt -> seq where appropriate in test file --- tests/layer/test_self_and_cross_attn.py | 446 ++++++++++++------------ 1 file changed, 223 insertions(+), 223 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 6945d15be39c8..811878a347e97 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -24,26 +24,26 @@ BACKEND_NAMES = ["xformers"] CUDA_DEVICE = "cuda:0" -MAX_Q_PROMPT_LENS = [128] -MAX_K_PROMPT_LENS = [128] +MAX_Q_SEQ_LENS = [128] +MAX_K_SEQ_LENS = [128] -def build_causal_mask(q_max_prompt_len, kv_max_prompt_len): +def build_causal_mask(q_max_seq_len, kv_max_seq_len): ''' - Create a q_max_prompt_len x kv_max_prompt_len causal mask + Create a q_max_seq_len x kv_max_seq_len causal mask Arguments: - * q_max_prompt_len: query max prompt len - * kv_max_prompt_len: key/value max prompt len + * q_max_seq_len: query max seq len + * kv_max_seq_len: key/value max seq len Returns: - * 2D tensor, q_max_prompt_len x kv_max_prompt_len + * 2D tensor, q_max_seq_len x kv_max_seq_len ''' # Create a matrix where entry (i, j) is True if i >= j - mask = torch.triu(torch.ones(q_max_prompt_len, kv_max_prompt_len), + mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) # Replace True with float('-inf') and False with 0 mask = mask.masked_fill(mask == 1, @@ -57,12 +57,12 @@ def ref_masked_attention( value: torch.Tensor, scale: float, custom_mask: Optional[torch.Tensor] = None, - q_prompt_lens: Optional[List] = None, - kv_prompt_lens: Optional[List] = None) -> torch.Tensor: + q_seq_lens: Optional[List] = None, + kv_seq_lens: Optional[List] = None) -> torch.Tensor: ''' "Golden" masked attention reference. Supports two types of masking: - * Basic attention mask, utilizing {q,kv}_prompt_lens args to mask out + * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out padding elements * Custom attention mask, which can force an arbitrary mask tensor, i.e. causal @@ -75,8 +75,8 @@ def ref_masked_attention( * scale: Attention scale factor * Custom mask: custom attention mask; good place to inject a causal attention mask - * q_prompt_lens: list of unpadded query seq_lens for each batch index - * kv_prompt_lens: list of unpadded key/value seq_lens for each batch index + * q_seq_lens: list of unpadded query seq_lens for each batch index + * kv_seq_lens: list of unpadded key/value seq_lens for each batch index Returns: @@ -84,19 +84,19 @@ def ref_masked_attention( ''' batch_size = query.shape[0] - assert (len(q_prompt_lens) == batch_size) - assert (len(kv_prompt_lens) == batch_size) + assert (len(q_seq_lens) == batch_size) + assert (len(kv_seq_lens) == batch_size) attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() - # Basic attention mask, derived from prompt lens - if (q_prompt_lens is not None) or (kv_prompt_lens is not None): + # Basic attention mask, derived from seq lens + if (q_seq_lens is not None) or (kv_seq_lens is not None): attn_mask = torch.zeros_like(attn_weights) - if q_prompt_lens is not None: - for bdx, plen in enumerate(q_prompt_lens): + if q_seq_lens is not None: + for bdx, plen in enumerate(q_seq_lens): attn_mask[bdx, :, plen:, :] = -torch.inf - if kv_prompt_lens is not None: - for bdx, plen in enumerate(kv_prompt_lens): + if kv_seq_lens is not None: + for bdx, plen in enumerate(kv_seq_lens): attn_mask[bdx, :, :, plen:] = -torch.inf attn_weights = attn_weights + attn_mask.float() @@ -111,8 +111,8 @@ def ref_masked_attention( def make_qkv(batch_size, - max_q_prompt_len, - max_kv_prompt_len, + max_q_seq_len, + max_kv_seq_len, num_heads, head_size, is_cross_attn=True, @@ -135,79 +135,79 @@ def make_qkv(batch_size, Arguments: * batch_size - * max_q_prompt_len: max query prompt len - * max_kv_prompt_len: max key/value prompt len + * max_q_seq_len: max query seq len + * max_kv_seq_len: max key/value seq len * num_heads * head_size * is_cross_attn: if True, query seqlen may differ from key/value seqlen (as is often the case for cross-attention); o/w, query/key/value seqlens match - at each batch index (max_kv_prompt_len is unused) - * force_max_len: if True, all query seqlens are max_q_prompt_len; o/w query - seqlens are random in [2,max_q_prompt_lens]. Same for key/value seqlens - and max_kv_prompt_len, unless forced by is_cross_attn=False + at each batch index (max_kv_seq_len is unused) + * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query + seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens + and max_kv_seq_len, unless forced by is_cross_attn=False * device: CPU or CUDA device Returns: - * query: "baseline" query; batch_size x max_q_prompt_len x num_heads x + * query: "baseline" query; batch_size x max_q_seq_len x num_heads x head_size - * key: "baseline" key; batch_size x max_kv_prompt_len x num_heads x + * key: "baseline" key; batch_size x max_kv_seq_len x num_heads x head_size - * value: "baseline" value; batch_size x max_kv_prompt_len x num_heads x + * value: "baseline" value; batch_size x max_kv_seq_len x num_heads x head_size - * prefill_query: batch_size x (max_q_prompt_len-1) x num_heads x head_size - * prefill_key: batch_size x (max_kv_prompt_len-1) x num_heads x head_size - * prefill_value: batch_size x (max_kv_prompt_len-1) x num_heads x head_size + * prefill_query: batch_size x (max_q_seq_len-1) x num_heads x head_size + * prefill_key: batch_size x (max_kv_seq_len-1) x num_heads x head_size + * prefill_value: batch_size x (max_kv_seq_len-1) x num_heads x head_size * decode_query: batch_size x 1 x num_heads x head_size * decode_key: batch_size x 1 x num_heads x head_size * decode_value: batch_size x 1 x num_heads x head_size - * q_prompt_lens: "baseline" query seqlen list - * kv_prompt_lens: "baseline" key/value seqlen list - * actual_max_q_prompt_len: actual "baseline" query max prompt len (may be <= - max_q_prompt_len due to randomness) - * actual_max_kv_prompt_len: actual "baseline" key/value max prompt len (may - be <= max_kv_prompt_len due to randomness) - * prefill_q_prompt_lens: "prefill" query seqlen list - * prefill_kv_prompt_lens: "prefill" key/value seqlen list - * decode_q_prompt_lens: "decode" query seqlen list (all ones) - * decode_kv_prompt_lens: "decode" key/value seqlen list + * q_seq_lens: "baseline" query seqlen list + * kv_seq_lens: "baseline" key/value seqlen list + * actual_max_q_seq_len: actual "baseline" query max seq len (may be <= + max_q_seq_len due to randomness) + * actual_max_kv_seq_len: actual "baseline" key/value max seq len (may + be <= max_kv_seq_len due to randomness) + * prefill_q_seq_lens: "prefill" query seqlen list + * prefill_kv_seq_lens: "prefill" key/value seqlen list + * decode_q_seq_lens: "decode" query seqlen list (all ones) + * decode_kv_seq_lens: "decode" key/value seqlen list ''' if force_max_len: - q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] + q_seq_lens = [max_q_seq_len for _ in range(batch_size)] else: - q_prompt_lens = [ - random.randint(2, max_q_prompt_len) for _ in range(batch_size) + q_seq_lens = [ + random.randint(2, max_q_seq_len) for _ in range(batch_size) ] - kv_prompt_lens = None + kv_seq_lens = None if not is_cross_attn: - # K,V prompt lens match Q for self-attention - kv_prompt_lens = q_prompt_lens + # K,V seq lens match Q for self-attention + kv_seq_lens = q_seq_lens else: - # K,V prompt lens are distinct from Q prompt lens & random + # K,V seq lens are distinct from Q seq lens & random if force_max_len: - kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] + kv_seq_lens = [max_kv_seq_len for _ in range(batch_size)] else: - kv_prompt_lens = [ - random.randint(2, max_kv_prompt_len) for _ in range(batch_size) + kv_seq_lens = [ + random.randint(2, max_kv_seq_len) for _ in range(batch_size) ] - actual_max_q_prompt_len = max(q_prompt_lens) - actual_max_kv_prompt_len = max(kv_prompt_lens) + actual_max_q_seq_len = max(q_seq_lens) + actual_max_kv_seq_len = max(kv_seq_lens) query = torch.rand( - (batch_size, max_q_prompt_len, num_heads * head_size)).to(device) + (batch_size, max_q_seq_len, num_heads * head_size)).to(device) key = torch.rand( - (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) value = torch.rand( - (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) prefill_query = torch.zeros( - (batch_size, max_q_prompt_len, num_heads * head_size)).to(device) + (batch_size, max_q_seq_len, num_heads * head_size)).to(device) prefill_key = torch.zeros( - (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) prefill_value = torch.zeros( - (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) decode_query = torch.zeros( (batch_size, 1, num_heads * head_size)).to(device) @@ -215,32 +215,32 @@ def make_qkv(batch_size, decode_value = torch.zeros( (batch_size, 1, num_heads * head_size)).to(device) - for bdx, (q_prompt_len, - kv_prompt_len) in enumerate(zip(q_prompt_lens, kv_prompt_lens)): - query[bdx, q_prompt_len:, :] = 0 - key[bdx, kv_prompt_len:, :] = 0 - value[bdx, kv_prompt_len:, :] = 0 + for bdx, (q_seq_len, + kv_seq_len) in enumerate(zip(q_seq_lens, kv_seq_lens)): + query[bdx, q_seq_len:, :] = 0 + key[bdx, kv_seq_len:, :] = 0 + value[bdx, kv_seq_len:, :] = 0 prefill_query[bdx, - 0:(q_prompt_len - 1), :] = query[bdx, - 0:(q_prompt_len - 1), :] + 0:(q_seq_len - 1), :] = query[bdx, + 0:(q_seq_len - 1), :] prefill_key[bdx, - 0:(kv_prompt_len - 1), :] = key[bdx, - 0:(kv_prompt_len - 1), :] - prefill_value[bdx, 0:(kv_prompt_len - - 1), :] = value[bdx, 0:(kv_prompt_len - 1), :] + 0:(kv_seq_len - 1), :] = key[bdx, + 0:(kv_seq_len - 1), :] + prefill_value[bdx, 0:(kv_seq_len - + 1), :] = value[bdx, 0:(kv_seq_len - 1), :] decode_query[bdx, :, :] = query[bdx, - (q_prompt_len - 1):q_prompt_len, :] - decode_key[bdx, :, :] = key[bdx, (kv_prompt_len - 1):kv_prompt_len, :] + (q_seq_len - 1):q_seq_len, :] + decode_key[bdx, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :] decode_value[bdx, :, :] = value[bdx, - (kv_prompt_len - 1):kv_prompt_len, :] + (kv_seq_len - 1):kv_seq_len, :] - prefill_q_prompt_lens = [plen - 1 for plen in q_prompt_lens] - prefill_kv_prompt_lens = [plen - 1 for plen in kv_prompt_lens] + prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] + prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] - decode_q_prompt_lens = [1 for _ in q_prompt_lens] - decode_kv_prompt_lens = [1 for _ in kv_prompt_lens] + decode_q_seq_lens = [1 for _ in q_seq_lens] + decode_kv_seq_lens = [1 for _ in kv_seq_lens] query = query.view(batch_size, query.shape[1], num_heads, head_size) key = key.view(batch_size, key.shape[1], num_heads, head_size) @@ -269,68 +269,68 @@ def make_qkv(batch_size, decode_query, \ decode_key, \ decode_value, \ - q_prompt_lens, \ - kv_prompt_lens, \ - actual_max_q_prompt_len, \ - actual_max_kv_prompt_len, \ - prefill_q_prompt_lens, \ - prefill_kv_prompt_lens, \ - decode_q_prompt_lens, \ - decode_kv_prompt_lens + q_seq_lens, \ + kv_seq_lens, \ + actual_max_q_seq_len, \ + actual_max_kv_seq_len, \ + prefill_q_seq_lens, \ + prefill_kv_seq_lens, \ + decode_q_seq_lens, \ + decode_kv_seq_lens -def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): +def pack_tensor(unpacked_tensor, seq_lens, device=CUDA_DEVICE): ''' Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an unpadded number_of_tokens x num_heads x head_size tensor, where - number_of_tokens = sum(prompt_lens) + number_of_tokens = sum(seq_lens) Arguments: * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size - * prompt_lens: list of token counts for each prompt + * seq_lens: list of token counts for each seq * device: CPU or CUDA device Returns * packed_tensor: number_of_tokens x num_heads x head_size * start_loc_list: start idx of each batch elt in packed_tensor; [0] + - list(itertools.accumulate(prompt_lens)) + list(itertools.accumulate(seq_lens)) ''' - num_tok = sum(prompt_lens) + num_tok = sum(seq_lens) num_heads = unpacked_tensor.shape[-2] head_size = unpacked_tensor.shape[-1] - start_loc_list = [0] + list(itertools.accumulate(prompt_lens)) + start_loc_list = [0] + list(itertools.accumulate(seq_lens)) packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) - for bdx, (prompt_len, - start_loc) in enumerate(zip(prompt_lens, start_loc_list)): + for bdx, (seq_len, + start_loc) in enumerate(zip(seq_lens, start_loc_list)): packed_tensor[start_loc:( start_loc + - prompt_len), :, :] = unpacked_tensor[bdx, :prompt_len, :, :] + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] return packed_tensor, start_loc_list -def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): +def pack_qkv(query, key, value, q_seq_lens, kv_seq_lens): ''' Individually pack each of Q, K and V, each with dimensions batch_size x padded_seq_len x num_heads x head_size, into respective number_of_tokens x num_heads x head_size tensors. - For Q, number_of_tokens = sum(q_prompt_lens). + For Q, number_of_tokens = sum(q_seq_lens). - For K and V, number_of_tokens = sum(kv_prompt_lens) + For K and V, number_of_tokens = sum(kv_seq_lens) Arguments: * query: batch_size x padded_seq_len x num_heads x head_size * key: batch_size x padded_seq_len x num_heads x head_size * value: batch_size x padded_seq_len x num_heads x head_size - * q_prompt_lens: list of token counts for each query - * kv_prompt_lens: list of token counts for each key/value + * q_seq_lens: list of token counts for each query + * kv_seq_lens: list of token counts for each key/value Returns @@ -345,9 +345,9 @@ def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): packed_query = None q_start_loc_list = None else: - packed_query, q_start_loc_list = pack_tensor(query, q_prompt_lens) - packed_key, kv_start_loc_list = pack_tensor(key, kv_prompt_lens) - packed_value, _ = pack_tensor(value, kv_prompt_lens) + packed_query, q_start_loc_list = pack_tensor(query, q_seq_lens) + packed_key, kv_start_loc_list = pack_tensor(key, kv_seq_lens) + packed_value, _ = pack_tensor(value, kv_seq_lens) if packed_query is not None: packed_query = packed_query.view( -1, packed_query.shape[-1] * packed_query.shape[-2]) @@ -382,7 +382,7 @@ def make_backend(backend_name: str) -> AttentionBackend: def make_metadata_tensors(is_prompt: bool, - prompt_lens: List[int], + seq_lens: List[int], context_lens: List[int], device=CUDA_DEVICE) -> tuple: ''' @@ -391,33 +391,33 @@ def make_metadata_tensors(is_prompt: bool, Arguments: * is_prompt: True -> Prefill, False -> Decode - * prompt_lens: list of token-counts for each prompt - * context_lens: list of context length values for each prompt + * seq_lens: list of token-counts for each seq + * context_lens: list of context length values for each seq * device: CPU or CUDA device Returns: - * prompt_lens_tensor: prompt_lens list, as tensor + * seq_lens_tensor: seq_lens list, as tensor * context_lens_tensor: context_lens list, as tensor - * max_query_len: max(prompt_lens) if is_prompt, o/w 1 + * max_query_len: max(seq_lens) if is_seq, o/w 1 * max_context_len: max(context_lens) - * max_prompt_len: max(prompt_lens) + * max_seq_len: max(seq_lens) * seq_start_loc: start idx of each sequence * query_start_loc: start idx of each query ''' - prompt_lens_tensor = torch.tensor(prompt_lens, + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) context_lens_tensor = None if context_lens is None else torch.tensor( context_lens, dtype=torch.int, device=device) max_context_len = None if context_lens is None else max(context_lens) - max_prompt_len = None if prompt_lens is None else max(prompt_lens) + max_seq_len = None if seq_lens is None else max(seq_lens) - seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) - torch.cumsum(prompt_lens_tensor, + torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) @@ -425,18 +425,18 @@ def make_metadata_tensors(is_prompt: bool, if is_prompt: # Prefill: query_start_loc matches seq_start_loc query_start_loc = copy.deepcopy(seq_start_loc) - max_query_len = max_prompt_len + max_query_len = max_seq_len else: # Decode: one new query input token per batch element, thus # query_start_loc is the cumsum of [1,1,1,...] query_start_loc = list(range(len(seq_start_loc))) max_query_len = 1 - return prompt_lens_tensor, \ + return seq_lens_tensor, \ context_lens_tensor, \ max_query_len, \ max_context_len, \ - max_prompt_len, \ + max_seq_len, \ seq_start_loc, \ query_start_loc @@ -480,7 +480,7 @@ def num_tokens_to_min_blocks(num_tokens, block_size): def make_block_tables_slot_mapping(block_size, - prompt_lens, + seq_lens, block_base_addr=0, device=CUDA_DEVICE): ''' @@ -495,7 +495,7 @@ def make_block_tables_slot_mapping(block_size, Arguments: * block_size: number of offsets per block - * prompt_lens: list of token-counts for each sequence + * seq_lens: list of token-counts for each sequence * block_base_addr: the block table base address * device: CPU or CUDA device @@ -518,7 +518,7 @@ def make_block_tables_slot_mapping(block_size, # Over-provision block table blocks by 1 num_blocks_list = [ num_tokens_to_min_blocks(num_tokens, block_size) + 1 - for num_tokens in prompt_lens + for num_tokens in seq_lens ] max_block_table_len = max(num_blocks_list) block_table_pad_tokens = 10 @@ -530,7 +530,7 @@ def make_block_tables_slot_mapping(block_size, block_base_idx = block_base_addr + sum( num_blocks_list) * 2 - 1 # Support more blocks than needed max_block_idx = block_base_idx - for sdx, num_tokens in enumerate(prompt_lens): + for sdx, num_tokens in enumerate(seq_lens): num_blocks = num_blocks_list[sdx] block_table = list( range(block_base_idx, block_base_idx - num_blocks, -1)) @@ -582,7 +582,7 @@ def make_block_tables_slot_mapping(block_size, def make_metadata_self_cross( attn_backend: AttentionBackend, is_prompt: bool, - prompt_lens: List[int], + seq_lens: List[int], context_lens: List[int], block_tables, slot_mapping, @@ -603,7 +603,7 @@ def make_metadata_self_cross( * attn_backend: Backend for sourcing attention kernels * is_prompt: prefill if True, o/w decode - * prompt_lens: list of token counts for each sequence + * seq_lens: list of token counts for each sequence * context_lens: list of context lengths for each sequence * block_tables: self-attention block tables * slot_mapping: self-attention slot_mapping @@ -619,18 +619,18 @@ def make_metadata_self_cross( ''' if is_prompt: - num_prefills = len(prompt_lens) - num_prefill_tokens = sum(prompt_lens) + num_prefills = len(seq_lens) + num_prefill_tokens = sum(seq_lens) num_decode_tokens = 0 - prompt_lens_tensor, \ + seq_lens_tensor, \ context_lens_tensor, \ max_query_len, \ _, \ _, \ seq_start_loc, \ query_start_loc = make_metadata_tensors(is_prompt, - prompt_lens, + seq_lens, context_lens, device=device) @@ -643,10 +643,10 @@ def make_metadata_self_cross( slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, - seq_lens=prompt_lens, - seq_lens_tensor=prompt_lens_tensor, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, - max_prefill_seq_len=max(prompt_lens), + max_prefill_seq_len=max(seq_lens), max_decode_seq_len=0, query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, @@ -662,16 +662,16 @@ def make_metadata_self_cross( num_prefills = 0 num_prefill_tokens = 0 - num_decode_tokens = len(prompt_lens) + num_decode_tokens = len(seq_lens) - prompt_lens_tensor, \ + seq_lens_tensor, \ context_lens_tensor, \ max_query_len, \ _, \ _, \ seq_start_loc, \ query_start_loc = make_metadata_tensors(is_prompt, - prompt_lens, + seq_lens, context_lens, device=device) @@ -684,11 +684,11 @@ def make_metadata_self_cross( slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, - seq_lens=prompt_lens, - seq_lens_tensor=prompt_lens_tensor, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_prefill_seq_len=0, - max_decode_seq_len=max(prompt_lens), + max_decode_seq_len=max(seq_lens), query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, @@ -747,7 +747,7 @@ def self_attn_setup(batch_size, head_size, block_size, scale, - max_q_prompt_len, + max_q_seq_len, block_base_addr=0): ''' Set up test vectors & data structures for self-attention test. @@ -784,7 +784,7 @@ def self_attn_setup(batch_size, * head_size: Head dimension * block_size: Number of offsets per KV cache block * scale: attention scale parameter - * max_q_prompt_len: upper limit on query length for synthetic test vectors + * max_q_seq_len: upper limit on query length for synthetic test vectors * block_base_addr: self-attention block table base address Returns: @@ -799,10 +799,10 @@ def self_attn_setup(batch_size, num_heads x head_size * prefill_packed_ideal_output: self-attn "prefill" ideal output; number_of_tokens x num_heads x head_size - * prefill_q_prompt_lens: list of token counts for each *prefill query* (one + * prefill_q_seq_lens: list of token counts for each *prefill query* (one less than baseline query) - * prefill_kv_prompt_lens: list of token counts for each self-attn *prefill - key/value* (should match prefill_q_prompt_lens) + * prefill_kv_seq_lens: list of token counts for each self-attn *prefill + key/value* (should match prefill_q_seq_lens) * decode_packed_query: "decode" query; number_of_tokens x num_heads x head_size * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x @@ -811,13 +811,13 @@ def self_attn_setup(batch_size, x head_size * decode_packed_ideal_output: self-attn "decode" ideal output; number_of_tokens x num_heads x head_size - * decode_q_prompt_lens: list of token counts for each *decode query* (should + * decode_q_seq_lens: list of token counts for each *decode query* (should be 1) - * decode_kv_prompt_lens: list of token counts for each self-attn *decode - key/value* (should match decode_q_prompt_lens) - * q_prompt_lens: "baseline" query seq lens; number_of_tokens x num_heads x + * decode_kv_seq_lens: list of token counts for each self-attn *decode + key/value* (should match decode_q_seq_lens) + * q_seq_lens: "baseline" query seq lens; number_of_tokens x num_heads x head_size - * kv_prompt_lens: self-attn "baseline" key/value seq lens; number_of_tokens + * kv_seq_lens: self-attn "baseline" key/value seq lens; number_of_tokens x num_heads x head_size * decode_block_tables: fake self-attn decode-phase block table * decode_slot_mapping: fake self-attn decode-phase slot mapping @@ -826,7 +826,7 @@ def self_attn_setup(batch_size, * max_block_idx: highest block address in the self-attention block-table ''' - max_kv_prompt_len = max_q_prompt_len + max_kv_seq_len = max_q_seq_len query, \ key, \ @@ -837,41 +837,41 @@ def self_attn_setup(batch_size, decode_query, \ decode_key, \ decode_value, \ - q_prompt_lens, \ - kv_prompt_lens, \ + q_seq_lens, \ + kv_seq_lens, \ _, \ _, \ - prefill_q_prompt_lens, \ - prefill_kv_prompt_lens, \ - decode_q_prompt_lens, \ - decode_kv_prompt_lens = make_qkv(batch_size, - max_q_prompt_len, - max_kv_prompt_len, - num_heads, - head_size, - is_cross_attn=False) - - causal_mask = build_causal_mask(max_q_prompt_len, - max_kv_prompt_len).to(CUDA_DEVICE) + prefill_q_seq_lens, \ + prefill_kv_seq_lens, \ + decode_q_seq_lens, \ + decode_kv_seq_lens = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + is_cross_attn=False) + + causal_mask = build_causal_mask(max_q_seq_len, + max_kv_seq_len).to(CUDA_DEVICE) ideal_output = ref_masked_attention(query, key, value, scale=scale, custom_mask=causal_mask, - q_prompt_lens=q_prompt_lens, - kv_prompt_lens=kv_prompt_lens) + q_seq_lens=q_seq_lens, + kv_seq_lens=kv_seq_lens) prefill_ideal_output = torch.zeros_like(ideal_output) decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ - bdx, :prefill_q_prompt_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( - prefill_q_prompt_len + 1)] + for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): + prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ + bdx, :prefill_q_seq_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( + prefill_q_seq_len + 1)] prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_prompt_lens) + prefill_q_seq_lens) decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, [1 for _ in range(batch_size)]) @@ -882,37 +882,37 @@ def self_attn_setup(batch_size, _, \ _, \ max_block_idx = make_block_tables_slot_mapping( - block_size, q_prompt_lens, block_base_addr=block_base_addr) + block_size, q_seq_lens, block_base_addr=block_base_addr) prefill_packed_query, \ prefill_packed_key, \ prefill_packed_value, _, _ = pack_qkv( - prefill_query, prefill_key, prefill_value, prefill_q_prompt_lens, - prefill_kv_prompt_lens) + prefill_query, prefill_key, prefill_value, prefill_q_seq_lens, + prefill_kv_seq_lens) decode_packed_query, \ decode_packed_key, \ decode_packed_value, \ _, \ _ = pack_qkv( - decode_query, decode_key, decode_value, decode_q_prompt_lens, - decode_kv_prompt_lens) + decode_query, decode_key, decode_value, decode_q_seq_lens, + decode_kv_seq_lens) return query, \ prefill_packed_query, \ prefill_packed_key, \ prefill_packed_value, \ prefill_packed_ideal_output, \ - prefill_q_prompt_lens, \ - prefill_kv_prompt_lens, \ + prefill_q_seq_lens, \ + prefill_kv_seq_lens, \ decode_packed_query, \ decode_packed_key, \ decode_packed_value, \ decode_packed_ideal_output, \ - decode_q_prompt_lens, \ - decode_kv_prompt_lens, \ - q_prompt_lens, \ - kv_prompt_lens, \ + decode_q_seq_lens, \ + decode_kv_seq_lens, \ + q_seq_lens, \ + kv_seq_lens, \ decode_block_tables, \ decode_slot_mapping, \ prefill_slot_mapping, \ @@ -921,15 +921,15 @@ def self_attn_setup(batch_size, def cross_attn_setup_reuses_query(query, - q_prompt_lens, - prefill_q_prompt_lens, + q_seq_lens, + prefill_q_seq_lens, batch_size, num_heads, head_size, block_size, scale, - max_q_prompt_len, - max_kv_prompt_len, + max_q_seq_len, + max_kv_seq_len, block_base_addr=0): ''' Set up test vectors & data structures for cross-attention test. @@ -962,16 +962,16 @@ def cross_attn_setup_reuses_query(query, * query: pre-existing "baseline" query; batch_size x padded_seq_len x num_heads x head_size - * q_prompt_lens: list of token-counts for each "baseline" query sequence - * prefill_q_prompt_lens: list of token-counts for each "prefill" query + * q_seq_lens: list of token-counts for each "baseline" query sequence + * prefill_q_seq_lens: list of token-counts for each "prefill" query sequence * batch_size * num_heads: Number of attention heads * head_size: Head dimension * block_size: Number of offsets per KV cache block * scale: attention scale parameter - * max_q_prompt_len: upper limit on query length for synthetic test vectors - * max_kv_prompt_len: upper limit on key/value length for synthetic test + * max_q_seq_len: upper limit on query length for synthetic test vectors + * max_kv_seq_len: upper limit on key/value length for synthetic test vectors * block_base_addr: cross-attention block table base address @@ -984,7 +984,7 @@ def cross_attn_setup_reuses_query(query, num_heads x head_size * decode_packed_ideal_output: "decode" ideal output; number_of_tokens x num_heads x head_size - * kv_prompt_lens: list of token-counts for each key/value + * kv_seq_lens: list of token-counts for each key/value * decode_block_tables: fake decode-phase block tables * decode_slot_mapping: fake decode-phase slot mapping * prefill_slot_mapping: fake prefill-phase slot mapping @@ -1002,15 +1002,15 @@ def cross_attn_setup_reuses_query(query, _, \ _, \ _, \ - kv_prompt_lens, \ + kv_seq_lens, \ _, \ _, \ _, \ _, \ _, \ _ = make_qkv(batch_size, - max_q_prompt_len, - max_kv_prompt_len, + max_q_seq_len, + max_kv_seq_len, num_heads, head_size, is_cross_attn=True) @@ -1019,19 +1019,19 @@ def cross_attn_setup_reuses_query(query, key, value, scale=scale, - q_prompt_lens=q_prompt_lens, - kv_prompt_lens=kv_prompt_lens) + q_seq_lens=q_seq_lens, + kv_seq_lens=kv_seq_lens) prefill_ideal_output = torch.zeros_like(ideal_output) decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ - bdx, :prefill_q_prompt_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( - prefill_q_prompt_len + 1)] + for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): + prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ + bdx, :prefill_q_seq_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( + prefill_q_seq_len + 1)] prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_prompt_lens) + prefill_q_seq_lens) decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, [1 for _ in range(batch_size)]) @@ -1045,17 +1045,17 @@ def cross_attn_setup_reuses_query(query, prefill_slot_mapping, \ decode_slot_mapping, \ max_block_idx = make_block_tables_slot_mapping( - block_size, kv_prompt_lens, block_base_addr=block_base_addr) + block_size, kv_seq_lens, block_base_addr=block_base_addr) # Packed key/value (query is already provided) _, packed_key, packed_value, _, _ = pack_qkv(None, key, value, None, - kv_prompt_lens) + kv_seq_lens) return packed_key, \ packed_value, \ prefill_packed_ideal_output, \ decode_packed_ideal_output, \ - kv_prompt_lens, \ + kv_seq_lens, \ decode_block_tables, \ decode_slot_mapping, \ prefill_slot_mapping, \ @@ -1083,12 +1083,12 @@ def run_cross_attention_test(attn, packed_query, packed_key, packed_value, @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len", MAX_Q_PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len", MAX_K_PROMPT_LENS) +@pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) +@pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) def test_prefill_decode_self_and_cross_attention( num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_q_prompt_len: int, - max_kv_prompt_len: int) -> None: + block_size: int, max_q_seq_len: int, + max_kv_seq_len: int) -> None: ''' Test: @@ -1138,15 +1138,15 @@ def test_prefill_decode_self_and_cross_attention( self_prefill_packed_key, \ self_prefill_packed_value, \ self_prefill_packed_ideal_output, \ - prefill_q_prompt_lens, \ - self_prefill_kv_prompt_lens, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ decode_packed_query, \ self_decode_packed_key, \ self_decode_packed_value, \ self_decode_packed_ideal_output, \ _, \ _, \ - q_prompt_lens, \ + q_seq_lens, \ _, \ self_decode_block_tables, \ self_decode_slot_mapping, \ @@ -1157,7 +1157,7 @@ def test_prefill_decode_self_and_cross_attention( head_size, block_size, scale, - max_q_prompt_len, + max_q_seq_len, block_base_addr=self_block_base_addr) # Cross-attention setup @@ -1166,21 +1166,21 @@ def test_prefill_decode_self_and_cross_attention( cross_prefill_packed_value, \ cross_prefill_packed_ideal_output, \ cross_decode_packed_ideal_output, \ - cross_kv_prompt_lens, \ + cross_kv_seq_lens, \ cross_decode_block_tables, \ cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ _ = cross_attn_setup_reuses_query(query, - q_prompt_lens, - prefill_q_prompt_lens, + q_seq_lens, + prefill_q_seq_lens, batch_size, num_heads, head_size, block_size, scale, - max_q_prompt_len, - max_kv_prompt_len, + max_q_seq_len, + max_kv_seq_len, block_base_addr=cross_block_base_addr) # PREFILL: self- and cross-attention tests @@ -1190,11 +1190,11 @@ def test_prefill_decode_self_and_cross_attention( prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross( attn_backend, True, - prefill_q_prompt_lens, + prefill_q_seq_lens, context_lens, self_prefill_block_tables, self_prefill_slot_mapping, - cross_seq_lens=cross_kv_prompt_lens, + cross_seq_lens=cross_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, ) @@ -1219,18 +1219,18 @@ def test_prefill_decode_self_and_cross_attention( cross_prefill_packed_actual_output.view_as( cross_prefill_packed_ideal_output)) - context_lens = copy.deepcopy(self_prefill_kv_prompt_lens) + context_lens = copy.deepcopy(self_prefill_kv_seq_lens) # DECODE: self- and cross-attention tests decode_attn_metadata: AttentionMetadata = make_metadata_self_cross( attn_backend, False, - q_prompt_lens, + q_seq_lens, context_lens, self_decode_block_tables, self_decode_slot_mapping, - cross_seq_lens=cross_kv_prompt_lens, + cross_seq_lens=cross_kv_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, ) From 2ced012a3e51a77abbbab2268d88730fdffa4a3f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 06:19:19 -0400 Subject: [PATCH 031/443] fix wording nits (ben->been, decoder->encoder/decoder) --- tests/core/test_block_manager.py | 2 +- vllm/core/block_manager_v2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index d6ab246699903..81e3444815d4e 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -314,7 +314,7 @@ def test_swap_encoder_decoder(): assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks decoder_prompt.status = SequenceStatus.SWAPPED - # Swap decoder seq group from CPU -> GPU. + # Swap encoder/decoder seq group from CPU -> GPU. decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt) cross_cpu_blocks = block_manager.get_cross_block_table(seq_group) cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 4ae3361e7b234..978acd915b69b 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -222,7 +222,7 @@ def free(self, seq: Sequence) -> None: def free_cross(self, seq_group: SequenceGroup) -> None: request_id = seq_group.request_id if request_id not in self.cross_block_tables: - # Already freed or hasn't ben scheduled yet. + # Already freed or hasn't been scheduled yet. return self.cross_block_tables[request_id].free() del self.cross_block_tables[request_id] From 8286b4cfbe57001767617a9ee33066945f6baa3d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 13:46:58 -0400 Subject: [PATCH 032/443] changed two block manager tests to construct fake prompts that are equal in length to the bock size, rather than half the block size (which had been the case --- tests/core/test_block_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 81e3444815d4e..9dc1c88819b70 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -116,7 +116,7 @@ def test_allocate_encoder_decoder(): watermark=1 / num_gpu_blocks) for i in range((num_gpu_blocks - 1) // block_req_per_seq_group): _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), block_size // 2, block_size // 2) + str(i), block_size, block_size) assert block_manager.can_allocate(seq_group) block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -365,8 +365,8 @@ def test_free_encoder_decoder(): decoder_prompt, encoder_prompt, seq_group = \ create_dummy_prompt_encoder_decoder( "1", - decoder_prompt_length=block_size // 2, - encoder_prompt_length=block_size // 2) + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) block_manager.allocate(seq_group) # Free allocated seq. From eba551cd7e1d53911cb392d773eec05cfe40cc4f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 13:50:04 -0400 Subject: [PATCH 033/443] keyword args for dummy prompt construction in block manager encoder/decoder tests --- tests/core/test_block_manager.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 9dc1c88819b70..19dfc09dbb001 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -103,7 +103,9 @@ def test_allocate_encoder_decoder(): # Allocate same sequence group to all available gpu blocks. for i in range(num_gpu_blocks // block_req_per_seq_group): _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), block_size, block_size) + str(i), + decoder_prompt_length=block_size, + decoder_prompt_length=block_size) assert block_manager.can_allocate(seq_group) block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -116,7 +118,9 @@ def test_allocate_encoder_decoder(): watermark=1 / num_gpu_blocks) for i in range((num_gpu_blocks - 1) // block_req_per_seq_group): _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), block_size, block_size) + str(i), + decoder_prompt_length=block_size, + decoder_prompt_length=block_size) assert block_manager.can_allocate(seq_group) block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK From a7c8b192cd7c6e6c815caf5acbbd4ed24b16925d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 14:00:05 -0400 Subject: [PATCH 034/443] bugfix - decoder prompt kwarg repeated in lieu of encoder prompt kwarg --- tests/core/test_block_manager.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 19dfc09dbb001..29956ff028143 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -73,7 +73,7 @@ def test_allocate(): # Allocate same sequence group to all available gpu blocks. for i in range(num_gpu_blocks): _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -85,7 +85,7 @@ def test_allocate(): watermark=1 / num_gpu_blocks) for i in range(num_gpu_blocks - 1): _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -105,8 +105,8 @@ def test_allocate_encoder_decoder(): _, _, seq_group = create_dummy_prompt_encoder_decoder( str(i), decoder_prompt_length=block_size, - decoder_prompt_length=block_size) - assert block_manager.can_allocate(seq_group) + encoder_prompt_length=block_size) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -120,8 +120,8 @@ def test_allocate_encoder_decoder(): _, _, seq_group = create_dummy_prompt_encoder_decoder( str(i), decoder_prompt_length=block_size, - decoder_prompt_length=block_size) - assert block_manager.can_allocate(seq_group) + encoder_prompt_length=block_size) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK From 9feb994966e365fac63bbec526cafb24cf00dcde Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 14:09:42 -0400 Subject: [PATCH 035/443] In block manager test which used with block to detect error - created a second with block for encoder-related call that previously shared a with block with the corresponding decoder-related call --- tests/core/test_block_manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 29956ff028143..808b0a5e651eb 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -386,6 +386,9 @@ def test_free_encoder_decoder(): # Block table for freed encoder & decoder seq's are deleted. with pytest.raises(KeyError): block_manager.get_block_table(decoder_prompt) + + # Block table for freed encoder & decoder seq's are deleted. + with pytest.raises(KeyError): block_manager.get_block_table(encoder_prompt) From 5eb0032bfaaf5bc43fab66f1fc8bea30045915b7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 14:17:50 -0400 Subject: [PATCH 036/443] refactoring block manager v1/v2 swap in/swap out functions --- vllm/core/block_manager_v1.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 03eba2e80c78d..119e444df1b11 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -570,12 +570,7 @@ def swap_in(self, self.cpu_allocator.free(cpu_block) self.cross_block_tables[request_id] = new_block_table - block_number_mapping = { - cpu_block.block_number: gpu_block.block_number - for cpu_block, gpu_block in mapping.items() - } - # convert to list of tuples once here - return list(block_number_mapping.items()) + return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) @@ -621,12 +616,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self.gpu_allocator.free(gpu_block) self.cross_block_tables[request_id] = new_block_table - block_number_mapping = { - gpu_block.block_number: cpu_block.block_number - for gpu_block, cpu_block in mapping.items() - } - # convert to list of tuples once here - return list(block_number_mapping.items()) + return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] def _free_block_table(self, block_table: BlockTable) -> None: # when using a sliding window, each seq will only use up From 0644cde2aced6d7fb6c279025b2a4a3d8f5625d2 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 14:23:50 -0400 Subject: [PATCH 037/443] formatting; changed blocktable type specifier from Dict to List[int] --- tests/core/test_block_manager.py | 6 +++--- vllm/core/block_manager_v1.py | 6 ++++-- vllm/sequence.py | 6 +++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 808b0a5e651eb..cdaf2f22115e8 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -103,7 +103,7 @@ def test_allocate_encoder_decoder(): # Allocate same sequence group to all available gpu blocks. for i in range(num_gpu_blocks // block_req_per_seq_group): _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), + str(i), decoder_prompt_length=block_size, encoder_prompt_length=block_size) assert block_manager.can_allocate(seq_group) == AllocStatus.OK @@ -118,8 +118,8 @@ def test_allocate_encoder_decoder(): watermark=1 / num_gpu_blocks) for i in range((num_gpu_blocks - 1) // block_req_per_seq_group): _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), - decoder_prompt_length=block_size, + str(i), + decoder_prompt_length=block_size, encoder_prompt_length=block_size) assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 119e444df1b11..2482cf17956f2 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -570,7 +570,8 @@ def swap_in(self, self.cpu_allocator.free(cpu_block) self.cross_block_tables[request_id] = new_block_table - return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] + return [(cpu_block.block_number, gpu_block.block_number) + for cpu_block, gpu_block in mapping.items()] def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) @@ -616,7 +617,8 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self.gpu_allocator.free(gpu_block) self.cross_block_tables[request_id] = new_block_table - return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] + return [(cpu_block.block_number, gpu_block.block_number) + for cpu_block, gpu_block in mapping.items()] def _free_block_table(self, block_table: BlockTable) -> None: # when using a sliding window, each seq will only use up diff --git a/vllm/sequence.py b/vllm/sequence.py index a11c411876ea8..6b07a00f09c6f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -527,8 +527,8 @@ def get_seqs( seq for seq in self.seqs_dict.values() if seq.status == status ] - def get_encoder_seq(self) -> Sequence: - return self.encoder_seq # type: ignore + def get_encoder_seq(self) -> Optional[Sequence]: + return self.encoder_seq def get_unfinished_seqs(self) -> List[Sequence]: return [ @@ -635,7 +635,7 @@ def __init__( state: Optional[SequenceGroupState] = None, multi_modal_data: Optional[MultiModalData] = None, encoder_seq_data: Optional[SequenceData] = None, - cross_block_table: Optional[Dict[int, List[int]]] = None, + cross_block_table: Optional[List[int]] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt From 19ed7413e315ce665cc07722d72fb874a362fafd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 14:39:50 -0400 Subject: [PATCH 038/443] prefixed internal method with _ --- vllm/core/block_manager_v1.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 2482cf17956f2..648ff843fd4e5 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -260,7 +260,7 @@ def __init__( # request ID self.cross_block_tables: Dict[str, BlockTable] = {} - def get_seq_num_required_blocks(self, seq: Sequence) -> int: + def _get_seq_num_required_blocks(self, seq: Sequence) -> int: if seq is None: return 0 return len(seq.logical_token_blocks) @@ -269,9 +269,9 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - self_num_required_blocks = self.get_seq_num_required_blocks( + self_num_required_blocks = self._get_seq_num_required_blocks( seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) - cross_num_required_blocks = self.get_seq_num_required_blocks( + cross_num_required_blocks = self._get_seq_num_required_blocks( seq_group.get_encoder_seq()) num_required_blocks = self_num_required_blocks + \ cross_num_required_blocks From a5579729928c4151e501138f82340c0afa2dc327 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 17:47:19 -0400 Subject: [PATCH 039/443] refactored self-/cross-attention allocation functions into a single helper function --- vllm/core/block_manager_v1.py | 57 ++++++++++++----------------------- 1 file changed, 19 insertions(+), 38 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 648ff843fd4e5..9f08d4a7939aa 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -290,11 +290,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: - # NOTE: Here we assume that all sequences in the group have the same - # decoder prompt. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - + def _allocate_sequence(self, seq: Sequence, ref_count: int) -> BlockTable: # Allocate new physical token blocks that will store the prompt tokens. num_prompt_blocks = len(seq.logical_token_blocks) @@ -304,7 +300,7 @@ def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() + block.ref_count = ref_count #seq_group.num_seqs() elif self.enable_caching: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), @@ -312,47 +308,32 @@ def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: else: block = self.gpu_allocator.allocate() # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() + block.ref_count = ref_count #seq_group.num_seqs() block_table.append(block) - # Assign the decoder block table for each sequence. - for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): - self.block_tables[seq.seq_id] = block_table.copy() + return block_table - def allocate_cross_block_table(self, seq_group: SequenceGroup) -> None: + def allocate(self, seq_group: SequenceGroup) -> None: + # Allocate decoder sequences + # # NOTE: Here we assume that all sequences in the group have the same - # encoder prompt. + # decoder prompt. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + block_table: BlockTable = self._allocate_sequence(seq, seq_group.num_seqs()) - # Allocate new physical token blocks that will store the prompt tokens. - seq = seq_group.get_encoder_seq() - if seq is not None: - block_table: BlockTable = [] - num_prompt_blocks = len(seq.logical_token_blocks) - for logical_idx in range(num_prompt_blocks): - if (self.block_sliding_window is not None - and logical_idx >= self.block_sliding_window): - block = block_table[logical_idx % - self.block_sliding_window] - # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() - elif self.enable_caching: - block = self.gpu_allocator.allocate( - seq.hash_of_block(logical_idx), - seq.num_hashed_tokens_of_block(logical_idx)) - else: - block = self.gpu_allocator.allocate() - # Set the reference counts of the token blocks. - # TODO: feature not supported with encoder/decoder - block.ref_count = seq_group.num_seqs() - block_table.append(block) + # Assign the self-attention block tables for each sequence. + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + self.block_tables[seq.seq_id] = block_table.copy() + # Allocate encoder sequence + encoder_seq = seq_group.get_encoder_seq() + if encoder_seq is not None: + # A SequenceGroup has only a single encoder sequence (at most), + # thus allocate with a ref count of 1 + block_table: BlockTable = self._allocate_sequence(encoder_seq, 1) # Assign the cross-attention block table for the SequenceGroup. self.cross_block_tables[seq_group.request_id] = block_table - def allocate(self, seq_group: SequenceGroup) -> None: - self.allocate_self_block_tables(seq_group) - self.allocate_cross_block_table(seq_group) - def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> bool: From e48bebf727ae67ffbdff206d168eab3e77b988da Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 17:59:44 -0400 Subject: [PATCH 040/443] Refactored block manager v2 self-/cross-block-table alloc functions together --- vllm/core/block_manager_v2.py | 38 ++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 978acd915b69b..a8085f54ac79d 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -121,7 +121,18 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: + def _allocate_sequence(self, seq: Sequence) -> BlockTable: + block_table = BlockTable( + block_size=self.block_size, + block_allocator=self.block_allocator, + ) + assert self.block_sliding_window is None + block_table.allocate(seq.get_token_ids()) + + return block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert not (set(seq.seq_id for seq in waiting_seqs) & self.block_tables.keys()), "block table already exists" @@ -129,43 +140,34 @@ def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # prompt. seq = waiting_seqs[0] - - block_table = BlockTable( - block_size=self.block_size, - block_allocator=self.block_allocator, - ) - assert self.block_sliding_window is None - block_table.allocate(seq.get_token_ids()) + block_table: BlockTable = self._allocate_sequence(seq) self.block_tables[seq.seq_id] = block_table # Assign the block table for each sequence. for seq in waiting_seqs[1:]: self.block_tables[seq.seq_id] = block_table.fork() - def allocate_cross_block_table(self, seq_group: SequenceGroup) -> None: + # Allocate cross-attention block table for encoder sequence + # # NOTE: Here we assume that all sequences in the group have the same - # prompt. + # encoder prompt. request_id = seq_group.request_id - seq = seq_group.encoder_seq + encoder_seq = seq_group.encoder_seq assert (request_id not in self.cross_block_tables), \ "block table already exists" - seq = seq_group.get_encoder_seq() - if seq is not None: + encoder_seq = seq_group.get_encoder_seq() + if encoder_seq is not None: block_table = BlockTable( block_size=self.block_size, block_allocator=self.block_allocator, ) assert self.block_sliding_window is None - block_table.allocate(seq.get_token_ids()) + block_table.allocate(encoder_seq.get_token_ids()) self.cross_block_tables[request_id] = block_table - def allocate(self, seq_group: SequenceGroup) -> None: - self.allocate_self_block_tables(seq_group) - self.allocate_cross_block_table(seq_group) - def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> bool: """Determine if there is enough space in the GPU KV cache to continue From ac2da978c786d998247cfe55a3d2a788109b71e4 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 18:11:58 -0400 Subject: [PATCH 041/443] formatting --- vllm/core/block_manager_v1.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 9f08d4a7939aa..fa53b3cd33229 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -319,7 +319,8 @@ def allocate(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # decoder prompt. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - block_table: BlockTable = self._allocate_sequence(seq, seq_group.num_seqs()) + block_table: BlockTable = \ + self._allocate_sequence(seq, seq_group.num_seqs()) # Assign the self-attention block tables for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): @@ -330,7 +331,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: if encoder_seq is not None: # A SequenceGroup has only a single encoder sequence (at most), # thus allocate with a ref count of 1 - block_table: BlockTable = self._allocate_sequence(encoder_seq, 1) + block_table = self._allocate_sequence(encoder_seq, 1) # Assign the cross-attention block table for the SequenceGroup. self.cross_block_tables[seq_group.request_id] = block_table From e985a2f05080a0e311f52adf119447993322541f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 18:40:07 -0400 Subject: [PATCH 042/443] refactored out block manager v1 swap_n/swap_out helper functions --- vllm/core/block_manager_v1.py | 116 ++++++++++++++++------------------ 1 file changed, 54 insertions(+), 62 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index fa53b3cd33229..dd6d8d702fae0 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -300,7 +300,7 @@ def _allocate_sequence(self, seq: Sequence, ref_count: int) -> BlockTable: and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] # Set the reference counts of the token blocks. - block.ref_count = ref_count #seq_group.num_seqs() + block.ref_count = ref_count #seq_group.num_seqs() elif self.enable_caching: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), @@ -308,7 +308,7 @@ def _allocate_sequence(self, seq: Sequence, ref_count: int) -> BlockTable: else: block = self.gpu_allocator.allocate() # Set the reference counts of the token blocks. - block.ref_count = ref_count #seq_group.num_seqs() + block.ref_count = ref_count #seq_group.num_seqs() block_table.append(block) return block_table @@ -507,6 +507,26 @@ def can_swap_in(self, else: return AllocStatus.LATER + def _swap_in_block_table( + self, block_table: BlockTable, + mapping: Dict[PhysicalTokenBlock, + PhysicalTokenBlock]) -> BlockTable: + new_block_table = [] + + for cpu_block in block_table: + if cpu_block in mapping: + gpu_block = mapping[cpu_block] + gpu_block.ref_count += 1 + else: + gpu_block = self.gpu_allocator.allocate( + cpu_block.block_hash, cpu_block.num_hashed_tokens) + mapping[cpu_block] = gpu_block + new_block_table.append(gpu_block) + # Free the CPU block swapped in to GPU. + self.cpu_allocator.free(cpu_block) + + return new_block_table + def swap_in(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> List[Tuple[int, int]]: @@ -519,38 +539,14 @@ def swap_in(self, # dict is efficient in lookup `if cpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - new_block_table: BlockTable = [] - block_table = self.block_tables[seq.seq_id] - - for cpu_block in block_table: - if cpu_block in mapping: - gpu_block = mapping[cpu_block] - gpu_block.ref_count += 1 - else: - gpu_block = self.gpu_allocator.allocate( - cpu_block.block_hash, cpu_block.num_hashed_tokens) - mapping[cpu_block] = gpu_block - new_block_table.append(gpu_block) - # Free the CPU block swapped in to GPU. - self.cpu_allocator.free(cpu_block) - self.block_tables[seq.seq_id] = new_block_table + self.block_tables[seq.seq_id] = \ + self._swap_in_block_table(self.block_tables[seq.seq_id], + mapping) if seq_group.encoder_seq is not None: - new_block_table = [] - block_table = self.cross_block_tables[request_id] - - for cpu_block in block_table: - if cpu_block in mapping: - gpu_block = mapping[cpu_block] - gpu_block.ref_count += 1 - else: - gpu_block = self.gpu_allocator.allocate( - cpu_block.block_hash, cpu_block.num_hashed_tokens) - mapping[cpu_block] = gpu_block - new_block_table.append(gpu_block) - # Free the CPU block swapped in to GPU. - self.cpu_allocator.free(cpu_block) - self.cross_block_tables[request_id] = new_block_table + self.cross_block_tables[request_id] = \ + self._swap_in_block_table(self.cross_block_tables[request_id], + mapping) return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] @@ -559,6 +555,26 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() + def _swap_out_block_table( + self, block_table: BlockTable, + mapping: Dict[PhysicalTokenBlock, + PhysicalTokenBlock]) -> BlockTable: + + new_block_table: BlockTable = [] + for gpu_block in block_table: + if gpu_block in mapping: + cpu_block = mapping[gpu_block] + cpu_block.ref_count += 1 + else: + cpu_block = self.cpu_allocator.allocate( + gpu_block.block_hash, gpu_block.num_hashed_tokens) + mapping[gpu_block] = cpu_block + new_block_table.append(cpu_block) + # Free the GPU block swapped out to CPU. + self.gpu_allocator.free(gpu_block) + + return new_block_table + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: request_id = seq_group.request_id @@ -566,38 +582,14 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: # dict is efficient in lookup `if gpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - new_block_table: BlockTable = [] - block_table = self.block_tables[seq.seq_id] - - for gpu_block in block_table: - if gpu_block in mapping: - cpu_block = mapping[gpu_block] - cpu_block.ref_count += 1 - else: - cpu_block = self.cpu_allocator.allocate( - gpu_block.block_hash, gpu_block.num_hashed_tokens) - mapping[gpu_block] = cpu_block - new_block_table.append(cpu_block) - # Free the GPU block swapped out to CPU. - self.gpu_allocator.free(gpu_block) - self.block_tables[seq.seq_id] = new_block_table + self.block_tables[seq.seq_id] = \ + self._swap_out_block_table(self.block_tables[seq.seq_id], + mapping) if seq_group.encoder_seq is not None: - new_block_table = [] - block_table = self.cross_block_tables[request_id] - - for gpu_block in block_table: - if gpu_block in mapping: - cpu_block = mapping[gpu_block] - cpu_block.ref_count += 1 - else: - cpu_block = self.cpu_allocator.allocate( - gpu_block.block_hash, gpu_block.num_hashed_tokens) - mapping[gpu_block] = cpu_block - new_block_table.append(cpu_block) - # Free the GPU block swapped out to CPU. - self.gpu_allocator.free(gpu_block) - self.cross_block_tables[request_id] = new_block_table + self.cross_block_tables[request_id] = \ + self._swap_out_block_table(self.cross_block_tables[request_id], + mapping) return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] From 98c5863ef946dbd52221b6b83517e483f48b3848 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 18:53:40 -0400 Subject: [PATCH 043/443] Help function avoids prefix caching code in encoder/decoder scenarios; alloc method asserts no prefix caching + enc/dec; refactoring --- vllm/core/block_manager_v1.py | 36 +++++++++++++++++------------------ vllm/core/block_manager_v2.py | 16 ---------------- 2 files changed, 18 insertions(+), 34 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index dd6d8d702fae0..40274bd29e9b0 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -290,7 +290,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def _allocate_sequence(self, seq: Sequence, ref_count: int) -> BlockTable: + def _allocate_sequence(self, \ + seq: Sequence, \ + ref_count: int, \ + decoder_only: bool = True) -> BlockTable: # Allocate new physical token blocks that will store the prompt tokens. num_prompt_blocks = len(seq.logical_token_blocks) @@ -300,27 +303,36 @@ def _allocate_sequence(self, seq: Sequence, ref_count: int) -> BlockTable: and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] # Set the reference counts of the token blocks. - block.ref_count = ref_count #seq_group.num_seqs() - elif self.enable_caching: + block.ref_count = ref_count + elif decoder_only and self.enable_caching: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) else: block = self.gpu_allocator.allocate() # Set the reference counts of the token blocks. - block.ref_count = ref_count #seq_group.num_seqs() + block.ref_count = ref_count block_table.append(block) return block_table def allocate(self, seq_group: SequenceGroup) -> None: + decoder_only = \ + seq_group.get_encoder_seq() is None + + assert decoder_only or (not self.enable_caching), \ + "Automatic prefix caching currently not " + \ + "supported for encoder/decoder models." + # Allocate decoder sequences # # NOTE: Here we assume that all sequences in the group have the same # decoder prompt. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] block_table: BlockTable = \ - self._allocate_sequence(seq, seq_group.num_seqs()) + self._allocate_sequence(seq, + seq_group.num_seqs(), + decoder_only) # Assign the self-attention block tables for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): @@ -331,7 +343,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: if encoder_seq is not None: # A SequenceGroup has only a single encoder sequence (at most), # thus allocate with a ref count of 1 - block_table = self._allocate_sequence(encoder_seq, 1) + block_table = self._allocate_sequence(encoder_seq, 1, decoder_only) # Assign the cross-attention block table for the SequenceGroup. self.cross_block_tables[seq_group.request_id] = block_table @@ -661,18 +673,6 @@ def access_all_blocks_in_seq( for block in block_table: block.last_accessed = access_time - def access_all_cross_blocks_in_seq_group( - self, - seq_group: SequenceGroup, - access_time: float, - ) -> None: - if self.enable_caching: - # Update the last accessed time of all the blocks accessed - # in this step. - block_table = self.cross_block_tables[seq_group.request_id] - for block in block_table: - block.last_accessed = access_time - def compute_full_blocks_in_seq(self, seq: Sequence): if seq.seq_id not in self.block_tables: return diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index a8085f54ac79d..31d1a60657832 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -260,22 +260,6 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float): block_ids, # type: ignore now) - def access_all_cross_blocks_in_seq_group( - self, - seq_group: SequenceGroup, - now: float, - ) -> None: - if self.enable_caching: - # Update the last accessed time of all the blocks accessed - # in this step. - block_table = self.cross_block_tables[seq_group.request_id] - block_ids = [] - for block_id in block_table.physical_block_ids: - block_ids.append(block_id) - self.block_allocator.mark_blocks_as_accessed( - block_ids, # type: ignore - now) - def mark_blocks_as_computed(self, seq_group: SequenceGroup): # The only need for mark block as computed is for prefix caching, # while currently we could determine whether one block is computed From 611bcec382ae4291f3a963d9eaa75889fe897251 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 11:17:48 -0400 Subject: [PATCH 044/443] fixed bugs introduced by merge --- tests/layer/test_self_and_cross_attn.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 811878a347e97..6cc365faa4ea5 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -1063,19 +1063,18 @@ def cross_attn_setup_reuses_query(query, max_block_idx -def run_self_attention_test(attn, packed_query, packed_key, packed_value, - kv_cache, attn_metadata: AttentionMetadata, scale): +def run_self_attention_test(attn: Attention, packed_query, packed_key, packed_value, + kv_cache, attn_metadata: AttentionMetadata): attn_metadata.do_cross_attn = False return attn.forward(packed_query, packed_key, packed_value, kv_cache, - attn_metadata, scale) + attn_metadata) -def run_cross_attention_test(attn, packed_query, packed_key, packed_value, - kv_cache, attn_metadata: AttentionMetadata, - scale): +def run_cross_attention_test(attn: Attention, packed_query, packed_key, packed_value, + kv_cache, attn_metadata: AttentionMetadata): attn_metadata.do_cross_attn = True return attn.forward(packed_query, packed_key, packed_value, kv_cache, - attn_metadata, scale) + attn_metadata) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -1201,7 +1200,7 @@ def test_prefill_decode_self_and_cross_attention( self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( attn, prefill_packed_query, self_prefill_packed_key, - self_prefill_packed_value, kv_cache, prefill_attn_metadata, scale) + self_prefill_packed_value, kv_cache, prefill_attn_metadata) # - Prefill self-attention correct? assert torch.allclose( @@ -1211,7 +1210,7 @@ def test_prefill_decode_self_and_cross_attention( cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test( attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, prefill_attn_metadata, scale) + cross_prefill_packed_value, kv_cache, prefill_attn_metadata) # - Prefill cross-attention correct? assert torch.allclose( @@ -1237,7 +1236,7 @@ def test_prefill_decode_self_and_cross_attention( self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( attn, decode_packed_query, self_decode_packed_key, - self_decode_packed_value, kv_cache, decode_attn_metadata, scale) + self_decode_packed_value, kv_cache, decode_attn_metadata) # - Decode self-attention correct? assert torch.allclose( @@ -1246,8 +1245,7 @@ def test_prefill_decode_self_and_cross_attention( self_decode_packed_ideal_output)) cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test( - attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata, - scale) + attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) # - Decode cross-attention correct? assert torch.allclose( From ed2f56deee922614f296f591afb57321387d6112 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 11:20:22 -0400 Subject: [PATCH 045/443] moved enc/dec test into tests/kernels so that it will be run automatically using existing buildkite config --- tests/{layer => kernels}/test_self_and_cross_attn.py | 0 tests/layer/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/{layer => kernels}/test_self_and_cross_attn.py (100%) delete mode 100644 tests/layer/__init__.py diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py similarity index 100% rename from tests/layer/test_self_and_cross_attn.py rename to tests/kernels/test_self_and_cross_attn.py diff --git a/tests/layer/__init__.py b/tests/layer/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 From b4ec9c6de46a114986f58b8f83608e8bad1ec755 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 11:21:07 -0400 Subject: [PATCH 046/443] formatting --- tests/kernels/test_self_and_cross_attn.py | 76 +++++++++++------------ 1 file changed, 36 insertions(+), 40 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 6cc365faa4ea5..04dc52d51af19 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -43,22 +43,20 @@ def build_causal_mask(q_max_seq_len, kv_max_seq_len): ''' # Create a matrix where entry (i, j) is True if i >= j - mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), - diagonal=1) + mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) # Replace True with float('-inf') and False with 0 mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, 0.0) return mask -def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - custom_mask: Optional[torch.Tensor] = None, - q_seq_lens: Optional[List] = None, - kv_seq_lens: Optional[List] = None) -> torch.Tensor: +def ref_masked_attention(query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + custom_mask: Optional[torch.Tensor] = None, + q_seq_lens: Optional[List] = None, + kv_seq_lens: Optional[List] = None) -> torch.Tensor: ''' "Golden" masked attention reference. Supports two types of masking: @@ -215,26 +213,23 @@ def make_qkv(batch_size, decode_value = torch.zeros( (batch_size, 1, num_heads * head_size)).to(device) - for bdx, (q_seq_len, - kv_seq_len) in enumerate(zip(q_seq_lens, kv_seq_lens)): + for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, + kv_seq_lens)): query[bdx, q_seq_len:, :] = 0 key[bdx, kv_seq_len:, :] = 0 value[bdx, kv_seq_len:, :] = 0 - prefill_query[bdx, - 0:(q_seq_len - 1), :] = query[bdx, - 0:(q_seq_len - 1), :] - prefill_key[bdx, - 0:(kv_seq_len - 1), :] = key[bdx, - 0:(kv_seq_len - 1), :] - prefill_value[bdx, 0:(kv_seq_len - - 1), :] = value[bdx, 0:(kv_seq_len - 1), :] - - decode_query[bdx, :, :] = query[bdx, - (q_seq_len - 1):q_seq_len, :] + prefill_query[bdx, 0:(q_seq_len - 1), :] = query[bdx, + 0:(q_seq_len - 1), :] + prefill_key[bdx, 0:(kv_seq_len - 1), :] = key[bdx, + 0:(kv_seq_len - 1), :] + prefill_value[bdx, + 0:(kv_seq_len - 1), :] = value[bdx, + 0:(kv_seq_len - 1), :] + + decode_query[bdx, :, :] = query[bdx, (q_seq_len - 1):q_seq_len, :] decode_key[bdx, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :] - decode_value[bdx, :, :] = value[bdx, - (kv_seq_len - 1):kv_seq_len, :] + decode_value[bdx, :, :] = value[bdx, (kv_seq_len - 1):kv_seq_len, :] prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] @@ -304,12 +299,10 @@ def pack_tensor(unpacked_tensor, seq_lens, device=CUDA_DEVICE): start_loc_list = [0] + list(itertools.accumulate(seq_lens)) packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) - for bdx, (seq_len, - start_loc) in enumerate(zip(seq_lens, start_loc_list)): + for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)): packed_tensor[start_loc:( - start_loc + - seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] + start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] return packed_tensor, start_loc_list @@ -405,9 +398,7 @@ def make_metadata_tensors(is_prompt: bool, * seq_start_loc: start idx of each sequence * query_start_loc: start idx of each query ''' - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) context_lens_tensor = None if context_lens is None else torch.tensor( context_lens, dtype=torch.int, device=device) max_context_len = None if context_lens is None else max(context_lens) @@ -1063,15 +1054,17 @@ def cross_attn_setup_reuses_query(query, max_block_idx -def run_self_attention_test(attn: Attention, packed_query, packed_key, packed_value, - kv_cache, attn_metadata: AttentionMetadata): +def run_self_attention_test(attn: Attention, packed_query, packed_key, + packed_value, kv_cache, + attn_metadata: AttentionMetadata): attn_metadata.do_cross_attn = False return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) -def run_cross_attention_test(attn: Attention, packed_query, packed_key, packed_value, - kv_cache, attn_metadata: AttentionMetadata): +def run_cross_attention_test(attn: Attention, packed_query, packed_key, + packed_value, kv_cache, + attn_metadata: AttentionMetadata): attn_metadata.do_cross_attn = True return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) @@ -1084,10 +1077,13 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, packed_v @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) @pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) -def test_prefill_decode_self_and_cross_attention( - num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_q_seq_len: int, - max_kv_seq_len: int) -> None: +def test_prefill_decode_self_and_cross_attention(num_heads: int, + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_q_seq_len: int, + max_kv_seq_len: int) -> None: ''' Test: From 84f5510a0a4e7d0b81b32e772e1cf710be83112b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 13:49:30 -0400 Subject: [PATCH 047/443] block manager v1 NotImplementError's for sliding window and automatic prefix caching --- vllm/core/block_manager_v1.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 40274bd29e9b0..d5da128f1a691 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -277,6 +277,11 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: cross_num_required_blocks if self.block_sliding_window is not None: + if seq_group.get_encoder_seq() is not None: + raise NotImplementedError( + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported.") + num_required_blocks = min(num_required_blocks, self.block_sliding_window) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() @@ -320,9 +325,16 @@ def allocate(self, seq_group: SequenceGroup) -> None: decoder_only = \ seq_group.get_encoder_seq() is None - assert decoder_only or (not self.enable_caching), \ - "Automatic prefix caching currently not " + \ - "supported for encoder/decoder models." + if (self.block_sliding_window is not None) and \ + (not decoder_only): + raise NotImplementedError( + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported.") + + if self.enable_caching and (not decoder_only): + raise NotImplementedError( + "Automatic prefix caching currently not " + \ + "supported for encoder/decoder models.") # Allocate decoder sequences # From cc61959d2075816ee49fa7a802e3c2240e737546 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 13:56:11 -0400 Subject: [PATCH 048/443] Fixes --- vllm/core/block_manager_v2.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 31d1a60657832..9c6466de468e5 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -152,7 +152,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # encoder prompt. request_id = seq_group.request_id - encoder_seq = seq_group.encoder_seq assert (request_id not in self.cross_block_tables), \ @@ -160,12 +159,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: encoder_seq = seq_group.get_encoder_seq() if encoder_seq is not None: - block_table = BlockTable( - block_size=self.block_size, - block_allocator=self.block_allocator, - ) - assert self.block_sliding_window is None - block_table.allocate(encoder_seq.get_token_ids()) + block_table: BlockTable = self._allocate_sequence(encoder_seq) self.cross_block_tables[request_id] = block_table def can_append_slots(self, seq_group: SequenceGroup, @@ -229,8 +223,6 @@ def free_cross(self, seq_group: SequenceGroup) -> None: self.cross_block_tables[request_id].free() del self.cross_block_tables[request_id] - del self.cross_block_tables[seq_group.request_id] - def get_block_table(self, seq: Sequence) -> List[int]: assert seq.seq_id in self.block_tables block_ids = self.block_tables[seq.seq_id].physical_block_ids From dcb9abe115cfd6bfa8f2131c645cbc0bb6acb2ab Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 13:58:30 -0400 Subject: [PATCH 049/443] formatting --- vllm/core/block_manager_v1.py | 2 +- vllm/core/block_manager_v2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index d5da128f1a691..95e9e5e20940d 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -281,7 +281,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: raise NotImplementedError( "Sliding window attention for encoder/decoder models " + \ "is not currently supported.") - + num_required_blocks = min(num_required_blocks, self.block_sliding_window) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 9c6466de468e5..b89f1cd05d1c1 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -159,7 +159,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: encoder_seq = seq_group.get_encoder_seq() if encoder_seq is not None: - block_table: BlockTable = self._allocate_sequence(encoder_seq) + block_table = self._allocate_sequence(encoder_seq) self.cross_block_tables[request_id] = block_table def can_append_slots(self, seq_group: SequenceGroup, From 5cd154102ab645bc246dd734d797e0d5d8a1652f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 16:23:24 -0400 Subject: [PATCH 050/443] Added explanatory comment to XFormersImpl.forward() --- vllm/attention/backends/xformers.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 36f1343e995df..f9d0d924b395e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -395,14 +395,29 @@ def __init__( def forward( self, query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor], attn_metadata: "XFormersMetadata", kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * XFormersImpl.forward() may be invoked for both self- and cross- + attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] From f3c430b3a226e249c611630ce776530d48971f0a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 16:37:32 -0400 Subject: [PATCH 051/443] Explanatory comment about sequence argument. --- vllm/sequence.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 6b07a00f09c6f..9670786f8f16c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -420,7 +420,8 @@ class SequenceGroup: for an embedding model. pooling_params: The pooling parameters used to generate the pooling for an embedding model. - encoder_seq: Optional, the single encoder sequence. + encoder_seq: Optional, the single encoder sequence. Should be None + unless you are working with an encoder/decoder model. """ def __init__( From f2564e0f1cc95fec5880847aa380a462ecd3d0bf Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:02:37 -0400 Subject: [PATCH 052/443] clarifying comment --- vllm/sequence.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 9670786f8f16c..9c8fcccab75ae 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -614,11 +614,15 @@ class SequenceGroupMetadata: used in prefix caching. state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. - encoder_seq_data: Optional, the sequence data - for the single encoder prompt. - cross_block_table: Optional, the cross-attention - block table associated with - the single encoder prompt. + encoder_seq_data: Optional sequence data for encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. + cross_block_table: Optional cross-attention block table associated + with the encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. """ def __init__( From e8c40fcf152c5d2f6514830644c8eb683eee7aa9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:08:00 -0400 Subject: [PATCH 053/443] explanatory comment --- vllm/sequence.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 6b07a00f09c6f..a456ecc111e4c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -613,11 +613,15 @@ class SequenceGroupMetadata: used in prefix caching. state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. - encoder_seq_data: Optional, the sequence data - for the single encoder prompt. - cross_block_table: Optional, the cross-attention - block table associated with - the single encoder prompt. + encoder_seq_data: Optional sequence data for encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. + cross_block_table: Optional cross-attention block table associated + with the encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. """ def __init__( From 5ccb70be1209521d0aa1e3d7cae7bf7707ac2fd8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:18:03 -0400 Subject: [PATCH 054/443] various fixes according to reviews --- vllm/core/block_manager_v1.py | 2 +- vllm/core/block_manager_v2.py | 14 ++++++++++++++ vllm/sequence.py | 3 ++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 95e9e5e20940d..1c81edb7a2df3 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -352,7 +352,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: # Allocate encoder sequence encoder_seq = seq_group.get_encoder_seq() - if encoder_seq is not None: + if not decoder_only: # A SequenceGroup has only a single encoder sequence (at most), # thus allocate with a ref count of 1 block_table = self._allocate_sequence(encoder_seq, 1, decoder_only) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b89f1cd05d1c1..f094bf99e3201 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -132,6 +132,9 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: return block_table def allocate(self, seq_group: SequenceGroup) -> None: + decoder_only = \ + seq_group.get_encoder_seq() is None + # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert not (set(seq.seq_id for seq in waiting_seqs) @@ -157,6 +160,17 @@ def allocate(self, seq_group: SequenceGroup) -> None: not in self.cross_block_tables), \ "block table already exists" + if (self.block_sliding_window is not None) and \ + (not decoder_only): + raise NotImplementedError( + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported.") + + if self.enable_caching and (not decoder_only): + raise NotImplementedError( + "Automatic prefix caching currently not " + \ + "supported for encoder/decoder models.") + encoder_seq = seq_group.get_encoder_seq() if encoder_seq is not None: block_table = self._allocate_sequence(encoder_seq) diff --git a/vllm/sequence.py b/vllm/sequence.py index a456ecc111e4c..9c8fcccab75ae 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -420,7 +420,8 @@ class SequenceGroup: for an embedding model. pooling_params: The pooling parameters used to generate the pooling for an embedding model. - encoder_seq: Optional, the single encoder sequence. + encoder_seq: Optional, the single encoder sequence. Should be None + unless you are working with an encoder/decoder model. """ def __init__( From dfcc28b19188a11c74aee06265051eb8fbbe599f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:22:41 -0400 Subject: [PATCH 055/443] slight refactoring --- vllm/core/block_manager_v1.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 1c81edb7a2df3..2daf45182bba9 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -322,8 +322,8 @@ def _allocate_sequence(self, \ return block_table def allocate(self, seq_group: SequenceGroup) -> None: - decoder_only = \ - seq_group.get_encoder_seq() is None + encoder_seq = seq_group.get_encoder_seq() + decoder_only = encoder_seq is None if (self.block_sliding_window is not None) and \ (not decoder_only): @@ -351,7 +351,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: self.block_tables[seq.seq_id] = block_table.copy() # Allocate encoder sequence - encoder_seq = seq_group.get_encoder_seq() if not decoder_only: # A SequenceGroup has only a single encoder sequence (at most), # thus allocate with a ref count of 1 From 8d3ad05a9f7d568f16eea6e090f6803869fc5443 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:26:54 -0400 Subject: [PATCH 056/443] small refactor --- vllm/core/block_manager_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index f094bf99e3201..6e02359f51782 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -132,8 +132,9 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: return block_table def allocate(self, seq_group: SequenceGroup) -> None: + encoder_seq = seq_group.get_encoder_seq() decoder_only = \ - seq_group.get_encoder_seq() is None + encoder_seq is None # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) @@ -171,8 +172,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: "Automatic prefix caching currently not " + \ "supported for encoder/decoder models.") - encoder_seq = seq_group.get_encoder_seq() - if encoder_seq is not None: + if not decoder_only: block_table = self._allocate_sequence(encoder_seq) self.cross_block_tables[request_id] = block_table From 5a7697976a964cf23d6141d9e432abb63d3f9e9d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:34:34 -0400 Subject: [PATCH 057/443] replaced all encoder_seq is not None with not decoder_only --- vllm/core/block_manager_v1.py | 19 +++++++++++++++---- vllm/core/block_manager_v2.py | 5 ++++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 2daf45182bba9..2e5d531565379 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -496,6 +496,10 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def _get_physical_blocks( self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: + encoder_seq = seq_group.get_encoder_seq() + decoder_only = \ + encoder_seq is None + # NOTE: Here, we assume that the physical blocks are only shared by # the sequences in the same group. request_id = seq_group.request_id @@ -505,7 +509,7 @@ def _get_physical_blocks( continue blocks.update(self.block_tables[seq.seq_id]) # Cross-attention blocks - if seq_group.encoder_seq is not None: + if not decoder_only: blocks.update(self.cross_block_tables[request_id]) return list(blocks) @@ -514,9 +518,12 @@ def can_swap_in(self, num_lookahead_slots: int = 0) -> AllocStatus: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" + encoder_seq = seq_group.get_encoder_seq() + decoder_only = encoder_seq is None + blocks = self._get_physical_blocks(seq_group) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) - if seq_group.encoder_seq is not None: + if not decoder_only: num_swapped_seqs += 1 num_free_blocks = self.gpu_allocator.get_num_free_blocks() # NOTE: Conservatively, we assume that every sequence will allocate @@ -556,6 +563,8 @@ def swap_in(self, assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" + encoder_seq = seq_group.get_encoder_seq() + decoder_only = encoder_seq is None request_id = seq_group.request_id # CPU block -> GPU block. @@ -566,7 +575,7 @@ def swap_in(self, self._swap_in_block_table(self.block_tables[seq.seq_id], mapping) - if seq_group.encoder_seq is not None: + if not decoder_only: self.cross_block_tables[request_id] = \ self._swap_in_block_table(self.cross_block_tables[request_id], mapping) @@ -600,6 +609,8 @@ def _swap_out_block_table( def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: request_id = seq_group.request_id + encoder_seq = seq_group.get_encoder_seq() + decoder_only = encoder_seq is None # GPU block -> CPU block. # dict is efficient in lookup `if gpu_block in mapping` @@ -609,7 +620,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self._swap_out_block_table(self.block_tables[seq.seq_id], mapping) - if seq_group.encoder_seq is not None: + if not decoder_only: self.cross_block_tables[request_id] = \ self._swap_out_block_table(self.cross_block_tables[request_id], mapping) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 6e02359f51782..a8090c1f93b5a 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -91,6 +91,9 @@ def __init__( def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. + encoder_seq = seq_group.get_encoder_seq() + decoder_only = encoder_seq is None + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( @@ -98,7 +101,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: block_size=self.block_size, ) - if seq_group.encoder_seq is not None: + if not decoder_only: num_required_blocks += BlockTable.get_num_required_blocks( seq_group.encoder_seq.get_token_ids(), block_size=self.block_size, From 09ae4adb656b79897d62d28015f968b0c7471d8e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:51:23 -0400 Subject: [PATCH 058/443] added is_encoder_decoder() method to sequence group --- vllm/core/block_manager_v1.py | 36 ++++++++++++++--------------------- vllm/core/block_manager_v2.py | 16 ++++++---------- vllm/sequence.py | 3 +++ 3 files changed, 23 insertions(+), 32 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 2e5d531565379..69a280c8bf9c6 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -277,7 +277,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: cross_num_required_blocks if self.block_sliding_window is not None: - if seq_group.get_encoder_seq() is not None: + if seq_group.is_encoder_decoder(): raise NotImplementedError( "Sliding window attention for encoder/decoder models " + \ "is not currently supported.") @@ -298,7 +298,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: def _allocate_sequence(self, \ seq: Sequence, \ ref_count: int, \ - decoder_only: bool = True) -> BlockTable: + is_encoder_decoder: bool = True) -> BlockTable: # Allocate new physical token blocks that will store the prompt tokens. num_prompt_blocks = len(seq.logical_token_blocks) @@ -309,7 +309,7 @@ def _allocate_sequence(self, \ block = block_table[logical_idx % self.block_sliding_window] # Set the reference counts of the token blocks. block.ref_count = ref_count - elif decoder_only and self.enable_caching: + elif not is_encoder_decoder and self.enable_caching: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) @@ -323,15 +323,15 @@ def _allocate_sequence(self, \ def allocate(self, seq_group: SequenceGroup) -> None: encoder_seq = seq_group.get_encoder_seq() - decoder_only = encoder_seq is None + is_encoder_decoder = seq_group.is_encoder_decoder() if (self.block_sliding_window is not None) and \ - (not decoder_only): + is_encoder_decoder: raise NotImplementedError( "Sliding window attention for encoder/decoder models " + \ "is not currently supported.") - if self.enable_caching and (not decoder_only): + if self.enable_caching and is_encoder_decoder: raise NotImplementedError( "Automatic prefix caching currently not " + \ "supported for encoder/decoder models.") @@ -344,17 +344,18 @@ def allocate(self, seq_group: SequenceGroup) -> None: block_table: BlockTable = \ self._allocate_sequence(seq, seq_group.num_seqs(), - decoder_only) + is_encoder_decoder) # Assign the self-attention block tables for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() # Allocate encoder sequence - if not decoder_only: + if is_encoder_decoder: # A SequenceGroup has only a single encoder sequence (at most), # thus allocate with a ref count of 1 - block_table = self._allocate_sequence(encoder_seq, 1, decoder_only) + block_table = self._allocate_sequence(encoder_seq, 1, + is_encoder_decoder) # Assign the cross-attention block table for the SequenceGroup. self.cross_block_tables[seq_group.request_id] = block_table @@ -496,9 +497,6 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def _get_physical_blocks( self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: - encoder_seq = seq_group.get_encoder_seq() - decoder_only = \ - encoder_seq is None # NOTE: Here, we assume that the physical blocks are only shared by # the sequences in the same group. @@ -509,7 +507,7 @@ def _get_physical_blocks( continue blocks.update(self.block_tables[seq.seq_id]) # Cross-attention blocks - if not decoder_only: + if seq_group.is_encoder_decoder(): blocks.update(self.cross_block_tables[request_id]) return list(blocks) @@ -518,12 +516,10 @@ def can_swap_in(self, num_lookahead_slots: int = 0) -> AllocStatus: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" - encoder_seq = seq_group.get_encoder_seq() - decoder_only = encoder_seq is None blocks = self._get_physical_blocks(seq_group) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) - if not decoder_only: + if seq_group.is_encoder_decoder(): num_swapped_seqs += 1 num_free_blocks = self.gpu_allocator.get_num_free_blocks() # NOTE: Conservatively, we assume that every sequence will allocate @@ -563,8 +559,6 @@ def swap_in(self, assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" - encoder_seq = seq_group.get_encoder_seq() - decoder_only = encoder_seq is None request_id = seq_group.request_id # CPU block -> GPU block. @@ -575,7 +569,7 @@ def swap_in(self, self._swap_in_block_table(self.block_tables[seq.seq_id], mapping) - if not decoder_only: + if seq_group.is_encoder_decoder(): self.cross_block_tables[request_id] = \ self._swap_in_block_table(self.cross_block_tables[request_id], mapping) @@ -609,8 +603,6 @@ def _swap_out_block_table( def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: request_id = seq_group.request_id - encoder_seq = seq_group.get_encoder_seq() - decoder_only = encoder_seq is None # GPU block -> CPU block. # dict is efficient in lookup `if gpu_block in mapping` @@ -620,7 +612,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self._swap_out_block_table(self.block_tables[seq.seq_id], mapping) - if not decoder_only: + if seq_group.is_encoder_decoder(): self.cross_block_tables[request_id] = \ self._swap_out_block_table(self.cross_block_tables[request_id], mapping) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index a8090c1f93b5a..0dd2ffcd182ec 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -91,19 +91,16 @@ def __init__( def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - encoder_seq = seq_group.get_encoder_seq() - decoder_only = encoder_seq is None seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = BlockTable.get_num_required_blocks( seq.get_token_ids(), block_size=self.block_size, ) - if not decoder_only: + if seq_group.is_encoder_decoder(): num_required_blocks += BlockTable.get_num_required_blocks( - seq_group.encoder_seq.get_token_ids(), + seq_group.get_encoder_seq().get_token_ids(), block_size=self.block_size, ) @@ -136,8 +133,7 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: def allocate(self, seq_group: SequenceGroup) -> None: encoder_seq = seq_group.get_encoder_seq() - decoder_only = \ - encoder_seq is None + is_encoder_decoder = seq_group.is_encoder_decoder() # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) @@ -165,17 +161,17 @@ def allocate(self, seq_group: SequenceGroup) -> None: "block table already exists" if (self.block_sliding_window is not None) and \ - (not decoder_only): + is_encoder_decoder: raise NotImplementedError( "Sliding window attention for encoder/decoder models " + \ "is not currently supported.") - if self.enable_caching and (not decoder_only): + if self.enable_caching and is_encoder_decoder: raise NotImplementedError( "Automatic prefix caching currently not " + \ "supported for encoder/decoder models.") - if not decoder_only: + if is_encoder_decoder: block_table = self._allocate_sequence(encoder_seq) self.cross_block_tables[request_id] = block_table diff --git a/vllm/sequence.py b/vllm/sequence.py index 9c8fcccab75ae..ad6c8d54974c3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -528,6 +528,9 @@ def get_seqs( seq for seq in self.seqs_dict.values() if seq.status == status ] + def is_encoder_decoder(self) -> bool: + return self.encoder_seq is not None + def get_encoder_seq(self) -> Optional[Sequence]: return self.encoder_seq From ecd1a998579ac171ce1936444fe9f7c8a6a09c92 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 18:59:03 -0400 Subject: [PATCH 059/443] tests for NotImplemented errors when encoder/decoder models are used with prefix cache or SWA --- tests/core/block/test_block_manager_v2.py | 103 +++++++++++++++++++++- tests/core/test_block_manager.py | 64 +++++++++++++- vllm/core/block_manager_v1.py | 29 +++--- vllm/core/block_manager_v2.py | 28 ++++-- 4 files changed, 205 insertions(+), 19 deletions(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 06c3389cfa0f0..cf423d292a25e 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -1,6 +1,8 @@ import pytest -from vllm.core.block_manager_v2 import BlockSpaceManagerV2 +from vllm.core.block_manager_v2 import (BlockSpaceManagerV2, + str_not_impl_enc_dec_prefix_cache, + str_not_impl_enc_dec_swa) from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list @@ -103,6 +105,105 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, assert can_allocate_result == AllocStatus.LATER +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16]) +@pytest.mark.parametrize("num_seqs_per_group", [1]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_allocate_encoder_decoder_fails_with_swa(block_size: int, + num_seqs_per_group: int, + num_gpu_blocks: int, + watermark: float): + ''' + SWA short for Sliding Window Attention. + + At time of writing block manager v2 does not support SWA. + + However even when SWA is implemented for block manager v2, + there will still most likely be a separate workstream required + to enable SWA for encoder/decoder models. + + Therefore this test enforces that one of the following cases + hold true: + 1. Block manager v2 does not support SWA at all (true at time of writing) + 2. Block manager v2 fails with NotImplementError when SWA is enabled + AND a SequenceGroup with an encoder sequence (i.e. in support of an + encoder/decoder model) is passed into can_allocate() as an argument + + The setup for this test is stripped down version of + test_can_allocate_seq_group_encoder_decoder() + ''' + + with pytest.raises((NotImplementedError, AssertionError)) as exc_info: + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + sliding_window=5 # SWA + ) + + num_output_blocks_per_seq = 1 + num_prompt_blocks = 1 + num_output_blocks = num_output_blocks_per_seq + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id="0") + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + block_manager.can_allocate(seq_group) + + # Assert that either + # 1. Block manager v2 constructor fails with assertion that sliding window + # is not yet supported (most likely near-term outcome at time of + # writing), or + # 2. can_allocate() fails with NotImplementedError due to combiantion of + # encoder/decoder and sliding window attention + if isinstance(exc_info.value, NotImplementedError): + assert str(exc_info.value) == str_not_impl_enc_dec_swa + elif isinstance(exc_info.value, AssertionError): + assert str(exc_info.value) == "Sliding window not yet supported" + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16]) +@pytest.mark.parametrize("num_seqs_per_group", [1]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_allocate_encoder_decoder_fails_with_prefix_cache( + block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, + watermark: float): + + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + enable_caching=True # Prefix cache + ) + + num_output_blocks_per_seq = 1 + num_prompt_blocks = 1 + num_output_blocks = num_output_blocks_per_seq + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id="0") + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + + # Assert that either can_allocate() fails with NotImplementedError + # due to combination of encoder/decoder and prefix cache + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + assert str(exc_info.value) == str_not_impl_enc_dec_prefix_cache + + @pytest.mark.parametrize("block_size", [1, 8]) @pytest.mark.parametrize("prompt_len", [1, 7, 8]) @pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index cdaf2f22115e8..6039f568fcf1e 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -7,7 +7,9 @@ from vllm import SamplingParams from vllm.block import PhysicalTokenBlock from vllm.core.block_manager_v1 import (BlockSpaceManagerV1, - UncachedBlockAllocator) + UncachedBlockAllocator, + str_not_impl_enc_dec_prefix_cache, + str_not_impl_enc_dec_swa) from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device @@ -126,6 +128,66 @@ def test_allocate_encoder_decoder(): assert block_manager.can_allocate(seq_group) != AllocStatus.OK +def test_allocate_encoder_decoder_fails_with_swa(): + # SWA short for sliding window attention + + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + sliding_window=5) # swa + + # Allocate same sequence group to all available gpu blocks. + _, _, seq_group = create_dummy_prompt_encoder_decoder( + "0", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + + # Assert that can_allocate() fails due to SWA + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + + assert str(exc_info.value) == str_not_impl_enc_dec_swa + + # Assert that allocate() fails due to SWA + with pytest.raises(NotImplementedError) as exc_info: + block_manager.allocate(seq_group) + + assert str(exc_info.value) == str_not_impl_enc_dec_swa + + +def test_allocate_encoder_decoder_fails_with_prefix_caching(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=True) # Prefix cache + + # Allocate same sequence group to all available gpu blocks. + _, _, seq_group = create_dummy_prompt_encoder_decoder( + "0", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + + # Assert that can_allocate() fails due to prefix caching + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + + assert str(exc_info.value) == str_not_impl_enc_dec_prefix_cache + + # Assert that allocate() fails due to prefix caching + with pytest.raises(NotImplementedError) as exc_info: + block_manager.allocate(seq_group) + + assert str(exc_info.value) == str_not_impl_enc_dec_prefix_cache + + def test_append_slot_single_seq(): block_size = 4 num_cpu_blocks = 4 diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 69a280c8bf9c6..904b12cd97b01 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -15,6 +15,17 @@ from vllm.utils import Device logger = init_logger(__name__) +''' +Exception strings for non-implemented encoder/decoder scenarios +''' + +str_not_impl_enc_dec_swa = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +str_not_impl_enc_dec_prefix_cache = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." class BlockAllocatorBase(ABC): @@ -269,6 +280,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. + is_encoder_decoder = seq_group.is_encoder_decoder() + if self.enable_caching and is_encoder_decoder: + raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) + self_num_required_blocks = self._get_seq_num_required_blocks( seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) cross_num_required_blocks = self._get_seq_num_required_blocks( @@ -277,10 +292,8 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: cross_num_required_blocks if self.block_sliding_window is not None: - if seq_group.is_encoder_decoder(): - raise NotImplementedError( - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported.") + if is_encoder_decoder: + raise NotImplementedError(str_not_impl_enc_dec_swa) num_required_blocks = min(num_required_blocks, self.block_sliding_window) @@ -327,14 +340,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: if (self.block_sliding_window is not None) and \ is_encoder_decoder: - raise NotImplementedError( - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported.") + raise NotImplementedError(str_not_impl_enc_dec_swa) if self.enable_caching and is_encoder_decoder: - raise NotImplementedError( - "Automatic prefix caching currently not " + \ - "supported for encoder/decoder models.") + raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) # Allocate decoder sequences # diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 0dd2ffcd182ec..d2dadd9a63dc2 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -8,6 +8,17 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device +''' +Exception strings for non-implemented encoder/decoder scenarios +''' + +str_not_impl_enc_dec_swa = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +str_not_impl_enc_dec_prefix_cache = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." SeqId = int EncoderSeqId = str @@ -92,13 +103,20 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. + is_encoder_decoder = seq_group.is_encoder_decoder() + if self.enable_caching and is_encoder_decoder: + raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) + + if self.block_sliding_window is not None and is_encoder_decoder: + raise NotImplementedError(str_not_impl_enc_dec_swa) + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( seq.get_token_ids(), block_size=self.block_size, ) - if seq_group.is_encoder_decoder(): + if is_encoder_decoder: num_required_blocks += BlockTable.get_num_required_blocks( seq_group.get_encoder_seq().get_token_ids(), block_size=self.block_size, @@ -162,14 +180,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: if (self.block_sliding_window is not None) and \ is_encoder_decoder: - raise NotImplementedError( - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported.") + raise NotImplementedError(str_not_impl_enc_dec_swa) if self.enable_caching and is_encoder_decoder: - raise NotImplementedError( - "Automatic prefix caching currently not " + \ - "supported for encoder/decoder models.") + raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) if is_encoder_decoder: block_table = self._allocate_sequence(encoder_seq) From d3935f73b5038ba7acc75fff07282b7f7fda6ed5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 19:05:36 -0400 Subject: [PATCH 060/443] rename tests --- tests/core/block/test_block_manager_v2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index cf423d292a25e..c893bc8f4209e 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -109,10 +109,10 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, @pytest.mark.parametrize("num_gpu_blocks", [16]) @pytest.mark.parametrize("num_seqs_per_group", [1]) @pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_allocate_encoder_decoder_fails_with_swa(block_size: int, - num_seqs_per_group: int, - num_gpu_blocks: int, - watermark: float): +def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, + num_seqs_per_group: int, + num_gpu_blocks: int, + watermark: float): ''' SWA short for Sliding Window Attention. @@ -172,7 +172,7 @@ def test_allocate_encoder_decoder_fails_with_swa(block_size: int, @pytest.mark.parametrize("num_gpu_blocks", [16]) @pytest.mark.parametrize("num_seqs_per_group", [1]) @pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_allocate_encoder_decoder_fails_with_prefix_cache( +def test_can_allocate_encoder_decoder_fails_with_prefix_cache( block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, watermark: float): From e6a7125383488af42dd5020b65824394c9c112e9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 19:10:35 -0400 Subject: [PATCH 061/443] spelling error --- tests/core/block/test_block_manager_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index c893bc8f4209e..19ea89d01ca7a 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -160,7 +160,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, # 1. Block manager v2 constructor fails with assertion that sliding window # is not yet supported (most likely near-term outcome at time of # writing), or - # 2. can_allocate() fails with NotImplementedError due to combiantion of + # 2. can_allocate() fails with NotImplementedError due to combination of # encoder/decoder and sliding window attention if isinstance(exc_info.value, NotImplementedError): assert str(exc_info.value) == str_not_impl_enc_dec_swa From 68b476203ba9c8342e3f6ba5d9db5e7d369a7a52 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 19:14:25 -0400 Subject: [PATCH 062/443] isort --- vllm/core/block_manager_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index d2dadd9a63dc2..b43f39a8ffaef 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -8,6 +8,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device + ''' Exception strings for non-implemented encoder/decoder scenarios ''' From a80325dcbe4af189e3542f00ffe92a11a7243e92 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 25 May 2024 21:45:13 -0400 Subject: [PATCH 063/443] return output of SequenceGroup constructor --- tests/core/utils.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index 376af0f0eac4f..fb53b6cc5e18b 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -145,14 +145,11 @@ def create_seq_group_encoder_decoder( block_size=16, ) - seq_group = SequenceGroup(request_id=request_id, - seqs=seqs, - sampling_params=sampling_params, - arrival_time=time.time(), - encoder_seq=encoder_seq) - - return seq_group - + return SequenceGroup(request_id=request_id, + seqs=seqs, + sampling_params=sampling_params, + arrival_time=time.time(), + encoder_seq=encoder_seq) def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size From 8b387767512a657fd0051c674f4a594159b67eee Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 25 May 2024 21:56:25 -0400 Subject: [PATCH 064/443] capitalize constants --- tests/core/block/test_block_manager_v2.py | 8 ++++---- tests/core/test_block_manager.py | 12 ++++++------ vllm/core/block_manager_v1.py | 17 ++++++++--------- vllm/core/block_manager_v2.py | 12 ++++++------ 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 19ea89d01ca7a..3aed0c58bd264 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -1,8 +1,8 @@ import pytest from vllm.core.block_manager_v2 import (BlockSpaceManagerV2, - str_not_impl_enc_dec_prefix_cache, - str_not_impl_enc_dec_swa) + STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list @@ -163,7 +163,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, # 2. can_allocate() fails with NotImplementedError due to combination of # encoder/decoder and sliding window attention if isinstance(exc_info.value, NotImplementedError): - assert str(exc_info.value) == str_not_impl_enc_dec_swa + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA elif isinstance(exc_info.value, AssertionError): assert str(exc_info.value) == "Sliding window not yet supported" @@ -201,7 +201,7 @@ def test_can_allocate_encoder_decoder_fails_with_prefix_cache( # due to combination of encoder/decoder and prefix cache with pytest.raises(NotImplementedError) as exc_info: block_manager.can_allocate(seq_group) - assert str(exc_info.value) == str_not_impl_enc_dec_prefix_cache + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE @pytest.mark.parametrize("block_size", [1, 8]) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 6039f568fcf1e..7e487a021d3c2 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -8,8 +8,8 @@ from vllm.block import PhysicalTokenBlock from vllm.core.block_manager_v1 import (BlockSpaceManagerV1, UncachedBlockAllocator, - str_not_impl_enc_dec_prefix_cache, - str_not_impl_enc_dec_swa) + STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device @@ -150,13 +150,13 @@ def test_allocate_encoder_decoder_fails_with_swa(): with pytest.raises(NotImplementedError) as exc_info: block_manager.can_allocate(seq_group) - assert str(exc_info.value) == str_not_impl_enc_dec_swa + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA # Assert that allocate() fails due to SWA with pytest.raises(NotImplementedError) as exc_info: block_manager.allocate(seq_group) - assert str(exc_info.value) == str_not_impl_enc_dec_swa + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA def test_allocate_encoder_decoder_fails_with_prefix_caching(): @@ -179,13 +179,13 @@ def test_allocate_encoder_decoder_fails_with_prefix_caching(): with pytest.raises(NotImplementedError) as exc_info: block_manager.can_allocate(seq_group) - assert str(exc_info.value) == str_not_impl_enc_dec_prefix_cache + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE # Assert that allocate() fails due to prefix caching with pytest.raises(NotImplementedError) as exc_info: block_manager.allocate(seq_group) - assert str(exc_info.value) == str_not_impl_enc_dec_prefix_cache + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE def test_append_slot_single_seq(): diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 904b12cd97b01..312690ee45893 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -19,11 +19,11 @@ Exception strings for non-implemented encoder/decoder scenarios ''' -str_not_impl_enc_dec_swa = \ +STR_NOT_IMPL_ENC_DEC_SWA = \ "Sliding window attention for encoder/decoder models " + \ "is not currently supported." -str_not_impl_enc_dec_prefix_cache = \ +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ "Prefix caching for encoder/decoder models " + \ "is not currently supported." @@ -272,9 +272,8 @@ def __init__( self.cross_block_tables: Dict[str, BlockTable] = {} def _get_seq_num_required_blocks(self, seq: Sequence) -> int: - if seq is None: - return 0 - return len(seq.logical_token_blocks) + return 0 if seq is None \ + else len(seq.logical_token_blocks) def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share @@ -282,7 +281,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: is_encoder_decoder = seq_group.is_encoder_decoder() if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) self_num_required_blocks = self._get_seq_num_required_blocks( seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) @@ -293,7 +292,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: if self.block_sliding_window is not None: if is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_swa) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) num_required_blocks = min(num_required_blocks, self.block_sliding_window) @@ -340,10 +339,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: if (self.block_sliding_window is not None) and \ is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_swa) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) # Allocate decoder sequences # diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b43f39a8ffaef..6113561032dd1 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -13,11 +13,11 @@ Exception strings for non-implemented encoder/decoder scenarios ''' -str_not_impl_enc_dec_swa = \ +STR_NOT_IMPL_ENC_DEC_SWA = \ "Sliding window attention for encoder/decoder models " + \ "is not currently supported." -str_not_impl_enc_dec_prefix_cache = \ +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ "Prefix caching for encoder/decoder models " + \ "is not currently supported." @@ -106,10 +106,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: is_encoder_decoder = seq_group.is_encoder_decoder() if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) if self.block_sliding_window is not None and is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_swa) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( @@ -181,10 +181,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: if (self.block_sliding_window is not None) and \ is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_swa) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) if is_encoder_decoder: block_table = self._allocate_sequence(encoder_seq) From f39c3132af87d410507644c9ea86aec1156f3533 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 25 May 2024 22:20:06 -0400 Subject: [PATCH 065/443] refactored swap-block-table functionality --- vllm/core/block_manager_v1.py | 68 +++++++++++++++-------------------- 1 file changed, 29 insertions(+), 39 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 312690ee45893..90a485b39e9d6 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -541,23 +541,25 @@ def can_swap_in(self, else: return AllocStatus.LATER - def _swap_in_block_table( + def _swap_block_table( self, block_table: BlockTable, + src_allocator: BlockAllocatorBase, + dest_allocator: BlockAllocatorBase, mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock]) -> BlockTable: new_block_table = [] - for cpu_block in block_table: - if cpu_block in mapping: - gpu_block = mapping[cpu_block] - gpu_block.ref_count += 1 + for from_block in block_table: + if from_block in mapping: + to_block = mapping[from_block] + to_block.ref_count += 1 else: - gpu_block = self.gpu_allocator.allocate( - cpu_block.block_hash, cpu_block.num_hashed_tokens) - mapping[cpu_block] = gpu_block - new_block_table.append(gpu_block) - # Free the CPU block swapped in to GPU. - self.cpu_allocator.free(cpu_block) + to_block = dest_allocator.allocate( + from_block.block_hash, from_block.num_hashed_tokens) + mapping[from_block] = to_block + new_block_table.append(to_block) + # Free the source block swapped in to destination. + src_allocator.free(from_block) return new_block_table @@ -574,13 +576,17 @@ def swap_in(self, mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): self.block_tables[seq.seq_id] = \ - self._swap_in_block_table(self.block_tables[seq.seq_id], - mapping) + self._swap_block_table(self.block_tables[seq.seq_id], + self.cpu_allocator, + self.gpu_allocator, + mapping) if seq_group.is_encoder_decoder(): self.cross_block_tables[request_id] = \ - self._swap_in_block_table(self.cross_block_tables[request_id], - mapping) + self._swap_block_table(self.cross_block_tables[request_id], + self.cpu_allocator, + self.gpu_allocator, + mapping) return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] @@ -589,26 +595,6 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() - def _swap_out_block_table( - self, block_table: BlockTable, - mapping: Dict[PhysicalTokenBlock, - PhysicalTokenBlock]) -> BlockTable: - - new_block_table: BlockTable = [] - for gpu_block in block_table: - if gpu_block in mapping: - cpu_block = mapping[gpu_block] - cpu_block.ref_count += 1 - else: - cpu_block = self.cpu_allocator.allocate( - gpu_block.block_hash, gpu_block.num_hashed_tokens) - mapping[gpu_block] = cpu_block - new_block_table.append(cpu_block) - # Free the GPU block swapped out to CPU. - self.gpu_allocator.free(gpu_block) - - return new_block_table - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: request_id = seq_group.request_id @@ -617,13 +603,17 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): self.block_tables[seq.seq_id] = \ - self._swap_out_block_table(self.block_tables[seq.seq_id], - mapping) + self._swap_block_table(self.block_tables[seq.seq_id], + self.gpu_allocator, + self.cpu_allocator, + mapping) if seq_group.is_encoder_decoder(): self.cross_block_tables[request_id] = \ - self._swap_out_block_table(self.cross_block_tables[request_id], - mapping) + self._swap_block_table(self.cross_block_tables[request_id], + self.gpu_allocator, + self.cpu_allocator, + mapping) return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] From 90b5a0e5303c937e56c5b8893fc0cbaeb985ac3f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 25 May 2024 22:51:09 -0400 Subject: [PATCH 066/443] Refactored block manager + enc dec + unsupported feature checks into utils --- tests/core/block/test_block_manager_v2.py | 6 ++-- tests/core/test_block_manager.py | 6 ++-- tests/core/utils.py | 1 + vllm/core/block/utils.py | 41 +++++++++++++++++++++++ vllm/core/block_manager_v1.py | 34 ++++--------------- vllm/core/block_manager_v2.py | 35 ++++--------------- 6 files changed, 60 insertions(+), 63 deletions(-) create mode 100644 vllm/core/block/utils.py diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 3aed0c58bd264..f1488916b508a 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -1,8 +1,8 @@ import pytest -from vllm.core.block_manager_v2 import (BlockSpaceManagerV2, - STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - STR_NOT_IMPL_ENC_DEC_SWA) +from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) +from vllm.core.block_manager_v2 import BlockSpaceManagerV2 from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 7e487a021d3c2..2264fe80c9c03 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -6,10 +6,10 @@ from vllm import SamplingParams from vllm.block import PhysicalTokenBlock +from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) from vllm.core.block_manager_v1 import (BlockSpaceManagerV1, - UncachedBlockAllocator, - STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - STR_NOT_IMPL_ENC_DEC_SWA) + UncachedBlockAllocator) from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device diff --git a/tests/core/utils.py b/tests/core/utils.py index fb53b6cc5e18b..7ac565c0eccf1 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -151,5 +151,6 @@ def create_seq_group_encoder_decoder( arrival_time=time.time(), encoder_seq=encoder_seq) + def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py new file mode 100644 index 0000000000000..6599011771cea --- /dev/null +++ b/vllm/core/block/utils.py @@ -0,0 +1,41 @@ +"""Block manager utils.""" +from typing import Union + +from vllm.core.block_manager_v1 import BlockSpaceManagerV1 +from vllm.core.block_manager_v2 import BlockSpaceManagerV2 +from vllm.sequence import SequenceGroup + +''' +Exception strings for non-implemented block manager encoder/decoder scenarios +''' + +STR_NOT_IMPL_ENC_DEC_SWA = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." + +def check_no_caching_or_swa_for_blckmgr_encdec( + block_mgr: Union[BlockSpaceManagerV1, + BlockSpaceManagerV2], + seq_group: SequenceGroup) -> None: + ''' + Enforce that prefix caching & sliding-window attention (SWA) + are currently unsupported *specifically* for encoder/decoder models. + + Raises NotImplementedError if unsupported scenario is detected. + + Arguments: + + * block_mgr: BlockSpaceManager instance + * seq_group: SequenceGroup passed to block_mgr + ''' + + if seq_group.is_encoder_decoder(): + if block_mgr.block_sliding_window is not None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) + + if block_mgr.enable_caching: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) \ No newline at end of file diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 90a485b39e9d6..fa64b96a5e7dc 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -8,6 +8,7 @@ from typing import Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock +from vllm.core.block.utils import check_no_caching_or_swa_for_blckmgr_encdec from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger @@ -15,17 +16,6 @@ from vllm.utils import Device logger = init_logger(__name__) -''' -Exception strings for non-implemented encoder/decoder scenarios -''' - -STR_NOT_IMPL_ENC_DEC_SWA = \ - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported." - -STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ - "Prefix caching for encoder/decoder models " + \ - "is not currently supported." class BlockAllocatorBase(ABC): @@ -279,9 +269,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - is_encoder_decoder = seq_group.is_encoder_decoder() - if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) + check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) self_num_required_blocks = self._get_seq_num_required_blocks( seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) @@ -291,8 +279,6 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: cross_num_required_blocks if self.block_sliding_window is not None: - if is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) num_required_blocks = min(num_required_blocks, self.block_sliding_window) @@ -334,15 +320,8 @@ def _allocate_sequence(self, \ return block_table def allocate(self, seq_group: SequenceGroup) -> None: - encoder_seq = seq_group.get_encoder_seq() is_encoder_decoder = seq_group.is_encoder_decoder() - - if (self.block_sliding_window is not None) and \ - is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) - - if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) + check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) # Allocate decoder sequences # @@ -362,8 +341,8 @@ def allocate(self, seq_group: SequenceGroup) -> None: if is_encoder_decoder: # A SequenceGroup has only a single encoder sequence (at most), # thus allocate with a ref count of 1 - block_table = self._allocate_sequence(encoder_seq, 1, - is_encoder_decoder) + block_table = self._allocate_sequence(seq_group.get_encoder_seq(), + 1, is_encoder_decoder) # Assign the cross-attention block table for the SequenceGroup. self.cross_block_tables[seq_group.request_id] = block_table @@ -542,8 +521,7 @@ def can_swap_in(self, return AllocStatus.LATER def _swap_block_table( - self, block_table: BlockTable, - src_allocator: BlockAllocatorBase, + self, block_table: BlockTable, src_allocator: BlockAllocatorBase, dest_allocator: BlockAllocatorBase, mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock]) -> BlockTable: diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 6113561032dd1..246ab9c297c5b 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -5,22 +5,11 @@ from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.utils import check_no_caching_or_swa_for_blckmgr_encdec from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -''' -Exception strings for non-implemented encoder/decoder scenarios -''' - -STR_NOT_IMPL_ENC_DEC_SWA = \ - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported." - -STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ - "Prefix caching for encoder/decoder models " + \ - "is not currently supported." - SeqId = int EncoderSeqId = str @@ -104,12 +93,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - is_encoder_decoder = seq_group.is_encoder_decoder() - if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) - - if self.block_sliding_window is not None and is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) + check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( @@ -117,7 +101,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: block_size=self.block_size, ) - if is_encoder_decoder: + if seq_group.is_encoder_decoder(): num_required_blocks += BlockTable.get_num_required_blocks( seq_group.get_encoder_seq().get_token_ids(), block_size=self.block_size, @@ -151,8 +135,6 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: return block_table def allocate(self, seq_group: SequenceGroup) -> None: - encoder_seq = seq_group.get_encoder_seq() - is_encoder_decoder = seq_group.is_encoder_decoder() # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) @@ -179,15 +161,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: not in self.cross_block_tables), \ "block table already exists" - if (self.block_sliding_window is not None) and \ - is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) - - if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) + check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) - if is_encoder_decoder: - block_table = self._allocate_sequence(encoder_seq) + if seq_group.is_encoder_decoder(): + block_table = self._allocate_sequence(seq_group.get_encoder_seq()) self.cross_block_tables[request_id] = block_table def can_append_slots(self, seq_group: SequenceGroup, From 9ee2582172b2b273ede9cb0e3ced9d9f197ecc0b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 25 May 2024 22:57:02 -0400 Subject: [PATCH 067/443] removed circular import --- vllm/core/block/utils.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 6599011771cea..14b99496b12dc 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,10 +1,5 @@ """Block manager utils.""" -from typing import Union - -from vllm.core.block_manager_v1 import BlockSpaceManagerV1 -from vllm.core.block_manager_v2 import BlockSpaceManagerV2 from vllm.sequence import SequenceGroup - ''' Exception strings for non-implemented block manager encoder/decoder scenarios ''' @@ -17,10 +12,9 @@ "Prefix caching for encoder/decoder models " + \ "is not currently supported." + def check_no_caching_or_swa_for_blckmgr_encdec( - block_mgr: Union[BlockSpaceManagerV1, - BlockSpaceManagerV2], - seq_group: SequenceGroup) -> None: + block_mgr, seq_group: SequenceGroup) -> None: ''' Enforce that prefix caching & sliding-window attention (SWA) are currently unsupported *specifically* for encoder/decoder models. @@ -38,4 +32,4 @@ def check_no_caching_or_swa_for_blckmgr_encdec( raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) if block_mgr.enable_caching: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) \ No newline at end of file + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) From 5d0ac231b751466771f25e9275acede785bf4344 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 25 May 2024 22:58:09 -0400 Subject: [PATCH 068/443] apparently isort has to run last? --- vllm/core/block/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 14b99496b12dc..4113f7e52b84f 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,5 +1,6 @@ """Block manager utils.""" from vllm.sequence import SequenceGroup + ''' Exception strings for non-implemented block manager encoder/decoder scenarios ''' From 1bcc949c7c4634da50d80d7bc4b47185e6ac6f18 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 26 May 2024 12:20:12 -0400 Subject: [PATCH 069/443] slight name change --- vllm/core/block/utils.py | 2 +- vllm/core/block_manager_v1.py | 6 +++--- vllm/core/block_manager_v2.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 4113f7e52b84f..3dee7ff16dd84 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -14,7 +14,7 @@ "is not currently supported." -def check_no_caching_or_swa_for_blckmgr_encdec( +def check_no_caching_or_swa_for_blockmgr_encdec( block_mgr, seq_group: SequenceGroup) -> None: ''' Enforce that prefix caching & sliding-window attention (SWA) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index fa64b96a5e7dc..201cba309f6ef 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -8,7 +8,7 @@ from typing import Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock -from vllm.core.block.utils import check_no_caching_or_swa_for_blckmgr_encdec +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger @@ -269,7 +269,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) self_num_required_blocks = self._get_seq_num_required_blocks( seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) @@ -321,7 +321,7 @@ def _allocate_sequence(self, \ def allocate(self, seq_group: SequenceGroup) -> None: is_encoder_decoder = seq_group.is_encoder_decoder() - check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) # Allocate decoder sequences # diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 246ab9c297c5b..6185a65983d3a 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -5,7 +5,7 @@ from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.core.block.utils import check_no_caching_or_swa_for_blckmgr_encdec +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device @@ -93,7 +93,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( @@ -161,7 +161,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: not in self.cross_block_tables), \ "block table already exists" - check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) if seq_group.is_encoder_decoder(): block_table = self._allocate_sequence(seq_group.get_encoder_seq()) From 1bece71b45331ed5e371a3842e5a1bba5fe7a160 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 12:27:47 -0400 Subject: [PATCH 070/443] wip merge --- vllm/core/block_manager_v2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b19f4b184db94..cad42ab3c1ba2 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -138,7 +138,6 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: block_allocator=self.block_allocator, max_block_sliding_window=self.max_block_sliding_window, ) - assert self.block_sliding_window is None block_table.allocate(seq.get_token_ids()) return block_table From 1d882ca8d5825ab68988740e81796abadd083b06 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 12:38:45 -0400 Subject: [PATCH 071/443] fixed utils to correctly handle encoder/decoder unsupported scenarios --- vllm/core/block/utils.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 3dee7ff16dd84..dd9345ab52d40 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -13,6 +13,26 @@ "Prefix caching for encoder/decoder models " + \ "is not currently supported." +def _get_block_mgr_sliding_window_attr(block_mgr): + ''' + BlockManagerV1 and BlockManagerV2 have slightly different + members related to sliding window attention (SWA). This + function extracts the appropriate member to use for determining + whether SWA is enabled. + + Arguments: + + * block_mgr: BlockManagerV1 or BlockManagerV2 instance + ''' + + if hasattr(block_mgr, 'block_sliding_window'): + return block_mgr.block_sliding_window + if hasattr(block_mgr, 'max_block_sliding_window'): + return block_mgr.max_block_sliding_window + + raise AttributeError("Block manager instance has neither " + \ + "block_sliding_window nor " + \ + "max_block_sliding_window attributes.") def check_no_caching_or_swa_for_blockmgr_encdec( block_mgr, seq_group: SequenceGroup) -> None: @@ -29,7 +49,7 @@ def check_no_caching_or_swa_for_blockmgr_encdec( ''' if seq_group.is_encoder_decoder(): - if block_mgr.block_sliding_window is not None: + if _get_block_mgr_sliding_window_attr(block_mgr) is not None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) if block_mgr.enable_caching: From dfd94692e0b35343e64aace3cd4a496564be5809 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 12:39:17 -0400 Subject: [PATCH 072/443] formatting --- vllm/core/block/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index dd9345ab52d40..c582ab270473c 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -13,6 +13,7 @@ "Prefix caching for encoder/decoder models " + \ "is not currently supported." + def _get_block_mgr_sliding_window_attr(block_mgr): ''' BlockManagerV1 and BlockManagerV2 have slightly different @@ -34,6 +35,7 @@ def _get_block_mgr_sliding_window_attr(block_mgr): "block_sliding_window nor " + \ "max_block_sliding_window attributes.") + def check_no_caching_or_swa_for_blockmgr_encdec( block_mgr, seq_group: SequenceGroup) -> None: ''' From 3c3687e9f59269268264e9f058ef82220fbac4ea Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 12:52:02 -0400 Subject: [PATCH 073/443] renamed xformers metadata is_cross_attn to is_encoder_decoder_attn --- vllm/attention/backends/xformers.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 5059fd8dc265b..b45eda9a68dd5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -119,7 +119,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # If True, prefill_metadata() and decode_metadata() will return # seqlen & memory-mapping data structures for cross-attention; # otherwise, self-attention data structures will be returned. - is_cross_attn: bool = False + is_encoder_decoder_attn: bool = False # (batch_size,). The "cross-sequence-length" per sequence,i.e. the key/value # sequence length (usually encoder sequence length) in the cross-attention @@ -166,7 +166,7 @@ def has_valid_cross_attn_metadata(self): @property def do_cross_attn(self): - return self.is_cross_attn + return self.is_encoder_decoder_attn @do_cross_attn.setter def do_cross_attn(self, state: bool): @@ -188,9 +188,9 @@ def do_cross_attn(self, state: bool): assert self.cross_seq_lens is not None self.max_cross_seq_len = max(self.cross_seq_lens) - self.is_cross_attn = True + self.is_encoder_decoder_attn = True else: - self.is_cross_attn = False + self.is_encoder_decoder_attn = False @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: @@ -225,7 +225,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_cross_attn=False, # Begin cross-attention fields below... + is_encoder_decoder_attn=False, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -261,7 +261,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_cross_attn=True, # Begin cross-attention fields below... + is_encoder_decoder_attn=True, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, @@ -297,7 +297,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_cross_attn=False, # Begin cross-attention fields below... + is_encoder_decoder_attn=False, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -328,7 +328,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_cross_attn=True, # Begin cross-attention fields below... + is_encoder_decoder_attn=True, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, @@ -593,7 +593,7 @@ def _run_memory_efficient_xformers_forward( # FIXME(woosuk): This is a hack. if attn_metadata.attn_bias is None: if self.alibi_slopes is None: - if attn_metadata.is_cross_attn: + if attn_metadata.is_encoder_decoder_attn: attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.cross_seq_lens) else: From 6f07c77ef3f1367369a2b5d96b5d0ed576b0a5ff Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 13:04:36 -0400 Subject: [PATCH 074/443] wip getting tests to pass after merge --- tests/kernels/test_self_and_cross_attn.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 04dc52d51af19..3cc60e5412d11 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -14,7 +14,8 @@ # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] # # TODO: FlashAttention forward only supports head dimension at most 128 -# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 +# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d0 +# 37782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] @@ -113,7 +114,7 @@ def make_qkv(batch_size, max_kv_seq_len, num_heads, head_size, - is_cross_attn=True, + is_encoder_decoder_attn=True, force_max_len=False, device=CUDA_DEVICE): ''' @@ -137,12 +138,12 @@ def make_qkv(batch_size, * max_kv_seq_len: max key/value seq len * num_heads * head_size - * is_cross_attn: if True, query seqlen may differ from key/value seqlen (as + * is_encoder_decoder_attn: if True, query seqlen may differ from key/value seqlen (as is often the case for cross-attention); o/w, query/key/value seqlens match at each batch index (max_kv_seq_len is unused) * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens - and max_kv_seq_len, unless forced by is_cross_attn=False + and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False * device: CPU or CUDA device Returns: @@ -178,7 +179,7 @@ def make_qkv(batch_size, random.randint(2, max_q_seq_len) for _ in range(batch_size) ] kv_seq_lens = None - if not is_cross_attn: + if not is_encoder_decoder_attn: # K,V seq lens match Q for self-attention kv_seq_lens = q_seq_lens else: @@ -644,7 +645,7 @@ def make_metadata_self_cross( context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - is_cross_attn=False, + is_encoder_decoder_attn=False, cross_seq_lens=cross_seq_lens, cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) @@ -685,7 +686,7 @@ def make_metadata_self_cross( context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - is_cross_attn=False, + is_encoder_decoder_attn=False, cross_seq_lens=cross_seq_lens, cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) @@ -840,7 +841,7 @@ def self_attn_setup(batch_size, max_kv_seq_len, num_heads, head_size, - is_cross_attn=False) + is_encoder_decoder_attn=False) causal_mask = build_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) @@ -1004,7 +1005,7 @@ def cross_attn_setup_reuses_query(query, max_kv_seq_len, num_heads, head_size, - is_cross_attn=True) + is_encoder_decoder_attn=True) ideal_output = ref_masked_attention(query, key, From 481c6463e8b7f7f744fd799bb945301a52182118 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 14:54:50 -0400 Subject: [PATCH 075/443] passing tests; formatting --- tests/kernels/test_self_and_cross_attn.py | 7 ++++--- vllm/attention/backends/xformers.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 3cc60e5412d11..d99a246712425 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -138,9 +138,10 @@ def make_qkv(batch_size, * max_kv_seq_len: max key/value seq len * num_heads * head_size - * is_encoder_decoder_attn: if True, query seqlen may differ from key/value seqlen (as - is often the case for cross-attention); o/w, query/key/value seqlens match - at each batch index (max_kv_seq_len is unused) + * is_encoder_decoder_attn: if True, query seqlen may differ from + key/value seqlen (as is often the case for cross-attention); + o/w, query/key/value seqlens match at each batch index + (max_kv_seq_len is unused) * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b45eda9a68dd5..6886b4836bd87 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -225,7 +225,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_encoder_decoder_attn=False, # Begin cross-attention fields below... + is_encoder_decoder_attn= + False, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -261,7 +262,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_encoder_decoder_attn=True, # Begin cross-attention fields below... + is_encoder_decoder_attn= + True, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, @@ -297,7 +299,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_encoder_decoder_attn=False, # Begin cross-attention fields below... + is_encoder_decoder_attn= + False, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -328,7 +331,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_encoder_decoder_attn=True, # Begin cross-attention fields below... + is_encoder_decoder_attn= + True, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, From 9c8e19d3bc8a56b1ad31c58d786a9e4c25c593b2 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 15:21:29 -0400 Subject: [PATCH 076/443] removed overprovisioning from make_block_tables_slot_mapping() --- tests/kernels/test_self_and_cross_attn.py | 25 +++++++---------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index d99a246712425..83576c52b8688 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -510,7 +510,7 @@ def make_block_tables_slot_mapping(block_size, # Over-provision block table blocks by 1 num_blocks_list = [ - num_tokens_to_min_blocks(num_tokens, block_size) + 1 + num_tokens_to_min_blocks(num_tokens, block_size) for num_tokens in seq_lens ] max_block_table_len = max(num_blocks_list) @@ -521,7 +521,7 @@ def make_block_tables_slot_mapping(block_size, decode_slot_mapping = [] slot_mapping = [] block_base_idx = block_base_addr + sum( - num_blocks_list) * 2 - 1 # Support more blocks than needed + num_blocks_list) # Support more blocks than needed max_block_idx = block_base_idx for sdx, num_tokens in enumerate(seq_lens): num_blocks = num_blocks_list[sdx] @@ -692,21 +692,6 @@ def make_metadata_self_cross( cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) - -def make_attention(num_heads: int, head_size: int, scale: float): - ''' - Construct an instance of the Attention wrapper, suited to the number of - attention heads and head dimension (num_heads and head_size respectively) as - well as the attention scale factor (scale) - ''' - - return Attention( - num_heads, - head_size, - scale=scale, - ) - - def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): ''' Compute & build entities required for the self-/cross-attention test. @@ -730,7 +715,11 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): scale = float(1.0 / (head_size**0.5)) attn_backend = make_backend(backend_name) - attn = make_attention(num_heads, head_size, scale) + attn = Attention( + num_heads, + head_size, + scale=scale, + ) kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache From ed17ee38478c6b67cdbb63d6cb7f929a9bd2a08b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 15:25:06 -0400 Subject: [PATCH 077/443] comments' --- tests/kernels/test_self_and_cross_attn.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 83576c52b8688..f132bf571defa 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -481,7 +481,7 @@ def make_block_tables_slot_mapping(block_size, The first block is at - block_base_addr + sum(num_blocks_list) * 2 - 1 + block_base_addr + sum(min. block count for each seq_len) and subsequent blocks count downward toward block_base_addr @@ -508,7 +508,7 @@ def make_block_tables_slot_mapping(block_size, * max_block_idx: the highest block address within this block table ''' - # Over-provision block table blocks by 1 + # Provision minimum number of KV cache blocks num_blocks_list = [ num_tokens_to_min_blocks(num_tokens, block_size) for num_tokens in seq_lens @@ -692,6 +692,7 @@ def make_metadata_self_cross( cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) + def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): ''' Compute & build entities required for the self-/cross-attention test. @@ -716,10 +717,10 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): scale = float(1.0 / (head_size**0.5)) attn_backend = make_backend(backend_name) attn = Attention( - num_heads, - head_size, - scale=scale, - ) + num_heads, + head_size, + scale=scale, + ) kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache From d630aa8090463bdc12554e63d79dda1ed7caa253 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 15:34:08 -0400 Subject: [PATCH 078/443] clarified block table address formula --- tests/kernels/test_self_and_cross_attn.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index f132bf571defa..a4379b1ece49f 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -479,11 +479,23 @@ def make_block_tables_slot_mapping(block_size, ''' Construct fake block tables & slot mappings. - The first block is at + For a sequence with num_tokens tokens the minimum number + of required KV cache blocks is - block_base_addr + sum(min. block count for each seq_len) + num_blocks = (num_tokens + block_size) // block_size - and subsequent blocks count downward toward block_base_addr + Then the minimum KV cache size in blocks is + + total_cache_blocks = sum(num_blocks for all seqs) + + Then, the blocktable mapping counts downward from + + block_base_addr + total_cache_blocks + + to + + block_base_addr + Arguments: @@ -520,8 +532,9 @@ def make_block_tables_slot_mapping(block_size, prefill_slot_mapping = [] decode_slot_mapping = [] slot_mapping = [] - block_base_idx = block_base_addr + sum( - num_blocks_list) # Support more blocks than needed + # Compute uppermost address of block table + total_cache_blocks = sum(num_blocks_list) + block_base_idx = block_base_addr + total_cache_blocks max_block_idx = block_base_idx for sdx, num_tokens in enumerate(seq_lens): num_blocks = num_blocks_list[sdx] From b664806905cee4697470907557f323bc25fd9ddb Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 22:22:00 -0400 Subject: [PATCH 079/443] wip changing cross attention flag --- tests/kernels/test_self_and_cross_attn.py | 11 ++-- vllm/attention/backends/abstract.py | 7 +++ vllm/attention/backends/xformers.py | 77 +++++++++++++---------- 3 files changed, 58 insertions(+), 37 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index a4379b1ece49f..e2dcf5b02f165 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -7,7 +7,8 @@ import torch from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionType) from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import make_tensor_with_pad @@ -114,7 +115,7 @@ def make_qkv(batch_size, max_kv_seq_len, num_heads, head_size, - is_encoder_decoder_attn=True, + attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len=False, device=CUDA_DEVICE): ''' @@ -180,7 +181,7 @@ def make_qkv(batch_size, random.randint(2, max_q_seq_len) for _ in range(batch_size) ] kv_seq_lens = None - if not is_encoder_decoder_attn: + if attn_type != AttentionType.ENCODER_DECODER: # K,V seq lens match Q for self-attention kv_seq_lens = q_seq_lens else: @@ -845,7 +846,7 @@ def self_attn_setup(batch_size, max_kv_seq_len, num_heads, head_size, - is_encoder_decoder_attn=False) + attn_type=False) causal_mask = build_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) @@ -1009,7 +1010,7 @@ def cross_attn_setup_reuses_query(query, max_kv_seq_len, num_heads, head_size, - is_encoder_decoder_attn=True) + attn_type=True) ideal_output = ref_masked_attention(query, key, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 6396103bf5efa..15e9a7fa5af3a 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -5,6 +5,13 @@ import torch +from enum import Enum, auto + +class AttentionType(Enum): + DECODER = auto() # Decoder attention between previously layer Q/K/V + ENCODER = auto() # Encoder attention between previously layer Q/K/V + ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V + class AttentionBackend(ABC): """Abstract class for attention backends.""" diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6886b4836bd87..144cb68bbff0b 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -10,11 +10,12 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger + logger = init_logger(__name__) @@ -119,7 +120,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # If True, prefill_metadata() and decode_metadata() will return # seqlen & memory-mapping data structures for cross-attention; # otherwise, self-attention data structures will be returned. - is_encoder_decoder_attn: bool = False + _attn_type: AttentionType = AttentionType.DECODER # (batch_size,). The "cross-sequence-length" per sequence,i.e. the key/value # sequence length (usually encoder sequence length) in the cross-attention @@ -165,13 +166,13 @@ def has_valid_cross_attn_metadata(self): return True @property - def do_cross_attn(self): - return self.is_encoder_decoder_attn + def attention_type(self) -> AttentionType: + return self._attn_type - @do_cross_attn.setter - def do_cross_attn(self, state: bool): + @attention_type.setter + def attention_type(self, atype: AttentionType) -> None: - if state: + if atype == AttentionType.ENCODER_DECODER: assert self.has_valid_cross_attn_metadata, \ "Must have self.cross_seq_lens not None " + \ "in order to enable cross-attention" @@ -188,17 +189,18 @@ def do_cross_attn(self, state: bool): assert self.cross_seq_lens is not None self.max_cross_seq_len = max(self.cross_seq_lens) - self.is_encoder_decoder_attn = True + self._attn_type = AttentionType.ENCODER_DECODER else: - self.is_encoder_decoder_attn = False + # AttentionType.{ENCODER,DECODER} + self._attn_type = atype @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self.num_prefills == 0: return None - if not self.do_cross_attn: - # Self-attention prefill + if self._attn_type != AttentionType.ENCODER_DECODER: + # Decoder or encoder self-attention prefill if self._self_cached_prefill_metadata is not None: return self._self_cached_prefill_metadata @@ -225,8 +227,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_encoder_decoder_attn= - False, # Begin cross-attention fields below... + _attn_type= + self._attn_type, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -235,7 +237,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: return self._self_cached_prefill_metadata else: - # Cross-attention prefill + # Encoder/decoder cross-attention prefill if self._cross_cached_prefill_metadata is not None: return self._cross_cached_prefill_metadata @@ -262,8 +264,9 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_encoder_decoder_attn= - True, # Begin cross-attention fields below... + _attn_type= + AttentionType.ENCODER_DECODER, + # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, @@ -276,8 +279,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: if self.num_decode_tokens == 0: return None - if not self.do_cross_attn: - # Self-attention decode + if self._attn_type != AttentionType.ENCODER_DECODER: + # Decoder or encoder self-attention prefill if self._self_cached_decode_metadata is not None: return self._self_cached_decode_metadata @@ -299,8 +302,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_encoder_decoder_attn= - False, # Begin cross-attention fields below... + _attn_type= + self._attn_type, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -309,7 +312,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: return self._self_cached_decode_metadata else: - # Cross-attention decode + # Encoder/decoder cross-attention decode if self._cross_cached_decode_metadata is not None: return self._cross_cached_decode_metadata @@ -331,8 +334,9 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_encoder_decoder_attn= - True, # Begin cross-attention fields below... + _attn_type= + AttentionType.ENCODER_DECODER, + # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, @@ -443,7 +447,7 @@ def forward( # Self-attention vs. cross-attention will impact # which KV cache memory-mapping & which # seqlen datastructures we utilize - do_cross_attn = attn_metadata.do_cross_attn + attn_type = attn_metadata._attn_type if (kv_cache is not None): # Even if there are no new key/value pairs to cache, @@ -454,7 +458,7 @@ def forward( if (key is not None) and (value is not None): - if do_cross_attn: + if attn_type == AttentionType.ENCODER_DECODER: # Update cross-attention KV cache (prefill-only) # During cross-attention decode, key & value will be None, # preventing this IF-statement branch from running @@ -476,9 +480,9 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert do_cross_attn or (key.shape[0] + assert attn_type == AttentionType.ENCODER_DECODER or (key.shape[0] == num_prefill_tokens + num_decode_tokens) - assert do_cross_attn or (value.shape[0] + assert attn_type == AttentionType.ENCODER_DECODER or (value.shape[0] == num_prefill_tokens + num_decode_tokens) output = torch.empty_like(query) @@ -487,7 +491,9 @@ def forward( # QKV for prefill. query = query[:num_prefill_tokens] - if not do_cross_attn and key is not None and value is not None: + if attn_type != AttentionType.ENCODER_DECODER \ + and key is not None and value is not None: + key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] @@ -529,7 +535,7 @@ def forward( output[:num_prefill_tokens] = out if decode_meta := attn_metadata.decode_metadata: - if do_cross_attn: + if attn_type == AttentionType.ENCODER_DECODER: # Paged attention against cross-attention KV cache seq_lens_arg = decode_meta.cross_seq_lens_tensor max_seq_len_arg = decode_meta.max_cross_seq_len @@ -597,12 +603,19 @@ def _run_memory_efficient_xformers_forward( # FIXME(woosuk): This is a hack. if attn_metadata.attn_bias is None: if self.alibi_slopes is None: - if attn_metadata.is_encoder_decoder_attn: + if attn_metadata.attention_type() == AttentionType.ENCODER_DECODER: + # Default enc/dec cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.cross_seq_lens) else: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seq_lens) + if attn_metadata.attention_type() == AttentionType.ENCODER: + # Default encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens) + else: + # Default decoder self-attention mask is causal + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) From 611df433882c1e10235084426d63fd817466dd19 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 22:27:41 -0400 Subject: [PATCH 080/443] yapf fix --- vllm/core/block/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index c582ab270473c..4da5a965616ac 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,6 +1,5 @@ """Block manager utils.""" from vllm.sequence import SequenceGroup - ''' Exception strings for non-implemented block manager encoder/decoder scenarios ''' From 8ee49dde309a93fd309f0117f74cde4949e958e4 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 22:30:12 -0400 Subject: [PATCH 081/443] yapf fix --- vllm/core/block/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 4da5a965616ac..2c412a8f472e0 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,8 +1,7 @@ """Block manager utils.""" from vllm.sequence import SequenceGroup -''' -Exception strings for non-implemented block manager encoder/decoder scenarios -''' + +# Exception strings for non-implemented block manager enc/dec scenarios STR_NOT_IMPL_ENC_DEC_SWA = \ "Sliding window attention for encoder/decoder models " + \ From 039c25eb6661f2aa89b4239235451f2c6f61d63d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 23:03:44 -0400 Subject: [PATCH 082/443] upstream merge --- tests/core/utils.py | 36 +++++++++++++++++++++++++++--------- vllm/core/block/utils.py | 1 + 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index 1ccc5c3cc0a8e..cd2045b8a1889 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -55,12 +55,24 @@ def create_dummy_prompt_encoder_decoder( # and prompt "0 ... block_size". decoder_prompt_tokens = list(range(decoder_prompt_length)) decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) - decoder_prompt = Sequence(int(request_id), decoder_prompt_str, - decoder_prompt_tokens, block_size) + + decoder_prompt = Sequence(int(request_id), + inputs={ + "prompt": decoder_prompt_str, + "prompt_token_ids": decoder_prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) + encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) - encoder_prompt = Sequence(int(request_id), encoder_prompt_str, - encoder_prompt_tokens, block_size) + encoder_prompt = Sequence(int(request_id), + inputs={ + "prompt": encoder_prompt_str, + "prompt_token_ids": encoder_prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) seq_group = SequenceGroup(request_id=request_id, seqs=[decoder_prompt], sampling_params=SamplingParams( @@ -134,8 +146,11 @@ def create_seq_group_encoder_decoder( for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) @@ -149,8 +164,11 @@ def create_seq_group_encoder_decoder( # Encoder sequence encoder_seq = Sequence( seq_id=seq_id_start + len(seq_output_lens), - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) @@ -162,4 +180,4 @@ def create_seq_group_encoder_decoder( def round_up_to_next_block(seq_len: int, block_size: int) -> int: - return (seq_len + block_size - 1) // block_size + return (seq_len + block_size - 1) // block_size \ No newline at end of file diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 4da5a965616ac..c582ab270473c 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,5 +1,6 @@ """Block manager utils.""" from vllm.sequence import SequenceGroup + ''' Exception strings for non-implemented block manager encoder/decoder scenarios ''' From 8e9ef5bb5ae7bc3ece7ae527e591df093ff7f31e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 23:06:08 -0400 Subject: [PATCH 083/443] fix formatting issue --- vllm/core/block/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index c582ab270473c..372bfb5ed2f9e 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,9 +1,7 @@ """Block manager utils.""" from vllm.sequence import SequenceGroup -''' -Exception strings for non-implemented block manager encoder/decoder scenarios -''' +# Exception strings for non-implemented block manager encoder/decoder scenarios STR_NOT_IMPL_ENC_DEC_SWA = \ "Sliding window attention for encoder/decoder models " + \ From 19d1ca5a6471a55603f257b7c6f6f1364b9d9b0e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 11:32:16 -0400 Subject: [PATCH 084/443] passing tests with new attention type enum --- tests/kernels/test_self_and_cross_attn.py | 64 ++++++++++++++++------- vllm/attention/backends/xformers.py | 6 +-- 2 files changed, 49 insertions(+), 21 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index e2dcf5b02f165..b36212abf01d1 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -593,6 +593,7 @@ def make_metadata_self_cross( context_lens: List[int], block_tables, slot_mapping, + is_encoder_only_test: bool, device=CUDA_DEVICE, cross_seq_lens: Optional[List[int]] = None, cross_block_tables: Optional[torch.Tensor] = None, @@ -625,6 +626,9 @@ def make_metadata_self_cross( * AttentionMetadata structure supporting self- and cross-attention ''' + default_attn_type = AttentionType.ENCODER if is_encoder_only_test \ + else AttentionType.DECODER + if is_prompt: num_prefills = len(seq_lens) num_prefill_tokens = sum(seq_lens) @@ -660,7 +664,7 @@ def make_metadata_self_cross( context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - is_encoder_decoder_attn=False, + _attn_type=default_attn_type, cross_seq_lens=cross_seq_lens, cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) @@ -701,7 +705,7 @@ def make_metadata_self_cross( context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - is_encoder_decoder_attn=False, + _attn_type=default_attn_type, cross_seq_lens=cross_seq_lens, cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) @@ -745,6 +749,7 @@ def self_attn_setup(batch_size, block_size, scale, max_q_seq_len, + attn_type: AttentionType, block_base_addr=0): ''' Set up test vectors & data structures for self-attention test. @@ -846,7 +851,7 @@ def self_attn_setup(batch_size, max_kv_seq_len, num_heads, head_size, - attn_type=False) + attn_type=attn_type) causal_mask = build_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) @@ -1010,7 +1015,7 @@ def cross_attn_setup_reuses_query(query, max_kv_seq_len, num_heads, head_size, - attn_type=True) + attn_type=AttentionType.ENCODER_DECODER) ideal_output = ref_masked_attention(query, key, @@ -1062,8 +1067,9 @@ def cross_attn_setup_reuses_query(query, def run_self_attention_test(attn: Attention, packed_query, packed_key, packed_value, kv_cache, - attn_metadata: AttentionMetadata): - attn_metadata.do_cross_attn = False + attn_metadata: AttentionMetadata, + attn_type: AttentionType): + attn_metadata.attention_type = attn_type return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) @@ -1071,10 +1077,27 @@ def run_self_attention_test(attn: Attention, packed_query, packed_key, def run_cross_attention_test(attn: Attention, packed_query, packed_key, packed_value, kv_cache, attn_metadata: AttentionMetadata): - attn_metadata.do_cross_attn = True + attn_metadata.attention_type = AttentionType.ENCODER_DECODER return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) +@pytest.mark.skip() +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) +@pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) +def test_encoder_attention(num_heads: int, + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_q_seq_len: int, + max_kv_seq_len: int) -> None: + + pass @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -1083,15 +1106,15 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) @pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) -def test_prefill_decode_self_and_cross_attention(num_heads: int, - head_size: int, - backend_name: str, - batch_size: int, - block_size: int, - max_q_seq_len: int, - max_kv_seq_len: int) -> None: +def test_enc_dec_self_and_cross_attention_prefill_decode_phases(num_heads: int, + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_q_seq_len: int, + max_kv_seq_len: int) -> None: ''' - Test: + Encoder/decoder attention test: * Construct fake test vectors for self- and cross-attention * Construct attention metadata structure with self- and cross-attention @@ -1159,6 +1182,7 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, block_size, scale, max_q_seq_len, + attn_type=AttentionType.DECODER, block_base_addr=self_block_base_addr) # Cross-attention setup @@ -1195,6 +1219,7 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, context_lens, self_prefill_block_tables, self_prefill_slot_mapping, + is_encoder_only_test=False, cross_seq_lens=cross_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, @@ -1202,7 +1227,8 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( attn, prefill_packed_query, self_prefill_packed_key, - self_prefill_packed_value, kv_cache, prefill_attn_metadata) + self_prefill_packed_value, kv_cache, prefill_attn_metadata, + attn_type=AttentionType.DECODER) # - Prefill self-attention correct? assert torch.allclose( @@ -1231,6 +1257,7 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, context_lens, self_decode_block_tables, self_decode_slot_mapping, + is_encoder_only_test=False, cross_seq_lens=cross_kv_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, @@ -1238,7 +1265,8 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( attn, decode_packed_query, self_decode_packed_key, - self_decode_packed_value, kv_cache, decode_attn_metadata) + self_decode_packed_value, kv_cache, decode_attn_metadata, + attn_type=AttentionType.DECODER) # - Decode self-attention correct? assert torch.allclose( @@ -1253,4 +1281,4 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, assert torch.allclose( cross_decode_packed_ideal_output, cross_decode_packed_actual_output.view_as( - cross_decode_packed_ideal_output)) + cross_decode_packed_ideal_output)) \ No newline at end of file diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 144cb68bbff0b..152f339e27485 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -447,7 +447,7 @@ def forward( # Self-attention vs. cross-attention will impact # which KV cache memory-mapping & which # seqlen datastructures we utilize - attn_type = attn_metadata._attn_type + attn_type = attn_metadata.attention_type if (kv_cache is not None): # Even if there are no new key/value pairs to cache, @@ -603,12 +603,12 @@ def _run_memory_efficient_xformers_forward( # FIXME(woosuk): This is a hack. if attn_metadata.attn_bias is None: if self.alibi_slopes is None: - if attn_metadata.attention_type() == AttentionType.ENCODER_DECODER: + if attn_metadata.attention_type == AttentionType.ENCODER_DECODER: # Default enc/dec cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.cross_seq_lens) else: - if attn_metadata.attention_type() == AttentionType.ENCODER: + if attn_metadata.attention_type == AttentionType.ENCODER: # Default encoder self-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens) From 700b6dca120d859a5b2a8d89f4b88a2e51187a86 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 11:33:39 -0400 Subject: [PATCH 085/443] formatting --- tests/kernels/test_self_and_cross_attn.py | 43 ++++++++++++----------- vllm/attention/backends/abstract.py | 8 ++--- vllm/attention/backends/xformers.py | 26 +++++++------- 3 files changed, 38 insertions(+), 39 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index b36212abf01d1..c9ae74e788754 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -7,8 +7,7 @@ import torch from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionType) +from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import make_tensor_with_pad @@ -1081,6 +1080,7 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) + @pytest.mark.skip() @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -1089,16 +1089,13 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) @pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) -def test_encoder_attention(num_heads: int, - head_size: int, - backend_name: str, - batch_size: int, - block_size: int, - max_q_seq_len: int, - max_kv_seq_len: int) -> None: +def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, + batch_size: int, block_size: int, + max_q_seq_len: int, max_kv_seq_len: int) -> None: pass + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -1106,13 +1103,9 @@ def test_encoder_attention(num_heads: int, @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) @pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) -def test_enc_dec_self_and_cross_attention_prefill_decode_phases(num_heads: int, - head_size: int, - backend_name: str, - batch_size: int, - block_size: int, - max_q_seq_len: int, - max_kv_seq_len: int) -> None: +def test_enc_dec_self_and_cross_attention_prefill_decode_phases( + num_heads: int, head_size: int, backend_name: str, batch_size: int, + block_size: int, max_q_seq_len: int, max_kv_seq_len: int) -> None: ''' Encoder/decoder attention test: @@ -1226,8 +1219,12 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases(num_heads: int, ) self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( - attn, prefill_packed_query, self_prefill_packed_key, - self_prefill_packed_value, kv_cache, prefill_attn_metadata, + attn, + prefill_packed_query, + self_prefill_packed_key, + self_prefill_packed_value, + kv_cache, + prefill_attn_metadata, attn_type=AttentionType.DECODER) # - Prefill self-attention correct? @@ -1264,8 +1261,12 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases(num_heads: int, ) self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( - attn, decode_packed_query, self_decode_packed_key, - self_decode_packed_value, kv_cache, decode_attn_metadata, + attn, + decode_packed_query, + self_decode_packed_key, + self_decode_packed_value, + kv_cache, + decode_attn_metadata, attn_type=AttentionType.DECODER) # - Decode self-attention correct? @@ -1281,4 +1282,4 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases(num_heads: int, assert torch.allclose( cross_decode_packed_ideal_output, cross_decode_packed_actual_output.view_as( - cross_decode_packed_ideal_output)) \ No newline at end of file + cross_decode_packed_ideal_output)) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 15e9a7fa5af3a..cffd2d577777c 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,16 +1,16 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields +from enum import Enum, auto from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) import torch -from enum import Enum, auto class AttentionType(Enum): - DECODER = auto() # Decoder attention between previously layer Q/K/V - ENCODER = auto() # Encoder attention between previously layer Q/K/V - ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V + DECODER = auto() # Decoder attention between previously layer Q/K/V + ENCODER = auto() # Encoder attention between previously layer Q/K/V + ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V class AttentionBackend(ABC): diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 152f339e27485..3e6fe0717b0e7 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -15,7 +15,6 @@ PagedAttentionMetadata) from vllm.logger import init_logger - logger = init_logger(__name__) @@ -227,8 +226,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - _attn_type= - self._attn_type, # Begin cross-attention fields below... + _attn_type=self. + _attn_type, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -264,8 +263,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - _attn_type= - AttentionType.ENCODER_DECODER, + _attn_type=AttentionType.ENCODER_DECODER, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, @@ -302,8 +300,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - _attn_type= - self._attn_type, # Begin cross-attention fields below... + _attn_type=self. + _attn_type, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -334,8 +332,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - _attn_type= - AttentionType.ENCODER_DECODER, + _attn_type=AttentionType.ENCODER_DECODER, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, @@ -480,10 +477,10 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert attn_type == AttentionType.ENCODER_DECODER or (key.shape[0] - == num_prefill_tokens + num_decode_tokens) - assert attn_type == AttentionType.ENCODER_DECODER or (value.shape[0] - == num_prefill_tokens + num_decode_tokens) + assert attn_type == AttentionType.ENCODER_DECODER or ( + key.shape[0] == num_prefill_tokens + num_decode_tokens) + assert attn_type == AttentionType.ENCODER_DECODER or ( + value.shape[0] == num_prefill_tokens + num_decode_tokens) output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. @@ -603,7 +600,8 @@ def _run_memory_efficient_xformers_forward( # FIXME(woosuk): This is a hack. if attn_metadata.attn_bias is None: if self.alibi_slopes is None: - if attn_metadata.attention_type == AttentionType.ENCODER_DECODER: + if attn_metadata.attention_type == \ + AttentionType.ENCODER_DECODER: # Default enc/dec cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.cross_seq_lens) From 76c639a461ff30762e23001e73ada922ee8d7c3f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 11:40:27 -0400 Subject: [PATCH 086/443] wip encoder test --- tests/kernels/test_self_and_cross_attn.py | 189 +++++++++++++++++++++- 1 file changed, 187 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index c9ae74e788754..02c58a1d57909 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -718,8 +718,9 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): * num_heads: Number of attention heads * head_size: Head dimension - * num_blocks: Number of KV cache blocks + * num_blocks: Number of KV cache blocks (no KV cache if None) * block_size: Number of offsets within a KV cache block + (no KV cache if None) * backend_name: selection of backend Returns: @@ -729,6 +730,7 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): * attn: Attention wrapper instance * kv_cache: fake KV cache, 2 x num_blocks x (block_size * num_heads * head_size) + * None if num_blocks or block_size is None ''' scale = float(1.0 / (head_size**0.5)) @@ -738,8 +740,14 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): head_size, scale=scale, ) + if num_blocks is None or num_heads is None: + # Caller does not require a KV cache + return scale, attn_backend, attn, None + + # Construct KV cache kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache + def self_attn_setup(batch_size, @@ -1093,7 +1101,184 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_seq_len: int, max_kv_seq_len: int) -> None: - pass + ''' + Encoder-only attention test: + + * Construct fake test vectors for self- and cross-attention + * Construct attention metadata structure with self- and cross-attention + attributes + * Test self- and cross-attention in the following order + + * Prefill self-attention + * Prefill cross-attention + * Decode self-attention + * Decode cross-attention + * This order would exacerbate any accidental overlap in the + self-/cross-attention block tables, which we attempt to avoid + * Validate output correctness against ideal reference attention + implementation + + Block tables are constructed such that cross-attention KV cache is in a + higher, non-intersecting address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V + tensors. Self-attention K/Vs must have the same seq len as Q while + cross-attention K/Vs are allowed to differ in seq len, as is often the case + for cross-attention. + ''' + + # Num KV cache blocks + # num_blocks = 4096 + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + None, + None, + backend_name) + + # Self-attention setup + + self_block_base_addr = 0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + _, \ + _, \ + q_seq_lens, \ + _, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = self_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + attn_type=AttentionType.DECODER, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + cross_kv_seq_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + _ = cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross( + attn_backend, + True, + prefill_q_seq_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + is_encoder_only_test=False, + cross_seq_lens=cross_kv_seq_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( + attn, + prefill_packed_query, + self_prefill_packed_key, + self_prefill_packed_value, + kv_cache, + prefill_attn_metadata, + attn_type=AttentionType.DECODER) + + # - Prefill self-attention correct? + assert torch.allclose( + self_prefill_packed_ideal_output, + self_prefill_packed_actual_output.view_as( + self_prefill_packed_ideal_output)) + + cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test( + attn, prefill_packed_query, cross_prefill_packed_key, + cross_prefill_packed_value, kv_cache, prefill_attn_metadata) + + # - Prefill cross-attention correct? + assert torch.allclose( + cross_prefill_packed_ideal_output, + cross_prefill_packed_actual_output.view_as( + cross_prefill_packed_ideal_output)) + + context_lens = copy.deepcopy(self_prefill_kv_seq_lens) + + # DECODE: self- and cross-attention tests + + decode_attn_metadata: AttentionMetadata = make_metadata_self_cross( + attn_backend, + False, + q_seq_lens, + context_lens, + self_decode_block_tables, + self_decode_slot_mapping, + is_encoder_only_test=False, + cross_seq_lens=cross_kv_seq_lens, + cross_block_tables=cross_decode_block_tables, + cross_slot_mapping=cross_decode_slot_mapping, + ) + + self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( + attn, + decode_packed_query, + self_decode_packed_key, + self_decode_packed_value, + kv_cache, + decode_attn_metadata, + attn_type=AttentionType.DECODER) + + # - Decode self-attention correct? + assert torch.allclose( + self_decode_packed_ideal_output, + self_decode_packed_actual_output.view_as( + self_decode_packed_ideal_output)) + + cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test( + attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) + + # - Decode cross-attention correct? + assert torch.allclose( + cross_decode_packed_ideal_output, + cross_decode_packed_actual_output.view_as( + cross_decode_packed_ideal_output)) + @pytest.mark.parametrize("num_heads", NUM_HEADS) From 882640e51a3727c058503e0fc04c91c32f2e11bf Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 12:09:33 -0400 Subject: [PATCH 087/443] first pass at encoder attention test --- tests/kernels/test_self_and_cross_attn.py | 333 ++++++++++++---------- 1 file changed, 190 insertions(+), 143 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 02c58a1d57909..64f7ec0eaac40 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -748,16 +748,154 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache +def encoder_attn_setup(batch_size, + num_heads, + head_size, + scale, + max_q_seq_len): + ''' + Set up test vectors & data structures for encoder attention test. + + A triplet of synthetic query/key/value tensors are constructed ("baseline" + query/key/value). Given this is a self-attention test, the key & value + sequences will have the same length as the corresponding queries. + + "Prefill" query/key/value tensors are derived by masking out the last value + in each baseline query/key/value. These tensors are used to test prefill & + populate KV cache for a subsequent decode test. + + "Decode" query/key/value tensors are derived by extracting *only* the last + value from each baseline query/key/value (i.e. complement of the prefill + tensors.) These tensors are used to test decode, conditional on the kv cache + being populated during the prefill test. + + The baseline query/key/value tensors are passed to an ideal reference + self-attention implementation to generate a "Baseline" ideal output tensor. + This tensor is split into the "Prefill" ideal output tensor (all but the + last element of each output sequence) and the "Decode" ideal output tensor + (*only* the last element of each output sequence); the "Prefill" and + "Decode" ideal output tensors can be used to validate the prefill and decode + test results, respectively. + + This function also constructs the self-attention KV cache memory mapping + (slot mapping and block table), ensuring that the block table starts at + block_base_addr + + Arguments: + + * batch_size + * num_heads: Number of attention heads + * head_size: Head dimension + * block_size: Number of offsets per KV cache block + * scale: attention scale parameter + * max_q_seq_len: upper limit on query length for synthetic test vectors + * block_base_addr: self-attention block table base address + + Returns: + + * query: "baseline" query; batch_size x padded_seq_len x num_heads x + head_size + * prefill_packed_query: "prefill" query; number_of_tokens x num_heads x + head_size + * prefill_packed_key: self-attn "prefill" key; number_of_tokens x num_heads + x head_size + * prefill_packed_value: self-attn "prefill" value; number_of_tokens x + num_heads x head_size + * prefill_packed_ideal_output: self-attn "prefill" ideal output; + number_of_tokens x num_heads x head_size + * prefill_q_seq_lens: list of token counts for each *prefill query* (one + less than baseline query) + * prefill_kv_seq_lens: list of token counts for each self-attn *prefill + key/value* (should match prefill_q_seq_lens) + * decode_packed_query: "decode" query; number_of_tokens x num_heads x + head_size + * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x + head_size + * decode_packed_value: self-attn "decode" key; number_of_tokens x num_heads + x head_size + * decode_packed_ideal_output: self-attn "decode" ideal output; + number_of_tokens x num_heads x head_size + * decode_q_seq_lens: list of token counts for each *decode query* (should + be 1) + * decode_kv_seq_lens: list of token counts for each self-attn *decode + key/value* (should match decode_q_seq_lens) + * q_seq_lens: "baseline" query seq lens; number_of_tokens x num_heads x + head_size + * kv_seq_lens: self-attn "baseline" key/value seq lens; number_of_tokens + x num_heads x head_size + * decode_block_tables: fake self-attn decode-phase block table + * decode_slot_mapping: fake self-attn decode-phase slot mapping + * prefill_slot_mapping: fake self-attn prefill-phase slot mapping + * prefill_block_tables: fake self-attn prefill-phase block table + * max_block_idx: highest block address in the self-attention block-table + ''' + max_kv_seq_len = max_q_seq_len -def self_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - attn_type: AttentionType, - block_base_addr=0): + query, \ + key, \ + value, \ + prefill_query, \ + prefill_key, \ + prefill_value, \ + decode_query, \ + decode_key, \ + decode_value, \ + q_seq_lens, \ + kv_seq_lens, \ + _, \ + _, \ + prefill_q_seq_lens, \ + prefill_kv_seq_lens, \ + decode_q_seq_lens, \ + decode_kv_seq_lens = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.ENCODER) + + # No attention mask + ideal_output = ref_masked_attention(query, + key, + value, + scale=scale, + q_seq_lens=q_seq_lens, + kv_seq_lens=kv_seq_lens) + + prefill_ideal_output = torch.zeros_like(ideal_output) + for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): + prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ + bdx, :prefill_q_seq_len] + + prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, + prefill_q_seq_lens) + + prefill_packed_query, \ + prefill_packed_key, \ + prefill_packed_value, _, _ = pack_qkv( + prefill_query, prefill_key, prefill_value, prefill_q_seq_lens, + prefill_kv_seq_lens) + + return query, \ + prefill_packed_query, \ + prefill_packed_key, \ + prefill_packed_value, \ + prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + prefill_kv_seq_lens, \ + decode_q_seq_lens, \ + decode_kv_seq_lens, \ + q_seq_lens, \ + kv_seq_lens + +def decoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + block_base_addr=0): ''' Set up test vectors & data structures for self-attention test. @@ -858,7 +996,7 @@ def self_attn_setup(batch_size, max_kv_seq_len, num_heads, head_size, - attn_type=attn_type) + attn_type=AttentionType.DECODER) causal_mask = build_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) @@ -929,17 +1067,17 @@ def self_attn_setup(batch_size, max_block_idx -def cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=0): +def enc_dec_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=0): ''' Set up test vectors & data structures for cross-attention test. @@ -1135,7 +1273,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, scale, \ attn_backend, \ attn, \ - kv_cache = basic_setup(num_heads, + _ = basic_setup(num_heads, head_size, None, None, @@ -1146,140 +1284,50 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, self_block_base_addr = 0 query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ + packed_query, \ + packed_key, \ + packed_value, \ + packed_ideal_output, \ prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ + prefill_kv_seq_lens, \ _, \ _, \ q_seq_lens, \ - _, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ - cross_block_base_addr = self_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - attn_type=AttentionType.DECODER, - block_base_addr=self_block_base_addr) - - # Cross-attention setup - - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ - cross_kv_seq_lens, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ - _ = cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) - - # PREFILL: self- and cross-attention tests + kv_seq_lens = encoder_attn_setup(batch_size, + num_heads, + head_size, + scale, + max_q_seq_len) context_lens = [0 for _ in range(batch_size)] - prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross( + attn_metadata: AttentionMetadata = make_metadata_self_cross( attn_backend, True, prefill_q_seq_lens, context_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, - ) - - self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( - attn, - prefill_packed_query, - self_prefill_packed_key, - self_prefill_packed_value, - kv_cache, - prefill_attn_metadata, - attn_type=AttentionType.DECODER) - - # - Prefill self-attention correct? - assert torch.allclose( - self_prefill_packed_ideal_output, - self_prefill_packed_actual_output.view_as( - self_prefill_packed_ideal_output)) - - cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test( - attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, prefill_attn_metadata) - - # - Prefill cross-attention correct? - assert torch.allclose( - cross_prefill_packed_ideal_output, - cross_prefill_packed_actual_output.view_as( - cross_prefill_packed_ideal_output)) - - context_lens = copy.deepcopy(self_prefill_kv_seq_lens) - - # DECODE: self- and cross-attention tests - - decode_attn_metadata: AttentionMetadata = make_metadata_self_cross( - attn_backend, - False, - q_seq_lens, - context_lens, - self_decode_block_tables, - self_decode_slot_mapping, - is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, - cross_block_tables=cross_decode_block_tables, - cross_slot_mapping=cross_decode_slot_mapping, + None, + None, + is_encoder_only_test=True, + cross_seq_lens=None, + cross_block_tables=None, + cross_slot_mapping=None, ) - self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( + packed_actual_output: torch.Tensor = run_self_attention_test( attn, - decode_packed_query, - self_decode_packed_key, - self_decode_packed_value, - kv_cache, - decode_attn_metadata, - attn_type=AttentionType.DECODER) - - # - Decode self-attention correct? + packed_query, + packed_key, + packed_value, + None, + attn_metadata, + attn_type=AttentionType.ENCODER) + + # - Is encoder attention result correct? assert torch.allclose( - self_decode_packed_ideal_output, - self_decode_packed_actual_output.view_as( - self_decode_packed_ideal_output)) - - cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test( - attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) - - # - Decode cross-attention correct? - assert torch.allclose( - cross_decode_packed_ideal_output, - cross_decode_packed_actual_output.view_as( - cross_decode_packed_ideal_output)) - - + packed_ideal_output, + packed_actual_output.view_as( + packed_ideal_output)) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -1354,13 +1402,12 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_slot_mapping, \ self_prefill_slot_mapping, \ self_prefill_block_tables, \ - cross_block_base_addr = self_attn_setup(batch_size, + cross_block_base_addr = decoder_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_seq_len, - attn_type=AttentionType.DECODER, block_base_addr=self_block_base_addr) # Cross-attention setup @@ -1374,7 +1421,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = cross_attn_setup_reuses_query(query, + _ = enc_dec_attn_setup_reuses_query(query, q_seq_lens, prefill_q_seq_lens, batch_size, From 584297e391915d4eae3ae73a2fea47d79a58cf95 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 12:46:07 -0400 Subject: [PATCH 088/443] wip encoder attention test --- tests/kernels/test_self_and_cross_attn.py | 39 +++++++++++++---------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 64f7ec0eaac40..027f17826aacf 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -751,8 +751,10 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): def encoder_attn_setup(batch_size, num_heads, head_size, + block_size, scale, - max_q_seq_len): + max_q_seq_len, + block_base_addr=0): ''' Set up test vectors & data structures for encoder attention test. @@ -871,6 +873,15 @@ def encoder_attn_setup(batch_size, prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, prefill_q_seq_lens) + _, \ + _, \ + prefill_slot_mapping, \ + prefill_block_tables, \ + _, \ + _, \ + _ = make_block_tables_slot_mapping( + block_size, q_seq_lens, block_base_addr=block_base_addr) + prefill_packed_query, \ prefill_packed_key, \ prefill_packed_value, _, _ = pack_qkv( @@ -884,10 +895,8 @@ def encoder_attn_setup(batch_size, prefill_packed_ideal_output, \ prefill_q_seq_lens, \ prefill_kv_seq_lens, \ - decode_q_seq_lens, \ - decode_kv_seq_lens, \ - q_seq_lens, \ - kv_seq_lens + prefill_slot_mapping, \ + prefill_block_tables def decoder_attn_setup(batch_size, num_heads, @@ -1227,7 +1236,7 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, attn_metadata) -@pytest.mark.skip() +#@pytest.mark.skip() @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -1280,22 +1289,20 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, backend_name) # Self-attention setup - - self_block_base_addr = 0 - - query, \ + # Let encoder_attn_setup() choose default block table + # base address + _, \ packed_query, \ packed_key, \ packed_value, \ packed_ideal_output, \ prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - _, \ _, \ - q_seq_lens, \ - kv_seq_lens = encoder_attn_setup(batch_size, + slot_mapping, \ + block_tables = encoder_attn_setup(batch_size, num_heads, head_size, + block_size, scale, max_q_seq_len) @@ -1306,8 +1313,8 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, True, prefill_q_seq_lens, context_lens, - None, - None, + block_tables, + slot_mapping, is_encoder_only_test=True, cross_seq_lens=None, cross_block_tables=None, From a89c7c678b965cce38fb74ee3688cc74d218aee1 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 13:10:49 -0400 Subject: [PATCH 089/443] encoder attention test passes! --- tests/kernels/test_self_and_cross_attn.py | 80 +++++++++++------------ 1 file changed, 38 insertions(+), 42 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 027f17826aacf..def055d356d6a 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -837,25 +837,25 @@ def encoder_attn_setup(batch_size, query, \ key, \ value, \ - prefill_query, \ - prefill_key, \ - prefill_value, \ - decode_query, \ - decode_key, \ - decode_value, \ + _, \ + _, \ + _, \ + _, \ + _, \ + _, \ q_seq_lens, \ kv_seq_lens, \ _, \ _, \ - prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - decode_q_seq_lens, \ - decode_kv_seq_lens = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.ENCODER) + _, \ + _, \ + _, \ + _ = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.ENCODER) # No attention mask ideal_output = ref_masked_attention(query, @@ -865,38 +865,36 @@ def encoder_attn_setup(batch_size, q_seq_lens=q_seq_lens, kv_seq_lens=kv_seq_lens) - prefill_ideal_output = torch.zeros_like(ideal_output) - for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): - prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ - bdx, :prefill_q_seq_len] + # prefill_ideal_output = torch.zeros_like(ideal_output) + # for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): + # prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ + # bdx, :prefill_q_seq_len] - prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_seq_lens) + packed_ideal_output, _ = pack_tensor(ideal_output, + q_seq_lens) + block_tables, \ _, \ _, \ - prefill_slot_mapping, \ - prefill_block_tables, \ _, \ + slot_mapping, \ _, \ _ = make_block_tables_slot_mapping( block_size, q_seq_lens, block_base_addr=block_base_addr) - prefill_packed_query, \ - prefill_packed_key, \ - prefill_packed_value, _, _ = pack_qkv( - prefill_query, prefill_key, prefill_value, prefill_q_seq_lens, - prefill_kv_seq_lens) + packed_query, \ + packed_key, \ + packed_value, _, _ = pack_qkv( + query, key, value, q_seq_lens, + kv_seq_lens) - return query, \ - prefill_packed_query, \ - prefill_packed_key, \ - prefill_packed_value, \ - prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - prefill_slot_mapping, \ - prefill_block_tables + return packed_query, \ + packed_key, \ + packed_value, \ + packed_ideal_output, \ + block_tables, \ + slot_mapping, \ + q_seq_lens def decoder_attn_setup(batch_size, num_heads, @@ -1291,15 +1289,13 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, # Self-attention setup # Let encoder_attn_setup() choose default block table # base address - _, \ packed_query, \ packed_key, \ packed_value, \ packed_ideal_output, \ - prefill_q_seq_lens, \ - _, \ + block_tables, \ slot_mapping, \ - block_tables = encoder_attn_setup(batch_size, + q_seq_lens = encoder_attn_setup(batch_size, num_heads, head_size, block_size, @@ -1311,7 +1307,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_metadata: AttentionMetadata = make_metadata_self_cross( attn_backend, True, - prefill_q_seq_lens, + q_seq_lens, context_lens, block_tables, slot_mapping, From 0bbd0db0f260d1b027fb46f47a417ab4d3532600 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 13:15:47 -0400 Subject: [PATCH 090/443] formatting --- tests/kernels/test_self_and_cross_attn.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index def055d356d6a..f3c5c3fd08cc5 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -747,7 +747,8 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): # Construct KV cache kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache - + + def encoder_attn_setup(batch_size, num_heads, head_size, @@ -870,8 +871,7 @@ def encoder_attn_setup(batch_size, # prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ # bdx, :prefill_q_seq_len] - packed_ideal_output, _ = pack_tensor(ideal_output, - q_seq_lens) + packed_ideal_output, _ = pack_tensor(ideal_output, q_seq_lens) block_tables, \ _, \ @@ -896,6 +896,7 @@ def encoder_attn_setup(batch_size, slot_mapping, \ q_seq_lens + def decoder_attn_setup(batch_size, num_heads, head_size, @@ -1245,7 +1246,6 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_seq_len: int, max_kv_seq_len: int) -> None: - ''' Encoder-only attention test: @@ -1327,10 +1327,9 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose( - packed_ideal_output, - packed_actual_output.view_as( - packed_ideal_output)) + assert torch.allclose(packed_ideal_output, + packed_actual_output.view_as(packed_ideal_output)) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) From af998ca8afad19e82320d15119d342bfbbca31bb Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 13:45:51 -0400 Subject: [PATCH 091/443] encoder test arguments --- tests/kernels/test_self_and_cross_attn.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index def055d356d6a..b0c8a143c45f3 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1233,18 +1233,15 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) - -#@pytest.mark.skip() @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) -@pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) +@pytest.mark.parametrize("max_seq_len", MAX_Q_SEQ_LENS) def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, - max_q_seq_len: int, max_kv_seq_len: int) -> None: + max_seq_len: int) -> None: ''' Encoder-only attention test: @@ -1300,7 +1297,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, head_size, block_size, scale, - max_q_seq_len) + max_seq_len) context_lens = [0 for _ in range(batch_size)] From 78c678add4d25a2c13fd7e5ae8966a83bafee933 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 14:42:02 -0400 Subject: [PATCH 092/443] type hints; formatting --- tests/kernels/test_self_and_cross_attn.py | 349 +++++++++++----------- 1 file changed, 176 insertions(+), 173 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index b0c8a143c45f3..6e2a3ed19ef76 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1,7 +1,7 @@ import copy import itertools import random -from typing import List, Optional +from typing import List, Optional, Union import pytest import torch @@ -29,7 +29,8 @@ MAX_K_SEQ_LENS = [128] -def build_causal_mask(q_max_seq_len, kv_max_seq_len): +def build_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ + -> torch.Tensor: ''' Create a q_max_seq_len x kv_max_seq_len causal mask @@ -109,14 +110,14 @@ def ref_masked_attention(query: torch.Tensor, return out -def make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, +def make_qkv(batch_size: int, + max_q_seq_len: int, + max_kv_seq_len: int, + num_heads: int, + head_size: int, attn_type: AttentionType = AttentionType.ENCODER_DECODER, - force_max_len=False, - device=CUDA_DEVICE): + force_max_len: bool = False, + device: Union[torch.device, str] = CUDA_DEVICE) -> tuple: ''' Construct QKV test tensors for self- and cross-attention. @@ -276,7 +277,9 @@ def make_qkv(batch_size, decode_kv_seq_lens -def pack_tensor(unpacked_tensor, seq_lens, device=CUDA_DEVICE): +def pack_tensor(unpacked_tensor: torch.Tensor, + seq_lens: List[int], + device: Union[torch.device, str] = CUDA_DEVICE) -> tuple: ''' Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an unpadded number_of_tokens x num_heads x head_size tensor, where @@ -309,7 +312,8 @@ def pack_tensor(unpacked_tensor, seq_lens, device=CUDA_DEVICE): return packed_tensor, start_loc_list -def pack_qkv(query, key, value, q_seq_lens, kv_seq_lens): +def pack_qkv(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + q_seq_lens: List[int], kv_seq_lens: List[int]) -> tuple: ''' Individually pack each of Q, K and V, each with dimensions batch_size x padded_seq_len x num_heads x head_size, into respective number_of_tokens x @@ -379,7 +383,8 @@ def make_backend(backend_name: str) -> AttentionBackend: def make_metadata_tensors(is_prompt: bool, seq_lens: List[int], context_lens: List[int], - device=CUDA_DEVICE) -> tuple: + device: Union[torch.device, str] = \ + CUDA_DEVICE) -> tuple: ''' Build scalar & tensor values required to build attention metadata structure. @@ -434,12 +439,13 @@ def make_metadata_tensors(is_prompt: bool, query_start_loc -def make_kv_cache(num_blocks, - num_heads, - head_size, - block_size, - device=CUDA_DEVICE, - default_val=0.0): +def make_kv_cache(num_blocks: int, + num_heads: int, + head_size: int, + block_size: int, + device: Union[torch.device, str] = \ + CUDA_DEVICE, + default_val: float=0.0) -> torch.Tensor: ''' Create a fake KV cache. @@ -464,7 +470,7 @@ def make_kv_cache(num_blocks, return kv_cache -def num_tokens_to_min_blocks(num_tokens, block_size): +def num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: ''' Compute the minimum number of blocks required to hold num_tokens tokens, given block_size @@ -472,10 +478,11 @@ def num_tokens_to_min_blocks(num_tokens, block_size): return (num_tokens + block_size) // block_size -def make_block_tables_slot_mapping(block_size, - seq_lens, - block_base_addr=0, - device=CUDA_DEVICE): +def make_block_tables_slot_mapping(block_size: int, + seq_lens: List, + block_base_addr: int=0, + device: Union[torch.device, str] = \ + CUDA_DEVICE) -> tuple: ''' Construct fake block tables & slot mappings. @@ -585,15 +592,15 @@ def make_block_tables_slot_mapping(block_size, max_block_idx -def make_metadata_self_cross( +def make_test_metadata( attn_backend: AttentionBackend, is_prompt: bool, seq_lens: List[int], context_lens: List[int], - block_tables, - slot_mapping, + block_tables: torch.Tensor, + slot_mapping: torch.Tensor, is_encoder_only_test: bool, - device=CUDA_DEVICE, + device: Union[torch.device, str] = CUDA_DEVICE, cross_seq_lens: Optional[List[int]] = None, cross_block_tables: Optional[torch.Tensor] = None, cross_slot_mapping: Optional[List[int]] = None, @@ -602,6 +609,10 @@ def make_metadata_self_cross( Construct fake attention metadata for a combined self-/cross-attention scenario i.e. an encoder/decoder model. + is_encoder_only_test=True causes the default attention metadata attention + type to be AttentionType.ENCODER. False causes the default to + be AttentionType.DECODER. + Assumptions: * No chunked prefill -> a batch is 100% prefill or 100% decode, never both @@ -614,6 +625,8 @@ def make_metadata_self_cross( * context_lens: list of context lengths for each sequence * block_tables: self-attention block tables * slot_mapping: self-attention slot_mapping + * is_encoder_only_test: True if testing encoder; False if testing + decoder self-attention or encoder/decoder cross-attention. * device: CPU or CUDA device * cross_seq_lens: list of token counts for each encoder sequence, if any exist @@ -644,13 +657,9 @@ def make_metadata_self_cross( context_lens, device=device) - slot_mapping_tensor = slot_mapping - - cross_slot_mapping_tensor = cross_slot_mapping - return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, + slot_mapping=slot_mapping, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -665,7 +674,7 @@ def make_metadata_self_cross( use_cuda_graph=False, _attn_type=default_attn_type, cross_seq_lens=cross_seq_lens, - cross_slot_mapping=cross_slot_mapping_tensor, + cross_slot_mapping=cross_slot_mapping, cross_block_tables=cross_block_tables) else: # not is_prompt @@ -685,13 +694,9 @@ def make_metadata_self_cross( context_lens, device=device) - slot_mapping_tensor = slot_mapping - - cross_slot_mapping_tensor = cross_slot_mapping - return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, + slot_mapping=slot_mapping, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -706,11 +711,12 @@ def make_metadata_self_cross( use_cuda_graph=False, _attn_type=default_attn_type, cross_seq_lens=cross_seq_lens, - cross_slot_mapping=cross_slot_mapping_tensor, + cross_slot_mapping=cross_slot_mapping, cross_block_tables=cross_block_tables) -def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): +def basic_setup(num_heads: int, head_size: int, num_blocks: int, + block_size: int, backend_name: str) -> tuple: ''' Compute & build entities required for the self-/cross-attention test. @@ -747,37 +753,24 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): # Construct KV cache kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache - -def encoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=0): + + +def encoder_attn_setup(batch_size: int, + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_q_seq_len: int, + block_base_addr: int = 0) -> tuple: ''' Set up test vectors & data structures for encoder attention test. - A triplet of synthetic query/key/value tensors are constructed ("baseline" - query/key/value). Given this is a self-attention test, the key & value + A triplet of synthetic query/key/value tensors are constructed. + Given this is an encoder attention test, the key & value sequences will have the same length as the corresponding queries. - "Prefill" query/key/value tensors are derived by masking out the last value - in each baseline query/key/value. These tensors are used to test prefill & - populate KV cache for a subsequent decode test. - - "Decode" query/key/value tensors are derived by extracting *only* the last - value from each baseline query/key/value (i.e. complement of the prefill - tensors.) These tensors are used to test decode, conditional on the kv cache - being populated during the prefill test. - - The baseline query/key/value tensors are passed to an ideal reference - self-attention implementation to generate a "Baseline" ideal output tensor. - This tensor is split into the "Prefill" ideal output tensor (all but the - last element of each output sequence) and the "Decode" ideal output tensor - (*only* the last element of each output sequence); the "Prefill" and - "Decode" ideal output tensors can be used to validate the prefill and decode - test results, respectively. + The query/key/value tensors are passed to an ideal reference + self-attention implementation to generate an ideal output tensor. This function also constructs the self-attention KV cache memory mapping (slot mapping and block table), ensuring that the block table starts at @@ -794,42 +787,14 @@ def encoder_attn_setup(batch_size, * block_base_addr: self-attention block table base address Returns: - - * query: "baseline" query; batch_size x padded_seq_len x num_heads x - head_size - * prefill_packed_query: "prefill" query; number_of_tokens x num_heads x - head_size - * prefill_packed_key: self-attn "prefill" key; number_of_tokens x num_heads - x head_size - * prefill_packed_value: self-attn "prefill" value; number_of_tokens x - num_heads x head_size - * prefill_packed_ideal_output: self-attn "prefill" ideal output; - number_of_tokens x num_heads x head_size - * prefill_q_seq_lens: list of token counts for each *prefill query* (one - less than baseline query) - * prefill_kv_seq_lens: list of token counts for each self-attn *prefill - key/value* (should match prefill_q_seq_lens) - * decode_packed_query: "decode" query; number_of_tokens x num_heads x - head_size - * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x - head_size - * decode_packed_value: self-attn "decode" key; number_of_tokens x num_heads - x head_size - * decode_packed_ideal_output: self-attn "decode" ideal output; - number_of_tokens x num_heads x head_size - * decode_q_seq_lens: list of token counts for each *decode query* (should - be 1) - * decode_kv_seq_lens: list of token counts for each self-attn *decode - key/value* (should match decode_q_seq_lens) - * q_seq_lens: "baseline" query seq lens; number_of_tokens x num_heads x - head_size - * kv_seq_lens: self-attn "baseline" key/value seq lens; number_of_tokens - x num_heads x head_size - * decode_block_tables: fake self-attn decode-phase block table - * decode_slot_mapping: fake self-attn decode-phase slot mapping - * prefill_slot_mapping: fake self-attn prefill-phase slot mapping - * prefill_block_tables: fake self-attn prefill-phase block table - * max_block_idx: highest block address in the self-attention block-table + + * packed_query: number_of_tokens x num_heads x head_size + * packed_key: number_of_tokens x num_heads x head_size + * packed_value: number_of_tokens x num_heads x head_size + * packed_ideal_output: number_of_tokens x num_heads x head_size + * block_tables: fake self-attn decode-phase block table + * slot_mapping: fake self-attn decode-phase slot mapping + * q_seq_lens: list of query sequence lengths ''' max_kv_seq_len = max_q_seq_len @@ -857,7 +822,7 @@ def encoder_attn_setup(batch_size, head_size, attn_type=AttentionType.ENCODER) - # No attention mask + # No causal attention mask ideal_output = ref_masked_attention(query, key, value, @@ -865,13 +830,7 @@ def encoder_attn_setup(batch_size, q_seq_lens=q_seq_lens, kv_seq_lens=kv_seq_lens) - # prefill_ideal_output = torch.zeros_like(ideal_output) - # for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): - # prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ - # bdx, :prefill_q_seq_len] - - packed_ideal_output, _ = pack_tensor(ideal_output, - q_seq_lens) + packed_ideal_output, _ = pack_tensor(ideal_output, q_seq_lens) block_tables, \ _, \ @@ -896,13 +855,14 @@ def encoder_attn_setup(batch_size, slot_mapping, \ q_seq_lens -def decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=0): + +def decoder_attn_setup(batch_size: int, + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_q_seq_len: int, + block_base_addr: int = 0) -> tuple: ''' Set up test vectors & data structures for self-attention test. @@ -1074,17 +1034,18 @@ def decoder_attn_setup(batch_size, max_block_idx -def enc_dec_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=0): +def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, + q_seq_lens: List, + prefill_q_seq_lens: List, + batch_size: int, + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_q_seq_len: int, + max_kv_seq_len: int, + block_base_addr: Optional[int]=0) \ + -> tuple: ''' Set up test vectors & data structures for cross-attention test. @@ -1092,7 +1053,7 @@ def enc_dec_attn_setup_reuses_query(query, ("baseline" key/value). Given this is a cross-attention test, we assume query tensors were already synthesized for a prior self-attention test and will be reused for cross-attention. The key & value sequences generated here - will may have a different length than the corresponding queries (as is often + may have a different length than the corresponding queries (as is often the case for cross-attention between decoder and encoder sequences.) Cross attention key & value tensors do not grow during autoregressive @@ -1217,22 +1178,63 @@ def enc_dec_attn_setup_reuses_query(query, max_block_idx -def run_self_attention_test(attn: Attention, packed_query, packed_key, - packed_value, kv_cache, +def run_self_attention_test(attn: Attention, packed_query: torch.Tensor, + packed_key: torch.Tensor, + packed_value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - attn_type: AttentionType): + attn_type: AttentionType) -> torch.Tensor: + ''' + Run encoder attention or decoder self-attention test. + + attn_metadata.attention_type is assigned attn_type in order to configure + the kernel invocation for either encoder or decoder self-attention. + + Arguments: + + * attn: Attention wrapper instance + * packed_{query,key,value}: total_num_tokens x (num_heads*head_size) + * kv_cache + * attn_metadata: attention metadata for encoder/decoder-self attention + * attn_type: AttentionType.DECODER or AttentionType.ENCODER + + Returns: + * Attention.forward() applied to packed_{query,key,value}, kv_cache + & attn_metadata + ''' + assert attn_type in [AttentionType.DECODER, AttentionType.ENCODER] attn_metadata.attention_type = attn_type return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) -def run_cross_attention_test(attn: Attention, packed_query, packed_key, - packed_value, kv_cache, - attn_metadata: AttentionMetadata): +def run_cross_attention_test(attn: Attention, packed_query: torch.Tensor, + packed_key: torch.Tensor, + packed_value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: + ''' + Run encoder/decoder cross-attention test. + + attn_metadata.attention_type is assigned AttentionType.ENCODER_DECODER + in order to configure the kernel invocation for encoder/decoder cross- + attention. + + Arguments: + + * attn: Attention wrapper instance + * packed_{query,key,value}: total_num_tokens x (num_heads*head_size) + * kv_cache + * attn_metadata: attention metadata for encoder/decoder-self attention + + Returns: + * Attention.forward() applied to packed_{query,key,value}, kv_cache + & attn_metadata + ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -1242,50 +1244,49 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_seq_len: int) -> None: - ''' Encoder-only attention test: - * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention - attributes - * Test self- and cross-attention in the following order - - * Prefill self-attention - * Prefill cross-attention - * Decode self-attention - * Decode cross-attention - * This order would exacerbate any accidental overlap in the - self-/cross-attention block tables, which we attempt to avoid + * Construct fake test vectors for encoder attention + * Construct attention metadata structure with encoder-attention- + specific attributes + * Run encoder attention with metadata structure & test vectors * Validate output correctness against ideal reference attention implementation - Block tables are constructed such that cross-attention KV cache is in a - higher, non-intersecting address-space than self-attention KV cache. + Encoder attention (by default) does not restrict which sequence offsets + may attend to each other. Thus the reference ideal reference + implementation does not employ a causal attention mask. - Self- and cross-attention share the same query tensor but not the K/V - tensors. Self-attention K/Vs must have the same seq len as Q while - cross-attention K/Vs are allowed to differ in seq len, as is often the case - for cross-attention. - ''' + Encoder attention does not utilize KV cache however the XFormer backend + requires block_tables & slot_mapping to be non-None and have a valid + structure, thus this test constructs dummy memory-mapping structures. - # Num KV cache blocks - # num_blocks = 4096 + Encoder attention is basically structured like decoder self-attention + in that Q/K/V are all derived from the previous layer output & have + the same sequence length (in contrast to encoder/decoder cross- + attention where K/V are drawn from the encoder hidden states and + may have a different length than Q derived from decoder previous + layer output.) + ''' # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init + # instance. Encoder attention does not require KV cache. scale, \ attn_backend, \ attn, \ _ = basic_setup(num_heads, - head_size, - None, - None, - backend_name) + head_size, + None, + None, + backend_name) # Self-attention setup # Let encoder_attn_setup() choose default block table - # base address + # base address; the block_tables and slot_mapping + # tensors are not actually utilized by encoder attention + # anyway but are required to be present & valid by the + # backend. packed_query, \ packed_key, \ packed_value, \ @@ -1301,7 +1302,13 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, context_lens = [0 for _ in range(batch_size)] - attn_metadata: AttentionMetadata = make_metadata_self_cross( + # Metadata config for encoder attention: + # + # * Use prefill kernel + # * Signal that this is an encoder-only test so that + # metadata attention_type property is correctly + # configured as AttentionType.ENCODER + attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, q_seq_lens, @@ -1309,9 +1316,6 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, block_tables, slot_mapping, is_encoder_only_test=True, - cross_seq_lens=None, - cross_block_tables=None, - cross_slot_mapping=None, ) packed_actual_output: torch.Tensor = run_self_attention_test( @@ -1324,10 +1328,9 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose( - packed_ideal_output, - packed_actual_output.view_as( - packed_ideal_output)) + assert torch.allclose(packed_ideal_output, + packed_actual_output.view_as(packed_ideal_output)) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -1421,7 +1424,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = enc_dec_attn_setup_reuses_query(query, + _ = enc_dec_cross_attn_setup_reuses_query(query, q_seq_lens, prefill_q_seq_lens, batch_size, @@ -1437,7 +1440,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( context_lens = [0 for _ in range(batch_size)] - prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross( + prefill_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, prefill_q_seq_lens, @@ -1479,7 +1482,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # DECODE: self- and cross-attention tests - decode_attn_metadata: AttentionMetadata = make_metadata_self_cross( + decode_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, False, q_seq_lens, From 641f43139bc0b7935e759ef9eb89b2f0d3889484 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 14:46:32 -0400 Subject: [PATCH 093/443] typo --- vllm/attention/backends/abstract.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index cffd2d577777c..ece0da25ee6f2 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -8,8 +8,8 @@ class AttentionType(Enum): - DECODER = auto() # Decoder attention between previously layer Q/K/V - ENCODER = auto() # Encoder attention between previously layer Q/K/V + DECODER = auto() # Decoder attention between previous layer Q/K/V + ENCODER = auto() # Encoder attention between previous layer Q/K/V ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V From c7f54907ba3ecaecee761512d9266bc47c17e310 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 14:53:04 -0400 Subject: [PATCH 094/443] changed helper function naming convention --- tests/kernels/test_self_and_cross_attn.py | 30 +++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 6e2a3ed19ef76..e45ea629a2a44 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1178,11 +1178,11 @@ def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, max_block_idx -def run_self_attention_test(attn: Attention, packed_query: torch.Tensor, - packed_key: torch.Tensor, - packed_value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - attn_type: AttentionType) -> torch.Tensor: +def run_encoder_or_decoder_self_attention_test(attn: Attention, packed_query: torch.Tensor, + packed_key: torch.Tensor, + packed_value: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + attn_type: AttentionType) -> torch.Tensor: ''' Run encoder attention or decoder self-attention test. @@ -1207,11 +1207,11 @@ def run_self_attention_test(attn: Attention, packed_query: torch.Tensor, attn_metadata) -def run_cross_attention_test(attn: Attention, packed_query: torch.Tensor, - packed_key: torch.Tensor, - packed_value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata) -> torch.Tensor: +def run_encoder_decoder_cross_attention_test(attn: Attention, packed_query: torch.Tensor, + packed_key: torch.Tensor, + packed_value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -1318,7 +1318,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, is_encoder_only_test=True, ) - packed_actual_output: torch.Tensor = run_self_attention_test( + packed_actual_output: torch.Tensor = run_encoder_or_decoder_self_attention_test( attn, packed_query, packed_key, @@ -1453,7 +1453,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_slot_mapping=cross_prefill_slot_mapping, ) - self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( + self_prefill_packed_actual_output: torch.Tensor = run_encoder_or_decoder_self_attention_test( attn, prefill_packed_query, self_prefill_packed_key, @@ -1468,7 +1468,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_packed_actual_output.view_as( self_prefill_packed_ideal_output)) - cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test( + cross_prefill_packed_actual_output: torch.Tensor = run_encoder_decoder_cross_attention_test( attn, prefill_packed_query, cross_prefill_packed_key, cross_prefill_packed_value, kv_cache, prefill_attn_metadata) @@ -1495,7 +1495,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_slot_mapping=cross_decode_slot_mapping, ) - self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( + self_decode_packed_actual_output: torch.Tensor = run_encoder_or_decoder_self_attention_test( attn, decode_packed_query, self_decode_packed_key, @@ -1510,7 +1510,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_packed_actual_output.view_as( self_decode_packed_ideal_output)) - cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test( + cross_decode_packed_actual_output: torch.Tensor = run_encoder_decoder_cross_attention_test( attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) # - Decode cross-attention correct? From 9c78f8555dce717521e0bf0c77306c0444f44824 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:05:18 -0400 Subject: [PATCH 095/443] check we are not testing decode-phase/encoder attention --- tests/kernels/test_self_and_cross_attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index e45ea629a2a44..a211b7e2cc210 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1202,6 +1202,7 @@ def run_encoder_or_decoder_self_attention_test(attn: Attention, packed_query: to & attn_metadata ''' assert attn_type in [AttentionType.DECODER, AttentionType.ENCODER] + assert attn_metadata.is_prompt or attn_type != AttentionType.ENCODER attn_metadata.attention_type = attn_type return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) From bf93a9eba5c70f82d7e66c6b2165c6e673467a49 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:08:55 -0400 Subject: [PATCH 096/443] refactoring --- tests/kernels/test_self_and_cross_attn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index a211b7e2cc210..9e6b3c16ee86a 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1189,6 +1189,10 @@ def run_encoder_or_decoder_self_attention_test(attn: Attention, packed_query: to attn_metadata.attention_type is assigned attn_type in order to configure the kernel invocation for either encoder or decoder self-attention. + attn_type must be AttentionType.ENCODER or DECODER; if ENCODER, + attn_metadata.num_decode_tokens must be 0 (i.e. there is no such thing as + "decode-phase enocder attention".) + Arguments: * attn: Attention wrapper instance @@ -1202,7 +1206,7 @@ def run_encoder_or_decoder_self_attention_test(attn: Attention, packed_query: to & attn_metadata ''' assert attn_type in [AttentionType.DECODER, AttentionType.ENCODER] - assert attn_metadata.is_prompt or attn_type != AttentionType.ENCODER + assert attn_metadata.num_decode_tokens==0 or attn_type != AttentionType.ENCODER attn_metadata.attention_type = attn_type return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) From cd759f2819db3848dad6eda4d78fe0da9932ce3e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:21:59 -0400 Subject: [PATCH 097/443] formatting --- tests/kernels/test_self_and_cross_attn.py | 37 +++++++++++++---------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 9e6b3c16ee86a..28348924ca64b 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1178,11 +1178,11 @@ def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, max_block_idx -def run_encoder_or_decoder_self_attention_test(attn: Attention, packed_query: torch.Tensor, - packed_key: torch.Tensor, - packed_value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - attn_type: AttentionType) -> torch.Tensor: +def run_encoder_or_decoder_self_attention_test( + attn: Attention, packed_query: torch.Tensor, packed_key: torch.Tensor, + packed_value: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + attn_type: AttentionType) -> torch.Tensor: ''' Run encoder attention or decoder self-attention test. @@ -1206,17 +1206,17 @@ def run_encoder_or_decoder_self_attention_test(attn: Attention, packed_query: to & attn_metadata ''' assert attn_type in [AttentionType.DECODER, AttentionType.ENCODER] - assert attn_metadata.num_decode_tokens==0 or attn_type != AttentionType.ENCODER + assert attn_metadata.num_decode_tokens == 0 or \ + attn_type != AttentionType.ENCODER attn_metadata.attention_type = attn_type return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) -def run_encoder_decoder_cross_attention_test(attn: Attention, packed_query: torch.Tensor, - packed_key: torch.Tensor, - packed_value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata) -> torch.Tensor: +def run_encoder_decoder_cross_attention_test( + attn: Attention, packed_query: torch.Tensor, packed_key: torch.Tensor, + packed_value: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -1323,7 +1323,8 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, is_encoder_only_test=True, ) - packed_actual_output: torch.Tensor = run_encoder_or_decoder_self_attention_test( + packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( attn, packed_query, packed_key, @@ -1458,7 +1459,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_slot_mapping=cross_prefill_slot_mapping, ) - self_prefill_packed_actual_output: torch.Tensor = run_encoder_or_decoder_self_attention_test( + self_prefill_packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( attn, prefill_packed_query, self_prefill_packed_key, @@ -1473,7 +1475,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_packed_actual_output.view_as( self_prefill_packed_ideal_output)) - cross_prefill_packed_actual_output: torch.Tensor = run_encoder_decoder_cross_attention_test( + cross_prefill_packed_actual_output: torch.Tensor = \ + run_encoder_decoder_cross_attention_test( attn, prefill_packed_query, cross_prefill_packed_key, cross_prefill_packed_value, kv_cache, prefill_attn_metadata) @@ -1500,7 +1503,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_slot_mapping=cross_decode_slot_mapping, ) - self_decode_packed_actual_output: torch.Tensor = run_encoder_or_decoder_self_attention_test( + self_decode_packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( attn, decode_packed_query, self_decode_packed_key, @@ -1515,7 +1519,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_packed_actual_output.view_as( self_decode_packed_ideal_output)) - cross_decode_packed_actual_output: torch.Tensor = run_encoder_decoder_cross_attention_test( + cross_decode_packed_actual_output: torch.Tensor = \ + run_encoder_decoder_cross_attention_test( attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) # - Decode cross-attention correct? From 1af36258ea3f80d7c89bc62fd1c91c445a789ee5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:23:57 -0400 Subject: [PATCH 098/443] removing unnecessary check --- vllm/attention/backends/xformers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3e6fe0717b0e7..90ed07b029b6a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -488,8 +488,7 @@ def forward( # QKV for prefill. query = query[:num_prefill_tokens] - if attn_type != AttentionType.ENCODER_DECODER \ - and key is not None and value is not None: + if key is not None and value is not None: key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] From eb5cf0cbd5f4ae3135f030b0bab4db1333ac61a6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:45:46 -0400 Subject: [PATCH 099/443] unit test for encoder/decoder+chunked prefill non-support; added attention utils file for error strings --- tests/kernels/test_self_and_cross_attn.py | 13 +++++++++++++ vllm/attention/backends/utils.py | 5 +++++ vllm/attention/backends/xformers.py | 17 +++++++++++++---- 3 files changed, 31 insertions(+), 4 deletions(-) create mode 100644 vllm/attention/backends/utils.py diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 28348924ca64b..5128108f52ad6 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -8,6 +8,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType +from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import make_tensor_with_pad @@ -1528,3 +1529,15 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_packed_ideal_output, cross_decode_packed_actual_output.view_as( cross_decode_packed_ideal_output)) + + # Set up a contrived scenario where the attention metadata + # is configured for chunked prefill & encoder/decoder cross- + # attention. Required that this triggers a NotImplementedError. + decode_attn_metadata.num_prefill_tokens = 1 + with pytest.raises(NotImplementedError) as exc_info: + run_encoder_decoder_cross_attention_test(attn, decode_packed_query, + None, None, kv_cache, + decode_attn_metadata) + + # "Encoder decoder models do not currently support chunked prefill" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py new file mode 100644 index 0000000000000..f893460cce06e --- /dev/null +++ b/vllm/attention/backends/utils.py @@ -0,0 +1,5 @@ +"""Attention utils""" + +STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ +"Encoder/decoder models " + \ +"currently do not support chunked prefill." \ No newline at end of file diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 90ed07b029b6a..ecd5413fba507 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -477,10 +478,18 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert attn_type == AttentionType.ENCODER_DECODER or ( - key.shape[0] == num_prefill_tokens + num_decode_tokens) - assert attn_type == AttentionType.ENCODER_DECODER or ( - value.shape[0] == num_prefill_tokens + num_decode_tokens) + if attn_type == AttentionType.ENCODER_DECODER: + # Encoder/decoder models are currently incompatible + # with chunked prefill. + if num_prefill_tokens > 0 and num_decode_tokens > 0: + raise NotImplementedError( \ + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) + else: + # This is a decoder self-attention scenario; + # ensure key/value shape match total number of + # tokens to process + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. From afcb42e125b16a8afa5bb05740783e684b53f6fd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:46:39 -0400 Subject: [PATCH 100/443] explanatory comment --- vllm/attention/backends/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index f893460cce06e..6141f3ead64ad 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,5 +1,8 @@ """Attention utils""" +# Error string(s) for encoder/decoder +# unsupported attention scenarios + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ "Encoder/decoder models " + \ "currently do not support chunked prefill." \ No newline at end of file From d13e08e72dafce77ba15bd2bb5df610506a274f3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:52:37 -0400 Subject: [PATCH 101/443] refactoring --- tests/kernels/test_self_and_cross_attn.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 5128108f52ad6..bdbb083397a93 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1530,9 +1530,20 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_packed_actual_output.view_as( cross_decode_packed_ideal_output)) + # The following test conditions could in principle be a + # standalone test, however the test setup is so involved that it is easier + # to piggyback off of the test vectors & other data structures + # created for testing decode-phase encoder/decoder cross- + # attention above. + # ---- # Set up a contrived scenario where the attention metadata # is configured for chunked prefill & encoder/decoder cross- # attention. Required that this triggers a NotImplementedError. + # + # We assume that decode_attn_metadata.num_decode_tokens > 1 + # already; the line below sets up a chunked prefill + # metadata configuration where there is nominally a mix + # of prefill and decode tokens. decode_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: run_encoder_decoder_cross_attention_test(attn, decode_packed_query, From ab92fb0bd237a1e3ef382751254f5b884dbaf549 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:56:18 -0400 Subject: [PATCH 102/443] spelling fix --- tests/kernels/test_self_and_cross_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index bdbb083397a93..187af811874b3 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1192,7 +1192,7 @@ def run_encoder_or_decoder_self_attention_test( attn_type must be AttentionType.ENCODER or DECODER; if ENCODER, attn_metadata.num_decode_tokens must be 0 (i.e. there is no such thing as - "decode-phase enocder attention".) + "decode-phase encoder attention".) Arguments: From 582a0f5cdf0928d8e662e0df6e0326dc606d567f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:59:49 -0400 Subject: [PATCH 103/443] rename --- vllm/attention/backends/xformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index ecd5413fba507..f0d9e3576bfc6 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -147,7 +147,7 @@ def __post_init__(self): self.attn_bias: Optional[List[AttentionBias]] = None @property - def has_valid_cross_attn_metadata(self): + def is_all_cross_attn_metadata_set(self): # No cross-attention metadata is present whatsoever no_md = (self.cross_seq_lens is None) and (self.cross_slot_mapping is @@ -173,7 +173,7 @@ def attention_type(self) -> AttentionType: def attention_type(self, atype: AttentionType) -> None: if atype == AttentionType.ENCODER_DECODER: - assert self.has_valid_cross_attn_metadata, \ + assert self.is_all_cross_attn_metadata_set, \ "Must have self.cross_seq_lens not None " + \ "in order to enable cross-attention" From a20be6da315d8c3df2b4f2048d5ef0785157b344 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 17:12:29 -0400 Subject: [PATCH 104/443] skip enc/dec tests if HIP --- tests/kernels/test_self_and_cross_attn.py | 9 ++++++--- vllm/attention/backends/utils.py | 6 +++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 187af811874b3..e45e14f2342ad 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -8,10 +8,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType -from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL +from vllm.attention.backends.utils import (STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, + STR_NOT_IMPL_ENC_DEC_ROCM_HIP) from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import make_tensor_with_pad +from vllm.utils import is_hip + # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] # # TODO: FlashAttention forward only supports head dimension at most 128 @@ -1240,7 +1243,7 @@ def run_encoder_decoder_cross_attention_test( return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) - +@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -1338,7 +1341,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, assert torch.allclose(packed_ideal_output, packed_actual_output.view_as(packed_ideal_output)) - +@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 6141f3ead64ad..727921641cd55 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -5,4 +5,8 @@ STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ "Encoder/decoder models " + \ -"currently do not support chunked prefill." \ No newline at end of file +"currently do not support chunked prefill." + +STR_NOT_IMPL_ENC_DEC_ROCM_HIP = \ +"Encoder/decoder models currently" + \ +"do not support ROCm/HIP." \ No newline at end of file From a9a162da1adbbcb5dc89742d2fc3a5c764aec6d7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 17:13:50 -0400 Subject: [PATCH 105/443] formatting --- tests/kernels/test_self_and_cross_attn.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index e45e14f2342ad..36f9616432e3a 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -8,12 +8,10 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType -from vllm.attention.backends.utils import (STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, - STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +from vllm.attention.backends.utils import ( + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) from vllm.attention.backends.xformers import XFormersBackend -from vllm.utils import make_tensor_with_pad - -from vllm.utils import is_hip +from vllm.utils import is_hip, make_tensor_with_pad # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] # @@ -1243,6 +1241,7 @@ def run_encoder_decoder_cross_attention_test( return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -1341,6 +1340,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, assert torch.allclose(packed_ideal_output, packed_actual_output.view_as(packed_ideal_output)) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) From 6e3cfe141af8894a49d56aff2ec77235a82c3276 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 17:54:03 -0400 Subject: [PATCH 106/443] Refactored checks into utils file --- vllm/attention/backends/utils.py | 42 ++++++++++++++++++++++++++++- vllm/attention/backends/xformers.py | 21 +++++++-------- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 727921641cd55..5616a28ae9f73 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,5 +1,10 @@ """Attention utils""" +from vllm.utils import is_hip +from vllm.attention import AttentionMetadata +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.xformers import XFormersMetadata + # Error string(s) for encoder/decoder # unsupported attention scenarios @@ -9,4 +14,39 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = \ "Encoder/decoder models currently" + \ -"do not support ROCm/HIP." \ No newline at end of file +"do not support ROCm/HIP." + +STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND = \ +"Encoder/decoder models currently support only the XFormers backend." + +# Check for unsupported encoder/decoder scenarios + +def check_hip_or_chunked_prefill_attention_encdec( + attn_metadata: AttentionMetadata): + ''' + Check for unsupported encoder/decoder scenarios when invoking + attention. + + Arguments: + + * attn_metadata: Attention metadata structure + ''' + if is_hip(): + # AMD ROCm/HIP support currently not implemented for + # encoder/decoder models + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ROCM_HIP) + + if not isinstance(attn_metadata,XFormersMetadata): + # Right now encoder/decoder support is only implemented + # for the XFormers backend. Pretty unlikely to encounter + # this case currently given this function will be invoked inside + # xFormers backend. + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND) + + if attn_metadata.attention_type != AttentionType.DECODER: + # Encoder/decoder models are currently incompatible + # with chunked prefill. + if attn_metadata.num_prefill_tokens > 0 and \ + attn_metadata.num_decode_tokens > 0: + raise NotImplementedError( \ + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index f0d9e3576bfc6..c23cabd0a07ba 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -447,6 +446,13 @@ def forward( # seqlen datastructures we utilize attn_type = attn_metadata.attention_type + if attn_type != AttentionType.DECODER: + # Raise NotImplementedError for unsupported encoder/decoder + # scenarios + from vllm.attention.backends.utils import \ + check_hip_or_chunked_prefill_attention_encdec + check_hip_or_chunked_prefill_attention_encdec(attn_metadata) + if (kv_cache is not None): # Even if there are no new key/value pairs to cache, # we still need to break out key_cache and value_cache @@ -478,16 +484,9 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - if attn_type == AttentionType.ENCODER_DECODER: - # Encoder/decoder models are currently incompatible - # with chunked prefill. - if num_prefill_tokens > 0 and num_decode_tokens > 0: - raise NotImplementedError( \ - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) - else: - # This is a decoder self-attention scenario; - # ensure key/value shape match total number of - # tokens to process + if attn_type != AttentionType.ENCODER_DECODER: + # Only enforce this shape-constraint for decoder + # self-attention assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens From 622ce09f19d8d21080b2a72377c929c0e149eb8f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 18:14:06 -0400 Subject: [PATCH 107/443] format --- tests/kernels/test_self_and_cross_attn.py | 134 ++++++++++++++++++++++ vllm/attention/backends/utils.py | 19 +-- vllm/attention/backends/xformers.py | 4 +- 3 files changed, 146 insertions(+), 11 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 36f9616432e3a..c12b19b929147 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1555,3 +1555,137 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # "Encoder decoder models do not currently support chunked prefill" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL + + +@pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") +@pytest.mark.parametrize("num_heads", [256]) +@pytest.mark.parametrize("head_size", [16]) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", [16]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("max_q_seq_len", [64]) +@pytest.mark.parametrize("max_kv_seq_len", [64]) +def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, + backend_name: str, batch_size: int, + block_size: int, max_q_seq_len: int, + max_kv_seq_len: int) -> None: + ''' + Encoder/decoder not-implemented-for-ROCm-HIP test: + + * Construct fake test vectors for self- and cross-attention + * Construct attention metadata structure with self- and cross-attention + attributes + * Test self- and cross-attention in the following order + + * Prefill self-attention + * Prefill cross-attention + * Decode self-attention + * Decode cross-attention + * This order would exacerbate any accidental overlap in the + self-/cross-attention block tables, which we attempt to avoid + * Validate output correctness against ideal reference attention + implementation + + Block tables are constructed such that cross-attention KV cache is in a + higher, non-intersecting address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V + tensors. Self-attention K/Vs must have the same seq len as Q while + cross-attention K/Vs are allowed to differ in seq len, as is often the case + for cross-attention. + ''' + + # Num KV cache blocks + num_blocks = 4096 + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, + backend_name) + + # Self-attention setup + + self_block_base_addr = 0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + _, \ + _, \ + q_seq_lens, \ + _, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = decoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + cross_kv_seq_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + _ = enc_dec_cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + prefill_q_seq_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + is_encoder_only_test=False, + cross_seq_lens=cross_kv_seq_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + with pytest.raises(NotImplementedError) as exc_info: + run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, + cross_prefill_packed_key, + cross_prefill_packed_value, + kv_cache, + prefill_attn_metadata) + + # "Encoder decoder models do not currently support ROCm/HIP" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 5616a28ae9f73..ad88b4f964a54 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,9 +1,9 @@ """Attention utils""" -from vllm.utils import is_hip from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.xformers import XFormersMetadata +from vllm.utils import is_hip # Error string(s) for encoder/decoder # unsupported attention scenarios @@ -21,6 +21,7 @@ # Check for unsupported encoder/decoder scenarios + def check_hip_or_chunked_prefill_attention_encdec( attn_metadata: AttentionMetadata): ''' @@ -35,18 +36,18 @@ def check_hip_or_chunked_prefill_attention_encdec( # AMD ROCm/HIP support currently not implemented for # encoder/decoder models raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ROCM_HIP) - - if not isinstance(attn_metadata,XFormersMetadata): + + if not isinstance(attn_metadata, XFormersMetadata): # Right now encoder/decoder support is only implemented # for the XFormers backend. Pretty unlikely to encounter # this case currently given this function will be invoked inside # xFormers backend. raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND) - - if attn_metadata.attention_type != AttentionType.DECODER: + + if attn_metadata.attention_type != AttentionType.DECODER \ + and attn_metadata.num_prefill_tokens > 0 and \ + attn_metadata.num_decode_tokens > 0: # Encoder/decoder models are currently incompatible # with chunked prefill. - if attn_metadata.num_prefill_tokens > 0 and \ - attn_metadata.num_decode_tokens > 0: - raise NotImplementedError( \ - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) + raise NotImplementedError( \ + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index c23cabd0a07ba..6c5bd8fa3726a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -449,8 +449,8 @@ def forward( if attn_type != AttentionType.DECODER: # Raise NotImplementedError for unsupported encoder/decoder # scenarios - from vllm.attention.backends.utils import \ - check_hip_or_chunked_prefill_attention_encdec + from vllm.attention.backends.utils import ( + check_hip_or_chunked_prefill_attention_encdec) check_hip_or_chunked_prefill_attention_encdec(attn_metadata) if (kv_cache is not None): From 4d88a898184f09b8ad69fe498bb23629fd99f338 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 19:09:18 -0400 Subject: [PATCH 108/443] wip trying to combine attention metadata caches --- vllm/attention/backends/xformers.py | 234 ++++++++++++---------------- 1 file changed, 102 insertions(+), 132 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6c5bd8fa3726a..e74b5fa1052db 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -198,149 +198,119 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self.num_prefills == 0: return None - if self._attn_type != AttentionType.ENCODER_DECODER: - # Decoder or encoder self-attention prefill - - if self._self_cached_prefill_metadata is not None: - return self._self_cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - - self._self_cached_prefill_metadata = XFormersMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self. - num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - _attn_type=self. - _attn_type, # Begin cross-attention fields below... - cross_seq_lens=None, - cross_seq_lens_tensor=None, - max_cross_seq_len=None, - cross_block_tables=None, - cross_slot_mapping=None) + target_attention_type = self.attention_type + + if self._self_cached_prefill_metadata is not None: + self._self_cached_prefill_metadata.attention_type = \ + target_attention_type return self._self_cached_prefill_metadata + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + + if self.is_all_cross_attn_metadata_set: + # This attention metadata structure could support + # encoder/decoder cross-attention; make sure to + # set the appropriate fields + cross_seq_lens=self.cross_seq_lens, + cross_seq_lens_tensor=self.cross_seq_lens_tensor, + max_cross_seq_len=self.max_cross_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables else: - # Encoder/decoder cross-attention prefill - - if self._cross_cached_prefill_metadata is not None: - return self._cross_cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - - self._cross_cached_prefill_metadata = XFormersMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self. - num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - _attn_type=AttentionType.ENCODER_DECODER, - # Begin cross-attention fields below... - cross_seq_lens=self.cross_seq_lens, - cross_seq_lens_tensor=self.cross_seq_lens_tensor, - max_cross_seq_len=self.max_cross_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cross_cached_prefill_metadata + # This attention metadata structure supports + # decoder-only self-attention; there are no fields + # to support encoder/decoder cross-attention + cross_seq_lens=None, + cross_seq_lens_tensor=None, + max_cross_seq_len=None, + cross_slot_mapping=None, + cross_block_tables=None + + self._self_cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None, + context_lens_tensor=self.context_lens_tensor[:self. + num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + _attn_type=self. + attention_type, # Begin cross-attention fields below... + cross_seq_lens=cross_seq_lens, + cross_seq_lens_tensor=cross_seq_lens_tensor, + max_cross_seq_len=max_cross_seq_len, + cross_slot_mapping=cross_slot_mapping, + cross_block_tables=cross_block_tables) + return self._self_cached_prefill_metadata @property def decode_metadata(self) -> Optional["XFormersMetadata"]: if self.num_decode_tokens == 0: return None - if self._attn_type != AttentionType.ENCODER_DECODER: - # Decoder or encoder self-attention prefill - - if self._self_cached_decode_metadata is not None: - return self._self_cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._self_cached_decode_metadata = XFormersMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - _attn_type=self. - _attn_type, # Begin cross-attention fields below... - cross_seq_lens=None, - cross_seq_lens_tensor=None, - max_cross_seq_len=None, - cross_block_tables=None, - cross_slot_mapping=None) - return self._self_cached_decode_metadata + target_attention_type = self.attention_type + if self._self_cached_decode_metadata is not None: + self._self_cached_decode_metadata.attention_type = \ + target_attention_type + return self._self_cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + if self.is_all_cross_attn_metadata_set: + # This attention metadata structure could support + # encoder/decoder cross-attention; make sure to + # set the appropriate fields + cross_seq_lens=self.cross_seq_lens, + cross_seq_lens_tensor=self.cross_seq_lens_tensor, + max_cross_seq_len=self.max_cross_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables else: - # Encoder/decoder cross-attention decode - - if self._cross_cached_decode_metadata is not None: - return self._cross_cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._cross_cached_decode_metadata = XFormersMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - _attn_type=AttentionType.ENCODER_DECODER, - # Begin cross-attention fields below... - cross_seq_lens=self.cross_seq_lens, - cross_seq_lens_tensor=self.cross_seq_lens_tensor, - max_cross_seq_len=self.max_cross_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cross_cached_decode_metadata - + # This attention metadata structure supports + # decoder-only self-attention; there are no fields + # to support encoder/decoder cross-attention + cross_seq_lens=None, + cross_seq_lens_tensor=None, + max_cross_seq_len=None, + cross_slot_mapping=None, + cross_block_tables=None + + self._self_cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + _attn_type=target_attention_type, + # Begin cross-attention fields below... + cross_seq_lens=cross_seq_lens, + cross_seq_lens_tensor=cross_seq_lens_tensor, + max_cross_seq_len=max_cross_seq_len, + cross_slot_mapping=cross_slot_mapping, + cross_block_tables=cross_block_tables) + return self._self_cached_decode_metadata class XFormersImpl(AttentionImpl[XFormersMetadata]): """ From 3dfcb556f43b4b753fea5e85e780a4fda85e99f9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 19:36:22 -0400 Subject: [PATCH 109/443] wip trying to merge self/cross caches; trying to fix attn_bias issues; just tried having xformers backend clear mask when changing target attention type --- vllm/attention/backends/xformers.py | 32 +++++++++++++++++------------ 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index e74b5fa1052db..71a34170abc3e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -162,8 +162,22 @@ def is_all_cross_attn_metadata_set(self): assert ( not invalid_md_if_not_no_md), "Invalid cross-attention metadata" + self._maybe_infer_implicit_cross_attention_metadata() return True + def _maybe_infer_implicit_cross_attention_metadata(self): + # Infer implicit cross-attention fields + # from user-provided fields, if needed + if self.cross_seq_lens_tensor is None: + assert self.seq_lens_tensor is not None + self.cross_seq_lens_tensor = torch.tensor( + self.cross_seq_lens, + dtype=self.seq_lens_tensor.dtype, + device=self.seq_lens_tensor.device) + if self.max_cross_seq_len is None: + assert self.cross_seq_lens is not None + self.max_cross_seq_len = max(self.cross_seq_lens) + @property def attention_type(self) -> AttentionType: return self._attn_type @@ -175,19 +189,7 @@ def attention_type(self, atype: AttentionType) -> None: assert self.is_all_cross_attn_metadata_set, \ "Must have self.cross_seq_lens not None " + \ "in order to enable cross-attention" - - # Infer implicit cross-attention fields - # from user-provided fields, if needed - if self.cross_seq_lens_tensor is None: - assert self.seq_lens_tensor is not None - self.cross_seq_lens_tensor = torch.tensor( - self.cross_seq_lens, - dtype=self.seq_lens_tensor.dtype, - device=self.seq_lens_tensor.device) - if self.max_cross_seq_len is None: - assert self.cross_seq_lens is not None - self.max_cross_seq_len = max(self.cross_seq_lens) - + self._maybe_infer_implicit_cross_attention_metadata() self._attn_type = AttentionType.ENCODER_DECODER else: # AttentionType.{ENCODER,DECODER} @@ -263,6 +265,10 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: target_attention_type = self.attention_type if self._self_cached_decode_metadata is not None: + if self._self_cached_decode_metadata.attention_type != \ + target_attention_type: + self._self_cached_decode_metadata.attn_bias = None + self._self_cached_decode_metadata.attention_type = \ target_attention_type return self._self_cached_decode_metadata From 696072392eca111c7a463ad7ad920528dc0c0e6a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 19:38:48 -0400 Subject: [PATCH 110/443] wip --- vllm/attention/backends/xformers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 71a34170abc3e..d97f6d306e5bd 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -203,6 +203,10 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: target_attention_type = self.attention_type if self._self_cached_prefill_metadata is not None: + if self._self_cached_prefill_metadata.attention_type != \ + target_attention_type: + self._self_cached_prefill_metadata.attn_bias = None + self._self_cached_prefill_metadata.attention_type = \ target_attention_type return self._self_cached_prefill_metadata From 31275ccba58545b1c090f24e600507b04adeec0c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 30 May 2024 09:32:28 -0400 Subject: [PATCH 111/443] wip merging attention metadata --- vllm/attention/backends/xformers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d97f6d306e5bd..6f21f7caf83e1 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -111,8 +111,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): _self_cached_prefill_metadata: Optional["XFormersMetadata"] = None _self_cached_decode_metadata: Optional["XFormersMetadata"] = None # Cross-attention prefill/decode metadata cache - _cross_cached_prefill_metadata: Optional["XFormersMetadata"] = None - _cross_cached_decode_metadata: Optional["XFormersMetadata"] = None # Begin cross-attention fields... From a643436c0aac5db823a009bef65146d1610a7522 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 30 May 2024 10:35:29 -0400 Subject: [PATCH 112/443] simplied is_all_cross_attn_metadata_set() --- vllm/attention/backends/xformers.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6c5bd8fa3726a..4ef052424f176 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -148,21 +148,9 @@ def __post_init__(self): @property def is_all_cross_attn_metadata_set(self): # No cross-attention metadata is present whatsoever - no_md = (self.cross_seq_lens is - None) and (self.cross_slot_mapping is - None) and (self.cross_block_tables is None) - # If any cross-attention metadata is present, it is invalid - invalid_md_if_not_no_md = (self.cross_seq_lens is None) or ( - self.cross_slot_mapping is None) or (self.cross_block_tables is - None) - - if no_md: - return False - - assert ( - not invalid_md_if_not_no_md), "Invalid cross-attention metadata" - - return True + return (self.cross_seq_lens is not None) and \ + (self.cross_slot_mapping is not None) and \ + (self.cross_block_tables is not None) @property def attention_type(self) -> AttentionType: @@ -173,8 +161,8 @@ def attention_type(self, atype: AttentionType) -> None: if atype == AttentionType.ENCODER_DECODER: assert self.is_all_cross_attn_metadata_set, \ - "Must have self.cross_seq_lens not None " + \ - "in order to enable cross-attention" + "Must enable self.cross_seq_lens, self.cross_slot_mapping, " + \ + "self.cross_block_tables in order to perform cross-attention" # Infer implicit cross-attention fields # from user-provided fields, if needed From 2a1d84ac71fbbdfd50262574c505e34bdc464157 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 2 Jun 2024 23:06:00 -0400 Subject: [PATCH 113/443] test: envs.VLLM_ATTENTION_BACKEND --- tests/kernels/test_self_and_cross_attn.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index c12b19b929147..5d92d73bfef38 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -13,6 +13,9 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import is_hip, make_tensor_with_pad +from vllm.logger import init_logger +logger = init_logger(__name__) + # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] # # TODO: FlashAttention forward only supports head dimension at most 128 @@ -1278,6 +1281,10 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, layer output.) ''' + import vllm.envs as envs + print("envs.VLLM_ATTENTION_BACKEND: "+str(envs.VLLM_ATTENTION_BACKEND)) + logger.info("envs.VLLM_ATTENTION_BACKEND: "+str(envs.VLLM_ATTENTION_BACKEND)) + # Attention scale factor, attention backend instance, attention wrapper # instance. Encoder attention does not require KV cache. scale, \ From f6e0310f1a955d7cf6e08e27a6b4f7fa12f80c88 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 2 Jun 2024 23:08:01 -0400 Subject: [PATCH 114/443] formatting' --- tests/kernels/test_self_and_cross_attn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 5d92d73bfef38..73625f6f68501 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -11,9 +11,9 @@ from vllm.attention.backends.utils import ( STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) from vllm.attention.backends.xformers import XFormersBackend +from vllm.logger import init_logger from vllm.utils import is_hip, make_tensor_with_pad -from vllm.logger import init_logger logger = init_logger(__name__) # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] @@ -1282,8 +1282,9 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, ''' import vllm.envs as envs - print("envs.VLLM_ATTENTION_BACKEND: "+str(envs.VLLM_ATTENTION_BACKEND)) - logger.info("envs.VLLM_ATTENTION_BACKEND: "+str(envs.VLLM_ATTENTION_BACKEND)) + print("envs.VLLM_ATTENTION_BACKEND: " + str(envs.VLLM_ATTENTION_BACKEND)) + logger.info("envs.VLLM_ATTENTION_BACKEND: ", + str(envs.VLLM_ATTENTION_BACKEND)) # Attention scale factor, attention backend instance, attention wrapper # instance. Encoder attention does not require KV cache. From 60c01c3b67121e6fb7bdfb074f7fb7f39fef0023 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 2 Jun 2024 23:45:30 -0400 Subject: [PATCH 115/443] attempted to fix issue whereby selector test doesn't cleanup environment variables --- tests/kernels/test_attention_selector.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index f439afa9b7d2b..cefd856898ac7 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -34,6 +34,9 @@ def test_env(name: str, device: str): if name_backup is not None: os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + else: + # VLLM_ATTENTION_BACKEND was unset + os.environ.pop('VLLM_ATTENTION_BACKEND', None) def test_flash_attn(): @@ -73,6 +76,9 @@ def test_flash_attn(): if name_backup is not None: os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + else: + # VLLM_ATTENTION_BACKEND was unset + os.environ.pop('VLLM_ATTENTION_BACKEND', None) def test_invalid_env(): @@ -81,4 +87,9 @@ def test_invalid_env(): os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID" with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + + if name_backup is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + else: + # VLLM_ATTENTION_BACKEND was unset + os.environ.pop('VLLM_ATTENTION_BACKEND', None) From 5c94166b203f22ad20d7909661f5229c6954e09c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 00:48:56 -0400 Subject: [PATCH 116/443] (1) In top-level tests utils.py added env var context manager, (2) in tests/kernels added utils.py w/ vLLM backend context manager, (3) all unit tests for backend selector & enc/dec use backend context manager --- tests/kernels/test_attention_selector.py | 101 ++- tests/kernels/test_self_and_cross_attn.py | 772 +++++++++++----------- tests/kernels/utils.py | 25 + tests/utils.py | 26 + 4 files changed, 484 insertions(+), 440 deletions(-) create mode 100644 tests/kernels/utils.py diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index cefd856898ac7..110137d2820b7 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -1,10 +1,10 @@ -import os from unittest.mock import patch import pytest import torch from vllm.attention.selector import which_attn_to_use +from tests.kernels.utils import backend_override_fixture @pytest.mark.parametrize( @@ -14,82 +14,63 @@ def test_env(name: str, device: str): """Test that the attention selector can be set via environment variable. Note that we do not test FlashAttn because it is the default backend. """ - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = name - if device == "cpu": - with patch("vllm.attention.selector.is_cpu", return_value=True): + with backend_override_fixture(name): + + if device == "cpu": + with patch("vllm.attention.selector.is_cpu", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "TORCH_SDPA" + elif device == "hip": + with patch("vllm.attention.selector.is_hip", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "ROCM_FLASH" + else: backend = which_attn_to_use(8, 16, 8, None, torch.float16, torch.float16, 16) - assert backend.name == "TORCH_SDPA" - elif device == "hip": - with patch("vllm.attention.selector.is_hip", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == "ROCM_FLASH" - else: - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == name - - if name_backup is not None: - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup - else: - # VLLM_ATTENTION_BACKEND was unset - os.environ.pop('VLLM_ATTENTION_BACKEND', None) + assert backend.name == name def test_flash_attn(): """Test FlashAttn validation.""" - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" - # Unsupported CUDA arch - with patch("torch.cuda.get_device_capability", return_value=[7, 5]): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + with backend_override_fixture("FLASH_ATTN"): - # Unsupported data type - backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) - assert backend.name != "FLASH_ATTN" + # Unsupported CUDA arch + with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" - # Unsupported kv cache data type - backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) - assert backend.name != "FLASH_ATTN" + # Unsupported data type + backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) + assert backend.name != "FLASH_ATTN" - # Unsupported block size - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) - assert backend.name != "FLASH_ATTN" + # Unsupported kv cache data type + backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + assert backend.name != "FLASH_ATTN" - # Unsupported sliding window - backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + # Unsupported block size + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + assert backend.name != "FLASH_ATTN" - # flash-attn is not installed - with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + # Unsupported sliding window + backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) assert backend.name != "FLASH_ATTN" - # Unsupported head size - backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + # flash-attn is not installed + with patch.dict('sys.modules', {'vllm_flash_attn': None}): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" - if name_backup is not None: - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup - else: - # VLLM_ATTENTION_BACKEND was unset - os.environ.pop('VLLM_ATTENTION_BACKEND', None) + # Unsupported head size + backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" def test_invalid_env(): """Throw an exception if the backend name is invalid.""" - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID" - with pytest.raises(ValueError): - which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - - if name_backup is not None: - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup - else: - # VLLM_ATTENTION_BACKEND was unset - os.environ.pop('VLLM_ATTENTION_BACKEND', None) + + with backend_override_fixture("INVALID"), pytest.raises(ValueError): + which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) \ No newline at end of file diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 73625f6f68501..684a86cfc4586 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -13,6 +13,7 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.logger import init_logger from vllm.utils import is_hip, make_tensor_with_pad +from tests.kernels.utils import backend_override_fixture logger = init_logger(__name__) @@ -27,7 +28,7 @@ BATCH_SIZES = [1, 16] BLOCK_SIZES = [16] -BACKEND_NAMES = ["xformers"] +BACKEND_NAMES = ["XFORMERS"] CUDA_DEVICE = "cuda:0" MAX_Q_SEQ_LENS = [128] @@ -371,15 +372,22 @@ def make_backend(backend_name: str) -> AttentionBackend: Construct the backend instance determined by the backend_name string argument. - "xformers" -> construct xformers backend + "XFORMERS" -> construct xformers backend + + TODO: other backends + + Note: at time of writing the Attention wrapper automatically selects + its own backend for Attention.forward(); so the backend instance which + you generate with this function is not meant to be used for *running* + inference, but rather for generating compatible metadata structures + using backend.make_metadata() - TODO: flash attention backend Returns: * Backend instance ''' - if backend_name == "xformers": + if backend_name == "XFORMERS": return XFormersBackend() raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") @@ -1281,72 +1289,70 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, layer output.) ''' - import vllm.envs as envs - print("envs.VLLM_ATTENTION_BACKEND: " + str(envs.VLLM_ATTENTION_BACKEND)) - logger.info("envs.VLLM_ATTENTION_BACKEND: ", - str(envs.VLLM_ATTENTION_BACKEND)) - - # Attention scale factor, attention backend instance, attention wrapper - # instance. Encoder attention does not require KV cache. - scale, \ - attn_backend, \ - attn, \ - _ = basic_setup(num_heads, - head_size, - None, - None, - backend_name) - - # Self-attention setup - # Let encoder_attn_setup() choose default block table - # base address; the block_tables and slot_mapping - # tensors are not actually utilized by encoder attention - # anyway but are required to be present & valid by the - # backend. - packed_query, \ - packed_key, \ - packed_value, \ - packed_ideal_output, \ - block_tables, \ - slot_mapping, \ - q_seq_lens = encoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_seq_len) - - context_lens = [0 for _ in range(batch_size)] - - # Metadata config for encoder attention: - # - # * Use prefill kernel - # * Signal that this is an encoder-only test so that - # metadata attention_type property is correctly - # configured as AttentionType.ENCODER - attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - q_seq_lens, - context_lens, - block_tables, - slot_mapping, - is_encoder_only_test=True, - ) - - packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( - attn, - packed_query, - packed_key, - packed_value, - None, - attn_metadata, - attn_type=AttentionType.ENCODER) - - # - Is encoder attention result correct? - assert torch.allclose(packed_ideal_output, - packed_actual_output.view_as(packed_ideal_output)) + with backend_override_fixture(backend_name): + # Force Attention wrapper backend + + # Attention scale factor, attention backend instance, attention wrapper + # instance. Encoder attention does not require KV cache. + scale, \ + attn_backend, \ + attn, \ + _ = basic_setup(num_heads, + head_size, + None, + None, + backend_name) + + # Self-attention setup + # Let encoder_attn_setup() choose default block table + # base address; the block_tables and slot_mapping + # tensors are not actually utilized by encoder attention + # anyway but are required to be present & valid by the + # backend. + packed_query, \ + packed_key, \ + packed_value, \ + packed_ideal_output, \ + block_tables, \ + slot_mapping, \ + q_seq_lens = encoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_seq_len) + + context_lens = [0 for _ in range(batch_size)] + + # Metadata config for encoder attention: + # + # * Use prefill kernel + # * Signal that this is an encoder-only test so that + # metadata attention_type property is correctly + # configured as AttentionType.ENCODER + attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + q_seq_lens, + context_lens, + block_tables, + slot_mapping, + is_encoder_only_test=True, + ) + + packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( + attn, + packed_query, + packed_key, + packed_value, + None, + attn_metadata, + attn_type=AttentionType.ENCODER) + + # - Is encoder attention result correct? + assert torch.allclose(packed_ideal_output, + packed_actual_output.view_as(packed_ideal_output)) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -1386,314 +1392,320 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( for cross-attention. ''' - # Num KV cache blocks - num_blocks = 4096 - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - scale, \ - attn_backend, \ - attn, \ - kv_cache = basic_setup(num_heads, - head_size, - num_blocks, - block_size, - backend_name) - - # Self-attention setup - - self_block_base_addr = 0 - - query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ - _, \ - _, \ - q_seq_lens, \ - _, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ - cross_block_base_addr = decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=self_block_base_addr) - - # Cross-attention setup - - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ - cross_kv_seq_lens, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ - _ = enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) - - # PREFILL: self- and cross-attention tests - - context_lens = [0 for _ in range(batch_size)] - - prefill_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - prefill_q_seq_lens, - context_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, - ) - - self_prefill_packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( - attn, - prefill_packed_query, - self_prefill_packed_key, - self_prefill_packed_value, - kv_cache, - prefill_attn_metadata, - attn_type=AttentionType.DECODER) - - # - Prefill self-attention correct? - assert torch.allclose( - self_prefill_packed_ideal_output, - self_prefill_packed_actual_output.view_as( - self_prefill_packed_ideal_output)) - - cross_prefill_packed_actual_output: torch.Tensor = \ - run_encoder_decoder_cross_attention_test( - attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, prefill_attn_metadata) - - # - Prefill cross-attention correct? - assert torch.allclose( - cross_prefill_packed_ideal_output, - cross_prefill_packed_actual_output.view_as( - cross_prefill_packed_ideal_output)) - - context_lens = copy.deepcopy(self_prefill_kv_seq_lens) - - # DECODE: self- and cross-attention tests - - decode_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - False, - q_seq_lens, - context_lens, - self_decode_block_tables, - self_decode_slot_mapping, - is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, - cross_block_tables=cross_decode_block_tables, - cross_slot_mapping=cross_decode_slot_mapping, - ) - - self_decode_packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( - attn, - decode_packed_query, - self_decode_packed_key, - self_decode_packed_value, - kv_cache, - decode_attn_metadata, - attn_type=AttentionType.DECODER) - - # - Decode self-attention correct? - assert torch.allclose( - self_decode_packed_ideal_output, - self_decode_packed_actual_output.view_as( - self_decode_packed_ideal_output)) - - cross_decode_packed_actual_output: torch.Tensor = \ - run_encoder_decoder_cross_attention_test( - attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) - - # - Decode cross-attention correct? - assert torch.allclose( - cross_decode_packed_ideal_output, - cross_decode_packed_actual_output.view_as( - cross_decode_packed_ideal_output)) - - # The following test conditions could in principle be a - # standalone test, however the test setup is so involved that it is easier - # to piggyback off of the test vectors & other data structures - # created for testing decode-phase encoder/decoder cross- - # attention above. - # ---- - # Set up a contrived scenario where the attention metadata - # is configured for chunked prefill & encoder/decoder cross- - # attention. Required that this triggers a NotImplementedError. - # - # We assume that decode_attn_metadata.num_decode_tokens > 1 - # already; the line below sets up a chunked prefill - # metadata configuration where there is nominally a mix - # of prefill and decode tokens. - decode_attn_metadata.num_prefill_tokens = 1 - with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test(attn, decode_packed_query, - None, None, kv_cache, - decode_attn_metadata) - - # "Encoder decoder models do not currently support chunked prefill" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL - - -@pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") -@pytest.mark.parametrize("num_heads", [256]) -@pytest.mark.parametrize("head_size", [16]) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", [16]) -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("max_q_seq_len", [64]) -@pytest.mark.parametrize("max_kv_seq_len", [64]) -def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, - backend_name: str, batch_size: int, - block_size: int, max_q_seq_len: int, - max_kv_seq_len: int) -> None: - ''' - Encoder/decoder not-implemented-for-ROCm-HIP test: - - * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention - attributes - * Test self- and cross-attention in the following order + with backend_override_fixture(backend_name): + # Force Attention wrapper backend + + # Num KV cache blocks + num_blocks = 4096 + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, + backend_name) + + # Self-attention setup + + self_block_base_addr = 0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + _, \ + _, \ + q_seq_lens, \ + _, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = decoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + cross_kv_seq_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + _ = enc_dec_cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + prefill_q_seq_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + is_encoder_only_test=False, + cross_seq_lens=cross_kv_seq_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + self_prefill_packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( + attn, + prefill_packed_query, + self_prefill_packed_key, + self_prefill_packed_value, + kv_cache, + prefill_attn_metadata, + attn_type=AttentionType.DECODER) + + # - Prefill self-attention correct? + assert torch.allclose( + self_prefill_packed_ideal_output, + self_prefill_packed_actual_output.view_as( + self_prefill_packed_ideal_output)) + + cross_prefill_packed_actual_output: torch.Tensor = \ + run_encoder_decoder_cross_attention_test( + attn, prefill_packed_query, cross_prefill_packed_key, + cross_prefill_packed_value, kv_cache, prefill_attn_metadata) + + # - Prefill cross-attention correct? + assert torch.allclose( + cross_prefill_packed_ideal_output, + cross_prefill_packed_actual_output.view_as( + cross_prefill_packed_ideal_output)) + + context_lens = copy.deepcopy(self_prefill_kv_seq_lens) + + # DECODE: self- and cross-attention tests + + decode_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + False, + q_seq_lens, + context_lens, + self_decode_block_tables, + self_decode_slot_mapping, + is_encoder_only_test=False, + cross_seq_lens=cross_kv_seq_lens, + cross_block_tables=cross_decode_block_tables, + cross_slot_mapping=cross_decode_slot_mapping, + ) + + self_decode_packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( + attn, + decode_packed_query, + self_decode_packed_key, + self_decode_packed_value, + kv_cache, + decode_attn_metadata, + attn_type=AttentionType.DECODER) + + # - Decode self-attention correct? + assert torch.allclose( + self_decode_packed_ideal_output, + self_decode_packed_actual_output.view_as( + self_decode_packed_ideal_output)) + + cross_decode_packed_actual_output: torch.Tensor = \ + run_encoder_decoder_cross_attention_test( + attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) + + # - Decode cross-attention correct? + assert torch.allclose( + cross_decode_packed_ideal_output, + cross_decode_packed_actual_output.view_as( + cross_decode_packed_ideal_output)) + + # The following test conditions could in principle be a + # standalone test, however the test setup is so involved that it is easier + # to piggyback off of the test vectors & other data structures + # created for testing decode-phase encoder/decoder cross- + # attention above. + # ---- + # Set up a contrived scenario where the attention metadata + # is configured for chunked prefill & encoder/decoder cross- + # attention. Required that this triggers a NotImplementedError. + # + # We assume that decode_attn_metadata.num_decode_tokens > 1 + # already; the line below sets up a chunked prefill + # metadata configuration where there is nominally a mix + # of prefill and decode tokens. + decode_attn_metadata.num_prefill_tokens = 1 + with pytest.raises(NotImplementedError) as exc_info: + run_encoder_decoder_cross_attention_test(attn, decode_packed_query, + None, None, kv_cache, + decode_attn_metadata) + + # "Encoder decoder models do not currently support chunked prefill" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL + + +# @pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") +# @pytest.mark.parametrize("num_heads", [256]) +# @pytest.mark.parametrize("head_size", [16]) +# @pytest.mark.parametrize("backend_name", BACKEND_NAMES) +# @pytest.mark.parametrize("batch_size", [16]) +# @pytest.mark.parametrize("block_size", [16]) +# @pytest.mark.parametrize("max_q_seq_len", [64]) +# @pytest.mark.parametrize("max_kv_seq_len", [64]) +# def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, +# backend_name: str, batch_size: int, +# block_size: int, max_q_seq_len: int, +# max_kv_seq_len: int) -> None: +# ''' +# Encoder/decoder not-implemented-for-ROCm-HIP test: + +# * Construct fake test vectors for self- and cross-attention +# * Construct attention metadata structure with self- and cross-attention +# attributes +# * Test self- and cross-attention in the following order - * Prefill self-attention - * Prefill cross-attention - * Decode self-attention - * Decode cross-attention - * This order would exacerbate any accidental overlap in the - self-/cross-attention block tables, which we attempt to avoid - * Validate output correctness against ideal reference attention - implementation - - Block tables are constructed such that cross-attention KV cache is in a - higher, non-intersecting address-space than self-attention KV cache. - - Self- and cross-attention share the same query tensor but not the K/V - tensors. Self-attention K/Vs must have the same seq len as Q while - cross-attention K/Vs are allowed to differ in seq len, as is often the case - for cross-attention. - ''' - - # Num KV cache blocks - num_blocks = 4096 - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - scale, \ - attn_backend, \ - attn, \ - kv_cache = basic_setup(num_heads, - head_size, - num_blocks, - block_size, - backend_name) - - # Self-attention setup - - self_block_base_addr = 0 - - query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ - _, \ - _, \ - q_seq_lens, \ - _, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ - cross_block_base_addr = decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=self_block_base_addr) - - # Cross-attention setup - - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ - cross_kv_seq_lens, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ - _ = enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) - - # PREFILL: self- and cross-attention tests - - context_lens = [0 for _ in range(batch_size)] - - prefill_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - prefill_q_seq_lens, - context_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, - ) - - with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, - cross_prefill_packed_key, - cross_prefill_packed_value, - kv_cache, - prefill_attn_metadata) - - # "Encoder decoder models do not currently support ROCm/HIP" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP +# * Prefill self-attention +# * Prefill cross-attention +# * Decode self-attention +# * Decode cross-attention +# * This order would exacerbate any accidental overlap in the +# self-/cross-attention block tables, which we attempt to avoid +# * Validate output correctness against ideal reference attention +# implementation + +# Block tables are constructed such that cross-attention KV cache is in a +# higher, non-intersecting address-space than self-attention KV cache. + +# Self- and cross-attention share the same query tensor but not the K/V +# tensors. Self-attention K/Vs must have the same seq len as Q while +# cross-attention K/Vs are allowed to differ in seq len, as is often the case +# for cross-attention. +# ''' + +# with backend_override_fixture(backend_name): +# # Force Attention wrapper backend + +# # Num KV cache blocks +# num_blocks = 4096 + +# # Attention scale factor, attention backend instance, attention wrapper +# # instance, KV cache init +# scale, \ +# attn_backend, \ +# attn, \ +# kv_cache = basic_setup(num_heads, +# head_size, +# num_blocks, +# block_size, +# backend_name) + +# # Self-attention setup + +# self_block_base_addr = 0 + +# query, \ +# prefill_packed_query, \ +# self_prefill_packed_key, \ +# self_prefill_packed_value, \ +# self_prefill_packed_ideal_output, \ +# prefill_q_seq_lens, \ +# self_prefill_kv_seq_lens, \ +# decode_packed_query, \ +# self_decode_packed_key, \ +# self_decode_packed_value, \ +# self_decode_packed_ideal_output, \ +# _, \ +# _, \ +# q_seq_lens, \ +# _, \ +# self_decode_block_tables, \ +# self_decode_slot_mapping, \ +# self_prefill_slot_mapping, \ +# self_prefill_block_tables, \ +# cross_block_base_addr = decoder_attn_setup(batch_size, +# num_heads, +# head_size, +# block_size, +# scale, +# max_q_seq_len, +# block_base_addr=self_block_base_addr) + +# # Cross-attention setup + +# cross_prefill_packed_key, \ +# cross_prefill_packed_value, \ +# cross_prefill_packed_ideal_output, \ +# cross_decode_packed_ideal_output, \ +# cross_kv_seq_lens, \ +# cross_decode_block_tables, \ +# cross_decode_slot_mapping, \ +# cross_prefill_slot_mapping, \ +# cross_prefill_block_tables, \ +# _ = enc_dec_cross_attn_setup_reuses_query(query, +# q_seq_lens, +# prefill_q_seq_lens, +# batch_size, +# num_heads, +# head_size, +# block_size, +# scale, +# max_q_seq_len, +# max_kv_seq_len, +# block_base_addr=cross_block_base_addr) + +# # PREFILL: self- and cross-attention tests + +# context_lens = [0 for _ in range(batch_size)] + +# prefill_attn_metadata: AttentionMetadata = make_test_metadata( +# attn_backend, +# True, +# prefill_q_seq_lens, +# context_lens, +# self_prefill_block_tables, +# self_prefill_slot_mapping, +# is_encoder_only_test=False, +# cross_seq_lens=cross_kv_seq_lens, +# cross_block_tables=cross_prefill_block_tables, +# cross_slot_mapping=cross_prefill_slot_mapping, +# ) + +# with pytest.raises(NotImplementedError) as exc_info: +# run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, +# cross_prefill_packed_key, +# cross_prefill_packed_value, +# kv_cache, +# prefill_attn_metadata) + +# # "Encoder decoder models do not currently support ROCm/HIP" +# assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py new file mode 100644 index 0000000000000..1d54a513f287c --- /dev/null +++ b/tests/kernels/utils.py @@ -0,0 +1,25 @@ +"""Kernel test utils""" + +from tests.utils import env_var_fixture +from contextlib import contextmanager +from typing import Iterator + +# Configure + +@contextmanager +def backend_override_fixture(backend_name: str) -> Iterator[None]: + ''' + Text fixture, temporarily configures the vLLM backend by setting + VLLM_ATTENTION_BACKEND, then resets the environment outside of + the fixture. + + Usage: + + with backend_override_fixture("backend_name"): + # code that depends on vLLM backend + + # VLLM_ATTENTION_BACKEND is returned to original value + # or unset + ''' + with env_var_fixture('VLLM_ATTENTION_BACKEND', backend_name): + yield # Control is yielded to the enclosed block, environment variable is set \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index 329842911e159..7c8740236956a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,8 @@ import warnings from contextlib import contextmanager +from typing import Iterator + import ray import requests @@ -101,3 +103,27 @@ def error_on_warning(): warnings.simplefilter("error") yield + +@contextmanager +def env_var_fixture(var_name: str, value: str) -> Iterator[None]: + ''' + Text fixture, temporarily assigns value var_name environment variable, + then resets environment variable outside of test fixture. + + Usage: + + with env_var_fixture("my_var","my_val"): + # code that depends on my_val == "my_val" + + # my_var is returned to original value or unset + ''' + original_value = os.environ.get(var_name) # Store the original value + os.environ[var_name] = value # Set the new value + try: + yield + finally: + # Restore the original value + if original_value is None: + del os.environ[var_name] + else: + os.environ[var_name] = original_value \ No newline at end of file From eaa627fd3da6f3491a08cf06d8d8fd1eee89f991 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 01:19:39 -0400 Subject: [PATCH 117/443] wip tests --- tests/kernels/test_self_and_cross_attn.py | 268 +++++++++++----------- 1 file changed, 134 insertions(+), 134 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 684a86cfc4586..67772b1286e2a 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1574,138 +1574,138 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL -# @pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") -# @pytest.mark.parametrize("num_heads", [256]) -# @pytest.mark.parametrize("head_size", [16]) -# @pytest.mark.parametrize("backend_name", BACKEND_NAMES) -# @pytest.mark.parametrize("batch_size", [16]) -# @pytest.mark.parametrize("block_size", [16]) -# @pytest.mark.parametrize("max_q_seq_len", [64]) -# @pytest.mark.parametrize("max_kv_seq_len", [64]) -# def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, -# backend_name: str, batch_size: int, -# block_size: int, max_q_seq_len: int, -# max_kv_seq_len: int) -> None: -# ''' -# Encoder/decoder not-implemented-for-ROCm-HIP test: - -# * Construct fake test vectors for self- and cross-attention -# * Construct attention metadata structure with self- and cross-attention -# attributes -# * Test self- and cross-attention in the following order +@pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") +@pytest.mark.parametrize("num_heads", [256]) +@pytest.mark.parametrize("head_size", [16]) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", [16]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("max_q_seq_len", [64]) +@pytest.mark.parametrize("max_kv_seq_len", [64]) +def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, + backend_name: str, batch_size: int, + block_size: int, max_q_seq_len: int, + max_kv_seq_len: int) -> None: + ''' + Encoder/decoder not-implemented-for-ROCm-HIP test: + + * Construct fake test vectors for self- and cross-attention + * Construct attention metadata structure with self- and cross-attention + attributes + * Test self- and cross-attention in the following order -# * Prefill self-attention -# * Prefill cross-attention -# * Decode self-attention -# * Decode cross-attention -# * This order would exacerbate any accidental overlap in the -# self-/cross-attention block tables, which we attempt to avoid -# * Validate output correctness against ideal reference attention -# implementation - -# Block tables are constructed such that cross-attention KV cache is in a -# higher, non-intersecting address-space than self-attention KV cache. - -# Self- and cross-attention share the same query tensor but not the K/V -# tensors. Self-attention K/Vs must have the same seq len as Q while -# cross-attention K/Vs are allowed to differ in seq len, as is often the case -# for cross-attention. -# ''' - -# with backend_override_fixture(backend_name): -# # Force Attention wrapper backend - -# # Num KV cache blocks -# num_blocks = 4096 - -# # Attention scale factor, attention backend instance, attention wrapper -# # instance, KV cache init -# scale, \ -# attn_backend, \ -# attn, \ -# kv_cache = basic_setup(num_heads, -# head_size, -# num_blocks, -# block_size, -# backend_name) - -# # Self-attention setup - -# self_block_base_addr = 0 - -# query, \ -# prefill_packed_query, \ -# self_prefill_packed_key, \ -# self_prefill_packed_value, \ -# self_prefill_packed_ideal_output, \ -# prefill_q_seq_lens, \ -# self_prefill_kv_seq_lens, \ -# decode_packed_query, \ -# self_decode_packed_key, \ -# self_decode_packed_value, \ -# self_decode_packed_ideal_output, \ -# _, \ -# _, \ -# q_seq_lens, \ -# _, \ -# self_decode_block_tables, \ -# self_decode_slot_mapping, \ -# self_prefill_slot_mapping, \ -# self_prefill_block_tables, \ -# cross_block_base_addr = decoder_attn_setup(batch_size, -# num_heads, -# head_size, -# block_size, -# scale, -# max_q_seq_len, -# block_base_addr=self_block_base_addr) - -# # Cross-attention setup - -# cross_prefill_packed_key, \ -# cross_prefill_packed_value, \ -# cross_prefill_packed_ideal_output, \ -# cross_decode_packed_ideal_output, \ -# cross_kv_seq_lens, \ -# cross_decode_block_tables, \ -# cross_decode_slot_mapping, \ -# cross_prefill_slot_mapping, \ -# cross_prefill_block_tables, \ -# _ = enc_dec_cross_attn_setup_reuses_query(query, -# q_seq_lens, -# prefill_q_seq_lens, -# batch_size, -# num_heads, -# head_size, -# block_size, -# scale, -# max_q_seq_len, -# max_kv_seq_len, -# block_base_addr=cross_block_base_addr) - -# # PREFILL: self- and cross-attention tests - -# context_lens = [0 for _ in range(batch_size)] - -# prefill_attn_metadata: AttentionMetadata = make_test_metadata( -# attn_backend, -# True, -# prefill_q_seq_lens, -# context_lens, -# self_prefill_block_tables, -# self_prefill_slot_mapping, -# is_encoder_only_test=False, -# cross_seq_lens=cross_kv_seq_lens, -# cross_block_tables=cross_prefill_block_tables, -# cross_slot_mapping=cross_prefill_slot_mapping, -# ) - -# with pytest.raises(NotImplementedError) as exc_info: -# run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, -# cross_prefill_packed_key, -# cross_prefill_packed_value, -# kv_cache, -# prefill_attn_metadata) - -# # "Encoder decoder models do not currently support ROCm/HIP" -# assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP + * Prefill self-attention + * Prefill cross-attention + * Decode self-attention + * Decode cross-attention + * This order would exacerbate any accidental overlap in the + self-/cross-attention block tables, which we attempt to avoid + * Validate output correctness against ideal reference attention + implementation + + Block tables are constructed such that cross-attention KV cache is in a + higher, non-intersecting address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V + tensors. Self-attention K/Vs must have the same seq len as Q while + cross-attention K/Vs are allowed to differ in seq len, as is often the case + for cross-attention. + ''' + + with backend_override_fixture(backend_name): + # Force Attention wrapper backend + + # Num KV cache blocks + num_blocks = 4096 + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, + backend_name) + + # Self-attention setup + + self_block_base_addr = 0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + _, \ + _, \ + q_seq_lens, \ + _, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = decoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + cross_kv_seq_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + _ = enc_dec_cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + prefill_q_seq_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + is_encoder_only_test=False, + cross_seq_lens=cross_kv_seq_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + with pytest.raises(NotImplementedError) as exc_info: + run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, + cross_prefill_packed_key, + cross_prefill_packed_value, + kv_cache, + prefill_attn_metadata) + + # "Encoder decoder models do not currently support ROCm/HIP" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP From 9c597c4b26cd5ecc0d38798818eaab99752ab03b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 01:29:22 -0400 Subject: [PATCH 118/443] FIX: test_attention_selector.py was leaking VLLM_ATTENTION_BACKEND values; fixed with backend context manager --- tests/kernels/test_attention_selector.py | 91 +++++++++++------------- tests/kernels/utils.py | 23 ++++++ tests/utils.py | 27 +++++++ 3 files changed, 93 insertions(+), 48 deletions(-) create mode 100644 tests/kernels/utils.py diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index f439afa9b7d2b..1726f58cee088 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -1,10 +1,10 @@ -import os from unittest.mock import patch import pytest import torch from vllm.attention.selector import which_attn_to_use +from tests.kernels.utils import backend_override_fixture @pytest.mark.parametrize( @@ -14,71 +14,66 @@ def test_env(name: str, device: str): """Test that the attention selector can be set via environment variable. Note that we do not test FlashAttn because it is the default backend. """ - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = name - if device == "cpu": - with patch("vllm.attention.selector.is_cpu", return_value=True): + with backend_override_fixture(name): + + if device == "cpu": + with patch("vllm.attention.selector.is_cpu", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "TORCH_SDPA" + elif device == "hip": + with patch("vllm.attention.selector.is_hip", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "ROCM_FLASH" + else: backend = which_attn_to_use(8, 16, 8, None, torch.float16, torch.float16, 16) - assert backend.name == "TORCH_SDPA" - elif device == "hip": - with patch("vllm.attention.selector.is_hip", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == "ROCM_FLASH" - else: - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == name - - if name_backup is not None: - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + assert backend.name == name def test_flash_attn(): """Test FlashAttn validation.""" - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" - # Unsupported CUDA arch - with patch("torch.cuda.get_device_capability", return_value=[7, 5]): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + with backend_override_fixture("FLASH_ATTN"): - # Unsupported data type - backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) - assert backend.name != "FLASH_ATTN" + # Unsupported CUDA arch + with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, + 16) + assert backend.name != "FLASH_ATTN" - # Unsupported kv cache data type - backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) - assert backend.name != "FLASH_ATTN" + # Unsupported data type + backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, + 16) + assert backend.name != "FLASH_ATTN" - # Unsupported block size - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) - assert backend.name != "FLASH_ATTN" + # Unsupported kv cache data type + backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + assert backend.name != "FLASH_ATTN" - # Unsupported sliding window - backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + # Unsupported block size + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + assert backend.name != "FLASH_ATTN" - # flash-attn is not installed - with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + # Unsupported sliding window + backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) assert backend.name != "FLASH_ATTN" - # Unsupported head size - backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + # flash-attn is not installed + with patch.dict('sys.modules', {'vllm_flash_attn': None}): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, + 16) + assert backend.name != "FLASH_ATTN" - if name_backup is not None: - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + # Unsupported head size + backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" def test_invalid_env(): """Throw an exception if the backend name is invalid.""" - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID" - with pytest.raises(ValueError): + + with backend_override_fixture("INVALID"), pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py new file mode 100644 index 0000000000000..955a96bae2a80 --- /dev/null +++ b/tests/kernels/utils.py @@ -0,0 +1,23 @@ +"""Kernel test utils""" + +from tests.utils import env_var_fixture +from contextlib import contextmanager +from typing import Iterator + +@contextmanager +def backend_override_fixture(backend_name: str) -> Iterator[None]: + ''' + Text fixture, temporarily configures the vLLM backend by setting + VLLM_ATTENTION_BACKEND, then resets the environment outside of + the fixture. + + Usage: + + with backend_override_fixture("backend_name"): + # code that depends on vLLM backend + + # VLLM_ATTENTION_BACKEND is returned to original value + # or unset + ''' + with env_var_fixture('VLLM_ATTENTION_BACKEND', backend_name): + yield diff --git a/tests/utils.py b/tests/utils.py index 329842911e159..48666ca652dd7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,8 @@ import warnings from contextlib import contextmanager +from typing import Iterator + import ray import requests @@ -101,3 +103,28 @@ def error_on_warning(): warnings.simplefilter("error") yield + + +@contextmanager +def env_var_fixture(var_name: str, value: str) -> Iterator[None]: + ''' + Text fixture, temporarily assigns value var_name environment variable, + then resets environment variable outside of test fixture. + + Usage: + + with env_var_fixture("my_var","my_val"): + # code that depends on my_val == "my_val" + + # my_var is returned to original value or unset + ''' + original_value = os.environ.get(var_name) # Store the original value + os.environ[var_name] = value # Set the new value + try: + yield + finally: + # Restore the original value + if original_value is None: + del os.environ[var_name] + else: + os.environ[var_name] = original_value From 9831ce63077cd4b531979dde442b4185a18b16b8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 03:06:50 -0400 Subject: [PATCH 119/443] formatting --- tests/kernels/test_attention_selector.py | 2 +- tests/kernels/utils.py | 4 +++- tests/utils.py | 1 - 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 1726f58cee088..b0b383974904c 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -3,8 +3,8 @@ import pytest import torch -from vllm.attention.selector import which_attn_to_use from tests.kernels.utils import backend_override_fixture +from vllm.attention.selector import which_attn_to_use @pytest.mark.parametrize( diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 955a96bae2a80..8ebc2fc5905aa 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1,9 +1,11 @@ """Kernel test utils""" -from tests.utils import env_var_fixture from contextlib import contextmanager from typing import Iterator +from tests.utils import env_var_fixture + + @contextmanager def backend_override_fixture(backend_name: str) -> Iterator[None]: ''' diff --git a/tests/utils.py b/tests/utils.py index 48666ca652dd7..adbff8e8dc1c6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,6 @@ import time import warnings from contextlib import contextmanager - from typing import Iterator import ray From faf9118554e2677634b3b3bec9d25c5506d1a7d5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 03:37:21 -0400 Subject: [PATCH 120/443] formatting --- tests/kernels/test_self_and_cross_attn.py | 25 ++++++++++++----------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 67772b1286e2a..9b1d22e19c57d 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -6,6 +6,7 @@ import pytest import torch +from tests.kernels.utils import backend_override_fixture from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import ( @@ -13,7 +14,6 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.logger import init_logger from vllm.utils import is_hip, make_tensor_with_pad -from tests.kernels.utils import backend_override_fixture logger = init_logger(__name__) @@ -1351,8 +1351,9 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose(packed_ideal_output, - packed_actual_output.view_as(packed_ideal_output)) + assert torch.allclose( + packed_ideal_output, + packed_actual_output.view_as(packed_ideal_output)) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -1542,7 +1543,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_packed_actual_output: torch.Tensor = \ run_encoder_decoder_cross_attention_test( - attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) + attn, decode_packed_query, None, + None, kv_cache, decode_attn_metadata) # - Decode cross-attention correct? assert torch.allclose( @@ -1551,7 +1553,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_packed_ideal_output)) # The following test conditions could in principle be a - # standalone test, however the test setup is so involved that it is easier + # standalone test, however the test setup is + # so involved that it is easier # to piggyback off of the test vectors & other data structures # created for testing decode-phase encoder/decoder cross- # attention above. @@ -1567,8 +1570,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decode_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: run_encoder_decoder_cross_attention_test(attn, decode_packed_query, - None, None, kv_cache, - decode_attn_metadata) + None, None, kv_cache, + decode_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL @@ -1701,11 +1704,9 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, ) with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, - cross_prefill_packed_key, - cross_prefill_packed_value, - kv_cache, - prefill_attn_metadata) + run_encoder_decoder_cross_attention_test( + attn, prefill_packed_query, cross_prefill_packed_key, + cross_prefill_packed_value, kv_cache, prefill_attn_metadata) # "Encoder decoder models do not currently support ROCm/HIP" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP From 61d63bd8ea84bbce3889aebb5e60d68316af4f22 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 03:57:28 -0400 Subject: [PATCH 121/443] removed comment about supported head_size's, which is not relevant under current encoder/decoder test conditions --- tests/kernels/test_self_and_cross_attn.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 9b1d22e19c57d..7171ba0c2d84c 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -17,11 +17,6 @@ logger = init_logger(__name__) -# If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] -# -# TODO: FlashAttention forward only supports head dimension at most 128 -# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d0 -# 37782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] From b9b604821899f11539041bb51ebed5cffbf1c3a0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 04:03:54 -0400 Subject: [PATCH 122/443] small refactors --- tests/kernels/test_self_and_cross_attn.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 7171ba0c2d84c..37bcc582b8bab 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -17,6 +17,7 @@ logger = init_logger(__name__) + HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] @@ -348,13 +349,6 @@ def pack_qkv(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, packed_query, q_start_loc_list = pack_tensor(query, q_seq_lens) packed_key, kv_start_loc_list = pack_tensor(key, kv_seq_lens) packed_value, _ = pack_tensor(value, kv_seq_lens) - if packed_query is not None: - packed_query = packed_query.view( - -1, packed_query.shape[-1] * packed_query.shape[-2]) - packed_key = packed_key.view(-1, - packed_key.shape[-1] * packed_key.shape[-2]) - packed_value = packed_value.view( - -1, packed_value.shape[-1] * packed_value.shape[-2]) return packed_query, \ packed_key, \ packed_value, \ From e2e208234f6d690d7cbbec217bac4e9dbd4f5d37 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 04:44:39 -0400 Subject: [PATCH 123/443] refactoring --- tests/kernels/test_self_and_cross_attn.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 37bcc582b8bab..5b2b0f718860e 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -12,12 +12,8 @@ from vllm.attention.backends.utils import ( STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) from vllm.attention.backends.xformers import XFormersBackend -from vllm.logger import init_logger from vllm.utils import is_hip, make_tensor_with_pad -logger = init_logger(__name__) - - HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] From b2238738b9a83ea21c58b2142daad80be0a8a659 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 05:20:45 -0400 Subject: [PATCH 124/443] make_qkv() tensors are 4D --- tests/kernels/test_self_and_cross_attn.py | 69 +++++++++-------------- 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 5b2b0f718860e..79d98142c5121 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -195,42 +195,45 @@ def make_qkv(batch_size: int, actual_max_kv_seq_len = max(kv_seq_lens) query = torch.rand( - (batch_size, max_q_seq_len, num_heads * head_size)).to(device) + (batch_size, max_q_seq_len, num_heads, head_size)).to(device) key = torch.rand( - (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) value = torch.rand( - (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) prefill_query = torch.zeros( - (batch_size, max_q_seq_len, num_heads * head_size)).to(device) + (batch_size, max_q_seq_len, num_heads, head_size)).to(device) prefill_key = torch.zeros( - (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) prefill_value = torch.zeros( - (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) decode_query = torch.zeros( - (batch_size, 1, num_heads * head_size)).to(device) - decode_key = torch.zeros((batch_size, 1, num_heads * head_size)).to(device) + (batch_size, 1, num_heads, head_size)).to(device) + decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) decode_value = torch.zeros( - (batch_size, 1, num_heads * head_size)).to(device) + (batch_size, 1, num_heads, head_size)).to(device) for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, kv_seq_lens)): - query[bdx, q_seq_len:, :] = 0 - key[bdx, kv_seq_len:, :] = 0 - value[bdx, kv_seq_len:, :] = 0 - - prefill_query[bdx, 0:(q_seq_len - 1), :] = query[bdx, - 0:(q_seq_len - 1), :] - prefill_key[bdx, 0:(kv_seq_len - 1), :] = key[bdx, - 0:(kv_seq_len - 1), :] - prefill_value[bdx, - 0:(kv_seq_len - 1), :] = value[bdx, - 0:(kv_seq_len - 1), :] - - decode_query[bdx, :, :] = query[bdx, (q_seq_len - 1):q_seq_len, :] - decode_key[bdx, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :] - decode_value[bdx, :, :] = value[bdx, (kv_seq_len - 1):kv_seq_len, :] + query[bdx, q_seq_len:, :, :] = 0 + key[bdx, kv_seq_len:, :, :] = 0 + value[bdx, kv_seq_len:, :, :] = 0 + + prefill_query[bdx, + 0:(q_seq_len - 1), :, :] = query[bdx, + 0:(q_seq_len - 1), :, :] + prefill_key[bdx, + 0:(kv_seq_len - 1), :, :] = key[bdx, + 0:(kv_seq_len - 1), :, :] + prefill_value[bdx, 0:(kv_seq_len - + 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :] + + decode_query[bdx, :, :, :] = query[bdx, + (q_seq_len - 1):q_seq_len, :, :] + decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :] + decode_value[bdx, :, :, :] = value[bdx, + (kv_seq_len - 1):kv_seq_len, :, :] prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] @@ -238,24 +241,6 @@ def make_qkv(batch_size: int, decode_q_seq_lens = [1 for _ in q_seq_lens] decode_kv_seq_lens = [1 for _ in kv_seq_lens] - query = query.view(batch_size, query.shape[1], num_heads, head_size) - key = key.view(batch_size, key.shape[1], num_heads, head_size) - value = value.view(batch_size, value.shape[1], num_heads, head_size) - - prefill_query = prefill_query.view(batch_size, prefill_query.shape[1], - num_heads, head_size) - prefill_key = prefill_key.view(batch_size, prefill_key.shape[1], num_heads, - head_size) - prefill_value = prefill_value.view(batch_size, prefill_value.shape[1], - num_heads, head_size) - - decode_query = decode_query.view(batch_size, decode_query.shape[1], - num_heads, head_size) - decode_key = decode_key.view(batch_size, decode_key.shape[1], num_heads, - head_size) - decode_value = decode_value.view(batch_size, decode_value.shape[1], - num_heads, head_size) - return query, \ key, \ value, \ From 2ea335cf656962208c8d989c2af07ea0dd66ce20 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 05:25:27 -0400 Subject: [PATCH 125/443] combined seq_start_loc init with cumsum --- tests/kernels/test_self_and_cross_attn.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 79d98142c5121..da4d3c757e4ea 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -393,15 +393,9 @@ def make_metadata_tensors(is_prompt: bool, context_lens, dtype=torch.int, device=device) max_context_len = None if context_lens is None else max(context_lens) max_seq_len = None if seq_lens is None else max(seq_lens) + + seq_start_loc = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), torch.cumsum(seq_lens_tensor, dim=0, dtype=torch.int32)]) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) if is_prompt: # Prefill: query_start_loc matches seq_start_loc From 02875abd2532cd08abbf457c544e9f8081320479 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 06:22:25 -0400 Subject: [PATCH 126/443] xformers metadata allows unspecified values for most Optional members; xformers forward only enforces non-none query start locs for prefix caching scenario; simplify self/cross attention test metadata building --- tests/kernels/test_self_and_cross_attn.py | 45 ++++++----------------- vllm/attention/backends/xformers.py | 43 ++++++++++++++-------- 2 files changed, 39 insertions(+), 49 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index da4d3c757e4ea..7cf6736b213d1 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -363,8 +363,7 @@ def make_backend(backend_name: str) -> AttentionBackend: f"Unrecognized backend_name {backend_name} for unit test") -def make_metadata_tensors(is_prompt: bool, - seq_lens: List[int], +def make_metadata_tensors(seq_lens: List[int], context_lens: List[int], device: Union[torch.device, str] = \ CUDA_DEVICE) -> tuple: @@ -393,27 +392,17 @@ def make_metadata_tensors(is_prompt: bool, context_lens, dtype=torch.int, device=device) max_context_len = None if context_lens is None else max(context_lens) max_seq_len = None if seq_lens is None else max(seq_lens) - - seq_start_loc = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), torch.cumsum(seq_lens_tensor, dim=0, dtype=torch.int32)]) - - if is_prompt: - # Prefill: query_start_loc matches seq_start_loc - query_start_loc = copy.deepcopy(seq_start_loc) - max_query_len = max_seq_len - else: - # Decode: one new query input token per batch element, thus - # query_start_loc is the cumsum of [1,1,1,...] - query_start_loc = list(range(len(seq_start_loc))) - max_query_len = 1 + seq_start_loc = torch.cat([ + torch.tensor([0], dtype=torch.int32, device=device), + torch.cumsum(seq_lens_tensor, dim=0, dtype=torch.int32) + ]) return seq_lens_tensor, \ context_lens_tensor, \ - max_query_len, \ max_context_len, \ max_seq_len, \ - seq_start_loc, \ - query_start_loc + seq_start_loc def make_kv_cache(num_blocks: int, @@ -625,14 +614,11 @@ def make_test_metadata( seq_lens_tensor, \ context_lens_tensor, \ - max_query_len, \ _, \ _, \ - seq_start_loc, \ - query_start_loc = make_metadata_tensors(is_prompt, - seq_lens, - context_lens, - device=device) + seq_start_loc = make_metadata_tensors(seq_lens, + context_lens, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -641,10 +627,8 @@ def make_test_metadata( num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, max_prefill_seq_len=max(seq_lens), max_decode_seq_len=0, - query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, @@ -662,14 +646,11 @@ def make_test_metadata( seq_lens_tensor, \ context_lens_tensor, \ - max_query_len, \ _, \ _, \ - seq_start_loc, \ - query_start_loc = make_metadata_tensors(is_prompt, - seq_lens, - context_lens, - device=device) + seq_start_loc = make_metadata_tensors(seq_lens, + context_lens, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -678,10 +659,8 @@ def make_test_metadata( num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, max_prefill_seq_len=0, max_decode_seq_len=max(seq_lens), - query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 4ef052424f176..bbff0b2dac906 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -80,8 +80,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] # FIXME: It is for flash attn. # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. @@ -89,23 +87,29 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] + seq_start_loc: Optional[torch.Tensor] = None + # (batch_size,) A tensor of context lengths (tokens that are computed # so far). - context_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] = None - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None # Self-attention prefill/decode metadata cache _self_cached_prefill_metadata: Optional["XFormersMetadata"] = None @@ -194,10 +198,12 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: assert self.seq_lens is not None assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None assert self.context_lens_tensor is not None assert self.block_tables is not None + query_start_loc = None if self.query_start_loc is None \ + else self.query_start_loc[:self.num_prefills + 1] + self._self_cached_prefill_metadata = XFormersMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, @@ -208,7 +214,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], + query_start_loc=query_start_loc, seq_start_loc=None, context_lens_tensor=self.context_lens_tensor[:self. num_prefills], @@ -231,10 +237,12 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: assert self.seq_lens is not None assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None assert self.context_lens_tensor is not None assert self.block_tables is not None + query_start_loc = None if self.query_start_loc is None \ + else self.query_start_loc[:self.num_prefills + 1] + self._cross_cached_prefill_metadata = XFormersMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, @@ -245,7 +253,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], + query_start_loc=query_start_loc, seq_start_loc=None, context_lens_tensor=self.context_lens_tensor[:self. num_prefills], @@ -503,6 +511,9 @@ def forward( assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: + assert prefill_meta.query_start_loc is not None + assert prefill_meta.max_query_len is not None + # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, From 79f307d218f87ccac90776859446d34690b6591b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 06:42:41 -0400 Subject: [PATCH 127/443] refactored slot mapping logic --- tests/kernels/test_self_and_cross_attn.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 7cf6736b213d1..d28c173f510fa 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -513,17 +513,13 @@ def make_block_tables_slot_mapping(block_size: int, num_blocks = num_blocks_list[sdx] block_table = list( range(block_base_idx, block_base_idx - num_blocks, -1)) - for idx in range(num_tokens - 1): - prefill_slot_mapping.append((idx % block_size) + - block_table[idx // block_size] * - block_size) - slot_mapping.append((idx % block_size) + - block_table[idx // block_size] * block_size) - idx = num_tokens - 1 - decode_slot_mapping.append((idx % block_size) + - block_table[idx // block_size] * block_size) - slot_mapping.append((idx % block_size) + - block_table[idx // block_size] * block_size) + for idx in range(num_tokens): + mapping_value = (idx % block_size) + block_table[idx // block_size] * block_size + slot_mapping.append(mapping_value) + if idx < num_tokens - 1: + prefill_slot_mapping.append(mapping_value) + elif idx == num_tokens - 1: + decode_slot_mapping.append(mapping_value) block_base_idx -= num_blocks block_tables.append(block_table) From e790a00fbde0175d81bab5806d378f68d3d08b9e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 06:43:35 -0400 Subject: [PATCH 128/443] formatting --- tests/kernels/test_self_and_cross_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index d28c173f510fa..c127cfbfa7e93 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -514,7 +514,8 @@ def make_block_tables_slot_mapping(block_size: int, block_table = list( range(block_base_idx, block_base_idx - num_blocks, -1)) for idx in range(num_tokens): - mapping_value = (idx % block_size) + block_table[idx // block_size] * block_size + mapping_value = ( + idx % block_size) + block_table[idx // block_size] * block_size slot_mapping.append(mapping_value) if idx < num_tokens - 1: prefill_slot_mapping.append(mapping_value) From aae601b88788eb8db2394793990cdab47c4c2264 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 08:01:41 -0400 Subject: [PATCH 129/443] selective renaming of cross -> encoder --- tests/kernels/test_self_and_cross_attn.py | 18 ++++---- vllm/attention/backends/xformers.py | 50 +++++++++++------------ 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index c127cfbfa7e93..e4a9993a143b3 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -564,7 +564,7 @@ def make_test_metadata( slot_mapping: torch.Tensor, is_encoder_only_test: bool, device: Union[torch.device, str] = CUDA_DEVICE, - cross_seq_lens: Optional[List[int]] = None, + encoder_seq_lens: Optional[List[int]] = None, cross_block_tables: Optional[torch.Tensor] = None, cross_slot_mapping: Optional[List[int]] = None, ) -> AttentionMetadata: @@ -591,7 +591,7 @@ def make_test_metadata( * is_encoder_only_test: True if testing encoder; False if testing decoder self-attention or encoder/decoder cross-attention. * device: CPU or CUDA device - * cross_seq_lens: list of token counts for each encoder sequence, if any + * encoder_seq_lens: list of token counts for each encoder sequence, if any exist * cross_block_tables: cross-attention block tables, if required * cross_slot_mapping: cross-attention slot mapping, if required @@ -631,7 +631,7 @@ def make_test_metadata( block_tables=block_tables, use_cuda_graph=False, _attn_type=default_attn_type, - cross_seq_lens=cross_seq_lens, + encoder_seq_lens=encoder_seq_lens, cross_slot_mapping=cross_slot_mapping, cross_block_tables=cross_block_tables) @@ -663,7 +663,7 @@ def make_test_metadata( block_tables=block_tables, use_cuda_graph=False, _attn_type=default_attn_type, - cross_seq_lens=cross_seq_lens, + encoder_seq_lens=encoder_seq_lens, cross_slot_mapping=cross_slot_mapping, cross_block_tables=cross_block_tables) @@ -1387,7 +1387,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_prefill_packed_value, \ cross_prefill_packed_ideal_output, \ cross_decode_packed_ideal_output, \ - cross_kv_seq_lens, \ + encoder_kv_seq_lens, \ cross_decode_block_tables, \ cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ @@ -1416,7 +1416,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_block_tables, self_prefill_slot_mapping, is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, + encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, ) @@ -1460,7 +1460,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_block_tables, self_decode_slot_mapping, is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, + encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, ) @@ -1609,7 +1609,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, cross_prefill_packed_value, \ cross_prefill_packed_ideal_output, \ cross_decode_packed_ideal_output, \ - cross_kv_seq_lens, \ + encoder_kv_seq_lens, \ cross_decode_block_tables, \ cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ @@ -1638,7 +1638,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, self_prefill_block_tables, self_prefill_slot_mapping, is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, + encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index bbff0b2dac906..fca06a1eeebb4 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -128,13 +128,13 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # (batch_size,). The "cross-sequence-length" per sequence,i.e. the key/value # sequence length (usually encoder sequence length) in the cross-attention # computation. None if this is self-attention - cross_seq_lens: Optional[List[int]] = None - cross_seq_lens_tensor: Optional[torch.Tensor] = None + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None # The maximum cross-sequence-length, if cross_seq_lens is specified. # Note that for cross-attention there is no difference in key/value # sequence length between prefill and decode - max_cross_seq_len: Optional[int] = None + max_encoder_seq_len: Optional[int] = None # Cross-attention memory-mapping data structures: slot mapping # and block tables @@ -152,7 +152,7 @@ def __post_init__(self): @property def is_all_cross_attn_metadata_set(self): # No cross-attention metadata is present whatsoever - return (self.cross_seq_lens is not None) and \ + return (self.encoder_seq_lens is not None) and \ (self.cross_slot_mapping is not None) and \ (self.cross_block_tables is not None) @@ -170,15 +170,15 @@ def attention_type(self, atype: AttentionType) -> None: # Infer implicit cross-attention fields # from user-provided fields, if needed - if self.cross_seq_lens_tensor is None: + if self.encoder_seq_lens_tensor is None: assert self.seq_lens_tensor is not None - self.cross_seq_lens_tensor = torch.tensor( - self.cross_seq_lens, + self.encoder_seq_lens_tensor = torch.tensor( + self.encoder_seq_lens, dtype=self.seq_lens_tensor.dtype, device=self.seq_lens_tensor.device) - if self.max_cross_seq_len is None: - assert self.cross_seq_lens is not None - self.max_cross_seq_len = max(self.cross_seq_lens) + if self.max_encoder_seq_len is None: + assert self.encoder_seq_lens is not None + self.max_encoder_seq_len = max(self.encoder_seq_lens) self._attn_type = AttentionType.ENCODER_DECODER else: @@ -222,9 +222,9 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: use_cuda_graph=False, _attn_type=self. _attn_type, # Begin cross-attention fields below... - cross_seq_lens=None, - cross_seq_lens_tensor=None, - max_cross_seq_len=None, + encoder_seq_lens=None, + encoder_seq_lens_tensor=None, + max_encoder_seq_len=None, cross_block_tables=None, cross_slot_mapping=None) return self._self_cached_prefill_metadata @@ -261,9 +261,9 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: use_cuda_graph=False, _attn_type=AttentionType.ENCODER_DECODER, # Begin cross-attention fields below... - cross_seq_lens=self.cross_seq_lens, - cross_seq_lens_tensor=self.cross_seq_lens_tensor, - max_cross_seq_len=self.max_cross_seq_len, + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) return self._cross_cached_prefill_metadata @@ -298,9 +298,9 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: use_cuda_graph=self.use_cuda_graph, _attn_type=self. _attn_type, # Begin cross-attention fields below... - cross_seq_lens=None, - cross_seq_lens_tensor=None, - max_cross_seq_len=None, + encoder_seq_lens=None, + encoder_seq_lens_tensor=None, + max_encoder_seq_len=None, cross_block_tables=None, cross_slot_mapping=None) return self._self_cached_decode_metadata @@ -330,9 +330,9 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: use_cuda_graph=self.use_cuda_graph, _attn_type=AttentionType.ENCODER_DECODER, # Begin cross-attention fields below... - cross_seq_lens=self.cross_seq_lens, - cross_seq_lens_tensor=self.cross_seq_lens_tensor, - max_cross_seq_len=self.max_cross_seq_len, + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) return self._cross_cached_decode_metadata @@ -540,8 +540,8 @@ def forward( if decode_meta := attn_metadata.decode_metadata: if attn_type == AttentionType.ENCODER_DECODER: # Paged attention against cross-attention KV cache - seq_lens_arg = decode_meta.cross_seq_lens_tensor - max_seq_len_arg = decode_meta.max_cross_seq_len + seq_lens_arg = decode_meta.encoder_seq_lens_tensor + max_seq_len_arg = decode_meta.max_encoder_seq_len block_tables_arg = decode_meta.cross_block_tables else: # Paged attention against self-attention KV cache @@ -610,7 +610,7 @@ def _run_memory_efficient_xformers_forward( AttentionType.ENCODER_DECODER: # Default enc/dec cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens, attn_metadata.cross_seq_lens) + attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) else: if attn_metadata.attention_type == AttentionType.ENCODER: # Default encoder self-attention mask is non-causal From d871c9fd8eb9774a9c3699240d34ab05cc0df6ef Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 08:04:18 -0400 Subject: [PATCH 130/443] added encoder, enc/dec cross-attention bias members --- vllm/attention/backends/xformers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index fca06a1eeebb4..d653c1d501e71 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -148,6 +148,8 @@ def __post_init__(self): # from xformer API. # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[AttentionBias]] = None + self.encoder_attn_bias: Optional[List[AttentionBias]] = None + self.cross_attn_bias: Optional[List[AttentionBias]] = None @property def is_all_cross_attn_metadata_set(self): From 90d5c0dfcbadd6dba3005095455338029965f167 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 08:20:29 -0400 Subject: [PATCH 131/443] xformers metadata now uses a different attn_bias for self, encoder and cross --- vllm/attention/backends/xformers.py | 34 +++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d653c1d501e71..0432ec4e87e85 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -339,6 +339,29 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: cross_block_tables=self.cross_block_tables) return self._cross_cached_decode_metadata +def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ + Optional[List[Optional[AttentionBias]]]: + attn_type = attn_metadata.attention_type + if attn_type == AttentionType.DECODER: + return attn_metadata.attn_bias + elif attn_type == AttentionType.ENCODER: + return attn_metadata.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return attn_metadata.cross_attn_bias + else: + raise AttributeError(f"Invalid attn_metadata.attention_type {attn_type}") + +def _set_attn_bias(attn_metadata: XFormersMetadata, + attn_bias: List[Optional[AttentionBias]]) -> None: + attn_type = attn_metadata.attention_type + if attn_type == AttentionType.DECODER: + attn_metadata.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + attn_metadata.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + attn_metadata.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attn_metadata.attention_type {attn_type}") class XFormersImpl(AttentionImpl[XFormersMetadata]): """ @@ -606,7 +629,8 @@ def _run_memory_efficient_xformers_forward( # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. - if attn_metadata.attn_bias is None: + attn_bias = _get_attn_bias(attn_metadata) + if attn_bias is None: if self.alibi_slopes is None: if attn_metadata.attention_type == \ AttentionType.ENCODER_DECODER: @@ -625,12 +649,14 @@ def _run_memory_efficient_xformers_forward( if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) - attn_metadata.attn_bias = [attn_bias] + attn_bias = [attn_bias] else: - attn_metadata.attn_bias = _make_alibi_bias( + attn_bias = _make_alibi_bias( self.alibi_slopes, self.num_kv_heads, query.dtype, attn_metadata.seq_lens) + _set_attn_bias(attn_metadata,attn_bias) + # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. @@ -643,7 +669,7 @@ def _run_memory_efficient_xformers_forward( query, key, value, - attn_bias=attn_metadata.attn_bias[0], + attn_bias=attn_bias[0], p=0.0, scale=self.scale) return out.view_as(original_query) From a973c2be7240f50adc95a093bbdd684d6c7cfc07 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 08:24:03 -0400 Subject: [PATCH 132/443] refactoring --- vllm/attention/backends/xformers.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0432ec4e87e85..368170d3cc4ca 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -341,6 +341,20 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ Optional[List[Optional[AttentionBias]]]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Depends on attn_metadata having a valid attention_type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + + Returns: + * Appropriate attention bias value + ''' + attn_type = attn_metadata.attention_type if attn_type == AttentionType.DECODER: return attn_metadata.attn_bias @@ -353,6 +367,18 @@ def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ def _set_attn_bias(attn_metadata: XFormersMetadata, attn_bias: List[Optional[AttentionBias]]) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Depends on attn_metadata having a valid attention_type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + ''' + attn_type = attn_metadata.attention_type if attn_type == AttentionType.DECODER: attn_metadata.attn_bias = attn_bias From d8d284ed372a863ae0ec95f469f3fa28aadefca8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 08:34:38 -0400 Subject: [PATCH 133/443] wip typing issues --- vllm/attention/backends/xformers.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 368170d3cc4ca..74ffca1ee4b6a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -363,9 +363,11 @@ def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ elif attn_type == AttentionType.ENCODER_DECODER: return attn_metadata.cross_attn_bias else: - raise AttributeError(f"Invalid attn_metadata.attention_type {attn_type}") + raise AttributeError( + f"Invalid attn_metadata.attention_type {attn_type}") -def _set_attn_bias(attn_metadata: XFormersMetadata, + +def _set_attn_bias(attn_metadata: XFormersMetadata, attn_bias: List[Optional[AttentionBias]]) -> None: ''' Update appropriate attention bias field of attention metadata, @@ -387,7 +389,9 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, elif attn_type == AttentionType.ENCODER_DECODER: attn_metadata.cross_attn_bias = attn_bias else: - raise AttributeError(f"Invalid attn_metadata.attention_type {attn_type}") + raise AttributeError( + f"Invalid attn_metadata.attention_type {attn_type}") + class XFormersImpl(AttentionImpl[XFormersMetadata]): """ @@ -677,11 +681,11 @@ def _run_memory_efficient_xformers_forward( self.sliding_window) attn_bias = [attn_bias] else: - attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, query.dtype, - attn_metadata.seq_lens) + attn_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, query.dtype, + attn_metadata.seq_lens) - _set_attn_bias(attn_metadata,attn_bias) + _set_attn_bias(attn_metadata, attn_bias) # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce @@ -712,7 +716,7 @@ def _run_memory_efficient_xformers_forward( query[None, start:end], key[None, start:end], value[None, start:end], - attn_bias=attn_metadata.attn_bias[i], + attn_bias=attn_bias[i], p=0.0, scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. From 27dc095e8c6b10b2c4470b6d4028a1b31053c7b1 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 08:53:13 -0400 Subject: [PATCH 134/443] added paged attention args collection, conditional on metadata attention type --- vllm/attention/backends/xformers.py | 55 ++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 74ffca1ee4b6a..eb5a5f0032ed6 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -364,7 +364,7 @@ def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ return attn_metadata.cross_attn_bias else: raise AttributeError( - f"Invalid attn_metadata.attention_type {attn_type}") + f"Invalid attn_metadata.attention_type {str(attn_type)}") def _set_attn_bias(attn_metadata: XFormersMetadata, @@ -390,7 +390,44 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, attn_metadata.cross_attn_bias = attn_bias else: raise AttributeError( - f"Invalid attn_metadata.attention_type {attn_type}") + f"Invalid attn_metadata.attention_type {str(attn_type)}") + + +def _get_paged_attention_args(attn_metadata: XFormersMetadata) -> tuple: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Depends on attn_metadata having a valid attention_type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + + Returns: + * Appropriate attention bias value + ''' + + attn_type = attn_metadata.attention_type + if attn_type == AttentionType.DECODER: + # Decoder self-attention + return attn_metadata.seq_lens_tensor, \ + attn_metadata.max_decode_seq_len, \ + attn_metadata.block_tables + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return attn_metadata.encoder_seq_lens_tensor, \ + attn_metadata.max_encoder_seq_len, \ + attn_metadata.cross_block_tables + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return attn_metadata.encoder_seq_lens_tensor, \ + attn_metadata.max_encoder_seq_len, \ + None + else: + raise AttributeError( + f"Invalid attn_metadata.attention_type {str(attn_type)}") class XFormersImpl(AttentionImpl[XFormersMetadata]): @@ -593,16 +630,10 @@ def forward( output[:num_prefill_tokens] = out if decode_meta := attn_metadata.decode_metadata: - if attn_type == AttentionType.ENCODER_DECODER: - # Paged attention against cross-attention KV cache - seq_lens_arg = decode_meta.encoder_seq_lens_tensor - max_seq_len_arg = decode_meta.max_encoder_seq_len - block_tables_arg = decode_meta.cross_block_tables - else: - # Paged attention against self-attention KV cache - seq_lens_arg = decode_meta.seq_lens_tensor - max_seq_len_arg = decode_meta.max_decode_seq_len - block_tables_arg = decode_meta.block_tables + + seq_lens_arg, \ + max_seq_len_arg,\ + block_tables_arg = _get_paged_attention_args(decode_meta) output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, From 24459051c1295a73ebfaf26e19c15836cd1c8aca Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 09:41:51 -0400 Subject: [PATCH 135/443] logic to support encoder-specific sequence length usage in xformers --- tests/kernels/test_self_and_cross_attn.py | 1 + vllm/attention/backends/xformers.py | 80 ++++++++++++++++------- 2 files changed, 58 insertions(+), 23 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index e4a9993a143b3..239c4a6139c64 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1278,6 +1278,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, block_tables, slot_mapping, is_encoder_only_test=True, + encoder_seq_lens=q_seq_lens ) packed_actual_output: torch.Tensor = \ diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index eb5a5f0032ed6..82a38f3af8123 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -150,11 +150,25 @@ def __post_init__(self): self.attn_bias: Optional[List[AttentionBias]] = None self.encoder_attn_bias: Optional[List[AttentionBias]] = None self.cross_attn_bias: Optional[List[AttentionBias]] = None + + if self.is_all_encoder_attn_metadata_set: + self._maybe_compute_implicit_encoder_attrs() + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return self.encoder_seq_lens is not None @property def is_all_cross_attn_metadata_set(self): - # No cross-attention metadata is present whatsoever - return (self.encoder_seq_lens is not None) and \ + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return self.is_all_encoder_attn_metadata_set and \ (self.cross_slot_mapping is not None) and \ (self.cross_block_tables is not None) @@ -162,27 +176,40 @@ def is_all_cross_attn_metadata_set(self): def attention_type(self) -> AttentionType: return self._attn_type + def _maybe_compute_implicit_encoder_attrs(self): + ''' + Encoder attention and cross-attention require some encoder-related + metadata attributes which may or may not be been provided by the user. + This method infers the implicit attributes from provided attributes + ''' + if self.encoder_seq_lens_tensor is None: + assert self.seq_lens_tensor is not None + self.encoder_seq_lens_tensor = torch.tensor( + self.encoder_seq_lens, + dtype=self.seq_lens_tensor.dtype, + device=self.seq_lens_tensor.device) + if self.max_encoder_seq_len is None: + assert self.encoder_seq_lens is not None + self.max_encoder_seq_len = max(self.encoder_seq_lens) + @attention_type.setter def attention_type(self, atype: AttentionType) -> None: if atype == AttentionType.ENCODER_DECODER: assert self.is_all_cross_attn_metadata_set, \ - "Must enable self.cross_seq_lens, self.cross_slot_mapping, " + \ + "Must set self.encoder_seq_lens, self.cross_slot_mapping, " + \ "self.cross_block_tables in order to perform cross-attention" - # Infer implicit cross-attention fields - # from user-provided fields, if needed - if self.encoder_seq_lens_tensor is None: - assert self.seq_lens_tensor is not None - self.encoder_seq_lens_tensor = torch.tensor( - self.encoder_seq_lens, - dtype=self.seq_lens_tensor.dtype, - device=self.seq_lens_tensor.device) - if self.max_encoder_seq_len is None: - assert self.encoder_seq_lens is not None - self.max_encoder_seq_len = max(self.encoder_seq_lens) + self._maybe_compute_implicit_encoder_attrs() self._attn_type = AttentionType.ENCODER_DECODER + elif atype == AttentionType.ENCODER: + assert self.is_all_encoder_attn_metadata_set, \ + "Must set self.encoder_seq_lens in order to perform cross-attention" + + self._maybe_compute_implicit_encoder_attrs() + + self._attn_type = AttentionType.ENCODER else: # AttentionType.{ENCODER,DECODER} self._attn_type = atype @@ -224,9 +251,9 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: use_cuda_graph=False, _attn_type=self. _attn_type, # Begin cross-attention fields below... - encoder_seq_lens=None, - encoder_seq_lens_tensor=None, - max_encoder_seq_len=None, + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, cross_block_tables=None, cross_slot_mapping=None) return self._self_cached_prefill_metadata @@ -300,9 +327,9 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: use_cuda_graph=self.use_cuda_graph, _attn_type=self. _attn_type, # Begin cross-attention fields below... - encoder_seq_lens=None, - encoder_seq_lens_tensor=None, - max_encoder_seq_len=None, + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, cross_block_tables=None, cross_slot_mapping=None) return self._self_cached_decode_metadata @@ -393,7 +420,7 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, f"Invalid attn_metadata.attention_type {str(attn_type)}") -def _get_paged_attention_args(attn_metadata: XFormersMetadata) -> tuple: +def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata) -> tuple: ''' Extract appropriate attention bias from attention metadata according to attention type. @@ -633,7 +660,7 @@ def forward( seq_lens_arg, \ max_seq_len_arg,\ - block_tables_arg = _get_paged_attention_args(decode_meta) + block_tables_arg = _get_seq_len_block_table_args(decode_meta) output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, @@ -672,7 +699,14 @@ def _run_memory_efficient_xformers_forward( value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ - assert attn_metadata.seq_lens is not None + + # Enforce that the appropriate *_seq_lens attribute of attn_metadata + # (seq_lens or encoder_seq_lens) is set. + seq_lens, \ + _,\ + _ = _get_seq_len_block_table_args(attn_metadata) + assert seq_lens is not None + original_query = query if self.num_kv_heads != self.num_heads: # GQA/MQA requires the shape [B, M, G, H, K]. From 8dabdc2abf59179064667461fb339ef4eb758247 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 09:43:26 -0400 Subject: [PATCH 136/443] formatting --- tests/kernels/test_self_and_cross_attn.py | 3 +-- vllm/attention/backends/xformers.py | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 239c4a6139c64..c0d62f7531b93 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1278,8 +1278,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, block_tables, slot_mapping, is_encoder_only_test=True, - encoder_seq_lens=q_seq_lens - ) + encoder_seq_lens=q_seq_lens) packed_actual_output: torch.Tensor = \ run_encoder_or_decoder_self_attention_test( diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 82a38f3af8123..648706b4bf07c 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -150,7 +150,7 @@ def __post_init__(self): self.attn_bias: Optional[List[AttentionBias]] = None self.encoder_attn_bias: Optional[List[AttentionBias]] = None self.cross_attn_bias: Optional[List[AttentionBias]] = None - + if self.is_all_encoder_attn_metadata_set: self._maybe_compute_implicit_encoder_attrs() @@ -172,10 +172,6 @@ def is_all_cross_attn_metadata_set(self): (self.cross_slot_mapping is not None) and \ (self.cross_block_tables is not None) - @property - def attention_type(self) -> AttentionType: - return self._attn_type - def _maybe_compute_implicit_encoder_attrs(self): ''' Encoder attention and cross-attention require some encoder-related @@ -192,6 +188,10 @@ def _maybe_compute_implicit_encoder_attrs(self): assert self.encoder_seq_lens is not None self.max_encoder_seq_len = max(self.encoder_seq_lens) + @property + def attention_type(self) -> AttentionType: + return self._attn_type + @attention_type.setter def attention_type(self, atype: AttentionType) -> None: From c6200e676bf89c0ef1116b3d18f520448c5e912f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 10:34:51 -0400 Subject: [PATCH 137/443] test name change; encoder functionality can tolerate being provided with only encoder metadata --- ...and_decoder_self_and_encdec_cross_attn.py} | 36 ++++++++++++------ vllm/attention/backends/xformers.py | 38 ++++++++++--------- 2 files changed, 45 insertions(+), 29 deletions(-) rename tests/kernels/{test_self_and_cross_attn.py => test_encoder_and_decoder_self_and_encdec_cross_attn.py} (98%) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py similarity index 98% rename from tests/kernels/test_self_and_cross_attn.py rename to tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py index c0d62f7531b93..8e1b8849c8a9f 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py @@ -387,16 +387,18 @@ def make_metadata_tensors(seq_lens: List[int], * seq_start_loc: start idx of each sequence * query_start_loc: start idx of each query ''' - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) + seq_lens_tensor = None if seq_lens is None else \ + torch.tensor(seq_lens, dtype=torch.int, device=device) context_lens_tensor = None if context_lens is None else torch.tensor( context_lens, dtype=torch.int, device=device) max_context_len = None if context_lens is None else max(context_lens) max_seq_len = None if seq_lens is None else max(seq_lens) - seq_start_loc = torch.cat([ - torch.tensor([0], dtype=torch.int32, device=device), - torch.cumsum(seq_lens_tensor, dim=0, dtype=torch.int32) - ]) + seq_start_loc = None + # seq_start_loc = torch.cat([ + # torch.tensor([0], dtype=torch.int32, device=device), + # torch.cumsum(seq_lens_tensor, dim=0, dtype=torch.int32) + # ]) return seq_lens_tensor, \ context_lens_tensor, \ @@ -563,6 +565,8 @@ def make_test_metadata( block_tables: torch.Tensor, slot_mapping: torch.Tensor, is_encoder_only_test: bool, + num_prefills_or_decodes: int, + num_prefill_or_decode_tokens: int, device: Union[torch.device, str] = CUDA_DEVICE, encoder_seq_lens: Optional[List[int]] = None, cross_block_tables: Optional[torch.Tensor] = None, @@ -605,8 +609,8 @@ def make_test_metadata( else AttentionType.DECODER if is_prompt: - num_prefills = len(seq_lens) - num_prefill_tokens = sum(seq_lens) + num_prefills = num_prefills_or_decodes + num_prefill_tokens = num_prefill_or_decode_tokens num_decode_tokens = 0 seq_lens_tensor, \ @@ -624,9 +628,9 @@ def make_test_metadata( num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_prefill_seq_len=max(seq_lens), + max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, - seq_start_loc=seq_start_loc, + # seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -639,7 +643,7 @@ def make_test_metadata( num_prefills = 0 num_prefill_tokens = 0 - num_decode_tokens = len(seq_lens) + num_decode_tokens = num_prefill_or_decode_tokens seq_lens_tensor, \ context_lens_tensor, \ @@ -658,7 +662,7 @@ def make_test_metadata( seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=max(seq_lens), - seq_start_loc=seq_start_loc, + # seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -1273,11 +1277,13 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, - q_seq_lens, + None, context_lens, block_tables, slot_mapping, is_encoder_only_test=True, + num_prefills_or_decodes=len(q_seq_lens), + num_prefill_or_decode_tokens=sum(q_seq_lens), encoder_seq_lens=q_seq_lens) packed_actual_output: torch.Tensor = \ @@ -1416,6 +1422,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_block_tables, self_prefill_slot_mapping, is_encoder_only_test=False, + num_prefills_or_decodes=len(prefill_q_seq_lens), + num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, @@ -1460,6 +1468,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_block_tables, self_decode_slot_mapping, is_encoder_only_test=False, + num_prefills_or_decodes=len(q_seq_lens), + num_prefill_or_decode_tokens=len(q_seq_lens), encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, @@ -1638,6 +1648,8 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, self_prefill_block_tables, self_prefill_slot_mapping, is_encoder_only_test=False, + num_prefills_or_decodes=len(prefill_q_seq_lens), + num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 648706b4bf07c..a7bdbb4d03839 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -179,13 +179,11 @@ def _maybe_compute_implicit_encoder_attrs(self): This method infers the implicit attributes from provided attributes ''' if self.encoder_seq_lens_tensor is None: - assert self.seq_lens_tensor is not None self.encoder_seq_lens_tensor = torch.tensor( self.encoder_seq_lens, - dtype=self.seq_lens_tensor.dtype, - device=self.seq_lens_tensor.device) + dtype=torch.int32, + device="cuda:0") if self.max_encoder_seq_len is None: - assert self.encoder_seq_lens is not None self.max_encoder_seq_len = max(self.encoder_seq_lens) @property @@ -225,8 +223,10 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self._self_cached_prefill_metadata is not None: return self._self_cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None + assert (self.seq_lens is not None) or \ + (self.encoder_seq_lens is not None) + assert (self.seq_lens_tensor is not None) or \ + (self.encoder_seq_lens_tensor is not None) assert self.context_lens_tensor is not None assert self.block_tables is not None @@ -238,8 +238,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + seq_lens=None if self.seq_lens is None else self.seq_lens[:self.num_prefills], + seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -264,8 +264,10 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self._cross_cached_prefill_metadata is not None: return self._cross_cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None + assert (self.seq_lens is not None) or \ + (self.encoder_seq_lens is not None) + assert (self.seq_lens_tensor is not None) or \ + (self.encoder_seq_lens_tensor is not None) assert self.context_lens_tensor is not None assert self.block_tables is not None @@ -277,8 +279,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + seq_lens=None if self.seq_lens is None else self.seq_lens[:self.num_prefills], + seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -308,7 +310,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: if self._self_cached_decode_metadata is not None: return self._self_cached_decode_metadata assert self.block_tables is not None - assert self.seq_lens_tensor is not None + assert (self.seq_lens_tensor is not None) or \ + (self.encoder_seq_lens_tensor is not None) self._self_cached_decode_metadata = XFormersMetadata( num_prefills=0, @@ -316,7 +319,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[self.num_prefills:], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, @@ -340,7 +343,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: if self._cross_cached_decode_metadata is not None: return self._cross_cached_decode_metadata assert self.block_tables is not None - assert self.seq_lens_tensor is not None + assert (self.seq_lens_tensor is not None) or \ + (self.encoder_seq_lens_tensor is not None) self._cross_cached_decode_metadata = XFormersMetadata( num_prefills=0, @@ -348,7 +352,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[self.num_prefills:], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, @@ -736,7 +740,7 @@ def _run_memory_efficient_xformers_forward( if attn_metadata.attention_type == AttentionType.ENCODER: # Default encoder self-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens) + attn_metadata.encoder_seq_lens) else: # Default decoder self-attention mask is causal attn_bias = BlockDiagonalCausalMask.from_seqlens( From 0dc197b794de728dc5f40ce6d0e2058d8099397a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 10:35:35 -0400 Subject: [PATCH 138/443] formatting --- vllm/attention/backends/xformers.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index a7bdbb4d03839..d3af144c742ba 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -179,10 +179,9 @@ def _maybe_compute_implicit_encoder_attrs(self): This method infers the implicit attributes from provided attributes ''' if self.encoder_seq_lens_tensor is None: - self.encoder_seq_lens_tensor = torch.tensor( - self.encoder_seq_lens, - dtype=torch.int32, - device="cuda:0") + self.encoder_seq_lens_tensor = torch.tensor(self.encoder_seq_lens, + dtype=torch.int32, + device="cuda:0") if self.max_encoder_seq_len is None: self.max_encoder_seq_len = max(self.encoder_seq_lens) @@ -238,8 +237,10 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None if self.seq_lens is None else self.seq_lens[:self.num_prefills], - seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills], + seq_lens=None if self.seq_lens is None else + self.seq_lens[:self.num_prefills], + seq_lens_tensor=None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -279,8 +280,10 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None if self.seq_lens is None else self.seq_lens[:self.num_prefills], - seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills], + seq_lens=None if self.seq_lens is None else + self.seq_lens[:self.num_prefills], + seq_lens_tensor=None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -319,7 +322,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, - seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, @@ -352,7 +356,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, - seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, From 3d3c04ff2aa6be6e93fb0dbb8aca24739d848668 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 10:48:31 -0400 Subject: [PATCH 139/443] prefill supports shared metadata structure --- vllm/attention/backends/xformers.py | 210 +++++++++++++++++----------- 1 file changed, 125 insertions(+), 85 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d3af144c742ba..dc3d3fdff4af4 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -115,10 +115,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): _self_cached_prefill_metadata: Optional["XFormersMetadata"] = None _self_cached_decode_metadata: Optional["XFormersMetadata"] = None # Cross-attention prefill/decode metadata cache - _cross_cached_prefill_metadata: Optional["XFormersMetadata"] = None _cross_cached_decode_metadata: Optional["XFormersMetadata"] = None - # Begin cross-attention fields... + # Begin encoder attn & enc/dec cross-attn fields... # If True, prefill_metadata() and decode_metadata() will return # seqlen & memory-mapping data structures for cross-attention; @@ -216,91 +215,132 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self.num_prefills == 0: return None - if self._attn_type != AttentionType.ENCODER_DECODER: - # Decoder or encoder self-attention prefill - - if self._self_cached_prefill_metadata is not None: - return self._self_cached_prefill_metadata - - assert (self.seq_lens is not None) or \ - (self.encoder_seq_lens is not None) - assert (self.seq_lens_tensor is not None) or \ - (self.encoder_seq_lens_tensor is not None) - assert self.context_lens_tensor is not None - assert self.block_tables is not None - - query_start_loc = None if self.query_start_loc is None \ - else self.query_start_loc[:self.num_prefills + 1] - - self._self_cached_prefill_metadata = XFormersMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None if self.seq_lens is None else - self.seq_lens[:self.num_prefills], - seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self. - num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - _attn_type=self. - _attn_type, # Begin cross-attention fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_block_tables=None, - cross_slot_mapping=None) + if self._self_cached_prefill_metadata is not None: + self._self_cached_prefill_metadata.attention_type = self.attention_type return self._self_cached_prefill_metadata - else: - # Encoder/decoder cross-attention prefill - - if self._cross_cached_prefill_metadata is not None: - return self._cross_cached_prefill_metadata - - assert (self.seq_lens is not None) or \ - (self.encoder_seq_lens is not None) - assert (self.seq_lens_tensor is not None) or \ - (self.encoder_seq_lens_tensor is not None) - assert self.context_lens_tensor is not None - assert self.block_tables is not None - - query_start_loc = None if self.query_start_loc is None \ - else self.query_start_loc[:self.num_prefills + 1] - - self._cross_cached_prefill_metadata = XFormersMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None if self.seq_lens is None else - self.seq_lens[:self.num_prefills], - seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self. - num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - _attn_type=AttentionType.ENCODER_DECODER, - # Begin cross-attention fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cross_cached_prefill_metadata + assert (self.seq_lens is not None) or \ + (self.encoder_seq_lens is not None) + assert (self.seq_lens_tensor is not None) or \ + (self.encoder_seq_lens_tensor is not None) + assert self.context_lens_tensor is not None + assert self.block_tables is not None + + query_start_loc = None if self.query_start_loc is None \ + else self.query_start_loc[:self.num_prefills + 1] + + self._self_cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=None if self.seq_lens is None else + self.seq_lens[:self.num_prefills], + seq_lens_tensor=None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=None, + context_lens_tensor=self.context_lens_tensor[:self. + num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + _attn_type=self.attention_type, + # Begin cross-attention fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._self_cached_prefill_metadata + + # if self._attn_type != AttentionType.ENCODER_DECODER: + # # Decoder or encoder self-attention prefill + + # if self._self_cached_prefill_metadata is not None: + # return self._self_cached_prefill_metadata + + # assert (self.seq_lens is not None) or \ + # (self.encoder_seq_lens is not None) + # assert (self.seq_lens_tensor is not None) or \ + # (self.encoder_seq_lens_tensor is not None) + # assert self.context_lens_tensor is not None + # assert self.block_tables is not None + + # query_start_loc = None if self.query_start_loc is None \ + # else self.query_start_loc[:self.num_prefills + 1] + + # self._self_cached_prefill_metadata = XFormersMetadata( + # num_prefills=self.num_prefills, + # num_prefill_tokens=self.num_prefill_tokens, + # num_decode_tokens=0, + # slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + # seq_lens=None if self.seq_lens is None else + # self.seq_lens[:self.num_prefills], + # seq_lens_tensor=None if self.seq_lens_tensor is None else + # self.seq_lens_tensor[:self.num_prefills], + # max_query_len=self.max_query_len, + # max_prefill_seq_len=self.max_prefill_seq_len, + # max_decode_seq_len=0, + # query_start_loc=query_start_loc, + # seq_start_loc=None, + # context_lens_tensor=self.context_lens_tensor[:self. + # num_prefills], + # block_tables=self.block_tables[:self.num_prefills], + # use_cuda_graph=False, + # _attn_type=self. + # _attn_type, # Begin cross-attention fields below... + # encoder_seq_lens=self.encoder_seq_lens, + # encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + # max_encoder_seq_len=self.max_encoder_seq_len, + # cross_block_tables=None, + # cross_slot_mapping=None) + # return self._self_cached_prefill_metadata + + # else: + # # Encoder/decoder cross-attention prefill + + # if self._cross_cached_prefill_metadata is not None: + # return self._cross_cached_prefill_metadata + + # assert (self.seq_lens is not None) or \ + # (self.encoder_seq_lens is not None) + # assert (self.seq_lens_tensor is not None) or \ + # (self.encoder_seq_lens_tensor is not None) + # assert self.context_lens_tensor is not None + # assert self.block_tables is not None + + # query_start_loc = None if self.query_start_loc is None \ + # else self.query_start_loc[:self.num_prefills + 1] + + # self._cross_cached_prefill_metadata = XFormersMetadata( + # num_prefills=self.num_prefills, + # num_prefill_tokens=self.num_prefill_tokens, + # num_decode_tokens=0, + # slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + # seq_lens=None if self.seq_lens is None else + # self.seq_lens[:self.num_prefills], + # seq_lens_tensor=None if self.seq_lens_tensor is None else + # self.seq_lens_tensor[:self.num_prefills], + # max_query_len=self.max_query_len, + # max_prefill_seq_len=self.max_prefill_seq_len, + # max_decode_seq_len=0, + # query_start_loc=query_start_loc, + # seq_start_loc=None, + # context_lens_tensor=self.context_lens_tensor[:self. + # num_prefills], + # block_tables=self.block_tables[:self.num_prefills], + # use_cuda_graph=False, + # _attn_type=AttentionType.ENCODER_DECODER, + # # Begin cross-attention fields below... + # encoder_seq_lens=self.encoder_seq_lens, + # encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + # max_encoder_seq_len=self.max_encoder_seq_len, + # cross_slot_mapping=self.cross_slot_mapping, + # cross_block_tables=self.cross_block_tables) + # return self._cross_cached_prefill_metadata @property def decode_metadata(self) -> Optional["XFormersMetadata"]: From c132caaa3be6bf67eacd628f0d21816de23e9dc7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 10:49:23 -0400 Subject: [PATCH 140/443] formatting --- vllm/attention/backends/xformers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index dc3d3fdff4af4..d52b1e8cd000c 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -216,7 +216,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: return None if self._self_cached_prefill_metadata is not None: - self._self_cached_prefill_metadata.attention_type = self.attention_type + self._self_cached_prefill_metadata.attention_type = \ + self.attention_type return self._self_cached_prefill_metadata assert (self.seq_lens is not None) or \ @@ -234,8 +235,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None if self.seq_lens is None else - self.seq_lens[:self.num_prefills], + seq_lens=None + if self.seq_lens is None else self.seq_lens[:self.num_prefills], seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -243,8 +244,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: max_decode_seq_len=0, query_start_loc=query_start_loc, seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self. - num_prefills], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, _attn_type=self.attention_type, From 39ee51a508b82693e8bf03634667011a88d2388d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 10:54:34 -0400 Subject: [PATCH 141/443] full generalization of prefill & decode metadata structures --- vllm/attention/backends/xformers.py | 181 +++++----------------------- 1 file changed, 30 insertions(+), 151 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d52b1e8cd000c..56a57ac5c4465 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -256,164 +256,43 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: cross_block_tables=self.cross_block_tables) return self._self_cached_prefill_metadata - # if self._attn_type != AttentionType.ENCODER_DECODER: - # # Decoder or encoder self-attention prefill - - # if self._self_cached_prefill_metadata is not None: - # return self._self_cached_prefill_metadata - - # assert (self.seq_lens is not None) or \ - # (self.encoder_seq_lens is not None) - # assert (self.seq_lens_tensor is not None) or \ - # (self.encoder_seq_lens_tensor is not None) - # assert self.context_lens_tensor is not None - # assert self.block_tables is not None - - # query_start_loc = None if self.query_start_loc is None \ - # else self.query_start_loc[:self.num_prefills + 1] - - # self._self_cached_prefill_metadata = XFormersMetadata( - # num_prefills=self.num_prefills, - # num_prefill_tokens=self.num_prefill_tokens, - # num_decode_tokens=0, - # slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - # seq_lens=None if self.seq_lens is None else - # self.seq_lens[:self.num_prefills], - # seq_lens_tensor=None if self.seq_lens_tensor is None else - # self.seq_lens_tensor[:self.num_prefills], - # max_query_len=self.max_query_len, - # max_prefill_seq_len=self.max_prefill_seq_len, - # max_decode_seq_len=0, - # query_start_loc=query_start_loc, - # seq_start_loc=None, - # context_lens_tensor=self.context_lens_tensor[:self. - # num_prefills], - # block_tables=self.block_tables[:self.num_prefills], - # use_cuda_graph=False, - # _attn_type=self. - # _attn_type, # Begin cross-attention fields below... - # encoder_seq_lens=self.encoder_seq_lens, - # encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - # max_encoder_seq_len=self.max_encoder_seq_len, - # cross_block_tables=None, - # cross_slot_mapping=None) - # return self._self_cached_prefill_metadata - - # else: - # # Encoder/decoder cross-attention prefill - - # if self._cross_cached_prefill_metadata is not None: - # return self._cross_cached_prefill_metadata - - # assert (self.seq_lens is not None) or \ - # (self.encoder_seq_lens is not None) - # assert (self.seq_lens_tensor is not None) or \ - # (self.encoder_seq_lens_tensor is not None) - # assert self.context_lens_tensor is not None - # assert self.block_tables is not None - - # query_start_loc = None if self.query_start_loc is None \ - # else self.query_start_loc[:self.num_prefills + 1] - - # self._cross_cached_prefill_metadata = XFormersMetadata( - # num_prefills=self.num_prefills, - # num_prefill_tokens=self.num_prefill_tokens, - # num_decode_tokens=0, - # slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - # seq_lens=None if self.seq_lens is None else - # self.seq_lens[:self.num_prefills], - # seq_lens_tensor=None if self.seq_lens_tensor is None else - # self.seq_lens_tensor[:self.num_prefills], - # max_query_len=self.max_query_len, - # max_prefill_seq_len=self.max_prefill_seq_len, - # max_decode_seq_len=0, - # query_start_loc=query_start_loc, - # seq_start_loc=None, - # context_lens_tensor=self.context_lens_tensor[:self. - # num_prefills], - # block_tables=self.block_tables[:self.num_prefills], - # use_cuda_graph=False, - # _attn_type=AttentionType.ENCODER_DECODER, - # # Begin cross-attention fields below... - # encoder_seq_lens=self.encoder_seq_lens, - # encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - # max_encoder_seq_len=self.max_encoder_seq_len, - # cross_slot_mapping=self.cross_slot_mapping, - # cross_block_tables=self.cross_block_tables) - # return self._cross_cached_prefill_metadata - @property def decode_metadata(self) -> Optional["XFormersMetadata"]: if self.num_decode_tokens == 0: return None - if self._attn_type != AttentionType.ENCODER_DECODER: - # Decoder or encoder self-attention prefill - - if self._self_cached_decode_metadata is not None: - return self._self_cached_decode_metadata - assert self.block_tables is not None - assert (self.seq_lens_tensor is not None) or \ - (self.encoder_seq_lens_tensor is not None) - - self._self_cached_decode_metadata = XFormersMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - _attn_type=self. - _attn_type, # Begin cross-attention fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_block_tables=None, - cross_slot_mapping=None) + if self._self_cached_decode_metadata is not None: + self._self_cached_decode_metadata.attention_type = \ + self.attention_type return self._self_cached_decode_metadata + assert self.block_tables is not None + assert (self.seq_lens_tensor is not None) or \ + (self.encoder_seq_lens_tensor is not None) - else: - # Encoder/decoder cross-attention decode - - if self._cross_cached_decode_metadata is not None: - return self._cross_cached_decode_metadata - assert self.block_tables is not None - assert (self.seq_lens_tensor is not None) or \ - (self.encoder_seq_lens_tensor is not None) - - self._cross_cached_decode_metadata = XFormersMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - _attn_type=AttentionType.ENCODER_DECODER, - # Begin cross-attention fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cross_cached_decode_metadata + self._self_cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + _attn_type=self. + _attn_type, # Begin cross-attention fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._self_cached_decode_metadata def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ Optional[List[Optional[AttentionBias]]]: From c41917bee6490dcb062e61c559ec3c17d7c951b8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 11:02:13 -0400 Subject: [PATCH 142/443] renamed metdata caching structure --- vllm/attention/backends/xformers.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 56a57ac5c4465..b56e566a89fe1 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -112,10 +112,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): query_start_loc: Optional[torch.Tensor] = None # Self-attention prefill/decode metadata cache - _self_cached_prefill_metadata: Optional["XFormersMetadata"] = None - _self_cached_decode_metadata: Optional["XFormersMetadata"] = None - # Cross-attention prefill/decode metadata cache - _cross_cached_decode_metadata: Optional["XFormersMetadata"] = None + _cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cached_decode_metadata: Optional["XFormersMetadata"] = None # Begin encoder attn & enc/dec cross-attn fields... @@ -215,10 +213,10 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self.num_prefills == 0: return None - if self._self_cached_prefill_metadata is not None: - self._self_cached_prefill_metadata.attention_type = \ + if self._cached_prefill_metadata is not None: + self._cached_prefill_metadata.attention_type = \ self.attention_type - return self._self_cached_prefill_metadata + return self._cached_prefill_metadata assert (self.seq_lens is not None) or \ (self.encoder_seq_lens is not None) @@ -230,7 +228,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: query_start_loc = None if self.query_start_loc is None \ else self.query_start_loc[:self.num_prefills + 1] - self._self_cached_prefill_metadata = XFormersMetadata( + self._cached_prefill_metadata = XFormersMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, @@ -254,22 +252,22 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) - return self._self_cached_prefill_metadata + return self._cached_prefill_metadata @property def decode_metadata(self) -> Optional["XFormersMetadata"]: if self.num_decode_tokens == 0: return None - if self._self_cached_decode_metadata is not None: - self._self_cached_decode_metadata.attention_type = \ + if self._cached_decode_metadata is not None: + self._cached_decode_metadata.attention_type = \ self.attention_type - return self._self_cached_decode_metadata + return self._cached_decode_metadata assert self.block_tables is not None assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) - self._self_cached_decode_metadata = XFormersMetadata( + self._cached_decode_metadata = XFormersMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, @@ -292,7 +290,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) - return self._self_cached_decode_metadata + return self._cached_decode_metadata def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ Optional[List[Optional[AttentionBias]]]: From e738fb4ee2c98f6af912e61c9a4a99acfe887dcf Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 11:47:40 -0400 Subject: [PATCH 143/443] reverted my custom env var patch impl --- tests/kernels/utils.py | 25 ------------------------- tests/utils.py | 26 -------------------------- 2 files changed, 51 deletions(-) delete mode 100644 tests/kernels/utils.py diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py deleted file mode 100644 index 8ebc2fc5905aa..0000000000000 --- a/tests/kernels/utils.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Kernel test utils""" - -from contextlib import contextmanager -from typing import Iterator - -from tests.utils import env_var_fixture - - -@contextmanager -def backend_override_fixture(backend_name: str) -> Iterator[None]: - ''' - Text fixture, temporarily configures the vLLM backend by setting - VLLM_ATTENTION_BACKEND, then resets the environment outside of - the fixture. - - Usage: - - with backend_override_fixture("backend_name"): - # code that depends on vLLM backend - - # VLLM_ATTENTION_BACKEND is returned to original value - # or unset - ''' - with env_var_fixture('VLLM_ATTENTION_BACKEND', backend_name): - yield diff --git a/tests/utils.py b/tests/utils.py index adbff8e8dc1c6..329842911e159 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,6 @@ import time import warnings from contextlib import contextmanager -from typing import Iterator import ray import requests @@ -102,28 +101,3 @@ def error_on_warning(): warnings.simplefilter("error") yield - - -@contextmanager -def env_var_fixture(var_name: str, value: str) -> Iterator[None]: - ''' - Text fixture, temporarily assigns value var_name environment variable, - then resets environment variable outside of test fixture. - - Usage: - - with env_var_fixture("my_var","my_val"): - # code that depends on my_val == "my_val" - - # my_var is returned to original value or unset - ''' - original_value = os.environ.get(var_name) # Store the original value - os.environ[var_name] = value # Set the new value - try: - yield - finally: - # Restore the original value - if original_value is None: - del os.environ[var_name] - else: - os.environ[var_name] = original_value From dfe9c10389beccfc43b2bddf687075e07e7283b9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:01:52 -0400 Subject: [PATCH 144/443] monkeypatch works --- tests/kernels/test_attention_selector.py | 93 ++++++++++++------------ 1 file changed, 47 insertions(+), 46 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index b0b383974904c..ebd2d460dc45e 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -1,79 +1,80 @@ +import os from unittest.mock import patch import pytest import torch -from tests.kernels.utils import backend_override_fixture from vllm.attention.selector import which_attn_to_use +_backend_env_var = "VLLM_ATTENTION_BACKEND" +_flash_attn_val = "FLASH_ATTN" +_invalid_val = "INVALID" + @pytest.mark.parametrize( "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) -def test_env(name: str, device: str): +def test_env(name: str, device: str, monkeypatch): """Test that the attention selector can be set via environment variable. Note that we do not test FlashAttn because it is the default backend. """ - with backend_override_fixture(name): - - if device == "cpu": - with patch("vllm.attention.selector.is_cpu", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == "TORCH_SDPA" - elif device == "hip": - with patch("vllm.attention.selector.is_hip", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == "ROCM_FLASH" - else: + monkeypatch.setenv(_backend_env_var,name) + + if device == "cpu": + with patch("vllm.attention.selector.is_cpu", return_value=True): backend = which_attn_to_use(8, 16, 8, None, torch.float16, torch.float16, 16) - assert backend.name == name + assert backend.name == "TORCH_SDPA" + elif device == "hip": + with patch("vllm.attention.selector.is_hip", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "ROCM_FLASH" + else: + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == name -def test_flash_attn(): +def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - with backend_override_fixture("FLASH_ATTN"): - - # Unsupported CUDA arch - with patch("torch.cuda.get_device_capability", return_value=[7, 5]): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, - 16) - assert backend.name != "FLASH_ATTN" + monkeypatch.setenv(_backend_env_var,_flash_attn_val) - # Unsupported data type - backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, - 16) + # Unsupported CUDA arch + with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) assert backend.name != "FLASH_ATTN" - # Unsupported kv cache data type - backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) - assert backend.name != "FLASH_ATTN" + # Unsupported data type + backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) + assert backend.name != "FLASH_ATTN" - # Unsupported block size - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) - assert backend.name != "FLASH_ATTN" + # Unsupported kv cache data type + backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + assert backend.name != "FLASH_ATTN" - # Unsupported sliding window - backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + # Unsupported block size + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + assert backend.name != "FLASH_ATTN" - # flash-attn is not installed - with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, - 16) - assert backend.name != "FLASH_ATTN" + # Unsupported sliding window + backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" - # Unsupported head size - backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + # flash-attn is not installed + with patch.dict('sys.modules', {'vllm_flash_attn': None}): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) assert backend.name != "FLASH_ATTN" + # Unsupported head size + backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" -def test_invalid_env(): - """Throw an exception if the backend name is invalid.""" - with backend_override_fixture("INVALID"), pytest.raises(ValueError): +def test_invalid_env(monkeypatch): + """Throw an exception if the backend name is invalid.""" + monkeypatch.setenv(_backend_env_var,_invalid_val) + with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) From 822175834eb2846c826a63357f731966ed83abce Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:02:31 -0400 Subject: [PATCH 145/443] formatting --- tests/kernels/test_attention_selector.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index ebd2d460dc45e..0b4bb7c353cc3 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -1,4 +1,3 @@ -import os from unittest.mock import patch import pytest @@ -19,7 +18,7 @@ def test_env(name: str, device: str, monkeypatch): Note that we do not test FlashAttn because it is the default backend. """ - monkeypatch.setenv(_backend_env_var,name) + monkeypatch.setenv(_backend_env_var, name) if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): @@ -40,7 +39,7 @@ def test_env(name: str, device: str, monkeypatch): def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - monkeypatch.setenv(_backend_env_var,_flash_attn_val) + monkeypatch.setenv(_backend_env_var, _flash_attn_val) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): @@ -75,6 +74,6 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" - monkeypatch.setenv(_backend_env_var,_invalid_val) + monkeypatch.setenv(_backend_env_var, _invalid_val) with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) From db2b2d23f8e17683f03ca36e0b9ea41e0ce0c74f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:05:23 -0400 Subject: [PATCH 146/443] wip monkeypatch --- .../test_encoder_and_decoder_self_and_encdec_cross_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py b/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py index 8e1b8849c8a9f..1b830a1dde63a 100644 --- a/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py +++ b/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py @@ -6,7 +6,6 @@ import pytest import torch -from tests.kernels.utils import backend_override_fixture from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import ( @@ -14,6 +13,8 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import is_hip, make_tensor_with_pad +_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" + HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] From cbb89b1dd912cb816d3c7d721bb129baa08c83b3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:08:56 -0400 Subject: [PATCH 147/443] refactored constants into tests/kernels/utils.py --- tests/kernels/test_attention_selector.py | 12 +++++------- tests/kernels/utils.py | 5 +++++ 2 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 tests/kernels/utils.py diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 0b4bb7c353cc3..7bc0439f3ee82 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -3,12 +3,10 @@ import pytest import torch +from tests.kernels.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, + STR_INVALID_VAL) from vllm.attention.selector import which_attn_to_use -_backend_env_var = "VLLM_ATTENTION_BACKEND" -_flash_attn_val = "FLASH_ATTN" -_invalid_val = "INVALID" - @pytest.mark.parametrize( "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) @@ -18,7 +16,7 @@ def test_env(name: str, device: str, monkeypatch): Note that we do not test FlashAttn because it is the default backend. """ - monkeypatch.setenv(_backend_env_var, name) + monkeypatch.setenv(STR_BACKEND_ENV_VAR, name) if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): @@ -39,7 +37,7 @@ def test_env(name: str, device: str, monkeypatch): def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - monkeypatch.setenv(_backend_env_var, _flash_attn_val) + monkeypatch.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): @@ -74,6 +72,6 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" - monkeypatch.setenv(_backend_env_var, _invalid_val) + monkeypatch.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py new file mode 100644 index 0000000000000..74ad9d8256e3f --- /dev/null +++ b/tests/kernels/utils.py @@ -0,0 +1,5 @@ +"""Kernel test utils""" + +STR_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" +STR_FLASH_ATTN_VAL = "FLASH_ATTN" +STR_INVALID_VAL = "INVALID" \ No newline at end of file From c9ce86be2f3b973e9df5057b6ef5ab956224f28e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:14:33 -0400 Subject: [PATCH 148/443] wip enc/dec monkeypatch integration --- ..._and_decoder_self_and_encdec_cross_attn.py | 144 +++++++++--------- 1 file changed, 73 insertions(+), 71 deletions(-) diff --git a/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py b/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py index 1b830a1dde63a..d87e548376257 100644 --- a/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py +++ b/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import copy import itertools import random @@ -6,6 +8,8 @@ import pytest import torch +from tests.kernels.utils import STR_BACKEND_ENV_VAR + from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import ( @@ -13,7 +17,6 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import is_hip, make_tensor_with_pad -_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" HEAD_SIZES = [64, 256] @@ -1195,8 +1198,7 @@ def run_encoder_decoder_cross_attention_test( ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER return attn.forward(packed_query, packed_key, packed_value, kv_cache, - attn_metadata) - + attn_metadata) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -1207,7 +1209,7 @@ def run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("max_seq_len", MAX_Q_SEQ_LENS) def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, - max_seq_len: int) -> None: + max_seq_len: int, monkeypatch) -> None: ''' Encoder-only attention test: @@ -1234,73 +1236,73 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, layer output.) ''' - with backend_override_fixture(backend_name): - # Force Attention wrapper backend - - # Attention scale factor, attention backend instance, attention wrapper - # instance. Encoder attention does not require KV cache. - scale, \ - attn_backend, \ - attn, \ - _ = basic_setup(num_heads, - head_size, - None, - None, - backend_name) - - # Self-attention setup - # Let encoder_attn_setup() choose default block table - # base address; the block_tables and slot_mapping - # tensors are not actually utilized by encoder attention - # anyway but are required to be present & valid by the - # backend. - packed_query, \ - packed_key, \ - packed_value, \ - packed_ideal_output, \ - block_tables, \ - slot_mapping, \ - q_seq_lens = encoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_seq_len) - - context_lens = [0 for _ in range(batch_size)] - - # Metadata config for encoder attention: - # - # * Use prefill kernel - # * Signal that this is an encoder-only test so that - # metadata attention_type property is correctly - # configured as AttentionType.ENCODER - attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - None, - context_lens, - block_tables, - slot_mapping, - is_encoder_only_test=True, - num_prefills_or_decodes=len(q_seq_lens), - num_prefill_or_decode_tokens=sum(q_seq_lens), - encoder_seq_lens=q_seq_lens) - - packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( - attn, - packed_query, - packed_key, - packed_value, - None, - attn_metadata, - attn_type=AttentionType.ENCODER) - - # - Is encoder attention result correct? - assert torch.allclose( - packed_ideal_output, - packed_actual_output.view_as(packed_ideal_output)) + # Force Attention wrapper backend + monkeypatch.setenv(STR_BACKEND_ENV_VAR,backend_name) + + # Attention scale factor, attention backend instance, attention wrapper + # instance. Encoder attention does not require KV cache. + scale, \ + attn_backend, \ + attn, \ + _ = basic_setup(num_heads, + head_size, + None, + None, + backend_name) + + # Self-attention setup + # Let encoder_attn_setup() choose default block table + # base address; the block_tables and slot_mapping + # tensors are not actually utilized by encoder attention + # anyway but are required to be present & valid by the + # backend. + packed_query, \ + packed_key, \ + packed_value, \ + packed_ideal_output, \ + block_tables, \ + slot_mapping, \ + q_seq_lens = encoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_seq_len) + + context_lens = [0 for _ in range(batch_size)] + + # Metadata config for encoder attention: + # + # * Use prefill kernel + # * Signal that this is an encoder-only test so that + # metadata attention_type property is correctly + # configured as AttentionType.ENCODER + attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + None, + context_lens, + block_tables, + slot_mapping, + is_encoder_only_test=True, + num_prefills_or_decodes=len(q_seq_lens), + num_prefill_or_decode_tokens=sum(q_seq_lens), + encoder_seq_lens=q_seq_lens) + + packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( + attn, + packed_query, + packed_key, + packed_value, + None, + attn_metadata, + attn_type=AttentionType.ENCODER) + + # - Is encoder attention result correct? + assert torch.allclose( + packed_ideal_output, + packed_actual_output.view_as(packed_ideal_output)) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) From ca570e7a078540d8788dc6b7cf961039f699a761 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:17:57 -0400 Subject: [PATCH 149/443] a refactoring backend override functionality into tests/kernels/utils.py --- tests/kernels/test_attention_selector.py | 10 +++++----- tests/kernels/utils.py | 6 +++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 7bc0439f3ee82..ea3ccb026ea2b 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -3,8 +3,8 @@ import pytest import torch -from tests.kernels.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, - STR_INVALID_VAL) +from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL, + override_backend) from vllm.attention.selector import which_attn_to_use @@ -16,7 +16,7 @@ def test_env(name: str, device: str, monkeypatch): Note that we do not test FlashAttn because it is the default backend. """ - monkeypatch.setenv(STR_BACKEND_ENV_VAR, name) + override_backend(monkeypatch, name) if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): @@ -37,7 +37,7 @@ def test_env(name: str, device: str, monkeypatch): def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - monkeypatch.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) + override_backend(monkeypatch, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): @@ -72,6 +72,6 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" - monkeypatch.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) + override_backend(monkeypatch, STR_INVALID_VAL) with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 74ad9d8256e3f..3874fad57ae43 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,4 +2,8 @@ STR_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" STR_FLASH_ATTN_VAL = "FLASH_ATTN" -STR_INVALID_VAL = "INVALID" \ No newline at end of file +STR_INVALID_VAL = "INVALID" + + +def override_backend(mpatch, backend_name): + mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name) From b2e131f95143984ddc0a38f5b9b08e9abc421b45 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:26:07 -0400 Subject: [PATCH 150/443] test rename --- ...s_attn.py => test_encoder_decoder_attn.py} | 605 +++++++++--------- 1 file changed, 303 insertions(+), 302 deletions(-) rename tests/kernels/{test_encoder_and_decoder_self_and_encdec_cross_attn.py => test_encoder_decoder_attn.py} (80%) diff --git a/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py b/tests/kernels/test_encoder_decoder_attn.py similarity index 80% rename from tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py rename to tests/kernels/test_encoder_decoder_attn.py index d87e548376257..0a835940c3d75 100644 --- a/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -1,4 +1,10 @@ -from unittest.mock import patch +""" +Test + +* Encoder attention +* Decoder self-attention +* Encoder/decoder cross-attention +""" import copy import itertools @@ -8,8 +14,7 @@ import pytest import torch -from tests.kernels.utils import STR_BACKEND_ENV_VAR - +from tests.kernels.utils import override_backend from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import ( @@ -17,7 +22,6 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import is_hip, make_tensor_with_pad - HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] @@ -399,10 +403,6 @@ def make_metadata_tensors(seq_lens: List[int], max_seq_len = None if seq_lens is None else max(seq_lens) seq_start_loc = None - # seq_start_loc = torch.cat([ - # torch.tensor([0], dtype=torch.int32, device=device), - # torch.cumsum(seq_lens_tensor, dim=0, dtype=torch.int32) - # ]) return seq_lens_tensor, \ context_lens_tensor, \ @@ -634,7 +634,6 @@ def make_test_metadata( seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, - # seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -666,7 +665,6 @@ def make_test_metadata( seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=max(seq_lens), - # seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -1198,7 +1196,8 @@ def run_encoder_decoder_cross_attention_test( ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER return attn.forward(packed_query, packed_key, packed_value, kv_cache, - attn_metadata) + attn_metadata) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -1208,8 +1207,8 @@ def run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_seq_len", MAX_Q_SEQ_LENS) def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, - batch_size: int, block_size: int, - max_seq_len: int, monkeypatch) -> None: + batch_size: int, block_size: int, max_seq_len: int, + monkeypatch) -> None: ''' Encoder-only attention test: @@ -1237,7 +1236,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, ''' # Force Attention wrapper backend - monkeypatch.setenv(STR_BACKEND_ENV_VAR,backend_name) + override_backend(monkeypatch, backend_name) # Attention scale factor, attention backend instance, attention wrapper # instance. Encoder attention does not require KV cache. @@ -1300,9 +1299,8 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose( - packed_ideal_output, - packed_actual_output.view_as(packed_ideal_output)) + assert torch.allclose(packed_ideal_output, + packed_actual_output.view_as(packed_ideal_output)) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -1315,7 +1313,8 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, @pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) def test_enc_dec_self_and_cross_attention_prefill_decode_phases( num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_q_seq_len: int, max_kv_seq_len: int) -> None: + block_size: int, max_q_seq_len: int, max_kv_seq_len: int, + monkeypatch) -> None: ''' Encoder/decoder attention test: @@ -1342,192 +1341,192 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( for cross-attention. ''' - with backend_override_fixture(backend_name): - # Force Attention wrapper backend - - # Num KV cache blocks - num_blocks = 4096 - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - scale, \ - attn_backend, \ - attn, \ - kv_cache = basic_setup(num_heads, - head_size, - num_blocks, - block_size, - backend_name) - - # Self-attention setup - - self_block_base_addr = 0 - - query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ - _, \ - _, \ - q_seq_lens, \ - _, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ - cross_block_base_addr = decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=self_block_base_addr) - - # Cross-attention setup - - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ - encoder_kv_seq_lens, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ - _ = enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) - - # PREFILL: self- and cross-attention tests - - context_lens = [0 for _ in range(batch_size)] - - prefill_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - prefill_q_seq_lens, - context_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=False, - num_prefills_or_decodes=len(prefill_q_seq_lens), - num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), - encoder_seq_lens=encoder_kv_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, - ) - - self_prefill_packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( - attn, - prefill_packed_query, - self_prefill_packed_key, - self_prefill_packed_value, - kv_cache, - prefill_attn_metadata, - attn_type=AttentionType.DECODER) - - # - Prefill self-attention correct? - assert torch.allclose( - self_prefill_packed_ideal_output, - self_prefill_packed_actual_output.view_as( - self_prefill_packed_ideal_output)) - - cross_prefill_packed_actual_output: torch.Tensor = \ - run_encoder_decoder_cross_attention_test( - attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, prefill_attn_metadata) - - # - Prefill cross-attention correct? - assert torch.allclose( - cross_prefill_packed_ideal_output, - cross_prefill_packed_actual_output.view_as( - cross_prefill_packed_ideal_output)) - - context_lens = copy.deepcopy(self_prefill_kv_seq_lens) - - # DECODE: self- and cross-attention tests - - decode_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - False, - q_seq_lens, - context_lens, - self_decode_block_tables, - self_decode_slot_mapping, - is_encoder_only_test=False, - num_prefills_or_decodes=len(q_seq_lens), - num_prefill_or_decode_tokens=len(q_seq_lens), - encoder_seq_lens=encoder_kv_seq_lens, - cross_block_tables=cross_decode_block_tables, - cross_slot_mapping=cross_decode_slot_mapping, - ) - - self_decode_packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( - attn, - decode_packed_query, - self_decode_packed_key, - self_decode_packed_value, - kv_cache, - decode_attn_metadata, - attn_type=AttentionType.DECODER) - - # - Decode self-attention correct? - assert torch.allclose( - self_decode_packed_ideal_output, - self_decode_packed_actual_output.view_as( - self_decode_packed_ideal_output)) - - cross_decode_packed_actual_output: torch.Tensor = \ - run_encoder_decoder_cross_attention_test( - attn, decode_packed_query, None, - None, kv_cache, decode_attn_metadata) - - # - Decode cross-attention correct? - assert torch.allclose( - cross_decode_packed_ideal_output, - cross_decode_packed_actual_output.view_as( - cross_decode_packed_ideal_output)) - - # The following test conditions could in principle be a - # standalone test, however the test setup is - # so involved that it is easier - # to piggyback off of the test vectors & other data structures - # created for testing decode-phase encoder/decoder cross- - # attention above. - # ---- - # Set up a contrived scenario where the attention metadata - # is configured for chunked prefill & encoder/decoder cross- - # attention. Required that this triggers a NotImplementedError. - # - # We assume that decode_attn_metadata.num_decode_tokens > 1 - # already; the line below sets up a chunked prefill - # metadata configuration where there is nominally a mix - # of prefill and decode tokens. - decode_attn_metadata.num_prefill_tokens = 1 - with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test(attn, decode_packed_query, - None, None, kv_cache, - decode_attn_metadata) - - # "Encoder decoder models do not currently support chunked prefill" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL + # Force Attention wrapper backend + override_backend(monkeypatch, backend_name) + + # Num KV cache blocks + num_blocks = 4096 + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, + backend_name) + + # Self-attention setup + + self_block_base_addr = 0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + _, \ + _, \ + q_seq_lens, \ + _, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = decoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + encoder_kv_seq_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + _ = enc_dec_cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + prefill_q_seq_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + is_encoder_only_test=False, + num_prefills_or_decodes=len(prefill_q_seq_lens), + num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), + encoder_seq_lens=encoder_kv_seq_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + self_prefill_packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( + attn, + prefill_packed_query, + self_prefill_packed_key, + self_prefill_packed_value, + kv_cache, + prefill_attn_metadata, + attn_type=AttentionType.DECODER) + + # - Prefill self-attention correct? + assert torch.allclose( + self_prefill_packed_ideal_output, + self_prefill_packed_actual_output.view_as( + self_prefill_packed_ideal_output)) + + cross_prefill_packed_actual_output: torch.Tensor = \ + run_encoder_decoder_cross_attention_test( + attn, prefill_packed_query, cross_prefill_packed_key, + cross_prefill_packed_value, kv_cache, prefill_attn_metadata) + + # - Prefill cross-attention correct? + assert torch.allclose( + cross_prefill_packed_ideal_output, + cross_prefill_packed_actual_output.view_as( + cross_prefill_packed_ideal_output)) + + context_lens = copy.deepcopy(self_prefill_kv_seq_lens) + + # DECODE: self- and cross-attention tests + + decode_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + False, + q_seq_lens, + context_lens, + self_decode_block_tables, + self_decode_slot_mapping, + is_encoder_only_test=False, + num_prefills_or_decodes=len(q_seq_lens), + num_prefill_or_decode_tokens=len(q_seq_lens), + encoder_seq_lens=encoder_kv_seq_lens, + cross_block_tables=cross_decode_block_tables, + cross_slot_mapping=cross_decode_slot_mapping, + ) + + self_decode_packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( + attn, + decode_packed_query, + self_decode_packed_key, + self_decode_packed_value, + kv_cache, + decode_attn_metadata, + attn_type=AttentionType.DECODER) + + # - Decode self-attention correct? + assert torch.allclose( + self_decode_packed_ideal_output, + self_decode_packed_actual_output.view_as( + self_decode_packed_ideal_output)) + + cross_decode_packed_actual_output: torch.Tensor = \ + run_encoder_decoder_cross_attention_test( + attn, decode_packed_query, None, + None, kv_cache, decode_attn_metadata) + + # - Decode cross-attention correct? + assert torch.allclose( + cross_decode_packed_ideal_output, + cross_decode_packed_actual_output.view_as( + cross_decode_packed_ideal_output)) + + # The following test conditions could in principle be a + # standalone test, however the test setup is + # so involved that it is easier + # to piggyback off of the test vectors & other data structures + # created for testing decode-phase encoder/decoder cross- + # attention above. + # ---- + # Set up a contrived scenario where the attention metadata + # is configured for chunked prefill & encoder/decoder cross- + # attention. Required that this triggers a NotImplementedError. + # + # We assume that decode_attn_metadata.num_decode_tokens > 1 + # already; the line below sets up a chunked prefill + # metadata configuration where there is nominally a mix + # of prefill and decode tokens. + decode_attn_metadata.num_prefill_tokens = 1 + with pytest.raises(NotImplementedError) as exc_info: + run_encoder_decoder_cross_attention_test(attn, decode_packed_query, + None, None, kv_cache, + decode_attn_metadata) + + # "Encoder decoder models do not currently support chunked prefill" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL @pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") @@ -1541,7 +1540,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_seq_len: int, - max_kv_seq_len: int) -> None: + max_kv_seq_len: int, monkeypatch) -> None: ''' Encoder/decoder not-implemented-for-ROCm-HIP test: @@ -1568,100 +1567,102 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, for cross-attention. ''' - with backend_override_fixture(backend_name): - # Force Attention wrapper backend - - # Num KV cache blocks - num_blocks = 4096 - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - scale, \ - attn_backend, \ - attn, \ - kv_cache = basic_setup(num_heads, - head_size, - num_blocks, - block_size, - backend_name) - - # Self-attention setup - - self_block_base_addr = 0 - - query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ - _, \ - _, \ - q_seq_lens, \ - _, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ - cross_block_base_addr = decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=self_block_base_addr) - - # Cross-attention setup - - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ - encoder_kv_seq_lens, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ - _ = enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) - - # PREFILL: self- and cross-attention tests - - context_lens = [0 for _ in range(batch_size)] - - prefill_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - prefill_q_seq_lens, - context_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=False, - num_prefills_or_decodes=len(prefill_q_seq_lens), - num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), - encoder_seq_lens=encoder_kv_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, - ) - - with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test( - attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, prefill_attn_metadata) - - # "Encoder decoder models do not currently support ROCm/HIP" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP + # Force Attention wrapper backend + override_backend(monkeypatch, backend_name) + + # Num KV cache blocks + num_blocks = 4096 + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, + backend_name) + + # Self-attention setup + + self_block_base_addr = 0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + _, \ + _, \ + q_seq_lens, \ + _, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = decoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + encoder_kv_seq_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + _ = enc_dec_cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + prefill_q_seq_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + is_encoder_only_test=False, + num_prefills_or_decodes=len(prefill_q_seq_lens), + num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), + encoder_seq_lens=encoder_kv_seq_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + with pytest.raises(NotImplementedError) as exc_info: + run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, + cross_prefill_packed_key, + cross_prefill_packed_value, + kv_cache, + prefill_attn_metadata) + + # "Encoder decoder models do not currently support ROCm/HIP" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP From ed8f8b3aa6eba958e0e527510f50aa3cc94f66c4 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:40:06 -0400 Subject: [PATCH 151/443] Comments & type hints --- tests/kernels/utils.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 3874fad57ae43..fb28924c5f9c4 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1,9 +1,21 @@ """Kernel test utils""" -STR_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" -STR_FLASH_ATTN_VAL = "FLASH_ATTN" -STR_INVALID_VAL = "INVALID" +import pytest +STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" +STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" +STR_INVALID_VAL: str = "INVALID" -def override_backend(mpatch, backend_name): + +def override_backend(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: + ''' + Override vLLM attention backend temporarily, + using pytest monkeypatch to ensure that the env vars get + reset once the test context exits. + + Arguments: + + * mpatch: pytest monkeypatch instance + * backend_name: attention backend name to force + ''' mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name) From 8abe51c8b6091de871600ea83d6fc1837eb4db79 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 13:15:28 -0400 Subject: [PATCH 152/443] small refactors per @sroy745 suggestions --- tests/kernels/test_attention_selector.py | 8 ++++---- tests/kernels/utils.py | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index ea3ccb026ea2b..79e03c7478de0 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -4,7 +4,7 @@ import torch from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL, - override_backend) + override_backend_env_variable) from vllm.attention.selector import which_attn_to_use @@ -16,7 +16,7 @@ def test_env(name: str, device: str, monkeypatch): Note that we do not test FlashAttn because it is the default backend. """ - override_backend(monkeypatch, name) + override_backend_env_variable(monkeypatch, name) if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): @@ -37,7 +37,7 @@ def test_env(name: str, device: str, monkeypatch): def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - override_backend(monkeypatch, STR_FLASH_ATTN_VAL) + override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): @@ -72,6 +72,6 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" - override_backend(monkeypatch, STR_INVALID_VAL) + override_backend_env_variable(monkeypatch, STR_INVALID_VAL) with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index fb28924c5f9c4..b401eb87d3ec3 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -7,9 +7,10 @@ STR_INVALID_VAL: str = "INVALID" -def override_backend(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: +def override_backend_env_variable(mpatch: pytest.MonkeyPatch, + backend_name: str) -> None: ''' - Override vLLM attention backend temporarily, + Override the environment variable indicating the vLLM backend temporarily, using pytest monkeypatch to ensure that the env vars get reset once the test context exits. From da1b64839dcd0c650410d83e520a26bf326d6913 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 13:58:11 -0400 Subject: [PATCH 153/443] merged backend env config --- tests/kernels/test_encoder_decoder_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 0a835940c3d75..2bcb0668f83dc 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -14,7 +14,7 @@ import pytest import torch -from tests.kernels.utils import override_backend +from tests.kernels.utils import override_backend_env_variable from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import ( @@ -1236,7 +1236,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, ''' # Force Attention wrapper backend - override_backend(monkeypatch, backend_name) + override_backend_env_variable(monkeypatch, backend_name) # Attention scale factor, attention backend instance, attention wrapper # instance. Encoder attention does not require KV cache. @@ -1342,7 +1342,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( ''' # Force Attention wrapper backend - override_backend(monkeypatch, backend_name) + override_backend_env_variable(monkeypatch, backend_name) # Num KV cache blocks num_blocks = 4096 @@ -1568,7 +1568,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, ''' # Force Attention wrapper backend - override_backend(monkeypatch, backend_name) + override_backend_env_variable(monkeypatch, backend_name) # Num KV cache blocks num_blocks = 4096 From 60a21e3cb7e4d4c96d6222be7d1fb0fa01634d47 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 14:35:23 -0400 Subject: [PATCH 154/443] fixed _get_seq_len_block_table_args() to change behavior based on is_prompt --- tests/kernels/test_encoder_decoder_attn.py | 5 +-- vllm/attention/backends/xformers.py | 42 ++++++++++++++-------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 2bcb0668f83dc..fa4db135323fd 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -373,6 +373,7 @@ def make_backend(backend_name: str) -> AttentionBackend: def make_metadata_tensors(seq_lens: List[int], context_lens: List[int], + encoder_seq_lens: List[int], device: Union[torch.device, str] = \ CUDA_DEVICE) -> tuple: ''' @@ -621,7 +622,7 @@ def make_test_metadata( context_lens_tensor, \ _, \ _, \ - seq_start_loc = make_metadata_tensors(seq_lens, + _ = make_metadata_tensors(seq_lens, context_lens, device=device) @@ -652,7 +653,7 @@ def make_test_metadata( context_lens_tensor, \ _, \ _, \ - seq_start_loc = make_metadata_tensors(seq_lens, + _ = make_metadata_tensors(seq_lens, context_lens, device=device) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b56e566a89fe1..01272c1d9628a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -148,8 +148,8 @@ def __post_init__(self): self.encoder_attn_bias: Optional[List[AttentionBias]] = None self.cross_attn_bias: Optional[List[AttentionBias]] = None - if self.is_all_encoder_attn_metadata_set: - self._maybe_compute_implicit_encoder_attrs() + # if self.is_all_encoder_attn_metadata_set: + # self._maybe_compute_implicit_encoder_attrs() @property def is_all_encoder_attn_metadata_set(self): @@ -194,14 +194,14 @@ def attention_type(self, atype: AttentionType) -> None: "Must set self.encoder_seq_lens, self.cross_slot_mapping, " + \ "self.cross_block_tables in order to perform cross-attention" - self._maybe_compute_implicit_encoder_attrs() + # self._maybe_compute_implicit_encoder_attrs() self._attn_type = AttentionType.ENCODER_DECODER elif atype == AttentionType.ENCODER: assert self.is_all_encoder_attn_metadata_set, \ "Must set self.encoder_seq_lens in order to perform cross-attention" - self._maybe_compute_implicit_encoder_attrs() + # self._maybe_compute_implicit_encoder_attrs() self._attn_type = AttentionType.ENCODER else: @@ -346,26 +346,38 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, f"Invalid attn_metadata.attention_type {str(attn_type)}") -def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata) -> tuple: +def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, is_prompt: bool) -> tuple: ''' - Extract appropriate attention bias from attention metadata - according to attention type. - - Depends on attn_metadata having a valid attention_type. - + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + Arguments: - * attn_metadata: Attention metadata structure associated with attention + * attn_metadata: Attention metadata structure associated with attention op Returns: - * Appropriate attention bias value + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) ''' attn_type = attn_metadata.attention_type if attn_type == AttentionType.DECODER: # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len return attn_metadata.seq_lens_tensor, \ - attn_metadata.max_decode_seq_len, \ + max_seq_len, \ attn_metadata.block_tables elif attn_type == AttentionType.ENCODER_DECODER: # Enc/dec cross-attention KVs match encoder sequence length; @@ -586,7 +598,7 @@ def forward( seq_lens_arg, \ max_seq_len_arg,\ - block_tables_arg = _get_seq_len_block_table_args(decode_meta) + block_tables_arg = _get_seq_len_block_table_args(decode_meta, False) output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, @@ -630,7 +642,7 @@ def _run_memory_efficient_xformers_forward( # (seq_lens or encoder_seq_lens) is set. seq_lens, \ _,\ - _ = _get_seq_len_block_table_args(attn_metadata) + _ = _get_seq_len_block_table_args(attn_metadata, True) assert seq_lens is not None original_query = query From 306ea5b75865ab05c1666a6dc2bab2b6b8eef4c9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 14:46:08 -0400 Subject: [PATCH 155/443] removed inference of encoder metadata attributes; removed guessing of encoder seq len tensor device --- tests/kernels/test_encoder_decoder_attn.py | 33 ++++++++++++++++------ vllm/attention/backends/xformers.py | 24 ++-------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index fa4db135323fd..93675880defc5 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -397,19 +397,26 @@ def make_metadata_tensors(seq_lens: List[int], * query_start_loc: start idx of each query ''' seq_lens_tensor = None if seq_lens is None else \ - torch.tensor(seq_lens, dtype=torch.int, device=device) + torch.tensor(seq_lens, dtype=torch.int, device=device) context_lens_tensor = None if context_lens is None else torch.tensor( context_lens, dtype=torch.int, device=device) max_context_len = None if context_lens is None else max(context_lens) max_seq_len = None if seq_lens is None else max(seq_lens) + encoder_seq_lens_tensor = None if encoder_seq_lens is None else \ + torch.tensor(encoder_seq_lens, dtype=torch.int, device=device) + max_encoder_seq_len = None if encoder_seq_lens is None else \ + max(encoder_seq_lens) + seq_start_loc = None return seq_lens_tensor, \ context_lens_tensor, \ max_context_len, \ max_seq_len, \ - seq_start_loc + seq_start_loc, \ + encoder_seq_lens_tensor, \ + max_encoder_seq_len def make_kv_cache(num_blocks: int, @@ -622,9 +629,12 @@ def make_test_metadata( context_lens_tensor, \ _, \ _, \ - _ = make_metadata_tensors(seq_lens, - context_lens, - device=device) + _, \ + encoder_seq_lens_tensor, \ + max_encoder_seq_len = make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -640,6 +650,8 @@ def make_test_metadata( use_cuda_graph=False, _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, + encoder_seq_lens_tensor=encoder_seq_lens_tensor, + max_encoder_seq_len=max_encoder_seq_len, cross_slot_mapping=cross_slot_mapping, cross_block_tables=cross_block_tables) @@ -653,9 +665,12 @@ def make_test_metadata( context_lens_tensor, \ _, \ _, \ - _ = make_metadata_tensors(seq_lens, - context_lens, - device=device) + _, \ + encoder_seq_lens_tensor, \ + max_encoder_seq_len = make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -671,6 +686,8 @@ def make_test_metadata( use_cuda_graph=False, _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, + encoder_seq_lens_tensor=encoder_seq_lens_tensor, + max_encoder_seq_len=max_encoder_seq_len, cross_slot_mapping=cross_slot_mapping, cross_block_tables=cross_block_tables) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 01272c1d9628a..75634b10443a2 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -148,15 +148,14 @@ def __post_init__(self): self.encoder_attn_bias: Optional[List[AttentionBias]] = None self.cross_attn_bias: Optional[List[AttentionBias]] = None - # if self.is_all_encoder_attn_metadata_set: - # self._maybe_compute_implicit_encoder_attrs() - @property def is_all_encoder_attn_metadata_set(self): ''' All attention metadata required for encoder attention is set. ''' - return self.encoder_seq_lens is not None + return (self.encoder_seq_lens is not None) and \ + (self.encoder_seq_lens_tensor is not None) and \ + (self.max_encoder_seq_len is not None) @property def is_all_cross_attn_metadata_set(self): @@ -169,19 +168,6 @@ def is_all_cross_attn_metadata_set(self): (self.cross_slot_mapping is not None) and \ (self.cross_block_tables is not None) - def _maybe_compute_implicit_encoder_attrs(self): - ''' - Encoder attention and cross-attention require some encoder-related - metadata attributes which may or may not be been provided by the user. - This method infers the implicit attributes from provided attributes - ''' - if self.encoder_seq_lens_tensor is None: - self.encoder_seq_lens_tensor = torch.tensor(self.encoder_seq_lens, - dtype=torch.int32, - device="cuda:0") - if self.max_encoder_seq_len is None: - self.max_encoder_seq_len = max(self.encoder_seq_lens) - @property def attention_type(self) -> AttentionType: return self._attn_type @@ -194,15 +180,11 @@ def attention_type(self, atype: AttentionType) -> None: "Must set self.encoder_seq_lens, self.cross_slot_mapping, " + \ "self.cross_block_tables in order to perform cross-attention" - # self._maybe_compute_implicit_encoder_attrs() - self._attn_type = AttentionType.ENCODER_DECODER elif atype == AttentionType.ENCODER: assert self.is_all_encoder_attn_metadata_set, \ "Must set self.encoder_seq_lens in order to perform cross-attention" - # self._maybe_compute_implicit_encoder_attrs() - self._attn_type = AttentionType.ENCODER else: # AttentionType.{ENCODER,DECODER} From eda2273ed659e1fd00b23751a43b42e1bc19dd13 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 15:10:08 -0400 Subject: [PATCH 156/443] wip refactoring --- tests/kernels/test_encoder_decoder_attn.py | 75 ++++++++++++++++------ vllm/attention/backends/xformers.py | 3 +- 2 files changed, 56 insertions(+), 22 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 93675880defc5..163292e9ad4af 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -8,6 +8,7 @@ import copy import itertools +import numbers import random from typing import List, Optional, Union @@ -35,6 +36,46 @@ MAX_K_SEQ_LENS = [128] +def maybe_list_to_int_tensor(_list: List[int], + device: Union[torch.device, str] \ + = CUDA_DEVICE) \ + -> torch.Tensor: + ''' + Convert Python int list to a 1D int torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D int torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.int, device=device) + +def maybe_list_to_long_tensor(_list: List[int], + device: Union[torch.device, str] \ + = CUDA_DEVICE) \ + -> torch.Tensor: + ''' + Convert Python int list to a 1D long torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D long torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.long, device=device) + + +def maybe_max(_list: List) -> Optional[numbers.Number]: + ''' + Returns: + + * If _list is not None: max(_list) + * None otherwise + ''' + return None if _list is None else max(_list) + def build_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ -> torch.Tensor: ''' @@ -396,15 +437,13 @@ def make_metadata_tensors(seq_lens: List[int], * seq_start_loc: start idx of each sequence * query_start_loc: start idx of each query ''' - seq_lens_tensor = None if seq_lens is None else \ - torch.tensor(seq_lens, dtype=torch.int, device=device) - context_lens_tensor = None if context_lens is None else torch.tensor( - context_lens, dtype=torch.int, device=device) - max_context_len = None if context_lens is None else max(context_lens) - max_seq_len = None if seq_lens is None else max(seq_lens) - - encoder_seq_lens_tensor = None if encoder_seq_lens is None else \ - torch.tensor(encoder_seq_lens, dtype=torch.int, device=device) + seq_lens_tensor = maybe_list_to_int_tensor(seq_lens, device) + context_lens_tensor = maybe_list_to_int_tensor(context_lens, device) + max_context_len = maybe_max(context_lens) + max_seq_len = maybe_max(seq_lens) + + encoder_seq_lens_tensor = maybe_list_to_int_tensor(encoder_seq_lens, + device) max_encoder_seq_len = None if encoder_seq_lens is None else \ max(encoder_seq_lens) @@ -547,18 +586,12 @@ def make_block_tables_slot_mapping(block_size: int, dtype=torch.int, device=device, ) - prefill_slot_mapping_tensor = torch.tensor(prefill_slot_mapping, - dtype=torch.long, - device=device) - decode_slot_mapping_tensor = torch.tensor(decode_slot_mapping, - dtype=torch.long, - device=device) - slot_mapping_tensor = torch.tensor(slot_mapping, - dtype=torch.long, - device=device) - empty_slot_mapping_tensor = torch.tensor([], - dtype=torch.long, - device=device) + prefill_slot_mapping_tensor = maybe_list_to_long_tensor( + prefill_slot_mapping, device) + decode_slot_mapping_tensor = maybe_list_to_long_tensor( + decode_slot_mapping, device) + slot_mapping_tensor = maybe_list_to_long_tensor(slot_mapping, device) + empty_slot_mapping_tensor = maybe_list_to_long_tensor([], device) return decode_block_tables_tensor, \ decode_slot_mapping_tensor, \ diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 75634b10443a2..f165f7922017f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -328,7 +328,8 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, f"Invalid attn_metadata.attention_type {str(attn_type)}") -def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, is_prompt: bool) -> tuple: +def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, + is_prompt: bool) -> tuple: ''' The particular choice of sequence-length- and block-table-related attributes which should be extracted from attn_metadata is dependent From 9425f0cd05069338ca5835196000fdcd878e9233 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 15:53:33 -0400 Subject: [PATCH 157/443] refactored helper functions into diffferent utils files --- tests/kernels/test_encoder_decoder_attn.py | 780 ++------------------- tests/kernels/utils.py | 638 +++++++++++++++++ vllm/utils.py | 61 ++ 3 files changed, 759 insertions(+), 720 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 163292e9ad4af..3c152c8988536 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -7,21 +7,20 @@ """ import copy -import itertools -import numbers -import random -from typing import List, Optional, Union +from typing import List, Optional import pytest import torch -from tests.kernels.utils import override_backend_env_variable +from tests.kernels.utils import (make_backend, make_block_tables_slot_mapping, + make_kv_cache, make_qkv, make_test_metadata, + override_backend_env_variable, pack_qkv, + pack_tensor, ref_masked_attention) from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.abstract import AttentionBackend, AttentionType +from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -from vllm.attention.backends.xformers import XFormersBackend -from vllm.utils import is_hip, make_tensor_with_pad +from vllm.utils import is_hip, make_causal_mask HEAD_SIZES = [64, 256] @@ -36,695 +35,6 @@ MAX_K_SEQ_LENS = [128] -def maybe_list_to_int_tensor(_list: List[int], - device: Union[torch.device, str] \ - = CUDA_DEVICE) \ - -> torch.Tensor: - ''' - Convert Python int list to a 1D int torch.Tensor on `device` - - Returns: - - * If _list is not None: 1D int torch.Tensor on `device` - * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.int, device=device) - -def maybe_list_to_long_tensor(_list: List[int], - device: Union[torch.device, str] \ - = CUDA_DEVICE) \ - -> torch.Tensor: - ''' - Convert Python int list to a 1D long torch.Tensor on `device` - - Returns: - - * If _list is not None: 1D long torch.Tensor on `device` - * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.long, device=device) - - -def maybe_max(_list: List) -> Optional[numbers.Number]: - ''' - Returns: - - * If _list is not None: max(_list) - * None otherwise - ''' - return None if _list is None else max(_list) - -def build_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ - -> torch.Tensor: - ''' - Create a q_max_seq_len x kv_max_seq_len causal mask - - Arguments: - - * q_max_seq_len: query max seq len - * kv_max_seq_len: key/value max seq len - - Returns: - - * 2D tensor, q_max_seq_len x kv_max_seq_len - ''' - - # Create a matrix where entry (i, j) is True if i >= j - mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) - # Replace True with float('-inf') and False with 0 - mask = mask.masked_fill(mask == 1, - float('-inf')).masked_fill(mask == 0, 0.0) - return mask - - -def ref_masked_attention(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - custom_mask: Optional[torch.Tensor] = None, - q_seq_lens: Optional[List] = None, - kv_seq_lens: Optional[List] = None) -> torch.Tensor: - ''' - "Golden" masked attention reference. Supports two types of masking: - - * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out - padding elements - * Custom attention mask, which can force an arbitrary mask tensor, i.e. - causal - - Arguments: - - * query: batch_size x q_padded_seq_len x num_heads x head_size - * key: batch_size x kv_padded_seq_len x num_heads x head_size - * value: batch_size x kv_padded_seq_len x num_heads x head_size - * scale: Attention scale factor - * Custom mask: custom attention mask; good place to inject a causal - attention mask - * q_seq_lens: list of unpadded query seq_lens for each batch index - * kv_seq_lens: list of unpadded key/value seq_lens for each batch index - - Returns: - - * Attention result, batch_size x q_padded_seq_len x num_heads x head_size - ''' - - batch_size = query.shape[0] - assert (len(q_seq_lens) == batch_size) - assert (len(kv_seq_lens) == batch_size) - - attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() - - # Basic attention mask, derived from seq lens - if (q_seq_lens is not None) or (kv_seq_lens is not None): - attn_mask = torch.zeros_like(attn_weights) - if q_seq_lens is not None: - for bdx, plen in enumerate(q_seq_lens): - attn_mask[bdx, :, plen:, :] = -torch.inf - if kv_seq_lens is not None: - for bdx, plen in enumerate(kv_seq_lens): - attn_mask[bdx, :, :, plen:] = -torch.inf - - attn_weights = attn_weights + attn_mask.float() - - # Custom attention mask - if custom_mask is not None: - attn_weights = attn_weights + custom_mask.float() - - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) - return out - - -def make_qkv(batch_size: int, - max_q_seq_len: int, - max_kv_seq_len: int, - num_heads: int, - head_size: int, - attn_type: AttentionType = AttentionType.ENCODER_DECODER, - force_max_len: bool = False, - device: Union[torch.device, str] = CUDA_DEVICE) -> tuple: - ''' - Construct QKV test tensors for self- and cross-attention. - - Generates three query/key/value triplets: - - * "Baseline" query/key/value (for input to reference attention function) - * "Prefill" query/key/value (last sequence offset zero'd out, for use as - input to prefill kernel) - * "Decode" query/key/value (only the last sequence offset from baseline, - for use as input to decode kernel) - - Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v - seqlens - - Arguments: - - * batch_size - * max_q_seq_len: max query seq len - * max_kv_seq_len: max key/value seq len - * num_heads - * head_size - * is_encoder_decoder_attn: if True, query seqlen may differ from - key/value seqlen (as is often the case for cross-attention); - o/w, query/key/value seqlens match at each batch index - (max_kv_seq_len is unused) - * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query - seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens - and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False - * device: CPU or CUDA device - - Returns: - - * query: "baseline" query; batch_size x max_q_seq_len x num_heads x - head_size - * key: "baseline" key; batch_size x max_kv_seq_len x num_heads x - head_size - * value: "baseline" value; batch_size x max_kv_seq_len x num_heads x - head_size - * prefill_query: batch_size x (max_q_seq_len-1) x num_heads x head_size - * prefill_key: batch_size x (max_kv_seq_len-1) x num_heads x head_size - * prefill_value: batch_size x (max_kv_seq_len-1) x num_heads x head_size - * decode_query: batch_size x 1 x num_heads x head_size - * decode_key: batch_size x 1 x num_heads x head_size - * decode_value: batch_size x 1 x num_heads x head_size - * q_seq_lens: "baseline" query seqlen list - * kv_seq_lens: "baseline" key/value seqlen list - * actual_max_q_seq_len: actual "baseline" query max seq len (may be <= - max_q_seq_len due to randomness) - * actual_max_kv_seq_len: actual "baseline" key/value max seq len (may - be <= max_kv_seq_len due to randomness) - * prefill_q_seq_lens: "prefill" query seqlen list - * prefill_kv_seq_lens: "prefill" key/value seqlen list - * decode_q_seq_lens: "decode" query seqlen list (all ones) - * decode_kv_seq_lens: "decode" key/value seqlen list - ''' - - if force_max_len: - q_seq_lens = [max_q_seq_len for _ in range(batch_size)] - else: - q_seq_lens = [ - random.randint(2, max_q_seq_len) for _ in range(batch_size) - ] - kv_seq_lens = None - if attn_type != AttentionType.ENCODER_DECODER: - # K,V seq lens match Q for self-attention - kv_seq_lens = q_seq_lens - else: - # K,V seq lens are distinct from Q seq lens & random - if force_max_len: - kv_seq_lens = [max_kv_seq_len for _ in range(batch_size)] - else: - kv_seq_lens = [ - random.randint(2, max_kv_seq_len) for _ in range(batch_size) - ] - - actual_max_q_seq_len = max(q_seq_lens) - actual_max_kv_seq_len = max(kv_seq_lens) - - query = torch.rand( - (batch_size, max_q_seq_len, num_heads, head_size)).to(device) - key = torch.rand( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - value = torch.rand( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - - prefill_query = torch.zeros( - (batch_size, max_q_seq_len, num_heads, head_size)).to(device) - prefill_key = torch.zeros( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - prefill_value = torch.zeros( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - - decode_query = torch.zeros( - (batch_size, 1, num_heads, head_size)).to(device) - decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) - decode_value = torch.zeros( - (batch_size, 1, num_heads, head_size)).to(device) - - for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, - kv_seq_lens)): - query[bdx, q_seq_len:, :, :] = 0 - key[bdx, kv_seq_len:, :, :] = 0 - value[bdx, kv_seq_len:, :, :] = 0 - - prefill_query[bdx, - 0:(q_seq_len - 1), :, :] = query[bdx, - 0:(q_seq_len - 1), :, :] - prefill_key[bdx, - 0:(kv_seq_len - 1), :, :] = key[bdx, - 0:(kv_seq_len - 1), :, :] - prefill_value[bdx, 0:(kv_seq_len - - 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :] - - decode_query[bdx, :, :, :] = query[bdx, - (q_seq_len - 1):q_seq_len, :, :] - decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :] - decode_value[bdx, :, :, :] = value[bdx, - (kv_seq_len - 1):kv_seq_len, :, :] - - prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] - prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] - - decode_q_seq_lens = [1 for _ in q_seq_lens] - decode_kv_seq_lens = [1 for _ in kv_seq_lens] - - return query, \ - key, \ - value, \ - prefill_query, \ - prefill_key, \ - prefill_value, \ - decode_query, \ - decode_key, \ - decode_value, \ - q_seq_lens, \ - kv_seq_lens, \ - actual_max_q_seq_len, \ - actual_max_kv_seq_len, \ - prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - decode_q_seq_lens, \ - decode_kv_seq_lens - - -def pack_tensor(unpacked_tensor: torch.Tensor, - seq_lens: List[int], - device: Union[torch.device, str] = CUDA_DEVICE) -> tuple: - ''' - Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an - unpadded number_of_tokens x num_heads x head_size tensor, where - number_of_tokens = sum(seq_lens) - - Arguments: - - * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size - * seq_lens: list of token counts for each seq - * device: CPU or CUDA device - - Returns - - * packed_tensor: number_of_tokens x num_heads x head_size - * start_loc_list: start idx of each batch elt in packed_tensor; [0] + - list(itertools.accumulate(seq_lens)) - ''' - - num_tok = sum(seq_lens) - num_heads = unpacked_tensor.shape[-2] - head_size = unpacked_tensor.shape[-1] - start_loc_list = [0] + list(itertools.accumulate(seq_lens)) - packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) - - for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)): - - packed_tensor[start_loc:( - start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] - - return packed_tensor, start_loc_list - - -def pack_qkv(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - q_seq_lens: List[int], kv_seq_lens: List[int]) -> tuple: - ''' - Individually pack each of Q, K and V, each with dimensions batch_size x - padded_seq_len x num_heads x head_size, into respective number_of_tokens x - num_heads x head_size tensors. - - For Q, number_of_tokens = sum(q_seq_lens). - - For K and V, number_of_tokens = sum(kv_seq_lens) - - Arguments: - - * query: batch_size x padded_seq_len x num_heads x head_size - * key: batch_size x padded_seq_len x num_heads x head_size - * value: batch_size x padded_seq_len x num_heads x head_size - * q_seq_lens: list of token counts for each query - * kv_seq_lens: list of token counts for each key/value - - Returns - - * packed_query: number_of_tokens x num_heads x head_size - * packed_key: number_of_tokens x num_heads x head_size - * packed_value: number_of_tokens x num_heads x head_size - * q_start_loc_list: start idx of each query in packed_query - * kv_start_loc_list: start idx of each {key,value} in packed_{key,value} - ''' - - if query is None: - packed_query = None - q_start_loc_list = None - else: - packed_query, q_start_loc_list = pack_tensor(query, q_seq_lens) - packed_key, kv_start_loc_list = pack_tensor(key, kv_seq_lens) - packed_value, _ = pack_tensor(value, kv_seq_lens) - return packed_query, \ - packed_key, \ - packed_value, \ - q_start_loc_list, \ - kv_start_loc_list - - -def make_backend(backend_name: str) -> AttentionBackend: - ''' - Construct the backend instance determined by the backend_name string - argument. - - "XFORMERS" -> construct xformers backend - - TODO: other backends - - Note: at time of writing the Attention wrapper automatically selects - its own backend for Attention.forward(); so the backend instance which - you generate with this function is not meant to be used for *running* - inference, but rather for generating compatible metadata structures - using backend.make_metadata() - - - Returns: - - * Backend instance - ''' - if backend_name == "XFORMERS": - return XFormersBackend() - raise AssertionError( - f"Unrecognized backend_name {backend_name} for unit test") - - -def make_metadata_tensors(seq_lens: List[int], - context_lens: List[int], - encoder_seq_lens: List[int], - device: Union[torch.device, str] = \ - CUDA_DEVICE) -> tuple: - ''' - Build scalar & tensor values required to build attention metadata structure. - - Arguments: - - * is_prompt: True -> Prefill, False -> Decode - * seq_lens: list of token-counts for each seq - * context_lens: list of context length values for each seq - * device: CPU or CUDA device - - Returns: - - * seq_lens_tensor: seq_lens list, as tensor - * context_lens_tensor: context_lens list, as tensor - * max_query_len: max(seq_lens) if is_seq, o/w 1 - * max_context_len: max(context_lens) - * max_seq_len: max(seq_lens) - * seq_start_loc: start idx of each sequence - * query_start_loc: start idx of each query - ''' - seq_lens_tensor = maybe_list_to_int_tensor(seq_lens, device) - context_lens_tensor = maybe_list_to_int_tensor(context_lens, device) - max_context_len = maybe_max(context_lens) - max_seq_len = maybe_max(seq_lens) - - encoder_seq_lens_tensor = maybe_list_to_int_tensor(encoder_seq_lens, - device) - max_encoder_seq_len = None if encoder_seq_lens is None else \ - max(encoder_seq_lens) - - seq_start_loc = None - - return seq_lens_tensor, \ - context_lens_tensor, \ - max_context_len, \ - max_seq_len, \ - seq_start_loc, \ - encoder_seq_lens_tensor, \ - max_encoder_seq_len - - -def make_kv_cache(num_blocks: int, - num_heads: int, - head_size: int, - block_size: int, - device: Union[torch.device, str] = \ - CUDA_DEVICE, - default_val: float=0.0) -> torch.Tensor: - ''' - Create a fake KV cache. - - Arguments: - - * num_blocks: number of blocks in the KV cache - * num_heads: number of attention heads - * head_size: head dimension - * block_size: number of offsets within a block - * device: CPU or CUDA device - * default_val: initialization value for KV cache elements - - Returns: - - * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) - ''' - - kv_cache = torch.rand( - (2, num_blocks, block_size * num_heads * head_size)).to(device) - if default_val is not None: - kv_cache[:, :, :] = default_val - return kv_cache - - -def num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: - ''' - Compute the minimum number of blocks required to hold num_tokens tokens, - given block_size - ''' - return (num_tokens + block_size) // block_size - - -def make_block_tables_slot_mapping(block_size: int, - seq_lens: List, - block_base_addr: int=0, - device: Union[torch.device, str] = \ - CUDA_DEVICE) -> tuple: - ''' - Construct fake block tables & slot mappings. - - For a sequence with num_tokens tokens the minimum number - of required KV cache blocks is - - num_blocks = (num_tokens + block_size) // block_size - - Then the minimum KV cache size in blocks is - - total_cache_blocks = sum(num_blocks for all seqs) - - Then, the blocktable mapping counts downward from - - block_base_addr + total_cache_blocks - - to - - block_base_addr - - - Arguments: - - * block_size: number of offsets per block - * seq_lens: list of token-counts for each sequence - * block_base_addr: the block table base address - * device: CPU or CUDA device - - Return: - - * decode_block_tables_tensor: fake the state of the block tables during - decode - * decode_slot_mapping_tensor: fake the state of the slot mapping during - decode - * prefill_slot_mapping_tensor: fake the state of the slot mapping during - prefill - * prefill_block_tables_tensor: fake the state of the block tables during - prefill - * slot_mapping_tensor: union of prefill and decode slot mappings - * empty_slot_mapping_tensor: empty slot mapping (useful for decode phase - cross attention) - * max_block_idx: the highest block address within this block table - ''' - - # Provision minimum number of KV cache blocks - num_blocks_list = [ - num_tokens_to_min_blocks(num_tokens, block_size) - for num_tokens in seq_lens - ] - max_block_table_len = max(num_blocks_list) - block_table_pad_tokens = 10 - - block_tables = [] - prefill_slot_mapping = [] - decode_slot_mapping = [] - slot_mapping = [] - # Compute uppermost address of block table - total_cache_blocks = sum(num_blocks_list) - block_base_idx = block_base_addr + total_cache_blocks - max_block_idx = block_base_idx - for sdx, num_tokens in enumerate(seq_lens): - num_blocks = num_blocks_list[sdx] - block_table = list( - range(block_base_idx, block_base_idx - num_blocks, -1)) - for idx in range(num_tokens): - mapping_value = ( - idx % block_size) + block_table[idx // block_size] * block_size - slot_mapping.append(mapping_value) - if idx < num_tokens - 1: - prefill_slot_mapping.append(mapping_value) - elif idx == num_tokens - 1: - decode_slot_mapping.append(mapping_value) - - block_base_idx -= num_blocks - block_tables.append(block_table) - - prefill_block_tables_tensor = torch.tensor([], device=CUDA_DEVICE) - decode_block_tables_tensor = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len + block_table_pad_tokens, - pad=0, - dtype=torch.int, - device=device, - ) - prefill_slot_mapping_tensor = maybe_list_to_long_tensor( - prefill_slot_mapping, device) - decode_slot_mapping_tensor = maybe_list_to_long_tensor( - decode_slot_mapping, device) - slot_mapping_tensor = maybe_list_to_long_tensor(slot_mapping, device) - empty_slot_mapping_tensor = maybe_list_to_long_tensor([], device) - - return decode_block_tables_tensor, \ - decode_slot_mapping_tensor, \ - prefill_slot_mapping_tensor, \ - prefill_block_tables_tensor, \ - slot_mapping_tensor, \ - empty_slot_mapping_tensor, \ - max_block_idx - - -def make_test_metadata( - attn_backend: AttentionBackend, - is_prompt: bool, - seq_lens: List[int], - context_lens: List[int], - block_tables: torch.Tensor, - slot_mapping: torch.Tensor, - is_encoder_only_test: bool, - num_prefills_or_decodes: int, - num_prefill_or_decode_tokens: int, - device: Union[torch.device, str] = CUDA_DEVICE, - encoder_seq_lens: Optional[List[int]] = None, - cross_block_tables: Optional[torch.Tensor] = None, - cross_slot_mapping: Optional[List[int]] = None, -) -> AttentionMetadata: - ''' - Construct fake attention metadata for a combined self-/cross-attention - scenario i.e. an encoder/decoder model. - - is_encoder_only_test=True causes the default attention metadata attention - type to be AttentionType.ENCODER. False causes the default to - be AttentionType.DECODER. - - Assumptions: - - * No chunked prefill -> a batch is 100% prefill or 100% decode, never both - - Arguments: - - * attn_backend: Backend for sourcing attention kernels - * is_prompt: prefill if True, o/w decode - * seq_lens: list of token counts for each sequence - * context_lens: list of context lengths for each sequence - * block_tables: self-attention block tables - * slot_mapping: self-attention slot_mapping - * is_encoder_only_test: True if testing encoder; False if testing - decoder self-attention or encoder/decoder cross-attention. - * device: CPU or CUDA device - * encoder_seq_lens: list of token counts for each encoder sequence, if any - exist - * cross_block_tables: cross-attention block tables, if required - * cross_slot_mapping: cross-attention slot mapping, if required - - Return: - - * AttentionMetadata structure supporting self- and cross-attention - ''' - - default_attn_type = AttentionType.ENCODER if is_encoder_only_test \ - else AttentionType.DECODER - - if is_prompt: - num_prefills = num_prefills_or_decodes - num_prefill_tokens = num_prefill_or_decode_tokens - num_decode_tokens = 0 - - seq_lens_tensor, \ - context_lens_tensor, \ - _, \ - _, \ - _, \ - encoder_seq_lens_tensor, \ - max_encoder_seq_len = make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) - - return attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_prefill_seq_len=None if seq_lens is None else max(seq_lens), - max_decode_seq_len=0, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - _attn_type=default_attn_type, - encoder_seq_lens=encoder_seq_lens, - encoder_seq_lens_tensor=encoder_seq_lens_tensor, - max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=cross_slot_mapping, - cross_block_tables=cross_block_tables) - - else: # not is_prompt - - num_prefills = 0 - num_prefill_tokens = 0 - num_decode_tokens = num_prefill_or_decode_tokens - - seq_lens_tensor, \ - context_lens_tensor, \ - _, \ - _, \ - _, \ - encoder_seq_lens_tensor, \ - max_encoder_seq_len = make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) - - return attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_prefill_seq_len=0, - max_decode_seq_len=max(seq_lens), - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - _attn_type=default_attn_type, - encoder_seq_lens=encoder_seq_lens, - encoder_seq_lens_tensor=encoder_seq_lens_tensor, - max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=cross_slot_mapping, - cross_block_tables=cross_block_tables) - - def basic_setup(num_heads: int, head_size: int, num_blocks: int, block_size: int, backend_name: str) -> tuple: ''' @@ -761,7 +71,11 @@ def basic_setup(num_heads: int, head_size: int, num_blocks: int, return scale, attn_backend, attn, None # Construct KV cache - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) + kv_cache = make_kv_cache(num_blocks, + num_heads, + head_size, + block_size, + device=CUDA_DEVICE) return scale, attn_backend, attn, kv_cache @@ -830,7 +144,8 @@ def encoder_attn_setup(batch_size: int, max_kv_seq_len, num_heads, head_size, - attn_type=AttentionType.ENCODER) + attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE) # No causal attention mask ideal_output = ref_masked_attention(query, @@ -840,7 +155,9 @@ def encoder_attn_setup(batch_size: int, q_seq_lens=q_seq_lens, kv_seq_lens=kv_seq_lens) - packed_ideal_output, _ = pack_tensor(ideal_output, q_seq_lens) + packed_ideal_output, _ = pack_tensor(ideal_output, + q_seq_lens, + device=CUDA_DEVICE) block_tables, \ _, \ @@ -849,13 +166,17 @@ def encoder_attn_setup(batch_size: int, slot_mapping, \ _, \ _ = make_block_tables_slot_mapping( - block_size, q_seq_lens, block_base_addr=block_base_addr) + block_size, + q_seq_lens, + block_base_addr=block_base_addr, + device=CUDA_DEVICE) packed_query, \ packed_key, \ packed_value, _, _ = pack_qkv( query, key, value, q_seq_lens, - kv_seq_lens) + kv_seq_lens, + device=CUDA_DEVICE) return packed_query, \ packed_key, \ @@ -973,10 +294,11 @@ def decoder_attn_setup(batch_size: int, max_kv_seq_len, num_heads, head_size, - attn_type=AttentionType.DECODER) + attn_type=AttentionType.DECODER, + device=CUDA_DEVICE) - causal_mask = build_causal_mask(max_q_seq_len, - max_kv_seq_len).to(CUDA_DEVICE) + causal_mask = make_causal_mask(max_q_seq_len, + max_kv_seq_len).to(CUDA_DEVICE) ideal_output = ref_masked_attention(query, key, @@ -995,9 +317,11 @@ def decoder_attn_setup(batch_size: int, prefill_q_seq_len + 1)] prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_seq_lens) + prefill_q_seq_lens, + device=CUDA_DEVICE) decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)]) + [1 for _ in range(batch_size)], + device=CUDA_DEVICE) decode_block_tables, \ decode_slot_mapping, \ @@ -1006,13 +330,17 @@ def decoder_attn_setup(batch_size: int, _, \ _, \ max_block_idx = make_block_tables_slot_mapping( - block_size, q_seq_lens, block_base_addr=block_base_addr) + block_size, + q_seq_lens, + block_base_addr=block_base_addr, + device=CUDA_DEVICE) prefill_packed_query, \ prefill_packed_key, \ prefill_packed_value, _, _ = pack_qkv( prefill_query, prefill_key, prefill_value, prefill_q_seq_lens, - prefill_kv_seq_lens) + prefill_kv_seq_lens, + device=CUDA_DEVICE) decode_packed_query, \ decode_packed_key, \ @@ -1020,7 +348,8 @@ def decoder_attn_setup(batch_size: int, _, \ _ = pack_qkv( decode_query, decode_key, decode_value, decode_q_seq_lens, - decode_kv_seq_lens) + decode_kv_seq_lens, + device=CUDA_DEVICE) return query, \ prefill_packed_query, \ @@ -1138,7 +467,8 @@ def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, max_kv_seq_len, num_heads, head_size, - attn_type=AttentionType.ENCODER_DECODER) + attn_type=AttentionType.ENCODER_DECODER, + device=CUDA_DEVICE) ideal_output = ref_masked_attention(query, key, @@ -1156,9 +486,11 @@ def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, prefill_q_seq_len + 1)] prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_seq_lens) + prefill_q_seq_lens, + device=CUDA_DEVICE) decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)]) + [1 for _ in range(batch_size)], + device=CUDA_DEVICE) # Unlike self-attention: # - Prefill slot-mapping includes all key slots @@ -1170,11 +502,18 @@ def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, prefill_slot_mapping, \ decode_slot_mapping, \ max_block_idx = make_block_tables_slot_mapping( - block_size, kv_seq_lens, block_base_addr=block_base_addr) + block_size, + kv_seq_lens, + block_base_addr=block_base_addr, + device=CUDA_DEVICE) # Packed key/value (query is already provided) - _, packed_key, packed_value, _, _ = pack_qkv(None, key, value, None, - kv_seq_lens) + _, packed_key, packed_value, _, _ = pack_qkv(None, + key, + value, + None, + kv_seq_lens, + device=CUDA_DEVICE) return packed_key, \ packed_value, \ @@ -1337,7 +676,8 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, is_encoder_only_test=True, num_prefills_or_decodes=len(q_seq_lens), num_prefill_or_decode_tokens=sum(q_seq_lens), - encoder_seq_lens=q_seq_lens) + encoder_seq_lens=q_seq_lens, + device=CUDA_DEVICE) packed_actual_output: torch.Tensor = \ run_encoder_or_decoder_self_attention_test( @@ -1480,7 +820,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, - ) + device=CUDA_DEVICE) self_prefill_packed_actual_output: torch.Tensor = \ run_encoder_or_decoder_self_attention_test( @@ -1526,7 +866,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, - ) + device=CUDA_DEVICE) self_decode_packed_actual_output: torch.Tensor = \ run_encoder_or_decoder_self_attention_test( @@ -1706,7 +1046,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, - ) + device=CUDA_DEVICE) with pytest.raises(NotImplementedError) as exc_info: run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index b401eb87d3ec3..2b752e4cbcd76 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1,6 +1,17 @@ """Kernel test utils""" +import itertools +import random +from typing import List, Optional, Union + import pytest +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, AttentionType) +from vllm.attention.backends.xformers import XFormersBackend +from vllm.utils import (make_tensor_with_pad, maybe_make_int_tensor, + maybe_make_long_tensor, maybe_max) STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" @@ -20,3 +31,630 @@ def override_backend_env_variable(mpatch: pytest.MonkeyPatch, * backend_name: attention backend name to force ''' mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name) + + +def ref_masked_attention(query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + custom_mask: Optional[torch.Tensor] = None, + q_seq_lens: Optional[List] = None, + kv_seq_lens: Optional[List] = None) -> torch.Tensor: + ''' + "Golden" masked attention reference. Supports two types of masking: + + * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out + padding elements + * Custom attention mask, which can force an arbitrary mask tensor, i.e. + causal + + Arguments: + + * query: batch_size x q_padded_seq_len x num_heads x head_size + * key: batch_size x kv_padded_seq_len x num_heads x head_size + * value: batch_size x kv_padded_seq_len x num_heads x head_size + * scale: Attention scale factor + * Custom mask: custom attention mask; good place to inject a causal + attention mask + * q_seq_lens: list of unpadded query seq_lens for each batch index + * kv_seq_lens: list of unpadded key/value seq_lens for each batch index + + Returns: + + * Attention result, batch_size x q_padded_seq_len x num_heads x head_size + ''' + + batch_size = query.shape[0] + assert (len(q_seq_lens) == batch_size) + assert (len(kv_seq_lens) == batch_size) + + attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() + + # Basic attention mask, derived from seq lens + if (q_seq_lens is not None) or (kv_seq_lens is not None): + attn_mask = torch.zeros_like(attn_weights) + if q_seq_lens is not None: + for bdx, plen in enumerate(q_seq_lens): + attn_mask[bdx, :, plen:, :] = -torch.inf + if kv_seq_lens is not None: + for bdx, plen in enumerate(kv_seq_lens): + attn_mask[bdx, :, :, plen:] = -torch.inf + + attn_weights = attn_weights + attn_mask.float() + + # Custom attention mask + if custom_mask is not None: + attn_weights = attn_weights + custom_mask.float() + + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) + return out + + +def make_qkv( + batch_size: int, + max_q_seq_len: int, + max_kv_seq_len: int, + num_heads: int, + head_size: int, + device: Union[torch.device, str], + attn_type: AttentionType = AttentionType.ENCODER_DECODER, + force_max_len: bool = False, +) -> tuple: + ''' + Construct QKV test tensors for self- and cross-attention. + + Generates three query/key/value triplets: + + * "Baseline" query/key/value (for input to reference attention function) + * "Prefill" query/key/value (last sequence offset zero'd out, for use as + input to prefill kernel) + * "Decode" query/key/value (only the last sequence offset from baseline, + for use as input to decode kernel) + + Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v + seqlens + + Arguments: + + * batch_size + * max_q_seq_len: max query seq len + * max_kv_seq_len: max key/value seq len + * num_heads + * head_size + * is_encoder_decoder_attn: if True, query seqlen may differ from + key/value seqlen (as is often the case for cross-attention); + o/w, query/key/value seqlens match at each batch index + (max_kv_seq_len is unused) + * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query + seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens + and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False + * device: CPU or CUDA device + + Returns: + + * query: "baseline" query; batch_size x max_q_seq_len x num_heads x + head_size + * key: "baseline" key; batch_size x max_kv_seq_len x num_heads x + head_size + * value: "baseline" value; batch_size x max_kv_seq_len x num_heads x + head_size + * prefill_query: batch_size x (max_q_seq_len-1) x num_heads x head_size + * prefill_key: batch_size x (max_kv_seq_len-1) x num_heads x head_size + * prefill_value: batch_size x (max_kv_seq_len-1) x num_heads x head_size + * decode_query: batch_size x 1 x num_heads x head_size + * decode_key: batch_size x 1 x num_heads x head_size + * decode_value: batch_size x 1 x num_heads x head_size + * q_seq_lens: "baseline" query seqlen list + * kv_seq_lens: "baseline" key/value seqlen list + * actual_max_q_seq_len: actual "baseline" query max seq len (may be <= + max_q_seq_len due to randomness) + * actual_max_kv_seq_len: actual "baseline" key/value max seq len (may + be <= max_kv_seq_len due to randomness) + * prefill_q_seq_lens: "prefill" query seqlen list + * prefill_kv_seq_lens: "prefill" key/value seqlen list + * decode_q_seq_lens: "decode" query seqlen list (all ones) + * decode_kv_seq_lens: "decode" key/value seqlen list + ''' + + if force_max_len: + q_seq_lens = [max_q_seq_len for _ in range(batch_size)] + else: + q_seq_lens = [ + random.randint(2, max_q_seq_len) for _ in range(batch_size) + ] + kv_seq_lens = None + if attn_type != AttentionType.ENCODER_DECODER: + # K,V seq lens match Q for self-attention + kv_seq_lens = q_seq_lens + else: + # K,V seq lens are distinct from Q seq lens & random + if force_max_len: + kv_seq_lens = [max_kv_seq_len for _ in range(batch_size)] + else: + kv_seq_lens = [ + random.randint(2, max_kv_seq_len) for _ in range(batch_size) + ] + + actual_max_q_seq_len = max(q_seq_lens) + actual_max_kv_seq_len = max(kv_seq_lens) + + query = torch.rand( + (batch_size, max_q_seq_len, num_heads, head_size)).to(device) + key = torch.rand( + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + value = torch.rand( + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + + prefill_query = torch.zeros( + (batch_size, max_q_seq_len, num_heads, head_size)).to(device) + prefill_key = torch.zeros( + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + prefill_value = torch.zeros( + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + + decode_query = torch.zeros( + (batch_size, 1, num_heads, head_size)).to(device) + decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) + decode_value = torch.zeros( + (batch_size, 1, num_heads, head_size)).to(device) + + for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, + kv_seq_lens)): + query[bdx, q_seq_len:, :, :] = 0 + key[bdx, kv_seq_len:, :, :] = 0 + value[bdx, kv_seq_len:, :, :] = 0 + + prefill_query[bdx, + 0:(q_seq_len - 1), :, :] = query[bdx, + 0:(q_seq_len - 1), :, :] + prefill_key[bdx, + 0:(kv_seq_len - 1), :, :] = key[bdx, + 0:(kv_seq_len - 1), :, :] + prefill_value[bdx, 0:(kv_seq_len - + 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :] + + decode_query[bdx, :, :, :] = query[bdx, + (q_seq_len - 1):q_seq_len, :, :] + decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :] + decode_value[bdx, :, :, :] = value[bdx, + (kv_seq_len - 1):kv_seq_len, :, :] + + prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] + prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] + + decode_q_seq_lens = [1 for _ in q_seq_lens] + decode_kv_seq_lens = [1 for _ in kv_seq_lens] + + return query, \ + key, \ + value, \ + prefill_query, \ + prefill_key, \ + prefill_value, \ + decode_query, \ + decode_key, \ + decode_value, \ + q_seq_lens, \ + kv_seq_lens, \ + actual_max_q_seq_len, \ + actual_max_kv_seq_len, \ + prefill_q_seq_lens, \ + prefill_kv_seq_lens, \ + decode_q_seq_lens, \ + decode_kv_seq_lens + + +def pack_tensor(unpacked_tensor: torch.Tensor, seq_lens: List[int], + device: Union[torch.device, str]) -> tuple: + ''' + Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an + unpadded number_of_tokens x num_heads x head_size tensor, where + number_of_tokens = sum(seq_lens) + + Arguments: + + * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size + * seq_lens: list of token counts for each seq + * device: CPU or CUDA device + + Returns + + * packed_tensor: number_of_tokens x num_heads x head_size + * start_loc_list: start idx of each batch elt in packed_tensor; [0] + + list(itertools.accumulate(seq_lens)) + ''' + + num_tok = sum(seq_lens) + num_heads = unpacked_tensor.shape[-2] + head_size = unpacked_tensor.shape[-1] + start_loc_list = [0] + list(itertools.accumulate(seq_lens)) + packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) + + for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)): + + packed_tensor[start_loc:( + start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] + + return packed_tensor, start_loc_list + + +def pack_qkv(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + q_seq_lens: List[int], kv_seq_lens: List[int], + device: Union[torch.device, str]) -> tuple: + ''' + Individually pack each of Q, K and V, each with dimensions batch_size x + padded_seq_len x num_heads x head_size, into respective number_of_tokens x + num_heads x head_size tensors. + + For Q, number_of_tokens = sum(q_seq_lens). + + For K and V, number_of_tokens = sum(kv_seq_lens) + + Arguments: + + * query: batch_size x padded_seq_len x num_heads x head_size + * key: batch_size x padded_seq_len x num_heads x head_size + * value: batch_size x padded_seq_len x num_heads x head_size + * q_seq_lens: list of token counts for each query + * kv_seq_lens: list of token counts for each key/value + + Returns + + * packed_query: number_of_tokens x num_heads x head_size + * packed_key: number_of_tokens x num_heads x head_size + * packed_value: number_of_tokens x num_heads x head_size + * q_start_loc_list: start idx of each query in packed_query + * kv_start_loc_list: start idx of each {key,value} in packed_{key,value} + ''' + + if query is None: + packed_query = None + q_start_loc_list = None + else: + packed_query, q_start_loc_list = pack_tensor(query, + q_seq_lens, + device=device) + packed_key, kv_start_loc_list = pack_tensor(key, + kv_seq_lens, + device=device) + packed_value, _ = pack_tensor(value, kv_seq_lens, device=device) + return packed_query, \ + packed_key, \ + packed_value, \ + q_start_loc_list, \ + kv_start_loc_list + + +def make_backend(backend_name: str) -> AttentionBackend: + ''' + Construct the backend instance determined by the backend_name string + argument. + + "XFORMERS" -> construct xformers backend + + TODO: other backends + + Note: at time of writing the Attention wrapper automatically selects + its own backend for Attention.forward(); so the backend instance which + you generate with this function is not meant to be used for *running* + inference, but rather for generating compatible metadata structures + using backend.make_metadata() + + + Returns: + + * Backend instance + ''' + if backend_name == "XFORMERS": + return XFormersBackend() + raise AssertionError( + f"Unrecognized backend_name {backend_name} for unit test") + + +def make_metadata_tensors(seq_lens: List[int], context_lens: List[int], + encoder_seq_lens: List[int], + device: Union[torch.device, str]) -> tuple: + ''' + Build scalar & tensor values required to build attention metadata structure. + + Arguments: + + * is_prompt: True -> Prefill, False -> Decode + * seq_lens: list of token-counts for each seq + * context_lens: list of context length values for each seq + * device: CPU or CUDA device + + Returns: + + * seq_lens_tensor: seq_lens list, as tensor + * context_lens_tensor: context_lens list, as tensor + * max_query_len: max(seq_lens) if is_seq, o/w 1 + * max_context_len: max(context_lens) + * max_seq_len: max(seq_lens) + * seq_start_loc: start idx of each sequence + * query_start_loc: start idx of each query + ''' + seq_lens_tensor = maybe_make_int_tensor(seq_lens, device) + context_lens_tensor = maybe_make_int_tensor(context_lens, device) + max_context_len = maybe_max(context_lens) + max_seq_len = maybe_max(seq_lens) + + encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device) + max_encoder_seq_len = None if encoder_seq_lens is None else \ + max(encoder_seq_lens) + + seq_start_loc = None + + return seq_lens_tensor, \ + context_lens_tensor, \ + max_context_len, \ + max_seq_len, \ + seq_start_loc, \ + encoder_seq_lens_tensor, \ + max_encoder_seq_len + + +def make_kv_cache(num_blocks: int, + num_heads: int, + head_size: int, + block_size: int, + device: Union[torch.device, str], + default_val: float = 0.0) -> torch.Tensor: + ''' + Create a fake KV cache. + + Arguments: + + * num_blocks: number of blocks in the KV cache + * num_heads: number of attention heads + * head_size: head dimension + * block_size: number of offsets within a block + * device: CPU or CUDA device + * default_val: initialization value for KV cache elements + + Returns: + + * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) + ''' + + kv_cache = torch.rand( + (2, num_blocks, block_size * num_heads * head_size)).to(device) + if default_val is not None: + kv_cache[:, :, :] = default_val + return kv_cache + + +def num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: + ''' + Compute the minimum number of blocks required to hold num_tokens tokens, + given block_size + ''' + return (num_tokens + block_size) // block_size + + +def make_block_tables_slot_mapping(block_size: int, + seq_lens: List, + device: Union[torch.device, str], + block_base_addr: int = 0) -> tuple: + ''' + Construct fake block tables & slot mappings. + + For a sequence with num_tokens tokens the minimum number + of required KV cache blocks is + + num_blocks = (num_tokens + block_size) // block_size + + Then the minimum KV cache size in blocks is + + total_cache_blocks = sum(num_blocks for all seqs) + + Then, the blocktable mapping counts downward from + + block_base_addr + total_cache_blocks + + to + + block_base_addr + + + Arguments: + + * block_size: number of offsets per block + * seq_lens: list of token-counts for each sequence + * block_base_addr: the block table base address + * device: CPU or CUDA device + + Return: + + * decode_block_tables_tensor: fake the state of the block tables during + decode + * decode_slot_mapping_tensor: fake the state of the slot mapping during + decode + * prefill_slot_mapping_tensor: fake the state of the slot mapping during + prefill + * prefill_block_tables_tensor: fake the state of the block tables during + prefill + * slot_mapping_tensor: union of prefill and decode slot mappings + * empty_slot_mapping_tensor: empty slot mapping (useful for decode phase + cross attention) + * max_block_idx: the highest block address within this block table + ''' + + # Provision minimum number of KV cache blocks + num_blocks_list = [ + num_tokens_to_min_blocks(num_tokens, block_size) + for num_tokens in seq_lens + ] + max_block_table_len = max(num_blocks_list) + block_table_pad_tokens = 10 + + block_tables = [] + prefill_slot_mapping = [] + decode_slot_mapping = [] + slot_mapping = [] + # Compute uppermost address of block table + total_cache_blocks = sum(num_blocks_list) + block_base_idx = block_base_addr + total_cache_blocks + max_block_idx = block_base_idx + for sdx, num_tokens in enumerate(seq_lens): + num_blocks = num_blocks_list[sdx] + block_table = list( + range(block_base_idx, block_base_idx - num_blocks, -1)) + for idx in range(num_tokens): + mapping_value = ( + idx % block_size) + block_table[idx // block_size] * block_size + slot_mapping.append(mapping_value) + if idx < num_tokens - 1: + prefill_slot_mapping.append(mapping_value) + elif idx == num_tokens - 1: + decode_slot_mapping.append(mapping_value) + + block_base_idx -= num_blocks + block_tables.append(block_table) + + prefill_block_tables_tensor = torch.tensor([], device=device) + decode_block_tables_tensor = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len + block_table_pad_tokens, + pad=0, + dtype=torch.int, + device=device, + ) + prefill_slot_mapping_tensor = maybe_make_long_tensor( + prefill_slot_mapping, device) + decode_slot_mapping_tensor = maybe_make_long_tensor( + decode_slot_mapping, device) + slot_mapping_tensor = maybe_make_long_tensor(slot_mapping, device) + empty_slot_mapping_tensor = maybe_make_long_tensor([], device) + + return decode_block_tables_tensor, \ + decode_slot_mapping_tensor, \ + prefill_slot_mapping_tensor, \ + prefill_block_tables_tensor, \ + slot_mapping_tensor, \ + empty_slot_mapping_tensor, \ + max_block_idx + + +def make_test_metadata( + attn_backend: AttentionBackend, + is_prompt: bool, + seq_lens: List[int], + context_lens: List[int], + block_tables: torch.Tensor, + slot_mapping: torch.Tensor, + is_encoder_only_test: bool, + num_prefills_or_decodes: int, + num_prefill_or_decode_tokens: int, + device: Union[torch.device, str], + encoder_seq_lens: Optional[List[int]] = None, + cross_block_tables: Optional[torch.Tensor] = None, + cross_slot_mapping: Optional[List[int]] = None, +) -> AttentionMetadata: + ''' + Construct fake attention metadata for a combined self-/cross-attention + scenario i.e. an encoder/decoder model. + + is_encoder_only_test=True causes the default attention metadata attention + type to be AttentionType.ENCODER. False causes the default to + be AttentionType.DECODER. + + Assumptions: + + * No chunked prefill -> a batch is 100% prefill or 100% decode, never both + + Arguments: + + * attn_backend: Backend for sourcing attention kernels + * is_prompt: prefill if True, o/w decode + * seq_lens: list of token counts for each sequence + * context_lens: list of context lengths for each sequence + * block_tables: self-attention block tables + * slot_mapping: self-attention slot_mapping + * is_encoder_only_test: True if testing encoder; False if testing + decoder self-attention or encoder/decoder cross-attention. + * device: CPU or CUDA device + * encoder_seq_lens: list of token counts for each encoder sequence, if any + exist + * cross_block_tables: cross-attention block tables, if required + * cross_slot_mapping: cross-attention slot mapping, if required + + Return: + + * AttentionMetadata structure supporting self- and cross-attention + ''' + + default_attn_type = AttentionType.ENCODER if is_encoder_only_test \ + else AttentionType.DECODER + + if is_prompt: + num_prefills = num_prefills_or_decodes + num_prefill_tokens = num_prefill_or_decode_tokens + num_decode_tokens = 0 + + seq_lens_tensor, \ + context_lens_tensor, \ + _, \ + _, \ + _, \ + encoder_seq_lens_tensor, \ + max_encoder_seq_len = make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) + + return attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_prefill_seq_len=None if seq_lens is None else max(seq_lens), + max_decode_seq_len=0, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + _attn_type=default_attn_type, + encoder_seq_lens=encoder_seq_lens, + encoder_seq_lens_tensor=encoder_seq_lens_tensor, + max_encoder_seq_len=max_encoder_seq_len, + cross_slot_mapping=cross_slot_mapping, + cross_block_tables=cross_block_tables) + + else: # not is_prompt + + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = num_prefill_or_decode_tokens + + seq_lens_tensor, \ + context_lens_tensor, \ + _, \ + _, \ + _, \ + encoder_seq_lens_tensor, \ + max_encoder_seq_len = make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) + + return attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_prefill_seq_len=0, + max_decode_seq_len=max(seq_lens), + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + _attn_type=default_attn_type, + encoder_seq_lens=encoder_seq_lens, + encoder_seq_lens_tensor=encoder_seq_lens_tensor, + max_encoder_seq_len=max_encoder_seq_len, + cross_slot_mapping=cross_slot_mapping, + cross_block_tables=cross_block_tables) diff --git a/vllm/utils.py b/vllm/utils.py index 2781eceb7ba98..1986ba2b3d8c6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -12,6 +12,7 @@ import warnings from collections import defaultdict from functools import lru_cache, partial, wraps +from numbers import Number from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Hashable, List, Optional, OrderedDict, Tuple, TypeVar, @@ -674,3 +675,63 @@ def inner(*args, **kwargs): return inner # type: ignore return wrapper + +def maybe_make_int_tensor(_list: List[int], + device: Union[torch.device, str]) \ + -> torch.Tensor: + ''' + Convert Python int list to a 1D int torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D int torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.int, device=device) + +def maybe_make_long_tensor(_list: List[int], + device: Union[torch.device, str]) \ + -> torch.Tensor: + ''' + Convert Python int list to a 1D long torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D long torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.long, device=device) + + +def maybe_max(_list: List) -> Optional[Number]: + ''' + Returns: + + * If _list is not None: max(_list) + * None otherwise + ''' + return None if _list is None else max(_list) + +def make_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ + -> torch.Tensor: + ''' + Create a q_max_seq_len x kv_max_seq_len causal mask + + Arguments: + + * q_max_seq_len: query max seq len + * kv_max_seq_len: key/value max seq len + + Returns: + + * 2D tensor, q_max_seq_len x kv_max_seq_len + ''' + + # Create a matrix where entry (i, j) is True if i >= j + mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) + # Replace True with float('-inf') and False with 0 + mask = mask.masked_fill(mask == 1, + float('-inf')).masked_fill(mask == 0, 0.0) + return mask \ No newline at end of file From 62fb8d1a63cc1f501d5ee22948dcc4853f26df50 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 15:55:25 -0400 Subject: [PATCH 158/443] _ for private functions in test_encoder_decoder_attn --- tests/kernels/test_encoder_decoder_attn.py | 42 +++++++++++----------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 3c152c8988536..33149cf38e866 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -35,7 +35,7 @@ MAX_K_SEQ_LENS = [128] -def basic_setup(num_heads: int, head_size: int, num_blocks: int, +def _basic_setup(num_heads: int, head_size: int, num_blocks: int, block_size: int, backend_name: str) -> tuple: ''' Compute & build entities required for the self-/cross-attention test. @@ -79,7 +79,7 @@ def basic_setup(num_heads: int, head_size: int, num_blocks: int, return scale, attn_backend, attn, kv_cache -def encoder_attn_setup(batch_size: int, +def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, block_size: int, @@ -187,7 +187,7 @@ def encoder_attn_setup(batch_size: int, q_seq_lens -def decoder_attn_setup(batch_size: int, +def _decoder_attn_setup(batch_size: int, num_heads: int, head_size: int, block_size: int, @@ -373,7 +373,7 @@ def decoder_attn_setup(batch_size: int, max_block_idx -def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, +def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, q_seq_lens: List, prefill_q_seq_lens: List, batch_size: int, @@ -527,7 +527,7 @@ def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, max_block_idx -def run_encoder_or_decoder_self_attention_test( +def _run_encoder_or_decoder_self_attention_test( attn: Attention, packed_query: torch.Tensor, packed_key: torch.Tensor, packed_value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, @@ -562,7 +562,7 @@ def run_encoder_or_decoder_self_attention_test( attn_metadata) -def run_encoder_decoder_cross_attention_test( +def _run_encoder_decoder_cross_attention_test( attn: Attention, packed_query: torch.Tensor, packed_key: torch.Tensor, packed_value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: @@ -633,7 +633,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, scale, \ attn_backend, \ attn, \ - _ = basic_setup(num_heads, + _ = _basic_setup(num_heads, head_size, None, None, @@ -651,7 +651,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, packed_ideal_output, \ block_tables, \ slot_mapping, \ - q_seq_lens = encoder_attn_setup(batch_size, + q_seq_lens = _encoder_attn_setup(batch_size, num_heads, head_size, block_size, @@ -680,7 +680,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, device=CUDA_DEVICE) packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( + _run_encoder_or_decoder_self_attention_test( attn, packed_query, packed_key, @@ -743,7 +743,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( scale, \ attn_backend, \ attn, \ - kv_cache = basic_setup(num_heads, + kv_cache = _basic_setup(num_heads, head_size, num_blocks, block_size, @@ -772,7 +772,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_slot_mapping, \ self_prefill_slot_mapping, \ self_prefill_block_tables, \ - cross_block_base_addr = decoder_attn_setup(batch_size, + cross_block_base_addr = _decoder_attn_setup(batch_size, num_heads, head_size, block_size, @@ -791,7 +791,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = enc_dec_cross_attn_setup_reuses_query(query, + _ = _enc_dec_cross_attn_setup_reuses_query(query, q_seq_lens, prefill_q_seq_lens, batch_size, @@ -823,7 +823,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( device=CUDA_DEVICE) self_prefill_packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( + _run_encoder_or_decoder_self_attention_test( attn, prefill_packed_query, self_prefill_packed_key, @@ -839,7 +839,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_packed_ideal_output)) cross_prefill_packed_actual_output: torch.Tensor = \ - run_encoder_decoder_cross_attention_test( + _run_encoder_decoder_cross_attention_test( attn, prefill_packed_query, cross_prefill_packed_key, cross_prefill_packed_value, kv_cache, prefill_attn_metadata) @@ -869,7 +869,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( device=CUDA_DEVICE) self_decode_packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( + _run_encoder_or_decoder_self_attention_test( attn, decode_packed_query, self_decode_packed_key, @@ -885,7 +885,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_packed_ideal_output)) cross_decode_packed_actual_output: torch.Tensor = \ - run_encoder_decoder_cross_attention_test( + _run_encoder_decoder_cross_attention_test( attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) @@ -912,7 +912,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # of prefill and decode tokens. decode_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test(attn, decode_packed_query, + _run_encoder_decoder_cross_attention_test(attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) @@ -969,7 +969,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, scale, \ attn_backend, \ attn, \ - kv_cache = basic_setup(num_heads, + kv_cache = _basic_setup(num_heads, head_size, num_blocks, block_size, @@ -998,7 +998,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, self_decode_slot_mapping, \ self_prefill_slot_mapping, \ self_prefill_block_tables, \ - cross_block_base_addr = decoder_attn_setup(batch_size, + cross_block_base_addr = _decoder_attn_setup(batch_size, num_heads, head_size, block_size, @@ -1017,7 +1017,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = enc_dec_cross_attn_setup_reuses_query(query, + _ = _enc_dec_cross_attn_setup_reuses_query(query, q_seq_lens, prefill_q_seq_lens, batch_size, @@ -1049,7 +1049,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, device=CUDA_DEVICE) with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, + _run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, cross_prefill_packed_key, cross_prefill_packed_value, kv_cache, From 2730daa37de2ea8ef9531f906661c5592c9b1307 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 15:58:46 -0400 Subject: [PATCH 159/443] _ refactor --- tests/kernels/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 2b752e4cbcd76..54ffa4ff0e700 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -352,7 +352,7 @@ def make_backend(backend_name: str) -> AttentionBackend: f"Unrecognized backend_name {backend_name} for unit test") -def make_metadata_tensors(seq_lens: List[int], context_lens: List[int], +def _make_metadata_tensors(seq_lens: List[int], context_lens: List[int], encoder_seq_lens: List[int], device: Union[torch.device, str]) -> tuple: ''' @@ -425,7 +425,7 @@ def make_kv_cache(num_blocks: int, return kv_cache -def num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: +def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: ''' Compute the minimum number of blocks required to hold num_tokens tokens, given block_size @@ -483,7 +483,7 @@ def make_block_tables_slot_mapping(block_size: int, # Provision minimum number of KV cache blocks num_blocks_list = [ - num_tokens_to_min_blocks(num_tokens, block_size) + _num_tokens_to_min_blocks(num_tokens, block_size) for num_tokens in seq_lens ] max_block_table_len = max(num_blocks_list) @@ -599,7 +599,7 @@ def make_test_metadata( _, \ _, \ encoder_seq_lens_tensor, \ - max_encoder_seq_len = make_metadata_tensors(seq_lens, + max_encoder_seq_len = _make_metadata_tensors(seq_lens, context_lens, encoder_seq_lens, device=device) @@ -635,7 +635,7 @@ def make_test_metadata( _, \ _, \ encoder_seq_lens_tensor, \ - max_encoder_seq_len = make_metadata_tensors(seq_lens, + max_encoder_seq_len = _make_metadata_tensors(seq_lens, context_lens, encoder_seq_lens, device=device) From 5face2ab0e6f77fc3ee98f307369e848fdea1a4f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 15:59:26 -0400 Subject: [PATCH 160/443] formatting --- tests/kernels/test_encoder_decoder_attn.py | 38 +++++++++++----------- tests/kernels/utils.py | 4 +-- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 33149cf38e866..c1ff1327af423 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -36,7 +36,7 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, - block_size: int, backend_name: str) -> tuple: + block_size: int, backend_name: str) -> tuple: ''' Compute & build entities required for the self-/cross-attention test. @@ -80,12 +80,12 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, def _encoder_attn_setup(batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_q_seq_len: int, - block_base_addr: int = 0) -> tuple: + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_q_seq_len: int, + block_base_addr: int = 0) -> tuple: ''' Set up test vectors & data structures for encoder attention test. @@ -188,12 +188,12 @@ def _encoder_attn_setup(batch_size: int, def _decoder_attn_setup(batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_q_seq_len: int, - block_base_addr: int = 0) -> tuple: + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_q_seq_len: int, + block_base_addr: int = 0) -> tuple: ''' Set up test vectors & data structures for self-attention test. @@ -913,8 +913,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decode_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: _run_encoder_decoder_cross_attention_test(attn, decode_packed_query, - None, None, kv_cache, - decode_attn_metadata) + None, None, kv_cache, + decode_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL @@ -1050,10 +1050,10 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, with pytest.raises(NotImplementedError) as exc_info: _run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, - cross_prefill_packed_key, - cross_prefill_packed_value, - kv_cache, - prefill_attn_metadata) + cross_prefill_packed_key, + cross_prefill_packed_value, + kv_cache, + prefill_attn_metadata) # "Encoder decoder models do not currently support ROCm/HIP" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 54ffa4ff0e700..b7951d4b5da28 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -353,8 +353,8 @@ def make_backend(backend_name: str) -> AttentionBackend: def _make_metadata_tensors(seq_lens: List[int], context_lens: List[int], - encoder_seq_lens: List[int], - device: Union[torch.device, str]) -> tuple: + encoder_seq_lens: List[int], + device: Union[torch.device, str]) -> tuple: ''' Build scalar & tensor values required to build attention metadata structure. From f39155a3b981fab2d53a18394a20211ce9d51dab Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 18:17:01 -0400 Subject: [PATCH 161/443] constructing attn md with minimum number of arguments --- tests/kernels/test_encoder_decoder_attn.py | 7 +----- tests/kernels/utils.py | 2 +- vllm/attention/backends/xformers.py | 28 ++++++++++------------ 3 files changed, 15 insertions(+), 22 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index c1ff1327af423..85a29f61ab159 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -669,7 +669,6 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, - None, context_lens, block_tables, slot_mapping, @@ -805,13 +804,10 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # PREFILL: self- and cross-attention tests - context_lens = [0 for _ in range(batch_size)] - prefill_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, prefill_q_seq_lens, - context_lens, self_prefill_block_tables, self_prefill_slot_mapping, is_encoder_only_test=False, @@ -857,10 +853,10 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_backend, False, q_seq_lens, - context_lens, self_decode_block_tables, self_decode_slot_mapping, is_encoder_only_test=False, + context_lens=context_lens, num_prefills_or_decodes=len(q_seq_lens), num_prefill_or_decode_tokens=len(q_seq_lens), encoder_seq_lens=encoder_kv_seq_lens, @@ -1037,7 +1033,6 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, attn_backend, True, prefill_q_seq_lens, - context_lens, self_prefill_block_tables, self_prefill_slot_mapping, is_encoder_only_test=False, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index b7951d4b5da28..74b33aaec5d2f 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -541,13 +541,13 @@ def make_test_metadata( attn_backend: AttentionBackend, is_prompt: bool, seq_lens: List[int], - context_lens: List[int], block_tables: torch.Tensor, slot_mapping: torch.Tensor, is_encoder_only_test: bool, num_prefills_or_decodes: int, num_prefill_or_decode_tokens: int, device: Union[torch.device, str], + context_lens: Optional[List[int]] = None, encoder_seq_lens: Optional[List[int]] = None, cross_block_tables: Optional[torch.Tensor] = None, cross_slot_mapping: Optional[List[int]] = None, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index f165f7922017f..81e31d4c38c2b 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -67,9 +67,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] @@ -93,6 +91,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] = None + # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is @@ -204,7 +206,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: (self.encoder_seq_lens is not None) assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) - assert self.context_lens_tensor is not None + #assert self.context_lens_tensor is not None assert self.block_tables is not None query_start_loc = None if self.query_start_loc is None \ @@ -216,19 +218,20 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], seq_lens=None - if self.seq_lens is None else self.seq_lens[:self.num_prefills], + if self.seq_lens is None \ + else self.seq_lens[:self.num_prefills], seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, query_start_loc=query_start_loc, - seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + context_lens_tensor=None if self.context_lens_tensor is None else \ + self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, _attn_type=self.attention_type, - # Begin cross-attention fields below... + # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, @@ -254,19 +257,14 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, + self.seq_lens_tensor[self.num_prefills:], max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, _attn_type=self. - _attn_type, # Begin cross-attention fields below... + _attn_type, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, From c7edbc6d962f702fc2d5c996e773a6fd8e121e59 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 09:24:21 -0400 Subject: [PATCH 162/443] Formatting --- tests/kernels/test_encoder_decoder_attn.py | 2 -- vllm/attention/backends/xformers.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 85a29f61ab159..ec8945f1c7257 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -1027,8 +1027,6 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, # PREFILL: self- and cross-attention tests - context_lens = [0 for _ in range(batch_size)] - prefill_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 81e31d4c38c2b..6d4aff2bba37d 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -258,7 +258,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:], + self.seq_lens_tensor[self.num_prefills:], max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, block_tables=self.block_tables[self.num_prefills:], From b023557e87b3a826f92aa8ff39896e23974990b8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 12:14:28 -0400 Subject: [PATCH 163/443] typing and formatting --- vllm/attention/backends/xformers.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6d4aff2bba37d..6948bfea33ce3 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -273,7 +273,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: return self._cached_decode_metadata def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ - Optional[List[Optional[AttentionBias]]]: + Optional[AttentionBias]: ''' Extract appropriate attention bias from attention metadata according to attention type. @@ -621,10 +621,10 @@ def _run_memory_efficient_xformers_forward( # Enforce that the appropriate *_seq_lens attribute of attn_metadata # (seq_lens or encoder_seq_lens) is set. - seq_lens, \ - _,\ - _ = _get_seq_len_block_table_args(attn_metadata, True) - assert seq_lens is not None + # seq_lens, \ + # _,\ + # _ = _get_seq_len_block_table_args(attn_metadata, True) + # assert seq_lens is not None original_query = query if self.num_kv_heads != self.num_heads: @@ -648,15 +648,22 @@ def _run_memory_efficient_xformers_forward( if self.alibi_slopes is None: if attn_metadata.attention_type == \ AttentionType.ENCODER_DECODER: + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens is not None + # Default enc/dec cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) else: if attn_metadata.attention_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + # Default encoder self-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.encoder_seq_lens) else: + assert attn_metadata.seq_lens is not None + # Default decoder self-attention mask is causal attn_bias = BlockDiagonalCausalMask.from_seqlens( attn_metadata.seq_lens) @@ -665,6 +672,7 @@ def _run_memory_efficient_xformers_forward( self.sliding_window) attn_bias = [attn_bias] else: + assert attn_metadata.seq_lens is not None attn_bias = _make_alibi_bias(self.alibi_slopes, self.num_kv_heads, query.dtype, attn_metadata.seq_lens) @@ -692,6 +700,7 @@ def _run_memory_efficient_xformers_forward( # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. + assert attn_metadata.seq_lens is not None output = torch.empty_like(original_query) start = 0 for i, seq_len in enumerate(attn_metadata.seq_lens): From af0c0b94d7e374159ed343b9c125497a94571d6c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 18:27:02 -0400 Subject: [PATCH 164/443] refactored block table/slot mapping construction process for decoder into two steps --- tests/kernels/test_encoder_decoder_attn.py | 42 +++++++-- tests/kernels/utils.py | 105 ++++++++++++++------- 2 files changed, 102 insertions(+), 45 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index ec8945f1c7257..42198a4e162db 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -15,7 +15,10 @@ from tests.kernels.utils import (make_backend, make_block_tables_slot_mapping, make_kv_cache, make_qkv, make_test_metadata, override_backend_env_variable, pack_qkv, - pack_tensor, ref_masked_attention) + pack_tensor, ref_masked_attention, + make_empty_slot_mapping_tensor, + make_empty_block_tables_tensor, + split_slot_mapping) from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( @@ -323,17 +326,36 @@ def _decoder_attn_setup(batch_size: int, [1 for _ in range(batch_size)], device=CUDA_DEVICE) + # Build prefill- & decode-phase data structures + # for decoder self-attention. Block tables and + # slot mapping must be in a format compatible + # with KV caching & attention kernels + # + # Prefill: + # + # * Empty block-tables tensor + # * Slot-mapping with entries for prompt tokens + # + # Decode: + # * Block-tables tensor with minimum number of blocks + # required by total num. tokens in the entirety of all sequences + # (including both prefill & decode) + # * Slot-mapping with entries for tokens that will be decoded in the + # current decode iteration + + prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) + decode_block_tables, \ - decode_slot_mapping, \ + slot_mapping_list, \ + max_block_idx = make_block_tables_slot_mapping(block_size, + q_seq_lens, + device=CUDA_DEVICE, + block_base_addr = block_base_addr) + prefill_slot_mapping, \ - prefill_block_tables, \ - _, \ - _, \ - max_block_idx = make_block_tables_slot_mapping( - block_size, - q_seq_lens, - block_base_addr=block_base_addr, - device=CUDA_DEVICE) + decode_slot_mapping = split_slot_mapping(slot_mapping_list, + q_seq_lens, + device=CUDA_DEVICE) prefill_packed_query, \ prefill_packed_key, \ diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 74b33aaec5d2f..0ac052c42a75a 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -432,9 +432,70 @@ def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: ''' return (num_tokens + block_size) // block_size +def make_empty_slot_mapping_tensor(device: Union[torch.device, str]): + return maybe_make_long_tensor([], device) + +def make_empty_block_tables_tensor(device: Union[torch.device, str]): + return torch.tensor([], device=device) + +def split_slot_mapping(slot_mapping_list: torch.Tensor, + seq_lens: List[int], + device: Union[torch.device, str]): + ''' + Split a slot mapping into valid prefill- and decode-phase slot mappings. + + Context: + * Your goal is to test (1) prefill of N prompts, with prompt-lengths + {K_i \forall i \in [0,N)}, followed by (2) decoding of a single token + for all N prompts (N tokens total); the resultant sequence lengths + after decode would be {K_i + 1 for i \in [0,N)} + * The test you want to do requires (1) having the prefill slot mapping + for all tokens present during prefill, the number of which is + M = \sum_i{K_i}, and (2) having the decode slot mapping for all N + decoded tokens + + This function consumes a single 1D slot mapping, which is the + concatenation of N slot mappings each of length K_i + 1 (corresponding + to the sequence lengths after decode), with a total length of + P = \sum_i{K_i + 1} = M + N + + The prefill-phase slot mapping results from excising the (K_i + 1)-th entry + from each of the N subsequences in the slot mapping (i.e. omitting the + decoded token's mapping.) + + The N excised entries are appended to obtain the decode-phase slot mapping + + Arguments: + + * slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N + post-decode sequences + * seq_lens: List of N post-decode sequence lengths (K_i + 1 in the + description above) + * device: cuda, cpu, etc. + + Returns: + + * prefill_slot_mapping: Length-M 1D slot mapping (as Tensor) + reflecting all N prefill prompts + * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting + all N decoded tokens + ''' + + prefill_slot_mapping = [] + decode_slot_mapping = [] + + base_idx=0 + for seq_len in seq_lens: + prefill_slot_mapping.extend( + slot_mapping_list[range(base_idx,base_idx+seq_len-1)]) + decode_slot_mapping.append(slot_mapping_list[base_idx+seq_len-1]) + base_idx += seq_len + + return maybe_make_long_tensor(prefill_slot_mapping, device), \ + maybe_make_long_tensor(decode_slot_mapping, device) def make_block_tables_slot_mapping(block_size: int, - seq_lens: List, + seq_lens: List[int], device: Union[torch.device, str], block_base_addr: int = 0) -> tuple: ''' @@ -467,17 +528,8 @@ def make_block_tables_slot_mapping(block_size: int, Return: - * decode_block_tables_tensor: fake the state of the block tables during - decode - * decode_slot_mapping_tensor: fake the state of the slot mapping during - decode - * prefill_slot_mapping_tensor: fake the state of the slot mapping during - prefill - * prefill_block_tables_tensor: fake the state of the block tables during - prefill - * slot_mapping_tensor: union of prefill and decode slot mappings - * empty_slot_mapping_tensor: empty slot mapping (useful for decode phase - cross attention) + * block_tables_tensor: block table for sequence + * slot_mapping_list: slot mapping for sequence * max_block_idx: the highest block address within this block table ''' @@ -490,9 +542,7 @@ def make_block_tables_slot_mapping(block_size: int, block_table_pad_tokens = 10 block_tables = [] - prefill_slot_mapping = [] - decode_slot_mapping = [] - slot_mapping = [] + slot_mapping_list = [] # Compute uppermost address of block table total_cache_blocks = sum(num_blocks_list) block_base_idx = block_base_addr + total_cache_blocks @@ -504,36 +554,21 @@ def make_block_tables_slot_mapping(block_size: int, for idx in range(num_tokens): mapping_value = ( idx % block_size) + block_table[idx // block_size] * block_size - slot_mapping.append(mapping_value) - if idx < num_tokens - 1: - prefill_slot_mapping.append(mapping_value) - elif idx == num_tokens - 1: - decode_slot_mapping.append(mapping_value) + slot_mapping_list.append(mapping_value) block_base_idx -= num_blocks block_tables.append(block_table) - prefill_block_tables_tensor = torch.tensor([], device=device) - decode_block_tables_tensor = make_tensor_with_pad( + block_tables_tensor = make_tensor_with_pad( block_tables, max_len=max_block_table_len + block_table_pad_tokens, pad=0, dtype=torch.int, device=device, ) - prefill_slot_mapping_tensor = maybe_make_long_tensor( - prefill_slot_mapping, device) - decode_slot_mapping_tensor = maybe_make_long_tensor( - decode_slot_mapping, device) - slot_mapping_tensor = maybe_make_long_tensor(slot_mapping, device) - empty_slot_mapping_tensor = maybe_make_long_tensor([], device) - - return decode_block_tables_tensor, \ - decode_slot_mapping_tensor, \ - prefill_slot_mapping_tensor, \ - prefill_block_tables_tensor, \ - slot_mapping_tensor, \ - empty_slot_mapping_tensor, \ + + return block_tables_tensor, \ + slot_mapping_list, \ max_block_idx From 50bca0886cf988dfc43712d3ad65ab9adce9c228 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 19:01:18 -0400 Subject: [PATCH 165/443] finished breaking block table/slot mapping construction into steps; formatting --- tests/kernels/test_encoder_decoder_attn.py | 62 ++++++++++++++-------- tests/kernels/utils.py | 15 +++--- 2 files changed, 50 insertions(+), 27 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 42198a4e162db..fdf48f12e33aa 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -13,17 +13,17 @@ import torch from tests.kernels.utils import (make_backend, make_block_tables_slot_mapping, - make_kv_cache, make_qkv, make_test_metadata, + make_empty_block_tables_tensor, + make_empty_slot_mapping_tensor, make_kv_cache, + make_qkv, make_test_metadata, override_backend_env_variable, pack_qkv, pack_tensor, ref_masked_attention, - make_empty_slot_mapping_tensor, - make_empty_block_tables_tensor, split_slot_mapping) from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -from vllm.utils import is_hip, make_causal_mask +from vllm.utils import is_hip, make_causal_mask, maybe_make_long_tensor HEAD_SIZES = [64, 256] @@ -163,11 +163,7 @@ def _encoder_attn_setup(batch_size: int, device=CUDA_DEVICE) block_tables, \ - _, \ - _, \ - _, \ slot_mapping, \ - _, \ _ = make_block_tables_slot_mapping( block_size, q_seq_lens, @@ -331,12 +327,12 @@ def _decoder_attn_setup(batch_size: int, # slot mapping must be in a format compatible # with KV caching & attention kernels # - # Prefill: - # + # Prefill-phase: + # # * Empty block-tables tensor # * Slot-mapping with entries for prompt tokens # - # Decode: + # Decode-phase: # * Block-tables tensor with minimum number of blocks # required by total num. tokens in the entirety of all sequences # (including both prefill & decode) @@ -353,8 +349,8 @@ def _decoder_attn_setup(batch_size: int, block_base_addr = block_base_addr) prefill_slot_mapping, \ - decode_slot_mapping = split_slot_mapping(slot_mapping_list, - q_seq_lens, + decode_slot_mapping = split_slot_mapping(slot_mapping_list, + q_seq_lens, device=CUDA_DEVICE) prefill_packed_query, \ @@ -514,21 +510,45 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, [1 for _ in range(batch_size)], device=CUDA_DEVICE) - # Unlike self-attention: - # - Prefill slot-mapping includes all key slots - # - Decode slot-mapping is empty + # Build prefill- & decode-phase data structures + # for encoder/decoder cross-attention. Block tables and + # slot mapping must be in a format compatible + # with KV caching & attention kernels + # + # Whereas decoder self-attention extracts relationships between + # equal-length Q/K/V sequences, which mutually grow in length + # with each decoded token, cross-attention relates the Q sequence + # - which grows with each new decoded token - to fixed-length + # K and V sequences derived from the encoder hidden states. + # + # Prefill-phase: + # + # * Empty block-tables tensor + # * Slot-mapping with as many entries as there are tokens in the encoder + # prompt. + # + # Decode-phase: + # * Block-tables tensor with minimum number of blocks to + # accommodate K & V tensors which are equal in lnegth + # to the encoder prompt length + # * Empty slot-mapping tensor (since K & V are fixed in size, + # new decoded tokens are not KV-cached and require no slot- + # mapping) + + prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) + decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE) + decode_block_tables, \ - _, \ - _, \ - prefill_block_tables, \ - prefill_slot_mapping, \ - decode_slot_mapping, \ + prefill_slot_mapping_list, \ max_block_idx = make_block_tables_slot_mapping( block_size, kv_seq_lens, block_base_addr=block_base_addr, device=CUDA_DEVICE) + prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list, + device=CUDA_DEVICE) + # Packed key/value (query is already provided) _, packed_key, packed_value, _, _ = pack_qkv(None, key, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 0ac052c42a75a..575e97f24c0c4 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -432,14 +432,16 @@ def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: ''' return (num_tokens + block_size) // block_size + def make_empty_slot_mapping_tensor(device: Union[torch.device, str]): return maybe_make_long_tensor([], device) + def make_empty_block_tables_tensor(device: Union[torch.device, str]): return torch.tensor([], device=device) -def split_slot_mapping(slot_mapping_list: torch.Tensor, - seq_lens: List[int], + +def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], device: Union[torch.device, str]): ''' Split a slot mapping into valid prefill- and decode-phase slot mappings. @@ -484,16 +486,17 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, prefill_slot_mapping = [] decode_slot_mapping = [] - base_idx=0 + base_idx = 0 for seq_len in seq_lens: - prefill_slot_mapping.extend( - slot_mapping_list[range(base_idx,base_idx+seq_len-1)]) - decode_slot_mapping.append(slot_mapping_list[base_idx+seq_len-1]) + prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx + + seq_len - 1)]) + decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1]) base_idx += seq_len return maybe_make_long_tensor(prefill_slot_mapping, device), \ maybe_make_long_tensor(decode_slot_mapping, device) + def make_block_tables_slot_mapping(block_size: int, seq_lens: List[int], device: Union[torch.device, str], From 90610daa1b4163872d5a243b95d7127309e1d91e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 19:19:32 -0400 Subject: [PATCH 166/443] slight refactor --- tests/kernels/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 575e97f24c0c4..ddeafbc654073 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -522,6 +522,10 @@ def make_block_tables_slot_mapping(block_size: int, block_base_addr + The constructed block-tables and slot-mapping are sized to the + lengths of the sequences in their entirety (as reflected by seq_lens), + i.e. the total of prefill prompt tokens + decoded tokens. + Arguments: * block_size: number of offsets per block From a006cc892b9438d77f8d3f992075bc73dfcb9bb5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 22:18:00 -0400 Subject: [PATCH 167/443] refactored encoder test into the cross-attention test --- tests/kernels/test_encoder_decoder_attn.py | 353 +++++++++------------ tests/kernels/utils.py | 12 +- vllm/attention/backends/xformers.py | 14 +- 3 files changed, 169 insertions(+), 210 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index fdf48f12e33aa..7d9d6d32c0c40 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -34,8 +34,8 @@ BACKEND_NAMES = ["XFORMERS"] CUDA_DEVICE = "cuda:0" -MAX_Q_SEQ_LENS = [128] -MAX_K_SEQ_LENS = [128] +MAX_DEC_SEQ_LENS = [128] +MAX_ENC_SEQ_LENS = [128] def _basic_setup(num_heads: int, head_size: int, num_blocks: int, @@ -82,13 +82,8 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, return scale, attn_backend, attn, kv_cache -def _encoder_attn_setup(batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_q_seq_len: int, - block_base_addr: int = 0) -> tuple: +def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, + scale: float, max_q_seq_len: int) -> tuple: ''' Set up test vectors & data structures for encoder attention test. @@ -162,14 +157,6 @@ def _encoder_attn_setup(batch_size: int, q_seq_lens, device=CUDA_DEVICE) - block_tables, \ - slot_mapping, \ - _ = make_block_tables_slot_mapping( - block_size, - q_seq_lens, - block_base_addr=block_base_addr, - device=CUDA_DEVICE) - packed_query, \ packed_key, \ packed_value, _, _ = pack_qkv( @@ -181,8 +168,6 @@ def _encoder_attn_setup(batch_size: int, packed_key, \ packed_value, \ packed_ideal_output, \ - block_tables, \ - slot_mapping, \ q_seq_lens @@ -380,10 +365,7 @@ def _decoder_attn_setup(batch_size: int, decode_packed_key, \ decode_packed_value, \ decode_packed_ideal_output, \ - decode_q_seq_lens, \ - decode_kv_seq_lens, \ q_seq_lens, \ - kv_seq_lens, \ decode_block_tables, \ decode_slot_mapping, \ prefill_slot_mapping, \ @@ -392,15 +374,16 @@ def _decoder_attn_setup(batch_size: int, def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, - q_seq_lens: List, + decoder_seq_lens: List[int], + encoder_seq_lens: Optional[List[int]], prefill_q_seq_lens: List, batch_size: int, num_heads: int, head_size: int, block_size: int, scale: float, - max_q_seq_len: int, - max_kv_seq_len: int, + max_decoder_seq_len: int, + max_encoder_seq_len: int, block_base_addr: Optional[int]=0) \ -> tuple: ''' @@ -481,10 +464,11 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, _, \ _, \ _ = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, + max_decoder_seq_len, + max_encoder_seq_len, num_heads, head_size, + force_kv_seq_lens=encoder_seq_lens, attn_type=AttentionType.ENCODER_DECODER, device=CUDA_DEVICE) @@ -492,7 +476,7 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, key, value, scale=scale, - q_seq_lens=q_seq_lens, + q_seq_lens=decoder_seq_lens, kv_seq_lens=kv_seq_lens) prefill_ideal_output = torch.zeros_like(ideal_output) @@ -561,28 +545,57 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, packed_value, \ prefill_packed_ideal_output, \ decode_packed_ideal_output, \ - kv_seq_lens, \ decode_block_tables, \ decode_slot_mapping, \ prefill_slot_mapping, \ - prefill_block_tables, \ - max_block_idx + prefill_block_tables -def _run_encoder_or_decoder_self_attention_test( - attn: Attention, packed_query: torch.Tensor, packed_key: torch.Tensor, - packed_value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - attn_type: AttentionType) -> torch.Tensor: +def _run_encoder_attention_test(attn: Attention, packed_query: torch.Tensor, + packed_key: torch.Tensor, + packed_value: torch.Tensor, + attn_metadata: AttentionMetadata, + attn_type: AttentionType) -> torch.Tensor: ''' - Run encoder attention or decoder self-attention test. + Run encoder attention. attn_metadata.attention_type is assigned attn_type in order to configure - the kernel invocation for either encoder or decoder self-attention. + the kernel invocation for either encoder attention - attn_type must be AttentionType.ENCODER or DECODER; if ENCODER, - attn_metadata.num_decode_tokens must be 0 (i.e. there is no such thing as - "decode-phase encoder attention".) + attn_type must be AttentionType.ENCODER + + Arguments: + + * attn: Attention wrapper instance + * packed_{query,key,value}: total_num_tokens x (num_heads*head_size) + * attn_metadata: attention metadata for encoder/decoder-self attention + * attn_type: AttentionType.DECODER or AttentionType.ENCODER + + Returns: + * Attention.forward() applied to packed_{query,key,value}, kv_cache + & attn_metadata + ''' + assert attn_type == AttentionType.ENCODER + assert attn_metadata.num_decode_tokens == 0 + attn_metadata.attention_type = attn_type + return attn.forward(packed_query, packed_key, packed_value, None, + attn_metadata) + + +def _run_decoder_self_attention_test(attn: Attention, + packed_query: torch.Tensor, + packed_key: torch.Tensor, + packed_value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + attn_type: AttentionType) -> torch.Tensor: + ''' + Run decoder self-attention test. + + attn_metadata.attention_type is assigned attn_type in order to configure + the kernel invocation for decoder self-attention. + + attn_type must be AttentionType.DECODER Arguments: @@ -596,9 +609,7 @@ def _run_encoder_or_decoder_self_attention_test( * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata ''' - assert attn_type in [AttentionType.DECODER, AttentionType.ENCODER] - assert attn_metadata.num_decode_tokens == 0 or \ - attn_type != AttentionType.ENCODER + assert attn_type == AttentionType.DECODER attn_metadata.attention_type = attn_type return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) @@ -637,115 +648,11 @@ def _run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_seq_len", MAX_Q_SEQ_LENS) -def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, - batch_size: int, block_size: int, max_seq_len: int, - monkeypatch) -> None: - ''' - Encoder-only attention test: - - * Construct fake test vectors for encoder attention - * Construct attention metadata structure with encoder-attention- - specific attributes - * Run encoder attention with metadata structure & test vectors - * Validate output correctness against ideal reference attention - implementation - - Encoder attention (by default) does not restrict which sequence offsets - may attend to each other. Thus the reference ideal reference - implementation does not employ a causal attention mask. - - Encoder attention does not utilize KV cache however the XFormer backend - requires block_tables & slot_mapping to be non-None and have a valid - structure, thus this test constructs dummy memory-mapping structures. - - Encoder attention is basically structured like decoder self-attention - in that Q/K/V are all derived from the previous layer output & have - the same sequence length (in contrast to encoder/decoder cross- - attention where K/V are drawn from the encoder hidden states and - may have a different length than Q derived from decoder previous - layer output.) - ''' - - # Force Attention wrapper backend - override_backend_env_variable(monkeypatch, backend_name) - - # Attention scale factor, attention backend instance, attention wrapper - # instance. Encoder attention does not require KV cache. - scale, \ - attn_backend, \ - attn, \ - _ = _basic_setup(num_heads, - head_size, - None, - None, - backend_name) - - # Self-attention setup - # Let encoder_attn_setup() choose default block table - # base address; the block_tables and slot_mapping - # tensors are not actually utilized by encoder attention - # anyway but are required to be present & valid by the - # backend. - packed_query, \ - packed_key, \ - packed_value, \ - packed_ideal_output, \ - block_tables, \ - slot_mapping, \ - q_seq_lens = _encoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_seq_len) - - context_lens = [0 for _ in range(batch_size)] - - # Metadata config for encoder attention: - # - # * Use prefill kernel - # * Signal that this is an encoder-only test so that - # metadata attention_type property is correctly - # configured as AttentionType.ENCODER - attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - context_lens, - block_tables, - slot_mapping, - is_encoder_only_test=True, - num_prefills_or_decodes=len(q_seq_lens), - num_prefill_or_decode_tokens=sum(q_seq_lens), - encoder_seq_lens=q_seq_lens, - device=CUDA_DEVICE) - - packed_actual_output: torch.Tensor = \ - _run_encoder_or_decoder_self_attention_test( - attn, - packed_query, - packed_key, - packed_value, - None, - attn_metadata, - attn_type=AttentionType.ENCODER) - - # - Is encoder attention result correct? - assert torch.allclose(packed_ideal_output, - packed_actual_output.view_as(packed_ideal_output)) - - -@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) -@pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) +@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) +@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) def test_enc_dec_self_and_cross_attention_prefill_decode_phases( num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_q_seq_len: int, max_kv_seq_len: int, + block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, monkeypatch) -> None: ''' Encoder/decoder attention test: @@ -790,25 +697,37 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( block_size, backend_name) - # Self-attention setup + # Encoder attention setup - self_block_base_addr = 0 + # Let encoder_attn_setup() choose default block table + # base address; the block_tables and slot_mapping + # tensors are not actually utilized by encoder attention + # anyway but are required to be present & valid by the + # backend. + enc_packed_query, \ + enc_packed_key, \ + enc_packed_value, \ + enc_packed_ideal_output, \ + encoder_seq_lens = _encoder_attn_setup(batch_size, + num_heads, + head_size, + scale, + max_enc_seq_len) + + # Decoder self-attention setup query, \ prefill_packed_query, \ self_prefill_packed_key, \ self_prefill_packed_value, \ self_prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ + prefill_decoder_seq_lens, \ + self_prefill_encoder_seq_lens, \ decode_packed_query, \ self_decode_packed_key, \ self_decode_packed_value, \ self_decode_packed_ideal_output, \ - _, \ - _, \ - q_seq_lens, \ - _, \ + decoder_seq_lens, \ self_decode_block_tables, \ self_decode_slot_mapping, \ self_prefill_slot_mapping, \ @@ -818,8 +737,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( head_size, block_size, scale, - max_q_seq_len, - block_base_addr=self_block_base_addr) + max_dec_seq_len) # Cross-attention setup @@ -827,47 +745,68 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_prefill_packed_value, \ cross_prefill_packed_ideal_output, \ cross_decode_packed_ideal_output, \ - encoder_kv_seq_lens, \ cross_decode_block_tables, \ cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = _enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) - - # PREFILL: self- and cross-attention tests + = _enc_dec_cross_attn_setup_reuses_query(query, + decoder_seq_lens, + encoder_seq_lens, + prefill_decoder_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_dec_seq_len, + max_enc_seq_len, + block_base_addr = \ + cross_block_base_addr) + + # Shared prefill metadata structure + # - prefill_attn_metadata: AttentionMetadata = make_test_metadata( + enc_and_dec_prefill_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, - prefill_q_seq_lens, + prefill_decoder_seq_lens, self_prefill_block_tables, self_prefill_slot_mapping, - is_encoder_only_test=False, - num_prefills_or_decodes=len(prefill_q_seq_lens), - num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), - encoder_seq_lens=encoder_kv_seq_lens, + is_encoder_only_test=True, + num_prefills_or_decodes=len(prefill_decoder_seq_lens), + num_prefill_or_decode_tokens=sum(prefill_decoder_seq_lens), + encoder_seq_lens=encoder_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, device=CUDA_DEVICE) + # PREFILL: encoder attention + # * Use prefill kernel + + enc_packed_actual_output: torch.Tensor = \ + _run_encoder_attention_test( + attn, + enc_packed_query, + enc_packed_key, + enc_packed_value, + enc_and_dec_prefill_attn_metadata, + attn_type=AttentionType.ENCODER) + + # - Is encoder attention result correct? + assert torch.allclose( + enc_packed_ideal_output, + enc_packed_actual_output.view_as(enc_packed_ideal_output)) + + # PREFILL: self-attention test + self_prefill_packed_actual_output: torch.Tensor = \ - _run_encoder_or_decoder_self_attention_test( + _run_decoder_self_attention_test( attn, prefill_packed_query, self_prefill_packed_key, self_prefill_packed_value, kv_cache, - prefill_attn_metadata, + enc_and_dec_prefill_attn_metadata, attn_type=AttentionType.DECODER) # - Prefill self-attention correct? @@ -876,10 +815,12 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_packed_actual_output.view_as( self_prefill_packed_ideal_output)) + # PREFILL: cross-attention test + cross_prefill_packed_actual_output: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, prefill_attn_metadata) + cross_prefill_packed_value, kv_cache, enc_and_dec_prefill_attn_metadata) # - Prefill cross-attention correct? assert torch.allclose( @@ -887,27 +828,29 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_prefill_packed_actual_output.view_as( cross_prefill_packed_ideal_output)) - context_lens = copy.deepcopy(self_prefill_kv_seq_lens) + context_lens = copy.deepcopy(self_prefill_encoder_seq_lens) - # DECODE: self- and cross-attention tests + # DECODE: build decode-phase attention metadata decode_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, False, - q_seq_lens, + decoder_seq_lens, self_decode_block_tables, self_decode_slot_mapping, is_encoder_only_test=False, context_lens=context_lens, - num_prefills_or_decodes=len(q_seq_lens), - num_prefill_or_decode_tokens=len(q_seq_lens), - encoder_seq_lens=encoder_kv_seq_lens, + num_prefills_or_decodes=len(decoder_seq_lens), + num_prefill_or_decode_tokens=len(decoder_seq_lens), + encoder_seq_lens=encoder_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, device=CUDA_DEVICE) + # DECODE: self-attention test + self_decode_packed_actual_output: torch.Tensor = \ - _run_encoder_or_decoder_self_attention_test( + _run_decoder_self_attention_test( attn, decode_packed_query, self_decode_packed_key, @@ -922,6 +865,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_packed_actual_output.view_as( self_decode_packed_ideal_output)) + # DECODE: cross-attention test + cross_decode_packed_actual_output: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( attn, decode_packed_query, None, @@ -1028,10 +973,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, self_decode_packed_key, \ self_decode_packed_value, \ self_decode_packed_ideal_output, \ - _, \ - _, \ q_seq_lens, \ - _, \ self_decode_block_tables, \ self_decode_slot_mapping, \ self_prefill_slot_mapping, \ @@ -1055,17 +997,18 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = _enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) + = _enc_dec_cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr = \ + cross_block_base_addr) # PREFILL: self- and cross-attention tests diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index ddeafbc654073..14e9549287604 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -94,10 +94,11 @@ def ref_masked_attention(query: torch.Tensor, def make_qkv( batch_size: int, max_q_seq_len: int, - max_kv_seq_len: int, + max_kv_seq_len: Optional[int], num_heads: int, head_size: int, device: Union[torch.device, str], + force_kv_seq_lens: List[int] = None, attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len: bool = False, ) -> tuple: @@ -126,6 +127,8 @@ def make_qkv( key/value seqlen (as is often the case for cross-attention); o/w, query/key/value seqlens match at each batch index (max_kv_seq_len is unused) + * force_kv_seq_lens: if not None, overrides kv sequence lengths + * attn_type: encoder, decoder self, or enc/dec cross attention * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False @@ -146,7 +149,8 @@ def make_qkv( * decode_key: batch_size x 1 x num_heads x head_size * decode_value: batch_size x 1 x num_heads x head_size * q_seq_lens: "baseline" query seqlen list - * kv_seq_lens: "baseline" key/value seqlen list + * kv_seq_lens: "baseline" key/value seqlen list; overridden by non-None + force_encoder_kv_seq_lens * actual_max_q_seq_len: actual "baseline" query max seq len (may be <= max_q_seq_len due to randomness) * actual_max_kv_seq_len: actual "baseline" key/value max seq len (may @@ -164,7 +168,9 @@ def make_qkv( random.randint(2, max_q_seq_len) for _ in range(batch_size) ] kv_seq_lens = None - if attn_type != AttentionType.ENCODER_DECODER: + if force_kv_seq_lens is not None: + kv_seq_lens = force_kv_seq_lens + elif attn_type != AttentionType.ENCODER_DECODER: # K,V seq lens match Q for self-attention kv_seq_lens = q_seq_lens else: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6948bfea33ce3..cc240242f7eae 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -515,8 +515,18 @@ def forward( self.kv_cache_dtype, kv_scale) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens + if attn_metadata.attention_type != AttentionType.ENCODER: + # Decoder self-attention supports chunked prefill. + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + num_prefill_tokens = query.shape[0] + num_decode_tokens = 0 if attn_type != AttentionType.ENCODER_DECODER: # Only enforce this shape-constraint for decoder From 20b95b00b4e85516ee077c08ba5fbe2c5dbd2811 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 22:24:49 -0400 Subject: [PATCH 168/443] slight refactoring --- tests/kernels/test_encoder_decoder_attn.py | 8 ++------ tests/kernels/utils.py | 2 -- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 7d9d6d32c0c40..6f91f055a80bc 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -12,6 +12,8 @@ import pytest import torch +import collections + from tests.kernels.utils import (make_backend, make_block_tables_slot_mapping, make_empty_block_tables_tensor, make_empty_slot_mapping_tensor, make_kv_cache, @@ -135,8 +137,6 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, _, \ _, \ _, \ - _, \ - _, \ _ = make_qkv(batch_size, max_q_seq_len, max_kv_seq_len, @@ -268,8 +268,6 @@ def _decoder_attn_setup(batch_size: int, decode_value, \ q_seq_lens, \ kv_seq_lens, \ - _, \ - _, \ prefill_q_seq_lens, \ prefill_kv_seq_lens, \ decode_q_seq_lens, \ @@ -461,8 +459,6 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, _, \ _, \ _, \ - _, \ - _, \ _ = make_qkv(batch_size, max_decoder_seq_len, max_encoder_seq_len, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 14e9549287604..c36a969149e73 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -243,8 +243,6 @@ def make_qkv( decode_value, \ q_seq_lens, \ kv_seq_lens, \ - actual_max_q_seq_len, \ - actual_max_kv_seq_len, \ prefill_q_seq_lens, \ prefill_kv_seq_lens, \ decode_q_seq_lens, \ From 6d52d606e7eaf650cbff541e691e9e831cab63cc Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 00:33:14 -0400 Subject: [PATCH 169/443] QKVInputs and PackedQKVInputs named tuple integration to simplify test logic --- tests/kernels/test_encoder_decoder_attn.py | 695 ++++++++++----------- tests/kernels/utils.py | 114 ++-- 2 files changed, 378 insertions(+), 431 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 6f91f055a80bc..431fe42cd6efe 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -20,7 +20,8 @@ make_qkv, make_test_metadata, override_backend_env_variable, pack_qkv, pack_tensor, ref_masked_attention, - split_slot_mapping) + split_slot_mapping, QKVInputs, + PackedQKVInputs) from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( @@ -85,7 +86,8 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, - scale: float, max_q_seq_len: int) -> tuple: + scale: float, max_q_seq_len: int) \ + -> tuple[PackedQKVInputs,torch.Tensor]: ''' Set up test vectors & data structures for encoder attention test. @@ -122,53 +124,33 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, ''' max_kv_seq_len = max_q_seq_len - - query, \ - key, \ - value, \ - _, \ - _, \ - _, \ - _, \ - _, \ - _, \ - q_seq_lens, \ - kv_seq_lens, \ - _, \ - _, \ - _, \ - _ = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.ENCODER, - device=CUDA_DEVICE) + + qkv_in, _, _ = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE) # No causal attention mask - ideal_output = ref_masked_attention(query, - key, - value, + ideal_output = ref_masked_attention(qkv_in.query, + qkv_in.key, + qkv_in.value, scale=scale, - q_seq_lens=q_seq_lens, - kv_seq_lens=kv_seq_lens) + q_seq_lens=qkv_in.q_seq_lens, + kv_seq_lens=qkv_in.kv_seq_lens) packed_ideal_output, _ = pack_tensor(ideal_output, - q_seq_lens, + qkv_in.q_seq_lens, device=CUDA_DEVICE) - packed_query, \ - packed_key, \ - packed_value, _, _ = pack_qkv( - query, key, value, q_seq_lens, - kv_seq_lens, + packed_qkv = pack_qkv( + qkv_in, device=CUDA_DEVICE) - return packed_query, \ - packed_key, \ - packed_value, \ - packed_ideal_output, \ - q_seq_lens + return packed_qkv, \ + packed_ideal_output def _decoder_attn_setup(batch_size: int, @@ -257,49 +239,37 @@ def _decoder_attn_setup(batch_size: int, max_kv_seq_len = max_q_seq_len - query, \ - key, \ - value, \ - prefill_query, \ - prefill_key, \ - prefill_value, \ - decode_query, \ - decode_key, \ - decode_value, \ - q_seq_lens, \ - kv_seq_lens, \ - prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - decode_q_seq_lens, \ - decode_kv_seq_lens = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.DECODER, - device=CUDA_DEVICE) + qkv, \ + prefill_qkv, \ + decode_qkv = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.DECODER, + device=CUDA_DEVICE) causal_mask = make_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) - ideal_output = ref_masked_attention(query, - key, - value, + ideal_output = ref_masked_attention(qkv.query, + qkv.key, + qkv.value, scale=scale, custom_mask=causal_mask, - q_seq_lens=q_seq_lens, - kv_seq_lens=kv_seq_lens) + q_seq_lens=qkv.q_seq_lens, + kv_seq_lens=qkv.kv_seq_lens) prefill_ideal_output = torch.zeros_like(ideal_output) decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): + for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens): prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ bdx, :prefill_q_seq_len] decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( prefill_q_seq_len + 1)] prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_seq_lens, + prefill_qkv.q_seq_lens, device=CUDA_DEVICE) decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, [1 for _ in range(batch_size)], @@ -327,62 +297,43 @@ def _decoder_attn_setup(batch_size: int, decode_block_tables, \ slot_mapping_list, \ max_block_idx = make_block_tables_slot_mapping(block_size, - q_seq_lens, + qkv.q_seq_lens, device=CUDA_DEVICE, block_base_addr = block_base_addr) prefill_slot_mapping, \ decode_slot_mapping = split_slot_mapping(slot_mapping_list, - q_seq_lens, + qkv.q_seq_lens, device=CUDA_DEVICE) - prefill_packed_query, \ - prefill_packed_key, \ - prefill_packed_value, _, _ = pack_qkv( - prefill_query, prefill_key, prefill_value, prefill_q_seq_lens, - prefill_kv_seq_lens, - device=CUDA_DEVICE) + prefill_pckd_qkv = pack_qkv(prefill_qkv, + device=CUDA_DEVICE) - decode_packed_query, \ - decode_packed_key, \ - decode_packed_value, \ - _, \ - _ = pack_qkv( - decode_query, decode_key, decode_value, decode_q_seq_lens, - decode_kv_seq_lens, - device=CUDA_DEVICE) + decode_pckd_qkv = pack_qkv(decode_qkv, + device=CUDA_DEVICE) - return query, \ - prefill_packed_query, \ - prefill_packed_key, \ - prefill_packed_value, \ + return qkv, \ + prefill_pckd_qkv, \ prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - decode_packed_query, \ - decode_packed_key, \ - decode_packed_value, \ + decode_pckd_qkv, \ decode_packed_ideal_output, \ - q_seq_lens, \ decode_block_tables, \ decode_slot_mapping, \ prefill_slot_mapping, \ prefill_block_tables, \ max_block_idx - -def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, - decoder_seq_lens: List[int], - encoder_seq_lens: Optional[List[int]], - prefill_q_seq_lens: List, - batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_decoder_seq_len: int, - max_encoder_seq_len: int, - block_base_addr: Optional[int]=0) \ +def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, + encoder_packed_qkv: PackedQKVInputs, + prefill_phase_decoder_packed_qkv: PackedQKVInputs, + batch_size: int, + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_decoder_seq_len: int, + max_encoder_seq_len: int, + block_base_addr: Optional[int]=0) \ -> tuple: ''' Set up test vectors & data structures for cross-attention test. @@ -445,19 +396,13 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, * max_block_idx: highest block address in the cross-attention block-table ''' - _, \ - key, \ - value, \ - _, \ - _, \ - _, \ - _, \ - _, \ - _, \ - _, \ - kv_seq_lens, \ - _, \ - _, \ + decoder_query = decoder_qkv.query + decoder_seq_lens = decoder_qkv.q_seq_lens + encoder_seq_lens = encoder_packed_qkv.q_seq_lens + prefill_q_seq_lens = prefill_phase_decoder_packed_qkv.q_seq_lens + + + cross_kv, \ _, \ _ = make_qkv(batch_size, max_decoder_seq_len, @@ -468,12 +413,12 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, attn_type=AttentionType.ENCODER_DECODER, device=CUDA_DEVICE) - ideal_output = ref_masked_attention(query, - key, - value, + ideal_output = ref_masked_attention(decoder_query, + cross_kv.key, + cross_kv.value, scale=scale, q_seq_lens=decoder_seq_lens, - kv_seq_lens=kv_seq_lens) + kv_seq_lens=cross_kv.kv_seq_lens) prefill_ideal_output = torch.zeros_like(ideal_output) decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) @@ -520,9 +465,9 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, decode_block_tables, \ prefill_slot_mapping_list, \ - max_block_idx = make_block_tables_slot_mapping( + _ = make_block_tables_slot_mapping( block_size, - kv_seq_lens, + cross_kv.kv_seq_lens, block_base_addr=block_base_addr, device=CUDA_DEVICE) @@ -530,26 +475,20 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, device=CUDA_DEVICE) # Packed key/value (query is already provided) - _, packed_key, packed_value, _, _ = pack_qkv(None, - key, - value, - None, - kv_seq_lens, - device=CUDA_DEVICE) + packed_cross_kv = pack_qkv(cross_kv, + device=CUDA_DEVICE) - return packed_key, \ - packed_value, \ - prefill_packed_ideal_output, \ - decode_packed_ideal_output, \ - decode_block_tables, \ - decode_slot_mapping, \ - prefill_slot_mapping, \ - prefill_block_tables + return packed_cross_kv, \ + prefill_packed_ideal_output, \ + decode_packed_ideal_output, \ + decode_block_tables, \ + decode_slot_mapping, \ + prefill_slot_mapping, \ + prefill_block_tables -def _run_encoder_attention_test(attn: Attention, packed_query: torch.Tensor, - packed_key: torch.Tensor, - packed_value: torch.Tensor, +def _run_encoder_attention_test(attn: Attention, + pckd_qkv: PackedQKVInputs, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: ''' @@ -563,7 +502,7 @@ def _run_encoder_attention_test(attn: Attention, packed_query: torch.Tensor, Arguments: * attn: Attention wrapper instance - * packed_{query,key,value}: total_num_tokens x (num_heads*head_size) + * pckd_qkv: Packed query/key/value inputs * attn_metadata: attention metadata for encoder/decoder-self attention * attn_type: AttentionType.DECODER or AttentionType.ENCODER @@ -574,14 +513,15 @@ def _run_encoder_attention_test(attn: Attention, packed_query: torch.Tensor, assert attn_type == AttentionType.ENCODER assert attn_metadata.num_decode_tokens == 0 attn_metadata.attention_type = attn_type - return attn.forward(packed_query, packed_key, packed_value, None, + return attn.forward(pckd_qkv.query, + pckd_qkv.key, + pckd_qkv.value, + None, attn_metadata) def _run_decoder_self_attention_test(attn: Attention, - packed_query: torch.Tensor, - packed_key: torch.Tensor, - packed_value: torch.Tensor, + pckd_qkv: PackedQKVInputs, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: @@ -596,7 +536,7 @@ def _run_decoder_self_attention_test(attn: Attention, Arguments: * attn: Attention wrapper instance - * packed_{query,key,value}: total_num_tokens x (num_heads*head_size) + * pckd_qkv: Packed query/key/value inputs * kv_cache * attn_metadata: attention metadata for encoder/decoder-self attention * attn_type: AttentionType.DECODER or AttentionType.ENCODER @@ -607,13 +547,18 @@ def _run_decoder_self_attention_test(attn: Attention, ''' assert attn_type == AttentionType.DECODER attn_metadata.attention_type = attn_type - return attn.forward(packed_query, packed_key, packed_value, kv_cache, + return attn.forward(pckd_qkv.query, + pckd_qkv.key, + pckd_qkv.value, + kv_cache, attn_metadata) def _run_encoder_decoder_cross_attention_test( - attn: Attention, packed_query: torch.Tensor, packed_key: torch.Tensor, - packed_value: torch.Tensor, kv_cache: torch.Tensor, + attn: Attention, + dec_pckd_qkv: PackedQKVInputs, + cross_pckd_qkv: PackedQKVInputs, + kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -634,7 +579,14 @@ def _run_encoder_decoder_cross_attention_test( & attn_metadata ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER - return attn.forward(packed_query, packed_key, packed_value, kv_cache, + key = None if cross_pckd_qkv is None else \ + cross_pckd_qkv.key + value = None if cross_pckd_qkv is None else \ + cross_pckd_qkv.value + return attn.forward(dec_pckd_qkv.query, + key, + value, + kv_cache, attn_metadata) @@ -700,34 +652,33 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # tensors are not actually utilized by encoder attention # anyway but are required to be present & valid by the # backend. - enc_packed_query, \ - enc_packed_key, \ - enc_packed_value, \ - enc_packed_ideal_output, \ - encoder_seq_lens = _encoder_attn_setup(batch_size, - num_heads, - head_size, - scale, - max_enc_seq_len) + + # encoder_packed_query, \ + # enc_packed_key, \ + # enc_packed_value, \ + # encoder_packed_ideal_output, \ + # encoder_seq_lens = + + + enc_pckd_qkv, \ + enc_pckd_idl_out = \ + _encoder_attn_setup(batch_size, + num_heads, + head_size, + scale, + max_enc_seq_len) # Decoder self-attention setup - query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ - prefill_decoder_seq_lens, \ - self_prefill_encoder_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ - decoder_seq_lens, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ + dec_qkv, \ + prephase_dec_pckd_qkv, \ + prephase_dec_pckd_idl_out, \ + decphase_dec_pckd_qkv, \ + decphase_dec_pckd_idl_out, \ + decphase_dec_blk_tbls, \ + decphase_dec_slt_map, \ + prephase_dec_slt_map, \ + prephase_dec_blk_tbls, \ cross_block_base_addr = _decoder_attn_setup(batch_size, num_heads, head_size, @@ -737,18 +688,16 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # Cross-attention setup - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ + prephase_cross_pckd_qkv, \ + prephase_cross_pckd_idl_out, \ + decphase_cross_pckd_idl_out, \ cross_decode_block_tables, \ cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - = _enc_dec_cross_attn_setup_reuses_query(query, - decoder_seq_lens, - encoder_seq_lens, - prefill_decoder_seq_lens, + = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, + enc_pckd_qkv, + prephase_dec_pckd_qkv, batch_size, num_heads, head_size, @@ -762,16 +711,16 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # Shared prefill metadata structure # - enc_and_dec_prefill_attn_metadata: AttentionMetadata = make_test_metadata( + prephase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, - prefill_decoder_seq_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=True, - num_prefills_or_decodes=len(prefill_decoder_seq_lens), - num_prefill_or_decode_tokens=sum(prefill_decoder_seq_lens), - encoder_seq_lens=encoder_seq_lens, + prephase_dec_pckd_qkv.q_seq_lens, + prephase_dec_blk_tbls, + prephase_dec_slt_map, + default_attn_type=AttentionType.ENCODER, + num_prefills_or_decodes=len(prephase_dec_pckd_qkv.q_seq_lens), + num_prefill_or_decode_tokens=sum(prephase_dec_pckd_qkv.q_seq_lens), + encoder_seq_lens=enc_pckd_qkv.q_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, device=CUDA_DEVICE) @@ -782,97 +731,99 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( enc_packed_actual_output: torch.Tensor = \ _run_encoder_attention_test( attn, - enc_packed_query, - enc_packed_key, - enc_packed_value, - enc_and_dec_prefill_attn_metadata, + enc_pckd_qkv, + prephase_attn_metadata, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? assert torch.allclose( - enc_packed_ideal_output, - enc_packed_actual_output.view_as(enc_packed_ideal_output)) + enc_pckd_idl_out, + enc_packed_actual_output.view_as(enc_pckd_idl_out)) # PREFILL: self-attention test self_prefill_packed_actual_output: torch.Tensor = \ _run_decoder_self_attention_test( attn, - prefill_packed_query, - self_prefill_packed_key, - self_prefill_packed_value, + prephase_dec_pckd_qkv, kv_cache, - enc_and_dec_prefill_attn_metadata, + prephase_attn_metadata, attn_type=AttentionType.DECODER) # - Prefill self-attention correct? assert torch.allclose( - self_prefill_packed_ideal_output, + prephase_dec_pckd_idl_out, self_prefill_packed_actual_output.view_as( - self_prefill_packed_ideal_output)) + prephase_dec_pckd_idl_out)) # PREFILL: cross-attention test - cross_prefill_packed_actual_output: torch.Tensor = \ + prephase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( - attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, enc_and_dec_prefill_attn_metadata) + attn, + prephase_dec_pckd_qkv, + prephase_cross_pckd_qkv, + kv_cache, + prephase_attn_metadata) # - Prefill cross-attention correct? assert torch.allclose( - cross_prefill_packed_ideal_output, - cross_prefill_packed_actual_output.view_as( - cross_prefill_packed_ideal_output)) - - context_lens = copy.deepcopy(self_prefill_encoder_seq_lens) + prephase_cross_pckd_idl_out, + prephase_cross_pckd_act_out.view_as( + prephase_cross_pckd_idl_out)) # DECODE: build decode-phase attention metadata - decode_attn_metadata: AttentionMetadata = make_test_metadata( + # - Cross-attention KV context is equal in length to + # encoder input + context_lens = copy.deepcopy(enc_pckd_qkv.q_seq_lens) + + decphase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, False, - decoder_seq_lens, - self_decode_block_tables, - self_decode_slot_mapping, - is_encoder_only_test=False, + dec_qkv.q_seq_lens, + decphase_dec_blk_tbls, + decphase_dec_slt_map, + default_attn_type=AttentionType.DECODER, context_lens=context_lens, - num_prefills_or_decodes=len(decoder_seq_lens), - num_prefill_or_decode_tokens=len(decoder_seq_lens), - encoder_seq_lens=encoder_seq_lens, + num_prefills_or_decodes=len(dec_qkv.q_seq_lens), + num_prefill_or_decode_tokens=len(dec_qkv.q_seq_lens), + encoder_seq_lens=enc_pckd_qkv.q_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, device=CUDA_DEVICE) # DECODE: self-attention test - self_decode_packed_actual_output: torch.Tensor = \ + decphase_dec_pckd_act_out: torch.Tensor = \ _run_decoder_self_attention_test( attn, - decode_packed_query, - self_decode_packed_key, - self_decode_packed_value, + decphase_dec_pckd_qkv, kv_cache, - decode_attn_metadata, + decphase_attn_metadata, attn_type=AttentionType.DECODER) # - Decode self-attention correct? assert torch.allclose( - self_decode_packed_ideal_output, - self_decode_packed_actual_output.view_as( - self_decode_packed_ideal_output)) + decphase_dec_pckd_idl_out, + decphase_dec_pckd_act_out.view_as( + decphase_dec_pckd_idl_out)) # DECODE: cross-attention test - cross_decode_packed_actual_output: torch.Tensor = \ + decphase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( - attn, decode_packed_query, None, - None, kv_cache, decode_attn_metadata) + attn, + decphase_dec_pckd_qkv, + None, + kv_cache, + decphase_attn_metadata) # - Decode cross-attention correct? assert torch.allclose( - cross_decode_packed_ideal_output, - cross_decode_packed_actual_output.view_as( - cross_decode_packed_ideal_output)) + decphase_cross_pckd_idl_out, + decphase_cross_pckd_act_out.view_as( + decphase_cross_pckd_idl_out)) # The following test conditions could in principle be a # standalone test, however the test setup is @@ -889,145 +840,147 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # already; the line below sets up a chunked prefill # metadata configuration where there is nominally a mix # of prefill and decode tokens. - decode_attn_metadata.num_prefill_tokens = 1 + decphase_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(attn, decode_packed_query, - None, None, kv_cache, - decode_attn_metadata) + _run_encoder_decoder_cross_attention_test(attn, + decphase_dec_pckd_qkv, + None, + kv_cache, + decphase_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL -@pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") -@pytest.mark.parametrize("num_heads", [256]) -@pytest.mark.parametrize("head_size", [16]) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", [16]) -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("max_q_seq_len", [64]) -@pytest.mark.parametrize("max_kv_seq_len", [64]) -def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, - backend_name: str, batch_size: int, - block_size: int, max_q_seq_len: int, - max_kv_seq_len: int, monkeypatch) -> None: - ''' - Encoder/decoder not-implemented-for-ROCm-HIP test: - - * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention - attributes - * Test self- and cross-attention in the following order +# @pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") +# @pytest.mark.parametrize("num_heads", [256]) +# @pytest.mark.parametrize("head_size", [16]) +# @pytest.mark.parametrize("backend_name", BACKEND_NAMES) +# @pytest.mark.parametrize("batch_size", [16]) +# @pytest.mark.parametrize("block_size", [16]) +# @pytest.mark.parametrize("max_q_seq_len", [64]) +# @pytest.mark.parametrize("max_kv_seq_len", [64]) +# def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, +# backend_name: str, batch_size: int, +# block_size: int, max_q_seq_len: int, +# max_kv_seq_len: int, monkeypatch) -> None: +# ''' +# Encoder/decoder not-implemented-for-ROCm-HIP test: + +# * Construct fake test vectors for self- and cross-attention +# * Construct attention metadata structure with self- and cross-attention +# attributes +# * Test self- and cross-attention in the following order - * Prefill self-attention - * Prefill cross-attention - * Decode self-attention - * Decode cross-attention - * This order would exacerbate any accidental overlap in the - self-/cross-attention block tables, which we attempt to avoid - * Validate output correctness against ideal reference attention - implementation - - Block tables are constructed such that cross-attention KV cache is in a - higher, non-intersecting address-space than self-attention KV cache. - - Self- and cross-attention share the same query tensor but not the K/V - tensors. Self-attention K/Vs must have the same seq len as Q while - cross-attention K/Vs are allowed to differ in seq len, as is often the case - for cross-attention. - ''' - - # Force Attention wrapper backend - override_backend_env_variable(monkeypatch, backend_name) - - # Num KV cache blocks - num_blocks = 4096 - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - scale, \ - attn_backend, \ - attn, \ - kv_cache = _basic_setup(num_heads, - head_size, - num_blocks, - block_size, - backend_name) - - # Self-attention setup - - self_block_base_addr = 0 - - query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ - q_seq_lens, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ - cross_block_base_addr = _decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=self_block_base_addr) - - # Cross-attention setup - - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ - encoder_kv_seq_lens, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ - = _enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr = \ - cross_block_base_addr) - - # PREFILL: self- and cross-attention tests - - prefill_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - prefill_q_seq_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=False, - num_prefills_or_decodes=len(prefill_q_seq_lens), - num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), - encoder_seq_lens=encoder_kv_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, - device=CUDA_DEVICE) - - with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, - cross_prefill_packed_key, - cross_prefill_packed_value, - kv_cache, - prefill_attn_metadata) - - # "Encoder decoder models do not currently support ROCm/HIP" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP +# * Prefill self-attention +# * Prefill cross-attention +# * Decode self-attention +# * Decode cross-attention +# * This order would exacerbate any accidental overlap in the +# self-/cross-attention block tables, which we attempt to avoid +# * Validate output correctness against ideal reference attention +# implementation + +# Block tables are constructed such that cross-attention KV cache is in a +# higher, non-intersecting address-space than self-attention KV cache. + +# Self- and cross-attention share the same query tensor but not the K/V +# tensors. Self-attention K/Vs must have the same seq len as Q while +# cross-attention K/Vs are allowed to differ in seq len, as is often the case +# for cross-attention. +# ''' + +# # Force Attention wrapper backend +# override_backend_env_variable(monkeypatch, backend_name) + +# # Num KV cache blocks +# num_blocks = 4096 + +# # Attention scale factor, attention backend instance, attention wrapper +# # instance, KV cache init +# scale, \ +# attn_backend, \ +# attn, \ +# kv_cache = _basic_setup(num_heads, +# head_size, +# num_blocks, +# block_size, +# backend_name) + +# # Self-attention setup + +# self_block_base_addr = 0 + +# query, \ +# prefill_packed_query, \ +# self_prefill_packed_key, \ +# self_prefill_packed_value, \ +# self_prefill_packed_ideal_output, \ +# prefill_q_seq_lens, \ +# self_prefill_kv_seq_lens, \ +# decode_packed_query, \ +# self_decode_packed_key, \ +# self_decode_packed_value, \ +# self_decode_packed_ideal_output, \ +# q_seq_lens, \ +# self_decode_block_tables, \ +# self_decode_slot_mapping, \ +# self_prefill_slot_mapping, \ +# self_prefill_block_tables, \ +# cross_block_base_addr = _decoder_attn_setup(batch_size, +# num_heads, +# head_size, +# block_size, +# scale, +# max_q_seq_len, +# block_base_addr=self_block_base_addr) + +# # Cross-attention setup + +# cross_prefill_packed_key, \ +# cross_prefill_packed_value, \ +# cross_prefill_packed_ideal_output, \ +# cross_decode_packed_ideal_output, \ +# encoder_kv_seq_lens, \ +# cross_decode_block_tables, \ +# cross_decode_slot_mapping, \ +# cross_prefill_slot_mapping, \ +# cross_prefill_block_tables, \ +# = _enc_dec_cross_attn_setup_reuses_query(query, +# q_seq_lens, +# prefill_q_seq_lens, +# batch_size, +# num_heads, +# head_size, +# block_size, +# scale, +# max_q_seq_len, +# max_kv_seq_len, +# block_base_addr = \ +# cross_block_base_addr) + +# # PREFILL: self- and cross-attention tests + +# prefill_attn_metadata: AttentionMetadata = make_test_metadata( +# attn_backend, +# True, +# prefill_q_seq_lens, +# self_prefill_block_tables, +# self_prefill_slot_mapping, +# is_encoder_only_test=False, +# num_prefills_or_decodes=len(prefill_q_seq_lens), +# num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), +# encoder_seq_lens=encoder_kv_seq_lens, +# cross_block_tables=cross_prefill_block_tables, +# cross_slot_mapping=cross_prefill_slot_mapping, +# device=CUDA_DEVICE) + +# with pytest.raises(NotImplementedError) as exc_info: +# _run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, +# cross_prefill_packed_key, +# cross_prefill_packed_value, +# kv_cache, +# prefill_attn_metadata) + +# # "Encoder decoder models do not currently support ROCm/HIP" +# assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index c36a969149e73..142bc64d34b99 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -7,6 +7,8 @@ import pytest import torch +from collections import namedtuple + from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend @@ -90,6 +92,23 @@ def ref_masked_attention(query: torch.Tensor, out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) return out +# batch_size x max_q_seq_len x num_heads x head_size +QKVInputs = namedtuple("QKVInputs", + ["query", + "key", + "value", + "q_seq_lens", + "kv_seq_lens"]) + +# total_num_tokens x (num_heads*head_size) +PackedQKVInputs = namedtuple("PackedQKVInputs", + ["query", + "key", + "value", + "q_start_loc_list", + "kv_start_loc_list", + "q_seq_lens", + "kv_seq_lens"]) def make_qkv( batch_size: int, @@ -101,7 +120,7 @@ def make_qkv( force_kv_seq_lens: List[int] = None, attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len: bool = False, -) -> tuple: +) -> tuple[QKVInputs,QKVInputs,QKVInputs]: ''' Construct QKV test tensors for self- and cross-attention. @@ -136,29 +155,7 @@ def make_qkv( Returns: - * query: "baseline" query; batch_size x max_q_seq_len x num_heads x - head_size - * key: "baseline" key; batch_size x max_kv_seq_len x num_heads x - head_size - * value: "baseline" value; batch_size x max_kv_seq_len x num_heads x - head_size - * prefill_query: batch_size x (max_q_seq_len-1) x num_heads x head_size - * prefill_key: batch_size x (max_kv_seq_len-1) x num_heads x head_size - * prefill_value: batch_size x (max_kv_seq_len-1) x num_heads x head_size - * decode_query: batch_size x 1 x num_heads x head_size - * decode_key: batch_size x 1 x num_heads x head_size - * decode_value: batch_size x 1 x num_heads x head_size - * q_seq_lens: "baseline" query seqlen list - * kv_seq_lens: "baseline" key/value seqlen list; overridden by non-None - force_encoder_kv_seq_lens - * actual_max_q_seq_len: actual "baseline" query max seq len (may be <= - max_q_seq_len due to randomness) - * actual_max_kv_seq_len: actual "baseline" key/value max seq len (may - be <= max_kv_seq_len due to randomness) - * prefill_q_seq_lens: "prefill" query seqlen list - * prefill_kv_seq_lens: "prefill" key/value seqlen list - * decode_q_seq_lens: "decode" query seqlen list (all ones) - * decode_kv_seq_lens: "decode" key/value seqlen list + * QKVInputs structure ''' if force_max_len: @@ -182,9 +179,6 @@ def make_qkv( random.randint(2, max_kv_seq_len) for _ in range(batch_size) ] - actual_max_q_seq_len = max(q_seq_lens) - actual_max_kv_seq_len = max(kv_seq_lens) - query = torch.rand( (batch_size, max_q_seq_len, num_heads, head_size)).to(device) key = torch.rand( @@ -232,21 +226,22 @@ def make_qkv( decode_q_seq_lens = [1 for _ in q_seq_lens] decode_kv_seq_lens = [1 for _ in kv_seq_lens] - return query, \ - key, \ - value, \ - prefill_query, \ - prefill_key, \ - prefill_value, \ - decode_query, \ - decode_key, \ - decode_value, \ - q_seq_lens, \ - kv_seq_lens, \ - prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - decode_q_seq_lens, \ - decode_kv_seq_lens + return QKVInputs(query, + key, + value, + q_seq_lens, + kv_seq_lens), \ + QKVInputs(prefill_query, + prefill_key, + prefill_value, + prefill_q_seq_lens, + prefill_kv_seq_lens), \ + QKVInputs( + decode_query, + decode_key, + decode_value, + decode_q_seq_lens, + decode_kv_seq_lens) def pack_tensor(unpacked_tensor: torch.Tensor, seq_lens: List[int], @@ -283,9 +278,8 @@ def pack_tensor(unpacked_tensor: torch.Tensor, seq_lens: List[int], return packed_tensor, start_loc_list -def pack_qkv(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - q_seq_lens: List[int], kv_seq_lens: List[int], - device: Union[torch.device, str]) -> tuple: +def pack_qkv(qkv: QKVInputs, + device: Union[torch.device, str]) -> PackedQKVInputs: ''' Individually pack each of Q, K and V, each with dimensions batch_size x padded_seq_len x num_heads x head_size, into respective number_of_tokens x @@ -312,22 +306,25 @@ def pack_qkv(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, * kv_start_loc_list: start idx of each {key,value} in packed_{key,value} ''' - if query is None: + if qkv.query is None: packed_query = None q_start_loc_list = None else: - packed_query, q_start_loc_list = pack_tensor(query, - q_seq_lens, + packed_query, q_start_loc_list = pack_tensor(qkv.query, + qkv.q_seq_lens, device=device) - packed_key, kv_start_loc_list = pack_tensor(key, - kv_seq_lens, + packed_key, kv_start_loc_list = pack_tensor(qkv.key, + qkv.kv_seq_lens, device=device) - packed_value, _ = pack_tensor(value, kv_seq_lens, device=device) - return packed_query, \ - packed_key, \ - packed_value, \ - q_start_loc_list, \ - kv_start_loc_list + packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device) + return PackedQKVInputs(packed_query, \ + packed_key, \ + packed_value, \ + q_start_loc_list, \ + kv_start_loc_list, \ + None if q_start_loc_list is None else \ + qkv.q_seq_lens, \ + qkv.kv_seq_lens) def make_backend(backend_name: str) -> AttentionBackend: @@ -589,7 +586,7 @@ def make_test_metadata( seq_lens: List[int], block_tables: torch.Tensor, slot_mapping: torch.Tensor, - is_encoder_only_test: bool, + default_attn_type: AttentionType, num_prefills_or_decodes: int, num_prefill_or_decode_tokens: int, device: Union[torch.device, str], @@ -631,9 +628,6 @@ def make_test_metadata( * AttentionMetadata structure supporting self- and cross-attention ''' - default_attn_type = AttentionType.ENCODER if is_encoder_only_test \ - else AttentionType.DECODER - if is_prompt: num_prefills = num_prefills_or_decodes num_prefill_tokens = num_prefill_or_decode_tokens From a81712b14126ff7563ed0e6ce145a8401c0efaec Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 00:36:49 -0400 Subject: [PATCH 170/443] refactoring --- tests/kernels/test_encoder_decoder_attn.py | 246 ++++----------------- tests/kernels/utils.py | 31 +-- 2 files changed, 56 insertions(+), 221 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 431fe42cd6efe..955fa1629440f 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -7,21 +7,19 @@ """ import copy -from typing import List, Optional +from typing import Optional import pytest import torch -import collections - -from tests.kernels.utils import (make_backend, make_block_tables_slot_mapping, +from tests.kernels.utils import (PackedQKVInputs, QKVInputs, make_backend, + make_block_tables_slot_mapping, make_empty_block_tables_tensor, make_empty_slot_mapping_tensor, make_kv_cache, make_qkv, make_test_metadata, override_backend_env_variable, pack_qkv, pack_tensor, ref_masked_attention, - split_slot_mapping, QKVInputs, - PackedQKVInputs) + split_slot_mapping) from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( @@ -124,14 +122,14 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, ''' max_kv_seq_len = max_q_seq_len - + qkv_in, _, _ = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.ENCODER, - device=CUDA_DEVICE) + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE) # No causal attention mask ideal_output = ref_masked_attention(qkv_in.query, @@ -145,9 +143,7 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, qkv_in.q_seq_lens, device=CUDA_DEVICE) - packed_qkv = pack_qkv( - qkv_in, - device=CUDA_DEVICE) + packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE) return packed_qkv, \ packed_ideal_output @@ -306,11 +302,9 @@ def _decoder_attn_setup(batch_size: int, qkv.q_seq_lens, device=CUDA_DEVICE) - prefill_pckd_qkv = pack_qkv(prefill_qkv, - device=CUDA_DEVICE) + prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE) - decode_pckd_qkv = pack_qkv(decode_qkv, - device=CUDA_DEVICE) + decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE) return qkv, \ prefill_pckd_qkv, \ @@ -324,8 +318,10 @@ def _decoder_attn_setup(batch_size: int, max_block_idx def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, - encoder_packed_qkv: PackedQKVInputs, - prefill_phase_decoder_packed_qkv: PackedQKVInputs, + encoder_packed_qkv: + PackedQKVInputs, + prefill_phase_decoder_packed_qkv: + PackedQKVInputs, batch_size: int, num_heads: int, head_size: int, @@ -475,8 +471,7 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, device=CUDA_DEVICE) # Packed key/value (query is already provided) - packed_cross_kv = pack_qkv(cross_kv, - device=CUDA_DEVICE) + packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE) return packed_cross_kv, \ prefill_packed_ideal_output, \ @@ -487,8 +482,7 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, prefill_block_tables -def _run_encoder_attention_test(attn: Attention, - pckd_qkv: PackedQKVInputs, +def _run_encoder_attention_test(attn: Attention, pckd_qkv: PackedQKVInputs, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: ''' @@ -513,10 +507,7 @@ def _run_encoder_attention_test(attn: Attention, assert attn_type == AttentionType.ENCODER assert attn_metadata.num_decode_tokens == 0 attn_metadata.attention_type = attn_type - return attn.forward(pckd_qkv.query, - pckd_qkv.key, - pckd_qkv.value, - None, + return attn.forward(pckd_qkv.query, pckd_qkv.key, pckd_qkv.value, None, attn_metadata) @@ -547,18 +538,13 @@ def _run_decoder_self_attention_test(attn: Attention, ''' assert attn_type == AttentionType.DECODER attn_metadata.attention_type = attn_type - return attn.forward(pckd_qkv.query, - pckd_qkv.key, - pckd_qkv.value, - kv_cache, + return attn.forward(pckd_qkv.query, pckd_qkv.key, pckd_qkv.value, kv_cache, attn_metadata) def _run_encoder_decoder_cross_attention_test( - attn: Attention, - dec_pckd_qkv: PackedQKVInputs, - cross_pckd_qkv: PackedQKVInputs, - kv_cache: torch.Tensor, + attn: Attention, dec_pckd_qkv: PackedQKVInputs, + cross_pckd_qkv: PackedQKVInputs, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -583,10 +569,7 @@ def _run_encoder_decoder_cross_attention_test( cross_pckd_qkv.key value = None if cross_pckd_qkv is None else \ cross_pckd_qkv.value - return attn.forward(dec_pckd_qkv.query, - key, - value, - kv_cache, + return attn.forward(dec_pckd_qkv.query, key, value, kv_cache, attn_metadata) @@ -657,9 +640,9 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # enc_packed_key, \ # enc_packed_value, \ # encoder_packed_ideal_output, \ - # encoder_seq_lens = - - + # encoder_seq_lens = + + enc_pckd_qkv, \ enc_pckd_idl_out = \ _encoder_attn_setup(batch_size, @@ -736,9 +719,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose( - enc_pckd_idl_out, - enc_packed_actual_output.view_as(enc_pckd_idl_out)) + assert torch.allclose(enc_pckd_idl_out, + enc_packed_actual_output.view_as(enc_pckd_idl_out)) # PREFILL: self-attention test @@ -753,24 +735,22 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # - Prefill self-attention correct? assert torch.allclose( prephase_dec_pckd_idl_out, - self_prefill_packed_actual_output.view_as( - prephase_dec_pckd_idl_out)) + self_prefill_packed_actual_output.view_as(prephase_dec_pckd_idl_out)) # PREFILL: cross-attention test prephase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( - attn, - prephase_dec_pckd_qkv, - prephase_cross_pckd_qkv, - kv_cache, + attn, + prephase_dec_pckd_qkv, + prephase_cross_pckd_qkv, + kv_cache, prephase_attn_metadata) # - Prefill cross-attention correct? assert torch.allclose( prephase_cross_pckd_idl_out, - prephase_cross_pckd_act_out.view_as( - prephase_cross_pckd_idl_out)) + prephase_cross_pckd_act_out.view_as(prephase_cross_pckd_idl_out)) # DECODE: build decode-phase attention metadata @@ -806,24 +786,22 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # - Decode self-attention correct? assert torch.allclose( decphase_dec_pckd_idl_out, - decphase_dec_pckd_act_out.view_as( - decphase_dec_pckd_idl_out)) + decphase_dec_pckd_act_out.view_as(decphase_dec_pckd_idl_out)) # DECODE: cross-attention test decphase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( - attn, + attn, decphase_dec_pckd_qkv, - None, - kv_cache, + None, + kv_cache, decphase_attn_metadata) # - Decode cross-attention correct? assert torch.allclose( decphase_cross_pckd_idl_out, - decphase_cross_pckd_act_out.view_as( - decphase_cross_pckd_idl_out)) + decphase_cross_pckd_act_out.view_as(decphase_cross_pckd_idl_out)) # The following test conditions could in principle be a # standalone test, however the test setup is @@ -842,145 +820,9 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # of prefill and decode tokens. decphase_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(attn, - decphase_dec_pckd_qkv, - None, - kv_cache, + _run_encoder_decoder_cross_attention_test(attn, decphase_dec_pckd_qkv, + None, kv_cache, decphase_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL - - -# @pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") -# @pytest.mark.parametrize("num_heads", [256]) -# @pytest.mark.parametrize("head_size", [16]) -# @pytest.mark.parametrize("backend_name", BACKEND_NAMES) -# @pytest.mark.parametrize("batch_size", [16]) -# @pytest.mark.parametrize("block_size", [16]) -# @pytest.mark.parametrize("max_q_seq_len", [64]) -# @pytest.mark.parametrize("max_kv_seq_len", [64]) -# def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, -# backend_name: str, batch_size: int, -# block_size: int, max_q_seq_len: int, -# max_kv_seq_len: int, monkeypatch) -> None: -# ''' -# Encoder/decoder not-implemented-for-ROCm-HIP test: - -# * Construct fake test vectors for self- and cross-attention -# * Construct attention metadata structure with self- and cross-attention -# attributes -# * Test self- and cross-attention in the following order - -# * Prefill self-attention -# * Prefill cross-attention -# * Decode self-attention -# * Decode cross-attention -# * This order would exacerbate any accidental overlap in the -# self-/cross-attention block tables, which we attempt to avoid -# * Validate output correctness against ideal reference attention -# implementation - -# Block tables are constructed such that cross-attention KV cache is in a -# higher, non-intersecting address-space than self-attention KV cache. - -# Self- and cross-attention share the same query tensor but not the K/V -# tensors. Self-attention K/Vs must have the same seq len as Q while -# cross-attention K/Vs are allowed to differ in seq len, as is often the case -# for cross-attention. -# ''' - -# # Force Attention wrapper backend -# override_backend_env_variable(monkeypatch, backend_name) - -# # Num KV cache blocks -# num_blocks = 4096 - -# # Attention scale factor, attention backend instance, attention wrapper -# # instance, KV cache init -# scale, \ -# attn_backend, \ -# attn, \ -# kv_cache = _basic_setup(num_heads, -# head_size, -# num_blocks, -# block_size, -# backend_name) - -# # Self-attention setup - -# self_block_base_addr = 0 - -# query, \ -# prefill_packed_query, \ -# self_prefill_packed_key, \ -# self_prefill_packed_value, \ -# self_prefill_packed_ideal_output, \ -# prefill_q_seq_lens, \ -# self_prefill_kv_seq_lens, \ -# decode_packed_query, \ -# self_decode_packed_key, \ -# self_decode_packed_value, \ -# self_decode_packed_ideal_output, \ -# q_seq_lens, \ -# self_decode_block_tables, \ -# self_decode_slot_mapping, \ -# self_prefill_slot_mapping, \ -# self_prefill_block_tables, \ -# cross_block_base_addr = _decoder_attn_setup(batch_size, -# num_heads, -# head_size, -# block_size, -# scale, -# max_q_seq_len, -# block_base_addr=self_block_base_addr) - -# # Cross-attention setup - -# cross_prefill_packed_key, \ -# cross_prefill_packed_value, \ -# cross_prefill_packed_ideal_output, \ -# cross_decode_packed_ideal_output, \ -# encoder_kv_seq_lens, \ -# cross_decode_block_tables, \ -# cross_decode_slot_mapping, \ -# cross_prefill_slot_mapping, \ -# cross_prefill_block_tables, \ -# = _enc_dec_cross_attn_setup_reuses_query(query, -# q_seq_lens, -# prefill_q_seq_lens, -# batch_size, -# num_heads, -# head_size, -# block_size, -# scale, -# max_q_seq_len, -# max_kv_seq_len, -# block_base_addr = \ -# cross_block_base_addr) - -# # PREFILL: self- and cross-attention tests - -# prefill_attn_metadata: AttentionMetadata = make_test_metadata( -# attn_backend, -# True, -# prefill_q_seq_lens, -# self_prefill_block_tables, -# self_prefill_slot_mapping, -# is_encoder_only_test=False, -# num_prefills_or_decodes=len(prefill_q_seq_lens), -# num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), -# encoder_seq_lens=encoder_kv_seq_lens, -# cross_block_tables=cross_prefill_block_tables, -# cross_slot_mapping=cross_prefill_slot_mapping, -# device=CUDA_DEVICE) - -# with pytest.raises(NotImplementedError) as exc_info: -# _run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, -# cross_prefill_packed_key, -# cross_prefill_packed_value, -# kv_cache, -# prefill_attn_metadata) - -# # "Encoder decoder models do not currently support ROCm/HIP" -# assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL \ No newline at end of file diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 142bc64d34b99..e57d6c412537c 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,13 +2,12 @@ import itertools import random +from collections import namedtuple from typing import List, Optional, Union import pytest import torch -from collections import namedtuple - from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend @@ -92,23 +91,17 @@ def ref_masked_attention(query: torch.Tensor, out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) return out + # batch_size x max_q_seq_len x num_heads x head_size -QKVInputs = namedtuple("QKVInputs", - ["query", - "key", - "value", - "q_seq_lens", - "kv_seq_lens"]) +QKVInputs = namedtuple("QKVInputs", + ["query", "key", "value", "q_seq_lens", "kv_seq_lens"]) # total_num_tokens x (num_heads*head_size) -PackedQKVInputs = namedtuple("PackedQKVInputs", - ["query", - "key", - "value", - "q_start_loc_list", - "kv_start_loc_list", - "q_seq_lens", - "kv_seq_lens"]) +PackedQKVInputs = namedtuple("PackedQKVInputs", [ + "query", "key", "value", "q_start_loc_list", "kv_start_loc_list", + "q_seq_lens", "kv_seq_lens" +]) + def make_qkv( batch_size: int, @@ -120,7 +113,7 @@ def make_qkv( force_kv_seq_lens: List[int] = None, attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len: bool = False, -) -> tuple[QKVInputs,QKVInputs,QKVInputs]: +) -> tuple[QKVInputs, QKVInputs, QKVInputs]: ''' Construct QKV test tensors for self- and cross-attention. @@ -278,8 +271,8 @@ def pack_tensor(unpacked_tensor: torch.Tensor, seq_lens: List[int], return packed_tensor, start_loc_list -def pack_qkv(qkv: QKVInputs, - device: Union[torch.device, str]) -> PackedQKVInputs: +def pack_qkv(qkv: QKVInputs, device: Union[torch.device, + str]) -> PackedQKVInputs: ''' Individually pack each of Q, K and V, each with dimensions batch_size x padded_seq_len x num_heads x head_size, into respective number_of_tokens x From d35ea41a5fd8d55c902dcbfeaac505e8362ef227 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 00:40:26 -0400 Subject: [PATCH 171/443] format --- tests/kernels/test_encoder_decoder_attn.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 955fa1629440f..db0dd3ec9ade8 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -636,13 +636,6 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # anyway but are required to be present & valid by the # backend. - # encoder_packed_query, \ - # enc_packed_key, \ - # enc_packed_value, \ - # encoder_packed_ideal_output, \ - # encoder_seq_lens = - - enc_pckd_qkv, \ enc_pckd_idl_out = \ _encoder_attn_setup(batch_size, @@ -692,7 +685,6 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_block_base_addr) # Shared prefill metadata structure - # prephase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, @@ -825,4 +817,4 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL \ No newline at end of file + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL From 27782df3ac02df531fe71730000e1ad40e34c975 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 00:42:56 -0400 Subject: [PATCH 172/443] yapf fix --- tests/kernels/test_encoder_decoder_attn.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index db0dd3ec9ade8..bc475599163c2 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -12,14 +12,11 @@ import pytest import torch -from tests.kernels.utils import (PackedQKVInputs, QKVInputs, make_backend, - make_block_tables_slot_mapping, - make_empty_block_tables_tensor, - make_empty_slot_mapping_tensor, make_kv_cache, - make_qkv, make_test_metadata, - override_backend_env_variable, pack_qkv, - pack_tensor, ref_masked_attention, - split_slot_mapping) +from tests.kernels.utils import ( + PackedQKVInputs, QKVInputs, make_backend, make_block_tables_slot_mapping, + make_empty_block_tables_tensor, make_empty_slot_mapping_tensor, + make_kv_cache, make_qkv, make_test_metadata, override_backend_env_variable, + pack_qkv, pack_tensor, ref_masked_attention, split_slot_mapping) from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( From c3a2e7afb82448299ffd3caf351d5b3d3b11cfdd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 00:44:26 -0400 Subject: [PATCH 173/443] import reorg --- tests/kernels/test_encoder_decoder_attn.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index bc475599163c2..db0dd3ec9ade8 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -12,11 +12,14 @@ import pytest import torch -from tests.kernels.utils import ( - PackedQKVInputs, QKVInputs, make_backend, make_block_tables_slot_mapping, - make_empty_block_tables_tensor, make_empty_slot_mapping_tensor, - make_kv_cache, make_qkv, make_test_metadata, override_backend_env_variable, - pack_qkv, pack_tensor, ref_masked_attention, split_slot_mapping) +from tests.kernels.utils import (PackedQKVInputs, QKVInputs, make_backend, + make_block_tables_slot_mapping, + make_empty_block_tables_tensor, + make_empty_slot_mapping_tensor, make_kv_cache, + make_qkv, make_test_metadata, + override_backend_env_variable, pack_qkv, + pack_tensor, ref_masked_attention, + split_slot_mapping) from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( From 8babfda338e5f941054b630cf6973bfb6d56f747 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 00:47:25 -0400 Subject: [PATCH 174/443] switched to star import to avoid unsatisfiable formatting constraints --- tests/kernels/test_encoder_decoder_attn.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index db0dd3ec9ade8..0b1546905926b 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -12,14 +12,7 @@ import pytest import torch -from tests.kernels.utils import (PackedQKVInputs, QKVInputs, make_backend, - make_block_tables_slot_mapping, - make_empty_block_tables_tensor, - make_empty_slot_mapping_tensor, make_kv_cache, - make_qkv, make_test_metadata, - override_backend_env_variable, pack_qkv, - pack_tensor, ref_masked_attention, - split_slot_mapping) +from tests.kernels.utils import * from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( From ce2422be980c7bac12355a6bf8f4fbac8c471136 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 01:09:37 -0400 Subject: [PATCH 175/443] progress on memory map structure integration --- tests/kernels/test_encoder_decoder_attn.py | 22 ++++++++++------------ tests/kernels/utils.py | 14 ++++++++------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 0b1546905926b..3690666371625 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -302,12 +302,14 @@ def _decoder_attn_setup(batch_size: int, return qkv, \ prefill_pckd_qkv, \ prefill_packed_ideal_output, \ + KVMemoryMap( + prefill_block_tables, \ + prefill_slot_mapping), \ decode_pckd_qkv, \ decode_packed_ideal_output, \ - decode_block_tables, \ - decode_slot_mapping, \ - prefill_slot_mapping, \ - prefill_block_tables, \ + KVMemoryMap( + decode_block_tables, \ + decode_slot_mapping), \ max_block_idx def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, @@ -642,12 +644,10 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( dec_qkv, \ prephase_dec_pckd_qkv, \ prephase_dec_pckd_idl_out, \ + prephase_dec_kv_mmap, \ decphase_dec_pckd_qkv, \ decphase_dec_pckd_idl_out, \ - decphase_dec_blk_tbls, \ - decphase_dec_slt_map, \ - prephase_dec_slt_map, \ - prephase_dec_blk_tbls, \ + decphase_dec_kv_mmap, \ cross_block_base_addr = _decoder_attn_setup(batch_size, num_heads, head_size, @@ -683,8 +683,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_backend, True, prephase_dec_pckd_qkv.q_seq_lens, - prephase_dec_blk_tbls, - prephase_dec_slt_map, + prephase_dec_kv_mmap, default_attn_type=AttentionType.ENCODER, num_prefills_or_decodes=len(prephase_dec_pckd_qkv.q_seq_lens), num_prefill_or_decode_tokens=sum(prephase_dec_pckd_qkv.q_seq_lens), @@ -747,8 +746,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_backend, False, dec_qkv.q_seq_lens, - decphase_dec_blk_tbls, - decphase_dec_slt_map, + decphase_dec_kv_mmap, default_attn_type=AttentionType.DECODER, context_lens=context_lens, num_prefills_or_decodes=len(dec_qkv.q_seq_lens), diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index e57d6c412537c..15baa68c8fb41 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -102,6 +102,9 @@ def ref_masked_attention(query: torch.Tensor, "q_seq_lens", "kv_seq_lens" ]) +KVMemoryMap = namedtuple("KVMemoryMap", [ + "block_tables", "slot_mapping" +]) def make_qkv( batch_size: int, @@ -577,8 +580,7 @@ def make_test_metadata( attn_backend: AttentionBackend, is_prompt: bool, seq_lens: List[int], - block_tables: torch.Tensor, - slot_mapping: torch.Tensor, + kv_mmap: KVMemoryMap, default_attn_type: AttentionType, num_prefills_or_decodes: int, num_prefill_or_decode_tokens: int, @@ -639,7 +641,7 @@ def make_test_metadata( return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=slot_mapping, + slot_mapping=kv_mmap.slot_mapping, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -647,7 +649,7 @@ def make_test_metadata( max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, context_lens_tensor=context_lens_tensor, - block_tables=block_tables, + block_tables=kv_mmap.block_tables, use_cuda_graph=False, _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, @@ -675,7 +677,7 @@ def make_test_metadata( return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=slot_mapping, + slot_mapping=kv_mmap.slot_mapping, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -683,7 +685,7 @@ def make_test_metadata( max_prefill_seq_len=0, max_decode_seq_len=max(seq_lens), context_lens_tensor=context_lens_tensor, - block_tables=block_tables, + block_tables=kv_mmap.block_tables, use_cuda_graph=False, _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, From 00450517ee415b8193c612faae041291ce9daccc Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 01:19:50 -0400 Subject: [PATCH 176/443] completed integration of KVMemoryMap into tests --- tests/kernels/test_encoder_decoder_attn.py | 23 ++++++++++------------ tests/kernels/utils.py | 15 ++++++++------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 3690666371625..c53eb9ed85970 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -470,12 +470,13 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, return packed_cross_kv, \ prefill_packed_ideal_output, \ + KVMemoryMap( + prefill_block_tables, \ + prefill_slot_mapping), \ decode_packed_ideal_output, \ - decode_block_tables, \ - decode_slot_mapping, \ - prefill_slot_mapping, \ - prefill_block_tables - + KVMemoryMap( + decode_block_tables, \ + decode_slot_mapping), \ def _run_encoder_attention_test(attn: Attention, pckd_qkv: PackedQKVInputs, attn_metadata: AttentionMetadata, @@ -659,11 +660,9 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_cross_pckd_qkv, \ prephase_cross_pckd_idl_out, \ + prephase_cross_kv_mmap, \ decphase_cross_pckd_idl_out, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ + decphase_cross_kv_mmap \ = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, enc_pckd_qkv, prephase_dec_pckd_qkv, @@ -688,8 +687,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( num_prefills_or_decodes=len(prephase_dec_pckd_qkv.q_seq_lens), num_prefill_or_decode_tokens=sum(prephase_dec_pckd_qkv.q_seq_lens), encoder_seq_lens=enc_pckd_qkv.q_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, + cross_kv_mmap=prephase_cross_kv_mmap, device=CUDA_DEVICE) # PREFILL: encoder attention @@ -752,8 +750,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( num_prefills_or_decodes=len(dec_qkv.q_seq_lens), num_prefill_or_decode_tokens=len(dec_qkv.q_seq_lens), encoder_seq_lens=enc_pckd_qkv.q_seq_lens, - cross_block_tables=cross_decode_block_tables, - cross_slot_mapping=cross_decode_slot_mapping, + cross_kv_mmap=decphase_cross_kv_mmap, device=CUDA_DEVICE) # DECODE: self-attention test diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 15baa68c8fb41..d0f5f20e12b1d 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -587,8 +587,7 @@ def make_test_metadata( device: Union[torch.device, str], context_lens: Optional[List[int]] = None, encoder_seq_lens: Optional[List[int]] = None, - cross_block_tables: Optional[torch.Tensor] = None, - cross_slot_mapping: Optional[List[int]] = None, + cross_kv_mmap: Optional[KVMemoryMap] = None, ) -> AttentionMetadata: ''' Construct fake attention metadata for a combined self-/cross-attention @@ -655,8 +654,10 @@ def make_test_metadata( encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=cross_slot_mapping, - cross_block_tables=cross_block_tables) + cross_slot_mapping=None if cross_kv_mmap is None else \ + cross_kv_mmap.slot_mapping, + cross_block_tables=None if cross_kv_mmap is None else \ + cross_kv_mmap.block_tables) else: # not is_prompt @@ -691,5 +692,7 @@ def make_test_metadata( encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=cross_slot_mapping, - cross_block_tables=cross_block_tables) + cross_slot_mapping=None if cross_kv_mmap is None else \ + cross_kv_mmap.slot_mapping, + cross_block_tables=None if cross_kv_mmap is None else \ + cross_kv_mmap.block_tables) From 91eb0671a960a6f52ca8de22d2c6a604c4ae477e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 01:25:57 -0400 Subject: [PATCH 177/443] first step toward QKVO integration into tests --- tests/kernels/test_encoder_decoder_attn.py | 34 +++++++++++----------- tests/kernels/utils.py | 12 ++++++++ 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index c53eb9ed85970..1cdd02ad5aa87 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -78,7 +78,7 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, scale: float, max_q_seq_len: int) \ - -> tuple[PackedQKVInputs,torch.Tensor]: + -> PackedQKVO: ''' Set up test vectors & data structures for encoder attention test. @@ -138,8 +138,9 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE) - return packed_qkv, \ - packed_ideal_output + return PackedQKVO( + packed_qkv, \ + packed_ideal_output) def _decoder_attn_setup(batch_size: int, @@ -632,13 +633,11 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # anyway but are required to be present & valid by the # backend. - enc_pckd_qkv, \ - enc_pckd_idl_out = \ - _encoder_attn_setup(batch_size, - num_heads, - head_size, - scale, - max_enc_seq_len) + enc_pckd_qkvo = _encoder_attn_setup(batch_size, + num_heads, + head_size, + scale, + max_enc_seq_len) # Decoder self-attention setup @@ -664,7 +663,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_cross_pckd_idl_out, \ decphase_cross_kv_mmap \ = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, - enc_pckd_qkv, + enc_pckd_qkvo.packed_qkv, prephase_dec_pckd_qkv, batch_size, num_heads, @@ -686,7 +685,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( default_attn_type=AttentionType.ENCODER, num_prefills_or_decodes=len(prephase_dec_pckd_qkv.q_seq_lens), num_prefill_or_decode_tokens=sum(prephase_dec_pckd_qkv.q_seq_lens), - encoder_seq_lens=enc_pckd_qkv.q_seq_lens, + encoder_seq_lens=enc_pckd_qkvo.packed_qkv.q_seq_lens, cross_kv_mmap=prephase_cross_kv_mmap, device=CUDA_DEVICE) @@ -696,13 +695,14 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( enc_packed_actual_output: torch.Tensor = \ _run_encoder_attention_test( attn, - enc_pckd_qkv, + enc_pckd_qkvo.packed_qkv, prephase_attn_metadata, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose(enc_pckd_idl_out, - enc_packed_actual_output.view_as(enc_pckd_idl_out)) + assert torch.allclose(enc_pckd_qkvo.ideal_output, + enc_packed_actual_output + .view_as(enc_pckd_qkvo.ideal_output)) # PREFILL: self-attention test @@ -738,7 +738,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # - Cross-attention KV context is equal in length to # encoder input - context_lens = copy.deepcopy(enc_pckd_qkv.q_seq_lens) + context_lens = copy.deepcopy(enc_pckd_qkvo.packed_qkv.q_seq_lens) decphase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, @@ -749,7 +749,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( context_lens=context_lens, num_prefills_or_decodes=len(dec_qkv.q_seq_lens), num_prefill_or_decode_tokens=len(dec_qkv.q_seq_lens), - encoder_seq_lens=enc_pckd_qkv.q_seq_lens, + encoder_seq_lens=enc_pckd_qkvo.packed_qkv.q_seq_lens, cross_kv_mmap=decphase_cross_kv_mmap, device=CUDA_DEVICE) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index d0f5f20e12b1d..14ae12a061e55 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -96,12 +96,24 @@ def ref_masked_attention(query: torch.Tensor, QKVInputs = namedtuple("QKVInputs", ["query", "key", "value", "q_seq_lens", "kv_seq_lens"]) +QKVO = namedtuple("QKVO", + [ + "qkv", + "ideal_output" + ]) + # total_num_tokens x (num_heads*head_size) PackedQKVInputs = namedtuple("PackedQKVInputs", [ "query", "key", "value", "q_start_loc_list", "kv_start_loc_list", "q_seq_lens", "kv_seq_lens" ]) +PackedQKVO = namedtuple("PackedQKVO", + [ + "packed_qkv", + "ideal_output" + ]) + KVMemoryMap = namedtuple("KVMemoryMap", [ "block_tables", "slot_mapping" ]) From a6aee8002125115ae41dc22a0cbf2a9a210bd5ce Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 15:28:44 -0400 Subject: [PATCH 178/443] wip test params structure integration --- tests/kernels/test_encoder_decoder_attn.py | 107 +++++++++++---------- tests/kernels/utils.py | 5 + 2 files changed, 63 insertions(+), 49 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 1cdd02ad5aa87..0ac2e47660db5 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -78,7 +78,7 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, scale: float, max_q_seq_len: int) \ - -> PackedQKVO: + -> PhaseTestParameters: ''' Set up test vectors & data structures for encoder attention test. @@ -138,9 +138,13 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE) - return PackedQKVO( - packed_qkv, \ - packed_ideal_output) + return PhaseTestParameters( + PackedQKVO( + packed_qkv, \ + packed_ideal_output), + + None # No KV cache + ) def _decoder_attn_setup(batch_size: int, @@ -149,7 +153,10 @@ def _decoder_attn_setup(batch_size: int, block_size: int, scale: float, max_q_seq_len: int, - block_base_addr: int = 0) -> tuple: + block_base_addr: int = 0) -> tuple[QKVInputs, + PhaseTestParameters, + PhaseTestParameters, + int]: ''' Set up test vectors & data structures for self-attention test. @@ -301,23 +308,27 @@ def _decoder_attn_setup(batch_size: int, decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE) return qkv, \ - prefill_pckd_qkv, \ - prefill_packed_ideal_output, \ - KVMemoryMap( - prefill_block_tables, \ - prefill_slot_mapping), \ - decode_pckd_qkv, \ - decode_packed_ideal_output, \ - KVMemoryMap( - decode_block_tables, \ - decode_slot_mapping), \ - max_block_idx + PhaseTestParameters( # Prefill test params + PackedQKVO( + prefill_pckd_qkv, \ + prefill_packed_ideal_output), \ + KVMemoryMap( + prefill_block_tables, \ + prefill_slot_mapping)), \ + PhaseTestParameters( # Decode test params + PackedQKVO( + decode_pckd_qkv, \ + decode_packed_ideal_output), \ + KVMemoryMap( + decode_block_tables, \ + decode_slot_mapping)), \ + max_block_idx def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, - encoder_packed_qkv: - PackedQKVInputs, - prefill_phase_decoder_packed_qkv: - PackedQKVInputs, + encoder_test_params: + PhaseTestParameters, + prefill_phase_test_params: + PhaseTestParameters, batch_size: int, num_heads: int, head_size: int, @@ -390,8 +401,8 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, decoder_query = decoder_qkv.query decoder_seq_lens = decoder_qkv.q_seq_lens - encoder_seq_lens = encoder_packed_qkv.q_seq_lens - prefill_q_seq_lens = prefill_phase_decoder_packed_qkv.q_seq_lens + encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens + prefill_q_seq_lens = prefill_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens cross_kv, \ @@ -469,15 +480,20 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, # Packed key/value (query is already provided) packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE) - return packed_cross_kv, \ - prefill_packed_ideal_output, \ - KVMemoryMap( - prefill_block_tables, \ - prefill_slot_mapping), \ - decode_packed_ideal_output, \ - KVMemoryMap( - decode_block_tables, \ - decode_slot_mapping), \ + return PhaseTestParameters( # Prefill-phase test params + PackedQKVO( + packed_cross_kv, \ + prefill_packed_ideal_output), \ + KVMemoryMap( + prefill_block_tables, \ + prefill_slot_mapping)), \ + PhaseTestParameters( # Decode-phase test params + PackedQKVO( + None, + decode_packed_ideal_output), \ + KVMemoryMap( + decode_block_tables, \ + decode_slot_mapping)) def _run_encoder_attention_test(attn: Attention, pckd_qkv: PackedQKVInputs, attn_metadata: AttentionMetadata, @@ -633,21 +649,17 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # anyway but are required to be present & valid by the # backend. - enc_pckd_qkvo = _encoder_attn_setup(batch_size, - num_heads, - head_size, - scale, - max_enc_seq_len) + enc_test_params = _encoder_attn_setup(batch_size, + num_heads, + head_size, + scale, + max_enc_seq_len) # Decoder self-attention setup dec_qkv, \ - prephase_dec_pckd_qkv, \ - prephase_dec_pckd_idl_out, \ - prephase_dec_kv_mmap, \ - decphase_dec_pckd_qkv, \ - decphase_dec_pckd_idl_out, \ - decphase_dec_kv_mmap, \ + prephase_dec_test_params, \ + decphase_dec_test_params, \ cross_block_base_addr = _decoder_attn_setup(batch_size, num_heads, head_size, @@ -657,14 +669,11 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # Cross-attention setup - prephase_cross_pckd_qkv, \ - prephase_cross_pckd_idl_out, \ - prephase_cross_kv_mmap, \ - decphase_cross_pckd_idl_out, \ - decphase_cross_kv_mmap \ + prephase_cross_test_params, \ + decphase_cross_test_params, \ = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, - enc_pckd_qkvo.packed_qkv, - prephase_dec_pckd_qkv, + enc_test_params, + prephase_dec_test_params, batch_size, num_heads, head_size, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 14ae12a061e55..e2d09522db601 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -118,6 +118,11 @@ def ref_masked_attention(query: torch.Tensor, "block_tables", "slot_mapping" ]) +PhaseTestParameters = namedtuple("PhaseTestParameters", [ + "packed_qkvo", + "kv_mmap" +]) + def make_qkv( batch_size: int, max_q_seq_len: int, From ee512605696e5170fabd3b17f6e7db86143230a7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 19:01:45 -0400 Subject: [PATCH 179/443] prephase md struct using test params --- tests/kernels/test_encoder_decoder_attn.py | 9 ++--- tests/kernels/utils.py | 38 ++++++++++++++++++---- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 0ac2e47660db5..e9136909f8af6 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -689,13 +689,10 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, - prephase_dec_pckd_qkv.q_seq_lens, - prephase_dec_kv_mmap, + decoder_test_params=prephase_dec_test_params, + encoder_test_params=enc_test_params, + cross_test_params=prephase_cross_test_params, default_attn_type=AttentionType.ENCODER, - num_prefills_or_decodes=len(prephase_dec_pckd_qkv.q_seq_lens), - num_prefill_or_decode_tokens=sum(prephase_dec_pckd_qkv.q_seq_lens), - encoder_seq_lens=enc_pckd_qkvo.packed_qkv.q_seq_lens, - cross_kv_mmap=prephase_cross_kv_mmap, device=CUDA_DEVICE) # PREFILL: encoder attention diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index e2d09522db601..2d1c274a49c07 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -596,15 +596,11 @@ def make_block_tables_slot_mapping(block_size: int, def make_test_metadata( attn_backend: AttentionBackend, is_prompt: bool, - seq_lens: List[int], - kv_mmap: KVMemoryMap, + decoder_test_params: PhaseTestParameters, default_attn_type: AttentionType, - num_prefills_or_decodes: int, - num_prefill_or_decode_tokens: int, device: Union[torch.device, str], - context_lens: Optional[List[int]] = None, - encoder_seq_lens: Optional[List[int]] = None, - cross_kv_mmap: Optional[KVMemoryMap] = None, + encoder_test_params: Optional[PhaseTestParameters]=None, + cross_test_params: Optional[PhaseTestParameters]=None ) -> AttentionMetadata: ''' Construct fake attention metadata for a combined self-/cross-attention @@ -639,6 +635,34 @@ def make_test_metadata( * AttentionMetadata structure supporting self- and cross-attention ''' + # Extract + # * Decoder input sequence lengths (seq_lens) + # * Decoder self-attention slot mapping & block tables (kv_mmap) + seq_lens = decoder_test_params.packed_qkvo.packed_qkv.seq_lens + kv_mmap = decoder_test_params.kv_mmap + + # is_prompt determines whether input tokens are treated + # as 100% prefill or 100% decode. In either case, + # the number of {prefills, decodes} and the number of + # {prefill, decode} tokens can be inferred from seq_lens + num_prefills_or_decodes = len(seq_lens) + num_prefill_or_decode_tokens = sum(seq_lens) + + if encoder_test_params is None: + encoder_seq_lens = None + else: + # Encoder/decoder models only: + # * Extract encoder input sequence lengths + encoder_seq_lens = encoder_test_params.q_seq_lens + + if cross_test_params is None: + cross_kv_mmap = None + else: + # Encoder/decoder models only: + # * Extract *cross-attention* slot_mapping and block table + # (kv_mmap) + cross_kv_mmap = cross_test_params.kv_mmap + if is_prompt: num_prefills = num_prefills_or_decodes num_prefill_tokens = num_prefill_or_decode_tokens From 50a45cce0a5114ae12dedeca0b39de349e960ebc Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 19:22:26 -0400 Subject: [PATCH 180/443] correctness check helper function --- tests/kernels/test_encoder_decoder_attn.py | 59 +++++++++++++--------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index e9136909f8af6..19e44273c37f4 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -495,7 +495,7 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, decode_block_tables, \ decode_slot_mapping)) -def _run_encoder_attention_test(attn: Attention, pckd_qkv: PackedQKVInputs, +def _run_encoder_attention_test(attn: Attention, encoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: ''' @@ -520,12 +520,13 @@ def _run_encoder_attention_test(attn: Attention, pckd_qkv: PackedQKVInputs, assert attn_type == AttentionType.ENCODER assert attn_metadata.num_decode_tokens == 0 attn_metadata.attention_type = attn_type - return attn.forward(pckd_qkv.query, pckd_qkv.key, pckd_qkv.value, None, + packed_qkv = encoder_test_params.packed_qkvo.packed_qkv + return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, None, attn_metadata) def _run_decoder_self_attention_test(attn: Attention, - pckd_qkv: PackedQKVInputs, + decoder_test_params: PhaseTestParameters, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: @@ -551,13 +552,14 @@ def _run_decoder_self_attention_test(attn: Attention, ''' assert attn_type == AttentionType.DECODER attn_metadata.attention_type = attn_type - return attn.forward(pckd_qkv.query, pckd_qkv.key, pckd_qkv.value, kv_cache, + packed_qkv = decoder_test_params.packed_qkvo.packed_qkv + return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, kv_cache, attn_metadata) def _run_encoder_decoder_cross_attention_test( - attn: Attention, dec_pckd_qkv: PackedQKVInputs, - cross_pckd_qkv: PackedQKVInputs, kv_cache: torch.Tensor, + attn: Attention, decoder_test_params: PhaseTestParameters, + cross_test_params: PhaseTestParameters, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -578,13 +580,29 @@ def _run_encoder_decoder_cross_attention_test( & attn_metadata ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER + cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv key = None if cross_pckd_qkv is None else \ cross_pckd_qkv.key value = None if cross_pckd_qkv is None else \ cross_pckd_qkv.value - return attn.forward(dec_pckd_qkv.query, key, value, kv_cache, + return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, value, kv_cache, attn_metadata) +def _assert_actual_match_ideal(test_params: PhaseTestParameters, + output_under_test: torch.Tensor) -> None: + ''' + Assert that observed output matches the ideal output + contained in the test parameters data structure. + + Arguments: + + * test_params: Test parameters including packed ideal output + * output_under_test: actually observed output value + ''' + ideal_output = test_params.packed_qkvo.ideal_output + assert torch.allclose(ideal_output, + output_under_test + .view_as(ideal_output)) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -701,14 +719,13 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( enc_packed_actual_output: torch.Tensor = \ _run_encoder_attention_test( attn, - enc_pckd_qkvo.packed_qkv, + enc_test_params, prephase_attn_metadata, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose(enc_pckd_qkvo.ideal_output, - enc_packed_actual_output - .view_as(enc_pckd_qkvo.ideal_output)) + _assert_actual_match_ideal(enc_test_params, + enc_packed_actual_output) # PREFILL: self-attention test @@ -721,9 +738,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.DECODER) # - Prefill self-attention correct? - assert torch.allclose( - prephase_dec_pckd_idl_out, - self_prefill_packed_actual_output.view_as(prephase_dec_pckd_idl_out)) + _assert_actual_match_ideal(prephase_dec_test_params, + self_prefill_packed_actual_output) # PREFILL: cross-attention test @@ -736,9 +752,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_attn_metadata) # - Prefill cross-attention correct? - assert torch.allclose( - prephase_cross_pckd_idl_out, - prephase_cross_pckd_act_out.view_as(prephase_cross_pckd_idl_out)) + _assert_actual_match_ideal(prephase_cross_test_params, + prephase_cross_pckd_act_out) # DECODE: build decode-phase attention metadata @@ -770,9 +785,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.DECODER) # - Decode self-attention correct? - assert torch.allclose( - decphase_dec_pckd_idl_out, - decphase_dec_pckd_act_out.view_as(decphase_dec_pckd_idl_out)) + _assert_actual_match_ideal(decphase_dec_test_params, + decphase_dec_pckd_act_out) # DECODE: cross-attention test @@ -785,9 +799,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_attn_metadata) # - Decode cross-attention correct? - assert torch.allclose( - decphase_cross_pckd_idl_out, - decphase_cross_pckd_act_out.view_as(decphase_cross_pckd_idl_out)) + _assert_actual_match_ideal(decphase_cross_test_params, + decphase_cross_pckd_act_out) # The following test conditions could in principle be a # standalone test, however the test setup is From cd0a1aaee6ec9af5f543e7f7e8e57acad1577038 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 19:30:59 -0400 Subject: [PATCH 181/443] wip --- tests/kernels/test_encoder_decoder_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 19e44273c37f4..cd72645adcb12 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -732,7 +732,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_packed_actual_output: torch.Tensor = \ _run_decoder_self_attention_test( attn, - prephase_dec_pckd_qkv, + prephase_dec_test_params, kv_cache, prephase_attn_metadata, attn_type=AttentionType.DECODER) @@ -746,8 +746,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( attn, - prephase_dec_pckd_qkv, - prephase_cross_pckd_qkv, + prephase_dec_test_params, + prephase_cross_test_params, kv_cache, prephase_attn_metadata) From ec5977d76fdce71e8a7ef7955446a52846c6806c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 20:00:31 -0400 Subject: [PATCH 182/443] debugging test params integration --- tests/kernels/test_encoder_decoder_attn.py | 18 ++++++++---------- tests/kernels/utils.py | 9 +++++++-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index cd72645adcb12..23400819de398 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -707,6 +707,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, + prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, decoder_test_params=prephase_dec_test_params, encoder_test_params=enc_test_params, cross_test_params=prephase_cross_test_params, @@ -759,19 +760,16 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # - Cross-attention KV context is equal in length to # encoder input - context_lens = copy.deepcopy(enc_pckd_qkvo.packed_qkv.q_seq_lens) + # context_lens = copy.deepcopy(enc_pckd_qkvo.packed_qkv.q_seq_lens) decphase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, False, dec_qkv.q_seq_lens, - decphase_dec_kv_mmap, + decoder_test_params=decphase_dec_test_params, + encoder_test_params=enc_test_params, + cross_test_params=decphase_cross_test_params, default_attn_type=AttentionType.DECODER, - context_lens=context_lens, - num_prefills_or_decodes=len(dec_qkv.q_seq_lens), - num_prefill_or_decode_tokens=len(dec_qkv.q_seq_lens), - encoder_seq_lens=enc_pckd_qkvo.packed_qkv.q_seq_lens, - cross_kv_mmap=decphase_cross_kv_mmap, device=CUDA_DEVICE) # DECODE: self-attention test @@ -779,7 +777,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_dec_pckd_act_out: torch.Tensor = \ _run_decoder_self_attention_test( attn, - decphase_dec_pckd_qkv, + decphase_dec_test_params, kv_cache, decphase_attn_metadata, attn_type=AttentionType.DECODER) @@ -793,7 +791,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( attn, - decphase_dec_pckd_qkv, + decphase_dec_test_params, None, kv_cache, decphase_attn_metadata) @@ -819,7 +817,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # of prefill and decode tokens. decphase_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(attn, decphase_dec_pckd_qkv, + _run_encoder_decoder_cross_attention_test(attn, decphase_dec_test_params, None, kv_cache, decphase_attn_metadata) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 2d1c274a49c07..38d363e7adcce 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -596,6 +596,7 @@ def make_block_tables_slot_mapping(block_size: int, def make_test_metadata( attn_backend: AttentionBackend, is_prompt: bool, + seq_lens: List[int], decoder_test_params: PhaseTestParameters, default_attn_type: AttentionType, device: Union[torch.device, str], @@ -638,7 +639,7 @@ def make_test_metadata( # Extract # * Decoder input sequence lengths (seq_lens) # * Decoder self-attention slot mapping & block tables (kv_mmap) - seq_lens = decoder_test_params.packed_qkvo.packed_qkv.seq_lens + #seq_lens = decoder_test_params.packed_qkvo.packed_qkv.q_seq_lens kv_mmap = decoder_test_params.kv_mmap # is_prompt determines whether input tokens are treated @@ -648,12 +649,16 @@ def make_test_metadata( num_prefills_or_decodes = len(seq_lens) num_prefill_or_decode_tokens = sum(seq_lens) + # Seems for non-prefix-caching scenarios context_lens + # is never needed + context_lens = None + if encoder_test_params is None: encoder_seq_lens = None else: # Encoder/decoder models only: # * Extract encoder input sequence lengths - encoder_seq_lens = encoder_test_params.q_seq_lens + encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens if cross_test_params is None: cross_kv_mmap = None From 1f7b2ebe2d36c56693b75e1bdffe1a24e2e62dd7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 20:21:10 -0400 Subject: [PATCH 183/443] passing tests with test params integration --- tests/kernels/test_encoder_decoder_attn.py | 14 +++++++++----- tests/kernels/utils.py | 8 +++++++- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 23400819de398..6b287ce4bc72f 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -580,11 +580,15 @@ def _run_encoder_decoder_cross_attention_test( & attn_metadata ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER - cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv - key = None if cross_pckd_qkv is None else \ - cross_pckd_qkv.key - value = None if cross_pckd_qkv is None else \ - cross_pckd_qkv.value + if cross_test_params is None: + key = None + value = None + else: + cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv + key = None if cross_pckd_qkv is None else \ + cross_pckd_qkv.key + value = None if cross_pckd_qkv is None else \ + cross_pckd_qkv.value return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, value, kv_cache, attn_metadata) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 38d363e7adcce..e18de66264204 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -647,7 +647,13 @@ def make_test_metadata( # the number of {prefills, decodes} and the number of # {prefill, decode} tokens can be inferred from seq_lens num_prefills_or_decodes = len(seq_lens) - num_prefill_or_decode_tokens = sum(seq_lens) + if is_prompt: + # Prefill: operate on total num. of prompt + # tokens + num_prefill_or_decode_tokens = sum(seq_lens) + else: + # Decode: operate on one token per seq + num_prefill_or_decode_tokens = len(seq_lens) # Seems for non-prefix-caching scenarios context_lens # is never needed From 76b0b9ea8bb1247ee6324a03b3a1b0fa11578b02 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 20:32:06 -0400 Subject: [PATCH 184/443] format --- tests/kernels/test_encoder_decoder_attn.py | 57 +++++++++++----------- tests/kernels/utils.py | 40 ++++++--------- 2 files changed, 42 insertions(+), 55 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 6b287ce4bc72f..8462b1580794b 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -6,7 +6,6 @@ * Encoder/decoder cross-attention """ -import copy from typing import Optional import pytest @@ -147,16 +146,15 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, ) -def _decoder_attn_setup(batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_q_seq_len: int, - block_base_addr: int = 0) -> tuple[QKVInputs, - PhaseTestParameters, - PhaseTestParameters, - int]: +def _decoder_attn_setup( + batch_size: int, + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_q_seq_len: int, + block_base_addr: int = 0 +) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: ''' Set up test vectors & data structures for self-attention test. @@ -402,7 +400,8 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, decoder_query = decoder_qkv.query decoder_seq_lens = decoder_qkv.q_seq_lens encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens - prefill_q_seq_lens = prefill_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens + prefill_q_seq_lens = \ + prefill_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens cross_kv, \ @@ -495,7 +494,9 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, decode_block_tables, \ decode_slot_mapping)) -def _run_encoder_attention_test(attn: Attention, encoder_test_params: PhaseTestParameters, + +def _run_encoder_attention_test(attn: Attention, + encoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: ''' @@ -521,8 +522,8 @@ def _run_encoder_attention_test(attn: Attention, encoder_test_params: PhaseTestP assert attn_metadata.num_decode_tokens == 0 attn_metadata.attention_type = attn_type packed_qkv = encoder_test_params.packed_qkvo.packed_qkv - return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, None, - attn_metadata) + return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, + None, attn_metadata) def _run_decoder_self_attention_test(attn: Attention, @@ -553,8 +554,8 @@ def _run_decoder_self_attention_test(attn: Attention, assert attn_type == AttentionType.DECODER attn_metadata.attention_type = attn_type packed_qkv = decoder_test_params.packed_qkvo.packed_qkv - return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, kv_cache, - attn_metadata) + return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, + kv_cache, attn_metadata) def _run_encoder_decoder_cross_attention_test( @@ -589,8 +590,9 @@ def _run_encoder_decoder_cross_attention_test( cross_pckd_qkv.key value = None if cross_pckd_qkv is None else \ cross_pckd_qkv.value - return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, value, kv_cache, - attn_metadata) + return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, + value, kv_cache, attn_metadata) + def _assert_actual_match_ideal(test_params: PhaseTestParameters, output_under_test: torch.Tensor) -> None: @@ -605,8 +607,8 @@ def _assert_actual_match_ideal(test_params: PhaseTestParameters, ''' ideal_output = test_params.packed_qkvo.ideal_output assert torch.allclose(ideal_output, - output_under_test - .view_as(ideal_output)) + output_under_test.view_as(ideal_output)) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -671,11 +673,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # anyway but are required to be present & valid by the # backend. - enc_test_params = _encoder_attn_setup(batch_size, - num_heads, - head_size, - scale, - max_enc_seq_len) + enc_test_params = _encoder_attn_setup(batch_size, num_heads, head_size, + scale, max_enc_seq_len) # Decoder self-attention setup @@ -729,8 +728,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - _assert_actual_match_ideal(enc_test_params, - enc_packed_actual_output) + _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) # PREFILL: self-attention test @@ -821,7 +819,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # of prefill and decode tokens. decphase_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(attn, decphase_dec_test_params, + _run_encoder_decoder_cross_attention_test(attn, + decphase_dec_test_params, None, kv_cache, decphase_attn_metadata) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index e18de66264204..cf4b41b96996a 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -96,11 +96,7 @@ def ref_masked_attention(query: torch.Tensor, QKVInputs = namedtuple("QKVInputs", ["query", "key", "value", "q_seq_lens", "kv_seq_lens"]) -QKVO = namedtuple("QKVO", - [ - "qkv", - "ideal_output" - ]) +QKVO = namedtuple("QKVO", ["qkv", "ideal_output"]) # total_num_tokens x (num_heads*head_size) PackedQKVInputs = namedtuple("PackedQKVInputs", [ @@ -108,20 +104,13 @@ def ref_masked_attention(query: torch.Tensor, "q_seq_lens", "kv_seq_lens" ]) -PackedQKVO = namedtuple("PackedQKVO", - [ - "packed_qkv", - "ideal_output" - ]) +PackedQKVO = namedtuple("PackedQKVO", ["packed_qkv", "ideal_output"]) -KVMemoryMap = namedtuple("KVMemoryMap", [ - "block_tables", "slot_mapping" -]) +KVMemoryMap = namedtuple("KVMemoryMap", ["block_tables", "slot_mapping"]) + +PhaseTestParameters = namedtuple("PhaseTestParameters", + ["packed_qkvo", "kv_mmap"]) -PhaseTestParameters = namedtuple("PhaseTestParameters", [ - "packed_qkvo", - "kv_mmap" -]) def make_qkv( batch_size: int, @@ -600,8 +589,8 @@ def make_test_metadata( decoder_test_params: PhaseTestParameters, default_attn_type: AttentionType, device: Union[torch.device, str], - encoder_test_params: Optional[PhaseTestParameters]=None, - cross_test_params: Optional[PhaseTestParameters]=None + encoder_test_params: Optional[PhaseTestParameters] = None, + cross_test_params: Optional[PhaseTestParameters] = None ) -> AttentionMetadata: ''' Construct fake attention metadata for a combined self-/cross-attention @@ -647,13 +636,12 @@ def make_test_metadata( # the number of {prefills, decodes} and the number of # {prefill, decode} tokens can be inferred from seq_lens num_prefills_or_decodes = len(seq_lens) - if is_prompt: - # Prefill: operate on total num. of prompt - # tokens - num_prefill_or_decode_tokens = sum(seq_lens) - else: - # Decode: operate on one token per seq - num_prefill_or_decode_tokens = len(seq_lens) + + # Prefill: operate on total num. of prompt + # tokens + # Decode: operate on one token per seq + num_prefill_or_decode_tokens = \ + sum(seq_lens) if is_prompt else len(seq_lens) # Seems for non-prefix-caching scenarios context_lens # is never needed From aa5363a589209a306b78dc3037eb0e278416386e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 21:32:57 -0400 Subject: [PATCH 185/443] test points and test resources structures integrated --- tests/kernels/test_encoder_decoder_attn.py | 188 +++++++++++++-------- 1 file changed, 113 insertions(+), 75 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 8462b1580794b..65bc7e55ee503 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -10,7 +10,7 @@ import pytest import torch - +from collections import namedtuple from tests.kernels.utils import * from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType @@ -31,8 +31,25 @@ MAX_ENC_SEQ_LENS = [128] -def _basic_setup(num_heads: int, head_size: int, num_blocks: int, - block_size: int, backend_name: str) -> tuple: +TestPoint = namedtuple("TestPoint",[ + "num_heads", + "head_size", + "backend_name", + "batch_size", + "block_size", + "max_dec_seq_len", + "max_enc_seq_len", + "num_blocks" +]) + +TestResources = namedtuple("TestResources",[ + "scale", + "attn_backend", + "attn", + "kv_cache" +]) + +def _make_test_resources(test_pt: TestPoint) -> TestResources: ''' Compute & build entities required for the self-/cross-attention test. @@ -55,29 +72,45 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, * None if num_blocks or block_size is None ''' - scale = float(1.0 / (head_size**0.5)) - attn_backend = make_backend(backend_name) + scale = float(1.0 / (test_pt.head_size**0.5)) + attn_backend = make_backend(test_pt.backend_name) attn = Attention( - num_heads, - head_size, + test_pt.num_heads, + test_pt.head_size, scale=scale, ) - if num_blocks is None or num_heads is None: + if test_pt.num_blocks is None or test_pt.num_heads is None: # Caller does not require a KV cache - return scale, attn_backend, attn, None + return TestResources(scale, + attn_backend, + attn, + None) # Construct KV cache - kv_cache = make_kv_cache(num_blocks, - num_heads, - head_size, - block_size, + kv_cache = make_kv_cache(test_pt.num_blocks, + test_pt.num_heads, + test_pt.head_size, + test_pt.block_size, device=CUDA_DEVICE) - return scale, attn_backend, attn, kv_cache + return TestResources(scale, + attn_backend, + attn, + kv_cache) -def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, - scale: float, max_q_seq_len: int) \ +def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ -> PhaseTestParameters: + (num_heads, + head_size, + _, + batch_size, + _, + _, + max_q_seq_len, + _) = test_pt + + scale=test_rsrcs.scale + ''' Set up test vectors & data structures for encoder attention test. @@ -147,14 +180,11 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, def _decoder_attn_setup( - batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_q_seq_len: int, - block_base_addr: int = 0 + test_pt: TestPoint, + test_rsrcs: TestResources, + block_base_addr: int = 0, ) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: + ''' Set up test vectors & data structures for self-attention test. @@ -232,6 +262,17 @@ def _decoder_attn_setup( * max_block_idx: highest block address in the self-attention block-table ''' + (num_heads, + head_size, + _, + batch_size, + block_size, + max_q_seq_len, + _, + _) = test_pt + + scale = test_rsrcs.scale + max_kv_seq_len = max_q_seq_len qkv, \ @@ -325,17 +366,13 @@ def _decoder_attn_setup( def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, encoder_test_params: PhaseTestParameters, - prefill_phase_test_params: + prefill_decoder_phase_test_params: PhaseTestParameters, - batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_decoder_seq_len: int, - max_encoder_seq_len: int, + test_pt: TestPoint, + test_rsrcs: TestResources, block_base_addr: Optional[int]=0) \ -> tuple: + ''' Set up test vectors & data structures for cross-attention test. @@ -397,11 +434,22 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, * max_block_idx: highest block address in the cross-attention block-table ''' + (num_heads, + head_size, + _, + batch_size, + block_size, + max_decoder_seq_len, + max_encoder_seq_len, + _) = test_pt + + scale = test_rsrcs.scale + decoder_query = decoder_qkv.query decoder_seq_lens = decoder_qkv.q_seq_lens encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens prefill_q_seq_lens = \ - prefill_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens + prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens cross_kv, \ @@ -526,11 +574,11 @@ def _run_encoder_attention_test(attn: Attention, None, attn_metadata) -def _run_decoder_self_attention_test(attn: Attention, +def _run_decoder_self_attention_test(test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, - kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: + ''' Run decoder self-attention test. @@ -552,6 +600,8 @@ def _run_decoder_self_attention_test(attn: Attention, & attn_metadata ''' assert attn_type == AttentionType.DECODER + attn = test_rsrcs.attn + kv_cache = test_rsrcs.kv_cache attn_metadata.attention_type = attn_type packed_qkv = decoder_test_params.packed_qkvo.packed_qkv return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, @@ -559,9 +609,11 @@ def _run_decoder_self_attention_test(attn: Attention, def _run_encoder_decoder_cross_attention_test( - attn: Attention, decoder_test_params: PhaseTestParameters, - cross_test_params: PhaseTestParameters, kv_cache: torch.Tensor, + test_rsrcs: TestResources, + decoder_test_params: PhaseTestParameters, + cross_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata) -> torch.Tensor: + ''' Run encoder/decoder cross-attention test. @@ -581,6 +633,8 @@ def _run_encoder_decoder_cross_attention_test( & attn_metadata ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER + attn = test_rsrcs.attn + kv_cache = test_rsrcs.kv_cache if cross_test_params is None: key = None value = None @@ -609,7 +663,6 @@ def _assert_actual_match_ideal(test_params: PhaseTestParameters, assert torch.allclose(ideal_output, output_under_test.view_as(ideal_output)) - @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -622,6 +675,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, monkeypatch) -> None: + ''' Encoder/decoder attention test: @@ -651,19 +705,18 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) - # Num KV cache blocks - num_blocks = 4096 + test_pt = TestPoint(num_heads, + head_size, + backend_name, + batch_size, + block_size, + max_dec_seq_len, + max_dec_seq_len, + 4096) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init - scale, \ - attn_backend, \ - attn, \ - kv_cache = _basic_setup(num_heads, - head_size, - num_blocks, - block_size, - backend_name) + test_rsrcs = _make_test_resources(test_pt) # Encoder attention setup @@ -673,20 +726,14 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # anyway but are required to be present & valid by the # backend. - enc_test_params = _encoder_attn_setup(batch_size, num_heads, head_size, - scale, max_enc_seq_len) + enc_test_params = _encoder_attn_setup(test_pt,test_rsrcs) # Decoder self-attention setup dec_qkv, \ prephase_dec_test_params, \ decphase_dec_test_params, \ - cross_block_base_addr = _decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_dec_seq_len) + cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) # Cross-attention setup @@ -695,20 +742,15 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, enc_test_params, prephase_dec_test_params, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_dec_seq_len, - max_enc_seq_len, + test_pt, + test_rsrcs, block_base_addr = \ cross_block_base_addr) # Shared prefill metadata structure prephase_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, + test_rsrcs.attn_backend, True, prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, decoder_test_params=prephase_dec_test_params, @@ -722,7 +764,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( enc_packed_actual_output: torch.Tensor = \ _run_encoder_attention_test( - attn, + test_rsrcs.attn, enc_test_params, prephase_attn_metadata, attn_type=AttentionType.ENCODER) @@ -734,9 +776,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_packed_actual_output: torch.Tensor = \ _run_decoder_self_attention_test( - attn, + test_rsrcs, prephase_dec_test_params, - kv_cache, prephase_attn_metadata, attn_type=AttentionType.DECODER) @@ -748,10 +789,9 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( - attn, + test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, - kv_cache, prephase_attn_metadata) # - Prefill cross-attention correct? @@ -765,7 +805,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # context_lens = copy.deepcopy(enc_pckd_qkvo.packed_qkv.q_seq_lens) decphase_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, + test_rsrcs.attn_backend, False, dec_qkv.q_seq_lens, decoder_test_params=decphase_dec_test_params, @@ -778,9 +818,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_dec_pckd_act_out: torch.Tensor = \ _run_decoder_self_attention_test( - attn, + test_rsrcs, decphase_dec_test_params, - kv_cache, decphase_attn_metadata, attn_type=AttentionType.DECODER) @@ -792,10 +831,9 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( - attn, + test_rsrcs, decphase_dec_test_params, None, - kv_cache, decphase_attn_metadata) # - Decode cross-attention correct? @@ -819,9 +857,9 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # of prefill and decode tokens. decphase_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(attn, + _run_encoder_decoder_cross_attention_test(test_rsrcs, decphase_dec_test_params, - None, kv_cache, + None, decphase_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" From 53514169c2d59b81b97d834f5264634ebd84687a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 21:34:30 -0400 Subject: [PATCH 186/443] formatting --- tests/kernels/test_encoder_decoder_attn.py | 91 ++++++---------------- 1 file changed, 22 insertions(+), 69 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 65bc7e55ee503..e14c4d16b141e 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -6,11 +6,12 @@ * Encoder/decoder cross-attention """ +from collections import namedtuple from typing import Optional import pytest import torch -from collections import namedtuple + from tests.kernels.utils import * from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType @@ -30,24 +31,14 @@ MAX_DEC_SEQ_LENS = [128] MAX_ENC_SEQ_LENS = [128] - -TestPoint = namedtuple("TestPoint",[ - "num_heads", - "head_size", - "backend_name", - "batch_size", - "block_size", - "max_dec_seq_len", - "max_enc_seq_len", - "num_blocks" +TestPoint = namedtuple("TestPoint", [ + "num_heads", "head_size", "backend_name", "batch_size", "block_size", + "max_dec_seq_len", "max_enc_seq_len", "num_blocks" ]) -TestResources = namedtuple("TestResources",[ - "scale", - "attn_backend", - "attn", - "kv_cache" -]) +TestResources = namedtuple("TestResources", + ["scale", "attn_backend", "attn", "kv_cache"]) + def _make_test_resources(test_pt: TestPoint) -> TestResources: ''' @@ -81,10 +72,7 @@ def _make_test_resources(test_pt: TestPoint) -> TestResources: ) if test_pt.num_blocks is None or test_pt.num_heads is None: # Caller does not require a KV cache - return TestResources(scale, - attn_backend, - attn, - None) + return TestResources(scale, attn_backend, attn, None) # Construct KV cache kv_cache = make_kv_cache(test_pt.num_blocks, @@ -92,25 +80,14 @@ def _make_test_resources(test_pt: TestPoint) -> TestResources: test_pt.head_size, test_pt.block_size, device=CUDA_DEVICE) - return TestResources(scale, - attn_backend, - attn, - kv_cache) + return TestResources(scale, attn_backend, attn, kv_cache) def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ -> PhaseTestParameters: - (num_heads, - head_size, - _, - batch_size, - _, - _, - max_q_seq_len, - _) = test_pt - - scale=test_rsrcs.scale + (num_heads, head_size, _, batch_size, _, _, max_q_seq_len, _) = test_pt + scale = test_rsrcs.scale ''' Set up test vectors & data structures for encoder attention test. @@ -184,7 +161,6 @@ def _decoder_attn_setup( test_rsrcs: TestResources, block_base_addr: int = 0, ) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: - ''' Set up test vectors & data structures for self-attention test. @@ -262,13 +238,7 @@ def _decoder_attn_setup( * max_block_idx: highest block address in the self-attention block-table ''' - (num_heads, - head_size, - _, - batch_size, - block_size, - max_q_seq_len, - _, + (num_heads, head_size, _, batch_size, block_size, max_q_seq_len, _, _) = test_pt scale = test_rsrcs.scale @@ -372,7 +342,6 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, test_rsrcs: TestResources, block_base_addr: Optional[int]=0) \ -> tuple: - ''' Set up test vectors & data structures for cross-attention test. @@ -434,14 +403,8 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, * max_block_idx: highest block address in the cross-attention block-table ''' - (num_heads, - head_size, - _, - batch_size, - block_size, - max_decoder_seq_len, - max_encoder_seq_len, - _) = test_pt + (num_heads, head_size, _, batch_size, block_size, max_decoder_seq_len, + max_encoder_seq_len, _) = test_pt scale = test_rsrcs.scale @@ -578,7 +541,6 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: - ''' Run decoder self-attention test. @@ -601,7 +563,7 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, ''' assert attn_type == AttentionType.DECODER attn = test_rsrcs.attn - kv_cache = test_rsrcs.kv_cache + kv_cache = test_rsrcs.kv_cache attn_metadata.attention_type = attn_type packed_qkv = decoder_test_params.packed_qkvo.packed_qkv return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, @@ -609,11 +571,9 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, def _run_encoder_decoder_cross_attention_test( - test_rsrcs: TestResources, - decoder_test_params: PhaseTestParameters, + test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, cross_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata) -> torch.Tensor: - ''' Run encoder/decoder cross-attention test. @@ -663,6 +623,7 @@ def _assert_actual_match_ideal(test_params: PhaseTestParameters, assert torch.allclose(ideal_output, output_under_test.view_as(ideal_output)) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -675,7 +636,6 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, monkeypatch) -> None: - ''' Encoder/decoder attention test: @@ -705,14 +665,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) - test_pt = TestPoint(num_heads, - head_size, - backend_name, - batch_size, - block_size, - max_dec_seq_len, - max_dec_seq_len, - 4096) + test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, + block_size, max_dec_seq_len, max_dec_seq_len, 4096) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init @@ -726,7 +680,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # anyway but are required to be present & valid by the # backend. - enc_test_params = _encoder_attn_setup(test_pt,test_rsrcs) + enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) # Decoder self-attention setup @@ -859,8 +813,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( with pytest.raises(NotImplementedError) as exc_info: _run_encoder_decoder_cross_attention_test(test_rsrcs, decphase_dec_test_params, - None, - decphase_attn_metadata) + None, decphase_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL From 8d390a075bb7435aed10d1af97fada08aacea489 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 22:00:52 -0400 Subject: [PATCH 187/443] first attempt at chunked prefill failure test --- tests/kernels/test_encoder_decoder_attn.py | 133 +++++++++++++++++++++ vllm/attention/backends/utils.py | 3 +- vllm/attention/backends/xformers.py | 4 +- 3 files changed, 136 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index e14c4d16b141e..fd0f96658e24f 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -817,3 +817,136 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # "Encoder decoder models do not currently support chunked prefill" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL + + + + +@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) +@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) +def test_backend_fails_for_chunked_prefill_enc_dec( + num_heads: int, head_size: int, backend_name: str, batch_size: int, + block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, + monkeypatch) -> None: + ''' + Encoder/decoder attention test: + + * Construct fake test vectors for self- and cross-attention + * Construct attention metadata structure with self- and cross-attention + attributes + * Test self- and cross-attention in the following order + + * Prefill self-attention + * Prefill cross-attention + * Decode self-attention + * Decode cross-attention + * This order would exacerbate any accidental overlap in the + self-/cross-attention block tables, which we attempt to avoid + * Validate output correctness against ideal reference attention + implementation + + Block tables are constructed such that cross-attention KV cache is in a + higher, non-intersecting address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V + tensors. Self-attention K/Vs must have the same seq len as Q while + cross-attention K/Vs are allowed to differ in seq len, as is often the case + for cross-attention. + ''' + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, + block_size, max_dec_seq_len, max_dec_seq_len, 4096) + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + test_rsrcs = _make_test_resources(test_pt) + + # Encoder attention setup + + # Let encoder_attn_setup() choose default block table + # base address; the block_tables and slot_mapping + # tensors are not actually utilized by encoder attention + # anyway but are required to be present & valid by the + # backend. + + enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) + + # Decoder self-attention setup + + dec_qkv, \ + prephase_dec_test_params, \ + decphase_dec_test_params, \ + cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) + + # Cross-attention setup + + prephase_cross_test_params, \ + decphase_cross_test_params, \ + = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, + enc_test_params, + prephase_dec_test_params, + test_pt, + test_rsrcs, + block_base_addr = \ + cross_block_base_addr) + + # Shared prefill metadata structure + + prephase_attn_metadata: AttentionMetadata = make_test_metadata( + test_rsrcs.attn_backend, + True, + prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, + decoder_test_params=prephase_dec_test_params, + encoder_test_params=enc_test_params, + cross_test_params=prephase_cross_test_params, + default_attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE) + + # PREFILL: encoder attention + # * Use prefill kernel + + enc_packed_actual_output: torch.Tensor = \ + _run_encoder_attention_test( + test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata, + attn_type=AttentionType.ENCODER) + + # - Is encoder attention result correct? + _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) + + + + # PREFILL: self-attention test + + # The following test conditions could in principle be a + # standalone test, however the test setup is + # so involved that it is easier + # to piggyback off of the test vectors & other data structures + # created for testing decode-phase encoder/decoder cross- + # attention above. + # ---- + # Set up a contrived scenario where the attention metadata + # is configured for chunked prefill & encoder/decoder cross- + # attention. Required that this triggers a NotImplementedError. + # + # We assume that decode_attn_metadata.num_decode_tokens > 1 + # already; the line below sets up a chunked prefill + # metadata configuration where there is nominally a mix + # of prefill and decode tokens. + prephase_attn_metadata.num_decode_tokens = 1 + with pytest.raises(NotImplementedError) as exc_info: + + _run_decoder_self_attention_test( + test_rsrcs, + prephase_dec_test_params, + prephase_attn_metadata, + attn_type=AttentionType.DECODER) \ No newline at end of file diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index ad88b4f964a54..66916c28d7685 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -44,8 +44,7 @@ def check_hip_or_chunked_prefill_attention_encdec( # xFormers backend. raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND) - if attn_metadata.attention_type != AttentionType.DECODER \ - and attn_metadata.num_prefill_tokens > 0 and \ + if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: # Encoder/decoder models are currently incompatible # with chunked prefill. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index cc240242f7eae..31630f0dbecc3 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -480,7 +480,7 @@ def forward( # seqlen datastructures we utilize attn_type = attn_metadata.attention_type - if attn_type != AttentionType.DECODER: + if attn_metadata.is_all_encoder_attn_metadata_set: # Raise NotImplementedError for unsupported encoder/decoder # scenarios from vllm.attention.backends.utils import ( @@ -528,7 +528,7 @@ def forward( num_prefill_tokens = query.shape[0] num_decode_tokens = 0 - if attn_type != AttentionType.ENCODER_DECODER: + if attn_type == AttentionType.DECODER: # Only enforce this shape-constraint for decoder # self-attention assert key.shape[0] == num_prefill_tokens + num_decode_tokens From 68b6d4b29ffb787713beacd66ee922b1d8ec9326 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 22:06:58 -0400 Subject: [PATCH 188/443] narrowed the space of test-cases for unsupported scenarios --- tests/kernels/test_encoder_decoder_attn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index fd0f96658e24f..14feb1f948b2a 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -31,6 +31,10 @@ MAX_DEC_SEQ_LENS = [128] MAX_ENC_SEQ_LENS = [128] +# Narrow teest-cases for unsupported-scenario +# tests +HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]] + TestPoint = namedtuple("TestPoint", [ "num_heads", "head_size", "backend_name", "batch_size", "block_size", "max_dec_seq_len", "max_enc_seq_len", "num_blocks" @@ -823,7 +827,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("head_size", HEAD_SIZES_FOR_UNSUPP) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) From 5923002add95a787a6ae5604bd8a5aa2f14c66d1 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 22:12:59 -0400 Subject: [PATCH 189/443] format --- tests/kernels/test_encoder_decoder_attn.py | 52 +++++++--------------- vllm/attention/backends/utils.py | 1 - 2 files changed, 15 insertions(+), 38 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 14feb1f948b2a..70cc807277af8 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -798,32 +798,6 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( _assert_actual_match_ideal(decphase_cross_test_params, decphase_cross_pckd_act_out) - # The following test conditions could in principle be a - # standalone test, however the test setup is - # so involved that it is easier - # to piggyback off of the test vectors & other data structures - # created for testing decode-phase encoder/decoder cross- - # attention above. - # ---- - # Set up a contrived scenario where the attention metadata - # is configured for chunked prefill & encoder/decoder cross- - # attention. Required that this triggers a NotImplementedError. - # - # We assume that decode_attn_metadata.num_decode_tokens > 1 - # already; the line below sets up a chunked prefill - # metadata configuration where there is nominally a mix - # of prefill and decode tokens. - decphase_attn_metadata.num_prefill_tokens = 1 - with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(test_rsrcs, - decphase_dec_test_params, - None, decphase_attn_metadata) - - # "Encoder decoder models do not currently support chunked prefill" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL - - - @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -833,10 +807,14 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_backend_fails_for_chunked_prefill_enc_dec( - num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, - monkeypatch) -> None: +def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_dec_seq_len: int, + max_enc_seq_len: int, + monkeypatch) -> None: ''' Encoder/decoder attention test: @@ -927,8 +905,6 @@ def test_backend_fails_for_chunked_prefill_enc_dec( # - Is encoder attention result correct? _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) - - # PREFILL: self-attention test # The following test conditions could in principle be a @@ -949,8 +925,10 @@ def test_backend_fails_for_chunked_prefill_enc_dec( prephase_attn_metadata.num_decode_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_decoder_self_attention_test( - test_rsrcs, - prephase_dec_test_params, - prephase_attn_metadata, - attn_type=AttentionType.DECODER) \ No newline at end of file + _run_decoder_self_attention_test(test_rsrcs, + prephase_dec_test_params, + prephase_attn_metadata, + attn_type=AttentionType.DECODER) + + # "Encoder decoder models do not currently support chunked prefill" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 66916c28d7685..9f38147799045 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,7 +1,6 @@ """Attention utils""" from vllm.attention import AttentionMetadata -from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.xformers import XFormersMetadata from vllm.utils import is_hip From c3e5d2aa0520f371a55de158af542c2ae22b4373 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 22:26:25 -0400 Subject: [PATCH 190/443] skeleton of encdec prefix cache failure test; fixed bug where max enc seq len was unused --- tests/kernels/test_encoder_decoder_attn.py | 139 ++++++++++++++++++++- 1 file changed, 137 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 70cc807277af8..b71934cb116ae 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -670,7 +670,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( override_backend_env_variable(monkeypatch, backend_name) test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, - block_size, max_dec_seq_len, max_dec_seq_len, 4096) + block_size, max_dec_seq_len, max_enc_seq_len, 4096) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init @@ -845,7 +845,142 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, override_backend_env_variable(monkeypatch, backend_name) test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, - block_size, max_dec_seq_len, max_dec_seq_len, 4096) + block_size, max_dec_seq_len, max_enc_seq_len, 4096) + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + test_rsrcs = _make_test_resources(test_pt) + + # Encoder attention setup + + # Let encoder_attn_setup() choose default block table + # base address; the block_tables and slot_mapping + # tensors are not actually utilized by encoder attention + # anyway but are required to be present & valid by the + # backend. + + enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) + + # Decoder self-attention setup + + dec_qkv, \ + prephase_dec_test_params, \ + decphase_dec_test_params, \ + cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) + + # Cross-attention setup + + prephase_cross_test_params, \ + decphase_cross_test_params, \ + = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, + enc_test_params, + prephase_dec_test_params, + test_pt, + test_rsrcs, + block_base_addr = \ + cross_block_base_addr) + + # Shared prefill metadata structure + + prephase_attn_metadata: AttentionMetadata = make_test_metadata( + test_rsrcs.attn_backend, + True, + prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, + decoder_test_params=prephase_dec_test_params, + encoder_test_params=enc_test_params, + cross_test_params=prephase_cross_test_params, + default_attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE) + + # PREFILL: encoder attention + # * Use prefill kernel + + enc_packed_actual_output: torch.Tensor = \ + _run_encoder_attention_test( + test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata, + attn_type=AttentionType.ENCODER) + + # - Is encoder attention result correct? + _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) + + # PREFILL: self-attention test + + # The following test conditions could in principle be a + # standalone test, however the test setup is + # so involved that it is easier + # to piggyback off of the test vectors & other data structures + # created for testing decode-phase encoder/decoder cross- + # attention above. + # ---- + # Set up a contrived scenario where the attention metadata + # is configured for chunked prefill & encoder/decoder cross- + # attention. Required that this triggers a NotImplementedError. + # + # We assume that decode_attn_metadata.num_decode_tokens > 1 + # already; the line below sets up a chunked prefill + # metadata configuration where there is nominally a mix + # of prefill and decode tokens. + prephase_attn_metadata.num_decode_tokens = 1 + with pytest.raises(NotImplementedError) as exc_info: + + _run_decoder_self_attention_test(test_rsrcs, + prephase_dec_test_params, + prephase_attn_metadata, + attn_type=AttentionType.DECODER) + + # "Encoder decoder models do not currently support chunked prefill" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL + + +@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES_FOR_UNSUPP) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) +@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) +def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_dec_seq_len: int, + max_enc_seq_len: int, + monkeypatch) -> None: + ''' + Encoder/decoder attention test: + + * Construct fake test vectors for self- and cross-attention + * Construct attention metadata structure with self- and cross-attention + attributes + * Test self- and cross-attention in the following order + + * Prefill self-attention + * Prefill cross-attention + * Decode self-attention + * Decode cross-attention + * This order would exacerbate any accidental overlap in the + self-/cross-attention block tables, which we attempt to avoid + * Validate output correctness against ideal reference attention + implementation + + Block tables are constructed such that cross-attention KV cache is in a + higher, non-intersecting address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V + tensors. Self-attention K/Vs must have the same seq len as Q while + cross-attention K/Vs are allowed to differ in seq len, as is often the case + for cross-attention. + ''' + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, + block_size, max_dec_seq_len, max_enc_seq_len, 4096) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init From 739ab3ca7ccc945a569c7f05a270b8eb2de78317 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 6 Jun 2024 10:29:23 -0400 Subject: [PATCH 191/443] wip prefill test --- tests/kernels/test_encoder_decoder_attn.py | 24 ++++++++++--- vllm/attention/backends/utils.py | 39 ++++++++++++++-------- vllm/attention/backends/xformers.py | 14 +++++--- 3 files changed, 54 insertions(+), 23 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index b71934cb116ae..2d8942afb7adb 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -16,7 +16,9 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, + STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING, + STR_NOT_IMPL_ENC_DEC_ROCM_HIP) from vllm.utils import is_hip, make_causal_mask, maybe_make_long_tensor HEAD_SIZES = [64, 256] @@ -1006,7 +1008,7 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # Cross-attention setup prephase_cross_test_params, \ - decphase_cross_test_params, \ + _, \ = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, enc_test_params, prephase_dec_test_params, @@ -1057,8 +1059,22 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # already; the line below sets up a chunked prefill # metadata configuration where there is nominally a mix # of prefill and decode tokens. - prephase_attn_metadata.num_decode_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: + # Fake a non-empty block_tables + # prephase_dec_test_params.kv_mmap.block_tables = \ + # decphase_dec_test_params.kv_mmap.block_tables + + # prefix_block_tables = decphase_dec_test_params.kv_mmap.block_tables + + # prefix_kv_mmap = KVMemoryMap(prefix_block_tables, + # prephase_dec_test_params.kv_mmap.slot_mapping) + + # prefix_test_params = PhaseTestParameters( + # prephase_dec_test_params.packed_qkvo, + # prefix_kv_mmap + # ) + + prephase_attn_metadata.block_tables = decphase_dec_test_params.kv_mmap.block_tables _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, @@ -1066,4 +1082,4 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, attn_type=AttentionType.DECODER) # "Encoder decoder models do not currently support chunked prefill" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 9f38147799045..3200ef9113820 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,28 +1,39 @@ """Attention utils""" -from vllm.attention import AttentionMetadata -from vllm.attention.backends.xformers import XFormersMetadata +# from vllm.attention import AttentionMetadata +# from vllm.attention.backends.xformers import XFormersMetadata from vllm.utils import is_hip # Error string(s) for encoder/decoder # unsupported attention scenarios STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ -"Encoder/decoder models " + \ -"currently do not support chunked prefill." +"Chunked prefill is not currently " + \ +"supported with encoder/decoder models." STR_NOT_IMPL_ENC_DEC_ROCM_HIP = \ -"Encoder/decoder models currently" + \ -"do not support ROCm/HIP." +"ROCm/HIP is not currently supported" + \ +"with encoder/decoder models." STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND = \ -"Encoder/decoder models currently support only the XFormers backend." +"Currently only the XFormers backend " + \ + "supports encoder/decoder models." + +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING = \ +"Prefix caching is not currently supported " + \ +"with encoder/decoder models" # Check for unsupported encoder/decoder scenarios +def is_encoder_decoder_metadata(attn_metadata) -> bool: + return attn_metadata.is_all_encoder_attn_metadata_set + +def fail_encoder_decoder_prefix_caching() -> None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING) + def check_hip_or_chunked_prefill_attention_encdec( - attn_metadata: AttentionMetadata): + attn_metadata) -> None: ''' Check for unsupported encoder/decoder scenarios when invoking attention. @@ -36,12 +47,12 @@ def check_hip_or_chunked_prefill_attention_encdec( # encoder/decoder models raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ROCM_HIP) - if not isinstance(attn_metadata, XFormersMetadata): - # Right now encoder/decoder support is only implemented - # for the XFormers backend. Pretty unlikely to encounter - # this case currently given this function will be invoked inside - # xFormers backend. - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND) + # if not isinstance(attn_metadata, XFormersMetadata): + # # Right now encoder/decoder support is only implemented + # # for the XFormers backend. Pretty unlikely to encounter + # # this case currently given this function will be invoked inside + # # xFormers backend. + # raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND) if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 31630f0dbecc3..6b94ef10764ea 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -15,6 +15,11 @@ PagedAttentionMetadata) from vllm.logger import init_logger +from vllm.attention.backends.utils import ( + check_hip_or_chunked_prefill_attention_encdec, + is_encoder_decoder_metadata, + fail_encoder_decoder_prefix_caching) + logger = init_logger(__name__) @@ -480,11 +485,9 @@ def forward( # seqlen datastructures we utilize attn_type = attn_metadata.attention_type - if attn_metadata.is_all_encoder_attn_metadata_set: + if is_encoder_decoder_metadata(attn_metadata): # Raise NotImplementedError for unsupported encoder/decoder # scenarios - from vllm.attention.backends.utils import ( - check_hip_or_chunked_prefill_attention_encdec) check_hip_or_chunked_prefill_attention_encdec(attn_metadata) if (kv_cache is not None): @@ -562,12 +565,13 @@ def forward( assert prefill_meta.query_start_loc is not None assert prefill_meta.max_query_len is not None + if is_encoder_decoder_metadata(attn_metadata): + fail_encoder_decoder_prefix_caching() + # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - # - # TODO(afeldman-nm): support cross-attention out = PagedAttention.forward_prefix( query, key, From 22652249deac4d19ab9116b0b943767501d538f8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 6 Jun 2024 10:53:38 -0400 Subject: [PATCH 192/443] passing prefix cache failure test --- tests/kernels/test_encoder_decoder_attn.py | 5 ++++- vllm/attention/backends/xformers.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 2d8942afb7adb..c5a888c8be8a3 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -1074,8 +1074,11 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # prefix_kv_mmap # ) - prephase_attn_metadata.block_tables = decphase_dec_test_params.kv_mmap.block_tables + num_seqs = len(prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens) + prephase_attn_metadata._cached_prefill_metadata.block_tables = torch.randint(0,10,(num_seqs,1)) + + _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, prephase_attn_metadata, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6b94ef10764ea..a930643afb0d1 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -562,12 +562,12 @@ def forward( assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: - assert prefill_meta.query_start_loc is not None - assert prefill_meta.max_query_len is not None - if is_encoder_decoder_metadata(attn_metadata): fail_encoder_decoder_prefix_caching() + assert prefill_meta.query_start_loc is not None + assert prefill_meta.max_query_len is not None + # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, From d72aaa9098694fda0771477429735bc35985d6c0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 6 Jun 2024 10:55:00 -0400 Subject: [PATCH 193/443] format --- tests/kernels/test_encoder_decoder_attn.py | 29 +++++++++++----------- vllm/attention/backends/utils.py | 5 ++-- vllm/attention/backends/xformers.py | 8 +++--- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index c5a888c8be8a3..4e1f9a2ec7709 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -16,8 +16,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, - STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING, + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) from vllm.utils import is_hip, make_causal_mask, maybe_make_long_tensor @@ -945,13 +944,13 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, - head_size: int, - backend_name: str, - batch_size: int, - block_size: int, - max_dec_seq_len: int, - max_enc_seq_len: int, - monkeypatch) -> None: + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_dec_seq_len: int, + max_enc_seq_len: int, + monkeypatch) -> None: ''' Encoder/decoder attention test: @@ -1063,7 +1062,7 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # Fake a non-empty block_tables # prephase_dec_test_params.kv_mmap.block_tables = \ # decphase_dec_test_params.kv_mmap.block_tables - + # prefix_block_tables = decphase_dec_test_params.kv_mmap.block_tables # prefix_kv_mmap = KVMemoryMap(prefix_block_tables, @@ -1072,13 +1071,15 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # prefix_test_params = PhaseTestParameters( # prephase_dec_test_params.packed_qkvo, # prefix_kv_mmap - # ) + # ) - num_seqs = len(prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens) + num_seqs = len( + prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens) - prephase_attn_metadata._cached_prefill_metadata.block_tables = torch.randint(0,10,(num_seqs,1)) + prephase_attn_metadata._cached_prefill_metadata.block_tables = \ + torch.randint( + 0, 10, (num_seqs, 1)) - _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, prephase_attn_metadata, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 3200ef9113820..3423587be4889 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -29,11 +29,12 @@ def is_encoder_decoder_metadata(attn_metadata) -> bool: return attn_metadata.is_all_encoder_attn_metadata_set + def fail_encoder_decoder_prefix_caching() -> None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING) -def check_hip_or_chunked_prefill_attention_encdec( - attn_metadata) -> None: + +def check_hip_or_chunked_prefill_attention_encdec(attn_metadata) -> None: ''' Check for unsupported encoder/decoder scenarios when invoking attention. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index a930643afb0d1..149de709141a5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,15 +11,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import ( + check_hip_or_chunked_prefill_attention_encdec, + fail_encoder_decoder_prefix_caching, is_encoder_decoder_metadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger -from vllm.attention.backends.utils import ( - check_hip_or_chunked_prefill_attention_encdec, - is_encoder_decoder_metadata, - fail_encoder_decoder_prefix_caching) - logger = init_logger(__name__) From 1c19d36f5f9dde347909bb3b5f38d9421d772461 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 6 Jun 2024 11:31:45 -0400 Subject: [PATCH 194/443] type annotations; formatting --- tests/kernels/test_encoder_decoder_attn.py | 45 ++++++---------------- tests/kernels/utils.py | 22 ++++++----- vllm/attention/backends/utils.py | 13 ++----- vllm/attention/backends/xformers.py | 2 - 4 files changed, 27 insertions(+), 55 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 4e1f9a2ec7709..26a9c49f8f069 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -1,9 +1,11 @@ """ -Test +Tests: + +* E2E Encoder attention + Decoder self-attention + + Encoder/decoder cross-attention +* Confirm enc/dec models will fail for chunked prefill +* Confirm enc/dec models will fail for prefix caching -* Encoder attention -* Decoder self-attention -* Encoder/decoder cross-attention """ from collections import namedtuple @@ -346,7 +348,8 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, test_pt: TestPoint, test_rsrcs: TestResources, block_base_addr: Optional[int]=0) \ - -> tuple: + -> tuple[PhaseTestParameters, + PhaseTestParameters]: ''' Set up test vectors & data structures for cross-attention test. @@ -866,13 +869,13 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, dec_qkv, \ prephase_dec_test_params, \ - decphase_dec_test_params, \ + _, \ cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) # Cross-attention setup prephase_cross_test_params, \ - decphase_cross_test_params, \ + _, \ = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, enc_test_params, prephase_dec_test_params, @@ -908,13 +911,6 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, # PREFILL: self-attention test - # The following test conditions could in principle be a - # standalone test, however the test setup is - # so involved that it is easier - # to piggyback off of the test vectors & other data structures - # created for testing decode-phase encoder/decoder cross- - # attention above. - # ---- # Set up a contrived scenario where the attention metadata # is configured for chunked prefill & encoder/decoder cross- # attention. Required that this triggers a NotImplementedError. @@ -1001,7 +997,7 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, dec_qkv, \ prephase_dec_test_params, \ - decphase_dec_test_params, \ + _, \ cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) # Cross-attention setup @@ -1043,13 +1039,6 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # PREFILL: self-attention test - # The following test conditions could in principle be a - # standalone test, however the test setup is - # so involved that it is easier - # to piggyback off of the test vectors & other data structures - # created for testing decode-phase encoder/decoder cross- - # attention above. - # ---- # Set up a contrived scenario where the attention metadata # is configured for chunked prefill & encoder/decoder cross- # attention. Required that this triggers a NotImplementedError. @@ -1060,18 +1049,6 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # of prefill and decode tokens. with pytest.raises(NotImplementedError) as exc_info: # Fake a non-empty block_tables - # prephase_dec_test_params.kv_mmap.block_tables = \ - # decphase_dec_test_params.kv_mmap.block_tables - - # prefix_block_tables = decphase_dec_test_params.kv_mmap.block_tables - - # prefix_kv_mmap = KVMemoryMap(prefix_block_tables, - # prephase_dec_test_params.kv_mmap.slot_mapping) - - # prefix_test_params = PhaseTestParameters( - # prephase_dec_test_params.packed_qkvo, - # prefix_kv_mmap - # ) num_seqs = len( prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index cf4b41b96996a..1680231c41fbf 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -246,8 +246,9 @@ def make_qkv( decode_kv_seq_lens) -def pack_tensor(unpacked_tensor: torch.Tensor, seq_lens: List[int], - device: Union[torch.device, str]) -> tuple: +def pack_tensor( + unpacked_tensor: torch.Tensor, seq_lens: List[int], + device: Union[torch.device, str]) -> tuple[torch.Tensor, List[int]]: ''' Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an unpadded number_of_tokens x num_heads x head_size tensor, where @@ -355,9 +356,11 @@ def make_backend(backend_name: str) -> AttentionBackend: f"Unrecognized backend_name {backend_name} for unit test") -def _make_metadata_tensors(seq_lens: List[int], context_lens: List[int], - encoder_seq_lens: List[int], - device: Union[torch.device, str]) -> tuple: +def _make_metadata_tensors( + seq_lens: List[int], context_lens: List[int], encoder_seq_lens: List[int], + device: Union[torch.device, str] +) -> tuple[torch.Tensor, torch.Tensor, int, int, Optional[List[int]], + torch.Tensor, int]: ''' Build scalar & tensor values required to build attention metadata structure. @@ -500,10 +503,11 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], maybe_make_long_tensor(decode_slot_mapping, device) -def make_block_tables_slot_mapping(block_size: int, - seq_lens: List[int], - device: Union[torch.device, str], - block_base_addr: int = 0) -> tuple: +def make_block_tables_slot_mapping( + block_size: int, + seq_lens: List[int], + device: Union[torch.device, str], + block_base_addr: int = 0) -> tuple[torch.Tensor, List[int], int]: ''' Construct fake block tables & slot mappings. diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 3423587be4889..d67251dd17b23 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,7 +1,6 @@ """Attention utils""" -# from vllm.attention import AttentionMetadata -# from vllm.attention.backends.xformers import XFormersMetadata +from vllm.attention import AttentionMetadata from vllm.utils import is_hip # Error string(s) for encoder/decoder @@ -34,7 +33,8 @@ def fail_encoder_decoder_prefix_caching() -> None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING) -def check_hip_or_chunked_prefill_attention_encdec(attn_metadata) -> None: +def check_hip_or_chunked_prefill_attention_encdec( + attn_metadata: AttentionMetadata) -> None: ''' Check for unsupported encoder/decoder scenarios when invoking attention. @@ -48,13 +48,6 @@ def check_hip_or_chunked_prefill_attention_encdec(attn_metadata) -> None: # encoder/decoder models raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ROCM_HIP) - # if not isinstance(attn_metadata, XFormersMetadata): - # # Right now encoder/decoder support is only implemented - # # for the XFormers backend. Pretty unlikely to encounter - # # this case currently given this function will be invoked inside - # # xFormers backend. - # raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND) - if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: # Encoder/decoder models are currently incompatible diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 149de709141a5..72fe333021f08 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -281,8 +281,6 @@ def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ Extract appropriate attention bias from attention metadata according to attention type. - Depends on attn_metadata having a valid attention_type. - Arguments: * attn_metadata: Attention metadata structure associated with attention From e10340d83e98de0588b2f3a83d6ebc95de7ec608 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 6 Jun 2024 11:49:14 -0400 Subject: [PATCH 195/443] completely replaced collections.namedtuple with typing.NamedTuple w/ type annotations; formatting --- tests/kernels/test_encoder_decoder_attn.py | 27 ++++++---- tests/kernels/utils.py | 61 ++++++++++++++-------- 2 files changed, 57 insertions(+), 31 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 26a9c49f8f069..a923ba4e18a4e 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -8,15 +8,14 @@ """ -from collections import namedtuple -from typing import Optional +from typing import NamedTuple, Optional import pytest import torch from tests.kernels.utils import * from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import ( STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -38,13 +37,23 @@ # tests HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]] -TestPoint = namedtuple("TestPoint", [ - "num_heads", "head_size", "backend_name", "batch_size", "block_size", - "max_dec_seq_len", "max_enc_seq_len", "num_blocks" -]) -TestResources = namedtuple("TestResources", - ["scale", "attn_backend", "attn", "kv_cache"]) +class TestPoint(NamedTuple): + num_heads: int + head_size: int + backend_name: str + batch_size: int + block_size: int + max_dec_seq_len: int + max_enc_seq_len: int + num_blocks: int + + +class TestResources(NamedTuple): + scale: float + attn_backend: AttentionBackend + attn: Attention + kv_cache: torch.Tensor def _make_test_resources(test_pt: TestPoint) -> TestResources: diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 1680231c41fbf..2d802645de060 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,8 +2,7 @@ import itertools import random -from collections import namedtuple -from typing import List, Optional, Union +from typing import List, NamedTuple, Optional, Union import pytest import torch @@ -19,6 +18,44 @@ STR_INVALID_VAL: str = "INVALID" +class QKVInputs(NamedTuple): + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + q_seq_lens: List[int] + kv_seq_lens: List[int] + + +class QKVO(NamedTuple): + qkv: QKVInputs + ideal_output: torch.Tensor + + +class PackedQKVInputs(NamedTuple): + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + q_start_loc_list: List[int] + kv_start_loc_list: List[int] + q_seq_lens: List[int] + kv_seq_lens: List[int] + + +class PackedQKVO(NamedTuple): + packed_qkv: PackedQKVInputs + ideal_output: torch.Tensor + + +class KVMemoryMap(NamedTuple): + block_tables: torch.Tensor + slot_mapping: torch.Tensor + + +class PhaseTestParameters(NamedTuple): + packed_qkvo: PackedQKVO + kv_mmap: KVMemoryMap + + def override_backend_env_variable(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: ''' @@ -92,26 +129,6 @@ def ref_masked_attention(query: torch.Tensor, return out -# batch_size x max_q_seq_len x num_heads x head_size -QKVInputs = namedtuple("QKVInputs", - ["query", "key", "value", "q_seq_lens", "kv_seq_lens"]) - -QKVO = namedtuple("QKVO", ["qkv", "ideal_output"]) - -# total_num_tokens x (num_heads*head_size) -PackedQKVInputs = namedtuple("PackedQKVInputs", [ - "query", "key", "value", "q_start_loc_list", "kv_start_loc_list", - "q_seq_lens", "kv_seq_lens" -]) - -PackedQKVO = namedtuple("PackedQKVO", ["packed_qkv", "ideal_output"]) - -KVMemoryMap = namedtuple("KVMemoryMap", ["block_tables", "slot_mapping"]) - -PhaseTestParameters = namedtuple("PhaseTestParameters", - ["packed_qkvo", "kv_mmap"]) - - def make_qkv( batch_size: int, max_q_seq_len: int, From 67ab576fdbc13815b4e7ce60520c48d84da4bb05 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 6 Jun 2024 12:08:31 -0400 Subject: [PATCH 196/443] removed HIP check; clarified assumptions about supported backends in enc/dec supported feature checks --- tests/kernels/test_encoder_decoder_attn.py | 38 ++++++------------ tests/kernels/utils.py | 16 ++++++++ vllm/attention/backends/utils.py | 46 +++++++++++++++++----- vllm/attention/backends/xformers.py | 16 ++++---- 4 files changed, 73 insertions(+), 43 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index a923ba4e18a4e..9e07611cba5c7 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -625,22 +625,6 @@ def _run_encoder_decoder_cross_attention_test( value, kv_cache, attn_metadata) -def _assert_actual_match_ideal(test_params: PhaseTestParameters, - output_under_test: torch.Tensor) -> None: - ''' - Assert that observed output matches the ideal output - contained in the test parameters data structure. - - Arguments: - - * test_params: Test parameters including packed ideal output - * output_under_test: actually observed output value - ''' - ideal_output = test_params.packed_qkvo.ideal_output - assert torch.allclose(ideal_output, - output_under_test.view_as(ideal_output)) - - @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -741,7 +725,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) + assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) # PREFILL: self-attention test @@ -753,8 +737,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.DECODER) # - Prefill self-attention correct? - _assert_actual_match_ideal(prephase_dec_test_params, - self_prefill_packed_actual_output) + assert_actual_matches_ideal(prephase_dec_test_params, + self_prefill_packed_actual_output) # PREFILL: cross-attention test @@ -766,8 +750,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_attn_metadata) # - Prefill cross-attention correct? - _assert_actual_match_ideal(prephase_cross_test_params, - prephase_cross_pckd_act_out) + assert_actual_matches_ideal(prephase_cross_test_params, + prephase_cross_pckd_act_out) # DECODE: build decode-phase attention metadata @@ -795,8 +779,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.DECODER) # - Decode self-attention correct? - _assert_actual_match_ideal(decphase_dec_test_params, - decphase_dec_pckd_act_out) + assert_actual_matches_ideal(decphase_dec_test_params, + decphase_dec_pckd_act_out) # DECODE: cross-attention test @@ -808,8 +792,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_attn_metadata) # - Decode cross-attention correct? - _assert_actual_match_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out) + assert_actual_matches_ideal(decphase_cross_test_params, + decphase_cross_pckd_act_out) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -916,7 +900,7 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) + assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) # PREFILL: self-attention test @@ -1044,7 +1028,7 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) + assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) # PREFILL: self-attention test diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 2d802645de060..5e6d5cb2c9bd6 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -757,3 +757,19 @@ def make_test_metadata( cross_kv_mmap.slot_mapping, cross_block_tables=None if cross_kv_mmap is None else \ cross_kv_mmap.block_tables) + + +def assert_actual_matches_ideal(test_params: PhaseTestParameters, + output_under_test: torch.Tensor) -> None: + ''' + Assert that observed output matches the ideal output + contained in the test parameters data structure. + + Arguments: + + * test_params: Test parameters including packed ideal output + * output_under_test: actually observed output value + ''' + ideal_output = test_params.packed_qkvo.ideal_output + assert torch.allclose(ideal_output, + output_under_test.view_as(ideal_output)) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index d67251dd17b23..45a6f4af37d13 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,7 +1,6 @@ """Attention utils""" from vllm.attention import AttentionMetadata -from vllm.utils import is_hip # Error string(s) for encoder/decoder # unsupported attention scenarios @@ -25,28 +24,57 @@ # Check for unsupported encoder/decoder scenarios -def is_encoder_decoder_metadata(attn_metadata) -> bool: +def is_encoder_decoder_metadata_assuming_supported_backend( + attn_metadata) -> bool: + ''' + Return True of the attn_metadata argument contains + the metadata fields that would be required for + encoder attention, which proves that the user is + not running a purely decoder-only model. + + Assumes attn_metadata is derived from a backend that supports + encoder/decoder models. + + Arguments: + + * attn_metadata: instance of supported backend metadata. + Type annotation omitted to avoid circular import. + + + Returns: + + * True if attn_metadata is configured for an encoder/decoder model + ''' return attn_metadata.is_all_encoder_attn_metadata_set def fail_encoder_decoder_prefix_caching() -> None: + ''' + Fail with NotImplementedError & a message indicating + enc/dec + prefix caching is unsupported + ''' raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING) -def check_hip_or_chunked_prefill_attention_encdec( +def assert_no_encdec_chunked_prefill_assuming_supported_backend( attn_metadata: AttentionMetadata) -> None: ''' - Check for unsupported encoder/decoder scenarios when invoking - attention. + Fail if encoder/decoder model is being executed with + chunked prefill. + Assumes we already know that the particular attention + backend in-use is supported. + Arguments: * attn_metadata: Attention metadata structure ''' - if is_hip(): - # AMD ROCm/HIP support currently not implemented for - # encoder/decoder models - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ROCM_HIP) + + if not is_encoder_decoder_metadata_assuming_supported_backend( + attn_metadata): + # Only care about encoder/decoder + # scenarios. + return if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 72fe333021f08..32ea44f74d106 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -12,8 +12,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import ( - check_hip_or_chunked_prefill_attention_encdec, - fail_encoder_decoder_prefix_caching, is_encoder_decoder_metadata) + assert_no_encdec_chunked_prefill_assuming_supported_backend, + fail_encoder_decoder_prefix_caching, + is_encoder_decoder_metadata_assuming_supported_backend) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -481,10 +482,10 @@ def forward( # seqlen datastructures we utilize attn_type = attn_metadata.attention_type - if is_encoder_decoder_metadata(attn_metadata): - # Raise NotImplementedError for unsupported encoder/decoder - # scenarios - check_hip_or_chunked_prefill_attention_encdec(attn_metadata) + # Raise NotImplementedError for unsupported encoder/decoder + # scenarios (has no effect on decoder-only models) + assert_no_encdec_chunked_prefill_assuming_supported_backend( + attn_metadata) if (kv_cache is not None): # Even if there are no new key/value pairs to cache, @@ -558,7 +559,8 @@ def forward( assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: - if is_encoder_decoder_metadata(attn_metadata): + if is_encoder_decoder_metadata_assuming_supported_backend( + attn_metadata): fail_encoder_decoder_prefix_caching() assert prefill_meta.query_start_loc is not None From dc7d3c8371a4a186abcc660ee6b9b1df8ed3298c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 8 Jun 2024 15:01:56 -0400 Subject: [PATCH 197/443] wip --- tests/kernels/test_encoder_decoder_attn.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 9e07611cba5c7..04412f78d8d84 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -39,6 +39,22 @@ class TestPoint(NamedTuple): + """ + Encapsulates the attributes which define the + test_enc_dec_self_and_cross_attention_prefill_decode_phases() + test + + Attributes: + num_heads: The number of heads in the model. + head_size: Head dimension + backend_name: Name of the backend framework used. + batch_size: Number of samples per batch. + block_size: Size of each block of data processed. + max_dec_seq_len: Maximum sequence length for the decoder. + max_enc_seq_len: Maximum sequence length for the encoder. + num_blocks: Number of blocks in the model. + """ + num_heads: int head_size: int backend_name: str From e9c2a8571241e1d256001c90ed492b3439b7de1f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 07:06:09 -0400 Subject: [PATCH 198/443] wip comments --- tests/kernels/test_encoder_decoder_attn.py | 40 ++++++++++++++++++---- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 04412f78d8d84..3a86d07277b02 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -1,8 +1,9 @@ """ Tests: -* E2E Encoder attention + Decoder self-attention + - Encoder/decoder cross-attention +* E2E test of Encoder attention + Decoder self-attention + + Encoder/decoder cross-attention (collectively + "encoder/decoder attention") * Confirm enc/dec models will fail for chunked prefill * Confirm enc/dec models will fail for prefix caching @@ -40,9 +41,8 @@ class TestPoint(NamedTuple): """ - Encapsulates the attributes which define the - test_enc_dec_self_and_cross_attention_prefill_decode_phases() - test + Encapsulates the attributes which define a single invocation + of the test_e2e_enc_dec_attn() test Attributes: num_heads: The number of heads in the model. @@ -66,6 +66,34 @@ class TestPoint(NamedTuple): class TestResources(NamedTuple): + ''' + Encapsuates key components for performing an + encoder/decoder attention test + + Note that + (1) attn automatically selects an attention backend + based on platform info & a set of canned + heuristics + (2) attn_backend is thus *not the same backend + instance* used by attn, but rather it is + intended to be a + *different instance* of the *same backend class*; + it is assumed that the user of TestResources + will leverage attn_backend for the purpose of + constructing backend-compatible attention + metadata instances + + Attributes: + + * scale: 1/sqrt(d) scale factor for attn + * attn_backend: implementatino of abstraction + attention interface using + a particular kernel library + i.e. XFormers + * attn: Attention layer instance + * kv_cache: shared key/value cache for all attention + ''' + scale: float attn_backend: AttentionBackend attn: Attention @@ -649,7 +677,7 @@ def _run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_enc_dec_self_and_cross_attention_prefill_decode_phases( +def test_e2e_enc_dec_attn( num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, monkeypatch) -> None: From 57910284fe69f9dea9bf1783ce889a305f12c7c7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 07:06:25 -0400 Subject: [PATCH 199/443] small fix --- tests/kernels/test_encoder_decoder_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 3a86d07277b02..ee76064ec6005 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -67,7 +67,7 @@ class TestPoint(NamedTuple): class TestResources(NamedTuple): ''' - Encapsuates key components for performing an + Encapsulates key components for performing an encoder/decoder attention test Note that From f0cd5eab267107978559facdcbe6ba6fa014657a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 08:00:07 -0400 Subject: [PATCH 200/443] formatting --- tests/kernels/test_encoder_decoder_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index ee76064ec6005..1d042b9fa6fe4 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -677,10 +677,10 @@ def _run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_e2e_enc_dec_attn( - num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, - monkeypatch) -> None: +def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, + batch_size: int, block_size: int, + max_dec_seq_len: int, max_enc_seq_len: int, + monkeypatch) -> None: ''' Encoder/decoder attention test: From 2a7fd866e313842de13371dbfae813fc7714cd55 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 09:36:45 -0400 Subject: [PATCH 201/443] enc/dec test comment updates; some function arg changes; formatting --- tests/kernels/test_encoder_decoder_attn.py | 512 +++++++++++---------- tests/kernels/utils.py | 2 +- 2 files changed, 264 insertions(+), 250 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 1d042b9fa6fe4..950b8244365a3 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -102,25 +102,29 @@ class TestResources(NamedTuple): def _make_test_resources(test_pt: TestPoint) -> TestResources: ''' - Compute & build entities required for the self-/cross-attention test. + Build key components for performing encoder/decoder attention test. + + Note that + (1) The Attention instance constructed here, automatically selects + an attention backend class based on platform info & a set of canned + heuristics, so + (2) The attention backend instance constructed here is thus *not + the same backend instance* used by attn, but rather it is + intended to be a *different instance* of the *same backend class*; + therefore, + (3) This function requires that test_pt.backend_name matches the backend + class that Attention will automatically select when it is constructed. + Arguments: - * num_heads: Number of attention heads - * head_size: Head dimension - * num_blocks: Number of KV cache blocks (no KV cache if None) - * block_size: Number of offsets within a KV cache block - (no KV cache if None) - * backend_name: selection of backend + * test_pt: TestPoint data structure; this function relies on the + following fields: num_heads, head_size, num_blocks, + block_size, backend_name Returns: - * scale: 1/sqrt(head_size) - * attn_backend: backend instance - * attn: Attention wrapper instance - * kv_cache: fake KV cache, 2 x num_blocks x (block_size * num_heads * - head_size) - * None if num_blocks or block_size is None + * TestResources data structure. ''' scale = float(1.0 / (test_pt.head_size**0.5)) @@ -158,33 +162,29 @@ def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ The query/key/value tensors are passed to an ideal reference self-attention implementation to generate an ideal output tensor. - This function also constructs the self-attention KV cache memory mapping - (slot mapping and block table), ensuring that the block table starts at - block_base_addr + Encoder inference does not populate the KV cache, therefore + no KV cache memory mapping is constructed Arguments: - * batch_size - * num_heads: Number of attention heads - * head_size: Head dimension - * block_size: Number of offsets per KV cache block - * scale: attention scale parameter - * max_q_seq_len: upper limit on query length for synthetic test vectors - * block_base_addr: self-attention block table base address + * test_pt: TestPoint data structure; this function relies on the + following fields: batch_size, num_heads, head_size, + block_size, max_q_seq_len + * test_rsrcs: TestResources data structure; this function relies on the + scale field + Returns: - * packed_query: number_of_tokens x num_heads x head_size - * packed_key: number_of_tokens x num_heads x head_size - * packed_value: number_of_tokens x num_heads x head_size - * packed_ideal_output: number_of_tokens x num_heads x head_size - * block_tables: fake self-attn decode-phase block table - * slot_mapping: fake self-attn decode-phase slot mapping - * q_seq_lens: list of query sequence lengths + * PhaseTestParameters data structure comprising (1) packed query/key/value + tensors, (2) the ideal output of attention computed using a naive + implementation, and (3) KVCache field set to None ''' max_kv_seq_len = max_q_seq_len + # Make test tensors + qkv_in, _, _ = make_qkv(batch_size, max_q_seq_len, max_kv_seq_len, @@ -193,7 +193,9 @@ def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ attn_type=AttentionType.ENCODER, device=CUDA_DEVICE) - # No causal attention mask + # Compute correct answer using naive non-causal attention + # implementation + ideal_output = ref_masked_attention(qkv_in.query, qkv_in.key, qkv_in.value, @@ -251,51 +253,30 @@ def _decoder_attn_setup( Arguments: - * batch_size - * num_heads: Number of attention heads - * head_size: Head dimension - * block_size: Number of offsets per KV cache block - * scale: attention scale parameter - * max_q_seq_len: upper limit on query length for synthetic test vectors - * block_base_addr: self-attention block table base address + * test_pt: TestPoint data structure; this function relies on the + following fields: batch_size, num_heads, head_size, + block_size, max_q_seq_len + * test_rsrcs: TestResources data structure; this function relies on the + scale field + * block_base_addr: decoder self-attention block-table base address Returns: - - * query: "baseline" query; batch_size x padded_seq_len x num_heads x - head_size - * prefill_packed_query: "prefill" query; number_of_tokens x num_heads x - head_size - * prefill_packed_key: self-attn "prefill" key; number_of_tokens x num_heads - x head_size - * prefill_packed_value: self-attn "prefill" value; number_of_tokens x - num_heads x head_size - * prefill_packed_ideal_output: self-attn "prefill" ideal output; - number_of_tokens x num_heads x head_size - * prefill_q_seq_lens: list of token counts for each *prefill query* (one - less than baseline query) - * prefill_kv_seq_lens: list of token counts for each self-attn *prefill - key/value* (should match prefill_q_seq_lens) - * decode_packed_query: "decode" query; number_of_tokens x num_heads x - head_size - * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x - head_size - * decode_packed_value: self-attn "decode" key; number_of_tokens x num_heads - x head_size - * decode_packed_ideal_output: self-attn "decode" ideal output; - number_of_tokens x num_heads x head_size - * decode_q_seq_lens: list of token counts for each *decode query* (should - be 1) - * decode_kv_seq_lens: list of token counts for each self-attn *decode - key/value* (should match decode_q_seq_lens) - * q_seq_lens: "baseline" query seq lens; number_of_tokens x num_heads x - head_size - * kv_seq_lens: self-attn "baseline" key/value seq lens; number_of_tokens - x num_heads x head_size - * decode_block_tables: fake self-attn decode-phase block table - * decode_slot_mapping: fake self-attn decode-phase slot mapping - * prefill_slot_mapping: fake self-attn prefill-phase slot mapping - * prefill_block_tables: fake self-attn prefill-phase block table - * max_block_idx: highest block address in the self-attention block-table + * qkv: Unpacked (batch_size x padded_seq_len x num_heads x + head_size) query/key/value tensors + * Prefill-phase decoder self-attention PhaseTestParameters data structure, + including (1) packed (number_of_tokens x num_heads x head_size) + query/key/value tensors along with (2) ideal attention output + computed using a naive implementation, and (3) memory-mapping data + structures appropriate for prefill phase. + * Decode-phase decoder self-attention PhaseTestParameters data structure, + including (1) packed (number_of_tokens x num_heads x head_size) + query/key/value tensors along with (2) ideal attention output + computed using a naive implementation, and (3) memory-mapping data + structures appropriate for decode phase. + * max_block_idx: max physical address in decoder self-attention block-table + (intended to be used as the base address for the encoder/ + decoder cross-attention block-table, which is not + constructed in this function) ''' (num_heads, head_size, _, batch_size, block_size, max_q_seq_len, _, @@ -305,6 +286,8 @@ def _decoder_attn_setup( max_kv_seq_len = max_q_seq_len + # Build test tensors + qkv, \ prefill_qkv, \ decode_qkv = make_qkv(batch_size, @@ -315,6 +298,9 @@ def _decoder_attn_setup( attn_type=AttentionType.DECODER, device=CUDA_DEVICE) + # Compute correct answer using naive attention implementation + # with causal attention mask + causal_mask = make_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) @@ -326,6 +312,8 @@ def _decoder_attn_setup( q_seq_lens=qkv.q_seq_lens, kv_seq_lens=qkv.kv_seq_lens) + # Split out the prefill- & decode-phase ideal answers & pack them + prefill_ideal_output = torch.zeros_like(ideal_output) decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens): @@ -357,6 +345,9 @@ def _decoder_attn_setup( # (including both prefill & decode) # * Slot-mapping with entries for tokens that will be decoded in the # current decode iteration + # + # Note: the format described above is simply mirroring what ModelRunner + # produces prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) @@ -432,36 +423,38 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, Arguments: - * query: pre-existing "baseline" query; batch_size x padded_seq_len x - num_heads x head_size - * q_seq_lens: list of token-counts for each "baseline" query sequence - * prefill_q_seq_lens: list of token-counts for each "prefill" query - sequence - * batch_size - * num_heads: Number of attention heads - * head_size: Head dimension - * block_size: Number of offsets per KV cache block - * scale: attention scale parameter - * max_q_seq_len: upper limit on query length for synthetic test vectors - * max_kv_seq_len: upper limit on key/value length for synthetic test - vectors - * block_base_addr: cross-attention block table base address + * decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x + num_heads x head_size) decoder self-attention inputs; + this function relies on the query and q_seq_lens + fields + * encoder_test_params: PhaseTestParameters data structure which was + used for encoder inference; KV cache field + is not used by this function + * prefill_decoder_phase_test_params: PhaseTestParameters data structure + used for prefill-phase decoder + self-attention; all fields + including KV cache required + * test_pt: TestPoint data structure; this function relies on the + following fields: batch_size, num_heads, head_size, + block_size, max_q_seq_len + * test_rsrcs: TestResources data structure; this function relies on the + scale field + * block_base_addr: decoder self-attention block-table base address Returns: - * packed_key: cross-attention key; number_of_tokens x num_heads x head_size - * packed_value: cross-attention value; number_of_tokens x num_heads x - head_size - * prefill_packed_ideal_output: "prefill" ideal output; number_of_tokens x - num_heads x head_size - * decode_packed_ideal_output: "decode" ideal output; number_of_tokens x - num_heads x head_size - * kv_seq_lens: list of token-counts for each key/value - * decode_block_tables: fake decode-phase block tables - * decode_slot_mapping: fake decode-phase slot mapping - * prefill_slot_mapping: fake prefill-phase slot mapping - * prefill_block_tables: fake prefill-phase block tables - * max_block_idx: highest block address in the cross-attention block-table + * Prefill-phase encoder/decoder cross-attention PhaseTestParameters data + structure, including (1) packed + (number_of_tokens x num_heads x head_size) query/key/value tensors + along with (2) ideal attention output computed using a + naive implementation, and (3) memory-mapping data structures appropriate + for prefill phase. + * Decode-phase encoder/decoder cross-attention PhaseTestParameters data + structure, including (1) packed + (number_of_tokens x num_heads x head_size) query/key/value tensors + along with (2) ideal attention output computed using a + naive implementation, and (3) memory-mapping data structures appropriate + for decode phase. ''' (num_heads, head_size, _, batch_size, block_size, max_decoder_seq_len, @@ -533,6 +526,9 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, # * Empty slot-mapping tensor (since K & V are fixed in size, # new decoded tokens are not KV-cached and require no slot- # mapping) + # + # Note: the format above is simply an extension of what ModelRunner + # produces for decoder-only models prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE) @@ -569,30 +565,32 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, def _run_encoder_attention_test(attn: Attention, encoder_test_params: PhaseTestParameters, - attn_metadata: AttentionMetadata, - attn_type: AttentionType) -> torch.Tensor: + attn_metadata: AttentionMetadata) \ + -> torch.Tensor: ''' Run encoder attention. - attn_metadata.attention_type is assigned attn_type in order to configure - the kernel invocation for either encoder attention + attn_metadata.attention_type is assigned AttentionType.ENCODER in order + to configure the kernel invocation for encoder attention - attn_type must be AttentionType.ENCODER + Requires attn_metadata.num_decode_tokens == 0 + (There is no encoder execution in the decode-phase) Arguments: * attn: Attention wrapper instance - * pckd_qkv: Packed query/key/value inputs + * encoder_test_params: encoder PhaseTestParameters data structure; + this function relies on the packed + (number_of_tokens x num_heads x head_size) + query/key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention - * attn_type: AttentionType.DECODER or AttentionType.ENCODER Returns: - * Attention.forward() applied to packed_{query,key,value}, kv_cache + * Attention.forward() applied to packed {query,key,value} and & attn_metadata ''' - assert attn_type == AttentionType.ENCODER assert attn_metadata.num_decode_tokens == 0 - attn_metadata.attention_type = attn_type + attn_metadata.attention_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, None, attn_metadata) @@ -600,32 +598,32 @@ def _run_encoder_attention_test(attn: Attention, def _run_decoder_self_attention_test(test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, - attn_metadata: AttentionMetadata, - attn_type: AttentionType) -> torch.Tensor: + attn_metadata: AttentionMetadata) \ + -> torch.Tensor: ''' Run decoder self-attention test. - attn_metadata.attention_type is assigned attn_type in order to configure - the kernel invocation for decoder self-attention. - - attn_type must be AttentionType.DECODER + attn_metadata.attention_type is assigned AttentionType.DECODER + in order to configure the kernel invocation for decoder self-attention. Arguments: - * attn: Attention wrapper instance - * pckd_qkv: Packed query/key/value inputs - * kv_cache - * attn_metadata: attention metadata for encoder/decoder-self attention - * attn_type: AttentionType.DECODER or AttentionType.ENCODER + * test_rsrcs: TestResources instance; this function relies on the kv_cache + and attn (Attention wrapper instance) fields + * decoder_test_params: decoder PhaseTestParameters data structure; + this function relies on the packed + (number_of_tokens x num_heads x head_size) + query/key/value fields + * attn_metadata: attention metadata for decoder-self attention + (contains KV cache memory-mapping) Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata ''' - assert attn_type == AttentionType.DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache - attn_metadata.attention_type = attn_type + attn_metadata.attention_type = AttentionType.DECODER packed_qkv = decoder_test_params.packed_qkvo.packed_qkv return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, kv_cache, attn_metadata) @@ -633,20 +631,34 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, def _run_encoder_decoder_cross_attention_test( test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, - cross_test_params: PhaseTestParameters, + cross_test_params: Optional[PhaseTestParameters], attn_metadata: AttentionMetadata) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. + Via PhaseTestParameters data structures, consumes the same query utilized + for decoder self-attention, plus a key/value specific to cross-attention. + + if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv + is None, this reflects that in decode-phase cross attention there + is no growth in the key and value tensors. + attn_metadata.attention_type is assigned AttentionType.ENCODER_DECODER in order to configure the kernel invocation for encoder/decoder cross- attention. Arguments: - * attn: Attention wrapper instance - * packed_{query,key,value}: total_num_tokens x (num_heads*head_size) - * kv_cache + * test_rsrcs: TestResources instance; this function relies on the kv_cache + and attn (Attention wrapper instance) fields + * decoder_test_params: decoder PhaseTestParameters data structure; + this function relies on the packed + (number_of_tokens x num_heads x head_size) + query field + * cross_test_params: encoder/decoder PhaseTestParameters data structure; + this function relies on the packed + (number_of_tokens x num_heads x head_size) + key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention Returns: @@ -682,19 +694,25 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, max_dec_seq_len: int, max_enc_seq_len: int, monkeypatch) -> None: ''' - Encoder/decoder attention test: - - * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention - attributes - * Test self- and cross-attention in the following order + End-to-end encoder/decoder test: + + * Construct fake test vectors for (1) encoder attention, + (2) decoder self-attention, and (3) encoder/decoder cross-attention + * Construct (1) attention metadata structure with self- and cross-attention + attributes for prefill-phase, and (2) an analogous attention metadata + structure but for decode-phase + * Test attention steps in the following order + * Encoder attention * Prefill self-attention * Prefill cross-attention * Decode self-attention * Decode cross-attention - * This order would exacerbate any accidental overlap in the - self-/cross-attention block tables, which we attempt to avoid + * Besides being reflective of realistic use-cases, this order would + exacerbate any accidental overlap in the self-/cross-attention + block tables, which one hopes to avoid + + * Validate output correctness against ideal reference attention implementation @@ -705,11 +723,32 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, tensors. Self-attention K/Vs must have the same seq len as Q while cross-attention K/Vs are allowed to differ in seq len, as is often the case for cross-attention. + + This test utilizes PyTest monkey patching to force the attention backend + via an environment variable. + + Note on ROCm/HIP: currently encoder/decoder models are not supported on + AMD GPUs, therefore this test simply is skipped if is_hip(). + + Note on metadata: there is a single attention metadata structure shared by + all prefill-phase attention operations (encoder, decoder, enc/dec cross), + and a single one shared by all decode-phase attention operations + (decoder & enc/dec cross.) This is intended to reflect the behavior + of ModelRunner, which constructs a single attention metadata structure for + each prefill or decode run. A realistic scenario would rely on the + attention backend to utilize the appropriate attention metadata fields + according to the value of attn_metadata.attention_type. Thus, this test is + organized so as to confirm that the backend-under-test can handle a + shared prefill attention metadata structure & a shared decode attention + metadata structure. ''' # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) + # Note: KV cache size of 4096 is arbitrary & chosen intentionally + # to be more than necessary, since exceeding the kv cache size + # is not part of this test test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, block_size, max_dec_seq_len, max_enc_seq_len, 4096) @@ -717,24 +756,24 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # instance, KV cache init test_rsrcs = _make_test_resources(test_pt) - # Encoder attention setup - - # Let encoder_attn_setup() choose default block table - # base address; the block_tables and slot_mapping - # tensors are not actually utilized by encoder attention - # anyway but are required to be present & valid by the - # backend. + # Construct encoder attention test params (only used + # during prefill) enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) - # Decoder self-attention setup + # Construct Decoder self-attention prefill-phase & decode-phase + # test params, including query/key/value tensors, decoder self-attention + # memory-mapping. cross_block_base_addr is the uppermost address in the + # decoder self-attention block-table, i.e. a base address which the + # encoder/decoder cross-attention block-table may build downward toward. dec_qkv, \ prephase_dec_test_params, \ decphase_dec_test_params, \ cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) - # Cross-attention setup + # Construct encoder/decoder cross-attention prefill-phase & decode-phase + # test params, including key/value tensors, cross-attention memory-mapping prephase_cross_test_params, \ decphase_cross_test_params, \ @@ -759,32 +798,29 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, device=CUDA_DEVICE) # PREFILL: encoder attention - # * Use prefill kernel - enc_packed_actual_output: torch.Tensor = \ + enc_pckd_act_out: torch.Tensor = \ _run_encoder_attention_test( test_rsrcs.attn, enc_test_params, - prephase_attn_metadata, - attn_type=AttentionType.ENCODER) + prephase_attn_metadata) # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) - # PREFILL: self-attention test + # PREFILL: decoder self-attention test - self_prefill_packed_actual_output: torch.Tensor = \ + prephase_dec_pckd_act_out: torch.Tensor = \ _run_decoder_self_attention_test( test_rsrcs, prephase_dec_test_params, - prephase_attn_metadata, - attn_type=AttentionType.DECODER) + prephase_attn_metadata) - # - Prefill self-attention correct? + # - Is prefill decoder self-attention correct? assert_actual_matches_ideal(prephase_dec_test_params, - self_prefill_packed_actual_output) + prephase_dec_pckd_act_out) - # PREFILL: cross-attention test + # PREFILL: encoder/decoder cross-attention test prephase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( @@ -793,16 +829,12 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, prephase_cross_test_params, prephase_attn_metadata) - # - Prefill cross-attention correct? + # - Is prefill encoder/decoder cross-attention correct? assert_actual_matches_ideal(prephase_cross_test_params, prephase_cross_pckd_act_out) # DECODE: build decode-phase attention metadata - # - Cross-attention KV context is equal in length to - # encoder input - # context_lens = copy.deepcopy(enc_pckd_qkvo.packed_qkv.q_seq_lens) - decphase_attn_metadata: AttentionMetadata = make_test_metadata( test_rsrcs.attn_backend, False, @@ -813,20 +845,19 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, default_attn_type=AttentionType.DECODER, device=CUDA_DEVICE) - # DECODE: self-attention test + # DECODE: decoder self-attention test decphase_dec_pckd_act_out: torch.Tensor = \ _run_decoder_self_attention_test( test_rsrcs, decphase_dec_test_params, - decphase_attn_metadata, - attn_type=AttentionType.DECODER) + decphase_attn_metadata) - # - Decode self-attention correct? + # - Is decode-phase decoder self-attention correct? assert_actual_matches_ideal(decphase_dec_test_params, decphase_dec_pckd_act_out) - # DECODE: cross-attention test + # DECODE: encoder/decoder cross-attention test decphase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( @@ -835,7 +866,7 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, None, decphase_attn_metadata) - # - Decode cross-attention correct? + # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, decphase_cross_pckd_act_out) @@ -857,29 +888,25 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, max_enc_seq_len: int, monkeypatch) -> None: ''' - Encoder/decoder attention test: + Confirm encoder/decoder models will fail with NotImplemented + if chunked prefill is enabled. - * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention - attributes - * Test self- and cross-attention in the following order - - * Prefill self-attention - * Prefill cross-attention - * Decode self-attention - * Decode cross-attention - * This order would exacerbate any accidental overlap in the - self-/cross-attention block tables, which we attempt to avoid - * Validate output correctness against ideal reference attention - implementation + This test + 1. Executes a subset of test setup code from + test_e2e_enc_dec_attn() (everything up to encoder + execution); see test_e2e_enc_dec_attn() for more context + on how this code works. - Block tables are constructed such that cross-attention KV cache is in a - higher, non-intersecting address-space than self-attention KV cache. + 2. Modifies the prefill-phase attention metadata structure + to imply a chunked-prefill scenario - Self- and cross-attention share the same query tensor but not the K/V - tensors. Self-attention K/Vs must have the same seq len as Q while - cross-attention K/Vs are allowed to differ in seq len, as is often the case - for cross-attention. + 3. Attempts to execute decoder self-attention + + 4. Asserts that that decoder self-attention fails & with the correct + error message + + Note on ROCm/HIP: currently encoder/decoder models are not supported on + AMD GPUs, therefore this test simply is skipped if is_hip(). ''' # Force Attention wrapper backend @@ -894,12 +921,6 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, # Encoder attention setup - # Let encoder_attn_setup() choose default block table - # base address; the block_tables and slot_mapping - # tensors are not actually utilized by encoder attention - # anyway but are required to be present & valid by the - # backend. - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) # Decoder self-attention setup @@ -934,37 +955,35 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, device=CUDA_DEVICE) # PREFILL: encoder attention - # * Use prefill kernel enc_packed_actual_output: torch.Tensor = \ _run_encoder_attention_test( test_rsrcs.attn, enc_test_params, - prephase_attn_metadata, - attn_type=AttentionType.ENCODER) + prephase_attn_metadata) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) - # PREFILL: self-attention test - + # Meat of the test: require that chunked prefill triggers failure. + # # Set up a contrived scenario where the attention metadata - # is configured for chunked prefill & encoder/decoder cross- + # is configured for chunked prefill & decoder self- # attention. Required that this triggers a NotImplementedError. # - # We assume that decode_attn_metadata.num_decode_tokens > 1 + # We assume that decode_attn_metadata.num_prefill_tokens > 1 # already; the line below sets up a chunked prefill # metadata configuration where there is nominally a mix # of prefill and decode tokens. prephase_attn_metadata.num_decode_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_decoder_self_attention_test(test_rsrcs, - prephase_dec_test_params, - prephase_attn_metadata, - attn_type=AttentionType.DECODER) + # Doomed decoder self-attention + _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, + prephase_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" + # or something to that effect assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL @@ -985,29 +1004,25 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, max_enc_seq_len: int, monkeypatch) -> None: ''' - Encoder/decoder attention test: + Confirm encoder/decoder models will fail with NotImplemented + if prefix caching is enabled. - * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention - attributes - * Test self- and cross-attention in the following order - - * Prefill self-attention - * Prefill cross-attention - * Decode self-attention - * Decode cross-attention - * This order would exacerbate any accidental overlap in the - self-/cross-attention block tables, which we attempt to avoid - * Validate output correctness against ideal reference attention - implementation + This test + 1. Executes a subset of test setup code from + test_e2e_enc_dec_attn() (everything up to encoder + execution); see test_e2e_enc_dec_attn() for more context + on how this code works. - Block tables are constructed such that cross-attention KV cache is in a - higher, non-intersecting address-space than self-attention KV cache. + 2. Modifies the prefill-phase attention metadata structure + to imply a prefix caching scenario - Self- and cross-attention share the same query tensor but not the K/V - tensors. Self-attention K/Vs must have the same seq len as Q while - cross-attention K/Vs are allowed to differ in seq len, as is often the case - for cross-attention. + 3. Attempts to execute decoder self-attention + + 4. Asserts that that decoder self-attention fails & with the correct + error message + + Note on ROCm/HIP: currently encoder/decoder models are not supported on + AMD GPUs, therefore this test simply is skipped if is_hip(). ''' # Force Attention wrapper backend @@ -1022,12 +1037,6 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # Encoder attention setup - # Let encoder_attn_setup() choose default block table - # base address; the block_tables and slot_mapping - # tensors are not actually utilized by encoder attention - # anyway but are required to be present & valid by the - # backend. - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) # Decoder self-attention setup @@ -1062,30 +1071,36 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, device=CUDA_DEVICE) # PREFILL: encoder attention - # * Use prefill kernel enc_packed_actual_output: torch.Tensor = \ _run_encoder_attention_test( test_rsrcs.attn, enc_test_params, - prephase_attn_metadata, - attn_type=AttentionType.ENCODER) + prephase_attn_metadata) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) - # PREFILL: self-attention test - - # Set up a contrived scenario where the attention metadata - # is configured for chunked prefill & encoder/decoder cross- - # attention. Required that this triggers a NotImplementedError. + # Meat of the test: require that prefix caching triggers failure. # - # We assume that decode_attn_metadata.num_decode_tokens > 1 - # already; the line below sets up a chunked prefill - # metadata configuration where there is nominally a mix - # of prefill and decode tokens. + # Set up a contrived scenario where the attention metadata + # is configured for prefix caching & decoder self- + # attention. Require that this triggers a NotImplementedError. with pytest.raises(NotImplementedError) as exc_info: - # Fake a non-empty block_tables + # In XFormers backend, the trigger for utilizing the + # prefix caching kernel is + # + # kv_cache is not None and prefill_meta.block_tables.numel() > 0 + # + # We can shallowly emulate a prefix caching scenario by passing + # in a non-None KV cache in test_rsrcs (already the + # case) and then tweaking the cached prefill attention metadata + # from the encoder run to have a non-empty (gibberish) block + # table. This block table will never actually be used, because + # its presence will signify to the backend a prefix-caching + # scenario and (given that the attention metadata structure + # is configured for an encoder/decoder scenario too) trigger + # a NotImplemented a exception. num_seqs = len( prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens) @@ -1094,10 +1109,9 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, torch.randint( 0, 10, (num_seqs, 1)) - _run_decoder_self_attention_test(test_rsrcs, - prephase_dec_test_params, - prephase_attn_metadata, - attn_type=AttentionType.DECODER) + _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, + prephase_attn_metadata) - # "Encoder decoder models do not currently support chunked prefill" + # "Encoder decoder models do not currently support prefix caching" + # or something to that effect assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 5e6d5cb2c9bd6..4f06088da909e 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -92,7 +92,7 @@ def ref_masked_attention(query: torch.Tensor, * key: batch_size x kv_padded_seq_len x num_heads x head_size * value: batch_size x kv_padded_seq_len x num_heads x head_size * scale: Attention scale factor - * Custom mask: custom attention mask; good place to inject a causal + * custom_mask: custom attention mask; good place to inject a causal attention mask * q_seq_lens: list of unpadded query seq_lens for each batch index * kv_seq_lens: list of unpadded key/value seq_lens for each batch index From 8e1daa16b4fdceb9a10f138888cec2303673e501 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 09:40:22 -0400 Subject: [PATCH 202/443] formatting --- tests/kernels/test_encoder_decoder_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index ee76064ec6005..1d042b9fa6fe4 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -677,10 +677,10 @@ def _run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_e2e_enc_dec_attn( - num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, - monkeypatch) -> None: +def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, + batch_size: int, block_size: int, + max_dec_seq_len: int, max_enc_seq_len: int, + monkeypatch) -> None: ''' Encoder/decoder attention test: From a2a7ac5874083192c348853dba529c8eed4260ea Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 09:45:29 -0400 Subject: [PATCH 203/443] fixed attention selector test to use FLASH_ATTN string constant var in all relevant locations --- tests/kernels/test_attention_selector.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 79e03c7478de0..d9000e58d1d43 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -42,32 +42,32 @@ def test_flash_attn(monkeypatch): # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported data type backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported kv cache data type backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported block size backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported sliding window backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # flash-attn is not installed with patch.dict('sys.modules', {'vllm_flash_attn': None}): backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported head size backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL def test_invalid_env(monkeypatch): From d3575687c7ac4966e64f4c92c4fef9389aaa9afa Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 10:41:54 -0400 Subject: [PATCH 204/443] additional commenting & added string constants for other backends --- tests/kernels/test_encoder_decoder_attn.py | 2 +- tests/kernels/utils.py | 168 +++++++++++++++------ 2 files changed, 126 insertions(+), 44 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 950b8244365a3..922e3ddb43dd8 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -28,7 +28,7 @@ BATCH_SIZES = [1, 16] BLOCK_SIZES = [16] -BACKEND_NAMES = ["XFORMERS"] +BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] CUDA_DEVICE = "cuda:0" MAX_DEC_SEQ_LENS = [128] diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 4f06088da909e..ea226878e4b33 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,12 +13,34 @@ from vllm.utils import (make_tensor_with_pad, maybe_make_int_tensor, maybe_make_long_tensor, maybe_max) +# String name of register which may be set in order to +# force auto-selection of attention backend by Attention +# wrapper STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" + +# Possible string values of STR_BACKEND_ENV_VAR +# register, corresponding to possible backends +STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" +STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" +STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" +STR_XFORMERS_ATTN_VAL: str = "XFORMERS" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" class QKVInputs(NamedTuple): + ''' + Data structure for representing unpacked attention inputs, + query/key/value. + + Attributes: + + * {query,key,value}: unpacked (batch_size x padded_seq_len x + num_heads x head_size) attention inputs + * q_seq_lens: query sequence lengths list + * kv_seq_lens: shared key/value sequence lengths list + ''' + query: torch.Tensor key: torch.Tensor value: torch.Tensor @@ -27,11 +49,37 @@ class QKVInputs(NamedTuple): class QKVO(NamedTuple): + ''' + Data structure for representing unpacked attention inputs, + alongside unpacked known-correct attention output + + Attributes: + + * qkv: unpacked (batch_size x padded_seq_len x + num_heads x head_size) attention inputs + * ideal_output: unpacked (batch_size x padded_seq_len x + num_heads x head_size) known-correct attention output + ''' + qkv: QKVInputs ideal_output: torch.Tensor class PackedQKVInputs(NamedTuple): + ''' + Data structure for representing packed attention inputs + + Attributes: + + * {query,key,value}: packed (number_of_tokens x num_heads + x head_size) attention inputs + * q_seq_lens: list of query start locations within packed tensor + * kv_seq_lens: shared list of key/value start locations within + packed tensor + * q_seq_lens: query sequence lengths list + * kv_seq_lens: shared key/value sequence lengths list + ''' + query: torch.Tensor key: torch.Tensor value: torch.Tensor @@ -42,16 +90,51 @@ class PackedQKVInputs(NamedTuple): class PackedQKVO(NamedTuple): + ''' + Data structure for representing packed attention inputs, + alongside packed known-correct attention output + + Attributes: + + * packed_qkv: packed (number_of_tokens x num_heads + x head_size) attention inputs + * ideal_output: packed (number_of_tokens x num_heads + x head_size) known-correct attention output + ''' + packed_qkv: PackedQKVInputs ideal_output: torch.Tensor class KVMemoryMap(NamedTuple): + ''' + Data structure for encapsulating KV cache memory mapping. + + Attributes: + + * block_tables: KV cache block tables + * slot_mapping: mapping of sequence offset to physical address + ''' + block_tables: torch.Tensor slot_mapping: torch.Tensor class PhaseTestParameters(NamedTuple): + ''' + Data structure for encapsulating the test parameters + for a given test "phase" (prefill or decode phase) and attention + scenario (encoder, decoder-self, encoder/decoder-cross) + + Attributes: + + * packed_qkvo: packed (number_of_tokens x num_heads + x head_size) attention inputs & known-correct + output + * kv_mmap: KV cache memory mapping, specific to this test phase & + attention scenario + ''' + packed_qkvo: PackedQKVO kv_mmap: KVMemoryMap @@ -174,7 +257,9 @@ def make_qkv( Returns: - * QKVInputs structure + * Overall QKVInputs structure (containing full unpacked Q/K/V tensors) + * Prefill QKVInputs structure (containing all but the last sequence offset) + * Decode QKVInputs structure (containing all only the last sequence offset) ''' if force_max_len: @@ -245,18 +330,18 @@ def make_qkv( decode_q_seq_lens = [1 for _ in q_seq_lens] decode_kv_seq_lens = [1 for _ in kv_seq_lens] - return QKVInputs(query, + return QKVInputs(query, # Overall QKV inputs key, value, q_seq_lens, kv_seq_lens), \ - QKVInputs(prefill_query, + QKVInputs(prefill_query, # Prefill subset of QKV sequences prefill_key, prefill_value, prefill_q_seq_lens, prefill_kv_seq_lens), \ QKVInputs( - decode_query, + decode_query, # Decode subset of KV sequences decode_key, decode_value, decode_q_seq_lens, @@ -311,19 +396,14 @@ def pack_qkv(qkv: QKVInputs, device: Union[torch.device, Arguments: - * query: batch_size x padded_seq_len x num_heads x head_size - * key: batch_size x padded_seq_len x num_heads x head_size - * value: batch_size x padded_seq_len x num_heads x head_size - * q_seq_lens: list of token counts for each query - * kv_seq_lens: list of token counts for each key/value + * qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size) + attention inputs + * device: CPU or CUDA device Returns - * packed_query: number_of_tokens x num_heads x head_size - * packed_key: number_of_tokens x num_heads x head_size - * packed_value: number_of_tokens x num_heads x head_size - * q_start_loc_list: start idx of each query in packed_query - * kv_start_loc_list: start idx of each {key,value} in packed_{key,value} + * Packed (number_of_tokens x num_heads x head_size) QKV inputs + derived from unpacked inputs ''' if qkv.query is None: @@ -367,7 +447,7 @@ def make_backend(backend_name: str) -> AttentionBackend: * Backend instance ''' - if backend_name == "XFORMERS": + if backend_name == STR_XFORMERS_ATTN_VAL: return XFormersBackend() raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") @@ -383,20 +463,19 @@ def _make_metadata_tensors( Arguments: - * is_prompt: True -> Prefill, False -> Decode - * seq_lens: list of token-counts for each seq + * seq_lens: list of token-counts for each decoder input seq * context_lens: list of context length values for each seq + * encoder_seq_lens: list of token-counts for each encoder input seq * device: CPU or CUDA device Returns: - * seq_lens_tensor: seq_lens list, as tensor + * seq_lens_tensor: decoder seq_lens list, as tensor * context_lens_tensor: context_lens list, as tensor - * max_query_len: max(seq_lens) if is_seq, o/w 1 * max_context_len: max(context_lens) * max_seq_len: max(seq_lens) * seq_start_loc: start idx of each sequence - * query_start_loc: start idx of each query + * max_encoder_seq_len: encoder seq_lens list, as tensor ''' seq_lens_tensor = maybe_make_int_tensor(seq_lens, device) context_lens_tensor = maybe_make_int_tensor(context_lens, device) @@ -614,12 +693,15 @@ def make_test_metadata( cross_test_params: Optional[PhaseTestParameters] = None ) -> AttentionMetadata: ''' - Construct fake attention metadata for a combined self-/cross-attention - scenario i.e. an encoder/decoder model. + Construct fake attention metadata for a given test phase + (prefill-phase or decode-phase). - is_encoder_only_test=True causes the default attention metadata attention - type to be AttentionType.ENCODER. False causes the default to - be AttentionType.DECODER. + encoder_test_params and cross_test_params arguments all encoder + attention and enc/dec cross-attention to use distinct metadata values + from decoder self-attention (decoder_test_params.) + + if encoder_test_params and cross_test_params are None, the attention + metadata will support decoder-only scenario. Assumptions: @@ -630,32 +712,29 @@ def make_test_metadata( * attn_backend: Backend for sourcing attention kernels * is_prompt: prefill if True, o/w decode * seq_lens: list of token counts for each sequence - * context_lens: list of context lengths for each sequence - * block_tables: self-attention block tables - * slot_mapping: self-attention slot_mapping - * is_encoder_only_test: True if testing encoder; False if testing - decoder self-attention or encoder/decoder cross-attention. + * decoder_test_params: decoder self-attention test params; + this function requires + kv_mmap (memory mapping) field + * default_attn_type: value of attn_metadata.attention_type at + construction time * device: CPU or CUDA device - * encoder_seq_lens: list of token counts for each encoder sequence, if any - exist - * cross_block_tables: cross-attention block tables, if required - * cross_slot_mapping: cross-attention slot mapping, if required + * encoder_test_params: encoder attention test params; + this function requires encoder query + sequence lengths field. If None, + encoder query sequence lengths are + treated as None + * cross_test_params: enc/dec cross-attention test params; + this function requires kv_mmap field. + If None, KV cache memory map data + structures are treated as None Return: - * AttentionMetadata structure supporting self- and cross-attention + * AttentionMetadata structure ''' - # Extract - # * Decoder input sequence lengths (seq_lens) - # * Decoder self-attention slot mapping & block tables (kv_mmap) - #seq_lens = decoder_test_params.packed_qkvo.packed_qkv.q_seq_lens kv_mmap = decoder_test_params.kv_mmap - # is_prompt determines whether input tokens are treated - # as 100% prefill or 100% decode. In either case, - # the number of {prefills, decodes} and the number of - # {prefill, decode} tokens can be inferred from seq_lens num_prefills_or_decodes = len(seq_lens) # Prefill: operate on total num. of prompt @@ -684,6 +763,8 @@ def make_test_metadata( cross_kv_mmap = cross_test_params.kv_mmap if is_prompt: + # Prefill-phase scenario + num_prefills = num_prefills_or_decodes num_prefill_tokens = num_prefill_or_decode_tokens num_decode_tokens = 0 @@ -721,6 +802,7 @@ def make_test_metadata( cross_kv_mmap.block_tables) else: # not is_prompt + # Decode-phase scenario num_prefills = 0 num_prefill_tokens = 0 From ecea911f023be201b64f468829bd6e32bfc5a08f Mon Sep 17 00:00:00 2001 From: laishzh Date: Thu, 6 Jun 2024 23:29:31 +0800 Subject: [PATCH 205/443] feat: bert embedding --- examples/bert_demo.py | 23 +++++ examples/offline_inference_bert_embedding.py | 17 ++++ vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/bert_embedding.py | 88 ++++++++++++++++++++ 4 files changed, 129 insertions(+) create mode 100644 examples/bert_demo.py create mode 100644 examples/offline_inference_bert_embedding.py create mode 100644 vllm/model_executor/models/bert_embedding.py diff --git a/examples/bert_demo.py b/examples/bert_demo.py new file mode 100644 index 0000000000000..b49e4cbb94005 --- /dev/null +++ b/examples/bert_demo.py @@ -0,0 +1,23 @@ +import torch +from transformers import BertTokenizer, BertModel + +# 初始化 BERT tokenizer 和模型 +tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') +model = BertModel.from_pretrained('bert-base-uncased') + +# 输入句子 +sentence = "This is an example sentence." + +# 对输入句子进行编码 +inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding='max_length', max_length=128) + +# 获取模型输出 +with torch.no_grad(): + outputs = model(**inputs) + +# 提取句子向量(这里我们使用 [CLS] token 的向量作为句子向量) +sentence_vector = outputs.last_hidden_state[:, 0, :].squeeze() + +print(sentence_vector) + +print(sentence_vector.shape) diff --git a/examples/offline_inference_bert_embedding.py b/examples/offline_inference_bert_embedding.py new file mode 100644 index 0000000000000..3a4921651b726 --- /dev/null +++ b/examples/offline_inference_bert_embedding.py @@ -0,0 +1,17 @@ +from vllm import LLM + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create an LLM. +model = LLM(model="bert-base-uncased", enforce_eager=True) +# Generate embedding. The output is a list of EmbeddingRequestOutputs. +outputs = model.encode(prompts) +# Print the outputs. +for output in outputs: + print(output.outputs.embedding) # list of 4096 floats diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 4446914c67c8e..9307783beed60 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -63,6 +63,7 @@ _EMBEDDING_MODELS = { "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), + "BertForMaskedLM": ("bert_embedding", "BertEmbeddingModel"), } _MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py new file mode 100644 index 0000000000000..23e16edcc1a4f --- /dev/null +++ b/vllm/model_executor/models/bert_embedding.py @@ -0,0 +1,88 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import PoolerOutput + + +class BertEmbeddingModel(nn.Module): + # TODO(): change the doc + """A model that uses Llama with additional embedding functionalities. + + This class encapsulates the LlamaModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of LlamaModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + self.model = LlamaModel(**kwargs) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model.forward(input_ids, positions, kv_caches, + attn_metadata, inputs_embeds) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.model.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From e35ad0210b32d8ce24a7db189eb633288ebeb24e Mon Sep 17 00:00:00 2001 From: laishzh Date: Sat, 8 Jun 2024 01:18:03 +0800 Subject: [PATCH 206/443] feat: implements BertEmbeddingModel --- examples/bert_demo.py | 8 +-- examples/offline_inference_bert_embedding.py | 8 +-- vllm/model_executor/models/bert_embedding.py | 60 ++++++-------------- 3 files changed, 21 insertions(+), 55 deletions(-) diff --git a/examples/bert_demo.py b/examples/bert_demo.py index b49e4cbb94005..58427379e6a33 100644 --- a/examples/bert_demo.py +++ b/examples/bert_demo.py @@ -1,23 +1,19 @@ import torch from transformers import BertTokenizer, BertModel -# 初始化 BERT tokenizer 和模型 +# Init BERT tokenizer and model tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased') -# 输入句子 sentence = "This is an example sentence." -# 对输入句子进行编码 inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding='max_length', max_length=128) -# 获取模型输出 with torch.no_grad(): outputs = model(**inputs) -# 提取句子向量(这里我们使用 [CLS] token 的向量作为句子向量) +# Get the sentence vector. Here use [CLS] token as the sentence vector. sentence_vector = outputs.last_hidden_state[:, 0, :].squeeze() print(sentence_vector) - print(sentence_vector.shape) diff --git a/examples/offline_inference_bert_embedding.py b/examples/offline_inference_bert_embedding.py index 3a4921651b726..ae06a0763947a 100644 --- a/examples/offline_inference_bert_embedding.py +++ b/examples/offline_inference_bert_embedding.py @@ -2,10 +2,7 @@ # Sample prompts. prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", + "This is an example sentence.", ] # Create an LLM. @@ -14,4 +11,5 @@ outputs = model.encode(prompts) # Print the outputs. for output in outputs: - print(output.outputs.embedding) # list of 4096 floats + print(output.outputs.embedding) # list of 768 floats + print(len(output.outputs.embedding)) diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index 23e16edcc1a4f..6e32ef75de9e4 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -3,23 +3,23 @@ import torch from torch import nn +from transformers import BertModel + from vllm.attention import AttentionMetadata from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import PoolerOutput class BertEmbeddingModel(nn.Module): - # TODO(): change the doc - """A model that uses Llama with additional embedding functionalities. + """A model that uses Bert to provide embedding functionalities. - This class encapsulates the LlamaModel and provides an interface for + This class encapsulates the BertModel and provides an interface for embedding operations and customized pooling functions. Attributes: - model: An instance of LlamaModel used for forward operations. + model: An instance of BertModel used for forward operations. _pooler: An instance of Pooler used for pooling operations. """ @@ -28,7 +28,7 @@ def __init__( **kwargs, ) -> None: super().__init__() - self.model = LlamaModel(**kwargs) + self.model = BertModel(kwargs["config"], add_pooling_layer=False) self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) def forward( @@ -39,8 +39,13 @@ def forward( attn_metadata: AttentionMetadata, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self.model.forward(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds) + ts = self.model.forward(input_ids=input_ids.unsqueeze(0), + position_ids=positions, + past_key_values=None, + inputs_embeds=inputs_embeds, + return_dict=False, + ) + return ts[0].squeeze(0) def pooler( self, @@ -50,39 +55,6 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.model.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + # TODO: load weights + for name, ts in weights: + print("Parameter: ", name) From 8fbf419747656de46fccd452865e2971d232ec60 Mon Sep 17 00:00:00 2001 From: laishzh Date: Sat, 8 Jun 2024 17:53:00 +0800 Subject: [PATCH 207/443] feat: reimplements the BertModel --- vllm/model_executor/models/bert_embedding.py | 172 ++++++++++++++++++- 1 file changed, 171 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index 6e32ef75de9e4..a98a21d2de981 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -2,11 +2,22 @@ import torch from torch import nn +from torch.nn import LayerNorm from transformers import BertModel +from transformers import BertConfig -from vllm.attention import AttentionMetadata +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import PoolerOutput @@ -58,3 +69,162 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # TODO: load weights for name, ts in weights: print("Parameter: ", name) + +class BertEmbedding(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + + self.size = config.hidden_size + + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.position_embedding_type = config.position_embedding_type + if self.position_embedding_type != "absolute": + raise ValueError("Only 'absolute' position_embedding_type is supported") + + def forward( + self, + input_ids: torch.Tensor, + token_type_ids: torch.Tensor, + position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + device = input_ids.device + else: + input_shape = inputs_embeds.size()[:-1] + device = inputs_embeds.device + seq_length = input_shape.size(1) + + # input embeddings + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # position embeddings + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + # token type embeddings + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + position_embeddings + embeddings = self.norm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertEncoder(nn.Module): + pass + +class BertLayer(nn.Module): + def __init__( + self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super(BertLayer, self).__init__() + self.attention = BertAttention(config=config, cache_config=cache_config, quant_config=quant_config) + TODO: + +class BertAttention(nn.Module): + + def __init__( + self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = self.total_num_heads + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.postition_embedding = config.position_embedding_type + self.max_position_embeddings = config.max_position_embeddings + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.scaling = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=False, + quant_config=quant_config + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, bias = self.query_key_value(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + atten_output = self.atten(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(atten_output) + return output + + +class BertOutput(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertModel(nn.Module): + def __init__(self, **kwargs) -> None: + super(BertModel, self).__init__(**kwargs) + + self.embedding = BertEmbedding() + self.encoder = BertEncoder() + self.softmax = + + + def forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask): \ No newline at end of file From c2fecf3a7db233157fc4d3c41fe0198f5032f0eb Mon Sep 17 00:00:00 2001 From: laishzh Date: Mon, 10 Jun 2024 00:40:38 +0800 Subject: [PATCH 208/443] feat: implements bert --- vllm/model_executor/models/bert_embedding.py | 233 ++++++++++++++----- 1 file changed, 177 insertions(+), 56 deletions(-) diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index a98a21d2de981..b100603ecca8c 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -2,23 +2,17 @@ import torch from torch import nn -from torch.nn import LayerNorm - -from transformers import BertModel from transformers import BertConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import PoolerOutput @@ -39,7 +33,9 @@ def __init__( **kwargs, ) -> None: super().__init__() - self.model = BertModel(kwargs["config"], add_pooling_layer=False) + self.model = BertModel(config=kwargs["config"], + cache_config=kwargs.get("cache_config", None), + quant_config=kwargs.get("quant_config", None)) self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) def forward( @@ -50,13 +46,11 @@ def forward( attn_metadata: AttentionMetadata, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - ts = self.model.forward(input_ids=input_ids.unsqueeze(0), - position_ids=positions, - past_key_values=None, - inputs_embeds=inputs_embeds, - return_dict=False, - ) - return ts[0].squeeze(0) + return self.model.forward(input_ids=input_ids, + position_ids=positions, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + attn_metadata=attn_metadata) def pooler( self, @@ -70,27 +64,58 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, ts in weights: print("Parameter: ", name) + +class BertModel(nn.Module): + def __init__( + self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embedding = BertEmbedding(config) + self.encoder = BertEncoder(config, cache_config, quant_config) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.embedding(input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds) + output = self.encoder(hidden_states, kv_caches, attn_metadata) + return output + + class BertEmbedding(nn.Module): def __init__(self, config: BertConfig): super().__init__() - self.size = config.hidden_size - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - + self.position_embedding_type = config.position_embedding_type if self.position_embedding_type != "absolute": - raise ValueError("Only 'absolute' position_embedding_type is supported") - + raise ValueError("Only 'absolute' position_embedding_type" + + " is supported") + def forward( self, input_ids: torch.Tensor, - token_type_ids: torch.Tensor, - position_ids: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if input_ids is not None: @@ -107,23 +132,54 @@ def forward( # position embeddings if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = torch.arange(seq_length, dtype=torch.long, + device=device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_embeddings = self.position_embeddings(position_ids) # token type embeddings if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + token_type_ids = torch.zeros(input_shape, dtype=torch.long, + device=device) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings + position_embeddings + embeddings = inputs_embeds + token_type_embeddings + embeddings += position_embeddings embeddings = self.norm(embeddings) embeddings = self.dropout(embeddings) return embeddings - + class BertEncoder(nn.Module): - pass + def __init__( + self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.layers = nn.ModuleList([ + BertLayer(config=config, + cache_config=cache_config, + quant_config=quant_config) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + hidden_states, + kv_caches[i], + attn_metadata, + ) + return hidden_states + class BertLayer(nn.Module): def __init__( @@ -133,10 +189,58 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super(BertLayer, self).__init__() - self.attention = BertAttention(config=config, cache_config=cache_config, quant_config=quant_config) - TODO: + self.attention = BertAttention(config=config, + cache_config=cache_config, + quant_config=quant_config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + ): + self_attention_outputs = self.attention( + hidden_states, + kv_cache, + attn_metadata, + ) + + output = self.feed_forward(self_attention_outputs) + return output + + def feed_forward(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + class BertAttention(nn.Module): + def __init__( + self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.self_attn = BertSelfAttention(config=config, + cache_config=cache_config, + quant_config=quant_config) + self.output = BertSelfOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + self_outputs = self.self_attn(hidden_states, kv_cache, attn_metadata) + attn_output = self.output(self_outputs[0], hidden_states) + return attn_output + + +class BertSelfAttention(nn.Module): def __init__( self, @@ -147,19 +251,19 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() - + self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = self.total_num_heads self.head_dim = self.hidden_size // self.total_num_heads assert self.head_dim * self.total_num_heads == self.hidden_size - + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - + self.postition_embedding = config.position_embedding_type self.max_position_embeddings = config.max_position_embeddings - + self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -180,7 +284,7 @@ def __init__( bias=False, quant_config=quant_config, ) - + self.attn = Attention( num_heads=self.num_heads, head_size=self.head_dim, @@ -189,42 +293,59 @@ def __init__( cache_config=cache_config, quant_config=quant_config, ) - def forward( self, - positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - qkv, bias = self.query_key_value(hidden_states) + qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) atten_output = self.atten(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(atten_output) return output - + +class BertSelfOutput(nn.Module): + def __init__(self, config: BertConfig): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.norm(hidden_states + input_tensor) + return hidden_states + + +class BertIntermediate(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + class BertOutput(nn.Module): def __init__(self, config: BertConfig): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) + hidden_states = self.norm(hidden_states + input_tensor) return hidden_states - -class BertModel(nn.Module): - def __init__(self, **kwargs) -> None: - super(BertModel, self).__init__(**kwargs) - - self.embedding = BertEmbedding() - self.encoder = BertEncoder() - self.softmax = - - - def forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask): \ No newline at end of file From efa7391773124b188bdd84ac367a5e1e7218c7f4 Mon Sep 17 00:00:00 2001 From: laishzh Date: Tue, 11 Jun 2024 00:51:50 +0800 Subject: [PATCH 209/443] fix bug --- examples/offline_inference_bert_embedding.py | 1 + vllm/model_executor/models/bert_embedding.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference_bert_embedding.py b/examples/offline_inference_bert_embedding.py index ae06a0763947a..e66fb5876e5f4 100644 --- a/examples/offline_inference_bert_embedding.py +++ b/examples/offline_inference_bert_embedding.py @@ -3,6 +3,7 @@ # Sample prompts. prompts = [ "This is an example sentence.", + "Another sentence.", ] # Create an LLM. diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index b100603ecca8c..8a427778ad2bb 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -124,7 +124,7 @@ def forward( else: input_shape = inputs_embeds.size()[:-1] device = inputs_embeds.device - seq_length = input_shape.size(1) + seq_length = input_shape[0] # input embeddings if inputs_embeds is None: @@ -302,7 +302,7 @@ def forward( ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - atten_output = self.atten(q, k, v, kv_cache, attn_metadata) + atten_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(atten_output) return output From 0e9e4313af703f1d73dd8ba5a0e9153a66c4efb6 Mon Sep 17 00:00:00 2001 From: laishzh Date: Tue, 11 Jun 2024 16:31:57 +0800 Subject: [PATCH 210/443] feat: implements --- examples/bert_demo.py | 24 +++++++++----- vllm/model_executor/models/bert_embedding.py | 35 ++++++++++++++++++-- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/examples/bert_demo.py b/examples/bert_demo.py index 58427379e6a33..91a44330d0d41 100644 --- a/examples/bert_demo.py +++ b/examples/bert_demo.py @@ -5,15 +5,23 @@ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased') -sentence = "This is an example sentence." +print("Model: ") +print(model) -inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding='max_length', max_length=128) +print("Parameter: ") +for name, weight in model.named_parameters(): + print(f"Name: {name}".ljust(60) + f"Weight: {weight.shape}".ljust(40) + f"dtype: {weight.dtype}".ljust(20)) -with torch.no_grad(): - outputs = model(**inputs) -# Get the sentence vector. Here use [CLS] token as the sentence vector. -sentence_vector = outputs.last_hidden_state[:, 0, :].squeeze() +# sentence = "This is an example sentence." -print(sentence_vector) -print(sentence_vector.shape) +# inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding='max_length', max_length=128) + +# with torch.no_grad(): +# outputs = model(**inputs) + +# # Get the sentence vector. Here use [CLS] token as the sentence vector. +# sentence_vector = outputs.last_hidden_state[:, 0, :].squeeze() + +# print(sentence_vector) +# print(sentence_vector.shape) diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index 8a427778ad2bb..3a3ea3de7f86d 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import PoolerOutput @@ -60,9 +61,37 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # TODO: load weights - for name, ts in weights: - print("Parameter: ", name) + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "query", "q"), + ("qkv_proj", "key", "k"), + ("qkv_proj", "value", "v"), + ] + params_dict = dict(self.model.named_parameters()) + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + # TODO: check + ## Skip loading extra bias for GPTQ models. + #if name.endswith(".bias") and name not in params_dict: + # continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # TODO: check + # # Skip loading extra bias for GPTQ models. + # if name.endswith(".bias") and name not in params_dict: + # continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) class BertModel(nn.Module): From 015c4bc04b84114d72468e65c4c534f78c6ac3f3 Mon Sep 17 00:00:00 2001 From: laishzh Date: Tue, 11 Jun 2024 22:09:41 +0800 Subject: [PATCH 211/443] feat: add base_model_prefix --- examples/bert_demo.py | 2 +- vllm/model_executor/models/bert_embedding.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/bert_demo.py b/examples/bert_demo.py index 91a44330d0d41..ea778c32bd244 100644 --- a/examples/bert_demo.py +++ b/examples/bert_demo.py @@ -2,7 +2,7 @@ from transformers import BertTokenizer, BertModel # Init BERT tokenizer and model -tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') +# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased') print("Model: ") diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index 3a3ea3de7f86d..487189ef34d70 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -34,6 +34,7 @@ def __init__( **kwargs, ) -> None: super().__init__() + self.base_model_prefix = "bert" self.model = BertModel(config=kwargs["config"], cache_config=kwargs.get("cache_config", None), quant_config=kwargs.get("quant_config", None)) @@ -68,7 +69,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "value", "v"), ] params_dict = dict(self.model.named_parameters()) + _prefix = f"{self.base_model_prefix}." for name, loaded_weight in weights: + name = name[len(_prefix) :] if name.startswith(_prefix) else name + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue @@ -88,6 +92,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # # Skip loading extra bias for GPTQ models. # if name.endswith(".bias") and name not in params_dict: # continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From df2f75c5c239b73896b38dda29c9b17a38637af2 Mon Sep 17 00:00:00 2001 From: laishzh Date: Wed, 12 Jun 2024 01:15:37 +0800 Subject: [PATCH 212/443] feat: fix Bert Embedding --- examples/bert_demo.py | 15 ++--- examples/offline_inference_bert_embedding.py | 7 ++- vllm/model_executor/models/bert_embedding.py | 64 ++++++++++++-------- 3 files changed, 52 insertions(+), 34 deletions(-) diff --git a/examples/bert_demo.py b/examples/bert_demo.py index ea778c32bd244..dbac2a34c4b28 100644 --- a/examples/bert_demo.py +++ b/examples/bert_demo.py @@ -2,7 +2,7 @@ from transformers import BertTokenizer, BertModel # Init BERT tokenizer and model -# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') +tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased') print("Model: ") @@ -14,14 +14,15 @@ # sentence = "This is an example sentence." +sentence = "Another sentence." -# inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding='max_length', max_length=128) +inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding='max_length', max_length=128) -# with torch.no_grad(): -# outputs = model(**inputs) +with torch.no_grad(): + outputs = model(**inputs) -# # Get the sentence vector. Here use [CLS] token as the sentence vector. -# sentence_vector = outputs.last_hidden_state[:, 0, :].squeeze() +# Get the sentence vector. Here use [CLS] token as the sentence vector. +sentence_vector = outputs.last_hidden_state[:, 0, :].squeeze() -# print(sentence_vector) +print(sentence_vector) # print(sentence_vector.shape) diff --git a/examples/offline_inference_bert_embedding.py b/examples/offline_inference_bert_embedding.py index e66fb5876e5f4..ab2e52cfa1346 100644 --- a/examples/offline_inference_bert_embedding.py +++ b/examples/offline_inference_bert_embedding.py @@ -2,12 +2,15 @@ # Sample prompts. prompts = [ - "This is an example sentence.", - "Another sentence.", + # "This is an example sentence.", + # "Another sentence.", + "今天天气怎么样?好一些了吧?" ] # Create an LLM. model = LLM(model="bert-base-uncased", enforce_eager=True) +# model = LLM(model="google-bert/bert-base-multilingual-uncased", enforce_eager=True) +# model = LLM(model="google-bert/bert-large-uncased", enforce_eager=True) # Generate embedding. The output is a list of EmbeddingRequestOutputs. outputs = model.encode(prompts) # Print the outputs. diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index 487189ef34d70..fe11353d991e8 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -62,6 +62,13 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + def _fix_key(key): + if "beta" in key: + return key.replace("beta", "bias") + if "gamma" in key: + return key.replace("gamma", "weight") + return key stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "query", "q"), @@ -71,27 +78,34 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.model.named_parameters()) _prefix = f"{self.base_model_prefix}." for name, loaded_weight in weights: + # Skip the specific downstream task weight. + if name.startswith('cls.'): + continue + name = name[len(_prefix) :] if name.startswith(_prefix) else name + name = _fix_key(name) + + # use Pooler instead. + if name.startswith('pooler.'): + continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) - # TODO: check - ## Skip loading extra bias for GPTQ models. - #if name.endswith(".bias") and name not in params_dict: - # continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - # TODO: check - # # Skip loading extra bias for GPTQ models. - # if name.endswith(".bias") and name not in params_dict: - # continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", @@ -107,7 +121,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.embedding = BertEmbedding(config) + self.embeddings = BertEmbedding(config) self.encoder = BertEncoder(config, cache_config, quant_config) def forward( @@ -118,9 +132,9 @@ def forward( attn_metadata: AttentionMetadata, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embedding(input_ids=input_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds) + hidden_states = self.embeddings(input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds) output = self.encoder(hidden_states, kv_caches, attn_metadata) return output @@ -137,7 +151,7 @@ def __init__(self, config: BertConfig): config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.position_embedding_type = config.position_embedding_type @@ -179,7 +193,7 @@ def forward( embeddings = inputs_embeds + token_type_embeddings embeddings += position_embeddings - embeddings = self.norm(embeddings) + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -192,7 +206,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.layers = nn.ModuleList([ + self.layer = nn.ModuleList([ BertLayer(config=config, cache_config=cache_config, quant_config=quant_config) @@ -205,8 +219,8 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: - for i in range(len(self.layers)): - layer = self.layers[i] + for i in range(len(self.layer)): + layer = self.layer[i] hidden_states = layer( hidden_states, kv_caches[i], @@ -258,7 +272,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.self_attn = BertSelfAttention(config=config, + self.self = BertSelfAttention(config=config, cache_config=cache_config, quant_config=quant_config) self.output = BertSelfOutput(config) @@ -269,7 +283,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - self_outputs = self.self_attn(hidden_states, kv_cache, attn_metadata) + self_outputs = self.self(hidden_states, kv_cache, attn_metadata) attn_output = self.output(self_outputs[0], hidden_states) return attn_output @@ -303,7 +317,7 @@ def __init__( self.scaling = self.head_dim**-0.5 - self.query_key_value = QKVParallelLinear( + self.qkv_proj = QKVParallelLinear( hidden_size=self.hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, @@ -334,7 +348,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - qkv, _ = self.query_key_value(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) atten_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(atten_output) @@ -345,13 +359,13 @@ class BertSelfOutput(nn.Module): def __init__(self, config: BertConfig): super(BertSelfOutput, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.norm(hidden_states + input_tensor) + hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states @@ -371,7 +385,7 @@ class BertOutput(nn.Module): def __init__(self, config: BertConfig): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( @@ -381,5 +395,5 @@ def forward( ) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.norm(hidden_states + input_tensor) + hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states From da231fd5a945c4f8aedfcec6cef63342ef2b1abc Mon Sep 17 00:00:00 2001 From: laishzh Date: Wed, 12 Jun 2024 13:11:12 +0800 Subject: [PATCH 213/443] feat: remote o_proj --- vllm/model_executor/models/bert_embedding.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index fe11353d991e8..c14b4125aafc8 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -326,13 +326,6 @@ def __init__( quant_config=quant_config ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - ) - self.attn = Attention( num_heads=self.num_heads, head_size=self.head_dim, @@ -350,8 +343,7 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - atten_output = self.attn(q, k, v, kv_cache, attn_metadata) - output, _ = self.o_proj(atten_output) + output = self.attn(q, k, v, kv_cache, attn_metadata) return output From cf11d480c43cb415fafab069d6f5f44ddb67c60c Mon Sep 17 00:00:00 2001 From: laishzh Date: Wed, 12 Jun 2024 16:44:20 +0800 Subject: [PATCH 214/443] feat: set qkv_proj.bias = True --- vllm/model_executor/models/bert_embedding.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index c14b4125aafc8..9c26839455311 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -48,11 +48,11 @@ def forward( attn_metadata: AttentionMetadata, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self.model.forward(input_ids=input_ids, - position_ids=positions, - kv_caches=kv_caches, - inputs_embeds=inputs_embeds, - attn_metadata=attn_metadata) + return self.model(input_ids=input_ids, + position_ids=positions, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + attn_metadata=attn_metadata) def pooler( self, @@ -273,8 +273,8 @@ def __init__( ): super().__init__() self.self = BertSelfAttention(config=config, - cache_config=cache_config, - quant_config=quant_config) + cache_config=cache_config, + quant_config=quant_config) self.output = BertSelfOutput(config) def forward( @@ -322,7 +322,7 @@ def __init__( head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, - bias=False, + bias=True, quant_config=quant_config ) From 25f7783cdb7f7315b643e3e9fc16b64bad7b6ac5 Mon Sep 17 00:00:00 2001 From: laishzh Date: Wed, 12 Jun 2024 17:33:08 +0800 Subject: [PATCH 215/443] chore: remove files --- examples/bert_demo.py | 28 -------------------- examples/offline_inference_bert_embedding.py | 8 ++---- vllm/model_executor/models/bert_embedding.py | 6 +---- 3 files changed, 3 insertions(+), 39 deletions(-) delete mode 100644 examples/bert_demo.py diff --git a/examples/bert_demo.py b/examples/bert_demo.py deleted file mode 100644 index dbac2a34c4b28..0000000000000 --- a/examples/bert_demo.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -from transformers import BertTokenizer, BertModel - -# Init BERT tokenizer and model -tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') -model = BertModel.from_pretrained('bert-base-uncased') - -print("Model: ") -print(model) - -print("Parameter: ") -for name, weight in model.named_parameters(): - print(f"Name: {name}".ljust(60) + f"Weight: {weight.shape}".ljust(40) + f"dtype: {weight.dtype}".ljust(20)) - - -# sentence = "This is an example sentence." -sentence = "Another sentence." - -inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding='max_length', max_length=128) - -with torch.no_grad(): - outputs = model(**inputs) - -# Get the sentence vector. Here use [CLS] token as the sentence vector. -sentence_vector = outputs.last_hidden_state[:, 0, :].squeeze() - -print(sentence_vector) -# print(sentence_vector.shape) diff --git a/examples/offline_inference_bert_embedding.py b/examples/offline_inference_bert_embedding.py index ab2e52cfa1346..30982316e55b0 100644 --- a/examples/offline_inference_bert_embedding.py +++ b/examples/offline_inference_bert_embedding.py @@ -2,17 +2,13 @@ # Sample prompts. prompts = [ - # "This is an example sentence.", - # "Another sentence.", - "今天天气怎么样?好一些了吧?" + "This is an example sentence.", ] # Create an LLM. model = LLM(model="bert-base-uncased", enforce_eager=True) -# model = LLM(model="google-bert/bert-base-multilingual-uncased", enforce_eager=True) -# model = LLM(model="google-bert/bert-large-uncased", enforce_eager=True) -# Generate embedding. The output is a list of EmbeddingRequestOutputs. outputs = model.encode(prompts) + # Print the outputs. for output in outputs: print(output.outputs.embedding) # list of 768 floats diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index 9c26839455311..9bb87aebf7a2a 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -302,6 +302,7 @@ def __init__( self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = self.total_num_heads self.head_dim = self.hidden_size // self.total_num_heads @@ -309,14 +310,9 @@ def __init__( self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.postition_embedding = config.position_embedding_type - self.max_position_embeddings = config.max_position_embeddings - self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.qkv_proj = QKVParallelLinear( hidden_size=self.hidden_size, head_size=self.head_dim, From b81fb8a015330eca25a5a3a6d97ef2384777550a Mon Sep 17 00:00:00 2001 From: laishzh Date: Wed, 12 Jun 2024 18:00:56 +0800 Subject: [PATCH 216/443] chore: fix code style issue --- vllm/model_executor/models/bert_embedding.py | 32 ++++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index 9bb87aebf7a2a..93c4068e8b1c2 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -8,8 +8,7 @@ from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -69,6 +68,7 @@ def _fix_key(key): if "gamma" in key: return key.replace("gamma", "weight") return key + stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "query", "q"), @@ -82,7 +82,7 @@ def _fix_key(key): if name.startswith('cls.'): continue - name = name[len(_prefix) :] if name.startswith(_prefix) else name + name = name[len(_prefix):] if name.startswith(_prefix) else name name = _fix_key(name) # use Pooler instead. @@ -114,6 +114,7 @@ def _fix_key(key): class BertModel(nn.Module): + def __init__( self, config: BertConfig, @@ -140,6 +141,7 @@ def forward( class BertEmbedding(nn.Module): + def __init__(self, config: BertConfig): super().__init__() self.size = config.hidden_size @@ -151,7 +153,8 @@ def __init__(self, config: BertConfig): config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.position_embedding_type = config.position_embedding_type @@ -180,14 +183,16 @@ def forward( # position embeddings if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, + position_ids = torch.arange(seq_length, + dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_embeddings = self.position_embeddings(position_ids) # token type embeddings if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, device=device) token_type_embeddings = self.token_type_embeddings(token_type_ids) @@ -199,6 +204,7 @@ def forward( class BertEncoder(nn.Module): + def __init__( self, config: BertConfig, @@ -230,6 +236,7 @@ def forward( class BertLayer(nn.Module): + def __init__( self, config: BertConfig, @@ -265,6 +272,7 @@ def feed_forward(self, attention_output): class BertAttention(nn.Module): + def __init__( self, config: BertConfig, @@ -319,8 +327,7 @@ def __init__( total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, bias=True, - quant_config=quant_config - ) + quant_config=quant_config) self.attn = Attention( num_heads=self.num_heads, @@ -344,10 +351,12 @@ def forward( class BertSelfOutput(nn.Module): + def __init__(self, config: BertConfig): super(BertSelfOutput, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): @@ -358,6 +367,7 @@ def forward(self, hidden_states, input_tensor): class BertIntermediate(nn.Module): + def __init__(self, config: BertConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) @@ -370,10 +380,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertOutput(nn.Module): + def __init__(self, config: BertConfig): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( From 389bb7aff431e46ba718916711a7729ee2c5453d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 12 Jun 2024 12:05:23 -0400 Subject: [PATCH 217/443] wip setting up encoder/decoder model runner --- vllm/worker/enc_dec_model_runner.py | 954 ++++++++++++++++++++++++++++ 1 file changed, 954 insertions(+) create mode 100644 vllm/worker/enc_dec_model_runner.py diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py new file mode 100644 index 0000000000000..c870718f09c96 --- /dev/null +++ b/vllm/worker/enc_dec_model_runner.py @@ -0,0 +1,954 @@ +import gc +import time +import warnings +from collections import defaultdict +from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.distributed.communication_op import graph_capture +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sampling_params import SamplingParams +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, + is_pin_memory_available, make_tensor_with_pad) +from vllm.worker.model_runner import (_PAD_SLOT_ID, + LORA_WARMUP_RANK, + _BATCH_SIZE_ALIGNMENT, + _BATCH_SIZES_TO_CAPTURE, + _NUM_WARMUP_ITERS, + ModelInput, + ModelRunner) + +logger = init_logger(__name__) + +class EncoderDecoderModelInput(ModelInput): + input_tokens: torch.Tensor + input_positions: torch.Tensor + attn_metadata: Optional[AttentionMetadata] + seq_lens: List[int] + query_lens: List[int] + lora_mapping: Optional[LoRAMapping] + lora_requests: Set[LoRARequest] + multi_modal_kwargs: Dict[str, torch.Tensor] + slot_mapping: torch.Tensor + num_prefill_tokens: int + num_decode_tokens: int + num_prefills: int + + @classmethod + def empty(cls, device): + return ModelInput( + input_tokens=torch.empty(0, device=device), + input_positions=torch.empty(0, device=device), + attn_metadata=None, + seq_lens=[], + query_lens=[], + lora_mapping=None, + lora_requests=set(), + multi_modal_kwargs={}, + slot_mapping=torch.empty(0, device=device), + num_prefill_tokens=0, + num_decode_tokens=0, + num_prefills=0, + ) + + +class EncoderDecoderModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + vision_language_config: Optional[VisionLanguageConfig] = None, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + self.vision_language_config = vision_language_config + + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture + self.graph_runners: Dict[int, CUDAGraphRunner] = {} + self.graph_memory_pool: Optional[Tuple[ + int, int]] = None # Set during graph capture. + # When using CUDA graph, the input block tables must be padded to + # max_seq_len_to_capture. However, creating the block table in + # Python can be expensive. To optimize this, we cache the block table + # in numpy and only copy the actual input content at every iteration. + # The shape of the cached block table will be + # (max batch size to capture, max context len to capture / block size). + self.graph_block_tables = np.zeros( + (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + dtype=np.int32) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) + + # Create processor for multi-modal data + if self.vision_language_config is not None: + self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ + .create_input_processor( + self.model_config, + self.vision_language_config, + ) + else: + self.multi_modal_input_processor = None + + # Lazy initialization + self.model: nn.Module # Set after load_model + # Set if the backend is flashinfer. + self.flashinfer_workspace_buffer: torch.Tensor + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + + def load_model(self) -> None: + with CudaMemoryProfiler() as m: + self.model = get_model( + model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config, + ) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) + + if self.lora_config: + assert hasattr(self.model, "supported_lora_modules" + ) and self.model.supported_lora_modules, ( + "Model does not support LoRA") + assert hasattr( + self.model, + "embedding_modules"), "Model does not have embedding_modules" + assert hasattr(self.model, "embedding_padding_modules" + ), "Model does not have embedding_padding_modules" + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=self.model.config. + max_position_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + if self.kv_cache_dtype == "fp8" and is_hip(): + # Currently only ROCm accepts kv-cache scaling factors + # via quantization_param_path and this will be deprecated + # in the future. + if self.model_config.quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + warnings.warn( + "Loading kv cache scaling factor from JSON is " + "deprecated and will be removed. Please include " + "kv cache scaling factors in the model checkpoint.", + FutureWarning, + stacklevel=2) + self.model.load_kv_cache_scales( + self.model_config.quantization_param_path) + logger.info("Loaded KV cache scaling factors from %s", + self.model_config.quantization_param_path) + else: + raise RuntimeError( + "Using FP8 KV cache and scaling factors provided but " + "model %s does not support loading scaling factors.", + self.model.__class__) + else: + logger.warning( + "Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!") + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from vllm.model_executor.model_loader.loader import ShardedStateLoader + ShardedStateLoader.save_model( + self.model, + path, + pattern=pattern, + max_size=max_size, + ) + + def get_max_block_per_batch(self) -> int: + block_size = self.block_size + return (self.max_seq_len_to_capture + block_size - 1) // block_size + + def _prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> ModelInput: + """Prepare the model input based on a given sequence group. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] + lora_requests: Set[LoRARequest] = set() + + seq_lens: List[int] = [] + prefill_seq_lens: List[int] = [] + decode_seq_lens: List[int] = [] + context_lens: List[int] = [] + query_lens: List[int] = [] + block_tables: List[List[int]] = [] + multi_modal_kwargs_list: Dict[str, + List[torch.Tensor]] = defaultdict(list) + decode_only = True + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = 0 + + # The following fields are only for flashinfer + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + paged_kv_last_page_len: List[int] = [] + + if len(seq_group_metadata_list) == 0: + return ModelInput.empty(self.device) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window + self.block_size - + 1) // self.block_size + block_aligned_sliding_window = \ + sliding_window_blocks * self.block_size + + for seq_group_metadata in seq_group_metadata_list: + seq_ids = list(seq_group_metadata.seq_data.keys()) + is_prompt = seq_group_metadata.is_prompt + + for seq_id in seq_ids: + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + seq_data = seq_group_metadata.seq_data[seq_id] + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_data.get_len() - 1 + + seq_len = min( + seq_data.get_len(), + context_len + seq_group_metadata.token_chunk_size) + if is_prompt: + tokens = seq_data.get_token_ids()[context_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + + # Prefix cache was hit. + # Prefix is not supported with sliding_window + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and is_prompt) + + # These are seq_len/context_len capped to the sliding window. + # They are passed to decode kernel. + # We still need original seq_len/context_len to compute slot + # mapping (and input position) below. + curr_sliding_window_blocks = None + sliding_seq_len = seq_len + sliding_context_len = context_len + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if (self.sliding_window is not None and not is_prompt): + curr_sliding_window_blocks = sliding_window_blocks + if self.scheduler_config.use_v2_block_manager: + # number of elements in last block + suff_len = seq_len % self.block_size + sliding_seq_len = min( + seq_len, block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_blocks += 1 + else: + sliding_seq_len = min(seq_len, self.sliding_window) + sliding_context_len = sliding_seq_len - 1 + + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + context_len = len(computed_block_nums) * self.block_size + tokens = tokens[context_len:] + + # need to think what to set it to when we have both sliding + # window and prefix caching... + assert self.sliding_window is None, \ + "Prefix caching is not supported with sliding window" + sliding_context_len = context_len + + if self.attn_backend.get_name() == "flash-attn": + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + # TODO(woosuk): This is a temporary fix. We should + # provide a unified interface for different backends. + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = computed_block_nums + elif (self.scheduler_config.chunked_prefill_enabled + or not is_prompt): + if seq_group_metadata.block_tables is not None: + # chunked prefill or decode + block_table = seq_group_metadata.block_tables[seq_id] + if curr_sliding_window_blocks is not None: + block_table = block_table[ + -curr_sliding_window_blocks:] + if self.attn_backend.get_name() == "flashinfer": + paged_kv_indices.extend(block_table) + paged_kv_indptr.append(paged_kv_indptr[-1] + + len(block_table)) + last_page_len = seq_data.get_len( + ) % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + paged_kv_last_page_len.append(last_page_len) + else: + # Only happens when memory profiling runs. + block_table = [] + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + block_tables.append(block_table) + + seq_lens.append(sliding_seq_len) + context_lens.append(sliding_context_len) + query_len = sliding_seq_len - sliding_context_len + query_lens.append(query_len) + input_tokens.extend(tokens) + input_positions.extend(list(range(context_len, seq_len))) + lora_id = seq_group_metadata.lora_int_id + + if is_prompt: + assert len(seq_ids) == 1 + num_prefills += 1 + num_prefill_tokens += len(tokens) + decode_only = False + prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + num_decode_tokens += query_len + decode_seq_lens.append(sliding_seq_len) + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * query_len + lora_prompt_mapping.extend( + [lora_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + is not None else 1)) + + mm_data = seq_group_metadata.multi_modal_data + if mm_data is not None: + # Process multi-modal data + if self.multi_modal_input_processor is None: + raise ValueError( + "Multi-modal inputs are only supported by " + "vision language models.") + + mm_kwargs = self.multi_modal_input_processor(mm_data) + for k, v in mm_kwargs.items(): + multi_modal_kwargs_list[k].append(v) + + if _is_block_tables_empty(seq_group_metadata.block_tables): + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + + # Mask the [0, start_idx) tokens of the prompt with + # _PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + if is_prompt: + assert self.scheduler_config.use_v2_block_manager \ + or context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager") + # It is an optimization. When it is decoding, it is always + # 0. When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + for i in range(context_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + batch_size = len(input_tokens) + max_query_len = max(query_lens) + max_prefill_seq_len = max(prefill_seq_lens, default=0) + max_decode_seq_len = max(decode_seq_lens, default=0) + + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. + use_captured_graph = ( + decode_only and not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_decode_seq_len <= self.max_seq_len_to_capture) + if use_captured_graph: + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + for _ in range(graph_batch_size - batch_size): + input_tokens.append(0) + input_positions.append(0) + slot_mapping.append(_PAD_SLOT_ID) + seq_lens.append(1) + block_tables.append([]) + lora_index_mapping.append(0) + batch_size = graph_batch_size + num_decode_tokens = batch_size + + if use_captured_graph: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.graph_block_tables[:batch_size] + for i, block_table in enumerate(block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=self.device) + else: + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + + if self.attn_backend.get_name() == "flashinfer": + if not hasattr(self, "flashinfer_workspace_buffer"): + # Allocate 16MB workspace buffer + # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html + self.flashinfer_workspace_buffer = torch.empty( + 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) + paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, + dtype=torch.int, + device=self.device) + paged_kv_indices_tensor = torch.tensor(paged_kv_indices, + dtype=torch.int, + device=self.device) + paged_kv_last_page_len_tensor = torch.tensor( + paged_kv_last_page_len, dtype=torch.int, device=self.device) + kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, + self.model_config.dtype) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + use_cuda_graph=False, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, + workspace_buffer=self.flashinfer_workspace_buffer, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + num_qo_heads=self.model_config.get_num_attention_heads( + self.parallel_config), + num_kv_heads=self.model_config.get_num_kv_heads( + self.parallel_config), + head_dim=self.model_config.get_head_size(), + page_size=16, + seq_start_loc=seq_start_loc, + data_type=kv_cache_dtype) + else: + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + attn_metadata = self.attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + if self.lora_config: + lora_mapping = LoRAMapping( + lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + + multi_modal_kwargs = { + k: torch.cat(v, dim=0).to(self.device) + for k, v in multi_modal_kwargs_list.items() + } + + return ModelInput( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_mapping=lora_mapping, + lora_requests=lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + ) + + def prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, + Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: + if self.is_driver_worker: + assert seq_group_metadata_list is not None + # Prepare input tensors. + ( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + query_lens, + lora_mapping, + lora_requests, + multi_modal_kwargs, + slot_mapping, + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.pin_memory) + + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "selected_token_indices": + sampling_metadata.selected_token_indices, + "lora_requests": lora_requests, + "lora_mapping": lora_mapping, + "multi_modal_kwargs": multi_modal_kwargs, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + } + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + selected_token_indices = metadata_dict.pop( + "selected_token_indices") + lora_mapping = metadata_dict.pop("lora_mapping") + lora_requests = metadata_dict.pop("lora_requests") + multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + else: + attn_metadata = None + sampling_metadata = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + + return (input_tokens, input_positions, attn_metadata, + sampling_metadata, lora_requests, lora_mapping, + multi_modal_kwargs) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[torch.Tensor], + ) -> Optional[SamplerOutput]: + (input_tokens, input_positions, attn_metadata, sampling_metadata, + lora_requests, lora_mapping, multi_modal_kwargs + ) = self.prepare_input_tensors(seq_group_metadata_list) + + if self.lora_config: + self.set_active_loras(lora_requests, lora_mapping) + + # Currently cuda graph is only supported by the decode phase. + prefill_meta = attn_metadata.prefill_metadata + decode_meta = attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: + graph_batch_size = input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + + hidden_states = model_executable( + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + **multi_modal_kwargs, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return None + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + + return output + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests = [] + dummy_lora_requests_per_seq = [] + if self.lora_config: + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for vision encoding, which needs + # to be accounted for when calculating the GPU blocks for + # vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + model_config = self.model_config + vlm_config = self.vision_language_config + + if vlm_config: + max_num_seqs = min( + max_num_seqs, + int(max_num_batched_tokens / vlm_config.image_feature_size)) + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + + if vlm_config is None: + seq_data = SequenceData([0] * seq_len) + dummy_multi_modal_data = None + else: + seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \ + .dummy_data_for_profiling(seq_len, model_config, vlm_config) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, + multi_modal_data=dummy_multi_modal_data, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + self.execute_model(seqs, kv_caches) + torch.cuda.synchronize() + return + + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_loras() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_loras(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_loras() + + @torch.inference_mode() + def capture_model(self, kv_caches: List[torch.Tensor]) -> None: + """Cuda graph capture a model. + + Note that CUDA graph's performance gain is negligible if number + of batched tokens are larger than 200. And since CUDA graph + requires fixed sized tensors, supporting large/variable batch + size requires high GPU memory overhead. Thus, vLLM only captures + decoding requests. Mixed batch (chunked prefill + decoding) or + prefill requests are not captured. + + Since it is used for decoding-only, it assumes there's only 1 token + per sequence in the batch. + """ + assert not self.model_config.enforce_eager + logger.info("Capturing the model for CUDA graphs. This may lead to " + "unexpected consequences if the model is not static. To " + "run the model in eager mode, set 'enforce_eager=True' or " + "use '--enforce-eager' in the CLI.") + logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " + "If you are running out of memory, consider decreasing " + "`gpu_memory_utilization` or enforcing eager mode. " + "You can also reduce the `max_num_seqs` as needed " + "to decrease memory usage.") + start_time = time.perf_counter() + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() + input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() + slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() + slot_mapping.fill_(_PAD_SLOT_ID) + seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + block_tables = torch.from_numpy(self.graph_block_tables).cuda() + + # Prepare buffer for outputs. These will be reused for all batch sizes. + # It will be filled after the first graph capture. + hidden_states: Optional[torch.Tensor] = None + + graph_batch_size = _get_graph_batch_size( + self.scheduler_config.max_num_seqs) + batch_size_capture_list = [ + bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size + ] + + with graph_capture() as graph_capture_context: + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for batch_size in reversed(batch_size_capture_list): + # Create dummy attn_metadata. + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=seq_lens[:batch_size], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=block_tables[:batch_size], + use_cuda_graph=True, + ) + + if self.lora_config: + lora_mapping = LoRAMapping( + [0] * batch_size, + [0] * batch_size, + ) + self.set_active_loras(set(), lora_mapping) + + graph_runner = CUDAGraphRunner(self.model) + hidden_states = graph_runner.capture( + input_tokens[:batch_size], + input_positions[:batch_size], + hidden_states[:batch_size] + if hidden_states is not None else None, + kv_caches, + attn_metadata, + memory_pool=self.graph_memory_pool, + stream=graph_capture_context.stream, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner + + end_time = time.perf_counter() + elapsed_time = end_time - start_time + # This usually takes < 10 seconds. + logger.info("Graph capturing finished in %.0f secs.", elapsed_time) + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() \ No newline at end of file From ee4a275f8a3e17aa66edd3b7d71637cd1c1cdc4e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 12 Jun 2024 16:05:22 -0400 Subject: [PATCH 218/443] wip --- vllm/worker/enc_dec_model_runner.py | 292 +++------------------------- vllm/worker/model_runner.py | 28 ++- 2 files changed, 55 insertions(+), 265 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index c870718f09c96..259a95c2ba794 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -25,16 +25,26 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) -from vllm.worker.model_runner import (_PAD_SLOT_ID, - LORA_WARMUP_RANK, +from vllm.worker.model_runner import (_PAD_SLOT_ID, LORA_WARMUP_RANK, _BATCH_SIZE_ALIGNMENT, _BATCH_SIZES_TO_CAPTURE, - _NUM_WARMUP_ITERS, - ModelInput, - ModelRunner) + _NUM_WARMUP_ITERS, ModelInput, + ModelRunner, _is_block_tables_empty, + _get_graph_batch_size, CUDAGraphRunner, + _is_encoder_decoder_model) +from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) logger = init_logger(__name__) +# Error message if EncoderDecoderModelRunner is used with +# a non-encoder/decoder model (i.e. decoder-only) +STR_ENCDECMR_ENCODER_DECODER_REQUIRED = "Only encoder/decoder models may be executed using EncoderDecoderModelRunner" + +# Error message if EncoderDecoderModelRunner is used with +# CUDAGraph +STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED = "Currently CUDAGraph is not supported for encoder/decoder models" + class EncoderDecoderModelInput(ModelInput): input_tokens: torch.Tensor input_positions: torch.Tensor @@ -67,7 +77,7 @@ def empty(cls, device): ) -class EncoderDecoderModelRunner: +class EncoderDecoderModelRunner(ModelRunner): def __init__( self, @@ -82,145 +92,18 @@ def __init__( is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config - self.is_driver_worker = is_driver_worker - self.vision_language_config = vision_language_config - - self.device = self.device_config.device - self.pin_memory = is_pin_memory_available() - - self.kv_cache_dtype = kv_cache_dtype - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture - self.graph_runners: Dict[int, CUDAGraphRunner] = {} - self.graph_memory_pool: Optional[Tuple[ - int, int]] = None # Set during graph capture. - # When using CUDA graph, the input block tables must be padded to - # max_seq_len_to_capture. However, creating the block table in - # Python can be expensive. To optimize this, we cache the block table - # in numpy and only copy the actual input content at every iteration. - # The shape of the cached block table will be - # (max batch size to capture, max context len to capture / block size). - self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), - dtype=np.int32) - self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), - self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), - self.model_config.get_sliding_window(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, - ) - - # Create processor for multi-modal data - if self.vision_language_config is not None: - self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ - .create_input_processor( - self.model_config, - self.vision_language_config, - ) - else: - self.multi_modal_input_processor = None - - # Lazy initialization - self.model: nn.Module # Set after load_model - # Set if the backend is flashinfer. - self.flashinfer_workspace_buffer: torch.Tensor - # Set after load_model. - self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - - def load_model(self) -> None: - with CudaMemoryProfiler() as m: - self.model = get_model( - model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config, - ) + super().__init__(model_config, parallel_config, scheduler_config, + device_config, cache_config, load_config, lora_config, + kv_cache_dtype, is_driver_worker, + vision_language_config) - self.model_memory_usage = m.consumed_memory - logger.info("Loading model weights took %.4f GB", - self.model_memory_usage / float(2**30)) + if not self._is_encoder_decoder_model(): + # Fail if EncoderDecoderModelRunner is constructed for a + # non-encoder/decoder model i.e. decoder-only + raise AttributeError(STR_ENCDECMR_ENCODER_DECODER_REQUIRED) - if self.lora_config: - assert hasattr(self.model, "supported_lora_modules" - ) and self.model.supported_lora_modules, ( - "Model does not support LoRA") - assert hasattr( - self.model, - "embedding_modules"), "Model does not have embedding_modules" - assert hasattr(self.model, "embedding_padding_modules" - ), "Model does not have embedding_padding_modules" - self.lora_manager = LRUCacheWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, - self.vocab_size, - self.lora_config, - self.device, - self.model.embedding_modules, - self.model.embedding_padding_modules, - max_position_embeddings=self.model.config. - max_position_embeddings, - ) - self.model = self.lora_manager.create_lora_manager(self.model) - - if self.kv_cache_dtype == "fp8" and is_hip(): - # Currently only ROCm accepts kv-cache scaling factors - # via quantization_param_path and this will be deprecated - # in the future. - if self.model_config.quantization_param_path is not None: - if callable(getattr(self.model, "load_kv_cache_scales", None)): - warnings.warn( - "Loading kv cache scaling factor from JSON is " - "deprecated and will be removed. Please include " - "kv cache scaling factors in the model checkpoint.", - FutureWarning, - stacklevel=2) - self.model.load_kv_cache_scales( - self.model_config.quantization_param_path) - logger.info("Loaded KV cache scaling factors from %s", - self.model_config.quantization_param_path) - else: - raise RuntimeError( - "Using FP8 KV cache and scaling factors provided but " - "model %s does not support loading scaling factors.", - self.model.__class__) - else: - logger.warning( - "Using FP8 KV cache but no scaling factors " - "provided. Defaulting to scaling factors of 1.0. " - "This may lead to less accurate results!") - - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - from vllm.model_executor.model_loader.loader import ShardedStateLoader - ShardedStateLoader.save_model( - self.model, - path, - pattern=pattern, - max_size=max_size, - ) - - def get_max_block_per_batch(self) -> int: - block_size = self.block_size - return (self.max_seq_len_to_capture + block_size - 1) // block_size + if self.scheduler_config.chunked_prefill_enabled: + raise NotImplementedError() def _prepare_model_input( self, @@ -830,125 +713,6 @@ def profile_run(self) -> None: torch.cuda.synchronize() return - def remove_all_loras(self): - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.remove_all_loras() - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.set_active_loras(lora_requests, lora_mapping) - - def add_lora(self, lora_request: LoRARequest) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_lora(lora_id) - - def list_loras(self) -> Set[int]: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_loras() - @torch.inference_mode() - def capture_model(self, kv_caches: List[torch.Tensor]) -> None: - """Cuda graph capture a model. - - Note that CUDA graph's performance gain is negligible if number - of batched tokens are larger than 200. And since CUDA graph - requires fixed sized tensors, supporting large/variable batch - size requires high GPU memory overhead. Thus, vLLM only captures - decoding requests. Mixed batch (chunked prefill + decoding) or - prefill requests are not captured. - - Since it is used for decoding-only, it assumes there's only 1 token - per sequence in the batch. - """ - assert not self.model_config.enforce_eager - logger.info("Capturing the model for CUDA graphs. This may lead to " - "unexpected consequences if the model is not static. To " - "run the model in eager mode, set 'enforce_eager=True' or " - "use '--enforce-eager' in the CLI.") - logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " - "If you are running out of memory, consider decreasing " - "`gpu_memory_utilization` or enforcing eager mode. " - "You can also reduce the `max_num_seqs` as needed " - "to decrease memory usage.") - start_time = time.perf_counter() - - # Prepare dummy inputs. These will be reused for all batch sizes. - max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() - input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() - slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() - slot_mapping.fill_(_PAD_SLOT_ID) - seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() - block_tables = torch.from_numpy(self.graph_block_tables).cuda() - - # Prepare buffer for outputs. These will be reused for all batch sizes. - # It will be filled after the first graph capture. - hidden_states: Optional[torch.Tensor] = None - - graph_batch_size = _get_graph_batch_size( - self.scheduler_config.max_num_seqs) - batch_size_capture_list = [ - bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size - ] - - with graph_capture() as graph_capture_context: - # NOTE: Capturing the largest batch size first may help reduce the - # memory usage of CUDA graph. - for batch_size in reversed(batch_size_capture_list): - # Create dummy attn_metadata. - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping[:batch_size], - seq_lens=None, - seq_lens_tensor=seq_lens[:batch_size], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables[:batch_size], - use_cuda_graph=True, - ) - - if self.lora_config: - lora_mapping = LoRAMapping( - [0] * batch_size, - [0] * batch_size, - ) - self.set_active_loras(set(), lora_mapping) - - graph_runner = CUDAGraphRunner(self.model) - hidden_states = graph_runner.capture( - input_tokens[:batch_size], - input_positions[:batch_size], - hidden_states[:batch_size] - if hidden_states is not None else None, - kv_caches, - attn_metadata, - memory_pool=self.graph_memory_pool, - stream=graph_capture_context.stream, - ) - self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[batch_size] = graph_runner - - end_time = time.perf_counter() - elapsed_time = end_time - start_time - # This usually takes < 10 seconds. - logger.info("Graph capturing finished in %.0f secs.", elapsed_time) - - @property - def vocab_size(self) -> int: - return self.model_config.get_vocab_size() \ No newline at end of file + def capture_model(self, _: List[torch.Tensor]) -> None: + raise NotImplementedError(STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED) \ No newline at end of file diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 99b12293a0244..d732bcca01b98 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -38,6 +38,8 @@ ] _NUM_WARMUP_ITERS = 2 +# Error message if ModelRunner is used with an encoder/decoder model +STR_MR_ENCODER_DECODER_UNSUPPORTED = "Encoder/decoder model must be executed using EncoderDecoderModelRunner" class ModelInput(NamedTuple): input_tokens: torch.Tensor @@ -142,6 +144,14 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + if self._am_not_child() and self._is_encoder_decoder_model(): + # Fail if ModelRunner is constructed for an + # encoder/decoder model + # + # Bypass this check if this constructor is being invoked by a child + # class (i.e. type(self) is not ModelRunner) + raise AttributeError(STR_MR_ENCODER_DECODER_UNSUPPORTED) + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -860,6 +870,22 @@ def list_loras(self) -> Set[int]: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.list_loras() + def _is_encoder_decoder_model(self): + ''' + Identify encoder/decoder models using the is_encoder_decoder + field of the HF config, if this field is present; otherwise + return False. + ''' + return False if self.model_config is None else \ + getattr(self.model_config.hf_config, "is_encoder_decoder", False) + + def _am_not_child(self): + ''' + True if self is an instance of the ModelRunner + base class, False otherwise (i.e. child class) + ''' + return type(self) is not ModelRunner + @torch.inference_mode() def capture_model(self, kv_caches: List[torch.Tensor]) -> None: """Cuda graph capture a model. @@ -1084,4 +1110,4 @@ def _is_block_tables_empty(block_tables: Union[None, Dict]): if isinstance(block_tables, dict) and all( value is None for value in block_tables.values()): return True - return False + return False \ No newline at end of file From 4f85b6e1dc7f208878d74ebe4b01d8409d09d251 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 12 Jun 2024 16:06:06 -0400 Subject: [PATCH 219/443] wip --- vllm/worker/enc_dec_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 259a95c2ba794..979af9c3a7416 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -34,6 +34,7 @@ _is_encoder_decoder_model) from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, STR_NOT_IMPL_ENC_DEC_SWA) +from vllm.attention.backends.utils import STR logger = init_logger(__name__) From 97cad0b96dafe5cabe004bc4a119f00d7d9db5a6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 12 Jun 2024 16:59:48 -0400 Subject: [PATCH 220/443] encoder-only unit test passes --- tests/kernels/test_encoder_decoder_attn.py | 55 +++++++++++++++++++++- tests/kernels/utils.py | 41 +++++++++++----- vllm/attention/backends/utils.py | 7 +++ vllm/attention/backends/xformers.py | 27 ++++++----- 4 files changed, 103 insertions(+), 27 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 922e3ddb43dd8..c8c72c21f5cc3 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -680,6 +680,59 @@ def _run_encoder_decoder_cross_attention_test( return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, value, kv_cache, attn_metadata) +@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) +@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) +def test_encoder_only(num_heads: int, head_size: int, backend_name: str, + batch_size: int, block_size: int, + max_dec_seq_len: int, max_enc_seq_len: int, + monkeypatch): + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + # Note: KV cache size of 4096 is arbitrary & chosen intentionally + # to be more than necessary, since exceeding the kv cache size + # is not part of this test + test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, + block_size, max_dec_seq_len, max_enc_seq_len, 4096) + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + test_rsrcs = _make_test_resources(test_pt) + + # Construct encoder attention test params (only used + # during prefill) + + enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) + + # Shared prefill metadata structure + + prephase_attn_metadata: AttentionMetadata = make_test_metadata( + test_rsrcs.attn_backend, + True, + None, + decoder_test_params=None, + encoder_test_params=enc_test_params, + cross_test_params=None, + default_attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE) + + # PREFILL: encoder attention + + enc_pckd_act_out: torch.Tensor = \ + _run_encoder_attention_test( + test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata) + + # - Is encoder attention result correct? + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -1114,4 +1167,4 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # "Encoder decoder models do not currently support prefix caching" # or something to that effect - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING \ No newline at end of file diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index ea226878e4b33..3dfc70a46eeb4 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -733,15 +733,28 @@ def make_test_metadata( * AttentionMetadata structure ''' - kv_mmap = decoder_test_params.kv_mmap - - num_prefills_or_decodes = len(seq_lens) - - # Prefill: operate on total num. of prompt - # tokens - # Decode: operate on one token per seq - num_prefill_or_decode_tokens = \ - sum(seq_lens) if is_prompt else len(seq_lens) + # Decoder self-attention memory mapping + # decoder_test_params is None signals encoder-only + # scenario, so kv_mmap is None + kv_mmap = None if decoder_test_params is None else \ + decoder_test_params.kv_mmap + + # This function constructs metadata assuming no chunked prefill, + # i.e. 100% prefill tokens or 100% decode tokens + # + # - If is_prompt, num_prefills_or_decodes is the number of prefills + # and num_prefill_or_decode_tokens is the number of prefill tokens + # - If not is_prompt, num_prefills_or_decodes is the number of decodes + # and num_prefill_or_decode_tokens is the number of decode tokens + # + # seq_lens is None signals encoder-only + # scenario, in which case num_prefills_or_decodes and + # num_prefill_or_decode_tokens are unused + num_prefills_or_decodes = None if seq_lens is None else \ + len(seq_lens) + + num_prefill_or_decode_tokens = None if seq_lens is None else \ + (sum(seq_lens) if is_prompt else len(seq_lens)) # Seems for non-prefix-caching scenarios context_lens # is never needed @@ -750,14 +763,14 @@ def make_test_metadata( if encoder_test_params is None: encoder_seq_lens = None else: - # Encoder/decoder models only: + # Encoder/decoder or encoder-only models only: # * Extract encoder input sequence lengths encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens if cross_test_params is None: cross_kv_mmap = None else: - # Encoder/decoder models only: + # Encoder/decoder or encoder-only models only: # * Extract *cross-attention* slot_mapping and block table # (kv_mmap) cross_kv_mmap = cross_test_params.kv_mmap @@ -782,7 +795,8 @@ def make_test_metadata( return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=kv_mmap.slot_mapping, + slot_mapping=None if kv_mmap is None else \ + kv_mmap.slot_mapping, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -790,7 +804,8 @@ def make_test_metadata( max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, context_lens_tensor=context_lens_tensor, - block_tables=kv_mmap.block_tables, + block_tables=None if kv_mmap is None else \ + kv_mmap.block_tables, use_cuda_graph=False, _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 45a6f4af37d13..c34b35cf77c8a 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -76,6 +76,13 @@ def assert_no_encdec_chunked_prefill_assuming_supported_backend( # scenarios. return + if attn_metadata.num_prefill_tokens is None or \ + attn_metadata.num_decode_tokens is None: + # The metadata which would be + # indicative of chunked prefill is unset; + # this may be the case for encoder-only models + return + if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: # Encoder/decoder models are currently incompatible diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 32ea44f74d106..32fed02d3adc9 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -211,7 +211,6 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) #assert self.context_lens_tensor is not None - assert self.block_tables is not None query_start_loc = None if self.query_start_loc is None \ else self.query_start_loc[:self.num_prefills + 1] @@ -220,19 +219,20 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None - if self.seq_lens is None \ + slot_mapping=None if self.slot_mapping is None else \ + self.slot_mapping[:self.num_prefill_tokens], + seq_lens=None if self.seq_lens is None \ else self.seq_lens[:self.num_prefills], - seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills], + seq_lens_tensor=None if self.seq_lens_tensor is None else \ + self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, query_start_loc=query_start_loc, context_lens_tensor=None if self.context_lens_tensor is None else \ - self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], + self.context_lens_tensor[:self.num_prefills], + block_tables=None if self.block_tables is None else \ + self.block_tables[:self.num_prefills], use_cuda_graph=False, _attn_type=self.attention_type, # Begin encoder & cross attn fields below... @@ -252,7 +252,6 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: self._cached_decode_metadata.attention_type = \ self.attention_type return self._cached_decode_metadata - assert self.block_tables is not None assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) @@ -260,15 +259,17 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + slot_mapping=None if self.slot_mapping is None else \ + self.slot_mapping[self.num_prefill_tokens:], seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:], + self.seq_lens_tensor[self.num_prefills:], max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - block_tables=self.block_tables[self.num_prefills:], + block_tables=None if self.block_tables is None else \ + self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, _attn_type=self. - _attn_type, # Begin encoder & cross attn fields below... + _attn_type, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, From 29fa1af416f1a6e18abe76ad05a68170f960cc6e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 13 Jun 2024 07:58:14 -0400 Subject: [PATCH 221/443] refactoring --- vllm/attention/backends/utils.py | 42 ++++++++++++++++------ vllm/attention/backends/xformers.py | 56 ++++++++++++++++++----------- 2 files changed, 68 insertions(+), 30 deletions(-) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index c34b35cf77c8a..b0c05fca285f8 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -7,33 +7,33 @@ STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ "Chunked prefill is not currently " + \ -"supported with encoder/decoder models." +"supported with encoder/decoder or encoder-only models." STR_NOT_IMPL_ENC_DEC_ROCM_HIP = \ "ROCm/HIP is not currently supported" + \ -"with encoder/decoder models." +"with encoder/decoder or encoder-only models." STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND = \ "Currently only the XFormers backend " + \ - "supports encoder/decoder models." + "supports encoder/decoder and encoder-only models." STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING = \ "Prefix caching is not currently supported " + \ -"with encoder/decoder models" +"with encoder/decoder or encoder-only models" # Check for unsupported encoder/decoder scenarios -def is_encoder_decoder_metadata_assuming_supported_backend( +def is_encoder_metadata_assuming_supported_backend( attn_metadata) -> bool: ''' - Return True of the attn_metadata argument contains + Return True if the attn_metadata argument contains the metadata fields that would be required for encoder attention, which proves that the user is - not running a purely decoder-only model. + not running a purely decoder-only model Assumes attn_metadata is derived from a backend that supports - encoder/decoder models. + encoder-only or encoder/decoder models. Arguments: @@ -43,10 +43,32 @@ def is_encoder_decoder_metadata_assuming_supported_backend( Returns: - * True if attn_metadata is configured for an encoder/decoder model + * True if attn_metadata is configured for an encoder-only model ''' return attn_metadata.is_all_encoder_attn_metadata_set +def is_encoder_decoder_metadata_assuming_supported_backend( + attn_metadata) -> bool: + ''' + Return True if the attn_metadata argument contains + the metadata fields that would be required for + encoder/decoder attention, which proves that the user is + running an encoder/decoder model + + Assumes attn_metadata is derived from a backend that supports + encoder-only or encoder/decoder models. + + Arguments: + + * attn_metadata: instance of supported backend metadata. + Type annotation omitted to avoid circular import. + + + Returns: + + * True if attn_metadata is configured for an encoder/decoder model + ''' + return attn_metadata.is_all_encoder_decoder_attn_metadata_set def fail_encoder_decoder_prefix_caching() -> None: ''' @@ -70,7 +92,7 @@ def assert_no_encdec_chunked_prefill_assuming_supported_backend( * attn_metadata: Attention metadata structure ''' - if not is_encoder_decoder_metadata_assuming_supported_backend( + if not is_encoder_metadata_assuming_supported_backend( attn_metadata): # Only care about encoder/decoder # scenarios. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 32fed02d3adc9..63fe10dc2502f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -14,7 +14,7 @@ from vllm.attention.backends.utils import ( assert_no_encdec_chunked_prefill_assuming_supported_backend, fail_encoder_decoder_prefix_caching, - is_encoder_decoder_metadata_assuming_supported_backend) + is_encoder_metadata_assuming_supported_backend) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -202,6 +202,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: return None if self._cached_prefill_metadata is not None: + # Recover cached prefill-phase attention + # metadata structure self._cached_prefill_metadata.attention_type = \ self.attention_type return self._cached_prefill_metadata @@ -210,29 +212,35 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: (self.encoder_seq_lens is not None) assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) - #assert self.context_lens_tensor is not None + # Compute some attn_metadata fields which default to None query_start_loc = None if self.query_start_loc is None \ else self.query_start_loc[:self.num_prefills + 1] - + slot_mapping=None if self.slot_mapping is None else \ + self.slot_mapping[:self.num_prefill_tokens] + seq_lens=None if self.seq_lens is None \ + else self.seq_lens[:self.num_prefills] + seq_lens_tensor=None if self.seq_lens_tensor is None else \ + self.seq_lens_tensor[:self.num_prefills] + context_lens_tensor=None if self.context_lens_tensor is None else \ + self.context_lens_tensor[:self.num_prefills] + block_tables=None if self.block_tables is None else \ + self.block_tables[:self.num_prefills] + + # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = XFormersMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, - slot_mapping=None if self.slot_mapping is None else \ - self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None if self.seq_lens is None \ - else self.seq_lens[:self.num_prefills], - seq_lens_tensor=None if self.seq_lens_tensor is None else \ - self.seq_lens_tensor[:self.num_prefills], + slot_mapping=slot_mapping, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, query_start_loc=query_start_loc, - context_lens_tensor=None if self.context_lens_tensor is None else \ - self.context_lens_tensor[:self.num_prefills], - block_tables=None if self.block_tables is None else \ - self.block_tables[:self.num_prefills], + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, use_cuda_graph=False, _attn_type=self.attention_type, # Begin encoder & cross attn fields below... @@ -249,24 +257,32 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: return None if self._cached_decode_metadata is not None: + # Recover cached decode-phase attention + # metadata structure self._cached_decode_metadata.attention_type = \ self.attention_type return self._cached_decode_metadata assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) + # Compute some attn_metadata fields which default to None + slot_mapping=None if self.slot_mapping is None else \ + self.slot_mapping[self.num_prefill_tokens:] + seq_lens_tensor=None if self.seq_lens_tensor is None else \ + self.seq_lens_tensor[self.num_prefills:] + block_tables=None if self.block_tables is None else \ + self.block_tables[self.num_prefills:] + + # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = XFormersMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, - slot_mapping=None if self.slot_mapping is None else \ - self.slot_mapping[self.num_prefill_tokens:], - seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:], + slot_mapping=slot_mapping, + seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - block_tables=None if self.block_tables is None else \ - self.block_tables[self.num_prefills:], + block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, _attn_type=self. _attn_type, # Begin encoder & cross attn fields below... @@ -560,7 +576,7 @@ def forward( assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: - if is_encoder_decoder_metadata_assuming_supported_backend( + if is_encoder_metadata_assuming_supported_backend( attn_metadata): fail_encoder_decoder_prefix_caching() From 71098ce57ee100f7a43f9caa0221be631ca260f0 Mon Sep 17 00:00:00 2001 From: laishzh Date: Thu, 13 Jun 2024 21:11:42 +0800 Subject: [PATCH 222/443] feat: refactor the load_weights(), remove Dropout, use VocabParallelEmbedding and rename to layernorm. --- vllm/model_executor/models/bert_embedding.py | 114 ++++++++++--------- 1 file changed, 61 insertions(+), 53 deletions(-) diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index 93c4068e8b1c2..37f7be827817a 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -12,6 +12,8 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import PoolerOutput @@ -28,6 +30,27 @@ class BertEmbeddingModel(nn.Module): _pooler: An instance of Pooler used for pooling operations. """ + stacked_params_mapping = { + "query": { + "param_name": "qkv_proj", + "shard_id": "q", + }, + "key": { + "param_name": "qkv_proj", + "shard_id": "k", + }, + "value": { + "param_name": "qkv_proj", + "shard_id": "v", + }, + } + + params_mapping = { + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + } + def __init__( self, **kwargs, @@ -62,56 +85,49 @@ def pooler( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - def _fix_key(key): - if "beta" in key: - return key.replace("beta", "bias") - if "gamma" in key: - return key.replace("gamma", "weight") - return key - - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "query", "q"), - ("qkv_proj", "key", "k"), - ("qkv_proj", "value", "v"), - ] params_dict = dict(self.model.named_parameters()) - _prefix = f"{self.base_model_prefix}." + for name, loaded_weight in weights: + name = self._rename_key(name) + # Skip the specific downstream task weight. if name.startswith('cls.'): continue - - name = name[len(_prefix):] if name.startswith(_prefix) else name - name = _fix_key(name) - # use Pooler instead. if name.startswith('pooler.'): continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = param.weight_loader + name, shard_id = self._rename_stacked_param(name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if shard_id: weight_loader(param, loaded_weight, shard_id) - break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) weight_loader(param, loaded_weight) + def _rekey_key(self, key: str): + prefix = f"{self.base_model_prefix}." + key = key[len(prefix):] if key.startswith(prefix) else key + + for src, dst in self.params_mapping.items(): + key = key.replace(src, dst) + + return key + + def _rename_stacked_param( + self, + name: str, + ) -> Tuple[str, Optional[str]]: + for key, mapping in self.stacked_params_mapping.items(): + if key in name: + name = name.replace(key, mapping["param_name"]) + return name, mapping["shard_id"] + return name, None + class BertModel(nn.Module): @@ -145,17 +161,14 @@ class BertEmbedding(nn.Module): def __init__(self, config: BertConfig): super().__init__() self.size = config.hidden_size - - self.word_embeddings = nn.Embedding(config.vocab_size, - config.hidden_size, - padding_idx=config.pad_token_id) + self.word_embeddings = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) self.position_embedding_type = config.position_embedding_type if self.position_embedding_type != "absolute": @@ -198,8 +211,7 @@ def forward( embeddings = inputs_embeds + token_type_embeddings embeddings += position_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) + embeddings = self.layernorm(embeddings) return embeddings @@ -355,14 +367,12 @@ class BertSelfOutput(nn.Module): def __init__(self, config: BertConfig): super(BertSelfOutput, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) + hidden_states = self.layernorm(hidden_states + input_tensor) return hidden_states @@ -384,9 +394,8 @@ class BertOutput(nn.Module): def __init__(self, config: BertConfig): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( self, @@ -394,6 +403,5 @@ def forward( input_tensor: torch.Tensor, ) -> torch.Tensor: hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) + hidden_states = self.layernorm(hidden_states + input_tensor) return hidden_states From 861556d5448ab12c1cb0d055c348cc3596e793fb Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 13 Jun 2024 10:44:54 -0400 Subject: [PATCH 223/443] wip --- vllm/worker/model_runner.py | 414 +++++++++++++++++++++--------------- 1 file changed, 238 insertions(+), 176 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 476e9ba3bb463..fa9cddd446408 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -237,6 +237,214 @@ def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size + def _prepare_seq_model_input(self, + is_prompt:bool, + num_prefills:int, + num_prefill_tokens:int, + num_decode_tokens:int, + decode_only:bool, + block_tables:List[List[int]], + seq_lens:List[int], + slot_mapping:List[int], + context_lens:List[int], + query_lens:List[int], + input_tokens:List[int], + input_positions:List[int], + prefill_seq_lens:List[int], + decode_seq_lens:List[int], + seq_group_metadata:SequenceGroupMetadata, + seq_data:SequenceData, + computed_block_nums:Optional[List[int]], + block_table:Optional[List[int]], + paged_kv_indices:Optional[List[int]], + paged_kv_indptr:Optional[List[int]], + paged_kv_last_page_len:Optional[List[int]], + sliding_window_blocks:int = 0, + block_aligned_sliding_window:int = 0, + lora_index_mapping: List[int] = [], + lora_prompt_mapping: List[int] = [], + lora_requests: Set[LoRARequest] = set(), + multi_modal_kwargs_list: Dict[str, + List[torch.Tensor]] = defaultdict(list)) -> tuple[bool,int,int,int]: + + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_data.get_len() - 1 + + seq_len = min( + seq_data.get_len(), + context_len + seq_group_metadata.token_chunk_size) + if is_prompt: + tokens = seq_data.get_token_ids()[context_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + + # Prefix cache was hit. + # Prefix is not supported with sliding_window + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and is_prompt) + + # These are seq_len/context_len capped to the sliding window. + # They are passed to decode kernel. + # We still need original seq_len/context_len to compute slot + # mapping (and input position) below. + curr_sliding_window_blocks = None + sliding_seq_len = seq_len + sliding_context_len = context_len + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if (self.sliding_window is not None and not is_prompt): + curr_sliding_window_blocks = sliding_window_blocks + if self.scheduler_config.use_v2_block_manager: + # number of elements in last block + suff_len = seq_len % self.block_size + sliding_seq_len = min( + seq_len, block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_blocks += 1 + else: + sliding_seq_len = min(seq_len, self.sliding_window) + sliding_context_len = sliding_seq_len - 1 + + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + context_len = len(computed_block_nums) * self.block_size + tokens = tokens[context_len:] + + # need to think what to set it to when we have both sliding + # window and prefix caching... + assert self.sliding_window is None, \ + "Prefix caching is not supported with sliding window" + sliding_context_len = context_len + + if self.attn_backend.get_name() != "flash-attn": + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + # TODO(woosuk): This is a temporary fix. We should + # provide a unified interface for different backends. + #block_table = seq_group_metadata.block_tables[seq_id] + block_table = computed_block_nums + elif (self.scheduler_config.chunked_prefill_enabled + or not is_prompt): + if seq_group_metadata.block_tables is not None: + # chunked prefill or decode + #block_table = seq_group_metadata.block_tables[seq_id] + if curr_sliding_window_blocks is not None: + block_table = block_table[ + -curr_sliding_window_blocks:] + if self.attn_backend.get_name() == "flashinfer": + paged_kv_indices.extend(block_table) + paged_kv_indptr.append(paged_kv_indptr[-1] + + len(block_table)) + last_page_len = seq_data.get_len( + ) % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + paged_kv_last_page_len.append(last_page_len) + else: + # Only happens when memory profiling runs. + block_table = [] + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + block_tables.append(block_table) + + seq_lens.append(sliding_seq_len) + context_lens.append(sliding_context_len) + query_len = sliding_seq_len - sliding_context_len + query_lens.append(query_len) + input_tokens.extend(tokens) + input_positions.extend(list(range(context_len, seq_len))) + lora_id = seq_group_metadata.lora_int_id + + if is_prompt: + num_prefills += 1 + num_prefill_tokens += len(tokens) + decode_only = False + prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + num_decode_tokens += query_len + decode_seq_lens.append(sliding_seq_len) + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * query_len + lora_prompt_mapping.extend( + [lora_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + is not None else 1)) + + mm_data = seq_group_metadata.multi_modal_data + if mm_data is not None: + # Process multi-modal data + if self.multi_modal_input_processor is None: + raise ValueError( + "Multi-modal inputs are only supported by " + "vision language models.") + + mm_kwargs = self.multi_modal_input_processor(mm_data) + for k, v in mm_kwargs.items(): + multi_modal_kwargs_list[k].append(v) + + if _is_block_tables_empty(seq_group_metadata.block_tables): + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + return decode_only,num_prefills,num_prefill_tokens,num_decode_tokens + + # Compute the slot mapping. + #block_table = seq_group_metadata.block_tables[seq_id] + + # Mask the [0, start_idx) tokens of the prompt with + # _PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + if is_prompt: + assert self.scheduler_config.use_v2_block_manager \ + or context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager") + # It is an optimization. When it is decoding, it is always + # 0. When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + for i in range(context_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + return decode_only,num_prefills,num_prefill_tokens,num_decode_tokens + def _prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -273,6 +481,9 @@ def _prepare_model_input( num_prefill_tokens = 0 num_decode_tokens = 0 + sliding_window_blocks = 0 + block_aligned_sliding_window = 0 + # The following fields are only for flashinfer # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout # for the precise definition of the following fields. @@ -315,183 +526,34 @@ def _prepare_model_input( "now.") seq_data = seq_group_metadata.seq_data[seq_id] - if is_prompt: - context_len = seq_data.get_num_computed_tokens() - else: - # get_num_computed_tokens is incorrect for spec decoding. - # So, we should have a special logic here. - # TODO(sang): Fix it. - context_len = seq_data.get_len() - 1 - - seq_len = min( - seq_data.get_len(), - context_len + seq_group_metadata.token_chunk_size) - if is_prompt: - tokens = seq_data.get_token_ids()[context_len:seq_len] - else: - # Optimization. get_token_ids requires the entire copy of - # tokens. - tokens = [seq_data.get_last_token_id()] - - # Prefix cache was hit. - # Prefix is not supported with sliding_window - prefix_cache_hit = (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None - and is_prompt) - - # These are seq_len/context_len capped to the sliding window. - # They are passed to decode kernel. - # We still need original seq_len/context_len to compute slot - # mapping (and input position) below. - curr_sliding_window_blocks = None - sliding_seq_len = seq_len - sliding_context_len = context_len - - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - if (self.sliding_window is not None and not is_prompt): - curr_sliding_window_blocks = sliding_window_blocks - if self.scheduler_config.use_v2_block_manager: - # number of elements in last block - suff_len = seq_len % self.block_size - sliding_seq_len = min( - seq_len, block_aligned_sliding_window + suff_len) - if suff_len > 0: - curr_sliding_window_blocks += 1 - else: - sliding_seq_len = min(seq_len, self.sliding_window) - sliding_context_len = sliding_seq_len - 1 - - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - if prefix_cache_hit: - assert computed_block_nums is not None - context_len = len(computed_block_nums) * self.block_size - tokens = tokens[context_len:] - - # need to think what to set it to when we have both sliding - # window and prefix caching... - assert self.sliding_window is None, \ - "Prefix caching is not supported with sliding window" - sliding_context_len = context_len - - if self.attn_backend.get_name() == "flash-attn": - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - # TODO(woosuk): This is a temporary fix. We should - # provide a unified interface for different backends. - block_table = seq_group_metadata.block_tables[seq_id] - else: - block_table = computed_block_nums - elif (self.scheduler_config.chunked_prefill_enabled - or not is_prompt): - if seq_group_metadata.block_tables is not None: - # chunked prefill or decode - block_table = seq_group_metadata.block_tables[seq_id] - if curr_sliding_window_blocks is not None: - block_table = block_table[ - -curr_sliding_window_blocks:] - if self.attn_backend.get_name() == "flashinfer": - paged_kv_indices.extend(block_table) - paged_kv_indptr.append(paged_kv_indptr[-1] + - len(block_table)) - last_page_len = seq_data.get_len( - ) % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - paged_kv_last_page_len.append(last_page_len) - else: - # Only happens when memory profiling runs. - block_table = [] - else: - # Prefill without chunked prefill or memory profiling. - block_table = [] - block_tables.append(block_table) - - seq_lens.append(sliding_seq_len) - context_lens.append(sliding_context_len) - query_len = sliding_seq_len - sliding_context_len - query_lens.append(query_len) - input_tokens.extend(tokens) - input_positions.extend(list(range(context_len, seq_len))) - lora_id = seq_group_metadata.lora_int_id - - if is_prompt: - assert len(seq_ids) == 1 - num_prefills += 1 - num_prefill_tokens += len(tokens) - decode_only = False - prefill_seq_lens.append(seq_len) - else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) - num_decode_tokens += query_len - decode_seq_lens.append(sliding_seq_len) - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * query_len - lora_prompt_mapping.extend( - [lora_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - is not None else 1)) - - mm_data = seq_group_metadata.multi_modal_data - if mm_data is not None: - # Process multi-modal data - if self.multi_modal_input_processor is None: - raise ValueError( - "Multi-modal inputs are only supported by " - "vision language models.") - - mm_kwargs = self.multi_modal_input_processor(mm_data) - for k, v in mm_kwargs.items(): - multi_modal_kwargs_list[k].append(v) - - if _is_block_tables_empty(seq_group_metadata.block_tables): - # During memory profiling, the block tables are not - # initialized yet. In this case, we just use a dummy - # slot mapping. - # In embeddings, the block tables are {seq_id: None}. - slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] - - # Mask the [0, start_idx) tokens of the prompt with - # _PAD_SLOT_ID, where start_idx is max(0, seq_len - - # sliding_window). For example, if the prompt len is 10, - # sliding window is 8, and block size is 4, the first two - # tokens are masked and the slot mapping will be - # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - if is_prompt: - assert self.scheduler_config.use_v2_block_manager \ - or context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention in V1 block manager") - # It is an optimization. When it is decoding, it is always - # 0. When prefill, we use it to not write slots to kv cache - # to save memory. - start_idx = max(0, query_len - self.sliding_window) - - for i in range(context_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) + decode_only,num_prefills,num_prefill_tokens,num_decode_tokens=self._prepare_seq_model_input(is_prompt, + decode_only, + num_prefills, + num_prefill_tokens, + num_decode_tokens, + block_tables, + seq_lens, + slot_mapping, + context_lens, + query_lens, + input_tokens, + input_positions, + prefill_seq_lens, + decode_seq_lens, + seq_group_metadata, + seq_data, + computed_block_nums, + block_table, + paged_kv_indices, + paged_kv_indptr, + paged_kv_last_page_len, + sliding_window_blocks, + block_aligned_sliding_window, + lora_index_mapping, + lora_prompt_mapping, + lora_requests, + multi_modal_kwargs_list) batch_size = len(input_tokens) max_query_len = max(query_lens) From f111f7177d789d06eb3282c06706cafd6744cb35 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 13 Jun 2024 10:48:34 -0400 Subject: [PATCH 224/443] refactor; still debugging --- vllm/worker/enc_dec_model_runner.py | 19 ++-- vllm/worker/model_runner.py | 136 ++++++++++++---------------- 2 files changed, 67 insertions(+), 88 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 979af9c3a7416..17bf68160f98f 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -25,27 +25,26 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) -from vllm.worker.model_runner import (_PAD_SLOT_ID, LORA_WARMUP_RANK, - _BATCH_SIZE_ALIGNMENT, - _BATCH_SIZES_TO_CAPTURE, - _NUM_WARMUP_ITERS, ModelInput, - ModelRunner, _is_block_tables_empty, - _get_graph_batch_size, CUDAGraphRunner, - _is_encoder_decoder_model) +from vllm.worker.model_runner import ( + _PAD_SLOT_ID, LORA_WARMUP_RANK, _BATCH_SIZE_ALIGNMENT, + _BATCH_SIZES_TO_CAPTURE, _NUM_WARMUP_ITERS, ModelInput, ModelRunner, + _is_block_tables_empty, _get_graph_batch_size, CUDAGraphRunner, + _is_encoder_decoder_model) from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, STR_NOT_IMPL_ENC_DEC_SWA) from vllm.attention.backends.utils import STR logger = init_logger(__name__) -# Error message if EncoderDecoderModelRunner is used with +# Error message if EncoderDecoderModelRunner is used with # a non-encoder/decoder model (i.e. decoder-only) STR_ENCDECMR_ENCODER_DECODER_REQUIRED = "Only encoder/decoder models may be executed using EncoderDecoderModelRunner" -# Error message if EncoderDecoderModelRunner is used with +# Error message if EncoderDecoderModelRunner is used with # CUDAGraph STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED = "Currently CUDAGraph is not supported for encoder/decoder models" + class EncoderDecoderModelInput(ModelInput): input_tokens: torch.Tensor input_positions: torch.Tensor @@ -716,4 +715,4 @@ def profile_run(self) -> None: @torch.inference_mode() def capture_model(self, _: List[torch.Tensor]) -> None: - raise NotImplementedError(STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED) \ No newline at end of file + raise NotImplementedError(STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 53e8ec9961e55..f5b7919d0cd39 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -42,6 +42,7 @@ # Error message if ModelRunner is used with an encoder/decoder model STR_MR_ENCODER_DECODER_UNSUPPORTED = "Encoder/decoder model must be executed using EncoderDecoderModelRunner" + class ModelInput(NamedTuple): input_tokens: torch.Tensor input_positions: torch.Tensor @@ -247,35 +248,37 @@ def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size - def _prepare_seq_model_input(self, - is_prompt:bool, - num_prefills:int, - num_prefill_tokens:int, - num_decode_tokens:int, - decode_only:bool, - block_tables:List[List[int]], - seq_lens:List[int], - slot_mapping:List[int], - context_lens:List[int], - query_lens:List[int], - input_tokens:List[int], - input_positions:List[int], - prefill_seq_lens:List[int], - decode_seq_lens:List[int], - seq_group_metadata:SequenceGroupMetadata, - seq_data:SequenceData, - computed_block_nums:Optional[List[int]], - block_table:Optional[List[int]], - paged_kv_indices:Optional[List[int]], - paged_kv_indptr:Optional[List[int]], - paged_kv_last_page_len:Optional[List[int]], - sliding_window_blocks:int = 0, - block_aligned_sliding_window:int = 0, - lora_index_mapping: List[int] = [], - lora_prompt_mapping: List[int] = [], - lora_requests: Set[LoRARequest] = set(), - multi_modal_kwargs_list: Dict[str, - List[torch.Tensor]] = defaultdict(list)) -> tuple[bool,int,int,int]: + def _prepare_seq_model_input( + self, + is_prompt: bool, + num_prefills: int, + num_prefill_tokens: int, + num_decode_tokens: int, + decode_only: bool, + block_tables: List[List[int]], + seq_lens: List[int], + slot_mapping: List[int], + context_lens: List[int], + query_lens: List[int], + input_tokens: List[int], + input_positions: List[int], + prefill_seq_lens: List[int], + decode_seq_lens: List[int], + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, + computed_block_nums: Optional[List[int]], + block_table: Optional[List[int]], + paged_kv_indices: Optional[List[int]], + paged_kv_indptr: Optional[List[int]], + paged_kv_last_page_len: Optional[List[int]], + sliding_window_blocks: int = 0, + block_aligned_sliding_window: int = 0, + lora_index_mapping: List[int] = [], + lora_prompt_mapping: List[int] = [], + lora_requests: Set[LoRARequest] = set(), + multi_modal_kwargs_list: Dict[str, + List[torch.Tensor]] = defaultdict(list) + ) -> tuple[bool, int, int, int]: if is_prompt: context_len = seq_data.get_num_computed_tokens() @@ -285,9 +288,8 @@ def _prepare_seq_model_input(self, # TODO(sang): Fix it. context_len = seq_data.get_len() - 1 - seq_len = min( - seq_data.get_len(), - context_len + seq_group_metadata.token_chunk_size) + seq_len = min(seq_data.get_len(), + context_len + seq_group_metadata.token_chunk_size) if is_prompt: tokens = seq_data.get_token_ids()[context_len:seq_len] else: @@ -299,8 +301,7 @@ def _prepare_seq_model_input(self, # Prefix is not supported with sliding_window prefix_cache_hit = (computed_block_nums is not None and len(computed_block_nums) > 0 - and self.sliding_window is None - and is_prompt) + and self.sliding_window is None and is_prompt) # These are seq_len/context_len capped to the sliding window. # They are passed to decode kernel. @@ -318,8 +319,8 @@ def _prepare_seq_model_input(self, if self.scheduler_config.use_v2_block_manager: # number of elements in last block suff_len = seq_len % self.block_size - sliding_seq_len = min( - seq_len, block_aligned_sliding_window + suff_len) + sliding_seq_len = min(seq_len, + block_aligned_sliding_window + suff_len) if suff_len > 0: curr_sliding_window_blocks += 1 else: @@ -347,20 +348,17 @@ def _prepare_seq_model_input(self, # provide a unified interface for different backends. #block_table = seq_group_metadata.block_tables[seq_id] block_table = computed_block_nums - elif (self.scheduler_config.chunked_prefill_enabled - or not is_prompt): + elif (self.scheduler_config.chunked_prefill_enabled or not is_prompt): if seq_group_metadata.block_tables is not None: # chunked prefill or decode #block_table = seq_group_metadata.block_tables[seq_id] if curr_sliding_window_blocks is not None: - block_table = block_table[ - -curr_sliding_window_blocks:] + block_table = block_table[-curr_sliding_window_blocks:] if self.attn_backend.get_name() == "flashinfer": paged_kv_indices.extend(block_table) paged_kv_indptr.append(paged_kv_indptr[-1] + - len(block_table)) - last_page_len = seq_data.get_len( - ) % self.block_size + len(block_table)) + last_page_len = seq_data.get_len() % self.block_size if last_page_len == 0: last_page_len = self.block_size paged_kv_last_page_len.append(last_page_len) @@ -399,16 +397,15 @@ def _prepare_seq_model_input(self, lora_prompt_mapping.extend( [lora_id] * (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - is not None else 1)) + and seq_group_metadata.sampling_params.prompt_logprobs is not None + else 1)) mm_data = seq_group_metadata.multi_modal_data if mm_data is not None: # Process multi-modal data if self.multi_modal_input_processor is None: - raise ValueError( - "Multi-modal inputs are only supported by " - "vision language models.") + raise ValueError("Multi-modal inputs are only supported by " + "vision language models.") mm_kwargs = self.multi_modal_input_processor(mm_data) for k, v in mm_kwargs.items(): @@ -420,7 +417,7 @@ def _prepare_seq_model_input(self, # slot mapping. # In embeddings, the block tables are {seq_id: None}. slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - return decode_only,num_prefills,num_prefill_tokens,num_decode_tokens + return decode_only, num_prefills, num_prefill_tokens, num_decode_tokens # Compute the slot mapping. #block_table = seq_group_metadata.block_tables[seq_id] @@ -453,7 +450,7 @@ def _prepare_seq_model_input(self, slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - return decode_only,num_prefills,num_prefill_tokens,num_decode_tokens + return decode_only, num_prefills, num_prefill_tokens, num_decode_tokens def _prepare_model_input( self, @@ -537,33 +534,16 @@ def _prepare_model_input( seq_data = seq_group_metadata.seq_data[seq_id] block_table = seq_group_metadata.block_tables[seq_id] - decode_only,num_prefills,num_prefill_tokens,num_decode_tokens=self._prepare_seq_model_input(is_prompt, - decode_only, - num_prefills, - num_prefill_tokens, - num_decode_tokens, - block_tables, - seq_lens, - slot_mapping, - context_lens, - query_lens, - input_tokens, - input_positions, - prefill_seq_lens, - decode_seq_lens, - seq_group_metadata, - seq_data, - computed_block_nums, - block_table, - paged_kv_indices, - paged_kv_indptr, - paged_kv_last_page_len, - sliding_window_blocks, - block_aligned_sliding_window, - lora_index_mapping, - lora_prompt_mapping, - lora_requests, - multi_modal_kwargs_list) + decode_only, num_prefills, num_prefill_tokens, num_decode_tokens = self._prepare_seq_model_input( + is_prompt, decode_only, num_prefills, num_prefill_tokens, + num_decode_tokens, block_tables, seq_lens, slot_mapping, + context_lens, query_lens, input_tokens, input_positions, + prefill_seq_lens, decode_seq_lens, seq_group_metadata, + seq_data, computed_block_nums, block_table, + paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, + sliding_window_blocks, block_aligned_sliding_window, + lora_index_mapping, lora_prompt_mapping, lora_requests, + multi_modal_kwargs_list) batch_size = len(input_tokens) max_query_len = max(query_lens) @@ -1183,4 +1163,4 @@ def _is_block_tables_empty(block_tables: Union[None, Dict]): if isinstance(block_tables, dict) and all( value is None for value in block_tables.values()): return True - return False \ No newline at end of file + return False From ba363b7253c6447adec2d1c64b316dc2ef4d4973 Mon Sep 17 00:00:00 2001 From: laishzh Date: Fri, 14 Jun 2024 17:42:39 +0800 Subject: [PATCH 225/443] fix: _rename_key --- vllm/model_executor/models/bert_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index 37f7be827817a..745949a1309e5 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -109,7 +109,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: weight_loader(param, loaded_weight) - def _rekey_key(self, key: str): + def _rename_key(self, key: str): prefix = f"{self.base_model_prefix}." key = key[len(prefix):] if key.startswith(prefix) else key From 3ea38598efd623b60dfd82a5db40a7f1cbee9f85 Mon Sep 17 00:00:00 2001 From: laishzh Date: Fri, 14 Jun 2024 18:46:14 +0800 Subject: [PATCH 226/443] fix: rename stacked_param first --- vllm/model_executor/models/bert_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index 745949a1309e5..5ec64e5384091 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -89,6 +89,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: name = self._rename_key(name) + name, shard_id = self._rename_stacked_param(name) # Skip the specific downstream task weight. if name.startswith('cls.'): @@ -100,7 +101,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue - name, shard_id = self._rename_stacked_param(name) param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From df109cfb0f5cfb428b069dfa8176802d08fa52de Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 13:36:37 -0400 Subject: [PATCH 227/443] wip --- vllm/worker/model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f5b7919d0cd39..162a641336f3a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -251,10 +251,10 @@ def get_max_block_per_batch(self) -> int: def _prepare_seq_model_input( self, is_prompt: bool, + decode_only: bool, num_prefills: int, num_prefill_tokens: int, num_decode_tokens: int, - decode_only: bool, block_tables: List[List[int]], seq_lens: List[int], slot_mapping: List[int], @@ -267,7 +267,7 @@ def _prepare_seq_model_input( seq_group_metadata: SequenceGroupMetadata, seq_data: SequenceData, computed_block_nums: Optional[List[int]], - block_table: Optional[List[int]], + original_block_table: Optional[List[int]], paged_kv_indices: Optional[List[int]], paged_kv_indptr: Optional[List[int]], paged_kv_last_page_len: Optional[List[int]], From 9cb8ee6858befd620be10a5b91f801fa15d561a6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 13:56:19 -0400 Subject: [PATCH 228/443] In ModelRunner, refactored the computation of model inputs for a given sequence into its own function --- vllm/worker/model_runner.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 162a641336f3a..914ec29987e21 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -341,17 +341,19 @@ def _prepare_seq_model_input( "Prefix caching is not supported with sliding window" sliding_context_len = context_len - if self.attn_backend.get_name() != "flash-attn": + if self.attn_backend.get_name() == "flash-attn": # NOTE(woosuk): For flash-attn, the block table should # include the entries for the incoming prefill tokens. # TODO(woosuk): This is a temporary fix. We should # provide a unified interface for different backends. - #block_table = seq_group_metadata.block_tables[seq_id] + block_table = original_block_table + else: block_table = computed_block_nums + elif (self.scheduler_config.chunked_prefill_enabled or not is_prompt): if seq_group_metadata.block_tables is not None: # chunked prefill or decode - #block_table = seq_group_metadata.block_tables[seq_id] + block_table = original_block_table if curr_sliding_window_blocks is not None: block_table = block_table[-curr_sliding_window_blocks:] if self.attn_backend.get_name() == "flashinfer": @@ -420,7 +422,7 @@ def _prepare_seq_model_input( return decode_only, num_prefills, num_prefill_tokens, num_decode_tokens # Compute the slot mapping. - #block_table = seq_group_metadata.block_tables[seq_id] + block_table = original_block_table # Mask the [0, start_idx) tokens of the prompt with # _PAD_SLOT_ID, where start_idx is max(0, seq_len - From c9f11ff89f49915524284126fe2a80c0d356af05 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 15:40:45 -0400 Subject: [PATCH 229/443] comment fixes --- tests/kernels/utils.py | 6 +++--- vllm/attention/backends/xformers.py | 15 +++------------ 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index ea226878e4b33..71cd93cd3df92 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -73,9 +73,9 @@ class PackedQKVInputs(NamedTuple): * {query,key,value}: packed (number_of_tokens x num_heads x head_size) attention inputs - * q_seq_lens: list of query start locations within packed tensor - * kv_seq_lens: shared list of key/value start locations within - packed tensor + * q_start_loc_list: list of query start locations within packed tensor + * kv_start_loc_list: shared list of key/value start locations within + packed tensor * q_seq_lens: query sequence lengths list * kv_seq_lens: shared key/value sequence lengths list ''' diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 18c324598bf9f..3e10b55fb6fcd 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -142,15 +142,11 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # _attn_type: AttentionType = AttentionType.DECODER - # (batch_size,). The "cross-sequence-length" per sequence,i.e. the key/value - # sequence length (usually encoder sequence length) in the cross-attention - # computation. None if this is self-attention + # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None - # The maximum cross-sequence-length, if cross_seq_lens is specified. - # Note that for cross-attention there is no difference in key/value - # sequence length between prefill and decode + # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None # Cross-attention memory-mapping data structures: slot mapping @@ -200,15 +196,11 @@ def attention_type(self, atype: AttentionType) -> None: "Must set self.encoder_seq_lens, self.cross_slot_mapping, " + \ "self.cross_block_tables in order to perform cross-attention" - self._attn_type = AttentionType.ENCODER_DECODER elif atype == AttentionType.ENCODER: assert self.is_all_encoder_attn_metadata_set, \ "Must set self.encoder_seq_lens in order to perform cross-attention" - self._attn_type = AttentionType.ENCODER - else: - # AttentionType.{ENCODER,DECODER} - self._attn_type = atype + self._attn_type = atype @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: @@ -224,7 +216,6 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: (self.encoder_seq_lens is not None) assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) - #assert self.context_lens_tensor is not None assert self.block_tables is not None query_start_loc = None if self.query_start_loc is None \ From 196e671a2fa583c7b0beb65236b61c0df54bff4b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 15:45:03 -0400 Subject: [PATCH 230/443] removed unnecessary tests --- tests/kernels/test_encoder_decoder_attn.py | 254 +-------------------- tests/kernels/utils.py | 6 +- 2 files changed, 8 insertions(+), 252 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 922e3ddb43dd8..8b93bb8fe6782 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -149,9 +149,7 @@ class that Attention will automatically select when it is constructed. def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ -> PhaseTestParameters: - (num_heads, head_size, _, batch_size, _, _, max_q_seq_len, _) = test_pt - scale = test_rsrcs.scale ''' Set up test vectors & data structures for encoder attention test. @@ -181,6 +179,10 @@ def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ implementation, and (3) KVCache field set to None ''' + (num_heads, head_size, _, batch_size, _, _, max_q_seq_len, _) = test_pt + + scale = test_rsrcs.scale + max_kv_seq_len = max_q_seq_len # Make test tensors @@ -868,250 +870,4 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out) - - -@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES_FOR_UNSUPP) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) -@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, - head_size: int, - backend_name: str, - batch_size: int, - block_size: int, - max_dec_seq_len: int, - max_enc_seq_len: int, - monkeypatch) -> None: - ''' - Confirm encoder/decoder models will fail with NotImplemented - if chunked prefill is enabled. - - This test - 1. Executes a subset of test setup code from - test_e2e_enc_dec_attn() (everything up to encoder - execution); see test_e2e_enc_dec_attn() for more context - on how this code works. - - 2. Modifies the prefill-phase attention metadata structure - to imply a chunked-prefill scenario - - 3. Attempts to execute decoder self-attention - - 4. Asserts that that decoder self-attention fails & with the correct - error message - - Note on ROCm/HIP: currently encoder/decoder models are not supported on - AMD GPUs, therefore this test simply is skipped if is_hip(). - ''' - - # Force Attention wrapper backend - override_backend_env_variable(monkeypatch, backend_name) - - test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, - block_size, max_dec_seq_len, max_enc_seq_len, 4096) - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - test_rsrcs = _make_test_resources(test_pt) - - # Encoder attention setup - - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) - - # Decoder self-attention setup - - dec_qkv, \ - prephase_dec_test_params, \ - _, \ - cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) - - # Cross-attention setup - - prephase_cross_test_params, \ - _, \ - = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, - enc_test_params, - prephase_dec_test_params, - test_pt, - test_rsrcs, - block_base_addr = \ - cross_block_base_addr) - - # Shared prefill metadata structure - - prephase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, - True, - prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, - decoder_test_params=prephase_dec_test_params, - encoder_test_params=enc_test_params, - cross_test_params=prephase_cross_test_params, - default_attn_type=AttentionType.ENCODER, - device=CUDA_DEVICE) - - # PREFILL: encoder attention - - enc_packed_actual_output: torch.Tensor = \ - _run_encoder_attention_test( - test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata) - - # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) - - # Meat of the test: require that chunked prefill triggers failure. - # - # Set up a contrived scenario where the attention metadata - # is configured for chunked prefill & decoder self- - # attention. Required that this triggers a NotImplementedError. - # - # We assume that decode_attn_metadata.num_prefill_tokens > 1 - # already; the line below sets up a chunked prefill - # metadata configuration where there is nominally a mix - # of prefill and decode tokens. - prephase_attn_metadata.num_decode_tokens = 1 - with pytest.raises(NotImplementedError) as exc_info: - - # Doomed decoder self-attention - _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, - prephase_attn_metadata) - - # "Encoder decoder models do not currently support chunked prefill" - # or something to that effect - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL - - -@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES_FOR_UNSUPP) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) -@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, - head_size: int, - backend_name: str, - batch_size: int, - block_size: int, - max_dec_seq_len: int, - max_enc_seq_len: int, - monkeypatch) -> None: - ''' - Confirm encoder/decoder models will fail with NotImplemented - if prefix caching is enabled. - - This test - 1. Executes a subset of test setup code from - test_e2e_enc_dec_attn() (everything up to encoder - execution); see test_e2e_enc_dec_attn() for more context - on how this code works. - - 2. Modifies the prefill-phase attention metadata structure - to imply a prefix caching scenario - - 3. Attempts to execute decoder self-attention - - 4. Asserts that that decoder self-attention fails & with the correct - error message - - Note on ROCm/HIP: currently encoder/decoder models are not supported on - AMD GPUs, therefore this test simply is skipped if is_hip(). - ''' - - # Force Attention wrapper backend - override_backend_env_variable(monkeypatch, backend_name) - - test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, - block_size, max_dec_seq_len, max_enc_seq_len, 4096) - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - test_rsrcs = _make_test_resources(test_pt) - - # Encoder attention setup - - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) - - # Decoder self-attention setup - - dec_qkv, \ - prephase_dec_test_params, \ - _, \ - cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) - - # Cross-attention setup - - prephase_cross_test_params, \ - _, \ - = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, - enc_test_params, - prephase_dec_test_params, - test_pt, - test_rsrcs, - block_base_addr = \ - cross_block_base_addr) - - # Shared prefill metadata structure - - prephase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, - True, - prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, - decoder_test_params=prephase_dec_test_params, - encoder_test_params=enc_test_params, - cross_test_params=prephase_cross_test_params, - default_attn_type=AttentionType.ENCODER, - device=CUDA_DEVICE) - - # PREFILL: encoder attention - - enc_packed_actual_output: torch.Tensor = \ - _run_encoder_attention_test( - test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata) - - # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) - - # Meat of the test: require that prefix caching triggers failure. - # - # Set up a contrived scenario where the attention metadata - # is configured for prefix caching & decoder self- - # attention. Require that this triggers a NotImplementedError. - with pytest.raises(NotImplementedError) as exc_info: - # In XFormers backend, the trigger for utilizing the - # prefix caching kernel is - # - # kv_cache is not None and prefill_meta.block_tables.numel() > 0 - # - # We can shallowly emulate a prefix caching scenario by passing - # in a non-None KV cache in test_rsrcs (already the - # case) and then tweaking the cached prefill attention metadata - # from the encoder run to have a non-empty (gibberish) block - # table. This block table will never actually be used, because - # its presence will signify to the backend a prefix-caching - # scenario and (given that the attention metadata structure - # is configured for an encoder/decoder scenario too) trigger - # a NotImplemented a exception. - - num_seqs = len( - prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens) - - prephase_attn_metadata._cached_prefill_metadata.block_tables = \ - torch.randint( - 0, 10, (num_seqs, 1)) - - _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, - prephase_attn_metadata) - - # "Encoder decoder models do not currently support prefix caching" - # or something to that effect - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING + decphase_cross_pckd_act_out) \ No newline at end of file diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 71cd93cd3df92..608a2b68fee81 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -696,9 +696,9 @@ def make_test_metadata( Construct fake attention metadata for a given test phase (prefill-phase or decode-phase). - encoder_test_params and cross_test_params arguments all encoder - attention and enc/dec cross-attention to use distinct metadata values - from decoder self-attention (decoder_test_params.) + encoder_test_params and cross_test_params arguments allow encoder + attention and enc/dec cross-attention (respectively) to use distinct + metadata values from decoder self-attention (decoder_test_params.) if encoder_test_params and cross_test_params are None, the attention metadata will support decoder-only scenario. From 03e5d8135731854460b62cac16149cbc8bb8ade9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 16:04:28 -0400 Subject: [PATCH 231/443] assert value None-ness matches key None-ness --- vllm/attention/backends/xformers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3e10b55fb6fcd..c62b782d928f5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -478,9 +478,11 @@ def forward( """ query = query.view(-1, self.num_heads, self.head_size) if key is not None: + assert value is not None key = key.view(-1, self.num_kv_heads, self.head_size) - if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None # Self-attention vs. cross-attention will impact # which KV cache memory-mapping & which From 1f3874db00ad9d3dc17239d92428107d9701c700 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 16:08:16 -0400 Subject: [PATCH 232/443] comment fix --- tests/kernels/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 608a2b68fee81..58938493603dd 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -31,7 +31,7 @@ class QKVInputs(NamedTuple): ''' Data structure for representing unpacked attention inputs, - query/key/value. + query/key/values and their sequence lengths. Attributes: From 528b4a71e3824d1aba1da9004c91aa9df1a713e6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 16:19:25 -0400 Subject: [PATCH 233/443] Remove util fxns & error strings for unneeded tests --- tests/kernels/test_encoder_decoder_attn.py | 7 +- vllm/attention/backends/utils.py | 80 +--------------------- vllm/attention/backends/xformers.py | 16 +---- 3 files changed, 6 insertions(+), 97 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 8b93bb8fe6782..af3614aaa2221 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -17,9 +17,7 @@ from tests.kernels.utils import * from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType -from vllm.attention.backends.utils import ( - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING, - STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.utils import is_hip, make_causal_mask, maybe_make_long_tensor HEAD_SIZES = [64, 256] @@ -149,7 +147,6 @@ class that Attention will automatically select when it is constructed. def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ -> PhaseTestParameters: - ''' Set up test vectors & data structures for encoder attention test. @@ -870,4 +867,4 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out) \ No newline at end of file + decphase_cross_pckd_act_out) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 45a6f4af37d13..24c89fd7967a1 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,84 +1,8 @@ -"""Attention utils""" - -from vllm.attention import AttentionMetadata +"""Attention backend utils""" # Error string(s) for encoder/decoder # unsupported attention scenarios -STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ -"Chunked prefill is not currently " + \ -"supported with encoder/decoder models." - STR_NOT_IMPL_ENC_DEC_ROCM_HIP = \ "ROCm/HIP is not currently supported" + \ -"with encoder/decoder models." - -STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND = \ -"Currently only the XFormers backend " + \ - "supports encoder/decoder models." - -STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING = \ -"Prefix caching is not currently supported " + \ -"with encoder/decoder models" - -# Check for unsupported encoder/decoder scenarios - - -def is_encoder_decoder_metadata_assuming_supported_backend( - attn_metadata) -> bool: - ''' - Return True of the attn_metadata argument contains - the metadata fields that would be required for - encoder attention, which proves that the user is - not running a purely decoder-only model. - - Assumes attn_metadata is derived from a backend that supports - encoder/decoder models. - - Arguments: - - * attn_metadata: instance of supported backend metadata. - Type annotation omitted to avoid circular import. - - - Returns: - - * True if attn_metadata is configured for an encoder/decoder model - ''' - return attn_metadata.is_all_encoder_attn_metadata_set - - -def fail_encoder_decoder_prefix_caching() -> None: - ''' - Fail with NotImplementedError & a message indicating - enc/dec + prefix caching is unsupported - ''' - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING) - - -def assert_no_encdec_chunked_prefill_assuming_supported_backend( - attn_metadata: AttentionMetadata) -> None: - ''' - Fail if encoder/decoder model is being executed with - chunked prefill. - - Assumes we already know that the particular attention - backend in-use is supported. - - Arguments: - - * attn_metadata: Attention metadata structure - ''' - - if not is_encoder_decoder_metadata_assuming_supported_backend( - attn_metadata): - # Only care about encoder/decoder - # scenarios. - return - - if attn_metadata.num_prefill_tokens > 0 and \ - attn_metadata.num_decode_tokens > 0: - # Encoder/decoder models are currently incompatible - # with chunked prefill. - raise NotImplementedError( \ - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) +"with encoder/decoder models." \ No newline at end of file diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index c62b782d928f5..cec6ee5867ae0 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,10 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import ( - assert_no_encdec_chunked_prefill_assuming_supported_backend, - fail_encoder_decoder_prefix_caching, - is_encoder_decoder_metadata_assuming_supported_backend) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -125,7 +121,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Attention type enum. # - # * Impact on XFormersImpl.forward(): + # * Impact on XFormersImpl.forward(): # # * DECODER: normal decoder-only behavior; # use decoder self-attention block table @@ -139,7 +135,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # will match encoder sequence lengths, pass encoder sequence # attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ # max_encoder_seq_len) - # + # _attn_type: AttentionType = AttentionType.DECODER # Encoder sequence lengths representation @@ -489,11 +485,6 @@ def forward( # seqlen datastructures we utilize attn_type = attn_metadata.attention_type - # Raise NotImplementedError for unsupported encoder/decoder - # scenarios (has no effect on decoder-only models) - assert_no_encdec_chunked_prefill_assuming_supported_backend( - attn_metadata) - if (attn_type != AttentionType.ENCODER and \ kv_cache is not None): # KV-cache during decoder-self- or @@ -571,9 +562,6 @@ def forward( assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: - if is_encoder_decoder_metadata_assuming_supported_backend( - attn_metadata): - fail_encoder_decoder_prefix_caching() assert prefill_meta.query_start_loc is not None assert prefill_meta.max_query_len is not None From b3c3411e26b7cf6f27604825d99a920c34605c9c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 16:39:35 -0400 Subject: [PATCH 234/443] formatting --- tests/kernels/test_encoder_decoder_attn.py | 9 +++++---- vllm/attention/backends/xformers.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 3212f331c47b2..99a5ae7b5f808 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -679,6 +679,7 @@ def _run_encoder_decoder_cross_attention_test( return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, value, kv_cache, attn_metadata) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -688,10 +689,9 @@ def _run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) def test_encoder_only(num_heads: int, head_size: int, backend_name: str, - batch_size: int, block_size: int, - max_dec_seq_len: int, max_enc_seq_len: int, - monkeypatch): - + batch_size: int, block_size: int, max_dec_seq_len: int, + max_enc_seq_len: int, monkeypatch): + # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) @@ -733,6 +733,7 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str, # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6745e3bba7c9f..c25957ea156a0 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -287,7 +287,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, _attn_type=self. - _attn_type, # Begin encoder & cross attn fields below... + _attn_type, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, From e229e0018138698bf13135f067eaf32a8cbf9167 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 16 Jun 2024 22:47:04 -0400 Subject: [PATCH 235/443] format --- tests/kernels/test_encoder_decoder_attn.py | 16 ++++++-- tests/kernels/utils.py | 45 +++++++++++++--------- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 99a5ae7b5f808..ef6c0fa9876b1 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -221,7 +221,7 @@ def _decoder_attn_setup( test_pt: TestPoint, test_rsrcs: TestResources, block_base_addr: int = 0, -) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: +) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: ''' Set up test vectors & data structures for self-attention test. @@ -390,8 +390,8 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, PhaseTestParameters, test_pt: TestPoint, test_rsrcs: TestResources, - block_base_addr: Optional[int]=0) \ - -> tuple[PhaseTestParameters, + block_base_addr: int=0) \ + -> Tuple[PhaseTestParameters, PhaseTestParameters]: ''' Set up test vectors & data structures for cross-attention test. @@ -456,6 +456,9 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, for decode phase. ''' + assert encoder_test_params.packed_qkvo.packed_qkv is not None + assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None + (num_heads, head_size, _, batch_size, block_size, max_decoder_seq_len, max_encoder_seq_len, _) = test_pt @@ -467,6 +470,7 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, prefill_q_seq_lens = \ prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens + assert prefill_q_seq_lens is not None cross_kv, \ _, \ @@ -591,6 +595,7 @@ def _run_encoder_attention_test(attn: Attention, assert attn_metadata.num_decode_tokens == 0 attn_metadata.attention_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv + assert packed_qkv is not None return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, None, attn_metadata) @@ -624,6 +629,7 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, kv_cache = test_rsrcs.kv_cache attn_metadata.attention_type = AttentionType.DECODER packed_qkv = decoder_test_params.packed_qkvo.packed_qkv + assert packed_qkv is not None return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, kv_cache, attn_metadata) @@ -664,6 +670,8 @@ def _run_encoder_decoder_cross_attention_test( * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata ''' + assert decoder_test_params.packed_qkvo.packed_qkv is not None + attn_metadata.attention_type = AttentionType.ENCODER_DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache @@ -839,7 +847,7 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, cross_block_base_addr) # Shared prefill metadata structure - + assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None prephase_attn_metadata: AttentionMetadata = make_test_metadata( test_rsrcs.attn_backend, True, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 9ff07bc9a6264..ffa6b69ef2374 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,7 +2,7 @@ import itertools import random -from typing import List, NamedTuple, Optional, Union +from typing import Any, List, NamedTuple, Optional, Tuple, Union import pytest import torch @@ -83,10 +83,10 @@ class PackedQKVInputs(NamedTuple): query: torch.Tensor key: torch.Tensor value: torch.Tensor - q_start_loc_list: List[int] - kv_start_loc_list: List[int] - q_seq_lens: List[int] - kv_seq_lens: List[int] + q_start_loc_list: Optional[List[int]] + kv_start_loc_list: Optional[List[int]] + q_seq_lens: Optional[List[int]] + kv_seq_lens: Optional[List[int]] class PackedQKVO(NamedTuple): @@ -102,7 +102,7 @@ class PackedQKVO(NamedTuple): x head_size) known-correct attention output ''' - packed_qkv: PackedQKVInputs + packed_qkv: Optional[PackedQKVInputs] ideal_output: torch.Tensor @@ -136,7 +136,7 @@ class PhaseTestParameters(NamedTuple): ''' packed_qkvo: PackedQKVO - kv_mmap: KVMemoryMap + kv_mmap: Optional[KVMemoryMap] def override_backend_env_variable(mpatch: pytest.MonkeyPatch, @@ -185,6 +185,9 @@ def ref_masked_attention(query: torch.Tensor, * Attention result, batch_size x q_padded_seq_len x num_heads x head_size ''' + assert q_seq_lens is not None + assert kv_seq_lens is not None + batch_size = query.shape[0] assert (len(q_seq_lens) == batch_size) assert (len(kv_seq_lens) == batch_size) @@ -219,10 +222,10 @@ def make_qkv( num_heads: int, head_size: int, device: Union[torch.device, str], - force_kv_seq_lens: List[int] = None, + force_kv_seq_lens: Optional[List[int]] = None, attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len: bool = False, -) -> tuple[QKVInputs, QKVInputs, QKVInputs]: +) -> Tuple[QKVInputs, QKVInputs, QKVInputs]: ''' Construct QKV test tensors for self- and cross-attention. @@ -276,8 +279,9 @@ def make_qkv( kv_seq_lens = q_seq_lens else: # K,V seq lens are distinct from Q seq lens & random + assert max_kv_seq_len is not None if force_max_len: - kv_seq_lens = [max_kv_seq_len for _ in range(batch_size)] + kv_seq_lens = [max_kv_seq_len] * batch_size else: kv_seq_lens = [ random.randint(2, max_kv_seq_len) for _ in range(batch_size) @@ -350,7 +354,7 @@ def make_qkv( def pack_tensor( unpacked_tensor: torch.Tensor, seq_lens: List[int], - device: Union[torch.device, str]) -> tuple[torch.Tensor, List[int]]: + device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]: ''' Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an unpadded number_of_tokens x num_heads x head_size tensor, where @@ -454,10 +458,10 @@ def make_backend(backend_name: str) -> AttentionBackend: def _make_metadata_tensors( - seq_lens: List[int], context_lens: List[int], encoder_seq_lens: List[int], - device: Union[torch.device, str] -) -> tuple[torch.Tensor, torch.Tensor, int, int, Optional[List[int]], - torch.Tensor, int]: + seq_lens: Optional[List[int]], context_lens: Optional[List[int]], + encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str] +) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]], + torch.Tensor, Optional[int]]: ''' Build scalar & tensor values required to build attention metadata structure. @@ -603,7 +607,7 @@ def make_block_tables_slot_mapping( block_size: int, seq_lens: List[int], device: Union[torch.device, str], - block_base_addr: int = 0) -> tuple[torch.Tensor, List[int], int]: + block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]: ''' Construct fake block tables & slot mappings. @@ -685,8 +689,8 @@ def make_block_tables_slot_mapping( def make_test_metadata( attn_backend: AttentionBackend, is_prompt: bool, - seq_lens: List[int], - decoder_test_params: PhaseTestParameters, + seq_lens: Optional[List[int]], + decoder_test_params: Optional[PhaseTestParameters], default_attn_type: AttentionType, device: Union[torch.device, str], encoder_test_params: Optional[PhaseTestParameters] = None, @@ -765,6 +769,7 @@ def make_test_metadata( else: # Encoder/decoder or encoder-only models only: # * Extract encoder input sequence lengths + assert encoder_test_params.packed_qkvo.packed_qkv is not None encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens if cross_test_params is None: @@ -819,6 +824,10 @@ def make_test_metadata( else: # not is_prompt # Decode-phase scenario + assert kv_mmap is not None + assert num_prefill_or_decode_tokens is not None + assert seq_lens is not None + num_prefills = 0 num_prefill_tokens = 0 num_decode_tokens = num_prefill_or_decode_tokens From 830a051267732f60b04b99a15552ea984b9f43f8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 17 Jun 2024 01:16:25 -0400 Subject: [PATCH 236/443] format --- vllm/worker/enc_dec_model_runner.py | 18 +++++++++++------- vllm/worker/model_runner.py | 26 +++++++++++++++++++++----- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 17bf68160f98f..df6afaf1fad75 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -25,11 +25,12 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) -from vllm.worker.model_runner import ( - _PAD_SLOT_ID, LORA_WARMUP_RANK, _BATCH_SIZE_ALIGNMENT, - _BATCH_SIZES_TO_CAPTURE, _NUM_WARMUP_ITERS, ModelInput, ModelRunner, - _is_block_tables_empty, _get_graph_batch_size, CUDAGraphRunner, - _is_encoder_decoder_model) +from vllm.worker.model_runner import (_PAD_SLOT_ID, LORA_WARMUP_RANK, + _BATCH_SIZE_ALIGNMENT, + _BATCH_SIZES_TO_CAPTURE, + _NUM_WARMUP_ITERS, ModelInput, + ModelRunner, _is_block_tables_empty, + _get_graph_batch_size, CUDAGraphRunner) from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, STR_NOT_IMPL_ENC_DEC_SWA) from vllm.attention.backends.utils import STR @@ -38,11 +39,14 @@ # Error message if EncoderDecoderModelRunner is used with # a non-encoder/decoder model (i.e. decoder-only) -STR_ENCDECMR_ENCODER_DECODER_REQUIRED = "Only encoder/decoder models may be executed using EncoderDecoderModelRunner" +STR_ENCDECMR_ENCODER_DECODER_REQUIRED = \ + "Only encoder/decoder models may be executed " + \ + "using EncoderDecoderModelRunner" # Error message if EncoderDecoderModelRunner is used with # CUDAGraph -STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED = "Currently CUDAGraph is not supported for encoder/decoder models" +STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED = \ + "Currently CUDAGraph is not supported for encoder/decoder models" class EncoderDecoderModelInput(ModelInput): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 43775b101d7c3..7089b0e47e087 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -40,7 +40,8 @@ _NUM_WARMUP_ITERS = 2 # Error message if ModelRunner is used with an encoder/decoder model -STR_MR_ENCODER_DECODER_UNSUPPORTED = "Encoder/decoder model must be executed using EncoderDecoderModelRunner" +STR_MR_ENCODER_DECODER_UNSUPPORTED = \ + "Encoder/decoder model must be executed using EncoderDecoderModelRunner" class ModelInput(NamedTuple): @@ -278,7 +279,7 @@ def _prepare_seq_model_input( lora_requests: Set[LoRARequest] = set(), multi_modal_kwargs_list: Dict[str, List[torch.Tensor]] = defaultdict(list) - ) -> tuple[bool, int, int, int]: + ) -> Tuple[bool, int, int, int]: if is_prompt: context_len = seq_data.get_num_computed_tokens() @@ -354,9 +355,13 @@ def _prepare_seq_model_input( if seq_group_metadata.block_tables is not None: # chunked prefill or decode block_table = original_block_table + assert block_table is not None if curr_sliding_window_blocks is not None: block_table = block_table[-curr_sliding_window_blocks:] if self.attn_backend.get_name() == "flashinfer": + assert paged_kv_indices is not None + assert paged_kv_indptr is not None + assert paged_kv_last_page_len is not None paged_kv_indices.extend(block_table) paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table)) @@ -370,6 +375,8 @@ def _prepare_seq_model_input( else: # Prefill without chunked prefill or memory profiling. block_table = [] + + assert block_table is not None block_tables.append(block_table) seq_lens.append(sliding_seq_len) @@ -419,10 +426,14 @@ def _prepare_seq_model_input( # slot mapping. # In embeddings, the block tables are {seq_id: None}. slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - return decode_only, num_prefills, num_prefill_tokens, num_decode_tokens + return decode_only, \ + num_prefills, \ + num_prefill_tokens, \ + num_decode_tokens # Compute the slot mapping. block_table = original_block_table + assert block_table is not None # Mask the [0, start_idx) tokens of the prompt with # _PAD_SLOT_ID, where start_idx is max(0, seq_len - @@ -536,7 +547,10 @@ def _prepare_model_input( seq_data = seq_group_metadata.seq_data[seq_id] block_table = seq_group_metadata.block_tables[seq_id] - decode_only, num_prefills, num_prefill_tokens, num_decode_tokens = self._prepare_seq_model_input( + decode_only, \ + num_prefills, \ + num_prefill_tokens, \ + num_decode_tokens = self._prepare_seq_model_input( is_prompt, decode_only, num_prefills, num_prefill_tokens, num_decode_tokens, block_tables, seq_lens, slot_mapping, context_lens, query_lens, input_tokens, input_positions, @@ -932,7 +946,9 @@ def _is_encoder_decoder_model(self): return False. ''' return False if self.model_config is None else \ - getattr(self.model_config.hf_config, "is_encoder_decoder", False) + getattr(self.model_config.hf_config, + "is_encoder_decoder", + False) def _am_not_child(self): ''' From 89fdb811629bfe86ce5aaf85e078ce953e03e700 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Jun 2024 00:52:29 -0400 Subject: [PATCH 237/443] first pass at _prepare_encoder_model_input() --- vllm/worker/enc_dec_model_runner.py | 383 +++++----------------------- 1 file changed, 67 insertions(+), 316 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index df6afaf1fad75..4c743e5147396 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -109,21 +109,14 @@ def __init__( if self.scheduler_config.chunked_prefill_enabled: raise NotImplementedError() - def _prepare_model_input( + def _prepare_encoder_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], + attn_metadata: AttentionMetadata ) -> ModelInput: - """Prepare the model input based on a given sequence group. + """Prepare the encoder input based on a given sequence group. - The API assumes seq_group_metadata_list is sorted by prefill -> decode. - - The result tensors and data structure also batches input in prefill - -> decode order. For example, - - - input_tokens[:num_prefill_tokens] contains prefill tokens. - - input_tokens[num_prefill_tokens:] contains decode tokens. - - If cuda graph is required, this API automatically pads inputs. + Encoder attention is an entirely prefill-phase operation. """ input_tokens: List[int] = [] input_positions: List[int] = [] @@ -145,6 +138,11 @@ def _prepare_model_input( num_prefill_tokens = 0 num_decode_tokens = 0 + sliding_window_blocks = 0 + block_aligned_sliding_window = 0 + + is_prompt = True + # The following fields are only for flashinfer # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout # for the precise definition of the following fields. @@ -173,241 +171,46 @@ def _prepare_model_input( sliding_window_blocks * self.block_size for seq_group_metadata in seq_group_metadata_list: - seq_ids = list(seq_group_metadata.seq_data.keys()) - is_prompt = seq_group_metadata.is_prompt - - for seq_id in seq_ids: - computed_block_nums = seq_group_metadata.computed_block_nums - if (self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and not (computed_block_nums is None - or computed_block_nums == [])): - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") - - seq_data = seq_group_metadata.seq_data[seq_id] - if is_prompt: - context_len = seq_data.get_num_computed_tokens() - else: - # get_num_computed_tokens is incorrect for spec decoding. - # So, we should have a special logic here. - # TODO(sang): Fix it. - context_len = seq_data.get_len() - 1 - - seq_len = min( - seq_data.get_len(), - context_len + seq_group_metadata.token_chunk_size) - if is_prompt: - tokens = seq_data.get_token_ids()[context_len:seq_len] - else: - # Optimization. get_token_ids requires the entire copy of - # tokens. - tokens = [seq_data.get_last_token_id()] - - # Prefix cache was hit. - # Prefix is not supported with sliding_window - prefix_cache_hit = (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None - and is_prompt) - - # These are seq_len/context_len capped to the sliding window. - # They are passed to decode kernel. - # We still need original seq_len/context_len to compute slot - # mapping (and input position) below. - curr_sliding_window_blocks = None - sliding_seq_len = seq_len - sliding_context_len = context_len - - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - if (self.sliding_window is not None and not is_prompt): - curr_sliding_window_blocks = sliding_window_blocks - if self.scheduler_config.use_v2_block_manager: - # number of elements in last block - suff_len = seq_len % self.block_size - sliding_seq_len = min( - seq_len, block_aligned_sliding_window + suff_len) - if suff_len > 0: - curr_sliding_window_blocks += 1 - else: - sliding_seq_len = min(seq_len, self.sliding_window) - sliding_context_len = sliding_seq_len - 1 - - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - if prefix_cache_hit: - assert computed_block_nums is not None - context_len = len(computed_block_nums) * self.block_size - tokens = tokens[context_len:] - - # need to think what to set it to when we have both sliding - # window and prefix caching... - assert self.sliding_window is None, \ - "Prefix caching is not supported with sliding window" - sliding_context_len = context_len - - if self.attn_backend.get_name() == "flash-attn": - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - # TODO(woosuk): This is a temporary fix. We should - # provide a unified interface for different backends. - block_table = seq_group_metadata.block_tables[seq_id] - else: - block_table = computed_block_nums - elif (self.scheduler_config.chunked_prefill_enabled - or not is_prompt): - if seq_group_metadata.block_tables is not None: - # chunked prefill or decode - block_table = seq_group_metadata.block_tables[seq_id] - if curr_sliding_window_blocks is not None: - block_table = block_table[ - -curr_sliding_window_blocks:] - if self.attn_backend.get_name() == "flashinfer": - paged_kv_indices.extend(block_table) - paged_kv_indptr.append(paged_kv_indptr[-1] + - len(block_table)) - last_page_len = seq_data.get_len( - ) % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - paged_kv_last_page_len.append(last_page_len) - else: - # Only happens when memory profiling runs. - block_table = [] - else: - # Prefill without chunked prefill or memory profiling. - block_table = [] - block_tables.append(block_table) - - seq_lens.append(sliding_seq_len) - context_lens.append(sliding_context_len) - query_len = sliding_seq_len - sliding_context_len - query_lens.append(query_len) - input_tokens.extend(tokens) - input_positions.extend(list(range(context_len, seq_len))) - lora_id = seq_group_metadata.lora_int_id - - if is_prompt: - assert len(seq_ids) == 1 - num_prefills += 1 - num_prefill_tokens += len(tokens) - decode_only = False - prefill_seq_lens.append(seq_len) - else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) - num_decode_tokens += query_len - decode_seq_lens.append(sliding_seq_len) - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * query_len - lora_prompt_mapping.extend( - [lora_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - is not None else 1)) - - mm_data = seq_group_metadata.multi_modal_data - if mm_data is not None: - # Process multi-modal data - if self.multi_modal_input_processor is None: - raise ValueError( - "Multi-modal inputs are only supported by " - "vision language models.") - - mm_kwargs = self.multi_modal_input_processor(mm_data) - for k, v in mm_kwargs.items(): - multi_modal_kwargs_list[k].append(v) - - if _is_block_tables_empty(seq_group_metadata.block_tables): - # During memory profiling, the block tables are not - # initialized yet. In this case, we just use a dummy - # slot mapping. - # In embeddings, the block tables are {seq_id: None}. - slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - - # Mask the [0, start_idx) tokens of the prompt with - # _PAD_SLOT_ID, where start_idx is max(0, seq_len - - # sliding_window). For example, if the prompt len is 10, - # sliding window is 8, and block size is 4, the first two - # tokens are masked and the slot mapping will be - # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - if is_prompt: - assert self.scheduler_config.use_v2_block_manager \ - or context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention in V1 block manager") - # It is an optimization. When it is decoding, it is always - # 0. When prefill, we use it to not write slots to kv cache - # to save memory. - start_idx = max(0, query_len - self.sliding_window) - - for i in range(context_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - batch_size = len(input_tokens) + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + seq_data = seq_group_metadata.encoder_seq_data + block_table = seq_group_metadata.cross_block_table + decode_only, \ + num_prefills, \ + num_prefill_tokens, \ + num_decode_tokens = self._prepare_seq_model_input( + is_prompt, decode_only, num_prefills, num_prefill_tokens, + num_decode_tokens, block_tables, seq_lens, slot_mapping, + context_lens, query_lens, input_tokens, input_positions, + prefill_seq_lens, decode_seq_lens, seq_group_metadata, + seq_data, computed_block_nums, block_table, + paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, + sliding_window_blocks, block_aligned_sliding_window, + lora_index_mapping, lora_prompt_mapping, lora_requests, + multi_modal_kwargs_list) + max_query_len = max(query_lens) - max_prefill_seq_len = max(prefill_seq_lens, default=0) - max_decode_seq_len = max(decode_seq_lens, default=0) - - # If cuda graph can be used, pad tensors accordingly. - # See `capture_model` API for more details. - # vLLM uses cuda graph only for decoding requests. - use_captured_graph = ( - decode_only and not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_decode_seq_len <= self.max_seq_len_to_capture) - if use_captured_graph: - graph_batch_size = _get_graph_batch_size(batch_size) - assert graph_batch_size >= batch_size - for _ in range(graph_batch_size - batch_size): - input_tokens.append(0) - input_positions.append(0) - slot_mapping.append(_PAD_SLOT_ID) - seq_lens.append(1) - block_tables.append([]) - lora_index_mapping.append(0) - batch_size = graph_batch_size - num_decode_tokens = batch_size - - if use_captured_graph: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = self.graph_block_tables[:batch_size] - for i, block_table in enumerate(block_tables): - if block_table: - input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=self.device) - else: - max_block_table_len = max( - len(block_table) for block_table in block_tables) - block_tables = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) + max_seq_len = max(prefill_seq_lens, default=0) + + # Assume Eager Mode + # TODO: CUDA Graph support + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) seq_lens_tensor = torch.tensor(seq_lens, @@ -431,75 +234,23 @@ def _prepare_model_input( slot_mapping_tensor = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) - - if self.attn_backend.get_name() == "flashinfer": - if not hasattr(self, "flashinfer_workspace_buffer"): - # Allocate 16MB workspace buffer - # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html - self.flashinfer_workspace_buffer = torch.empty( - 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) - paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, - dtype=torch.int, - device=self.device) - paged_kv_indices_tensor = torch.tensor(paged_kv_indices, - dtype=torch.int, - device=self.device) - paged_kv_last_page_len_tensor = torch.tensor( - paged_kv_last_page_len, dtype=torch.int, device=self.device) - kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, - self.model_config.dtype) - attn_metadata = self.attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - use_cuda_graph=False, - max_prefill_seq_len=max_prefill_seq_len, - block_tables=block_tables, - workspace_buffer=self.flashinfer_workspace_buffer, - paged_kv_indptr=paged_kv_indptr_tensor, - paged_kv_indices=paged_kv_indices_tensor, - paged_kv_last_page_len=paged_kv_last_page_len_tensor, - num_qo_heads=self.model_config.get_num_attention_heads( - self.parallel_config), - num_kv_heads=self.model_config.get_num_kv_heads( - self.parallel_config), - head_dim=self.model_config.get_head_size(), - page_size=16, - seq_start_loc=seq_start_loc, - data_type=kv_cache_dtype) - else: - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) - - attn_metadata = self.attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + attn_metadata.encoder_seq_lens = seq_lens + attn_metadata.encoder_seq_lens_tensor = seq_lens_tensor + attn_metadata.max_encoder_seq_len = max_seq_len + attn_metadata.cross_slot_mapping = slot_mapping_tensor + attn_metadata.cross_block_tables = block_tables if self.lora_config: lora_mapping = LoRAMapping( @@ -651,8 +402,8 @@ def profile_run(self) -> None: # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request # passed in, which contains a lora from the lora warmup path. - dummy_lora_requests = [] - dummy_lora_requests_per_seq = [] + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] if self.lora_config: assert self.lora_manager is not None with self.lora_manager.dummy_lora_cache(): From 7b9cb7f4339364b66180bf5cf7015f8fea67479d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Jun 2024 11:01:05 -0400 Subject: [PATCH 238/443] Replace attn_metadata.attention_type and attn_metadata._attn_type with attn_type argument to forward() --- tests/kernels/test_encoder_decoder_attn.py | 15 +-- tests/kernels/utils.py | 5 - vllm/attention/backends/xformers.py | 138 ++++++++++----------- vllm/attention/layer.py | 15 ++- 4 files changed, 85 insertions(+), 88 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index ef6c0fa9876b1..de33840bf57dd 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -593,11 +593,11 @@ def _run_encoder_attention_test(attn: Attention, & attn_metadata ''' assert attn_metadata.num_decode_tokens == 0 - attn_metadata.attention_type = AttentionType.ENCODER + attn_type=AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, - None, attn_metadata) + None, attn_metadata, attn_type=attn_type) def _run_decoder_self_attention_test(test_rsrcs: TestResources, @@ -625,13 +625,13 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata ''' + attn_type = AttentionType.DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache - attn_metadata.attention_type = AttentionType.DECODER packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, - kv_cache, attn_metadata) + kv_cache, attn_metadata, attn_type=attn_type) def _run_encoder_decoder_cross_attention_test( @@ -672,7 +672,7 @@ def _run_encoder_decoder_cross_attention_test( ''' assert decoder_test_params.packed_qkvo.packed_qkv is not None - attn_metadata.attention_type = AttentionType.ENCODER_DECODER + attn_type = AttentionType.ENCODER_DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache if cross_test_params is None: @@ -685,7 +685,7 @@ def _run_encoder_decoder_cross_attention_test( value = None if cross_pckd_qkv is None else \ cross_pckd_qkv.value return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, - value, kv_cache, attn_metadata) + value, kv_cache, attn_metadata, attn_type=attn_type) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -727,7 +727,6 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str, decoder_test_params=None, encoder_test_params=enc_test_params, cross_test_params=None, - default_attn_type=AttentionType.ENCODER, device=CUDA_DEVICE) # PREFILL: encoder attention @@ -855,7 +854,6 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, decoder_test_params=prephase_dec_test_params, encoder_test_params=enc_test_params, cross_test_params=prephase_cross_test_params, - default_attn_type=AttentionType.ENCODER, device=CUDA_DEVICE) # PREFILL: encoder attention @@ -903,7 +901,6 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, decoder_test_params=decphase_dec_test_params, encoder_test_params=enc_test_params, cross_test_params=decphase_cross_test_params, - default_attn_type=AttentionType.DECODER, device=CUDA_DEVICE) # DECODE: decoder self-attention test diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index ffa6b69ef2374..49232b209a186 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -691,7 +691,6 @@ def make_test_metadata( is_prompt: bool, seq_lens: Optional[List[int]], decoder_test_params: Optional[PhaseTestParameters], - default_attn_type: AttentionType, device: Union[torch.device, str], encoder_test_params: Optional[PhaseTestParameters] = None, cross_test_params: Optional[PhaseTestParameters] = None @@ -719,8 +718,6 @@ def make_test_metadata( * decoder_test_params: decoder self-attention test params; this function requires kv_mmap (memory mapping) field - * default_attn_type: value of attn_metadata.attention_type at - construction time * device: CPU or CUDA device * encoder_test_params: encoder attention test params; this function requires encoder query @@ -812,7 +809,6 @@ def make_test_metadata( block_tables=None if kv_mmap is None else \ kv_mmap.block_tables, use_cuda_graph=False, - _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, @@ -855,7 +851,6 @@ def make_test_metadata( context_lens_tensor=context_lens_tensor, block_tables=kv_mmap.block_tables, use_cuda_graph=False, - _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d03417f071510..832cd561c9932 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -119,25 +119,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Begin encoder attn & enc/dec cross-attn fields... - # Attention type enum. - # - # * Impact on XFormersImpl.forward(): - # - # * DECODER: normal decoder-only behavior; - # use decoder self-attention block table - # * ENCODER: no KV caching; pass encoder sequence - # attributes (encoder_seq_lens/encoder_seq_lens_tensor/ - # max_encoder_seq_len) to kernel, in lieu of decoder - # sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) - # * ENCODER_DECODER: cross-attention behavior; - # use cross-attention block table for caching KVs derived - # from encoder hidden states; since KV sequence lengths - # will match encoder sequence lengths, pass encoder sequence - # attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ - # max_encoder_seq_len) - # - _attn_type: AttentionType = AttentionType.DECODER - # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None @@ -180,24 +161,6 @@ def is_all_cross_attn_metadata_set(self): (self.cross_slot_mapping is not None) and \ (self.cross_block_tables is not None) - @property - def attention_type(self) -> AttentionType: - return self._attn_type - - @attention_type.setter - def attention_type(self, atype: AttentionType) -> None: - - if atype == AttentionType.ENCODER_DECODER: - assert self.is_all_cross_attn_metadata_set, \ - "Must set self.encoder_seq_lens, self.cross_slot_mapping, " + \ - "self.cross_block_tables in order to perform cross-attention" - - elif atype == AttentionType.ENCODER: - assert self.is_all_encoder_attn_metadata_set, \ - "Must set self.encoder_seq_lens in order to perform cross-attention" - - self._attn_type = atype - @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self.num_prefills == 0: @@ -206,8 +169,6 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self._cached_prefill_metadata is not None: # Recover cached prefill-phase attention # metadata structure - self._cached_prefill_metadata.attention_type = \ - self.attention_type return self._cached_prefill_metadata assert (self.seq_lens is not None) or \ @@ -244,7 +205,6 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - _attn_type=self.attention_type, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, @@ -261,8 +221,6 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: if self._cached_decode_metadata is not None: # Recover cached decode-phase attention # metadata structure - self._cached_decode_metadata.attention_type = \ - self.attention_type return self._cached_decode_metadata assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) @@ -286,8 +244,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: max_decode_seq_len=self.max_decode_seq_len, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, - _attn_type=self. - _attn_type, # Begin encoder & cross attn fields below... + # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, @@ -295,7 +252,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata -def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ +def _get_attn_bias(attn_metadata: XFormersMetadata, + attn_type: AttentionType) -> \ Optional[AttentionBias]: ''' Extract appropriate attention bias from attention metadata @@ -304,12 +262,13 @@ def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ Arguments: * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention Returns: - * Appropriate attention bias value + * Appropriate attention bias value given the attention type ''' - attn_type = attn_metadata.attention_type if attn_type == AttentionType.DECODER: return attn_metadata.attn_bias elif attn_type == AttentionType.ENCODER: @@ -318,24 +277,24 @@ def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ return attn_metadata.cross_attn_bias else: raise AttributeError( - f"Invalid attn_metadata.attention_type {str(attn_type)}") + f"Invalid attention type {str(attn_type)}") def _set_attn_bias(attn_metadata: XFormersMetadata, - attn_bias: List[Optional[AttentionBias]]) -> None: + attn_bias: List[Optional[AttentionBias]], + attn_type: AttentionType) -> None: ''' Update appropriate attention bias field of attention metadata, according to attention type. - Depends on attn_metadata having a valid attention_type. - Arguments: * attn_metadata: Attention metadata structure associated with attention * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention ''' - attn_type = attn_metadata.attention_type if attn_type == AttentionType.DECODER: attn_metadata.attn_bias = attn_bias elif attn_type == AttentionType.ENCODER: @@ -344,11 +303,13 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, attn_metadata.cross_attn_bias = attn_bias else: raise AttributeError( - f"Invalid attn_metadata.attention_type {str(attn_type)}") + f"Invalid attention type {str(attn_type)}") def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, - is_prompt: bool) -> tuple: + is_prompt: bool, + attn_type: AttentionType) \ + -> tuple: ''' The particular choice of sequence-length- and block-table-related attributes which should be extracted from attn_metadata is dependent @@ -362,6 +323,9 @@ def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, Arguments: * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention Returns: @@ -370,7 +334,6 @@ def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, * Appropriate block tables (or None) ''' - attn_type = attn_metadata.attention_type if attn_type == AttentionType.DECODER: # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run @@ -394,7 +357,7 @@ def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, None else: raise AttributeError( - f"Invalid attn_metadata.attention_type {str(attn_type)}") + f"Invalid attention type {str(attn_type)}") class XFormersImpl(AttentionImpl[XFormersMetadata]): @@ -463,6 +426,7 @@ def forward( kv_cache: Optional[torch.Tensor], attn_metadata: "XFormersMetadata", kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -481,15 +445,48 @@ def forward( (2) cross-attention key and value tensors do not grow during decode + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally Returns: shape = [num_tokens, num_heads * head_size] """ + + # Check that appropriate attention metadata attributes are + # selected for the desired attention type + if attn_type == AttentionType.ENCODER: + if not attn_metadata.is_all_encoder_attn_metadata_set: + raise AttributeError("Encoder attention requires setting " + \ + "encoder metadata attributes.") + elif attn_type == AttentionType.ENCODER_DECODER: + if not attn_metadata.is_all_cross_attn_metadata_set: + raise AttributeError("Encoder/decoder cross-attention " + \ + "requires setting cross-attention " + \ + "metadata attributes.") + query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None @@ -501,7 +498,6 @@ def forward( # Self-attention vs. cross-attention will impact # which KV cache memory-mapping & which # seqlen datastructures we utilize - attn_type = attn_metadata.attention_type if (attn_type != AttentionType.ENCODER and \ kv_cache is not None): @@ -536,7 +532,7 @@ def forward( self.kv_cache_dtype, kv_scale) - if attn_metadata.attention_type != AttentionType.ENCODER: + if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. # Encoder/decoder cross-attention requires no chunked # prefill (100% prefill or 100% decode tokens, no mix) @@ -576,7 +572,7 @@ def forward( # block tables are empty if the prompt does not have a cached # prefix. out = self._run_memory_efficient_xformers_forward( - query, key, value, prefill_meta) + query, key, value, prefill_meta, attn_type=attn_type) assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: @@ -609,7 +605,9 @@ def forward( seq_lens_arg, \ max_seq_len_arg,\ - block_tables_arg = _get_seq_len_block_table_args(decode_meta, False) + block_tables_arg = _get_seq_len_block_table_args(decode_meta, + False, + attn_type) output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, @@ -634,6 +632,7 @@ def _run_memory_efficient_xformers_forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: XFormersMetadata, + attn_type: AttentionType = AttentionType.DECODER ) -> torch.Tensor: """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. @@ -647,15 +646,12 @@ def _run_memory_efficient_xformers_forward( key: shape = [num_prefill_tokens, num_kv_heads, head_size] value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally """ - # Enforce that the appropriate *_seq_lens attribute of attn_metadata - # (seq_lens or encoder_seq_lens) is set. - # seq_lens, \ - # _,\ - # _ = _get_seq_len_block_table_args(attn_metadata, True) - # assert seq_lens is not None - original_query = query if self.num_kv_heads != self.num_heads: # GQA/MQA requires the shape [B, M, G, H, K]. @@ -673,10 +669,10 @@ def _run_memory_efficient_xformers_forward( # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. - attn_bias = _get_attn_bias(attn_metadata) + attn_bias = _get_attn_bias(attn_metadata,attn_type) if attn_bias is None: if self.alibi_slopes is None: - if attn_metadata.attention_type == \ + if attn_type == \ AttentionType.ENCODER_DECODER: assert attn_metadata.seq_lens is not None assert attn_metadata.encoder_seq_lens is not None @@ -685,7 +681,7 @@ def _run_memory_efficient_xformers_forward( attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) else: - if attn_metadata.attention_type == AttentionType.ENCODER: + if attn_type == AttentionType.ENCODER: assert attn_metadata.encoder_seq_lens is not None # Default encoder self-attention mask is non-causal @@ -707,7 +703,7 @@ def _run_memory_efficient_xformers_forward( self.num_kv_heads, query.dtype, attn_metadata.seq_lens) - _set_attn_bias(attn_metadata, attn_bias) + _set_attn_bias(attn_metadata, attn_bias, attn_type) # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index db55a31476fed..77be19772601f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.abstract import AttentionMetadata, AttentionType from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import ( @@ -85,9 +85,18 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, + attn_type: Optional[AttentionType] = None ) -> torch.Tensor: - return self.impl.forward(query, key, value, kv_cache, attn_metadata, - self._kv_scale) + if attn_type is None: + # Support backends without an attention type argument + return self.impl.forward(query, key, value, kv_cache, attn_metadata, + self._kv_scale) + else: + # Backends with encoder/decoder support require attention + # type argument to distinguish between encoder attention, + # decoder self-attention, or encoder/decoder cross-attention + return self.impl.forward(query, key, value, kv_cache, attn_metadata, + self._kv_scale, attn_type=attn_type) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore From 5f8c7f6cd6776cbda8289a5cee28e5cd8b858f4d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Jun 2024 11:26:24 -0400 Subject: [PATCH 239/443] Moved attention type for attn_metadata to attention forward(); added NotImplement failures to backends in non-decoder-only scenarios --- tests/kernels/test_encoder_decoder_attn.py | 26 ++++++--- vllm/attention/backends/abstract.py | 16 +++--- vllm/attention/backends/blocksparse_attn.py | 24 +++++--- vllm/attention/backends/flash_attn.py | 24 +++++--- vllm/attention/backends/flashinfer.py | 23 +++++--- vllm/attention/backends/ipex_attn.py | 23 +++++--- vllm/attention/backends/pallas.py | 23 +++++--- vllm/attention/backends/rocm_flash_attn.py | 24 +++++--- vllm/attention/backends/torch_sdpa.py | 23 +++++--- vllm/attention/backends/xformers.py | 63 ++++++++++----------- vllm/attention/layer.py | 35 ++++++------ 11 files changed, 173 insertions(+), 131 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index de33840bf57dd..f61b0a0dcc706 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -593,11 +593,15 @@ def _run_encoder_attention_test(attn: Attention, & attn_metadata ''' assert attn_metadata.num_decode_tokens == 0 - attn_type=AttentionType.ENCODER + attn_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, - None, attn_metadata, attn_type=attn_type) + return attn.forward(packed_qkv.query, + packed_qkv.key, + packed_qkv.value, + None, + attn_metadata, + attn_type=attn_type) def _run_decoder_self_attention_test(test_rsrcs: TestResources, @@ -630,8 +634,12 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, - kv_cache, attn_metadata, attn_type=attn_type) + return attn.forward(packed_qkv.query, + packed_qkv.key, + packed_qkv.value, + kv_cache, + attn_metadata, + attn_type=attn_type) def _run_encoder_decoder_cross_attention_test( @@ -684,8 +692,12 @@ def _run_encoder_decoder_cross_attention_test( cross_pckd_qkv.key value = None if cross_pckd_qkv is None else \ cross_pckd_qkv.value - return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, - value, kv_cache, attn_metadata, attn_type=attn_type) + return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, + key, + value, + kv_cache, + attn_metadata, + attn_type=attn_type) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index ece0da25ee6f2..1ac8efc6b2584 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -124,12 +124,12 @@ def __init__( @abstractmethod def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: T, - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index dce2b83615b7a..2afa4d286900e 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -4,7 +4,7 @@ import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn, get_head_sliding_step) from vllm.attention.ops.paged_attn import PagedAttention @@ -321,14 +321,14 @@ def __init__( ) def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: BlocksparseFlashAttentionMetadata, - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: BlocksparseFlashAttentionMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. Args: @@ -340,6 +340,12 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "BlocksparseFlashAttentionImpl") + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 1c48e2a0bb33d..16098ca68213d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -7,7 +7,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) class FlashAttentionBackend(AttentionBackend): @@ -250,14 +250,14 @@ def __init__( f"Supported head sizes are: {support_head_sizes}.") def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with FlashAttention. Args: @@ -269,6 +269,12 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "FlashAttentionImpl") + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 7b7959d257fac..2a7db5a35382e 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -8,7 +8,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) class FlashInferBackend(AttentionBackend): @@ -185,15 +185,20 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: FlashInferMetadata, - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: FlashInferMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: assert kv_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "FlashInferImpl") num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index f09b24f2a0304..c3328d6ed7665 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -7,7 +7,7 @@ from vllm._ipex_ops import ipex_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -150,14 +150,14 @@ def split_kv_cache( return key_cache, value_cache def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: IpexAttnMetadata, # type: ignore - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: IpexAttnMetadata, # type: ignore + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. Args: @@ -170,6 +170,11 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert kv_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "IpexAttnBackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index b203c5ec54c92..a453f642509a6 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -6,7 +6,7 @@ import torch_xla.experimental.dynamo_set_buffer_donor from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) class PallasAttentionBackend(AttentionBackend): @@ -120,14 +120,14 @@ def __init__( self.megacore_mode = "batch" def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], - attn_metadata: PallasMetadata, - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], + attn_metadata: PallasMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with Pallas attention. Args: @@ -141,6 +141,11 @@ def forward( shape = [batch_size, seq_len, num_heads * head_size] """ assert kv_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "PallasAttentionBackendImpl") batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 9294068c64d1a..9a1d90107745d 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -6,7 +6,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -259,14 +259,14 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: head_dim)) def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: ROCmFlashAttentionMetadata, - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: ROCmFlashAttentionMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. Args: @@ -278,6 +278,12 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "ROCmFlashAttentionImpl") + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c01e0a0a3a19c..efafc233da7df 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,7 +7,7 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.utils import is_cpu @@ -138,14 +138,14 @@ def __init__( "Please use xFormers backend instead.") def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: TorchSDPAMetadata, # type: ignore - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: TorchSDPAMetadata, # type: ignore + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. Args: @@ -158,6 +158,11 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert kv_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "TorchSDPABackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 832cd561c9932..bf4d755d2a72f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -276,8 +276,7 @@ def _get_attn_bias(attn_metadata: XFormersMetadata, elif attn_type == AttentionType.ENCODER_DECODER: return attn_metadata.cross_attn_bias else: - raise AttributeError( - f"Invalid attention type {str(attn_type)}") + raise AttributeError(f"Invalid attention type {str(attn_type)}") def _set_attn_bias(attn_metadata: XFormersMetadata, @@ -302,8 +301,7 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, elif attn_type == AttentionType.ENCODER_DECODER: attn_metadata.cross_attn_bias = attn_bias else: - raise AttributeError( - f"Invalid attention type {str(attn_type)}") + raise AttributeError(f"Invalid attention type {str(attn_type)}") def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, @@ -356,8 +354,7 @@ def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, attn_metadata.max_encoder_seq_len, \ None else: - raise AttributeError( - f"Invalid attention type {str(attn_type)}") + raise AttributeError(f"Invalid attention type {str(attn_type)}") class XFormersImpl(AttentionImpl[XFormersMetadata]): @@ -419,15 +416,14 @@ def __init__( f"Supported head sizes are: {suppored_head_sizes}.") def forward( - self, - query: torch.Tensor, - key: Optional[torch.Tensor], - value: Optional[torch.Tensor], - kv_cache: Optional[torch.Tensor], - attn_metadata: "XFormersMetadata", - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor], + attn_metadata: "XFormersMetadata", + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. For decoder-only models: query, key and value must be non-None. @@ -477,15 +473,15 @@ def forward( # Check that appropriate attention metadata attributes are # selected for the desired attention type - if attn_type == AttentionType.ENCODER: - if not attn_metadata.is_all_encoder_attn_metadata_set: - raise AttributeError("Encoder attention requires setting " + \ - "encoder metadata attributes.") - elif attn_type == AttentionType.ENCODER_DECODER: - if not attn_metadata.is_all_cross_attn_metadata_set: - raise AttributeError("Encoder/decoder cross-attention " + \ - "requires setting cross-attention " + \ - "metadata attributes.") + if attn_type == AttentionType.ENCODER and \ + (not attn_metadata.is_all_encoder_attn_metadata_set): + raise AttributeError("Encoder attention requires setting " + \ + "encoder metadata attributes.") + elif attn_type == AttentionType.ENCODER_DECODER and \ + (not attn_metadata.is_all_cross_attn_metadata_set): + raise AttributeError("Encoder/decoder cross-attention " + \ + "requires setting cross-attention " + \ + "metadata attributes.") query = query.view(-1, self.num_heads, self.head_size) if key is not None: @@ -605,8 +601,8 @@ def forward( seq_lens_arg, \ max_seq_len_arg,\ - block_tables_arg = _get_seq_len_block_table_args(decode_meta, - False, + block_tables_arg = _get_seq_len_block_table_args(decode_meta, + False, attn_type) output[num_prefill_tokens:] = PagedAttention.forward_decode( @@ -627,13 +623,12 @@ def forward( return output.view(-1, self.num_heads * self.head_size) def _run_memory_efficient_xformers_forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: XFormersMetadata, - attn_type: AttentionType = AttentionType.DECODER - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: XFormersMetadata, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. @@ -669,7 +664,7 @@ def _run_memory_efficient_xformers_forward( # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. - attn_bias = _get_attn_bias(attn_metadata,attn_type) + attn_bias = _get_attn_bias(attn_metadata, attn_type) if attn_bias is None: if self.alibi_slopes is None: if attn_type == \ diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 77be19772601f..984c8d77b94e7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -78,25 +78,22 @@ def __init__( alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params) - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata, - attn_type: Optional[AttentionType] = None - ) -> torch.Tensor: - if attn_type is None: - # Support backends without an attention type argument - return self.impl.forward(query, key, value, kv_cache, attn_metadata, - self._kv_scale) - else: - # Backends with encoder/decoder support require attention - # type argument to distinguish between encoder attention, - # decoder self-attention, or encoder/decoder cross-attention - return self.impl.forward(query, key, value, kv_cache, attn_metadata, - self._kv_scale, attn_type=attn_type) + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + attn_type: AttentionType = AttentionType.DECODER) \ + -> torch.Tensor: + + return self.impl.forward(query, + key, + value, + kv_cache, + attn_metadata, + self._kv_scale, + attn_type=attn_type) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore From 525303c7c61127900680ff06b6cc09610001b71e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Jun 2024 18:06:33 -0400 Subject: [PATCH 240/443] num encoder tokens --- tests/kernels/utils.py | 5 +++++ vllm/attention/backends/xformers.py | 11 +++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 49232b209a186..94e7379123c7c 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -763,11 +763,14 @@ def make_test_metadata( if encoder_test_params is None: encoder_seq_lens = None + num_encoder_tokens = None else: # Encoder/decoder or encoder-only models only: # * Extract encoder input sequence lengths assert encoder_test_params.packed_qkvo.packed_qkv is not None encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens + num_encoder_tokens = None if encoder_seq_lens is None else \ + (sum(encoder_seq_lens)) if cross_test_params is None: cross_kv_mmap = None @@ -809,6 +812,7 @@ def make_test_metadata( block_tables=None if kv_mmap is None else \ kv_mmap.block_tables, use_cuda_graph=False, + num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, @@ -851,6 +855,7 @@ def make_test_metadata( context_lens_tensor=context_lens_tensor, block_tables=kv_mmap.block_tables, use_cuda_graph=False, + num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index bf4d755d2a72f..2cd61d0161f9e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -68,9 +68,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): updated from `CUDAGraphRunner.forward` API. """ - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| @@ -78,6 +75,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # |-------------------- seq_len ----------------------| # |-- query_len ---| + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + # FIXME: It is for flash attn. # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. @@ -126,6 +126,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + # Cross-attention memory-mapping data structures: slot mapping # and block tables cross_slot_mapping: Optional[torch.Tensor] = None @@ -538,7 +541,7 @@ def forward( # Encoder attention - chunked prefill is not applicable; # derive token-count from query shape & and treat them # as 100% prefill tokens - num_prefill_tokens = query.shape[0] + num_prefill_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = 0 if attn_type == AttentionType.DECODER: From ea37e17ab5ad7c084c13bf8e8492039d6a9bcdbf Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Jun 2024 19:16:38 -0400 Subject: [PATCH 241/443] merge conflict; typing; formatting --- vllm/attention/backends/xformers.py | 1 + vllm/utils.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 2cd61d0161f9e..a3f3d41a5491c 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -541,6 +541,7 @@ def forward( # Encoder attention - chunked prefill is not applicable; # derive token-count from query shape & and treat them # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = 0 diff --git a/vllm/utils.py b/vllm/utils.py index edad75a6904b5..ebcd2181c1086 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -825,6 +825,7 @@ def make_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ float('-inf')).masked_fill(mask == 0, 0.0) return mask + #From: https://stackoverflow.com/a/4104188/2749989 def run_once(f): @@ -834,4 +835,4 @@ def wrapper(*args, **kwargs) -> Any: return f(*args, **kwargs) wrapper.has_run = False # type: ignore[attr-defined] - return wrapper \ No newline at end of file + return wrapper From e3ba61e368f0085fe64e8dae3d80494f5254164c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Jun 2024 22:44:23 -0400 Subject: [PATCH 242/443] wip --- vllm/worker/enc_dec_model_runner.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 4c743e5147396..d61c80af0f297 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -110,10 +110,8 @@ def __init__( raise NotImplementedError() def _prepare_encoder_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - attn_metadata: AttentionMetadata - ) -> ModelInput: + self, seq_group_metadata_list: List[SequenceGroupMetadata], + attn_metadata: AttentionMetadata) -> ModelInput: """Prepare the encoder input based on a given sequence group. Encoder attention is an entirely prefill-phase operation. @@ -175,7 +173,7 @@ def _prepare_encoder_model_input( if (self.scheduler_config is not None and self.scheduler_config.chunked_prefill_enabled and not (computed_block_nums is None - or computed_block_nums == [])): + or computed_block_nums == [])): raise RuntimeError( "chunked prefill cannot be used with prefix caching " "now.") @@ -235,16 +233,16 @@ def _prepare_encoder_model_input( dtype=torch.long, device=self.device) query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) + dtype=torch.long, + device=self.device) query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) + dtype=torch.int32, + device=self.device) torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) attn_metadata.encoder_seq_lens = seq_lens attn_metadata.encoder_seq_lens_tensor = seq_lens_tensor From 37aeed66141b10b0d43c8e6d56613806dc7108ff Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Jun 2024 23:35:11 -0400 Subject: [PATCH 243/443] enc dec model runner testable if only for encoder decoder model --- .../test_encoder_decoder_model_runner.py | 174 ++++++++++++++++++ vllm/worker/enc_dec_model_runner.py | 118 +++++------- 2 files changed, 220 insertions(+), 72 deletions(-) create mode 100644 tests/worker/test_encoder_decoder_model_runner.py diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py new file mode 100644 index 0000000000000..04f37b8757509 --- /dev/null +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -0,0 +1,174 @@ +from typing import List + +import pytest +import torch + +from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner + + +def _create_model_runner(model: str, *args, **kwargs) -> EncoderDecoderModelRunner: + engine_args = EngineArgs(model, *args, **kwargs) + engine_config = engine_args.create_engine_config() + model_runner = EncoderDecoderModelRunner( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + lora_config=engine_config.lora_config, + is_driver_worker=True, + ) + return model_runner + + +@pytest.mark.parametrize("batch_size", list(range(1, 257))) +def test_prepare_prompt(batch_size): + model_runner = _create_model_runner( + "facebook/opt-125m", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + ) + + seq_lens: List[int] = [] + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + block_tables = {0: [1]} + for i in range(batch_size): + # make sure all tokens fit into one block + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + ) + assert seq_group_metadata.token_chunk_size == seq_data.get_len() + seq_group_metadata_list.append(seq_group_metadata) + + expected_selected_token_indices = [] + selected_token_start_idx = 0 + for seq_len in seq_lens: + expected_selected_token_indices.append(selected_token_start_idx + + seq_len - 1) + selected_token_start_idx += seq_len + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = model_input.slot_mapping + assert return_seq_lens == seq_lens + assert len(slot_mapping) == len(input_tokens) + + # Verify input metadata is correct for prompts. + device = model_runner.device + assert attn_metadata.num_prefills > 0 + assert attn_metadata.num_decode_tokens == 0 + assert torch.allclose( + attn_metadata.seq_lens_tensor, + torch.tensor(seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.max_prefill_seq_len == max(seq_lens) + assert attn_metadata.max_decode_seq_len == 0 + + # Test subquery start locs. + start_idx = 0 + start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.query_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + + # Test seq start locs. Note that for normal prefill it is + # equivalent to query_start_loc. + start_idx = 0 + seq_start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + seq_start_loc.append(start_idx) + + assert torch.allclose( + attn_metadata.seq_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + assert torch.allclose( + attn_metadata.context_lens_tensor, + torch.zeros(attn_metadata.context_lens_tensor.shape[0], + dtype=torch.int, + device=device)) + + expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.allclose(attn_metadata.block_tables, expected) + # Cuda graph should not be used for prerill. + assert attn_metadata.use_cuda_graph is False + + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) + torch.testing.assert_close(input_tokens, input_positions) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens=seq_lens, + device=model_runner.device, + pin_memory=model_runner.pin_memory) + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) + torch.allclose(input_tokens, input_positions) + + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) + +def test_empty_seq_group(): + """Verify prepare prompt and decode returns empty output.""" + model_runner = _create_model_runner( + "facebook/opt-125m", + seed=0, + dtype="float16", + enforce_eager=False, + ) + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, slot_mapping = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + ) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, slot_mapping, + return_seq_lens) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + model_input.seq_lens, + ) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + assert len(return_seq_lens) == 0 \ No newline at end of file diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index d61c80af0f297..87ee9f583f6d4 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -1,39 +1,25 @@ -import gc -import time -import warnings from collections import defaultdict -from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union +from typing import Dict, List, NamedTuple, Optional, Set, Tuple -import numpy as np import torch -import torch.nn as nn -from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention import AttentionMetadata from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict -from vllm.distributed.communication_op import graph_capture +# from vllm.distributed.communication_op import graph_capture from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata -from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, - is_pin_memory_available, make_tensor_with_pad) -from vllm.worker.model_runner import (_PAD_SLOT_ID, LORA_WARMUP_RANK, - _BATCH_SIZE_ALIGNMENT, - _BATCH_SIZES_TO_CAPTURE, - _NUM_WARMUP_ITERS, ModelInput, - ModelRunner, _is_block_tables_empty, - _get_graph_batch_size, CUDAGraphRunner) -from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - STR_NOT_IMPL_ENC_DEC_SWA) -from vllm.attention.backends.utils import STR +from vllm.utils import make_tensor_with_pad +from vllm.worker.model_runner import (LORA_WARMUP_RANK, + ModelInput, + ModelRunner) logger = init_logger(__name__) @@ -49,35 +35,15 @@ "Currently CUDAGraph is not supported for encoder/decoder models" -class EncoderDecoderModelInput(ModelInput): +class EncoderInput(NamedTuple): input_tokens: torch.Tensor input_positions: torch.Tensor - attn_metadata: Optional[AttentionMetadata] - seq_lens: List[int] - query_lens: List[int] - lora_mapping: Optional[LoRAMapping] - lora_requests: Set[LoRARequest] - multi_modal_kwargs: Dict[str, torch.Tensor] - slot_mapping: torch.Tensor - num_prefill_tokens: int - num_decode_tokens: int - num_prefills: int @classmethod def empty(cls, device): return ModelInput( input_tokens=torch.empty(0, device=device), input_positions=torch.empty(0, device=device), - attn_metadata=None, - seq_lens=[], - query_lens=[], - lora_mapping=None, - lora_requests=set(), - multi_modal_kwargs={}, - slot_mapping=torch.empty(0, device=device), - num_prefill_tokens=0, - num_decode_tokens=0, - num_prefills=0, ) @@ -111,7 +77,7 @@ def __init__( def _prepare_encoder_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - attn_metadata: AttentionMetadata) -> ModelInput: + attn_metadata: AttentionMetadata) -> None: """Prepare the encoder input based on a given sequence group. Encoder attention is an entirely prefill-phase operation. @@ -223,12 +189,6 @@ def _prepare_encoder_model_input( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions_tensor = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) slot_mapping_tensor = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) @@ -244,38 +204,40 @@ def _prepare_encoder_model_input( dtype=query_start_loc.dtype, out=query_start_loc[1:]) + # Set encoder-oriented attention metadata fields + attn_metadata.num_encoder_tokens = num_prefill_tokens attn_metadata.encoder_seq_lens = seq_lens attn_metadata.encoder_seq_lens_tensor = seq_lens_tensor attn_metadata.max_encoder_seq_len = max_seq_len attn_metadata.cross_slot_mapping = slot_mapping_tensor attn_metadata.cross_block_tables = block_tables - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, + if seq_group_metadata.is_prompt: + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + + return EncoderInput( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor ) + else: - lora_mapping = None - multi_modal_kwargs = { - k: torch.cat(v, dim=0).to(self.device) - for k, v in multi_modal_kwargs_list.items() - } + input_tokens_tensor = torch.tensor([], + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor([], + dtype=torch.long, + device=self.device) - return ModelInput( + return EncoderInput( input_tokens=input_tokens_tensor, - input_positions=input_positions_tensor, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_mapping=lora_mapping, - lora_requests=lora_requests, - multi_modal_kwargs=multi_modal_kwargs, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, + input_positions=input_positions_tensor ) def prepare_input_tensors( @@ -300,6 +262,10 @@ def prepare_input_tensors( num_decode_tokens, num_prefills, ) = self._prepare_model_input(seq_group_metadata_list) + ( + encoder_input_tokens, + encoder_input_positions + ) = self._prepare_encoder_model_input(seq_group_metadata_list,attn_metadata) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory) @@ -316,6 +282,8 @@ def prepare_input_tensors( "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, + "encoder_input_tokens":encoder_input_tokens, + "encoder_input_positions":encoder_input_positions } if attn_metadata: metadata_dict.update(attn_metadata.asdict_zerocopy()) @@ -324,6 +292,8 @@ def prepare_input_tensors( metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") + encoder_input_tokens = metadata_dict.pop("encoder_input_tokens") + encoder_input_positions = metadata_dict.pop("encoder_input_positions") selected_token_indices = metadata_dict.pop( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") @@ -341,7 +311,8 @@ def prepare_input_tensors( num_prompts=0, ) - return (input_tokens, input_positions, attn_metadata, + return (input_tokens, input_positions, encoder_input_tokens, + encoder_input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_kwargs) @@ -351,7 +322,8 @@ def execute_model( seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, attn_metadata, sampling_metadata, + (input_tokens, input_positions, encoder_input_tokens, + encoder_input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_kwargs ) = self.prepare_input_tensors(seq_group_metadata_list) @@ -370,6 +342,8 @@ def execute_model( hidden_states = model_executable( input_ids=input_tokens, positions=input_positions, + encoder_input_ids=encoder_input_tokens, + encoder_positions=encoder_input_positions, kv_caches=kv_caches, attn_metadata=attn_metadata, **multi_modal_kwargs, From a8a52d2935d5a2ab969c05d498ec2423ae19507b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Jun 2024 23:39:15 -0400 Subject: [PATCH 244/443] some formatting fixes --- .../test_encoder_decoder_model_runner.py | 6 +- vllm/worker/enc_dec_model_runner.py | 64 +++++++++---------- 2 files changed, 34 insertions(+), 36 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 04f37b8757509..cb6d27ab8bec4 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -9,7 +9,8 @@ from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner -def _create_model_runner(model: str, *args, **kwargs) -> EncoderDecoderModelRunner: +def _create_model_runner(model: str, *args, + **kwargs) -> EncoderDecoderModelRunner: engine_args = EngineArgs(model, *args, **kwargs) engine_config = engine_args.create_engine_config() model_runner = EncoderDecoderModelRunner( @@ -137,6 +138,7 @@ def test_prepare_prompt(batch_size): dtype=actual.dtype) torch.testing.assert_close(actual, expected) + def test_empty_seq_group(): """Verify prepare prompt and decode returns empty output.""" model_runner = _create_model_runner( @@ -171,4 +173,4 @@ def test_empty_seq_group(): assert len(input_positions) == 0 assert attn_metadata is None assert len(slot_mapping) == 0 - assert len(return_seq_lens) == 0 \ No newline at end of file + assert len(return_seq_lens) == 0 diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 87ee9f583f6d4..b59c9041bd872 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -17,8 +17,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad -from vllm.worker.model_runner import (LORA_WARMUP_RANK, - ModelInput, +from vllm.worker.model_runner import (LORA_WARMUP_RANK, ModelInput, ModelRunner) logger = init_logger(__name__) @@ -41,7 +40,7 @@ class EncoderInput(NamedTuple): @classmethod def empty(cls, device): - return ModelInput( + return EncoderInput( input_tokens=torch.empty(0, device=device), input_positions=torch.empty(0, device=device), ) @@ -77,7 +76,7 @@ def __init__( def _prepare_encoder_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - attn_metadata: AttentionMetadata) -> None: + attn_metadata: AttentionMetadata) -> EncoderInput: """Prepare the encoder input based on a given sequence group. Encoder attention is an entirely prefill-phase operation. @@ -215,36 +214,33 @@ def _prepare_encoder_model_input( if seq_group_metadata.is_prompt: input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) + dtype=torch.long, + device=self.device) input_positions_tensor = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) + dtype=torch.long, + device=self.device) - return EncoderInput( - input_tokens=input_tokens_tensor, - input_positions=input_positions_tensor - ) + return EncoderInput(input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor) else: input_tokens_tensor = torch.tensor([], - dtype=torch.long, - device=self.device) + dtype=torch.long, + device=self.device) input_positions_tensor = torch.tensor([], - dtype=torch.long, - device=self.device) + dtype=torch.long, + device=self.device) - return EncoderInput( - input_tokens=input_tokens_tensor, - input_positions=input_positions_tensor - ) + return EncoderInput(input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor) - def prepare_input_tensors( + def prepare_input_tensors_encoder_decoder( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + AttentionMetadata, SamplingMetadata, Set[LoRARequest], + LoRAMapping, Dict[str, torch.Tensor]]: if self.is_driver_worker: assert seq_group_metadata_list is not None # Prepare input tensors. @@ -262,10 +258,9 @@ def prepare_input_tensors( num_decode_tokens, num_prefills, ) = self._prepare_model_input(seq_group_metadata_list) - ( - encoder_input_tokens, - encoder_input_positions - ) = self._prepare_encoder_model_input(seq_group_metadata_list,attn_metadata) + (encoder_input_tokens, + encoder_input_positions) = self._prepare_encoder_model_input( + seq_group_metadata_list, attn_metadata) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory) @@ -282,8 +277,8 @@ def prepare_input_tensors( "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, - "encoder_input_tokens":encoder_input_tokens, - "encoder_input_positions":encoder_input_positions + "encoder_input_tokens": encoder_input_tokens, + "encoder_input_positions": encoder_input_positions } if attn_metadata: metadata_dict.update(attn_metadata.asdict_zerocopy()) @@ -293,7 +288,8 @@ def prepare_input_tensors( input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") encoder_input_tokens = metadata_dict.pop("encoder_input_tokens") - encoder_input_positions = metadata_dict.pop("encoder_input_positions") + encoder_input_positions = metadata_dict.pop( + "encoder_input_positions") selected_token_indices = metadata_dict.pop( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") @@ -312,9 +308,8 @@ def prepare_input_tensors( ) return (input_tokens, input_positions, encoder_input_tokens, - encoder_input_positions, attn_metadata, - sampling_metadata, lora_requests, lora_mapping, - multi_modal_kwargs) + encoder_input_positions, attn_metadata, sampling_metadata, + lora_requests, lora_mapping, multi_modal_kwargs) @torch.inference_mode() def execute_model( @@ -325,7 +320,8 @@ def execute_model( (input_tokens, input_positions, encoder_input_tokens, encoder_input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_kwargs - ) = self.prepare_input_tensors(seq_group_metadata_list) + ) = \ + self.prepare_input_tensors_encoder_decoder(seq_group_metadata_list) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) From fbec309f0cc8d94df6ba7ab3f71f172d30f73531 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 19 Jun 2024 01:14:35 -0400 Subject: [PATCH 245/443] moved enc/dec error strings to top-level vllm utils --- .../test_encoder_decoder_model_runner.py | 2 +- vllm/core/block/utils.py | 12 +- vllm/model_executor/models/__init__.py | 10 +- vllm/model_executor/models/bart.py | 1842 +++++++++++++++++ vllm/utils.py | 10 + vllm/worker/model_runner.py | 4 +- 6 files changed, 1866 insertions(+), 14 deletions(-) create mode 100644 vllm/model_executor/models/bart.py diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index cb6d27ab8bec4..69672d8a55158 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -142,7 +142,7 @@ def test_prepare_prompt(batch_size): def test_empty_seq_group(): """Verify prepare prompt and decode returns empty output.""" model_runner = _create_model_runner( - "facebook/opt-125m", + "facebook/bart-base", seed=0, dtype="float16", enforce_eager=False, diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 2c412a8f472e0..28839437c33c5 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,15 +1,7 @@ """Block manager utils.""" from vllm.sequence import SequenceGroup - -# Exception strings for non-implemented block manager enc/dec scenarios - -STR_NOT_IMPL_ENC_DEC_SWA = \ - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported." - -STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ - "Prefix caching for encoder/decoder models " + \ - "is not currently supported." +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) def _get_block_mgr_sliding_window_attr(block_mgr): diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index f9ec7209689e7..ba05340b93534 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -66,7 +66,15 @@ "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), } -_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} +_CONDITIONAL_GENERATION_MODELS = { + "BartModel": ("bart", ), +} + +_MODELS = { + **_GENERATION_MODELS, + **_EMBEDDING_MODELS, + **_CONDITIONAL_GENERATION_MODELS +} # Architecture -> type. # out of tree models diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py new file mode 100644 index 0000000000000..bccc1aa964192 --- /dev/null +++ b/vllm/model_executor/models/bart.py @@ -0,0 +1,1842 @@ +# Derived from BART implementation posted on HuggingFace; license below: +# +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch BART model.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import MixtralConfig + +from vllm import _custom_ops as ops +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, + per_tensor_dequantize, + per_tensor_quantize) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import SamplerOutput +from vllm.utils import print_warning_once + +import copy +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, +) +from .configuration_bart import BartConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/bart-base" +_CONFIG_FOR_DOC = "BartConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 768] + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/bart-large-sst2" +_SEQ_CLASS_EXPECTED_LOSS = 0.0 +_SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'" + +# QuestionAsnwering docstring +_CHECKPOINT_FOR_QA = "valhalla/bart-large-finetuned-squadv1" +_QA_EXPECTED_LOSS = 0.59 +_QA_EXPECTED_OUTPUT = "' nice puppet'" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class BartLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + + return super().forward(positions + self.offset) + + +class BartScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +class BartAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[BartConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class BartFlashAttention2(BartAttention): + """ + Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # BartFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("BartFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class BartSdpaAttention(BartAttention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz) + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +BART_ATTENTION_CLASSES = { + "eager": BartAttention, + "sdpa": BartSdpaAttention, + "flash_attention_2": BartFlashAttention2, +} + + +class BartEncoderLayer(nn.Module): + def __init__(self, config: BartConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class BartDecoderLayer(nn.Module): + def __init__(self, config: BartConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class BartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class BartPreTrainedModel(PreTrainedModel): + config_class = BartConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"] + _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class PretrainedBartModel(BartPreTrainedModel): + def __init_subclass__(self): + warnings.warn( + "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", + FutureWarning, + ) + + +class BartPretrainedModel(BartPreTrainedModel): + def __init_subclass__(self): + warnings.warn( + "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", + FutureWarning, + ) + + +BART_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BartConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BART_GENERATION_EXAMPLE = r""" + Summarization example: + + ```python + >>> from transformers import AutoTokenizer, BartForConditionalGeneration + + >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") + + >>> ARTICLE_TO_SUMMARIZE = ( + ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + ... ) + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20) + >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions' + ``` + + Mask filling example: + + ```python + >>> from transformers import AutoTokenizer, BartForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") + >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") + + >>> TXT = "My friends are but they eat too many carbs." + >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"] + >>> logits = model(input_ids).logits + + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = logits[0, masked_index].softmax(dim=0) + >>> values, predictions = probs.topk(5) + + >>> tokenizer.decode(predictions).split() + ['not', 'good', 'healthy', 'great', 'very'] + ``` +""" + +BART_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class BartEncoder(BartPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`BartEncoderLayer`]. + + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(input) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + if self._use_flash_attention_2: + attention_mask = attention_mask if 0 in attention_mask else None + elif self._use_sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class BartDecoder(BartPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] + + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self._use_flash_attention_2: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + +class BartModel(BartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: BartConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = BartEncoder(config, self.shared) + self.decoder = BartDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqModelOutput]: + # different to other models, Bart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + +class BartForConditionalGeneration(BartPreTrainedModel): + base_model_prefix = "model" + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + + def __init__(self, config: BartConfig): + super().__init__(config) + self.model = BartModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata + ) -> Union[Tuple, Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) + + masked_lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], + ) + return reordered_past \ No newline at end of file diff --git a/vllm/utils.py b/vllm/utils.py index ebcd2181c1086..00ef1d9570d8c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -30,6 +30,16 @@ logger = init_logger(__name__) +# Exception strings for non-implemented encoder/decoder scenarios + +STR_NOT_IMPL_ENC_DEC_SWA = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." + STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.half, "bfloat16": torch.bfloat16, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7089b0e47e087..81c1386a5397c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -147,7 +147,7 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - if self._am_not_child() and self._is_encoder_decoder_model(): + if (not self._am_child()) and self._is_encoder_decoder_model(): # Fail if ModelRunner is constructed for an # encoder/decoder model # @@ -950,7 +950,7 @@ def _is_encoder_decoder_model(self): "is_encoder_decoder", False) - def _am_not_child(self): + def _am_child(self): ''' True if self is an instance of the ModelRunner base class, False otherwise (i.e. child class) From 1581eb7f978a83690e0aaa2b390be491b42ffb15 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 19 Jun 2024 22:28:28 -0400 Subject: [PATCH 246/443] wip --- tests/worker/test_encoder_decoder_model_runner.py | 9 ++++++++- vllm/utils.py | 7 +++++++ vllm/worker/enc_dec_model_runner.py | 10 +++------- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 69672d8a55158..083d87cec6880 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -29,26 +29,33 @@ def _create_model_runner(model: str, *args, @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_prompt(batch_size): model_runner = _create_model_runner( - "facebook/opt-125m", + "facebook/bart-base", max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, ) seq_lens: List[int] = [] + encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] block_tables = {0: [1]} + cross_block_table = [2] for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) seq_data = SequenceData(list(range(seq_len))) + encoder_seq_len = (i+1) % (model_runner.block_size - 1) + 1 + encoder_seq_lens.append(encoder_seq_len) + encoder_seq_data = SequenceData(list(range(encoder_seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, seq_data={0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table ) assert seq_group_metadata.token_chunk_size == seq_data.get_len() seq_group_metadata_list.append(seq_group_metadata) diff --git a/vllm/utils.py b/vllm/utils.py index 00ef1d9570d8c..f03a8cc7d1bf2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -40,6 +40,13 @@ "Prefix caching for encoder/decoder models " + \ "is not currently supported." +STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ + "Chunked prefill for encoder/decoder models " + \ + "is not currently supported." + +STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED = \ + "Currently CUDAGraph is not supported for encoder/decoder models" + STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.half, "bfloat16": torch.bfloat16, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index b59c9041bd872..38c3bc6a57f45 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -19,6 +19,8 @@ from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner import (LORA_WARMUP_RANK, ModelInput, ModelRunner) +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, + STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED) logger = init_logger(__name__) @@ -28,12 +30,6 @@ "Only encoder/decoder models may be executed " + \ "using EncoderDecoderModelRunner" -# Error message if EncoderDecoderModelRunner is used with -# CUDAGraph -STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED = \ - "Currently CUDAGraph is not supported for encoder/decoder models" - - class EncoderInput(NamedTuple): input_tokens: torch.Tensor input_positions: torch.Tensor @@ -72,7 +68,7 @@ def __init__( raise AttributeError(STR_ENCDECMR_ENCODER_DECODER_REQUIRED) if self.scheduler_config.chunked_prefill_enabled: - raise NotImplementedError() + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) def _prepare_encoder_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], From f0094bd8a90cc26325f1ea7ca1506fc459a312c9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 20 Jun 2024 10:59:52 -0400 Subject: [PATCH 247/443] wip enc/dec modelrunner prepare_prompt test --- .../test_encoder_decoder_model_runner.py | 34 +- vllm/model_executor/models/bart.py | 510 +++++++++++------- vllm/worker/enc_dec_model_runner.py | 5 +- vllm/worker/model_runner.py | 11 +- 4 files changed, 348 insertions(+), 212 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 083d87cec6880..a304775f93baf 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -7,6 +7,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner +from tests.kernels.utils import (override_backend_env_variable, + STR_XFORMERS_ATTN_VAL) + +BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] def _create_model_runner(model: str, *args, @@ -27,7 +31,12 @@ def _create_model_runner(model: str, *args, @pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_prompt(batch_size): +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +def test_prepare_prompt(batch_size, backend_name, monkeypatch): + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + model_runner = _create_model_runner( "facebook/bart-base", max_num_batched_tokens=100000, @@ -45,7 +54,7 @@ def test_prepare_prompt(batch_size): seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) seq_data = SequenceData(list(range(seq_len))) - encoder_seq_len = (i+1) % (model_runner.block_size - 1) + 1 + encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_lens.append(encoder_seq_len) encoder_seq_data = SequenceData(list(range(encoder_seq_len))) seq_group_metadata = SequenceGroupMetadata( @@ -55,8 +64,7 @@ def test_prepare_prompt(batch_size): sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table - ) + cross_block_table=cross_block_table) assert seq_group_metadata.token_chunk_size == seq_data.get_len() seq_group_metadata_list.append(seq_group_metadata) @@ -66,6 +74,8 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices.append(selected_token_start_idx + seq_len - 1) selected_token_start_idx += seq_len + + # Decoder model input model_input = model_runner._prepare_model_input(seq_group_metadata_list) input_tokens = model_input.input_tokens input_positions = model_input.input_positions @@ -75,6 +85,15 @@ def test_prepare_prompt(batch_size): assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) + # Encoder model input + encoder_model_input = model_runner._prepare_encoder_model_input( + seq_group_metadata_list, attn_metadata) + encoder_input_tokens = encoder_model_input.input_tokens + encoder_input_positions = encoder_model_input.input_positions + cross_slot_mapping = attn_metadata.cross_slot_mapping + assert len(encoder_input_tokens) == sum(encoder_seq_lens) + assert len(cross_slot_mapping) == len(encoder_input_tokens) + # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.num_prefills > 0 @@ -146,8 +165,13 @@ def test_prepare_prompt(batch_size): torch.testing.assert_close(actual, expected) -def test_empty_seq_group(): +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +def test_empty_seq_group(backend_name, monkeypatch): """Verify prepare prompt and decode returns empty output.""" + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + model_runner = _create_model_runner( "facebook/bart-base", seed=0, diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index bccc1aa964192..4b083b510df3b 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """PyTorch BART model.""" from typing import Iterable, List, Optional, Tuple @@ -84,12 +83,10 @@ ) from .configuration_bart import BartConfig - if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "facebook/bart-base" @@ -114,7 +111,8 @@ def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, @@ -122,7 +120,8 @@ def _get_unpad_data(attention_mask): ) -def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, + decoder_start_token_id: int): """ Shift input ids one token to the right. """ @@ -149,13 +148,16 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + def forward(self, + input_ids: torch.Tensor, + past_key_values_length: int = 0): """`input_ids' shape is expected to be [bsz x seqlen].""" bsz, seq_len = input_ids.shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ).expand(bsz, -1) + positions = torch.arange(past_key_values_length, + past_key_values_length + seq_len, + dtype=torch.long, + device=self.weight.device).expand(bsz, -1) return super().forward(positions + self.offset) @@ -165,7 +167,11 @@ class BartScaledWordEmbedding(nn.Embedding): This module overrides nn.Embeddings' forward by multiplying with embeddings scale. """ - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + embed_scale: Optional[float] = 1.0): super().__init__(num_embeddings, embedding_dim, padding_idx) self.embed_scale = embed_scale @@ -196,8 +202,7 @@ def __init__( if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) + f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal @@ -208,7 +213,8 @@ def __init__( self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() def forward( self, @@ -218,7 +224,8 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer @@ -233,11 +240,8 @@ def forward( # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if (is_cross_attention and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1]): # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] @@ -267,7 +271,8 @@ def forward( past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = self._shape(query_states, tgt_len, + bsz).view(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -277,49 +282,57 @@ def forward( if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) + f" {attn_weights.size()}") if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, + src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): + if layer_head_mask.size() != (self.num_heads, ): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + f" {layer_head_mask.size()}") + attn_weights = layer_head_mask.view( + 1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, + src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, + tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, + tgt_len, src_len) else: attn_weights_reshaped = None - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_probs = nn.functional.dropout(attn_weights, + p=self.dropout, + training=self.training) attn_output = torch.bmm(attn_probs, value_states) - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + if attn_output.size() != (bsz * self.num_heads, tgt_len, + self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + f" {attn_output.size()}") - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, + self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be @@ -343,9 +356,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10( + ) def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -358,10 +372,13 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: # BartFlashAttention2 attention does not support output_attentions if output_attentions: - raise ValueError("BartFlashAttention2 attention does not support output_attentions") + raise ValueError( + "BartFlashAttention2 attention does not support output_attentions" + ) # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder @@ -375,24 +392,24 @@ def forward( # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if (is_cross_attention and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1]): # reuse k,v, cross_attentions key_states = past_key_value[0].transpose(1, 2) value_states = past_key_value[1].transpose(1, 2) elif is_cross_attention: # cross_attentions key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, + bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + key_states = torch.cat( + [past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat( + [past_key_value[1].transpose(1, 2), value_states], dim=1) else: # self_attention key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) @@ -406,7 +423,8 @@ def forward( # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + past_key_value = (key_states.transpose(1, 2), + value_states.transpose(1, 2)) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -431,16 +449,18 @@ def forward( logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + f" {target_dtype}.") query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout - ) + attn_output = self._flash_attention_forward(query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) @@ -451,9 +471,14 @@ def forward( return attn_output, attn_weights, past_key_value # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward - def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None - ): + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. @@ -483,8 +508,8 @@ def _flash_attention_forward( if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) + query_states, key_states, value_states, attention_mask, + query_length) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -502,29 +527,35 @@ def _flash_attention_forward( causal=causal, ) - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, + query_length) else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal - ) + attn_output = flash_attn_func(query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal) return attn_output # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) if query_length == kv_seq_len: query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k - ) + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, + head_dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k @@ -538,7 +569,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask) return ( query_layer, @@ -551,6 +583,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query class BartSdpaAttention(BartAttention): + def forward( self, hidden_states: torch.Tensor, @@ -559,7 +592,8 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. @@ -588,11 +622,8 @@ def forward( # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if (is_cross_attention and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1]): # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] @@ -642,8 +673,7 @@ def forward( if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + f" {attn_output.size()}") attn_output = attn_output.transpose(1, 2) @@ -664,6 +694,7 @@ def forward( class BartEncoderLayer(nn.Module): + def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model @@ -707,33 +738,42 @@ def forward( layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, + p=self.activation_dropout, + training=self.training) hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - ): + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) - outputs = (hidden_states,) + outputs = (hidden_states, ) if output_attentions: - outputs += (attn_weights,) + outputs += (attn_weights, ) return outputs class BartDecoderLayer(nn.Module): + def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model @@ -751,13 +791,14 @@ def __init__(self, config: BartConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - config=config, - ) + self.encoder_attn = BART_ATTENTION_CLASSES[ + config._attn_implementation]( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) @@ -774,7 +815,8 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -797,7 +839,8 @@ def forward( # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, @@ -806,7 +849,9 @@ def forward( layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -817,7 +862,8 @@ def forward( residual = hidden_states # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attn_past_key_value = past_key_value[ + -2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, @@ -826,7 +872,9 @@ def forward( past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) @@ -836,19 +884,23 @@ def forward( # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, + p=self.activation_dropout, + training=self.training) hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) - outputs = (hidden_states,) + outputs = (hidden_states, ) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (present_key_value, ) return outputs @@ -901,7 +953,8 @@ def _init_weights(self, module): @property def dummy_inputs(self): pad_token = self.config.pad_token_id - input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], + device=self.device) dummy_inputs = { "attention_mask": input_ids.ne(pad_token), "input_ids": input_ids, @@ -910,6 +963,7 @@ def dummy_inputs(self): class PretrainedBartModel(BartPreTrainedModel): + def __init_subclass__(self): warnings.warn( "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", @@ -918,6 +972,7 @@ def __init_subclass__(self): class BartPretrainedModel(BartPreTrainedModel): + def __init_subclass__(self): warnings.warn( "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", @@ -1092,7 +1147,9 @@ class BartEncoder(BartPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, + config: BartConfig, + embed_tokens: Optional[nn.Embedding] = None): super().__init__(config) self.dropout = config.dropout @@ -1103,9 +1160,10 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.max_source_positions = config.max_position_embeddings embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - self.embed_tokens = BartScaledWordEmbedding( - config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale - ) + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + embed_dim, + self.padding_idx, + embed_scale=embed_scale) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -1114,7 +1172,8 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No config.max_position_embeddings, embed_dim, ) - self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layers = nn.ModuleList( + [BartEncoderLayer(config) for _ in range(config.encoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) @@ -1176,21 +1235,24 @@ def forward( Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) elif input_ids is not None: input = input_ids input_ids = input_ids.view(-1, input_ids.shape[-1]) elif inputs_embeds is not None: input = inputs_embeds[:, :, -1] else: - raise ValueError("You have to specify either input_ids or inputs_embeds") + raise ValueError( + "You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1200,7 +1262,9 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) # expand attention_mask if attention_mask is not None: @@ -1210,10 +1274,12 @@ def forward( # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, inputs_embeds.dtype) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask( + attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1223,12 +1289,11 @@ def forward( if head_mask.size()[0] != (len(self.layers)): raise ValueError( f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}." - ) + f" {head_mask.size()[0]}.") for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + encoder_states = encoder_states + (hidden_states, ) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) to_drop = False if self.training: @@ -1251,23 +1316,26 @@ def forward( layer_outputs = encoder_layer( hidden_states, attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), + layer_head_mask=(head_mask[idx] + if head_mask is not None else None), output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_attentions = all_attentions + (layer_outputs[1], ) if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + encoder_states = encoder_states + (hidden_states, ) if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) + return tuple( + v for v in [hidden_states, encoder_states, all_attentions] + if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions) class BartDecoder(BartPreTrainedModel): @@ -1279,17 +1347,21 @@ class BartDecoder(BartPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, + config: BartConfig, + embed_tokens: Optional[nn.Embedding] = None): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings - embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 - self.embed_tokens = BartScaledWordEmbedding( - config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale - ) + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + config.d_model, + self.padding_idx, + embed_scale=embed_scale) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight @@ -1298,7 +1370,8 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [BartDecoderLayer(config) for _ in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" @@ -1395,15 +1468,17 @@ def forward( Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) elif input_ids is not None: input = input_ids input_shape = input.shape @@ -1412,17 +1487,21 @@ def forward( input_shape = inputs_embeds.size()[:-1] input = inputs_embeds[:, :, -1] else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) if self._use_flash_attention_2: # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask = attention_mask if ( + attention_mask is not None and 0 in attention_mask) else None elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. @@ -1435,8 +1514,8 @@ def forward( else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + attention_mask, input_shape, inputs_embeds, + past_key_values_length) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: @@ -1454,8 +1533,9 @@ def forward( else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1]) # embed positions positions = self.embed_positions(input, past_key_values_length) @@ -1464,7 +1544,9 @@ def forward( hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) if self.gradient_checkpointing and self.training: if use_cache: @@ -1476,28 +1558,30 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + all_cross_attentions = () if ( + output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], + ["head_mask", "cross_attn_head_mask"]): if attn_mask is not None: if attn_mask.size()[0] != (len(self.layers)): raise ValueError( f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}." - ) + f" {head_mask.size()[0]}.") for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: - all_hidden_states += (hidden_states,) + all_hidden_states += (hidden_states, ) if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None + past_key_value = past_key_values[ + idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -1507,7 +1591,8 @@ def forward( encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + cross_attn_head_mask[idx] + if cross_attn_head_mask is not None else None, None, output_attentions, use_cache, @@ -1518,10 +1603,11 @@ def forward( attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), + layer_head_mask=(head_mask[idx] + if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] + if cross_attn_head_mask + is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1529,25 +1615,25 @@ def forward( hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache += ( + layer_outputs[3 if output_attentions else 1], ) if output_attentions: - all_self_attns += (layer_outputs[1],) + all_self_attns += (layer_outputs[1], ) if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) + all_cross_attentions += (layer_outputs[2], ) # add hidden states from the last decoder layer if output_hidden_states: - all_hidden_states += (hidden_states,) + all_hidden_states += (hidden_states, ) next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] - if v is not None - ) + return tuple(v for v in [ + hidden_states, next_cache, all_hidden_states, all_self_attns, + all_cross_attentions + ] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1556,8 +1642,11 @@ def forward( cross_attentions=all_cross_attentions, ) + class BartModel(BartPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + ] def __init__(self, config: BartConfig): super().__init__(config) @@ -1619,13 +1708,13 @@ def forward( ) decoder_input_ids = shift_tokens_right( - input_ids, self.config.pad_token_id, self.config.decoder_start_token_id - ) + input_ids, self.config.pad_token_id, + self.config.decoder_start_token_id) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1643,8 +1732,10 @@ def forward( elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + hidden_states=encoder_outputs[1] + if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] + if len(encoder_outputs) > 2 else None, ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) @@ -1677,16 +1768,24 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) + class BartForConditionalGeneration(BartPreTrainedModel): base_model_prefix = "model" - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", + "lm_head.weight" + ] _keys_to_ignore_on_load_missing = ["final_logits_bias"] def __init__(self, config: BartConfig): super().__init__(config) self.model = BartModel(config) - self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) - self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + self.register_buffer( + "final_logits_bias", + torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, + self.model.shared.num_embeddings, + bias=False) # Initialize weights and apply final processing self.post_init() @@ -1697,8 +1796,12 @@ def get_encoder(self): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + def resize_token_embeddings( + self, + new_num_tokens: int, + pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings( + new_num_tokens, pad_to_multiple_of) self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings @@ -1707,7 +1810,8 @@ def _resize_final_logits_bias(self, new_num_tokens: int) -> None: if new_num_tokens <= old_num_tokens: new_bias = self.final_logits_bias[:, :new_num_tokens] else: - extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), + device=self.final_logits_bias.device) new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) self.register_buffer("final_logits_bias", new_bias) @@ -1718,14 +1822,10 @@ def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata - ) -> Union[Tuple, Seq2SeqLMOutput]: + self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> Union[Tuple, Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1738,12 +1838,14 @@ def forward( if labels is not None: if use_cache: - logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + logger.warning( + "The `use_cache` argument is changed to `False` since `labels` is provided." + ) use_cache = False if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) + labels, self.config.pad_token_id, + self.config.decoder_start_token_id) outputs = self.model( input_ids, @@ -1770,11 +1872,13 @@ def forward( if labels is not None: labels = labels.to(lm_logits.device) loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + masked_lm_loss = loss_fct( + lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + output = (lm_logits, ) + outputs[1:] + return ((masked_lm_loss, ) + + output) if masked_lm_loss is not None else output return Seq2SeqLMOutput( loss=masked_lm_loss, @@ -1815,7 +1919,8 @@ def prepare_inputs_for_generation( decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed + "input_ids": + None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, @@ -1824,19 +1929,20 @@ def prepare_inputs_for_generation( "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + "use_cache": + use_cache, # change this to avoid caching (presumably for debugging) } def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + return shift_tokens_right(labels, self.config.pad_token_id, + self.config.decoder_start_token_id) @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past \ No newline at end of file + reordered_past += (tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past[:2]) + layer_past[2:], ) + return reordered_past diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 38c3bc6a57f45..d04fc2052703b 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -30,6 +30,7 @@ "Only encoder/decoder models may be executed " + \ "using EncoderDecoderModelRunner" + class EncoderInput(NamedTuple): input_tokens: torch.Tensor input_positions: torch.Tensor @@ -130,7 +131,7 @@ def _prepare_encoder_model_input( sliding_window_blocks * self.block_size for seq_group_metadata in seq_group_metadata_list: - computed_block_nums = seq_group_metadata.computed_block_nums + computed_block_nums = None #seq_group_metadata.computed_block_nums if (self.scheduler_config is not None and self.scheduler_config.chunked_prefill_enabled and not (computed_block_nums is None @@ -153,7 +154,7 @@ def _prepare_encoder_model_input( paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, sliding_window_blocks, block_aligned_sliding_window, lora_index_mapping, lora_prompt_mapping, lora_requests, - multi_modal_kwargs_list) + multi_modal_kwargs_list, is_encoder_seq=True) max_query_len = max(query_lens) max_seq_len = max(prefill_seq_lens, default=0) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 81c1386a5397c..89d8029c21c97 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -278,7 +278,8 @@ def _prepare_seq_model_input( lora_prompt_mapping: List[int] = [], lora_requests: Set[LoRARequest] = set(), multi_modal_kwargs_list: Dict[str, - List[torch.Tensor]] = defaultdict(list) + List[torch.Tensor]] = defaultdict(list), + is_encoder_seq: bool = False ) -> Tuple[bool, int, int, int]: if is_prompt: @@ -289,8 +290,12 @@ def _prepare_seq_model_input( # TODO(sang): Fix it. context_len = seq_data.get_len() - 1 - seq_len = min(seq_data.get_len(), - context_len + seq_group_metadata.token_chunk_size) + if is_encoder_seq: + seq_len = seq_data.get_len() + else: + seq_len = min(seq_data.get_len(), + context_len + seq_group_metadata.token_chunk_size) + if is_prompt: tokens = seq_data.get_token_ids()[context_len:seq_len] else: From a0068fc9112c5acefe69f5a8e30470c73a90a039 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 00:21:05 -0400 Subject: [PATCH 248/443] Encoder/decoder model runner passes prefill/decode/empty-SG tests --- .../test_encoder_decoder_model_runner.py | 186 ++++++++++++++++-- vllm/worker/enc_dec_model_runner.py | 13 +- vllm/worker/model_runner.py | 25 ++- 3 files changed, 200 insertions(+), 24 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index a304775f93baf..7ce27abbea2c0 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import pytest import torch @@ -10,8 +10,16 @@ from tests.kernels.utils import (override_backend_env_variable, STR_XFORMERS_ATTN_VAL) +# Backends under test +# +# Currently only XFormers is supported BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] +# CUDA graph scenarios to test +# +# Currently CUDA graph is not supported +ENFORCE_EAGER = [True] + def _create_model_runner(model: str, *args, **kwargs) -> EncoderDecoderModelRunner: @@ -32,7 +40,8 @@ def _create_model_runner(model: str, *args, @pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) -def test_prepare_prompt(batch_size, backend_name, monkeypatch): +@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) +def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) @@ -42,6 +51,7 @@ def test_prepare_prompt(batch_size, backend_name, monkeypatch): max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, + enforce_eager=enforce_eager ) seq_lens: List[int] = [] @@ -91,10 +101,10 @@ def test_prepare_prompt(batch_size, backend_name, monkeypatch): encoder_input_tokens = encoder_model_input.input_tokens encoder_input_positions = encoder_model_input.input_positions cross_slot_mapping = attn_metadata.cross_slot_mapping - assert len(encoder_input_tokens) == sum(encoder_seq_lens) assert len(cross_slot_mapping) == len(encoder_input_tokens) # Verify input metadata is correct for prompts. + # - Decoder attention metadata device = model_runner.device assert attn_metadata.num_prefills > 0 assert attn_metadata.num_decode_tokens == 0 @@ -104,8 +114,15 @@ def test_prepare_prompt(batch_size, backend_name, monkeypatch): assert attn_metadata.seq_lens == seq_lens assert attn_metadata.max_prefill_seq_len == max(seq_lens) assert attn_metadata.max_decode_seq_len == 0 + # - Encoder attention metadata + assert attn_metadata.encoder_seq_lens == encoder_seq_lens + assert torch.allclose( + attn_metadata.encoder_seq_lens_tensor, + torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) + assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) - # Test subquery start locs. + # Test decoder subquery start locs. start_idx = 0 start_loc = [start_idx] for seq_len in seq_lens: @@ -115,7 +132,7 @@ def test_prepare_prompt(batch_size, backend_name, monkeypatch): attn_metadata.query_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) - # Test seq start locs. Note that for normal prefill it is + # Test decoder seq start locs. Note that for normal prefill it is # equivalent to query_start_loc. start_idx = 0 seq_start_loc = [start_idx] @@ -132,16 +149,33 @@ def test_prepare_prompt(batch_size, backend_name, monkeypatch): dtype=torch.int, device=device)) + # Verify block tables are correct for prompts + # - Decoder self-attention expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], dtype=torch.int32, device=model_runner.device) assert torch.allclose(attn_metadata.block_tables, expected) - # Cuda graph should not be used for prerill. + # - Encoder/decoder cross-attention + # expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], + # dtype=torch.int32, + # device=model_runner.device) + assert torch.allclose(attn_metadata.cross_block_tables, expected) + + # Cuda graph should not be used for prefill, regardless of + # enforce_eager setting assert attn_metadata.use_cuda_graph is False + # Verify the lengths of input tokens & positions + # - Decoder assert len(input_tokens) == sum(seq_lens) assert len(input_positions) == sum(seq_lens) - torch.testing.assert_close(input_tokens, input_positions) + torch.testing.assert_close(input_tokens, + input_positions) + # - Encoder + assert len(encoder_input_tokens) == sum(encoder_seq_lens) + assert len(encoder_input_tokens) == sum(encoder_seq_lens) + torch.testing.assert_close(encoder_input_tokens, + encoder_input_positions) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, @@ -149,8 +183,6 @@ def test_prepare_prompt(batch_size, backend_name, monkeypatch): query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) - assert len(input_tokens) == sum(seq_lens) - assert len(input_positions) == sum(seq_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, @@ -164,9 +196,141 @@ def test_prepare_prompt(batch_size, backend_name, monkeypatch): dtype=actual.dtype) torch.testing.assert_close(actual, expected) +@pytest.mark.parametrize("batch_size", list(range(1, 257))) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) +def test_prepare_decode(batch_size, backend_name, enforce_eager, monkeypatch): + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + model_runner = _create_model_runner( + "facebook/bart-base", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=enforce_eager + ) + + seq_lens: List[int] = [] + encoder_seq_lens: List[int] = [] + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + block_tables = {0: [1]} + cross_block_table = [2] + for i in range(batch_size): + # make sure all tokens fit into one block + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) + encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 + encoder_seq_lens.append(encoder_seq_len) + encoder_seq_data = SequenceData(list(range(encoder_seq_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) + + # Decoder model input + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = model_input.slot_mapping + assert return_seq_lens == seq_lens + assert len(slot_mapping) == len(input_tokens) + + # Encoder model input + encoder_model_input = model_runner._prepare_encoder_model_input( + seq_group_metadata_list, attn_metadata) + encoder_input_tokens = encoder_model_input.input_tokens + encoder_input_positions = encoder_model_input.input_positions + return_encoder_seq_lens = attn_metadata.encoder_seq_lens + cross_slot_mapping = attn_metadata.cross_slot_mapping + assert return_encoder_seq_lens == encoder_seq_lens + assert len(cross_slot_mapping) == len(encoder_input_tokens) + + # Verify input metadata is correct for decode phase. + # - Decoder attention metadata + device = model_runner.device + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_decode_tokens > 0 + assert torch.allclose( + attn_metadata.seq_lens_tensor, + torch.tensor(seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.max_prefill_seq_len == 0 + assert attn_metadata.max_decode_seq_len == max(seq_lens) + # - Encoder attention metadata + assert attn_metadata.encoder_seq_lens == encoder_seq_lens + assert torch.allclose( + attn_metadata.encoder_seq_lens_tensor, + torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) + assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) + + # Test decoder subquery start locs. + start_idx = 0 + start_loc = [start_idx] + for seq_len in seq_lens: + # 1 decoded token per sequence + start_idx += 1 + start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.query_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + + # Test decoder seq start locs. Note that for normal prefill it is + # equivalent to query_start_loc. + start_idx = 0 + seq_start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + seq_start_loc.append(start_idx) + + assert torch.allclose( + attn_metadata.seq_start_loc, + torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) + assert torch.allclose( + attn_metadata.context_lens_tensor, + torch.tensor([seq_len-1 for seq_len in seq_lens], + dtype=torch.int, + device=device)) + + # Verify block tables are correct for prompts + # - Decoder self-attention + expected = torch.tensor([block_tables[0] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.allclose(attn_metadata.block_tables, expected) + # - Encoder/decoder cross-attention + expected = torch.tensor([cross_block_table for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.allclose(attn_metadata.cross_block_tables, expected) + + # Cuda graph should not be used for prefill. + assert attn_metadata.use_cuda_graph == (not enforce_eager) + + # Verify the lengths of input tokens & positions + # - Decoder + assert len(input_tokens) == len(seq_lens) + assert len(input_positions) == len(seq_lens) + torch.testing.assert_close(input_tokens, + input_positions) + # - Encoder + assert len(encoder_input_tokens) == 0 + assert len(encoder_input_positions) == 0 @pytest.mark.parametrize("backend_name", BACKEND_NAMES) -def test_empty_seq_group(backend_name, monkeypatch): +@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) +def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): """Verify prepare prompt and decode returns empty output.""" # Force Attention wrapper backend @@ -176,7 +340,7 @@ def test_empty_seq_group(backend_name, monkeypatch): "facebook/bart-base", seed=0, dtype="float16", - enforce_eager=False, + enforce_eager=enforce_eager, ) seq_group_metadata_list: List[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input(seq_group_metadata_list) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index d04fc2052703b..bfa84c589745c 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -101,8 +101,6 @@ def _prepare_encoder_model_input( sliding_window_blocks = 0 block_aligned_sliding_window = 0 - is_prompt = True - # The following fields are only for flashinfer # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout # for the precise definition of the following fields. @@ -132,6 +130,9 @@ def _prepare_encoder_model_input( for seq_group_metadata in seq_group_metadata_list: computed_block_nums = None #seq_group_metadata.computed_block_nums + + is_prompt = seq_group_metadata.is_prompt + if (self.scheduler_config is not None and self.scheduler_config.chunked_prefill_enabled and not (computed_block_nums is None @@ -157,7 +158,9 @@ def _prepare_encoder_model_input( multi_modal_kwargs_list, is_encoder_seq=True) max_query_len = max(query_lens) - max_seq_len = max(prefill_seq_lens, default=0) + + max_seq_len = max(prefill_seq_lens, default=0) if is_prompt else \ + max(decode_seq_lens, default=0) # Assume Eager Mode # TODO: CUDA Graph support @@ -171,7 +174,7 @@ def _prepare_encoder_model_input( device=self.device, ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + assert (not is_prompt) or max_query_len > 0, ("query_lens: {}".format(query_lens)) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, @@ -201,7 +204,7 @@ def _prepare_encoder_model_input( out=query_start_loc[1:]) # Set encoder-oriented attention metadata fields - attn_metadata.num_encoder_tokens = num_prefill_tokens + attn_metadata.num_encoder_tokens = sum(seq_lens) attn_metadata.encoder_seq_lens = seq_lens attn_metadata.encoder_seq_lens_tensor = seq_lens_tensor attn_metadata.max_encoder_seq_len = max_seq_len diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 89d8029c21c97..2a747e2e5c244 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -285,10 +285,16 @@ def _prepare_seq_model_input( if is_prompt: context_len = seq_data.get_num_computed_tokens() else: - # get_num_computed_tokens is incorrect for spec decoding. - # So, we should have a special logic here. - # TODO(sang): Fix it. - context_len = seq_data.get_len() - 1 + if is_encoder_seq: + # In decode phase, no new *encoder* tokens are + # introduced, so the context is always the full + # encoder sequence + context_len = seq_data.get_len() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_data.get_len() - 1 if is_encoder_seq: seq_len = seq_data.get_len() @@ -357,7 +363,7 @@ def _prepare_seq_model_input( block_table = computed_block_nums elif (self.scheduler_config.chunked_prefill_enabled or not is_prompt): - if seq_group_metadata.block_tables is not None: + if original_block_table is not None: # chunked prefill or decode block_table = original_block_table assert block_table is not None @@ -398,7 +404,9 @@ def _prepare_seq_model_input( decode_only = False prefill_seq_lens.append(seq_len) else: - assert query_len == 1, ( + # Except in encoder/decoder scenario, decode-phase + # query_len must be 1 + assert is_encoder_seq or query_len == 1, ( "seq_len: {}, context_len: {}, query_len: {}".format( seq_len, context_len, query_len)) num_decode_tokens += query_len @@ -425,7 +433,7 @@ def _prepare_seq_model_input( for k, v in mm_kwargs.items(): multi_modal_kwargs_list[k].append(v) - if _is_block_tables_empty(seq_group_metadata.block_tables): + if _is_block_tables_empty(original_block_table): # During memory profiling, the block tables are not # initialized yet. In this case, we just use a dummy # slot mapping. @@ -551,7 +559,8 @@ def _prepare_model_input( "now.") seq_data = seq_group_metadata.seq_data[seq_id] - block_table = seq_group_metadata.block_tables[seq_id] + block_table = None if seq_group_metadata.block_tables is None else \ + seq_group_metadata.block_tables[seq_id] decode_only, \ num_prefills, \ num_prefill_tokens, \ From f8569facd10b0cbf05689bfc364831a37bb48b45 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 00:35:24 -0400 Subject: [PATCH 249/443] formatting --- .../test_encoder_decoder_model_runner.py | 57 ++++++------ vllm/model_executor/models/bart.py | 3 +- vllm/worker/enc_dec_model_runner.py | 5 +- vllm/worker/model_runner.py | 87 +++++++++++-------- 4 files changed, 83 insertions(+), 69 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 7ce27abbea2c0..dd113a640f383 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List import pytest import torch @@ -46,13 +46,11 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) - model_runner = _create_model_runner( - "facebook/bart-base", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=enforce_eager - ) + model_runner = _create_model_runner("facebook/bart-base", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=enforce_eager) seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] @@ -169,13 +167,11 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): # - Decoder assert len(input_tokens) == sum(seq_lens) assert len(input_positions) == sum(seq_lens) - torch.testing.assert_close(input_tokens, - input_positions) + torch.testing.assert_close(input_tokens, input_positions) # - Encoder assert len(encoder_input_tokens) == sum(encoder_seq_lens) assert len(encoder_input_tokens) == sum(encoder_seq_lens) - torch.testing.assert_close(encoder_input_tokens, - encoder_input_positions) + torch.testing.assert_close(encoder_input_tokens, encoder_input_positions) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, @@ -196,6 +192,7 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): dtype=actual.dtype) torch.testing.assert_close(actual, expected) + @pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) @@ -204,13 +201,11 @@ def test_prepare_decode(batch_size, backend_name, enforce_eager, monkeypatch): # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) - model_runner = _create_model_runner( - "facebook/bart-base", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=enforce_eager - ) + model_runner = _create_model_runner("facebook/bart-base", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=enforce_eager) seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] @@ -299,20 +294,22 @@ def test_prepare_decode(batch_size, backend_name, enforce_eager, monkeypatch): torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) assert torch.allclose( attn_metadata.context_lens_tensor, - torch.tensor([seq_len-1 for seq_len in seq_lens], - dtype=torch.int, - device=device)) + torch.tensor([seq_len - 1 for seq_len in seq_lens], + dtype=torch.int, + device=device)) # Verify block tables are correct for prompts # - Decoder self-attention - expected = torch.tensor([block_tables[0] for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device) + expected = torch.tensor( + [block_tables[0] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) assert torch.allclose(attn_metadata.block_tables, expected) # - Encoder/decoder cross-attention - expected = torch.tensor([cross_block_table for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device) + expected = torch.tensor( + [cross_block_table for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) assert torch.allclose(attn_metadata.cross_block_tables, expected) # Cuda graph should not be used for prefill. @@ -322,12 +319,12 @@ def test_prepare_decode(batch_size, backend_name, enforce_eager, monkeypatch): # - Decoder assert len(input_tokens) == len(seq_lens) assert len(input_positions) == len(seq_lens) - torch.testing.assert_close(input_tokens, - input_positions) + torch.testing.assert_close(input_tokens, input_positions) # - Encoder assert len(encoder_input_tokens) == 0 assert len(encoder_input_positions) == 0 + @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 4b083b510df3b..4c3054dab3517 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -1,7 +1,8 @@ # Derived from BART implementation posted on HuggingFace; license below: # # coding=utf-8 -# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. +# All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index bfa84c589745c..19d1c0bbc781d 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -129,7 +129,7 @@ def _prepare_encoder_model_input( sliding_window_blocks * self.block_size for seq_group_metadata in seq_group_metadata_list: - computed_block_nums = None #seq_group_metadata.computed_block_nums + computed_block_nums = None #seq_group_metadata.computed_block_nums is_prompt = seq_group_metadata.is_prompt @@ -174,7 +174,8 @@ def _prepare_encoder_model_input( device=self.device, ) - assert (not is_prompt) or max_query_len > 0, ("query_lens: {}".format(query_lens)) + assert (not is_prompt) or max_query_len > 0, ( + "query_lens: {}".format(query_lens)) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 73122eb9fb3fa..2913d20cf8049 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -254,37 +254,48 @@ def get_max_block_per_batch(self) -> int: return (self.max_seq_len_to_capture + block_size - 1) // block_size def _prepare_seq_model_input( - self, - is_prompt: bool, - decode_only: bool, - num_prefills: int, - num_prefill_tokens: int, - num_decode_tokens: int, - block_tables: List[List[int]], - seq_lens: List[int], - slot_mapping: List[int], - context_lens: List[int], - query_lens: List[int], - input_tokens: List[int], - input_positions: List[int], - prefill_seq_lens: List[int], - decode_seq_lens: List[int], - seq_group_metadata: SequenceGroupMetadata, - seq_data: SequenceData, - computed_block_nums: Optional[List[int]], - original_block_table: Optional[List[int]], - paged_kv_indices: Optional[List[int]], - paged_kv_indptr: Optional[List[int]], - paged_kv_last_page_len: Optional[List[int]], - sliding_window_blocks: int = 0, - block_aligned_sliding_window: int = 0, - lora_index_mapping: List[int] = [], - lora_prompt_mapping: List[int] = [], - lora_requests: Set[LoRARequest] = set(), - multi_modal_kwargs_list: Dict[str, - List[torch.Tensor]] = defaultdict(list), - is_encoder_seq: bool = False - ) -> Tuple[bool, int, int, int]: + self, + is_prompt: bool, + decode_only: bool, + num_prefills: int, + num_prefill_tokens: int, + num_decode_tokens: int, + block_tables: List[List[int]], + seq_lens: List[int], + slot_mapping: List[int], + context_lens: List[int], + query_lens: List[int], + input_tokens: List[int], + input_positions: List[int], + prefill_seq_lens: List[int], + decode_seq_lens: List[int], + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, + computed_block_nums: Optional[List[int]], + original_block_table: Optional[List[int]], + paged_kv_indices: Optional[List[int]], + paged_kv_indptr: Optional[List[int]], + paged_kv_last_page_len: Optional[List[int]], + sliding_window_blocks: int = 0, + block_aligned_sliding_window: int = 0, + lora_index_mapping: Optional[List[int]] = None, + lora_prompt_mapping: Optional[List[int]] = None, + lora_requests: Optional[Set[LoRARequest]] = None, + multi_modal_kwargs_list: Optional[Dict[str, List[torch.Tensor]]] \ + = None, + is_encoder_seq: bool = False) -> Tuple[bool, int, int, int]: + + if lora_index_mapping is None: + lora_index_mapping = [] + + if lora_prompt_mapping is None: + lora_prompt_mapping = [] + + if lora_requests is None: + lora_requests = set() + + if multi_modal_kwargs_list is None: + multi_modal_kwargs_list = defaultdict(list) if is_prompt: context_len = seq_data.get_num_computed_tokens() @@ -304,8 +315,8 @@ def _prepare_seq_model_input( seq_len = seq_data.get_len() else: seq_len = min(seq_data.get_len(), - context_len + seq_group_metadata.token_chunk_size) - + context_len + seq_group_metadata.token_chunk_size) + if is_prompt: tokens = seq_data.get_token_ids()[context_len:seq_len] else: @@ -563,8 +574,8 @@ def _prepare_model_input( "now.") seq_data = seq_group_metadata.seq_data[seq_id] - block_table = None if seq_group_metadata.block_tables is None else \ - seq_group_metadata.block_tables[seq_id] + block_table = None if seq_group_metadata.block_tables is None \ + else seq_group_metadata.block_tables[seq_id] decode_only, \ num_prefills, \ num_prefill_tokens, \ @@ -1198,13 +1209,17 @@ def _get_graph_batch_size(batch_size: int) -> int: _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) -def _is_block_tables_empty(block_tables: Union[None, Dict]): +def _is_block_tables_empty(block_tables: Union[None, Dict, List]): """ Check if block_tables is None or a dictionary with all None values. """ if block_tables is None: return True + if isinstance(block_tables, List): + # A single block table passed in as a List + return False if isinstance(block_tables, dict) and all( value is None for value in block_tables.values()): + # Block tables dict where all block tables are None return True return False From de967174dcbbdb5e81d975edf158416bcbeb74cd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 02:25:36 -0400 Subject: [PATCH 250/443] wip bart test --- tests/models/test_bart.py | 40 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 tests/models/test_bart.py diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py new file mode 100644 index 0000000000000..175cfcc1126bc --- /dev/null +++ b/tests/models/test_bart.py @@ -0,0 +1,40 @@ +"""Compare the outputs of HF and vLLM for BART models using greedy sampling. + +Run `pytest tests/models/test_bart.py`. +""" +import pytest + +from .utils import check_logprobs_close + +MODELS = [ + "facebook/bart-base" +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + # TODO(sang): Sliding window should be tested separately. + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) From e9ecd25cb733b220785611056295ea9787b1ce47 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 05:48:50 -0400 Subject: [PATCH 251/443] added unoptimized BART example --- examples/offline_inference_encoder_decoder.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 examples/offline_inference_encoder_decoder.py diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py new file mode 100644 index 0000000000000..8e1be5b1f49fc --- /dev/null +++ b/examples/offline_inference_encoder_decoder.py @@ -0,0 +1,22 @@ +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="facebook/bart-base") +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") From 2b2d2e9df2b1535883e36b8353a26d52200f7783 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 08:55:19 -0400 Subject: [PATCH 252/443] wip encoder/decoder API integration; WIP BART integration; WIP BART example --- examples/offline_inference_encoder_decoder.py | 13 ++- tests/models/test_bart.py | 4 +- vllm/engine/async_llm_engine.py | 4 +- vllm/engine/llm_engine.py | 6 +- vllm/entrypoints/llm.py | 6 +- vllm/inputs.py | 104 +++++++++++++++++- vllm/model_executor/models/__init__.py | 2 +- vllm/model_executor/models/bart.py | 100 +++++++++++++++-- vllm/sequence.py | 30 ++++- vllm/utils.py | 20 ++++ vllm/worker/model_runner.py | 8 +- vllm/worker/worker.py | 20 +++- 12 files changed, 280 insertions(+), 37 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 8e1be5b1f49fc..3bf7f2e8660ee 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -1,12 +1,23 @@ from vllm import LLM, SamplingParams # Sample prompts. -prompts = [ +# - Encoder prompts +encoder_prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] +# - Decoder prompts +decoder_prompts = [ + "", + "", + "", + "", +] +# - Unified prompts +prompts = [enc_dec for enc_dec in zip(encoder_prompts,decoder_prompts)] + # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index 175cfcc1126bc..df76777a0de00 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -6,9 +6,7 @@ from .utils import check_logprobs_close -MODELS = [ - "facebook/bart-base" -] +MODELS = ["facebook/bart-base"] @pytest.mark.parametrize("model", MODELS) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index df25eb111e87f..6ff21a896e324 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -13,7 +13,7 @@ from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import LLMEngine from vllm.executor.ray_utils import initialize_ray_cluster, ray -from vllm.inputs import LLMInputs, PromptInputs +from vllm.inputs import LLMInputs, PromptInputs, LLMInputsOptions from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -263,7 +263,7 @@ async def process_model_inputs_async( request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ) -> LLMInputs: + ) -> LLMInputsOptions: if isinstance(inputs, str): inputs = {"prompt": inputs} diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 75d417f525e3a..7ef7e57fe678d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -20,7 +20,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import LLMInputs, PromptInputs +from vllm.inputs import LLMInputs, PromptInputs, LLMInputsOptions from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, @@ -453,7 +453,7 @@ def _get_eos_token_id( def _add_processed_request( self, request_id: str, - processed_inputs: LLMInputs, + processed_inputs: LLMInputsOptions, params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], @@ -497,7 +497,7 @@ def process_model_inputs( request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ) -> LLMInputs: + ) -> LLMInputsOptions: if isinstance(inputs, str): inputs = {"prompt": inputs} diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9e923493160ed..a511d14425e19 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -8,7 +8,8 @@ from vllm.engine.llm_engine import LLMEngine from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt, TextTokensPrompt, TokensPrompt, - parse_and_batch_prompt) + PromptStrictInputsOptions, parse_and_batch_prompt, + EncoderDecoderStringPrompts) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -243,8 +244,7 @@ def generate( "instead.") def generate( self, - prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], - Optional[Union[str, List[str]]]] = None, + prompts: PromptStrictInputsOptions = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, diff --git a/vllm/inputs.py b/vllm/inputs.py index 026903e19a26e..30851bff5e905 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -1,5 +1,5 @@ from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, - TypedDict, Union, cast, overload) + TypedDict, Union, cast, overload, Tuple) from typing_extensions import NotRequired @@ -79,6 +79,22 @@ class TextPrompt(TypedDict): """ +class EncoderDecoderTextPrompt(TypedDict): + """Schema for a dual text prompt (encoder & decoder prompts.)""" + + encoder_prompt: str + """The input text to be tokenized before passing to the encoder model.""" + + decoder_prompt: str + """The input text to be tokenized before passing to the decoder model.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + class TokensPrompt(TypedDict): """Schema for a tokenized prompt.""" @@ -92,6 +108,22 @@ class TokensPrompt(TypedDict): """ +class EncoderDecoderTokensPrompt(TypedDict): + """Schema for a dual tokenized prompt (encoder & decoder prompts)""" + + encoder_prompt_token_ids: List[int] + """A list of token IDs to pass to the encoder model.""" + + decoder_prompt_token_ids: List[int] + """A list of token IDs to pass to the decoder model.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + class TextTokensPrompt(TypedDict): """It is assumed that :attr:`prompt` is consistent with :attr:`prompt_token_ids`. This is currently used in @@ -111,7 +143,49 @@ class TextTokensPrompt(TypedDict): """ -PromptStrictInputs = Union[str, TextPrompt, TokensPrompt] +class EncoderDecoderTextTokensPrompt(TypedDict): + """It is assumed that :attr:`encoder_prompt` and :attr:`decoder_prompt` + are consistent with :attr:`encoder_prompt_token_ids` and + :attr:`decoder_prompt_token_ids`, respectively. This is currently used in + :class:`AsyncLLMEngine` for logging both the text and token IDs.""" + + encoder_prompt: str + """The encoder prompt text.""" + + encoder_prompt_token_ids: List[int] + """The token IDs of the encoder prompt. If None, we use the + tokenizer to convert the prompts to token IDs.""" + + decoder_prompt: str + """The decoder prompt text.""" + + decoder_prompt_token_ids: List[int] + """The token IDs of the decoder prompt. If None, we use the + tokenizer to convert the prompts to token IDs.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +EncoderDecoderStringPrompts = Tuple[str, str] + +EncoderDecoderPromptStrictInputs = Union[EncoderDecoderStringPrompts, + EncoderDecoderTextPrompt, + EncoderDecoderTokensPrompt] +""" +The inputs to the encoder/decoder LLM, +which can take one of the following forms: + +- A pair of encoder & decoder text prompts (:class:`tuple` of two :class:`str` + i.e. (encoder_prompt,decoder_prompt) or :class:`EncoderDecoderTextPrompt`) +- Tokenized encoder & decoder prompts (:class:`EncoderDecoderTokensPrompt`) +""" + +PromptStrictInputs = Union[str, TextPrompt, TokensPrompt, + EncoderDecoderPromptStrictInputs] """ The inputs to the LLM, which can take one of the following forms: @@ -119,12 +193,36 @@ class TextTokensPrompt(TypedDict): - A tokenized prompt (:class:`TokensPrompt`) """ -PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt] +EncoderDecoderPromptInputs = Union[EncoderDecoderStringPrompts, + EncoderDecoderTextPrompt, + EncoderDecoderTokensPrompt, + EncoderDecoderTextTokensPrompt] +"""Same as :const:`EncoderDecoderPromptStrictInputs` but additionally accepts +:class:`EncoderDecoderTextTokensPrompt`.""" + +PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt, + EncoderDecoderPromptInputs] """Same as :const:`PromptStrictInputs` but additionally accepts :class:`TextTokensPrompt`.""" +PromptStrictInputsOptions = Union[ + Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + Optional[Union[str, EncoderDecoderStringPrompts, List[str], + List[EncoderDecoderStringPrompts]]]] + class LLMInputs(TypedDict): prompt_token_ids: List[int] prompt: NotRequired[Optional[str]] multi_modal_data: NotRequired[Optional["MultiModalData"]] + + +class EncoderDecoderLLMInputs(TypedDict): + encoder_prompt_token_ids: List[int] + encoder_prompt: NotRequired[Optional[str]] + decoder_prompt_token_ids: NotRequired[Optional[List[int]]] + decoder_prompt: NotRequired[Optional[str]] + multi_modal_data: NotRequired[Optional["MultiModalData"]] + + +LLMInputsOptions = Union[LLMInputs, EncoderDecoderLLMInputs] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index b32d1967d87ac..cb049268db73d 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -68,7 +68,7 @@ } _CONDITIONAL_GENERATION_MODELS = { - "BartModel": ("bart", ), + "BartModel": ("bart", "BartForConditionalGeneration"), } _MODELS = { diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 4c3054dab3517..f7f12e2a79154 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -60,14 +60,14 @@ from torch import nn from torch.nn import CrossEntropyLoss -from ...activations import ACT2FN -from ...modeling_attn_mask_utils import ( +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) -from ...modeling_outputs import ( +from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -76,13 +76,15 @@ Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import ( +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, ) -from .configuration_bart import BartConfig +from transformers import BartConfig + +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -1778,7 +1780,7 @@ class BartForConditionalGeneration(BartPreTrainedModel): ] _keys_to_ignore_on_load_missing = ["final_logits_bias"] - def __init__(self, config: BartConfig): + def __init__(self, config: BartConfig, cache_config: CacheConfig, quant_config: QuantizationConfig): super().__init__(config) self.model = BartModel(config) self.register_buffer( @@ -1947,3 +1949,87 @@ def _reorder_cache(past_key_values, beam_idx): past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + return + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + expert_params_mapping = [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id) + ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id) + ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the activation scales for the experts + # (param_name, weight_name, expert_id) + ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", + f"experts.{expert_id}.{weight_name}.input_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) \ No newline at end of file diff --git a/vllm/sequence.py b/vllm/sequence.py index 287e1b9df6165..0517719264ff3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -3,12 +3,12 @@ import enum from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast import torch from vllm.block import LogicalTokenBlock -from vllm.inputs import LLMInputs +from vllm.inputs import LLMInputsOptions from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -212,7 +212,9 @@ class Sequence: Args: seq_id: The ID of the sequence. - inputs: The inputs of the sequence. + inputs: The inputs of the sequence. Note that for encoder/decoder + inputs, Sequence makes no use of the encoder prompt (which is + tracked at the level of the SequenceGroup) block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. lora_request: LoRA request. @@ -221,7 +223,7 @@ class Sequence: def __init__( self, seq_id: int, - inputs: LLMInputs, + inputs: LLMInputsOptions, block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, @@ -250,11 +252,27 @@ def __init__( @property def prompt(self) -> Optional[str]: - return self.inputs.get("prompt") + if "prompt" in self.inputs: + # Decoder-only prompt + return cast(Optional[str], self.inputs.get("prompt")) + elif "decoder_prompt" in self.inputs: + # In encoder/decoder scenario, self.prompt() + # returns the decoder prompt + return cast(Optional[str], self.inputs.get("decoder_prompt")) + else: + raise AttributeError("Invalid Sequence.inputs: {self.inputs}") @property def prompt_token_ids(self) -> List[int]: - return self.inputs["prompt_token_ids"] + if "prompt_token_ids" in self.inputs: + # Decoder-only prompt + return cast(List[int], self.inputs.get("prompt_token_ids")) + elif "decoder_prompt_token_ids" in self.inputs: + # In encoder/decoder scenario, self.prompt_token_ids() + # returns the decoder prompt + return cast(List[int], self.inputs.get("decoder_prompt_token_ids")) + else: + raise AttributeError("Invalid Sequence.inputs: {self.inputs}") @property def multi_modal_data(self) -> Optional["MultiModalData"]: diff --git a/vllm/utils.py b/vllm/utils.py index 92977d9b9961c..37cdfcf95662f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -872,3 +872,23 @@ def parse_args(self, args=None, namespace=None): processed_args.append(arg) return super().parse_args(processed_args, namespace) + +def is_encoder_decoder_model_config(model_config) -> bool: + ''' + Extract the HF encoder/decoder model flag from the ModelConfig instance. + + Return False if model_config is None. + ''' + return False if model_config is None else \ + getattr(model_config.hf_config, + "is_encoder_decoder", + False) + +def is_embedding_model_config(model_config) -> bool: + ''' + Extract the embedding model flag from the ModelConfig instance. + + Return False if model_config is None. + ''' + return False if model_config is None else \ + model_config.embedding_mode \ No newline at end of file diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2913d20cf8049..56206ee1ecbef 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -25,7 +25,8 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, - is_pin_memory_available, make_tensor_with_pad) + is_pin_memory_available, make_tensor_with_pad, + is_encoder_decoder_model_config) logger = init_logger(__name__) @@ -982,10 +983,7 @@ def _is_encoder_decoder_model(self): field of the HF config, if this field is present; otherwise return False. ''' - return False if self.model_config is None else \ - getattr(self.model_config.hf_config, - "is_encoder_decoder", - False) + return is_encoder_decoder_model_config(self.model_config) def _am_child(self): ''' diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index e334ffbb755bb..80586714abe1f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -20,8 +20,11 @@ from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import ModelRunner +from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.worker_base import WorkerBase - +from vllm.utils import (is_embedding_model_config, + is_encoder_decoder_model_config) + class Worker(WorkerBase): """A worker class that executes (a partition of) the model on a GPU. @@ -78,8 +81,13 @@ def __init__( or (speculative_config.draft_model_config.hf_config.model_type != "mlp_speculator") else {"return_hidden_states": True} - ModelRunnerClass = (EmbeddingModelRunner if - self.model_config.embedding_mode else ModelRunner) + if is_embedding_model_config(self.model_config): + ModelRunnerClass = EmbeddingModelRunner + elif is_encoder_decoder_model_config(self.model_config): + ModelRunnerClass = EncoderDecoderModelRunner + else: + ModelRunnerClass = ModelRunner + self.model_runner = ModelRunnerClass( model_config, parallel_config, @@ -99,6 +107,12 @@ def __init__( # Initialize gpu_cache as embedding models don't initialize kv_caches self.gpu_cache: Optional[List[torch.tensor]] = None + def _is_encoder_decoder_model(self) -> bool: + return is_encoder_decoder_model_config(self.model_config) + + def _is_embedding_model(self) -> bool: + return is_embedding_model_config(self.model_config) + def init_device(self) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until From 7000573396666a58cf5ca06d626f2b4c2e4f8bb2 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 09:49:37 -0400 Subject: [PATCH 253/443] temporarily removing BART work --- examples/offline_inference_encoder_decoder.py | 33 - vllm/model_executor/models/__init__.py | 10 - vllm/model_executor/models/bart.py | 2035 ----------------- 3 files changed, 2078 deletions(-) delete mode 100644 examples/offline_inference_encoder_decoder.py delete mode 100644 vllm/model_executor/models/bart.py diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py deleted file mode 100644 index 3bf7f2e8660ee..0000000000000 --- a/examples/offline_inference_encoder_decoder.py +++ /dev/null @@ -1,33 +0,0 @@ -from vllm import LLM, SamplingParams - -# Sample prompts. -# - Encoder prompts -encoder_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# - Decoder prompts -decoder_prompts = [ - "", - "", - "", - "", -] -# - Unified prompts -prompts = [enc_dec for enc_dec in zip(encoder_prompts,decoder_prompts)] - -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - -# Create an LLM. -llm = LLM(model="facebook/bart-base") -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index cb049268db73d..8b45364d757cf 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -67,16 +67,6 @@ "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), } -_CONDITIONAL_GENERATION_MODELS = { - "BartModel": ("bart", "BartForConditionalGeneration"), -} - -_MODELS = { - **_GENERATION_MODELS, - **_EMBEDDING_MODELS, - **_CONDITIONAL_GENERATION_MODELS -} - # Architecture -> type. # out of tree models _OOT_MODELS: Dict[str, Type[nn.Module]] = {} diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py deleted file mode 100644 index f7f12e2a79154..0000000000000 --- a/vllm/model_executor/models/bart.py +++ /dev/null @@ -1,2035 +0,0 @@ -# Derived from BART implementation posted on HuggingFace; license below: -# -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch BART model.""" -from typing import Iterable, List, Optional, Tuple - -import torch -from torch import nn -from transformers import MixtralConfig - -from vllm import _custom_ops as ops -from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, - per_tensor_dequantize, - per_tensor_quantize) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import SamplerOutput -from vllm.utils import print_warning_once - -import copy -import math -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers.activations import ACT2FN -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_attention_mask, - _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithCrossAttentions, - Seq2SeqLMOutput, - Seq2SeqModelOutput, - Seq2SeqQuestionAnsweringModelOutput, - Seq2SeqSequenceClassifierOutput, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, -) -from transformers import BartConfig - -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "facebook/bart-base" -_CONFIG_FOR_DOC = "BartConfig" - -# Base model docstring -_EXPECTED_OUTPUT_SHAPE = [1, 8, 768] - -# SequenceClassification docstring -_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/bart-large-sst2" -_SEQ_CLASS_EXPECTED_LOSS = 0.0 -_SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'" - -# QuestionAsnwering docstring -_CHECKPOINT_FOR_QA = "valhalla/bart-large-finetuned-squadv1" -_QA_EXPECTED_LOSS = 0.59 -_QA_EXPECTED_OUTPUT = "' nice puppet'" - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, - decoder_start_token_id: int): - """ - Shift input ids one token to the right. - """ - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() - shifted_input_ids[:, 0] = decoder_start_token_id - - if pad_token_id is None: - raise ValueError("self.model.config.pad_token_id has to be defined.") - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) - - return shifted_input_ids - - -class BartLearnedPositionalEmbedding(nn.Embedding): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int): - # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim) - - def forward(self, - input_ids: torch.Tensor, - past_key_values_length: int = 0): - """`input_ids' shape is expected to be [bsz x seqlen].""" - - bsz, seq_len = input_ids.shape[:2] - positions = torch.arange(past_key_values_length, - past_key_values_length + seq_len, - dtype=torch.long, - device=self.weight.device).expand(bsz, -1) - - return super().forward(positions + self.offset) - - -class BartScaledWordEmbedding(nn.Embedding): - """ - This module overrides nn.Embeddings' forward by multiplying with embeddings scale. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int, - embed_scale: Optional[float] = 1.0): - super().__init__(num_embeddings, embedding_dim, padding_idx) - self.embed_scale = embed_scale - - def forward(self, input_ids: torch.Tensor): - return super().forward(input_ids) * self.embed_scale - - -class BartAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_causal: bool = False, - config: Optional[BartConfig] = None, - ): - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - self.config = config - - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads}).") - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - self.is_causal = is_causal - - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if (is_cross_attention and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1]): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, - bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}") - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, - src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, - src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads, ): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}") - attn_weights = layer_head_mask.view( - 1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, - src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, - src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, - tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, - tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, - p=self.dropout, - training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, - self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}") - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, - self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class BartFlashAttention2(BartAttention): - """ - Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10( - ) - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - # BartFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError( - "BartFlashAttention2 attention does not support output_attentions" - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if (is_cross_attention and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1]): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, - bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat( - [past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat( - [past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), - value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}.") - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = self._flash_attention_forward(query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout) - - attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.out_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward - def _flash_attention_forward(self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, - query_length) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, - query_length) - else: - attn_output = flash_attn_func(query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal) - - return attn_output - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, - query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( - attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, - head_dim), indices_k) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, - head_dim), indices_k) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, - head_dim), indices_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( - query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -class BartSdpaAttention(BartAttention): - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if (is_cross_attention and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1]): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - -BART_ATTENTION_CLASSES = { - "eager": BartAttention, - "sdpa": BartSdpaAttention, - "flash_attention_2": BartFlashAttention2, -} - - -class BartEncoderLayer(nn.Module): - - def __init__(self, config: BartConfig): - super().__init__() - self.embed_dim = config.d_model - - self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - config=config, - ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: torch.FloatTensor, - layer_head_mask: torch.FloatTensor, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, - p=self.activation_dropout, - training=self.training) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() - or torch.isnan(hidden_states).any()): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, - min=-clamp_value, - max=clamp_value) - - outputs = (hidden_states, ) - - if output_attentions: - outputs += (attn_weights, ) - - return outputs - - -class BartDecoderLayer(nn.Module): - - def __init__(self, config: BartConfig): - super().__init__() - self.embed_dim = config.d_model - - self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - is_causal=True, - config=config, - ) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BART_ATTENTION_CLASSES[ - config._attn_implementation]( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - config=config, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) - self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = True, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, - torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - encoder_hidden_states (`torch.FloatTensor`): - cross attention input to the layer of shape `(batch, seq_len, embed_dim)` - encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of - size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[: - 2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[ - -2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, - p=self.activation_dropout, - training=self.training) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states, ) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (present_key_value, ) - - return outputs - - -class BartClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" - - def __init__( - self, - input_dim: int, - inner_dim: int, - num_classes: int, - pooler_dropout: float, - ): - super().__init__() - self.dense = nn.Linear(input_dim, inner_dim) - self.dropout = nn.Dropout(p=pooler_dropout) - self.out_proj = nn.Linear(inner_dim, num_classes) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dropout(hidden_states) - hidden_states = self.dense(hidden_states) - hidden_states = torch.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - -class BartPreTrainedModel(PreTrainedModel): - config_class = BartConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"] - _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - @property - def dummy_inputs(self): - pad_token = self.config.pad_token_id - input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], - device=self.device) - dummy_inputs = { - "attention_mask": input_ids.ne(pad_token), - "input_ids": input_ids, - } - return dummy_inputs - - -class PretrainedBartModel(BartPreTrainedModel): - - def __init_subclass__(self): - warnings.warn( - "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", - FutureWarning, - ) - - -class BartPretrainedModel(BartPreTrainedModel): - - def __init_subclass__(self): - warnings.warn( - "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", - FutureWarning, - ) - - -BART_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`BartConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -BART_GENERATION_EXAMPLE = r""" - Summarization example: - - ```python - >>> from transformers import AutoTokenizer, BartForConditionalGeneration - - >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") - - >>> ARTICLE_TO_SUMMARIZE = ( - ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " - ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " - ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." - ... ) - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt") - - >>> # Generate Summary - >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20) - >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions' - ``` - - Mask filling example: - - ```python - >>> from transformers import AutoTokenizer, BartForConditionalGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") - >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") - - >>> TXT = "My friends are but they eat too many carbs." - >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"] - >>> logits = model(input_ids).logits - - >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() - >>> probs = logits[0, masked_index].softmax(dim=0) - >>> values, predictions = probs.topk(5) - - >>> tokenizer.decode(predictions).split() - ['not', 'good', 'healthy', 'great', 'very'] - ``` -""" - -BART_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` - is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, - 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded - representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be - input (see `past_key_values`). This is useful if you want more control over how to convert - `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - - If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value - of `inputs_embeds`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class BartEncoder(BartPreTrainedModel): - """ - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - [`BartEncoderLayer`]. - - Args: - config: BartConfig - embed_tokens (nn.Embedding): output embedding - """ - - def __init__(self, - config: BartConfig, - embed_tokens: Optional[nn.Embedding] = None): - super().__init__(config) - - self.dropout = config.dropout - self.layerdrop = config.encoder_layerdrop - - embed_dim = config.d_model - self.padding_idx = config.pad_token_id - self.max_source_positions = config.max_position_embeddings - embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - - self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, - embed_dim, - self.padding_idx, - embed_scale=embed_scale) - - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - - self.embed_positions = BartLearnedPositionalEmbedding( - config.max_position_embeddings, - embed_dim, - ) - self.layers = nn.ModuleList( - [BartEncoderLayer(config) for _ in range(config.encoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" - self.layernorm_embedding = nn.LayerNorm(embed_dim) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - input = input_ids - input_ids = input_ids.view(-1, input_ids.shape[-1]) - elif inputs_embeds is not None: - input = inputs_embeds[:, :, -1] - else: - raise ValueError( - "You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - embed_pos = self.embed_positions(input) - embed_pos = embed_pos.to(inputs_embeds.device) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) - - # expand attention_mask - if attention_mask is not None: - if self._use_flash_attention_2: - attention_mask = attention_mask if 0 in attention_mask else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa( - attention_mask, inputs_embeds.dtype) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask( - attention_mask, inputs_embeds.dtype) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}.") - - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - to_drop = False - if self.training: - dropout_probability = torch.rand([]) - if dropout_probability < self.layerdrop: # skip the layer - to_drop = True - - if to_drop: - layer_outputs = (None, None) - else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] - if head_mask is not None else None), - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1], ) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) - - if not return_dict: - return tuple( - v for v in [hidden_states, encoder_states, all_attentions] - if v is not None) - return BaseModelOutput(last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions) - - -class BartDecoder(BartPreTrainedModel): - """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] - - Args: - config: BartConfig - embed_tokens (nn.Embedding): output embedding - """ - - def __init__(self, - config: BartConfig, - embed_tokens: Optional[nn.Embedding] = None): - super().__init__(config) - self.dropout = config.dropout - self.layerdrop = config.decoder_layerdrop - self.padding_idx = config.pad_token_id - self.max_target_positions = config.max_position_embeddings - embed_scale = math.sqrt( - config.d_model) if config.scale_embedding else 1.0 - - self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, - config.d_model, - self.padding_idx, - embed_scale=embed_scale) - - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - - self.embed_positions = BartLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - ) - self.layers = nn.ModuleList( - [BartDecoderLayer(config) for _ in range(config.decoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" - - self.layernorm_embedding = nn.LayerNorm(config.d_model) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - of the decoder. - encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values - selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing - cross-attention on hidden heads. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - input = input_ids - input_shape = input.shape - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - input = inputs_embeds[:, :, -1] - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[ - 2] if past_key_values is not None else 0 - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if ( - attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, - past_key_values_length) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1]) - - # embed positions - positions = self.embed_positions(input, past_key_values_length) - positions = positions.to(inputs_embeds.device) - - hidden_states = inputs_embeds + positions - hidden_states = self.layernorm_embedding(hidden_states) - - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if ( - output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None - - # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], - ["head_mask", "cross_attn_head_mask"]): - if attn_mask is not None: - if attn_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}.") - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states, ) - if self.training: - dropout_probability = torch.rand([]) - if dropout_probability < self.layerdrop: - continue - - past_key_value = past_key_values[ - idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] - if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] - if head_mask is not None else None), - cross_attn_layer_head_mask=(cross_attn_head_mask[idx] - if cross_attn_head_mask - is not None else None), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += ( - layer_outputs[3 if output_attentions else 1], ) - - if output_attentions: - all_self_attns += (layer_outputs[1], ) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2], ) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states, ) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [ - hidden_states, next_cache, all_hidden_states, all_self_attns, - all_cross_attentions - ] if v is not None) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -class BartModel(BartPreTrainedModel): - _tied_weights_keys = [ - "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" - ] - - def __init__(self, config: BartConfig): - super().__init__(config) - - padding_idx, vocab_size = config.pad_token_id, config.vocab_size - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) - - self.encoder = BartEncoder(config, self.shared) - self.decoder = BartDecoder(config, self.shared) - - # Initialize weights and apply final processing - self.post_init() - - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) - self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, value): - self.shared = value - self.encoder.embed_tokens = self.shared - self.decoder.embed_tokens = self.shared - - def get_encoder(self): - return self.encoder - - def get_decoder(self): - return self.decoder - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[List[torch.FloatTensor]] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Seq2SeqModelOutput]: - # different to other models, Bart automatically creates decoder_input_ids from - # input_ids if no decoder_input_ids are provided - if decoder_input_ids is None and decoder_inputs_embeds is None: - if input_ids is None: - raise ValueError( - "If no `decoder_input_ids` or `decoder_inputs_embeds` are " - "passed, `input_ids` cannot be `None`. Please pass either " - "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." - ) - - decoder_input_ids = shift_tokens_right( - input_ids, self.config.pad_token_id, - self.config.decoder_start_token_id) - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] - if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] - if len(encoder_outputs) > 2 else None, - ) - - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return Seq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -class BartForConditionalGeneration(BartPreTrainedModel): - base_model_prefix = "model" - _tied_weights_keys = [ - "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", - "lm_head.weight" - ] - _keys_to_ignore_on_load_missing = ["final_logits_bias"] - - def __init__(self, config: BartConfig, cache_config: CacheConfig, quant_config: QuantizationConfig): - super().__init__(config) - self.model = BartModel(config) - self.register_buffer( - "final_logits_bias", - torch.zeros((1, self.model.shared.num_embeddings))) - self.lm_head = nn.Linear(config.d_model, - self.model.shared.num_embeddings, - bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - - def resize_token_embeddings( - self, - new_num_tokens: int, - pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings( - new_num_tokens, pad_to_multiple_of) - self._resize_final_logits_bias(new_embeddings.weight.shape[0]) - return new_embeddings - - def _resize_final_logits_bias(self, new_num_tokens: int) -> None: - old_num_tokens = self.final_logits_bias.shape[-1] - if new_num_tokens <= old_num_tokens: - new_bias = self.final_logits_bias[:, :new_num_tokens] - else: - extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), - device=self.final_logits_bias.device) - new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) - self.register_buffer("final_logits_bias", new_bias) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def forward( - self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> Union[Tuple, Seq2SeqLMOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if labels is not None: - if use_cache: - logger.warning( - "The `use_cache` argument is changed to `False` since `labels` is provided." - ) - use_cache = False - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, - self.config.decoder_start_token_id) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - encoder_outputs=encoder_outputs, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - lm_logits = self.lm_head(outputs[0]) - lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) - - masked_lm_loss = None - if labels is not None: - labels = labels.to(lm_logits.device) - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct( - lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) - - if not return_dict: - output = (lm_logits, ) + outputs[1:] - return ((masked_lm_loss, ) + - output) if masked_lm_loss is not None else output - - return Seq2SeqLMOutput( - loss=masked_lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": - None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": - use_cache, # change this to avoid caching (presumably for debugging) - } - - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, - self.config.decoder_start_token_id) - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += (tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past[:2]) + layer_past[2:], ) - return reordered_past - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - return - - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - expert_params_mapping = [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id) - ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id) - ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] + [ - # These are the activation scales for the experts - # (param_name, weight_name, expert_id) - ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", - f"experts.{expert_id}.{weight_name}.input_scale", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for param_name, weight_name, expert_id in expert_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - print_warning_once( - "Found kv scale in the checkpoint " - f"(e.g. {name}), but not found the expected " - f"name in the model " - f"(e.g. {remapped_kv_scale_name}). " - "kv-scale is not loaded.") - continue - else: - name = remapped_kv_scale_name - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) \ No newline at end of file From a1ab7a110c334f54dc451f1b273c3b0f0345332e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 09:50:37 -0400 Subject: [PATCH 254/443] removing BART test --- tests/models/test_bart.py | 38 -------------------------------------- 1 file changed, 38 deletions(-) delete mode 100644 tests/models/test_bart.py diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py deleted file mode 100644 index df76777a0de00..0000000000000 --- a/tests/models/test_bart.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Compare the outputs of HF and vLLM for BART models using greedy sampling. - -Run `pytest tests/models/test_bart.py`. -""" -import pytest - -from .utils import check_logprobs_close - -MODELS = ["facebook/bart-base"] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: - # TODO(sang): Sliding window should be tested separately. - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) From beec4f5717d5c8193d70449c066f2aa469bf50b0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 10:24:50 -0400 Subject: [PATCH 255/443] enc/dec support in LLMEngine._add_processed_request() --- vllm/engine/llm_engine.py | 187 +++++++++++++++++++------ vllm/model_executor/models/__init__.py | 5 + 2 files changed, 150 insertions(+), 42 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7ef7e57fe678d..472b3456a4832 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -20,7 +20,8 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import LLMInputs, PromptInputs, LLMInputsOptions +from vllm.inputs import (LLMInputs, PromptInputs, LLMInputsOptions, + EncoderDecoderStringPrompts, EncoderDecoderLLMInputs) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, @@ -38,7 +39,9 @@ get_tokenizer_group) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter +from vllm.utils import (Counter, + is_encoder_decoder_model_config, + is_embedding_model_config) from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -239,7 +242,7 @@ def __init__( load_config=load_config, ) - if not self.model_config.embedding_mode: + if not self._is_embedding_model(): self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. @@ -316,6 +319,12 @@ def __init__( ), )) + def _is_encoder_decoder_model(self): + return is_encoder_decoder_model_config(self.model_config) + + def _is_embedding_model(self): + return is_embedding_model_config(self.model_config) + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -464,33 +473,78 @@ def _add_processed_request( seq_id = next(self.seq_counter) eos_token_id = self._get_eos_token_id(lora_request) - seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, - lora_request) - - # Create a SequenceGroup based on SamplingParams or PoolingParams - if isinstance(params, SamplingParams): - seq_group = self._create_sequence_group_with_sampling( - request_id, - seq, - params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - ) - elif isinstance(params, PoolingParams): - seq_group = self._create_sequence_group_with_pooling( - request_id, - seq, - params, - arrival_time=arrival_time, - lora_request=lora_request, - ) + if self._is_encoder_decoder_model(): + # Add encoder/decoder model request + encoder_seq_id = 0 # Encoder sequence id is not used + + processed_encoder_inputs = {"prompt": processed_inputs.get("encoder_prompt"), + "prompt_token_ids": + processed_inputs + .get("encoder_prompt_token_ids")} + + encoder_seq = Sequence(encoder_seq_id, processed_encoder_inputs, block_size, eos_token_id, + lora_request) + + decoder_seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, + lora_request) + + + + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + decoder_seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + encoder_seq=encoder_seq + ) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + decoder_seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + encoder_seq=encoder_seq + ) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) else: - raise ValueError( - "Either SamplingParams or PoolingParams must be provided.") + # Add decoder-only model request + decoder_seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, + lora_request) + + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + decoder_seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + ) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + decoder_seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) def process_model_inputs( self, @@ -498,22 +552,67 @@ def process_model_inputs( inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, ) -> LLMInputsOptions: - if isinstance(inputs, str): - inputs = {"prompt": inputs} + + if self._is_encoder_decoder_model(): + # Encoder/decoder model input + + if isinstance(inputs, str): + # Interpret a single input prompt as a single encoder input + # (leave decoder input to default) + inputs = {"encoder_prompt": inputs} + + if isinstance(inputs,EncoderDecoderStringPrompts): + # Interpret a tuple of input string prompts as a single + # encoder input and a single decoder input, respectively + inputs = {"encoder_prompt": inputs[0], + "decoder_prompt": inputs[1]} + + input_has_decoder_token_ids = "decoder_prompt_token_ids" in inputs + input_has_encoder_token_ids = "encoder_prompt_token_ids" in inputs + + if not (input_has_decoder_token_ids and \ + input_has_encoder_token_ids): + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + if not input_has_decoder_token_ids: + decoder_prompt_token_ids = tokenizer.encode(request_id=request_id, + prompt=inputs["decoder_prompt"], + lora_request=lora_request) + + encoder_prompt_token_ids = tokenizer.encode(request_id=request_id, + prompt=inputs["encoder_prompt"], + lora_request=lora_request) - if "prompt_token_ids" not in inputs: - tokenizer = self.get_tokenizer_group("prompts must be None if " - "skip_tokenizer_init is True") + else: + decoder_prompt_token_ids = inputs["decoder_prompt_token_ids"] + encoder_prompt_token_ids = inputs["encoder_prompt_token_ids"] + + return EncoderDecoderLLMInputs(decoder_prompt_token_ids=decoder_prompt_token_ids, + decoder_prompt=inputs.get("decoder_prompt"), + encoder_prompt_token_ids=encoder_prompt_token_ids, + encoder_prompt=inputs.get("encoder_prompt"), + multi_modal_data=inputs.get("multi_modal_data")) - prompt_token_ids = tokenizer.encode(request_id=request_id, - prompt=inputs["prompt"], - lora_request=lora_request) else: - prompt_token_ids = inputs["prompt_token_ids"] + # Decoder-only model input + + if isinstance(inputs, str): + inputs = {"prompt": inputs} + + if "prompt_token_ids" not in inputs: + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = tokenizer.encode(request_id=request_id, + prompt=inputs["prompt"], + lora_request=lora_request) + else: + prompt_token_ids = inputs["prompt_token_ids"] - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) def add_request( self, @@ -593,6 +692,7 @@ def _create_sequence_group_with_sampling( arrival_time: float, lora_request: Optional[LoRARequest], trace_headers: Optional[Dict[str, str]] = None, + encoder_seq: Optional[Sequence] = None ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -621,6 +721,7 @@ def _create_sequence_group_with_sampling( sampling_params=sampling_params, lora_request=lora_request, trace_headers=trace_headers, + encoder_seq=encoder_seq ) return seq_group @@ -632,6 +733,7 @@ def _create_sequence_group_with_pooling( pooling_params: PoolingParams, arrival_time: float, lora_request: Optional[LoRARequest], + encoder_seq: Optional[Sequence] = None ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -641,7 +743,8 @@ def _create_sequence_group_with_pooling( seqs=[seq], arrival_time=arrival_time, lora_request=lora_request, - pooling_params=pooling_params) + pooling_params=pooling_params, + encoder_seq=encoder_seq) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: @@ -717,7 +820,7 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - if self.model_config.embedding_mode: + if self._is_embedding_model(): self._process_sequence_group_outputs(seq_group, outputs) continue diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 8b45364d757cf..cace072aac837 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -67,6 +67,11 @@ "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), } +_MODELS = { + **_GENERATION_MODELS, + **_EMBEDDING_MODELS, +} + # Architecture -> type. # out of tree models _OOT_MODELS: Dict[str, Type[nn.Module]] = {} From b6d4383e141e1fc23ee0c8c6bb9a7d172949266a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 10:46:15 -0400 Subject: [PATCH 256/443] enc/dec integrated in Scheduler.schedule() --- vllm/core/scheduler.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 48c34625c08ae..736196f8cd88d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -960,6 +960,13 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data: Dict[int, SequenceData] = {} # seq_id -> physical block numbers block_tables: Dict[int, List[int]] = {} + # Encoder associated with SequenceGroup + encoder_seq_data: SequenceData = seq_group.get_encoder_seq().data if seq_group.is_encoder_decoder() else \ + None + # Block table for cross-attention + # Also managed at SequenceGroup level + cross_block_table: List[int] = self.block_manager.get_cross_block_table(seq_group) if seq_group.is_encoder_decoder() else \ + None for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id @@ -1000,6 +1007,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, state=seq_group.state, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but From 614de4e13869f1b2938d1f30369bbb98752a20c6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 10:54:25 -0400 Subject: [PATCH 257/443] formatting --- .../test_encoder_decoder_model_runner.py | 4 +- vllm/core/scheduler.py | 12 +- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 140 +++++++++--------- vllm/entrypoints/llm.py | 6 +- vllm/inputs.py | 4 +- vllm/utils.py | 4 +- vllm/worker/enc_dec_model_runner.py | 9 +- vllm/worker/model_runner.py | 6 +- vllm/worker/worker.py | 16 +- 10 files changed, 105 insertions(+), 98 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index dd113a640f383..bce35e04d9f68 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -3,12 +3,12 @@ import pytest import torch +from tests.kernels.utils import (STR_XFORMERS_ATTN_VAL, + override_backend_env_variable) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner -from tests.kernels.utils import (override_backend_env_variable, - STR_XFORMERS_ATTN_VAL) # Backends under test # diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 736196f8cd88d..0ce0fa88095a0 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -961,12 +961,16 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # seq_id -> physical block numbers block_tables: Dict[int, List[int]] = {} # Encoder associated with SequenceGroup - encoder_seq_data: SequenceData = seq_group.get_encoder_seq().data if seq_group.is_encoder_decoder() else \ - None + encoder_seq_data: SequenceData = \ + seq_group.get_encoder_seq().data \ + if seq_group.is_encoder_decoder() else \ + None # Block table for cross-attention # Also managed at SequenceGroup level - cross_block_table: List[int] = self.block_manager.get_cross_block_table(seq_group) if seq_group.is_encoder_decoder() else \ - None + cross_block_table: List[int] = \ + self.block_manager.get_cross_block_table(seq_group) \ + if seq_group.is_encoder_decoder() else \ + None for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6ff21a896e324..90a2550bd4ca7 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -13,7 +13,7 @@ from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import LLMEngine from vllm.executor.ray_utils import initialize_ray_cluster, ray -from vllm.inputs import LLMInputs, PromptInputs, LLMInputsOptions +from vllm.inputs import LLMInputs, LLMInputsOptions, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 472b3456a4832..39505ddbf1f9b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -20,8 +20,8 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (LLMInputs, PromptInputs, LLMInputsOptions, - EncoderDecoderStringPrompts, EncoderDecoderLLMInputs) +from vllm.inputs import (EncoderDecoderLLMInputs, EncoderDecoderStringPrompts, + LLMInputs, LLMInputsOptions, PromptInputs) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, @@ -39,9 +39,8 @@ get_tokenizer_group) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import (Counter, - is_encoder_decoder_model_config, - is_embedding_model_config) +from vllm.utils import (Counter, is_embedding_model_config, + is_encoder_decoder_model_config) from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -475,20 +474,20 @@ def _add_processed_request( if self._is_encoder_decoder_model(): # Add encoder/decoder model request - encoder_seq_id = 0 # Encoder sequence id is not used + encoder_seq_id = 0 # Encoder sequence id is not used - processed_encoder_inputs = {"prompt": processed_inputs.get("encoder_prompt"), - "prompt_token_ids": - processed_inputs - .get("encoder_prompt_token_ids")} - - encoder_seq = Sequence(encoder_seq_id, processed_encoder_inputs, block_size, eos_token_id, - lora_request) - - decoder_seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, - lora_request) + processed_encoder_inputs = { + "prompt": + processed_inputs.get("encoder_prompt"), + "prompt_token_ids": + processed_inputs.get("encoder_prompt_token_ids") + } + encoder_seq = Sequence(encoder_seq_id, processed_encoder_inputs, + block_size, eos_token_id, lora_request) + decoder_seq = Sequence(seq_id, processed_inputs, block_size, + eos_token_id, lora_request) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): @@ -499,8 +498,7 @@ def _add_processed_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, - encoder_seq=encoder_seq - ) + encoder_seq=encoder_seq) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( request_id, @@ -508,8 +506,7 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, - encoder_seq=encoder_seq - ) + encoder_seq=encoder_seq) else: raise ValueError( "Either SamplingParams or PoolingParams must be provided.") @@ -518,8 +515,8 @@ def _add_processed_request( self.scheduler.add_seq_group(seq_group) else: # Add decoder-only model request - decoder_seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, - lora_request) + decoder_seq = Sequence(seq_id, processed_inputs, block_size, + eos_token_id, lora_request) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): @@ -552,7 +549,7 @@ def process_model_inputs( inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, ) -> LLMInputsOptions: - + if self._is_encoder_decoder_model(): # Encoder/decoder model input @@ -561,38 +558,44 @@ def process_model_inputs( # (leave decoder input to default) inputs = {"encoder_prompt": inputs} - if isinstance(inputs,EncoderDecoderStringPrompts): + if isinstance(inputs, EncoderDecoderStringPrompts): # Interpret a tuple of input string prompts as a single # encoder input and a single decoder input, respectively - inputs = {"encoder_prompt": inputs[0], - "decoder_prompt": inputs[1]} - + inputs = { + "encoder_prompt": inputs[0], + "decoder_prompt": inputs[1] + } + input_has_decoder_token_ids = "decoder_prompt_token_ids" in inputs input_has_encoder_token_ids = "encoder_prompt_token_ids" in inputs if not (input_has_decoder_token_ids and \ input_has_encoder_token_ids): - tokenizer = self.get_tokenizer_group("prompts must be None if " - "skip_tokenizer_init is True") + tokenizer = self.get_tokenizer_group( + "prompts must be None if " + "skip_tokenizer_init is True") if not input_has_decoder_token_ids: - decoder_prompt_token_ids = tokenizer.encode(request_id=request_id, - prompt=inputs["decoder_prompt"], - lora_request=lora_request) + decoder_prompt_token_ids = tokenizer.encode( + request_id=request_id, + prompt=inputs["decoder_prompt"], + lora_request=lora_request) - encoder_prompt_token_ids = tokenizer.encode(request_id=request_id, - prompt=inputs["encoder_prompt"], - lora_request=lora_request) + encoder_prompt_token_ids = tokenizer.encode( + request_id=request_id, + prompt=inputs["encoder_prompt"], + lora_request=lora_request) else: decoder_prompt_token_ids = inputs["decoder_prompt_token_ids"] encoder_prompt_token_ids = inputs["encoder_prompt_token_ids"] - return EncoderDecoderLLMInputs(decoder_prompt_token_ids=decoder_prompt_token_ids, - decoder_prompt=inputs.get("decoder_prompt"), - encoder_prompt_token_ids=encoder_prompt_token_ids, - encoder_prompt=inputs.get("encoder_prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + return EncoderDecoderLLMInputs( + decoder_prompt_token_ids=decoder_prompt_token_ids, + decoder_prompt=inputs.get("decoder_prompt"), + encoder_prompt_token_ids=encoder_prompt_token_ids, + encoder_prompt=inputs.get("encoder_prompt"), + multi_modal_data=inputs.get("multi_modal_data")) else: # Decoder-only model input @@ -601,8 +604,9 @@ def process_model_inputs( inputs = {"prompt": inputs} if "prompt_token_ids" not in inputs: - tokenizer = self.get_tokenizer_group("prompts must be None if " - "skip_tokenizer_init is True") + tokenizer = self.get_tokenizer_group( + "prompts must be None if " + "skip_tokenizer_init is True") prompt_token_ids = tokenizer.encode(request_id=request_id, prompt=inputs["prompt"], @@ -611,8 +615,8 @@ def process_model_inputs( prompt_token_ids = inputs["prompt_token_ids"] return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) def add_request( self, @@ -685,15 +689,14 @@ def add_request( ) def _create_sequence_group_with_sampling( - self, - request_id: str, - seq: Sequence, - sampling_params: SamplingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - trace_headers: Optional[Dict[str, str]] = None, - encoder_seq: Optional[Sequence] = None - ) -> SequenceGroup: + self, + request_id: str, + seq: Sequence, + sampling_params: SamplingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + trace_headers: Optional[Dict[str, str]] = None, + encoder_seq: Optional[Sequence] = None) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs if (sampling_params.logprobs @@ -714,27 +717,24 @@ def _create_sequence_group_with_sampling( self.generation_config_fields) # Create the sequence group. - seq_group = SequenceGroup( - request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - sampling_params=sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - encoder_seq=encoder_seq - ) + seq_group = SequenceGroup(request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + sampling_params=sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + encoder_seq=encoder_seq) return seq_group def _create_sequence_group_with_pooling( - self, - request_id: str, - seq: Sequence, - pooling_params: PoolingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - encoder_seq: Optional[Sequence] = None - ) -> SequenceGroup: + self, + request_id: str, + seq: Sequence, + pooling_params: PoolingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + encoder_seq: Optional[Sequence] = None) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler pooling_params = pooling_params.clone() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a511d14425e19..17abc1d99afc8 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,10 +6,10 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt, +from vllm.inputs import (PromptInputs, PromptStrictInputs, + PromptStrictInputsOptions, TextPrompt, TextTokensPrompt, TokensPrompt, - PromptStrictInputsOptions, parse_and_batch_prompt, - EncoderDecoderStringPrompts) + parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput diff --git a/vllm/inputs.py b/vllm/inputs.py index 30851bff5e905..d893551d89195 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -1,5 +1,5 @@ -from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, - TypedDict, Union, cast, overload, Tuple) +from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, Tuple, + TypedDict, Union, cast, overload) from typing_extensions import NotRequired diff --git a/vllm/utils.py b/vllm/utils.py index 37cdfcf95662f..84802bc5b5dae 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -873,6 +873,7 @@ def parse_args(self, args=None, namespace=None): return super().parse_args(processed_args, namespace) + def is_encoder_decoder_model_config(model_config) -> bool: ''' Extract the HF encoder/decoder model flag from the ModelConfig instance. @@ -884,6 +885,7 @@ def is_encoder_decoder_model_config(model_config) -> bool: "is_encoder_decoder", False) + def is_embedding_model_config(model_config) -> bool: ''' Extract the embedding model flag from the ModelConfig instance. @@ -891,4 +893,4 @@ def is_embedding_model_config(model_config) -> bool: Return False if model_config is None. ''' return False if model_config is None else \ - model_config.embedding_mode \ No newline at end of file + model_config.embedding_mode diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 19d1c0bbc781d..76dccd1074d45 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -16,11 +16,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.utils import make_tensor_with_pad -from vllm.worker.model_runner import (LORA_WARMUP_RANK, ModelInput, - ModelRunner) -from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, - STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED) +from vllm.utils import (STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED, + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, + make_tensor_with_pad) +from vllm.worker.model_runner import LORA_WARMUP_RANK, ModelInput, ModelRunner logger = init_logger(__name__) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 56206ee1ecbef..b49967e956943 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -24,9 +24,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, - is_pin_memory_available, make_tensor_with_pad, - is_encoder_decoder_model_config) +from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, + is_encoder_decoder_model_config, is_hip, + is_pin_memory_available, make_tensor_with_pad) logger = init_logger(__name__) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 80586714abe1f..fe65f53275a49 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import torch import torch.distributed @@ -17,14 +17,14 @@ from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput +from vllm.utils import (is_embedding_model_config, + is_encoder_decoder_model_config) from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner -from vllm.worker.model_runner import ModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner +from vllm.worker.model_runner import ModelRunner from vllm.worker.worker_base import WorkerBase -from vllm.utils import (is_embedding_model_config, - is_encoder_decoder_model_config) - + class Worker(WorkerBase): """A worker class that executes (a partition of) the model on a GPU. @@ -81,12 +81,14 @@ def __init__( or (speculative_config.draft_model_config.hf_config.model_type != "mlp_speculator") else {"return_hidden_states": True} + ModelRunnerClass: Union[Type[EmbeddingModelRunner], + Type[EncoderDecoderModelRunner], + Type[ModelRunner]] = ModelRunner + if is_embedding_model_config(self.model_config): ModelRunnerClass = EmbeddingModelRunner elif is_encoder_decoder_model_config(self.model_config): ModelRunnerClass = EncoderDecoderModelRunner - else: - ModelRunnerClass = ModelRunner self.model_runner = ModelRunnerClass( model_config, From c15731710bd5c317638fef4d861567031d6126b8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 11:30:25 -0400 Subject: [PATCH 258/443] free sequence groups --- vllm/core/scheduler.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0ce0fa88095a0..725ddaec27659 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -356,6 +356,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: continue seq.status = SequenceStatus.FINISHED_ABORTED self.free_seq(seq) + self.free_seq_group(aborted_group) def has_unfinished_seqs(self) -> bool: return len(self.waiting) != 0 or len(self.running) != 0 or len( @@ -1039,9 +1040,27 @@ def free_seq(self, seq: Sequence) -> None: """Free a sequence from a block table.""" self.block_manager.free(seq) + def free_seq_group(self, seq_group: SequenceGroup) \ + -> None: + """ + Free a sequence group from a cross-attention block table. + Has no effect on decoder-only models. + """ + self.block_manager.free_cross(seq_group) + def free_finished_seq_groups(self) -> None: - self.running = deque(seq_group for seq_group in self.running - if not seq_group.is_finished()) + new_running: deque = deque() + for seq_group in self.running: + if seq_group.is_finished(): + # For encoder/decoder models, free cross- + # attention block table associated with finished + # seq_group + self.free_seq_group(seq_group) + else: + # Maintain `running` deque without finished + # sequence groups + new_running.append(seq_group) + self.running = new_running def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) From 84c0dcc5fe2b653cb0517df523504a107055061a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 11:58:45 -0400 Subject: [PATCH 259/443] scheduler tests --- tests/core/test_scheduler.py | 91 ++++++++++++++++++++++++++++++++++-- 1 file changed, 87 insertions(+), 4 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index bae958211cb7b..f93b0d89387a1 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -8,11 +8,14 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus from vllm.core.policy import PolicyFactory -from vllm.core.scheduler import Scheduler, SchedulingBudget +from vllm.core.scheduler import (Scheduler, SchedulingBudget, + SchedulerOutputs) from vllm.lora.request import LoRARequest -from vllm.sequence import Logprob, SequenceGroup, SequenceStatus +from vllm.sequence import (Logprob, SequenceGroup, SequenceStatus, + SequenceGroupMetadata) -from .utils import create_dummy_prompt +from .utils import (create_dummy_prompt, + create_dummy_prompt_encoder_decoder) def get_sequence_groups(scheduler_output): @@ -26,7 +29,9 @@ def append_new_token(out, token_id: int): seq.append_token_id(token_id, {token_id: Logprob(token_id)}) -def schedule_and_update_computed_tokens(scheduler): +def schedule_and_update_computed_tokens(scheduler) \ + -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: + metas, out = scheduler.schedule() for s, meta in zip(out.scheduled_seq_groups, metas): s.seq_group.update_num_computed_tokens(meta.token_chunk_size) @@ -54,6 +59,27 @@ def test_scheduler_add_seq_group(): scheduler.add_seq_group(seq_group) assert scheduler.get_num_unfinished_seq_groups() == i + 1 +# def test_scheduler_add_seq_group_encoder_decoder(): +# block_size = 4 +# scheduler_config = SchedulerConfig(100, 64, 1) +# cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto") +# cache_config.num_cpu_blocks = 4 +# cache_config.num_gpu_blocks = 4 +# scheduler = Scheduler(scheduler_config, cache_config, None) + +# # Add seq group to scheduler. +# num_seq_group = 4 +# for i in range(num_seq_group): +# # _, seq_group = create_dummy_prompt(str(i), block_size) +# req_id = str(i) +# _, _, seq_group = create_dummy_prompt_encoder_decoder(req_id, +# block_size, +# block_size, +# block_size) +# scheduler.add_seq_group(seq_group) +# assert scheduler.get_num_unfinished_seq_groups() == i + 1 +# # Verify that cross-attention block-table has been registered +# #assert req_id in scheduler.block_manager.cross_block_tables def test_scheduler_abort_seq_group(): block_size = 4 @@ -113,6 +139,63 @@ def test_scheduler_schedule_simple(): assert len(seq_group_meta) == num_seq_group append_new_token(out, 1) +def test_scheduler_schedule_simple_encoder_decoder(): + block_size = 4 + num_seq_group = 4 + max_model_len = 16 + scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group + cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + req_id_list=[] + for i in range(num_seq_group): + # _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + req_id = str(i) + req_id_list.append(req_id) + _, _, seq_group = create_dummy_prompt_encoder_decoder(req_id, + block_size, + block_size, + block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Schedule seq groups prompts. + num_tokens = block_size * num_seq_group + seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) + # - Verify that sequence group cross-attention block tables are + # registered with the block manager + assert all([(req_id in scheduler.block_manager.cross_block_tables) for req_id in req_id_list]) + assert set(get_sequence_groups(out)) == set(running) + assert out.num_batched_tokens == num_tokens + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta_list) == num_seq_group + append_new_token(out, 1) + + # Schedule seq groups generation. + seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) + # - Verify that sequence group metadata includes encoder attention + # and cross-attention metadata + assert all([not ((seq_group_meta.encoder_seq_data is None) or \ + (seq_group_meta.cross_block_table is None)) \ + for seq_group_meta in seq_group_meta_list]) + assert set(get_sequence_groups(out)) == set(running) + assert out.num_batched_tokens == num_seq_group + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta_list) == num_seq_group + append_new_token(out, 1) + + # Abort sequences + for req_id in req_id_list: + scheduler.abort_seq_group(req_id) + # - Verify that sequence group cross-attention block tables are + # NO LONGER registered with the block manager + assert req_id not in scheduler.block_manager.cross_block_tables def test_scheduler_prefill_prioritized(): """Verify running batched tokens are not applied to prefill requests.""" From 49c7162d70441963ec6c26430a8e36426fbfe1aa Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 12:01:59 -0400 Subject: [PATCH 260/443] formatting --- tests/core/test_scheduler.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index f93b0d89387a1..9f95853660eb6 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -8,14 +8,12 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus from vllm.core.policy import PolicyFactory -from vllm.core.scheduler import (Scheduler, SchedulingBudget, - SchedulerOutputs) +from vllm.core.scheduler import Scheduler, SchedulerOutputs, SchedulingBudget from vllm.lora.request import LoRARequest -from vllm.sequence import (Logprob, SequenceGroup, SequenceStatus, - SequenceGroupMetadata) +from vllm.sequence import (Logprob, SequenceGroup, SequenceGroupMetadata, + SequenceStatus) -from .utils import (create_dummy_prompt, - create_dummy_prompt_encoder_decoder) +from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder def get_sequence_groups(scheduler_output): @@ -59,6 +57,7 @@ def test_scheduler_add_seq_group(): scheduler.add_seq_group(seq_group) assert scheduler.get_num_unfinished_seq_groups() == i + 1 + # def test_scheduler_add_seq_group_encoder_decoder(): # block_size = 4 # scheduler_config = SchedulerConfig(100, 64, 1) @@ -72,15 +71,16 @@ def test_scheduler_add_seq_group(): # for i in range(num_seq_group): # # _, seq_group = create_dummy_prompt(str(i), block_size) # req_id = str(i) -# _, _, seq_group = create_dummy_prompt_encoder_decoder(req_id, -# block_size, -# block_size, +# _, _, seq_group = create_dummy_prompt_encoder_decoder(req_id, +# block_size, +# block_size, # block_size) # scheduler.add_seq_group(seq_group) # assert scheduler.get_num_unfinished_seq_groups() == i + 1 # # Verify that cross-attention block-table has been registered # #assert req_id in scheduler.block_manager.cross_block_tables + def test_scheduler_abort_seq_group(): block_size = 4 scheduler_config = SchedulerConfig(100, 64, 1) @@ -139,27 +139,26 @@ def test_scheduler_schedule_simple(): assert len(seq_group_meta) == num_seq_group append_new_token(out, 1) + def test_scheduler_schedule_simple_encoder_decoder(): block_size = 4 num_seq_group = 4 max_model_len = 16 scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group - cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group + cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group + cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. - req_id_list=[] + req_id_list = [] for i in range(num_seq_group): # _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) req_id = str(i) req_id_list.append(req_id) - _, _, seq_group = create_dummy_prompt_encoder_decoder(req_id, - block_size, - block_size, - block_size) + _, _, seq_group = create_dummy_prompt_encoder_decoder( + req_id, block_size, block_size, block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -168,7 +167,8 @@ def test_scheduler_schedule_simple_encoder_decoder(): seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) # - Verify that sequence group cross-attention block tables are # registered with the block manager - assert all([(req_id in scheduler.block_manager.cross_block_tables) for req_id in req_id_list]) + assert all([(req_id in scheduler.block_manager.cross_block_tables) + for req_id in req_id_list]) assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_tokens assert (not out.blocks_to_copy and not out.blocks_to_swap_in @@ -197,6 +197,7 @@ def test_scheduler_schedule_simple_encoder_decoder(): # NO LONGER registered with the block manager assert req_id not in scheduler.block_manager.cross_block_tables + def test_scheduler_prefill_prioritized(): """Verify running batched tokens are not applied to prefill requests.""" block_size = 4 From 213dc597274da4c963510b1d72166d0a8eddbc7b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 12:03:50 -0400 Subject: [PATCH 261/443] test_bart.py --- tests/models/test_bart.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/models/test_bart.py diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py new file mode 100644 index 0000000000000..df76777a0de00 --- /dev/null +++ b/tests/models/test_bart.py @@ -0,0 +1,38 @@ +"""Compare the outputs of HF and vLLM for BART models using greedy sampling. + +Run `pytest tests/models/test_bart.py`. +""" +import pytest + +from .utils import check_logprobs_close + +MODELS = ["facebook/bart-base"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + # TODO(sang): Sliding window should be tested separately. + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) From 28f0d2fff6752a90227aa8aa07ca32e43bee395d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 12:06:56 -0400 Subject: [PATCH 262/443] pulled in bart code --- examples/offline_inference_encoder_decoder.py | 33 +++++++++++++++++++ vllm/model_executor/models/__init__.py | 5 +++ 2 files changed, 38 insertions(+) create mode 100644 examples/offline_inference_encoder_decoder.py diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py new file mode 100644 index 0000000000000..3bf7f2e8660ee --- /dev/null +++ b/examples/offline_inference_encoder_decoder.py @@ -0,0 +1,33 @@ +from vllm import LLM, SamplingParams + +# Sample prompts. +# - Encoder prompts +encoder_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# - Decoder prompts +decoder_prompts = [ + "", + "", + "", + "", +] +# - Unified prompts +prompts = [enc_dec for enc_dec in zip(encoder_prompts,decoder_prompts)] + +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="facebook/bart-base") +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index cace072aac837..cb049268db73d 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -67,9 +67,14 @@ "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), } +_CONDITIONAL_GENERATION_MODELS = { + "BartModel": ("bart", "BartForConditionalGeneration"), +} + _MODELS = { **_GENERATION_MODELS, **_EMBEDDING_MODELS, + **_CONDITIONAL_GENERATION_MODELS } # Architecture -> type. From ed610b0b9a6abcdaf874d16225a441509a207076 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 12:09:51 -0400 Subject: [PATCH 263/443] pulled in bart model code --- vllm/model_executor/models/bart.py | 2035 ++++++++++++++++++++++++++++ 1 file changed, 2035 insertions(+) create mode 100644 vllm/model_executor/models/bart.py diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py new file mode 100644 index 0000000000000..f7f12e2a79154 --- /dev/null +++ b/vllm/model_executor/models/bart.py @@ -0,0 +1,2035 @@ +# Derived from BART implementation posted on HuggingFace; license below: +# +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BART model.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import MixtralConfig + +from vllm import _custom_ops as ops +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, + per_tensor_dequantize, + per_tensor_quantize) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import SamplerOutput +from vllm.utils import print_warning_once + +import copy +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, +) +from transformers import BartConfig + +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/bart-base" +_CONFIG_FOR_DOC = "BartConfig" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 768] + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/bart-large-sst2" +_SEQ_CLASS_EXPECTED_LOSS = 0.0 +_SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'" + +# QuestionAsnwering docstring +_CHECKPOINT_FOR_QA = "valhalla/bart-large-finetuned-squadv1" +_QA_EXPECTED_LOSS = 0.59 +_QA_EXPECTED_OUTPUT = "' nice puppet'" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, + decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class BartLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, + input_ids: torch.Tensor, + past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] + positions = torch.arange(past_key_values_length, + past_key_values_length + seq_len, + dtype=torch.long, + device=self.weight.device).expand(bsz, -1) + + return super().forward(positions + self.offset) + + +class BartScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +class BartAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[BartConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if (is_cross_attention and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1]): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, + bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, + src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads, ): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}") + attn_weights = layer_head_mask.view( + 1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, + src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, + tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, + tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, + p=self.dropout, + training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, + self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}") + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, + self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class BartFlashAttention2(BartAttention): + """ + Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10( + ) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + # BartFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError( + "BartFlashAttention2 attention does not support output_attentions" + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if (is_cross_attention and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1]): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, + bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat( + [past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat( + [past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), + value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}.") + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward(query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, + query_length) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, + query_length) + else: + attn_output = flash_attn_func(query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, + head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class BartSdpaAttention(BartAttention): + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if (is_cross_attention and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1]): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz) + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +BART_ATTENTION_CLASSES = { + "eager": BartAttention, + "sdpa": BartSdpaAttention, + "flash_attention_2": BartFlashAttention2, +} + + +class BartEncoderLayer(nn.Module): + + def __init__(self, config: BartConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, + p=self.activation_dropout, + training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (attn_weights, ) + + return outputs + + +class BartDecoderLayer(nn.Module): + + def __init__(self, config: BartConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = BART_ATTENTION_CLASSES[ + config._attn_implementation]( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[ + -2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, + p=self.activation_dropout, + training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value, ) + + return outputs + + +class BartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class BartPreTrainedModel(PreTrainedModel): + config_class = BartConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"] + _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], + device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class PretrainedBartModel(BartPreTrainedModel): + + def __init_subclass__(self): + warnings.warn( + "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", + FutureWarning, + ) + + +class BartPretrainedModel(BartPreTrainedModel): + + def __init_subclass__(self): + warnings.warn( + "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", + FutureWarning, + ) + + +BART_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BartConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BART_GENERATION_EXAMPLE = r""" + Summarization example: + + ```python + >>> from transformers import AutoTokenizer, BartForConditionalGeneration + + >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") + + >>> ARTICLE_TO_SUMMARIZE = ( + ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + ... ) + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20) + >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions' + ``` + + Mask filling example: + + ```python + >>> from transformers import AutoTokenizer, BartForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") + >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") + + >>> TXT = "My friends are but they eat too many carbs." + >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"] + >>> logits = model(input_ids).logits + + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = logits[0, masked_index].softmax(dim=0) + >>> values, predictions = probs.topk(5) + + >>> tokenizer.decode(predictions).split() + ['not', 'good', 'healthy', 'great', 'very'] + ``` +""" + +BART_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class BartEncoder(BartPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`BartEncoderLayer`]. + + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, + config: BartConfig, + embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + embed_dim, + self.padding_idx, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList( + [BartEncoderLayer(config) for _ in range(config.encoder_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input = input_ids + input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError( + "You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(input) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + + # expand attention_mask + if attention_mask is not None: + if self._use_flash_attention_2: + attention_mask = attention_mask if 0 in attention_mask else None + elif self._use_sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, inputs_embeds.dtype) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask( + attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}.") + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] + if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + + if not return_dict: + return tuple( + v for v in [hidden_states, encoder_states, all_attentions] + if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions) + + +class BartDecoder(BartPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] + + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, + config: BartConfig, + embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + config.d_model, + self.padding_idx, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList( + [BartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if ( + attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, + past_key_values_length) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self._use_flash_attention_2: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1]) + + # embed positions + positions = self.embed_positions(input, past_key_values_length) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if ( + output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], + ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}.") + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states, ) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[ + idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] + if cross_attn_head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] + if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] + if cross_attn_head_mask + is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += ( + layer_outputs[3 if output_attentions else 1], ) + + if output_attentions: + all_self_attns += (layer_outputs[1], ) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2], ) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [ + hidden_states, next_cache, all_hidden_states, all_self_attns, + all_cross_attentions + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class BartModel(BartPreTrainedModel): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + ] + + def __init__(self, config: BartConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = BartEncoder(config, self.shared) + self.decoder = BartDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqModelOutput]: + # different to other models, Bart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, + self.config.decoder_start_token_id) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] + if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] + if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class BartForConditionalGeneration(BartPreTrainedModel): + base_model_prefix = "model" + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", + "lm_head.weight" + ] + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + + def __init__(self, config: BartConfig, cache_config: CacheConfig, quant_config: QuantizationConfig): + super().__init__(config) + self.model = BartModel(config) + self.register_buffer( + "final_logits_bias", + torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, + self.model.shared.num_embeddings, + bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings( + self, + new_num_tokens: int, + pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings( + new_num_tokens, pad_to_multiple_of) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), + device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def forward( + self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> Union[Tuple, Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning( + "The `use_cache` argument is changed to `False` since `labels` is provided." + ) + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, + self.config.decoder_start_token_id) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) + + masked_lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits, ) + outputs[1:] + return ((masked_lm_loss, ) + + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + + return { + "input_ids": + None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": + use_cache, # change this to avoid caching (presumably for debugging) + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, + self.config.decoder_start_token_id) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += (tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past[:2]) + layer_past[2:], ) + return reordered_past + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + return + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + expert_params_mapping = [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id) + ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id) + ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the activation scales for the experts + # (param_name, weight_name, expert_id) + ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", + f"experts.{expert_id}.{weight_name}.input_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) \ No newline at end of file From d2ad2328e41ad7a8898ddbb37db8c1bfaf2ae803 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 13:37:27 -0400 Subject: [PATCH 264/443] wip bart integration --- vllm/model_executor/models/bart.py | 411 ++++++++++++++++------------ vllm/worker/enc_dec_model_runner.py | 2 + 2 files changed, 241 insertions(+), 172 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index f7f12e2a79154..25a2e62eb82a3 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -1151,7 +1151,10 @@ class BartEncoder(BartPreTrainedModel): """ def __init__(self, - config: BartConfig, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, embed_tokens: Optional[nn.Embedding] = None): super().__init__(config) @@ -1177,13 +1180,13 @@ def __init__(self, ) self.layers = nn.ModuleList( [BartEncoderLayer(config) for _ in range(config.encoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" + # self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + # self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() + # self.gradient_checkpointing = False + # # Initialize weights and apply final processing + # self.post_init() def get_input_embeddings(self): return self.embed_tokens @@ -1351,8 +1354,11 @@ class BartDecoder(BartPreTrainedModel): """ def __init__(self, - config: BartConfig, - embed_tokens: Optional[nn.Embedding] = None): + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None,): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -1375,14 +1381,14 @@ def __init__(self, ) self.layers = nn.ModuleList( [BartDecoderLayer(config) for _ in range(config.decoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" + # self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + # self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() + # self.gradient_checkpointing = False + # # Initialize weights and apply final processing + # self.post_init() def get_input_embeddings(self): return self.embed_tokens @@ -1651,17 +1657,53 @@ class BartModel(BartPreTrainedModel): "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" ] - def __init__(self, config: BartConfig): + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): super().__init__(config) - padding_idx, vocab_size = config.pad_token_id, config.vocab_size - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + # padding_idx, vocab_size = config.pad_token_id, config.vocab_size + # self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + # self.encoder = BartEncoder(config, self.shared) + # self.decoder = BartDecoder(config, self.shared) + + # # Initialize weights and apply final processing + # self.post_init() - self.encoder = BartEncoder(config, self.shared) - self.decoder = BartDecoder(config, self.shared) - # Initialize weights and apply final processing - self.post_init() + + + + + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.encoder = BartEncoder(config, + cache_config, + quant_config=quant_config) + self.decoder = BartDecoder(config, + cache_config, + quant_config=quant_config) + + # self.layers = nn.ModuleList([ + # MixtralDecoderLayer(config, + # cache_config, + # quant_config=quant_config) + # for _ in range(config.num_hidden_layers) + # ]) + #self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def _tie_weights(self): if self.config.tie_word_embeddings: @@ -1683,93 +1725,85 @@ def get_decoder(self): return self.decoder def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[List[torch.FloatTensor]] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata ) -> Union[Tuple, Seq2SeqModelOutput]: - # different to other models, Bart automatically creates decoder_input_ids from - # input_ids if no decoder_input_ids are provided - if decoder_input_ids is None and decoder_inputs_embeds is None: - if input_ids is None: - raise ValueError( - "If no `decoder_input_ids` or `decoder_inputs_embeds` are " - "passed, `input_ids` cannot be `None`. Please pass either " - "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." - ) - - decoder_input_ids = shift_tokens_right( - input_ids, self.config.pad_token_id, - self.config.decoder_start_token_id) - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] - if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] - if len(encoder_outputs) > 2 else None, - ) - - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return Seq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) + + + assert False + + # # different to other models, Bart automatically creates decoder_input_ids from + # # input_ids if no decoder_input_ids are provided + # if decoder_input_ids is None and decoder_inputs_embeds is None: + # if input_ids is None: + # raise ValueError( + # "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + # "passed, `input_ids` cannot be `None`. Please pass either " + # "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + # ) + + # decoder_input_ids = shift_tokens_right( + # input_ids, self.config.pad_token_id, + # self.config.decoder_start_token_id) + + # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + # output_hidden_states = (output_hidden_states + # if output_hidden_states is not None else + # self.config.output_hidden_states) + # use_cache = use_cache if use_cache is not None else self.config.use_cache + # return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # if encoder_outputs is None: + # encoder_outputs = self.encoder( + # input_ids=input_ids, + # attention_mask=attention_mask, + # head_mask=head_mask, + # inputs_embeds=inputs_embeds, + # output_attentions=output_attentions, + # output_hidden_states=output_hidden_states, + # return_dict=return_dict, + # ) + # # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + # elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + # encoder_outputs = BaseModelOutput( + # last_hidden_state=encoder_outputs[0], + # hidden_states=encoder_outputs[1] + # if len(encoder_outputs) > 1 else None, + # attentions=encoder_outputs[2] + # if len(encoder_outputs) > 2 else None, + # ) + + # # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder_outputs = self.decoder( + # input_ids=decoder_input_ids, + # attention_mask=decoder_attention_mask, + # encoder_hidden_states=encoder_outputs[0], + # encoder_attention_mask=attention_mask, + # head_mask=decoder_head_mask, + # cross_attn_head_mask=cross_attn_head_mask, + # past_key_values=past_key_values, + # inputs_embeds=decoder_inputs_embeds, + # use_cache=use_cache, + # output_attentions=output_attentions, + # output_hidden_states=output_hidden_states, + # return_dict=return_dict, + # ) + + # if not return_dict: + # return decoder_outputs + encoder_outputs + + # return Seq2SeqModelOutput( + # last_hidden_state=decoder_outputs.last_hidden_state, + # past_key_values=decoder_outputs.past_key_values, + # decoder_hidden_states=decoder_outputs.hidden_states, + # decoder_attentions=decoder_outputs.attentions, + # cross_attentions=decoder_outputs.cross_attentions, + # encoder_last_hidden_state=encoder_outputs.last_hidden_state, + # encoder_hidden_states=encoder_outputs.hidden_states, + # encoder_attentions=encoder_outputs.attentions, + # ) class BartForConditionalGeneration(BartPreTrainedModel): @@ -1780,18 +1814,43 @@ class BartForConditionalGeneration(BartPreTrainedModel): ] _keys_to_ignore_on_load_missing = ["final_logits_bias"] - def __init__(self, config: BartConfig, cache_config: CacheConfig, quant_config: QuantizationConfig): + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): super().__init__(config) - self.model = BartModel(config) - self.register_buffer( - "final_logits_bias", - torch.zeros((1, self.model.shared.num_embeddings))) - self.lm_head = nn.Linear(config.d_model, - self.model.shared.num_embeddings, - bias=False) + # self.model = BartModel(config) + # self.register_buffer( + # "final_logits_bias", + # torch.zeros((1, self.model.shared.num_embeddings))) + # self.lm_head = nn.Linear(config.d_model, + # self.model.shared.num_embeddings, + # bias=False) - # Initialize weights and apply final processing - self.post_init() + # # Initialize weights and apply final processing + # self.post_init() + + self.config = config + self.model = BartModel(config, + cache_config, + quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() def get_encoder(self): return self.model.get_encoder() @@ -1837,63 +1896,71 @@ def forward( Returns: """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if labels is not None: - if use_cache: - logger.warning( - "The `use_cache` argument is changed to `False` since `labels` is provided." - ) - use_cache = False - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, - self.config.decoder_start_token_id) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - encoder_outputs=encoder_outputs, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - lm_logits = self.lm_head(outputs[0]) - lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) - - masked_lm_loss = None - if labels is not None: - labels = labels.to(lm_logits.device) - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct( - lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + hidden_states = self.model(input_ids, + positions, + encoder_input_ids, + encoder_positions, + kv_caches, + attn_metadata) + return hidden_states - if not return_dict: - output = (lm_logits, ) + outputs[1:] - return ((masked_lm_loss, ) + - output) if masked_lm_loss is not None else output - - return Seq2SeqLMOutput( - loss=masked_lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) + # return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # if labels is not None: + # if use_cache: + # logger.warning( + # "The `use_cache` argument is changed to `False` since `labels` is provided." + # ) + # use_cache = False + # if decoder_input_ids is None and decoder_inputs_embeds is None: + # decoder_input_ids = shift_tokens_right( + # labels, self.config.pad_token_id, + # self.config.decoder_start_token_id) + + # outputs = self.model( + # input_ids, + # attention_mask=attention_mask, + # decoder_input_ids=decoder_input_ids, + # encoder_outputs=encoder_outputs, + # decoder_attention_mask=decoder_attention_mask, + # head_mask=head_mask, + # decoder_head_mask=decoder_head_mask, + # cross_attn_head_mask=cross_attn_head_mask, + # past_key_values=past_key_values, + # inputs_embeds=inputs_embeds, + # decoder_inputs_embeds=decoder_inputs_embeds, + # use_cache=use_cache, + # output_attentions=output_attentions, + # output_hidden_states=output_hidden_states, + # return_dict=return_dict, + # ) + + # lm_logits = self.lm_head(outputs[0]) + # lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) + + # masked_lm_loss = None + # if labels is not None: + # labels = labels.to(lm_logits.device) + # loss_fct = CrossEntropyLoss() + # masked_lm_loss = loss_fct( + # lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + # if not return_dict: + # output = (lm_logits, ) + outputs[1:] + # return ((masked_lm_loss, ) + + # output) if masked_lm_loss is not None else output + + # return Seq2SeqLMOutput( + # loss=masked_lm_loss, + # logits=lm_logits, + # past_key_values=outputs.past_key_values, + # decoder_hidden_states=outputs.decoder_hidden_states, + # decoder_attentions=outputs.decoder_attentions, + # cross_attentions=outputs.cross_attentions, + # encoder_last_hidden_state=outputs.encoder_last_hidden_state, + # encoder_hidden_states=outputs.encoder_hidden_states, + # encoder_attentions=outputs.encoder_attentions, + # ) def prepare_inputs_for_generation( self, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 76dccd1074d45..9eb004d269499 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -426,6 +426,8 @@ def profile_run(self) -> None: lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, multi_modal_data=dummy_multi_modal_data, + encoder_seq_data=seq_data, + cross_block_table=None ) seqs.append(seq) From 30becae9d35d4b994bcd995c81603a97b93d0e3d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 13:45:48 -0400 Subject: [PATCH 265/443] profiling fix; wip bart --- examples/offline_inference_encoder_decoder.py | 2 +- vllm/model_executor/models/bart.py | 51 ++++++++----------- vllm/worker/enc_dec_model_runner.py | 3 +- 3 files changed, 23 insertions(+), 33 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 3bf7f2e8660ee..c9c0f642aac21 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -16,7 +16,7 @@ "", ] # - Unified prompts -prompts = [enc_dec for enc_dec in zip(encoder_prompts,decoder_prompts)] +prompts = [enc_dec for enc_dec in zip(encoder_prompts, decoder_prompts)] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 25a2e62eb82a3..de8abbadfca6c 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -1151,7 +1151,7 @@ class BartEncoder(BartPreTrainedModel): """ def __init__(self, - config: BartConfig, + config: BartConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, @@ -1353,12 +1353,14 @@ class BartDecoder(BartPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - embed_tokens: Optional[nn.Embedding] = None,): + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + ): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -1657,8 +1659,8 @@ class BartModel(BartPreTrainedModel): "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" ] - def __init__(self, - config: BartConfig, + def __init__(self, + config: BartConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None): @@ -1673,11 +1675,6 @@ def __init__(self, # # Initialize weights and apply final processing # self.post_init() - - - - - self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 @@ -1690,10 +1687,10 @@ def __init__(self, org_num_embeddings=config.vocab_size, ) - self.encoder = BartEncoder(config, + self.encoder = BartEncoder(config, cache_config, quant_config=quant_config) - self.decoder = BartDecoder(config, + self.decoder = BartDecoder(config, cache_config, quant_config=quant_config) @@ -1725,12 +1722,10 @@ def get_decoder(self): return self.decoder def forward( - self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata + self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, + kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata ) -> Union[Tuple, Seq2SeqModelOutput]: - assert False @@ -1814,8 +1809,8 @@ class BartForConditionalGeneration(BartPreTrainedModel): ] _keys_to_ignore_on_load_missing = ["final_logits_bias"] - def __init__(self, - config: BartConfig, + def __init__(self, + config: BartConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None): @@ -1896,12 +1891,8 @@ def forward( Returns: """ - hidden_states = self.model(input_ids, - positions, - encoder_input_ids, - encoder_positions, - kv_caches, - attn_metadata) + hidden_states = self.model(input_ids, positions, encoder_input_ids, + encoder_positions, kv_caches, attn_metadata) return hidden_states # return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -2099,4 +2090,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) \ No newline at end of file + weight_loader(param, loaded_weight) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 9eb004d269499..7a76c6d1a9499 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -427,8 +427,7 @@ def profile_run(self) -> None: if dummy_lora_requests_per_seq else None, multi_modal_data=dummy_multi_modal_data, encoder_seq_data=seq_data, - cross_block_table=None - ) + cross_block_table=None) seqs.append(seq) # Run the model with the dummy inputs. From 45a53877dc815398f1f190fa7e7d513db7928b6f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 14:28:59 -0400 Subject: [PATCH 266/443] pruning out training functionality & unnecessary code from BART --- vllm/model_executor/models/bart.py | 948 ++++------------------------- 1 file changed, 109 insertions(+), 839 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index de8abbadfca6c..43da0791772b4 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -20,7 +20,6 @@ import torch from torch import nn -from transformers import MixtralConfig from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata @@ -36,29 +35,18 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, - per_tensor_dequantize, - per_tensor_quantize) -from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import SamplerOutput from vllm.utils import print_warning_once -import copy import math -import warnings from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F -import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN from transformers.modeling_attn_mask_utils import ( @@ -70,26 +58,16 @@ from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, - Seq2SeqQuestionAnsweringModelOutput, - Seq2SeqSequenceClassifierOutput, ) -from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, ) from transformers import BartConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "facebook/bart-base" @@ -183,13 +161,11 @@ def forward(self, input_ids: torch.Tensor): class BartAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, is_causal: bool = False, @@ -198,14 +174,14 @@ def __init__( super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads - self.dropout = dropout self.head_dim = embed_dim // num_heads self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads}).") + f" and `num_heads`: {num_heads})." + ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal @@ -216,376 +192,7 @@ def __init__( self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if (is_cross_attention and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1]): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, - bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}") - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, - src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, - src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads, ): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}") - attn_weights = layer_head_mask.view( - 1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, - src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, - src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, - tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, - tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, - p=self.dropout, - training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, - self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}") - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, - self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class BartFlashAttention2(BartAttention): - """ - Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10( - ) - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - # BartFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError( - "BartFlashAttention2 attention does not support output_attentions" - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if (is_cross_attention and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1]): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, - bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat( - [past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat( - [past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), - value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}.") - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = self._flash_attention_forward(query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout) - - attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.out_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward - def _flash_attention_forward(self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, - query_length) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, - query_length) - else: - attn_output = flash_attn_func(query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal) - - return attn_output - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, - query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( - attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, - head_dim), indices_k) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, - head_dim), indices_k) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, - head_dim), indices_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( - query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -class BartSdpaAttention(BartAttention): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, @@ -669,7 +276,7 @@ def forward( key_states, value_states, attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, + dropout_p= 0.0, is_causal=is_causal, ) @@ -688,30 +295,19 @@ def forward( return attn_output, None, past_key_value - -BART_ATTENTION_CLASSES = { - "eager": BartAttention, - "sdpa": BartSdpaAttention, - "flash_attention_2": BartFlashAttention2, -} - - class BartEncoderLayer(nn.Module): def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BartAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) @@ -741,21 +337,15 @@ def forward( layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, - p=self.activation_dropout, - training=self.training) + hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -781,24 +371,19 @@ def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BartAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, is_decoder=True, is_causal=True, config=config, ) - self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BART_ATTENTION_CLASSES[ - config._attn_implementation]( + self.encoder_attn = BartAttention( self.embed_dim, config.decoder_attention_heads, - dropout=config.attention_dropout, is_decoder=True, config=config, ) @@ -852,9 +437,7 @@ def forward( layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -875,9 +458,7 @@ def forward( past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) @@ -887,13 +468,9 @@ def forward( # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, - p=self.activation_dropout, - training=self.training) + hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -907,240 +484,7 @@ def forward( return outputs - -class BartClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" - - def __init__( - self, - input_dim: int, - inner_dim: int, - num_classes: int, - pooler_dropout: float, - ): - super().__init__() - self.dense = nn.Linear(input_dim, inner_dim) - self.dropout = nn.Dropout(p=pooler_dropout) - self.out_proj = nn.Linear(inner_dim, num_classes) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dropout(hidden_states) - hidden_states = self.dense(hidden_states) - hidden_states = torch.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - -class BartPreTrainedModel(PreTrainedModel): - config_class = BartConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"] - _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - @property - def dummy_inputs(self): - pad_token = self.config.pad_token_id - input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], - device=self.device) - dummy_inputs = { - "attention_mask": input_ids.ne(pad_token), - "input_ids": input_ids, - } - return dummy_inputs - - -class PretrainedBartModel(BartPreTrainedModel): - - def __init_subclass__(self): - warnings.warn( - "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", - FutureWarning, - ) - - -class BartPretrainedModel(BartPreTrainedModel): - - def __init_subclass__(self): - warnings.warn( - "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", - FutureWarning, - ) - - -BART_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`BartConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -BART_GENERATION_EXAMPLE = r""" - Summarization example: - - ```python - >>> from transformers import AutoTokenizer, BartForConditionalGeneration - - >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") - - >>> ARTICLE_TO_SUMMARIZE = ( - ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " - ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " - ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." - ... ) - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt") - - >>> # Generate Summary - >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20) - >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions' - ``` - - Mask filling example: - - ```python - >>> from transformers import AutoTokenizer, BartForConditionalGeneration - - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") - >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") - - >>> TXT = "My friends are but they eat too many carbs." - >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"] - >>> logits = model(input_ids).logits - - >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() - >>> probs = logits[0, masked_index].softmax(dim=0) - >>> values, predictions = probs.topk(5) - - >>> tokenizer.decode(predictions).split() - ['not', 'good', 'healthy', 'great', 'very'] - ``` -""" - -BART_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` - is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). - - For translation and summarization training, `decoder_input_ids` should be provided. If no - `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right - for denoising pre-training following the paper. - decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, - 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded - representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be - input (see `past_key_values`). This is useful if you want more control over how to convert - `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - - If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value - of `inputs_embeds`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class BartEncoder(BartPreTrainedModel): +class BartEncoder(nn.Module): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`BartEncoderLayer`]. @@ -1158,9 +502,6 @@ def __init__(self, embed_tokens: Optional[nn.Embedding] = None): super().__init__(config) - self.dropout = config.dropout - self.layerdrop = config.encoder_layerdrop - embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings @@ -1180,13 +521,9 @@ def __init__(self, ) self.layers = nn.ModuleList( [BartEncoderLayer(config) for _ in range(config.encoder_layers)]) - # self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - # self._use_sdpa = config._attn_implementation == "sdpa" + self.layernorm_embedding = nn.LayerNorm(embed_dim) - # self.gradient_checkpointing = False - # # Initialize weights and apply final processing - # self.post_init() def get_input_embeddings(self): return self.embed_tokens @@ -1268,9 +605,6 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) # expand attention_mask if attention_mask is not None: @@ -1300,34 +634,16 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states, ) - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - to_drop = False - if self.training: - dropout_probability = torch.rand([]) - if dropout_probability < self.layerdrop: # skip the layer - to_drop = True - - if to_drop: - layer_outputs = (None, None) - else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] - if head_mask is not None else None), - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] + if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1], ) @@ -1344,7 +660,7 @@ def forward( attentions=all_attentions) -class BartDecoder(BartPreTrainedModel): +class BartDecoder(nn.Module): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] @@ -1362,8 +678,6 @@ def __init__( embed_tokens: Optional[nn.Embedding] = None, ): super().__init__(config) - self.dropout = config.dropout - self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings embed_scale = math.sqrt( @@ -1383,15 +697,9 @@ def __init__( ) self.layers = nn.ModuleList( [BartDecoderLayer(config) for _ in range(config.decoder_layers)]) - # self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - # self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) - # self.gradient_checkpointing = False - # # Initialize weights and apply final processing - # self.post_init() - def get_input_embeddings(self): return self.embed_tokens @@ -1555,17 +863,6 @@ def forward( hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -1586,43 +883,25 @@ def forward( # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states, ) - if self.training: - dropout_probability = torch.rand([]) - if dropout_probability < self.layerdrop: - continue past_key_value = past_key_values[ idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] - if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] - if head_mask is not None else None), - cross_attn_layer_head_mask=(cross_attn_head_mask[idx] - if cross_attn_head_mask - is not None else None), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] + if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] + if cross_attn_head_mask + is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: @@ -1654,7 +933,7 @@ def forward( ) -class BartModel(BartPreTrainedModel): +class BartModel(nn.Module): _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" ] @@ -1694,14 +973,6 @@ def __init__(self, cache_config, quant_config=quant_config) - # self.layers = nn.ModuleList([ - # MixtralDecoderLayer(config, - # cache_config, - # quant_config=quant_config) - # for _ in range(config.num_hidden_layers) - # ]) - #self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def _tie_weights(self): if self.config.tie_word_embeddings: self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) @@ -1727,81 +998,80 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata ) -> Union[Tuple, Seq2SeqModelOutput]: - assert False - # # different to other models, Bart automatically creates decoder_input_ids from - # # input_ids if no decoder_input_ids are provided - # if decoder_input_ids is None and decoder_inputs_embeds is None: - # if input_ids is None: - # raise ValueError( - # "If no `decoder_input_ids` or `decoder_inputs_embeds` are " - # "passed, `input_ids` cannot be `None`. Please pass either " - # "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." - # ) + # different to other models, Bart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) - # decoder_input_ids = shift_tokens_right( - # input_ids, self.config.pad_token_id, - # self.config.decoder_start_token_id) + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, + self.config.decoder_start_token_id) - # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - # output_hidden_states = (output_hidden_states - # if output_hidden_states is not None else - # self.config.output_hidden_states) - # use_cache = use_cache if use_cache is not None else self.config.use_cache - # return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # if encoder_outputs is None: - # encoder_outputs = self.encoder( - # input_ids=input_ids, - # attention_mask=attention_mask, - # head_mask=head_mask, - # inputs_embeds=inputs_embeds, - # output_attentions=output_attentions, - # output_hidden_states=output_hidden_states, - # return_dict=return_dict, - # ) - # # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True - # elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - # encoder_outputs = BaseModelOutput( - # last_hidden_state=encoder_outputs[0], - # hidden_states=encoder_outputs[1] - # if len(encoder_outputs) > 1 else None, - # attentions=encoder_outputs[2] - # if len(encoder_outputs) > 2 else None, - # ) - - # # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - # decoder_outputs = self.decoder( - # input_ids=decoder_input_ids, - # attention_mask=decoder_attention_mask, - # encoder_hidden_states=encoder_outputs[0], - # encoder_attention_mask=attention_mask, - # head_mask=decoder_head_mask, - # cross_attn_head_mask=cross_attn_head_mask, - # past_key_values=past_key_values, - # inputs_embeds=decoder_inputs_embeds, - # use_cache=use_cache, - # output_attentions=output_attentions, - # output_hidden_states=output_hidden_states, - # return_dict=return_dict, - # ) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] + if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] + if len(encoder_outputs) > 2 else None, + ) - # if not return_dict: - # return decoder_outputs + encoder_outputs - - # return Seq2SeqModelOutput( - # last_hidden_state=decoder_outputs.last_hidden_state, - # past_key_values=decoder_outputs.past_key_values, - # decoder_hidden_states=decoder_outputs.hidden_states, - # decoder_attentions=decoder_outputs.attentions, - # cross_attentions=decoder_outputs.cross_attentions, - # encoder_last_hidden_state=encoder_outputs.last_hidden_state, - # encoder_hidden_states=encoder_outputs.hidden_states, - # encoder_attentions=encoder_outputs.attentions, - # ) + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) -class BartForConditionalGeneration(BartPreTrainedModel): +class BartForConditionalGeneration(nn.Module): base_model_prefix = "model" _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", From 97cad4b875ee09ebeff455a20fdf351eef9d2f16 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 14:40:40 -0400 Subject: [PATCH 267/443] wip BART model cleanup --- vllm/model_executor/models/bart.py | 44 ++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 43da0791772b4..969f0be95c98a 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -500,7 +500,8 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, embed_tokens: Optional[nn.Embedding] = None): - super().__init__(config) + #super().__init__(config) + super().__init__() embed_dim = config.d_model self.padding_idx = config.pad_token_id @@ -677,7 +678,8 @@ def __init__( lora_config: Optional[LoRAConfig] = None, embed_tokens: Optional[nn.Embedding] = None, ): - super().__init__(config) + #super().__init__(config) + super().__init__() self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings embed_scale = math.sqrt( @@ -943,7 +945,8 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None): - super().__init__(config) + #super().__init__(config) + super().__init__() # padding_idx, vocab_size = config.pad_token_id, config.vocab_size # self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) @@ -954,6 +957,8 @@ def __init__(self, # # Initialize weights and apply final processing # self.post_init() + self.config = config + self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 @@ -997,7 +1002,15 @@ def forward( encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata ) -> Union[Tuple, Seq2SeqModelOutput]: - + decoder_inputs_embeds = None + decoder_input_ids = input_ids + attention_mask = None + head_mask = None + inputs_embeds = None + decoder_attention_mask = None + decoder_head_mask = None + cross_attn_head_mask = None + past_key_values = None # different to other models, Bart automatically creates decoder_input_ids from # input_ids if no decoder_input_ids are provided @@ -1013,12 +1026,20 @@ def forward( input_ids, self.config.pad_token_id, self.config.decoder_start_token_id) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + #output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = self.config.output_attentions + + # output_hidden_states = (output_hidden_states + # if output_hidden_states is not None else + # self.config.output_hidden_states) + + output_hidden_states = self.config.output_hidden_states + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + # return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + use_cache = self.config.use_cache + return_dict = self.config.use_return_dict if encoder_outputs is None: encoder_outputs = self.encoder( @@ -1084,7 +1105,7 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None): - super().__init__(config) + #super().__init__(config) # self.model = BartModel(config) # self.register_buffer( # "final_logits_bias", @@ -1096,6 +1117,7 @@ def __init__(self, # # Initialize weights and apply final processing # self.post_init() + super().__init__() self.config = config self.model = BartModel(config, cache_config, From 2123517ef5fc8a5593e693b7d28d8c217c729282 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 15:13:36 -0400 Subject: [PATCH 268/443] formatting --- vllm/model_executor/models/bart.py | 33 +++++++++++++++--------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 969f0be95c98a..4ee1247b58d6f 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -62,8 +62,7 @@ Seq2SeqModelOutput, ) from transformers.utils import ( - logging, -) + logging, ) from transformers import BartConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -180,8 +179,7 @@ def __init__( if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) + f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal @@ -192,7 +190,8 @@ def __init__( self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() def forward( self, @@ -276,7 +275,7 @@ def forward( key_states, value_states, attn_mask=attention_mask, - dropout_p= 0.0, + dropout_p=0.0, is_causal=is_causal, ) @@ -295,6 +294,7 @@ def forward( return attn_output, None, past_key_value + class BartEncoderLayer(nn.Module): def __init__(self, config: BartConfig): @@ -382,11 +382,11 @@ def __init__(self, config: BartConfig): self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn = BartAttention( - self.embed_dim, - config.decoder_attention_heads, - is_decoder=True, - config=config, - ) + self.embed_dim, + config.decoder_attention_heads, + is_decoder=True, + config=config, + ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) @@ -484,6 +484,7 @@ def forward( return outputs + class BartEncoder(nn.Module): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a @@ -525,7 +526,6 @@ def __init__(self, self.layernorm_embedding = nn.LayerNorm(embed_dim) - def get_input_embeddings(self): return self.embed_tokens @@ -640,7 +640,7 @@ def forward( hidden_states, attention_mask, layer_head_mask=(head_mask[idx] - if head_mask is not None else None), + if head_mask is not None else None), output_attentions=output_attentions, ) @@ -889,17 +889,16 @@ def forward( past_key_value = past_key_values[ idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] - if head_mask is not None else None), + if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] - if cross_attn_head_mask - is not None else None), + if cross_attn_head_mask is not None + else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, From c11db0fd30e326d2273da95439c5087e83725b04 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 15:21:15 -0400 Subject: [PATCH 269/443] integrating BART weight loading code --- vllm/model_executor/models/bart.py | 199 ++++++++++++++++++----------- 1 file changed, 122 insertions(+), 77 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 4ee1247b58d6f..62ff1ea7ed818 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -1299,86 +1299,131 @@ def _reorder_cache(past_key_values, beam_idx): for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past + def _rename_key(self, key: str): + prefix = f"{self.base_model_prefix}." + key = key[len(prefix):] if key.startswith(prefix) else key + + for src, dst in self.params_mapping.items(): + key = key.replace(src, dst) + + return key + + + def _rename_stacked_param( + self, + name: str, + ) -> Tuple[str, Optional[str]]: + for key, mapping in self.stacked_params_mapping.items(): + if key in name: + name = name.replace(key, mapping["param_name"]) + return name, mapping["shard_id"] + return name, None + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - return - - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - expert_params_mapping = [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id) - ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id) - ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] + [ - # These are the activation scales for the experts - # (param_name, weight_name, expert_id) - ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", - f"experts.{expert_id}.{weight_name}.input_scale", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - - params_dict = dict(self.named_parameters()) + + params_dict = dict(self.model.named_parameters()) + for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: + name = self._rename_key(name) + name, shard_id = self._rename_stacked_param(name) + + # Skip the specific downstream task weight. + if name.startswith('cls.'): + continue + # use Pooler instead. + if name.startswith('pooler.'): + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if shard_id: weight_loader(param, loaded_weight, shard_id) - break else: - for param_name, weight_name, expert_id in expert_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - print_warning_once( - "Found kv scale in the checkpoint " - f"(e.g. {name}), but not found the expected " - f"name in the model " - f"(e.g. {remapped_kv_scale_name}). " - "kv-scale is not loaded.") - continue - else: - name = remapped_kv_scale_name - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + weight_loader(param, loaded_weight) + + # def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + # stacked_params_mapping = [ + # # (param_name, shard_name, shard_id) + # ("qkv_proj", "q_proj", "q"), + # ("qkv_proj", "k_proj", "k"), + # ("qkv_proj", "v_proj", "v"), + # ] + + # expert_params_mapping = [ + # # These are the weight scales for the experts + # # (param_name, weight_name, expert_id) + # ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", + # f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) + # for expert_id in range(self.config.num_local_experts) + # for weight_name in ["w1", "w2", "w3"] + # ] + [ + # # These are the weights for the experts + # # (param_name, weight_name, expert_id) + # ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", + # f"experts.{expert_id}.{weight_name}.weight", expert_id) + # for expert_id in range(self.config.num_local_experts) + # for weight_name in ["w1", "w2", "w3"] + # ] + [ + # # These are the activation scales for the experts + # # (param_name, weight_name, expert_id) + # ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", + # f"experts.{expert_id}.{weight_name}.input_scale", expert_id) + # for expert_id in range(self.config.num_local_experts) + # for weight_name in ["w1", "w2", "w3"] + # ] + + # params_dict = dict(self.named_parameters()) + # for name, loaded_weight in weights: + # if "rotary_emb.inv_freq" in name: + # continue + + # for (param_name, weight_name, shard_id) in stacked_params_mapping: + # if weight_name not in name: + # continue + # name = name.replace(weight_name, param_name) + # # Skip loading extra bias for GPTQ models. + # if name.endswith(".bias") and name not in params_dict: + # continue + # param = params_dict[name] + # weight_loader = param.weight_loader + # weight_loader(param, loaded_weight, shard_id) + # break + # else: + # for param_name, weight_name, expert_id in expert_params_mapping: + # if weight_name not in name: + # continue + # name = name.replace(weight_name, param_name) + # param = params_dict[name] + # weight_loader = param.weight_loader + # weight_loader(param, + # loaded_weight, + # weight_name, + # expert_id=expert_id) + # break + # else: + # # Skip loading extra bias for GPTQ models. + # if name.endswith(".bias") and name not in params_dict: + # continue + # # Remapping the name of FP8 kv-scale. + # if name.endswith("kv_scale"): + # remapped_kv_scale_name = name.replace( + # ".kv_scale", ".attn.kv_scale") + # if remapped_kv_scale_name not in params_dict: + # print_warning_once( + # "Found kv scale in the checkpoint " + # f"(e.g. {name}), but not found the expected " + # f"name in the model " + # f"(e.g. {remapped_kv_scale_name}). " + # "kv-scale is not loaded.") + # continue + # else: + # name = remapped_kv_scale_name + # param = params_dict[name] + # weight_loader = getattr(param, "weight_loader", + # default_weight_loader) + # weight_loader(param, loaded_weight) From 576c26c86a9b210fcca29180ed20fd15770f2660 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 15:35:11 -0400 Subject: [PATCH 270/443] first pass a BART load_weights; probably not handling qkv correctly --- vllm/model_executor/models/bart.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 62ff1ea7ed818..8d3085e1e53b8 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -1299,6 +1299,27 @@ def _reorder_cache(past_key_values, beam_idx): for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past + stacked_params_mapping = { + "query": { + "param_name": "qkv_proj", + "shard_id": "q", + }, + "key": { + "param_name": "qkv_proj", + "shard_id": "k", + }, + "value": { + "param_name": "qkv_proj", + "shard_id": "v", + }, + } + + params_mapping = { + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + } + def _rename_key(self, key: str): prefix = f"{self.base_model_prefix}." key = key[len(prefix):] if key.startswith(prefix) else key @@ -1324,6 +1345,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.model.named_parameters()) for name, loaded_weight in weights: + if 'shared.weight' in name: + continue + name = self._rename_key(name) name, shard_id = self._rename_stacked_param(name) From 6219d9590dfae14c574d598ce879af58fe97177f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 15:36:36 -0400 Subject: [PATCH 271/443] Formatting --- vllm/model_executor/models/bart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 8d3085e1e53b8..43d6511e74342 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -1329,7 +1329,6 @@ def _rename_key(self, key: str): return key - def _rename_stacked_param( self, name: str, From 9ad5143ab290419d27fcde1287d9bea853a58be3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 16:00:15 -0400 Subject: [PATCH 272/443] refactored backend constants --- tests/kernels/test_attention_selector.py | 4 ++-- tests/kernels/test_encoder_decoder_attn.py | 6 ++++-- tests/kernels/utils.py | 18 ++-------------- vllm/utils.py | 24 +++++++++++++++++++++- vllm/worker/enc_dec_model_runner.py | 4 ++-- 5 files changed, 33 insertions(+), 23 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index d9000e58d1d43..c27607912692b 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -3,9 +3,9 @@ import pytest import torch -from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL, - override_backend_env_variable) +from tests.kernels.utils import override_backend_env_variable from vllm.attention.selector import which_attn_to_use +from vllm.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL) @pytest.mark.parametrize( diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index f61b0a0dcc706..6f2545d839b81 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -15,10 +15,12 @@ import torch from tests.kernels.utils import * +import copy from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP -from vllm.utils import is_hip, make_causal_mask, maybe_make_long_tensor +from vllm.utils import (is_hip, make_causal_mask, maybe_make_long_tensor, + LIST_ENC_DEC_SUPPORTED_BACKENDS) HEAD_SIZES = [64, 256] @@ -26,7 +28,7 @@ BATCH_SIZES = [1, 16] BLOCK_SIZES = [16] -BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] +BACKEND_NAMES = copy.copy(LIST_ENC_DEC_SUPPORTED_BACKENDS) CUDA_DEVICE = "cuda:0" MAX_DEC_SEQ_LENS = [128] diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 94e7379123c7c..642ddeb014303 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -11,22 +11,8 @@ AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import (make_tensor_with_pad, maybe_make_int_tensor, - maybe_make_long_tensor, maybe_max) - -# String name of register which may be set in order to -# force auto-selection of attention backend by Attention -# wrapper -STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" - -# Possible string values of STR_BACKEND_ENV_VAR -# register, corresponding to possible backends -STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" -STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" -STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" -STR_XFORMERS_ATTN_VAL: str = "XFORMERS" -STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" -STR_INVALID_VAL: str = "INVALID" - + maybe_make_long_tensor, maybe_max, + STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL) class QKVInputs(NamedTuple): ''' diff --git a/vllm/utils.py b/vllm/utils.py index 84802bc5b5dae..f852c5e174fe1 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -45,9 +45,31 @@ "Chunked prefill for encoder/decoder models " + \ "is not currently supported." -STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED = \ +STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = \ "Currently CUDAGraph is not supported for encoder/decoder models" +STR_NOT_IMPL_ENC_DEC_BACKEND = \ + "This backend is currently unsupported for encoder/decoder models:" + +# Constants related to forcing the attention backend selection + +# String name of register which may be set in order to +# force auto-selection of attention backend by Attention +# wrapper +STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" + +# Possible string values of STR_BACKEND_ENV_VAR +# register, corresponding to possible backends +STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" +STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" +STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" +STR_XFORMERS_ATTN_VAL: str = "XFORMERS" +STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" +STR_INVALID_VAL: str = "INVALID" + +# List of support backends for encoder/decoder models +LIST_ENC_DEC_SUPPORTED_BACKENDS = [STR_XFORMERS_ATTN_VAL] + STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.half, "bfloat16": torch.bfloat16, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 7a76c6d1a9499..09b240295ea8b 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -16,7 +16,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.utils import (STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED, +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, make_tensor_with_pad) from vllm.worker.model_runner import LORA_WARMUP_RANK, ModelInput, ModelRunner @@ -439,4 +439,4 @@ def profile_run(self) -> None: @torch.inference_mode() def capture_model(self, _: List[torch.Tensor]) -> None: - raise NotImplementedError(STR_ENCDECMR_CUDAGRAPH_UNSUPPORTED) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH) From 42c36443981dd89c9defaf2f51c1481ddb0a5751 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 16:24:26 -0400 Subject: [PATCH 273/443] encoder decoder model runner fails for unsupported scenarios --- tests/kernels/utils.py | 5 ++-- vllm/worker/enc_dec_model_runner.py | 37 ++++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 642ddeb014303..b230bcf3c297d 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -11,8 +11,9 @@ AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import (make_tensor_with_pad, maybe_make_int_tensor, - maybe_make_long_tensor, maybe_max, - STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL) + maybe_make_long_tensor, maybe_max, STR_BACKEND_ENV_VAR, + STR_XFORMERS_ATTN_VAL) + class QKVInputs(NamedTuple): ''' diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 09b240295ea8b..0f88075ab56fb 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -18,7 +18,9 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, - make_tensor_with_pad) + STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_BACKEND, STR_NOT_IMPL_ENC_DEC_SWA, + LIST_ENC_DEC_SUPPORTED_BACKENDS, make_tensor_with_pad) from vllm.worker.model_runner import LORA_WARMUP_RANK, ModelInput, ModelRunner logger = init_logger(__name__) @@ -62,14 +64,47 @@ def __init__( kv_cache_dtype, is_driver_worker, vision_language_config) + self._check_encoder_decoder_unsupported_scenarios() + + def _check_encoder_decoder_unsupported_scenarios(self): + ''' + Catch and raise NotImplemented errors if features unsupported + for encoder/decoder models are enabled, or if an otherwise + unsupported scenario is configured. + ''' + if not self._is_encoder_decoder_model(): # Fail if EncoderDecoderModelRunner is constructed for a # non-encoder/decoder model i.e. decoder-only raise AttributeError(STR_ENCDECMR_ENCODER_DECODER_REQUIRED) if self.scheduler_config.chunked_prefill_enabled: + # Fail if chunked prefill is enabled raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) + if self.cache_config.enable_prefix_caching: + # Fail if prefix caching is enabled + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) + + if self.sliding_window is not None: + # Fail if sliding window is enabled + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) + + if not self.model_config.enforce_eager: + # Fail if CUDA graph is enabled + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH) + + backend_name = self.attn_backend.get_name() + caps_backend_name = backend_name.upper() + if caps_backend_name not in LIST_ENC_DEC_SUPPORTED_BACKENDS: + # Fail if the selected backend is not supported for + # encoder decoder models. + msg = STR_NOT_IMPL_ENC_DEC_BACKEND + \ + f" {backend_name}; supported backends: " + \ + "{str(LIST_ENC_DEC_SUPPORTED_BACKENDS)}" + + raise NotImplementedError(msg) + def _prepare_encoder_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], attn_metadata: AttentionMetadata) -> EncoderInput: From 082be510533d1e39008db19ca8754a91aa4879d3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 19:36:46 -0400 Subject: [PATCH 274/443] loading tied weights --- examples/offline_inference_encoder_decoder.py | 3 +- .../test_encoder_decoder_model_runner.py | 4 +- vllm/config.py | 1 + vllm/model_executor/models/bart.py | 117 ++++++++++++------ 4 files changed, 86 insertions(+), 39 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index c9c0f642aac21..95ed705f47c7e 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -22,7 +22,8 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/bart-base") +llm = LLM(model="facebook/bart-base", + enforce_eager=True) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index bce35e04d9f68..8456a2ec8e7bd 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -3,8 +3,8 @@ import pytest import torch -from tests.kernels.utils import (STR_XFORMERS_ATTN_VAL, - override_backend_env_variable) +from vllm.utils import STR_XFORMERS_ATTN_VAL +from tests.kernels.utils import override_backend_env_variable from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata diff --git a/vllm/config.py b/vllm/config.py index 8d004902fe4ff..c6a7139c51282 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -137,6 +137,7 @@ def __init__( self.hf_config = get_config(self.model, trust_remote_code, revision, code_revision, rope_scaling, rope_theta) + self.hf_config.tie_word_embeddings = False self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len( diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 43d6511e74342..c564714860949 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -23,6 +23,7 @@ from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -41,6 +42,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.utils import print_warning_once +from vllm.sequence import SamplerOutput + import math from typing import List, Optional, Tuple, Union @@ -964,11 +967,11 @@ def __init__(self, self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) + # self.embed_tokens = VocabParallelEmbedding( + # self.vocab_size, + # config.hidden_size, + # org_num_embeddings=config.vocab_size, + # ) self.encoder = BartEncoder(config, cache_config, @@ -1109,9 +1112,6 @@ def __init__(self, # self.register_buffer( # "final_logits_bias", # torch.zeros((1, self.model.shared.num_embeddings))) - # self.lm_head = nn.Linear(config.d_model, - # self.model.shared.num_embeddings, - # bias=False) # # Initialize weights and apply final processing # self.post_init() @@ -1122,18 +1122,24 @@ def __init__(self, cache_config, quant_config, lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - ) + + self.lm_head = nn.Linear(config.d_model, + config.vocab_size, + bias=False) + + # self.lm_head = ParallelLMHead( + # self.unpadded_vocab_size, + # config.hidden_size, + # org_num_embeddings=config.vocab_size, + # padding_size=DEFAULT_VOCAB_PADDING_SIZE + # # We need bigger padding if using lora for kernel + # # compatibility + # if not lora_config else lora_config.lora_vocab_padding_size, + # ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() @@ -1244,6 +1250,20 @@ def forward( # encoder_attentions=outputs.encoder_attentions, # ) + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + def prepare_inputs_for_generation( self, decoder_input_ids, @@ -1341,32 +1361,57 @@ def _rename_stacked_param( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - params_dict = dict(self.model.named_parameters()) + model_params_dict = dict(self.model.named_parameters()) + top_params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: - if 'shared.weight' in name: - continue + # if 'shared.weight' in name: + # continue name = self._rename_key(name) name, shard_id = self._rename_stacked_param(name) - # Skip the specific downstream task weight. - if name.startswith('cls.'): - continue - # use Pooler instead. - if name.startswith('pooler.'): - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - if shard_id: - weight_loader(param, loaded_weight, shard_id) + if 'shared.weight' in name: + encoder_in_param = model_params_dict['encoder.embed_tokens.weight'] + encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader", + default_weight_loader) + + decoder_in_param = model_params_dict['decoder.embed_tokens.weight'] + decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader", + default_weight_loader) + + lm_head_in_param = top_params_dict['lm_head.weight'] + lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader", + default_weight_loader) + + if shard_id: + encoder_in_weight_loader(encoder_in_param, loaded_weight, shard_id) + decoder_in_weight_loader(decoder_in_param, loaded_weight, shard_id) + lm_head_in_weight_loader(lm_head_in_param, loaded_weight, shard_id) + else: + encoder_in_weight_loader(encoder_in_param, loaded_weight) + decoder_in_weight_loader(decoder_in_param, loaded_weight) + lm_head_in_weight_loader(lm_head_in_param, loaded_weight) + else: - weight_loader(param, loaded_weight) + # Skip the specific downstream task weight. + if name.startswith('cls.'): + continue + # use Pooler instead. + if name.startswith('pooler.'): + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in model_params_dict: + continue + + param = model_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if shard_id: + weight_loader(param, loaded_weight, shard_id) + else: + weight_loader(param, loaded_weight) # def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): From f2dac1ce0ae1033b5143b8f1cd234e1eee5e67ee Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 20:13:05 -0400 Subject: [PATCH 275/443] wip --- vllm/model_executor/models/bart.py | 51 +++++++++++++----------------- 1 file changed, 22 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index c564714860949..73169a0013aa5 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -712,19 +712,9 @@ def set_input_embeddings(self, value): self.embed_tokens = value def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], encoder_positions: Optional[torch.Tensor], + kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -1043,9 +1033,12 @@ def forward( use_cache = self.config.use_cache return_dict = self.config.use_return_dict - if encoder_outputs is None: - encoder_outputs = self.encoder( + if encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input + encoder_hidden_states = self.encoder( input_ids=input_ids, + attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, @@ -1053,21 +1046,21 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] - if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] - if len(encoder_outputs) > 2 else None, - ) + # # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + # elif return_dict and not isinstance(encoder_hidden_states, BaseModelOutput): + # encoder_hidden_states = BaseModelOutput( + # last_hidden_state=encoder_hidden_states[0], + # hidden_states=encoder_hidden_states[1] + # if len(encoder_hidden_states) > 1 else None, + # attentions=encoder_hidden_states[2] + # if len(encoder_hidden_states) > 2 else None, + # ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_outputs[0], + encoder_hidden_states=encoder_hidden_states[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, @@ -1080,7 +1073,7 @@ def forward( ) if not return_dict: - return decoder_outputs + encoder_outputs + return decoder_outputs + encoder_hidden_states return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, @@ -1088,9 +1081,9 @@ def forward( decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, + encoder_last_hidden_state=encoder_hidden_states.last_hidden_state, + encoder_hidden_states=encoder_hidden_states.hidden_states, + encoder_attentions=encoder_hidden_states.attentions, ) From 59caabecf2666c33306625843908b1d9dc2ffa8b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 21 Jun 2024 21:42:39 -0400 Subject: [PATCH 276/443] BART almost passing profile_run() --- vllm/model_executor/models/bart.py | 702 ++++++++++------------------- 1 file changed, 244 insertions(+), 458 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 73169a0013aa5..0aeedf2594375 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -18,6 +18,8 @@ """PyTorch BART model.""" from typing import Iterable, List, Optional, Tuple +from vllm.attention.backends.abstract import AttentionType + import torch from torch import nn @@ -161,21 +163,21 @@ def __init__(self, def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale - -class BartAttention(nn.Module): +class BartEncoderAttention(nn.Module): def __init__( self, embed_dim: int, num_heads: int, - is_decoder: bool = False, bias: bool = True, - is_causal: bool = False, config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads + self.num_kv_heads = self.num_heads self.head_dim = embed_dim // num_heads self.config = config @@ -184,130 +186,186 @@ def __init__( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + self, + hidden_states: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) + q=self.q_proj(hidden_states) + k=self.k_proj(hidden_states) + v=self.v_proj(hidden_states) - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if (is_cross_attention and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1]): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=is_causal, - ) + attn_output = self.attn(q, + k, + v, + kv_caches, + attn_metadata, + attn_type=AttentionType.ENCODER) + + output, _ = self.out_proj(attn_output) + return output + +class BartDecoderSelfAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = self.num_heads + self.head_dim = embed_dim // num_heads + self.config = config - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}") + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + q=self.q_proj(hidden_states) + k=self.k_proj(hidden_states) + v=self.v_proj(hidden_states) + + attn_output = self.attn(q, + k, + v, + kv_caches, + attn_metadata, + attn_type=AttentionType.DECODER) + + output, _ = self.out_proj(attn_output) + return output + +class BartCrossAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = self.num_heads + self.head_dim = embed_dim // num_heads + self.config = config - attn_output = attn_output.transpose(1, 2) + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - attn_output = self.out_proj(attn_output) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) - return attn_output, None, past_key_value + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + def forward( + self, + decoder_hidden_states: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor]=None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + q=self.q_proj(decoder_hidden_states) + k=None if encoder_hidden_states is None else \ + self.k_proj(encoder_hidden_states) + v=None if encoder_hidden_states is None else \ + self.v_proj(encoder_hidden_states) + + attn_output = self.attn(q, + k, + v, + kv_caches, + attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) + + output, _ = self.out_proj(attn_output) + return output class BartEncoderLayer(nn.Module): - def __init__(self, config: BartConfig): + def __init__(self, config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None,): super().__init__() self.embed_dim = config.d_model - self.self_attn = BartAttention( + self.self_attn = BartEncoderAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, config=config, + cache_config=cache_config, + quant_config=quant_config ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.activation_fn = ACT2FN[config.activation_function] @@ -317,10 +375,9 @@ def __init__(self, config: BartConfig): def forward( self, - hidden_states: torch.FloatTensor, - attention_mask: torch.FloatTensor, - layer_head_mask: torch.FloatTensor, - output_attentions: Optional[bool] = False, + hidden_states: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: """ Args: @@ -336,9 +393,8 @@ def forward( residual = hidden_states hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, + kv_caches=kv_caches, + attn_metadata=attn_metadata ) hidden_states = residual + hidden_states @@ -360,34 +416,30 @@ def forward( min=-clamp_value, max=clamp_value) - outputs = (hidden_states, ) - - if output_attentions: - outputs += (attn_weights, ) - - return outputs + return hidden_states class BartDecoderLayer(nn.Module): - def __init__(self, config: BartConfig): + def __init__(self, config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None,): super().__init__() self.embed_dim = config.d_model - self.self_attn = BartAttention( + self.self_attn = BartDecoderSelfAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, - is_decoder=True, - is_causal=True, config=config, + cache_config=cache_config, + quant_config=quant_config ) self.activation_fn = ACT2FN[config.activation_function] self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BartAttention( + self.encoder_attn = BartCrossAttention( self.embed_dim, config.decoder_attention_heads, - is_decoder=True, config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -397,15 +449,10 @@ def __init__(self, config: BartConfig): def forward( self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = True, + decoder_hidden_states: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor]=None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -429,44 +476,29 @@ def forward( residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[: - 2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, + hidden_states = self.self_attn( + hidden_states=decoder_hidden_states, + kv_caches=kv_caches, + attn_metadata=attn_metadata ) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[ - -2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) + residual = hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value + + hidden_states = self.encoder_attn( + decoder_hidden_states=hidden_states, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) # Fully Connected residual = hidden_states @@ -477,15 +509,15 @@ def forward( hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) - outputs = (hidden_states, ) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (present_key_value, ) + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) - return outputs + return hidden_states class BartEncoder(nn.Module): @@ -507,6 +539,9 @@ def __init__(self, #super().__init__(config) super().__init__() + self.cache_config=cache_config + self.quant_config=quant_config + self.lora_config=lora_config embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings @@ -525,7 +560,8 @@ def __init__(self, embed_dim, ) self.layers = nn.ModuleList( - [BartEncoderLayer(config) for _ in range(config.encoder_layers)]) + [BartEncoderLayer(config,cache_config,quant_config) \ + for _ in range(config.encoder_layers)]) self.layernorm_embedding = nn.LayerNorm(embed_dim) @@ -536,14 +572,8 @@ def set_input_embeddings(self, value): self.embed_tokens = value def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata ) -> Union[Tuple, BaseModelOutput]: r""" Args: @@ -581,28 +611,11 @@ def forward( return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - input = input_ids - input_ids = input_ids.view(-1, input_ids.shape[-1]) - elif inputs_embeds is not None: - input = inputs_embeds[:, :, -1] - else: - raise ValueError( - "You have to specify either input_ids or inputs_embeds") - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + input = input_ids + input_ids = input_ids.view(-1, input_ids.shape[-1]) + inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(input) embed_pos = embed_pos.to(inputs_embeds.device) @@ -610,58 +623,14 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) - # expand attention_mask - if attention_mask is not None: - if self._use_flash_attention_2: - attention_mask = attention_mask if 0 in attention_mask else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa( - attention_mask, inputs_embeds.dtype) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask( - attention_mask, inputs_embeds.dtype) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}.") - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) - - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] - if head_mask is not None else None), - output_attentions=output_attentions, + hidden_states = encoder_layer( + hidden_states=hidden_states, + kv_caches=kv_caches, + attn_metadata=attn_metadata, ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1], ) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) - - if not return_dict: - return tuple( - v for v in [hidden_states, encoder_states, all_attentions] - if v is not None) - return BaseModelOutput(last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions) + return hidden_states class BartDecoder(nn.Module): @@ -681,8 +650,10 @@ def __init__( lora_config: Optional[LoRAConfig] = None, embed_tokens: Optional[nn.Embedding] = None, ): - #super().__init__(config) super().__init__() + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings embed_scale = math.sqrt( @@ -700,8 +671,10 @@ def __init__( config.max_position_embeddings, config.d_model, ) + self.layers = nn.ModuleList( - [BartDecoderLayer(config) for _ in range(config.decoder_layers)]) + [BartDecoderLayer(config,cache_config,quant_config) \ + for _ in range(config.decoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -712,7 +685,7 @@ def set_input_embeddings(self, value): self.embed_tokens = value def forward( - self, input_ids: torch.Tensor, positions: torch.Tensor, + self, decoder_input_ids: torch.Tensor, decoder_positions: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor], encoder_positions: Optional[torch.Tensor], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: @@ -781,150 +754,32 @@ def forward( return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - input = input_ids - input_shape = input.shape - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - input = inputs_embeds[:, :, -1] - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) + input = decoder_input_ids + input_shape = input.shape + decoder_input_ids = decoder_input_ids.view(-1, input_shape[-1]) - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[ - 2] if past_key_values is not None else 0 - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if ( - attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, - past_key_values_length) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1]) + inputs_embeds = self.embed_tokens(input) # embed positions - positions = self.embed_positions(input, past_key_values_length) - positions = positions.to(inputs_embeds.device) + decoder_positions = self.embed_positions(input, past_key_values_length) + decoder_positions = decoder_positions.to(inputs_embeds.device) - hidden_states = inputs_embeds + positions + hidden_states = inputs_embeds + decoder_positions hidden_states = self.layernorm_embedding(hidden_states) # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if ( - output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None - - # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], - ["head_mask", "cross_attn_head_mask"]): - if attn_mask is not None: - if attn_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}.") for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states, ) - - past_key_value = past_key_values[ - idx] if past_key_values is not None else None - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] - if head_mask is not None else None), - cross_attn_layer_head_mask=(cross_attn_head_mask[idx] - if cross_attn_head_mask is not None - else None), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, + hidden_states = decoder_layer( + decoder_hidden_states=hidden_states, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += ( - layer_outputs[3 if output_attentions else 1], ) - - if output_attentions: - all_self_attns += (layer_outputs[1], ) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2], ) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states, ) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [ - hidden_states, next_cache, all_hidden_states, all_self_attns, - all_cross_attentions - ] if v is not None) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) + # hidden_states = layer_outputs[0] + + return hidden_states class BartModel(nn.Module): @@ -994,97 +849,28 @@ def forward( encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata ) -> Union[Tuple, Seq2SeqModelOutput]: - decoder_inputs_embeds = None - decoder_input_ids = input_ids - attention_mask = None - head_mask = None - inputs_embeds = None - decoder_attention_mask = None - decoder_head_mask = None - cross_attn_head_mask = None - past_key_values = None - - # different to other models, Bart automatically creates decoder_input_ids from - # input_ids if no decoder_input_ids are provided - if decoder_input_ids is None and decoder_inputs_embeds is None: - if input_ids is None: - raise ValueError( - "If no `decoder_input_ids` or `decoder_inputs_embeds` are " - "passed, `input_ids` cannot be `None`. Please pass either " - "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." - ) - - decoder_input_ids = shift_tokens_right( - input_ids, self.config.pad_token_id, - self.config.decoder_start_token_id) - - #output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_attentions = self.config.output_attentions - - # output_hidden_states = (output_hidden_states - # if output_hidden_states is not None else - # self.config.output_hidden_states) - - output_hidden_states = self.config.output_hidden_states - - # use_cache = use_cache if use_cache is not None else self.config.use_cache - # return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - use_cache = self.config.use_cache - return_dict = self.config.use_return_dict if encoder_input_ids.numel() > 0: # Run encoder attention if a non-zero number of encoder tokens # are provided as input encoder_hidden_states = self.encoder( - input_ids=input_ids, - - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + input_ids=encoder_input_ids, + positions=encoder_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata ) - # # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True - # elif return_dict and not isinstance(encoder_hidden_states, BaseModelOutput): - # encoder_hidden_states = BaseModelOutput( - # last_hidden_state=encoder_hidden_states[0], - # hidden_states=encoder_hidden_states[1] - # if len(encoder_hidden_states) > 1 else None, - # attentions=encoder_hidden_states[2] - # if len(encoder_hidden_states) > 2 else None, - # ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_hidden_states[0], - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=encoder_hidden_states, + encoder_positions=encoder_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata ) - if not return_dict: - return decoder_outputs + encoder_hidden_states - - return Seq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_hidden_states.last_hidden_state, - encoder_hidden_states=encoder_hidden_states.hidden_states, - encoder_attentions=encoder_hidden_states.attentions, - ) + return decoder_outputs class BartForConditionalGeneration(nn.Module): @@ -1093,7 +879,7 @@ class BartForConditionalGeneration(nn.Module): "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight" ] - _keys_to_ignore_on_load_missing = ["final_logits_bias"] + # _keys_to_ignore_on_load_missing = ["final_logits_bias"] def __init__(self, config: BartConfig, From b8d5637c510b42a6503d9b0c4d810fe3568314dd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 24 Jun 2024 12:50:25 -0400 Subject: [PATCH 277/443] wip bart --- examples/offline_inference_encoder_decoder.py | 8 ++++---- vllm/model_executor/models/bart.py | 9 ++++++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 95ed705f47c7e..5d7cecb7a5ecc 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -10,10 +10,10 @@ ] # - Decoder prompts decoder_prompts = [ - "", - "", - "", - "", + "ad", + "b", + "cat", + "dabble", ] # - Unified prompts prompts = [enc_dec for enc_dec in zip(encoder_prompts, decoder_prompts)] diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 0aeedf2594375..ec5bed3205f0c 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -121,6 +121,13 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, return shifted_input_ids +def get_bsz_seq_len(input_ids): + shp = input_ids.shape + ndim = len(shp) + if ndim == 1: + return 1, input_ids.numel() + else: + return shp[:2] class BartLearnedPositionalEmbedding(nn.Embedding): """ @@ -138,7 +145,7 @@ def forward(self, past_key_values_length: int = 0): """`input_ids' shape is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids.shape[:2] + bsz, seq_len = get_bsz_seq_len(input_ids) positions = torch.arange(past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, From 7d2fcf90a6516be432ffd39f4571ed0a524438b2 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 24 Jun 2024 15:39:07 -0400 Subject: [PATCH 278/443] BART passes profile run --- vllm/model_executor/models/bart.py | 48 ++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index ec5bed3205f0c..c6ed6b641abfd 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -142,12 +142,30 @@ def __init__(self, num_embeddings: int, embedding_dim: int): def forward(self, input_ids: torch.Tensor, + attn_type: AttentionType, + attn_metadata: AttentionMetadata, past_key_values_length: int = 0): """`input_ids' shape is expected to be [bsz x seqlen].""" + assert attn_type != AttentionType.ENCODER_DECODER + bsz, seq_len = get_bsz_seq_len(input_ids) - positions = torch.arange(past_key_values_length, - past_key_values_length + seq_len, + # afeldman-nm: This BART implementation is designed for vLLM, which + # packs variable-length sequences into a single vector + # without padding + assert bsz == 1 + + if attn_type == AttentionType.ENCODER: + seq_lens=attn_metadata.encoder_seq_lens + else: + # AttentionType.DECODER + seq_lens=attn_metadata.seq_lens + + positions=[] + for seq_len in seq_lens: + positions.extend(list(range(seq_len))) + + positions = torch.tensor(positions, dtype=torch.long, device=self.weight.device).expand(bsz, -1) @@ -229,7 +247,7 @@ def forward( attn_metadata, attn_type=AttentionType.ENCODER) - output, _ = self.out_proj(attn_output) + output = self.out_proj(attn_output) return output class BartDecoderSelfAttention(nn.Module): @@ -291,7 +309,7 @@ def forward( attn_metadata, attn_type=AttentionType.DECODER) - output, _ = self.out_proj(attn_output) + output = self.out_proj(attn_output) return output class BartCrossAttention(nn.Module): @@ -356,7 +374,7 @@ def forward( attn_metadata, attn_type=AttentionType.ENCODER_DECODER) - output, _ = self.out_proj(attn_output) + output = self.out_proj(attn_output) return output class BartEncoderLayer(nn.Module): @@ -398,7 +416,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states = self.self_attn( hidden_states=hidden_states, kv_caches=kv_caches, attn_metadata=attn_metadata @@ -444,6 +462,12 @@ def __init__(self, config: BartConfig, self.activation_fn = ACT2FN[config.activation_function] self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + ''' + afeldman-nm: personally I would call this "cross-attention", + however I left the name as "encoder_attn" to maintain consistency + with the name of the pretrained weights. + ''' self.encoder_attn = BartCrossAttention( self.embed_dim, config.decoder_attention_heads, @@ -480,7 +504,7 @@ def forward( Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ - residual = hidden_states + residual = decoder_hidden_states # Self Attention hidden_states = self.self_attn( @@ -624,7 +648,7 @@ def forward( input_ids = input_ids.view(-1, input_ids.shape[-1]) inputs_embeds = self.embed_tokens(input_ids) - embed_pos = self.embed_positions(input) + embed_pos = self.embed_positions(input,AttentionType.ENCODER,attn_metadata) embed_pos = embed_pos.to(inputs_embeds.device) hidden_states = inputs_embeds + embed_pos @@ -633,7 +657,7 @@ def forward( for idx, encoder_layer in enumerate(self.layers): hidden_states = encoder_layer( hidden_states=hidden_states, - kv_caches=kv_caches, + kv_caches=kv_caches[idx], attn_metadata=attn_metadata, ) @@ -769,7 +793,7 @@ def forward( inputs_embeds = self.embed_tokens(input) # embed positions - decoder_positions = self.embed_positions(input, past_key_values_length) + decoder_positions = self.embed_positions(input,AttentionType.DECODER,attn_metadata) decoder_positions = decoder_positions.to(inputs_embeds.device) hidden_states = inputs_embeds + decoder_positions @@ -780,7 +804,7 @@ def forward( for idx, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( decoder_hidden_states=hidden_states, - kv_caches=kv_caches, + kv_caches=kv_caches[idx], attn_metadata=attn_metadata, encoder_hidden_states=encoder_hidden_states, ) @@ -976,7 +1000,7 @@ def forward( """ hidden_states = self.model(input_ids, positions, encoder_input_ids, encoder_positions, kv_caches, attn_metadata) - return hidden_states + return hidden_states[0,:,:] # return_dict = return_dict if return_dict is not None else self.config.use_return_dict From 6fd4c020a9c5ee8ecbf6e086d8b9dfefb3f8396f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 24 Jun 2024 15:42:09 -0400 Subject: [PATCH 279/443] fixed prompt processing bug that was preventing inference from starting --- vllm/engine/llm_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a2f0c8eb6785f..449ba28105327 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -558,7 +558,8 @@ def process_model_inputs( # (leave decoder input to default) inputs = {"encoder_prompt": inputs} - if isinstance(inputs, EncoderDecoderStringPrompts): + if isinstance(inputs, tuple) and len(inputs) == 2: + # Detect input which is EncoderDecoderStringPrompts (i.e. Tuple[str,str]) # Interpret a tuple of input string prompts as a single # encoder input and a single decoder input, respectively inputs = { From 8f9ee625557ec34ec29787b6b66ec760ff390e77 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 24 Jun 2024 18:06:10 -0400 Subject: [PATCH 280/443] wip bart-cnn summarization example --- examples/offline_inference_encoder_decoder.py | 14 ++++++++------ tests/conftest.py | 6 ++++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 5d7cecb7a5ecc..f5159b272d197 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -3,17 +3,19 @@ # Sample prompts. # - Encoder prompts encoder_prompts = [ - "Hello, my name is", + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.", "The president of the United States is", "The capital of France is", "The future of AI is", ] # - Decoder prompts decoder_prompts = [ - "ad", - "b", - "cat", - "dabble", + "", + "", + "", + "", ] # - Unified prompts prompts = [enc_dec for enc_dec in zip(encoder_prompts, decoder_prompts)] @@ -22,7 +24,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/bart-base", +llm = LLM(model="facebook/bart-large-cnn", enforce_eager=True) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/tests/conftest.py b/tests/conftest.py index 67885b93285c5..353e7a0e10671 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -113,6 +113,12 @@ def example_prompts() -> List[str]: prompts += _read_prompts(filename) return prompts +@pytest.fixture +def example_encoder_decoder_prompts() -> List[str]: + prompts = [] + for filename in _TEST_PROMPTS: + prompts += _read_prompts(filename) + return prompts @pytest.fixture def example_long_prompts() -> List[str]: From 2d8429e1b0002eccb7deaa805d25ebb6d5616187 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 24 Jun 2024 18:47:19 -0400 Subject: [PATCH 281/443] fixed a number of bugs related to BART decode-phase; added support for the particular architecture alias used by bart-large-cnn --- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/bart.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index cb049268db73d..b5c12a1f5eb2e 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -69,6 +69,7 @@ _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), + "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), } _MODELS = { diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index c6ed6b641abfd..2101dfd0daade 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -159,7 +159,12 @@ def forward(self, seq_lens=attn_metadata.encoder_seq_lens else: # AttentionType.DECODER - seq_lens=attn_metadata.seq_lens + if attn_metadata.num_prefill_tokens > 0: + # Prefill + seq_lens=attn_metadata.seq_lens + else: + # Decode: infer one (1) new token per sequence + seq_lens=[1]*len(attn_metadata.seq_lens) positions=[] for seq_len in seq_lens: @@ -793,10 +798,10 @@ def forward( inputs_embeds = self.embed_tokens(input) # embed positions - decoder_positions = self.embed_positions(input,AttentionType.DECODER,attn_metadata) - decoder_positions = decoder_positions.to(inputs_embeds.device) + embed_pos = self.embed_positions(input,AttentionType.DECODER,attn_metadata) + embed_pos = embed_pos.to(inputs_embeds.device) - hidden_states = inputs_embeds + decoder_positions + hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) # decoder layers @@ -881,6 +886,8 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata ) -> Union[Tuple, Seq2SeqModelOutput]: + encoder_hidden_states = None + if encoder_input_ids.numel() > 0: # Run encoder attention if a non-zero number of encoder tokens # are provided as input From 919bf88f8925b2e60c765f309df655318c392c2e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:13:52 -0400 Subject: [PATCH 282/443] BART e2e test runs but does not pass --- tests/conftest.py | 153 ++++++++++++++++++++++++++++++++++---- tests/models/test_bart.py | 32 ++++++-- 2 files changed, 164 insertions(+), 21 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 353e7a0e10671..c54d344e609e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,8 @@ import torch.nn.functional as F from PIL import Image from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, - AutoProcessor, AutoTokenizer, BatchEncoding) + AutoProcessor, AutoTokenizer, BatchEncoding, + AutoModelForSeq2SeqLM) from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig @@ -21,6 +22,7 @@ from vllm.multimodal.image import ImageFeatureData, ImagePixelData from vllm.sequence import SampleLogprobs from vllm.utils import cuda_device_count_stateless, is_cpu +from vllm.outputs import RequestOutput logger = init_logger(__name__) @@ -114,11 +116,23 @@ def example_prompts() -> List[str]: return prompts @pytest.fixture -def example_encoder_decoder_prompts() -> List[str]: - prompts = [] +def example_encoder_decoder_prompts() -> Tuple[List[str],List[str]]: + ''' + Returns an encoder prompt list and a decoder prompt list, wherein each pair + of same-index entries in both lists corresponds to an (encoder prompt, + decoder prompt) tuple. + + Returns: + * Encoder prompt list + * Decoder prompt list (reverse of encoder prompt list) + ''' + encoder_prompts = [] for filename in _TEST_PROMPTS: - prompts += _read_prompts(filename) - return prompts + encoder_prompts += _read_prompts(filename) + + # Encoder prompts, decoder prompts + return encoder_prompts, \ + encoder_prompts[::-1] @pytest.fixture def example_long_prompts() -> List[str]: @@ -153,6 +167,7 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, is_embedding_model: bool = False, is_vision_model: bool = False, + is_encoder_decoder_model: bool = False ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] @@ -170,6 +185,8 @@ def __init__( else: if is_vision_model: auto_cls = AutoModelForVision2Seq + elif is_encoder_decoder_model: + auto_cls = AutoModelForSeq2SeqLM else: auto_cls = AutoModelForCausalLM @@ -362,6 +379,72 @@ def generate_greedy_logprobs_limit( return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + def generate_encoder_decoder_greedy_logprobs_limit( + self, + encoder_decoder_prompts: Tuple[List[str],List[str]], + max_tokens: int, + num_logprobs: int, + ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + ''' + Greedy logprobs generation for vLLM encoder/decoder models + ''' + + all_logprobs: List[List[Dict[int, float]]] = [] + all_output_ids: List[List[int]] = [] + all_output_strs: List[str] = [] + + for encoder_prompt,decoder_prompt in zip(*encoder_decoder_prompts): + encoder_input_ids = self.tokenizer(encoder_prompt, return_tensors="pt").input_ids + decoder_input_ids = self.tokenizer(decoder_prompt, return_tensors="pt").input_ids + output = self.model.generate( + self.wrap_device(encoder_input_ids), + decoder_input_ids=self.wrap_device(decoder_input_ids), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + seq_logprobs: List[torch.Tensor] = [] + for _, decoder_hidden_states in enumerate(output.decoder_hidden_states): + last_hidden_states = decoder_hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if getattr(self.model.get_output_embeddings(), "bias", + None) is not None: + logits += self.model.get_output_embeddings( + ).bias.unsqueeze(0) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + seq_logprobs.append(logprobs) + + # convert to dict + seq_logprobs_lst: List[Dict[int, float]] = [] + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(num_logprobs) + + tok_logprobs_dct = {} + for token_id, logprob in zip(topk.indices[0], topk.values[0]): + tok_logprobs_dct[token_id.item()] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + all_logprobs.append(seq_logprobs_lst) + seq_ids = output.sequences[0] + output_len = seq_ids.shape[0] - decoder_input_ids.shape[1] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: return self.model.encode(prompts) @@ -442,6 +525,22 @@ def generate( outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs + def _final_steps_generate_w_logprobs(self, + req_outputs: List[RequestOutput]) \ + -> List[ + Tuple[List[int], + str, + Optional[ + SampleLogprobs]]]: + outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] + for req_output in req_outputs: + for sample in req_output.outputs: + output_str = sample.text + output_ids = sample.token_ids + output_logprobs = sample.logprobs + outputs.append((output_ids, output_str, output_logprobs)) + return outputs + def generate_w_logprobs( self, prompts: List[str], @@ -451,14 +550,24 @@ def generate_w_logprobs( req_outputs = self.model.generate(prompts, sampling_params=sampling_params) - outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] - for req_output in req_outputs: - for sample in req_output.outputs: - output_str = sample.text - output_ids = sample.token_ids - output_logprobs = sample.logprobs - outputs.append((output_ids, output_str, output_logprobs)) - return outputs + return self._final_steps_generate_w_logprobs(req_outputs) + + def generate_encoder_decoder_w_logprobs( + self, + encoder_decoder_prompts: Tuple[List[str],List[str]], + sampling_params: SamplingParams, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + ''' + Logprobs generation for vLLM encoder/decoder models + ''' + + assert sampling_params.logprobs is not None + + prompt_inputs = list(zip(encoder_decoder_prompts[0],encoder_decoder_prompts[1])) + + req_outputs = self.model.generate(prompt_inputs, + sampling_params=sampling_params) + return self._final_steps_generate_w_logprobs(req_outputs) def generate_greedy( self, @@ -485,6 +594,24 @@ def generate_greedy_logprobs( return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + def generate_encoder_decoder_greedy_logprobs( + self, + encoder_decoder_prompts: Tuple[List[str],List[str]], + max_tokens: int, + num_logprobs: int, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + greedy_logprobs_params = SamplingParams(temperature=0.0, + max_tokens=max_tokens, + logprobs=num_logprobs) + ''' + Greedy logprobs generation for vLLM encoder/decoder models + ''' + + outputs = self.generate_encoder_decoder_w_logprobs(encoder_decoder_prompts, greedy_logprobs_params) + + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + def generate_beam_search( self, prompts: List[str], diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index df76777a0de00..71880f2530031 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -5,31 +5,47 @@ import pytest from .utils import check_logprobs_close +from tests.kernels.utils import override_backend_env_variable +from vllm.utils import STR_XFORMERS_ATTN_VAL -MODELS = ["facebook/bart-base"] +MODELS = ["facebook/bart-large-cnn"] + +# Backends under test +# +# Currently only XFormers is supported +BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("backend_name",BACKEND_NAMES) def test_models( hf_runner, vllm_runner, - example_prompts, + example_encoder_decoder_prompts, model: str, dtype: str, max_tokens: int, num_logprobs: int, + backend_name: str, + monkeypatch, ) -> None: # TODO(sang): Sliding window should be tested separately. - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + with hf_runner(model, + dtype=dtype, + is_encoder_decoder_model=True) as hf_model: + hf_outputs = hf_model.generate_encoder_decoder_greedy_logprobs_limit( + example_encoder_decoder_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + example_encoder_decoder_prompts, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, From 597526a49e041ec99329add79ef272ce6e457b9e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:18:02 -0400 Subject: [PATCH 283/443] removed extra line --- vllm/attention/backends/xformers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index a3f3d41a5491c..d66d2ce9c277a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -556,9 +556,7 @@ def forward( decode_query = query[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] - if key is not None and value is not None: - key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] From a178b7a8c9838665ee7e169471206b70d62e1b71 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:20:00 -0400 Subject: [PATCH 284/443] changed nested if/else to elif/else in xformers mask computation code --- vllm/attention/backends/xformers.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d66d2ce9c277a..1090c2e062bd5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -677,19 +677,18 @@ def _run_memory_efficient_xformers_forward( # Default enc/dec cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) + elif attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + + # Default encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.encoder_seq_lens) else: - if attn_type == AttentionType.ENCODER: - assert attn_metadata.encoder_seq_lens is not None - - # Default encoder self-attention mask is non-causal - attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.encoder_seq_lens) - else: - assert attn_metadata.seq_lens is not None - - # Default decoder self-attention mask is causal - attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seq_lens) + assert attn_metadata.seq_lens is not None + + # Default decoder self-attention mask is causal + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) From 06c7f7500140c574d20a12079dbd1ef83db29688 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:28:42 -0400 Subject: [PATCH 285/443] reorganized helper functions that were only being used for testing into tests/kernels/utils.py from vllm/utils.py --- tests/kernels/test_encoder_decoder_attn.py | 3 +- tests/kernels/utils.py | 81 ++++++++++++++++++++- vllm/utils.py | 84 +--------------------- 3 files changed, 82 insertions(+), 86 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index f61b0a0dcc706..44519e6c9ec3f 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -18,7 +18,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP -from vllm.utils import is_hip, make_causal_mask, maybe_make_long_tensor +from vllm.utils import is_hip +from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor HEAD_SIZES = [64, 256] diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 94e7379123c7c..b7c51c1bcf5c7 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -10,8 +10,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend -from vllm.utils import (make_tensor_with_pad, maybe_make_int_tensor, - maybe_make_long_tensor, maybe_max) +import numpy as np +from numbers import Number # String name of register which may be set in order to # force auto-selection of attention backend by Attention @@ -138,6 +138,83 @@ class PhaseTestParameters(NamedTuple): packed_qkvo: PackedQKVO kv_mmap: Optional[KVMemoryMap] +def make_tensor_with_pad( + x: List[List[int]], + max_len: int, + pad: int, + dtype: torch.dtype, + device: Optional[Union[str, torch.device]], +) -> torch.Tensor: + """Make a padded tensor of a 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len + padded_x[ind, :len(blocktb)] = blocktb + return torch.tensor(padded_x, dtype=dtype, device=device) + +def maybe_make_int_tensor(_list: List[int], + device: Union[torch.device, str]) \ + -> torch.Tensor: + ''' + Convert Python int list to a 1D int torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D int torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.int, device=device) + +def maybe_make_long_tensor(_list: List[int], + device: Union[torch.device, str]) \ + -> torch.Tensor: + ''' + Convert Python int list to a 1D long torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D long torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.long, device=device) + + +def maybe_max(_list: List) -> Optional[Number]: + ''' + Returns: + + * If _list is not None: max(_list) + * None otherwise + ''' + return None if _list is None else max(_list) + +def make_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ + -> torch.Tensor: + ''' + Create a q_max_seq_len x kv_max_seq_len causal mask + + Arguments: + + * q_max_seq_len: query max seq len + * kv_max_seq_len: key/value max seq len + + Returns: + + * 2D tensor, q_max_seq_len x kv_max_seq_len + ''' + + # Create a matrix where entry (i, j) is True if i >= j + mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) + # Replace True with float('-inf') and False with 0 + mask = mask.masked_fill(mask == 1, + float('-inf')).masked_fill(mask == 0, 0.0) + return mask def override_backend_env_variable(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: diff --git a/vllm/utils.py b/vllm/utils.py index 7e3ebab01513f..7a4639950472d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -13,13 +13,12 @@ import warnings from collections import defaultdict from functools import lru_cache, partial, wraps -from numbers import Number from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar, Union) -import numpy as np + import psutil import torch import torch.types @@ -585,26 +584,6 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]: "String must be a series of integers separated by commas " f"(e.g., 1, 2, 3). Given input: {s}") from e - -def make_tensor_with_pad( - x: List[List[int]], - max_len: int, - pad: int, - dtype: torch.dtype, - device: Optional[Union[str, torch.device]], -) -> torch.Tensor: - """Make a padded tensor of a 2D inputs. - - The padding is applied to the end of each inner list until it reaches - `max_len`. - """ - padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad - for ind, blocktb in enumerate(x): - assert len(blocktb) <= max_len - padded_x[ind, :len(blocktb)] = blocktb - return torch.tensor(padded_x, dtype=dtype, device=device) - - def async_tensor_h2d( data: list, dtype: torch.dtype, @@ -799,67 +778,6 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) -def maybe_make_int_tensor(_list: List[int], - device: Union[torch.device, str]) \ - -> torch.Tensor: - ''' - Convert Python int list to a 1D int torch.Tensor on `device` - - Returns: - - * If _list is not None: 1D int torch.Tensor on `device` - * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.int, device=device) - -def maybe_make_long_tensor(_list: List[int], - device: Union[torch.device, str]) \ - -> torch.Tensor: - ''' - Convert Python int list to a 1D long torch.Tensor on `device` - - Returns: - - * If _list is not None: 1D long torch.Tensor on `device` - * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.long, device=device) - - -def maybe_max(_list: List) -> Optional[Number]: - ''' - Returns: - - * If _list is not None: max(_list) - * None otherwise - ''' - return None if _list is None else max(_list) - -def make_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ - -> torch.Tensor: - ''' - Create a q_max_seq_len x kv_max_seq_len causal mask - - Arguments: - - * q_max_seq_len: query max seq len - * kv_max_seq_len: key/value max seq len - - Returns: - - * 2D tensor, q_max_seq_len x kv_max_seq_len - ''' - - # Create a matrix where entry (i, j) is True if i >= j - mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) - # Replace True with float('-inf') and False with 0 - mask = mask.masked_fill(mask == 1, - float('-inf')).masked_fill(mask == 0, 0.0) - return mask - - #From: https://stackoverflow.com/a/4104188/2749989 def run_once(f): From 47c9f396fdcd40895597423ebfefe585b014c2f3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:32:52 -0400 Subject: [PATCH 286/443] removed attention_type --- tests/kernels/test_encoder_decoder_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 44519e6c9ec3f..2421c022b0ec5 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -574,7 +574,7 @@ def _run_encoder_attention_test(attn: Attention, ''' Run encoder attention. - attn_metadata.attention_type is assigned AttentionType.ENCODER in order + attn.forward() is passed attn_type=AttentionType.ENCODER in order to configure the kernel invocation for encoder attention Requires attn_metadata.num_decode_tokens == 0 @@ -612,7 +612,7 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, ''' Run decoder self-attention test. - attn_metadata.attention_type is assigned AttentionType.DECODER + attn.forward() is passed attn_type=AttentionType.DECODER in order to configure the kernel invocation for decoder self-attention. Arguments: @@ -657,7 +657,7 @@ def _run_encoder_decoder_cross_attention_test( is None, this reflects that in decode-phase cross attention there is no growth in the key and value tensors. - attn_metadata.attention_type is assigned AttentionType.ENCODER_DECODER + attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER in order to configure the kernel invocation for encoder/decoder cross- attention. From 2f0b05bb805513e73eb0609ea87b6367ec9d4803 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:35:34 -0400 Subject: [PATCH 287/443] typing and formatting --- tests/kernels/test_encoder_decoder_attn.py | 2 +- tests/kernels/utils.py | 12 +++++++----- vllm/utils.py | 3 ++- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 2421c022b0ec5..654cd621145c5 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -15,11 +15,11 @@ import torch from tests.kernels.utils import * +from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.utils import is_hip -from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor HEAD_SIZES = [64, 256] diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index b7c51c1bcf5c7..3ae345bafa36e 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,16 +2,16 @@ import itertools import random +from numbers import Number from typing import Any, List, NamedTuple, Optional, Tuple, Union +import numpy as np import pytest import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend -import numpy as np -from numbers import Number # String name of register which may be set in order to # force auto-selection of attention backend by Attention @@ -138,6 +138,7 @@ class PhaseTestParameters(NamedTuple): packed_qkvo: PackedQKVO kv_mmap: Optional[KVMemoryMap] + def make_tensor_with_pad( x: List[List[int]], max_len: int, @@ -156,7 +157,7 @@ def make_tensor_with_pad( padded_x[ind, :len(blocktb)] = blocktb return torch.tensor(padded_x, dtype=dtype, device=device) -def maybe_make_int_tensor(_list: List[int], +def maybe_make_int_tensor(_list: Optional[List[int]], device: Union[torch.device, str]) \ -> torch.Tensor: ''' @@ -170,7 +171,7 @@ def maybe_make_int_tensor(_list: List[int], return None if _list is None else torch.tensor( _list, dtype=torch.int, device=device) -def maybe_make_long_tensor(_list: List[int], +def maybe_make_long_tensor(_list: Optional[List[int]], device: Union[torch.device, str]) \ -> torch.Tensor: ''' @@ -185,7 +186,7 @@ def maybe_make_long_tensor(_list: List[int], _list, dtype=torch.long, device=device) -def maybe_max(_list: List) -> Optional[Number]: +def maybe_max(_list: Optional[List]) -> Optional[Number]: ''' Returns: @@ -216,6 +217,7 @@ def make_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ float('-inf')).masked_fill(mask == 0, 0.0) return mask + def override_backend_env_variable(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: ''' diff --git a/vllm/utils.py b/vllm/utils.py index 7a4639950472d..cc11e4b00283f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -18,7 +18,6 @@ Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar, Union) - import psutil import torch import torch.types @@ -584,6 +583,7 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]: "String must be a series of integers separated by commas " f"(e.g., 1, 2, 3). Given input: {s}") from e + def async_tensor_h2d( data: list, dtype: torch.dtype, @@ -778,6 +778,7 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) + #From: https://stackoverflow.com/a/4104188/2749989 def run_once(f): From d23c28466765496049a1696d0a053a0a2505ce9a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:38:08 -0400 Subject: [PATCH 288/443] typing and formatting; fixed escape sequences in comments --- tests/kernels/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 3ae345bafa36e..45f56e364175e 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -633,18 +633,18 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], Context: * Your goal is to test (1) prefill of N prompts, with prompt-lengths - {K_i \forall i \in [0,N)}, followed by (2) decoding of a single token + {K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token for all N prompts (N tokens total); the resultant sequence lengths - after decode would be {K_i + 1 for i \in [0,N)} + after decode would be {K_i + 1 for i \\in [0,N)} * The test you want to do requires (1) having the prefill slot mapping for all tokens present during prefill, the number of which is - M = \sum_i{K_i}, and (2) having the decode slot mapping for all N + M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N decoded tokens This function consumes a single 1D slot mapping, which is the concatenation of N slot mappings each of length K_i + 1 (corresponding to the sequence lengths after decode), with a total length of - P = \sum_i{K_i + 1} = M + N + P = \\sum_i{K_i + 1} = M + N The prefill-phase slot mapping results from excising the (K_i + 1)-th entry from each of the N subsequences in the slot mapping (i.e. omitting the From 1a6e5a31846e2ef886b66e9cc9216ffe983d0ec0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:52:04 -0400 Subject: [PATCH 289/443] moved make_tensor_with_pad() helper function back to vllm.utils --- tests/kernels/utils.py | 22 ++-------------------- vllm/utils.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 45f56e364175e..7e1e084f650ea 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -5,7 +5,6 @@ from numbers import Number from typing import Any, List, NamedTuple, Optional, Tuple, Union -import numpy as np import pytest import torch @@ -13,6 +12,8 @@ AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend +from vllm.utils import make_tensor_with_pad + # String name of register which may be set in order to # force auto-selection of attention backend by Attention # wrapper @@ -138,25 +139,6 @@ class PhaseTestParameters(NamedTuple): packed_qkvo: PackedQKVO kv_mmap: Optional[KVMemoryMap] - -def make_tensor_with_pad( - x: List[List[int]], - max_len: int, - pad: int, - dtype: torch.dtype, - device: Optional[Union[str, torch.device]], -) -> torch.Tensor: - """Make a padded tensor of a 2D inputs. - - The padding is applied to the end of each inner list until it reaches - `max_len`. - """ - padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad - for ind, blocktb in enumerate(x): - assert len(blocktb) <= max_len - padded_x[ind, :len(blocktb)] = blocktb - return torch.tensor(padded_x, dtype=dtype, device=device) - def maybe_make_int_tensor(_list: Optional[List[int]], device: Union[torch.device, str]) \ -> torch.Tensor: diff --git a/vllm/utils.py b/vllm/utils.py index cc11e4b00283f..1a778005bd867 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -27,6 +27,8 @@ from vllm import _custom_ops as ops from vllm.logger import enable_trace_function_call, init_logger +import numpy as np + logger = init_logger(__name__) STR_DTYPE_TO_TORCH_DTYPE = { @@ -573,6 +575,23 @@ def __exit__(self, exc_type, exc_val, exc_tb): # Force garbage collection gc.collect() +def make_tensor_with_pad( + x: List[List[int]], + max_len: int, + pad: int, + dtype: torch.dtype, + device: Optional[Union[str, torch.device]], +) -> torch.Tensor: + """Make a padded tensor of a 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len + padded_x[ind, :len(blocktb)] = blocktb + return torch.tensor(padded_x, dtype=dtype, device=device) def str_to_int_tuple(s: str) -> Tuple[int, ...]: """Convert a string to a tuple of integers.""" From e2a46e3b7b9f9d1a9cc751046c3cddd1522620ed Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:53:35 -0400 Subject: [PATCH 290/443] formatting --- tests/kernels/utils.py | 1 - vllm/utils.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 7e1e084f650ea..f0b0dd5dbaee6 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -11,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend - from vllm.utils import make_tensor_with_pad # String name of register which may be set in order to diff --git a/vllm/utils.py b/vllm/utils.py index 1a778005bd867..127f86733a852 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -18,6 +18,7 @@ Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar, Union) +import numpy as np import psutil import torch import torch.types @@ -27,8 +28,6 @@ from vllm import _custom_ops as ops from vllm.logger import enable_trace_function_call, init_logger -import numpy as np - logger = init_logger(__name__) STR_DTYPE_TO_TORCH_DTYPE = { @@ -575,6 +574,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): # Force garbage collection gc.collect() + def make_tensor_with_pad( x: List[List[int]], max_len: int, @@ -593,6 +593,7 @@ def make_tensor_with_pad( padded_x[ind, :len(blocktb)] = blocktb return torch.tensor(padded_x, dtype=dtype, device=device) + def str_to_int_tuple(s: str) -> Tuple[int, ...]: """Convert a string to a tuple of integers.""" try: From 5169a2a6518d5ae338001eae0eae6dad64bf52eb Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 03:25:40 -0400 Subject: [PATCH 291/443] removed unnecessary positions arguments from BART encoder, decoder forward() --- vllm/model_executor/models/bart.py | 283 ++--------------------------- 1 file changed, 16 insertions(+), 267 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 73538ec625943..d7efb17218233 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -82,8 +82,10 @@ class BartLearnedPositionalEmbedding(nn.Embedding): """ def __init__(self, num_embeddings: int, embedding_dim: int): - # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack + # Bart is set up so that if padding_idx is + # specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. + # Other models don't have this hack self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) @@ -551,7 +553,7 @@ def set_input_embeddings(self, value): self.embed_tokens = value def forward( - self, input_ids: torch.Tensor, positions: torch.Tensor, + self, input_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata) -> Union[Tuple, BaseModelOutput]: r""" @@ -665,9 +667,8 @@ def set_input_embeddings(self, value): self.embed_tokens = value def forward( - self, decoder_input_ids: torch.Tensor, decoder_positions: torch.Tensor, + self, decoder_input_ids: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor], - encoder_positions: Optional[torch.Tensor], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" @@ -759,7 +760,6 @@ def forward( attn_metadata=attn_metadata, encoder_hidden_states=encoder_hidden_states, ) - # hidden_states = layer_outputs[0] return hidden_states @@ -774,18 +774,8 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None): - #super().__init__(config) super().__init__() - # padding_idx, vocab_size = config.pad_token_id, config.vocab_size - # self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) - - # self.encoder = BartEncoder(config, self.shared) - # self.decoder = BartDecoder(config, self.shared) - - # # Initialize weights and apply final processing - # self.post_init() - self.config = config self.padding_idx = config.pad_token_id @@ -794,12 +784,6 @@ def __init__(self, self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - # self.embed_tokens = VocabParallelEmbedding( - # self.vocab_size, - # config.hidden_size, - # org_num_embeddings=config.vocab_size, - # ) - self.encoder = BartEncoder(config, cache_config, quant_config=quant_config) @@ -838,16 +822,14 @@ def forward( # Run encoder attention if a non-zero number of encoder tokens # are provided as input encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions, kv_caches=kv_caches, attn_metadata=attn_metadata) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of + # (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( decoder_input_ids=input_ids, - decoder_positions=positions, encoder_hidden_states=encoder_hidden_states, - encoder_positions=encoder_positions, kv_caches=kv_caches, attn_metadata=attn_metadata) @@ -861,21 +843,11 @@ class BartForConditionalGeneration(nn.Module): "lm_head.weight" ] - # _keys_to_ignore_on_load_missing = ["final_logits_bias"] - def __init__(self, config: BartConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None): - #super().__init__(config) - # self.model = BartModel(config) - # self.register_buffer( - # "final_logits_bias", - # torch.zeros((1, self.model.shared.num_embeddings))) - - # # Initialize weights and apply final processing - # self.post_init() super().__init__() self.config = config @@ -890,15 +862,6 @@ def __init__(self, self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) - # self.lm_head = ParallelLMHead( - # self.unpadded_vocab_size, - # config.hidden_size, - # org_num_embeddings=config.vocab_size, - # padding_size=DEFAULT_VOCAB_PADDING_SIZE - # # We need bigger padding if using lora for kernel - # # compatibility - # if not lora_config else lora_config.lora_vocab_padding_size, - # ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() @@ -940,10 +903,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata) -> Union[Tuple, Seq2SeqLMOutput]: r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, + *optional*): Labels for computing the masked language modeling + loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). + Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the + tokens with labels in `[0, ..., config.vocab_size]`. Returns: """ @@ -951,64 +917,6 @@ def forward( encoder_positions, kv_caches, attn_metadata) return hidden_states[0, :, :] - # return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # if labels is not None: - # if use_cache: - # logger.warning( - # "The `use_cache` argument is changed to `False` since `labels` is provided." - # ) - # use_cache = False - # if decoder_input_ids is None and decoder_inputs_embeds is None: - # decoder_input_ids = shift_tokens_right( - # labels, self.config.pad_token_id, - # self.config.decoder_start_token_id) - - # outputs = self.model( - # input_ids, - # attention_mask=attention_mask, - # decoder_input_ids=decoder_input_ids, - # encoder_outputs=encoder_outputs, - # decoder_attention_mask=decoder_attention_mask, - # head_mask=head_mask, - # decoder_head_mask=decoder_head_mask, - # cross_attn_head_mask=cross_attn_head_mask, - # past_key_values=past_key_values, - # inputs_embeds=inputs_embeds, - # decoder_inputs_embeds=decoder_inputs_embeds, - # use_cache=use_cache, - # output_attentions=output_attentions, - # output_hidden_states=output_hidden_states, - # return_dict=return_dict, - # ) - - # lm_logits = self.lm_head(outputs[0]) - # lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) - - # masked_lm_loss = None - # if labels is not None: - # labels = labels.to(lm_logits.device) - # loss_fct = CrossEntropyLoss() - # masked_lm_loss = loss_fct( - # lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) - - # if not return_dict: - # output = (lm_logits, ) + outputs[1:] - # return ((masked_lm_loss, ) + - # output) if masked_lm_loss is not None else output - - # return Seq2SeqLMOutput( - # loss=masked_lm_loss, - # logits=lm_logits, - # past_key_values=outputs.past_key_values, - # decoder_hidden_states=outputs.decoder_hidden_states, - # decoder_attentions=outputs.decoder_attentions, - # cross_attentions=outputs.cross_attentions, - # encoder_last_hidden_state=outputs.encoder_last_hidden_state, - # encoder_hidden_states=outputs.encoder_hidden_states, - # encoder_attentions=outputs.encoder_attentions, - # ) - def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head.weight, hidden_states, @@ -1023,82 +931,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": - None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": - use_cache, # change this to avoid caching (presumably for debugging) - } - - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, - self.config.decoder_start_token_id) - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += (tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past[:2]) + layer_past[2:], ) - return reordered_past - - stacked_params_mapping = { - "query": { - "param_name": "qkv_proj", - "shard_id": "q", - }, - "key": { - "param_name": "qkv_proj", - "shard_id": "k", - }, - "value": { - "param_name": "qkv_proj", - "shard_id": "v", - }, - } - - params_mapping = { - "beta": "bias", - "gamma": "weight", - "LayerNorm": "layernorm", - } - def _rename_key(self, key: str): prefix = f"{self.base_model_prefix}." key = key[len(prefix):] if key.startswith(prefix) else key @@ -1177,87 +1009,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if shard_id: weight_loader(param, loaded_weight, shard_id) else: - weight_loader(param, loaded_weight) - - # def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - - # stacked_params_mapping = [ - # # (param_name, shard_name, shard_id) - # ("qkv_proj", "q_proj", "q"), - # ("qkv_proj", "k_proj", "k"), - # ("qkv_proj", "v_proj", "v"), - # ] - - # expert_params_mapping = [ - # # These are the weight scales for the experts - # # (param_name, weight_name, expert_id) - # ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", - # f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) - # for expert_id in range(self.config.num_local_experts) - # for weight_name in ["w1", "w2", "w3"] - # ] + [ - # # These are the weights for the experts - # # (param_name, weight_name, expert_id) - # ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", - # f"experts.{expert_id}.{weight_name}.weight", expert_id) - # for expert_id in range(self.config.num_local_experts) - # for weight_name in ["w1", "w2", "w3"] - # ] + [ - # # These are the activation scales for the experts - # # (param_name, weight_name, expert_id) - # ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", - # f"experts.{expert_id}.{weight_name}.input_scale", expert_id) - # for expert_id in range(self.config.num_local_experts) - # for weight_name in ["w1", "w2", "w3"] - # ] - - # params_dict = dict(self.named_parameters()) - # for name, loaded_weight in weights: - # if "rotary_emb.inv_freq" in name: - # continue - - # for (param_name, weight_name, shard_id) in stacked_params_mapping: - # if weight_name not in name: - # continue - # name = name.replace(weight_name, param_name) - # # Skip loading extra bias for GPTQ models. - # if name.endswith(".bias") and name not in params_dict: - # continue - # param = params_dict[name] - # weight_loader = param.weight_loader - # weight_loader(param, loaded_weight, shard_id) - # break - # else: - # for param_name, weight_name, expert_id in expert_params_mapping: - # if weight_name not in name: - # continue - # name = name.replace(weight_name, param_name) - # param = params_dict[name] - # weight_loader = param.weight_loader - # weight_loader(param, - # loaded_weight, - # weight_name, - # expert_id=expert_id) - # break - # else: - # # Skip loading extra bias for GPTQ models. - # if name.endswith(".bias") and name not in params_dict: - # continue - # # Remapping the name of FP8 kv-scale. - # if name.endswith("kv_scale"): - # remapped_kv_scale_name = name.replace( - # ".kv_scale", ".attn.kv_scale") - # if remapped_kv_scale_name not in params_dict: - # print_warning_once( - # "Found kv scale in the checkpoint " - # f"(e.g. {name}), but not found the expected " - # f"name in the model " - # f"(e.g. {remapped_kv_scale_name}). " - # "kv-scale is not loaded.") - # continue - # else: - # name = remapped_kv_scale_name - # param = params_dict[name] - # weight_loader = getattr(param, "weight_loader", - # default_weight_loader) - # weight_loader(param, loaded_weight) + weight_loader(param, loaded_weight) \ No newline at end of file From 4400d7733f7dca2acffac916a00f5edc6a89e14e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 03:36:28 -0400 Subject: [PATCH 292/443] some reformatting --- vllm/model_executor/models/bart.py | 169 ++++++++--------------------- 1 file changed, 46 insertions(+), 123 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index d7efb17218233..f034378fee9c4 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -38,8 +38,6 @@ from transformers.activations import ACT2FN from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) @@ -50,23 +48,6 @@ logger = logging.get_logger(__name__) -def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, - decoder_start_token_id: int): - """ - Shift input ids one token to the right. - """ - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() - shifted_input_ids[:, 0] = decoder_start_token_id - - if pad_token_id is None: - raise ValueError("self.model.config.pad_token_id has to be defined.") - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) - - return shifted_input_ids - - def get_bsz_seq_len(input_ids): shp = input_ids.shape ndim = len(shp) @@ -128,7 +109,8 @@ def forward(self, class BartScaledWordEmbedding(nn.Embedding): """ - This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + This module overrides nn.Embeddings' + forward by multiplying with embeddings scale. """ def __init__(self, @@ -162,9 +144,9 @@ def __init__( self.config = config if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads}).") + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -223,9 +205,9 @@ def __init__( self.config = config if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads}).") + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -284,9 +266,9 @@ def __init__( self.config = config if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads}).") + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -503,8 +485,8 @@ def forward( class BartEncoder(nn.Module): """ - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - [`BartEncoderLayer`]. + Transformer encoder consisting of *config.encoder_layers* + self attention layers. Each layer is a [`BartEncoderLayer`]. Args: config: BartConfig @@ -552,45 +534,27 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - def forward( - self, input_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> Union[Tuple, BaseModelOutput]: + def forward(self, input_ids: torch.Tensor, kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: r""" Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. + input_ids + (`torch.LongTensor` of shape `(total_num_tokens)`): - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. + Indices of *encoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + kv_caches: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. + Layer-wise list of KV cache tensors - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + attn_metadata: - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. + vLLM Attention metadata structure - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + Returns: + Decoder output torch.Tensor """ # retrieve input_ids and inputs_embeds @@ -617,7 +581,8 @@ def forward( class BartDecoder(nn.Module): """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] + Transformer decoder consisting of *config.decoder_layers* layers. + Each layer is a [`BartDecoderLayer`] Args: config: BartConfig @@ -666,75 +631,33 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - def forward( - self, decoder_input_ids: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor], - kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + def forward(self, decoder_input_ids: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: r""" Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. + decoder_input_ids + (`torch.LongTensor` of shape `(total_num_tokens)`): - [What are attention masks?](../glossary#attention-mask) - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - of the decoder. - encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values - selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + encoder_hidden_states: - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. + Tensor of encoder output embeddings - cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing - cross-attention on hidden heads. Mask values selected in `[0, 1]`: + kv_caches: - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. + Layer-wise list of KV cache tensors - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + attn_metadata: - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + vLLM Attention metadata structure - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + Returns: + Decoder output torch.Tensor """ input = decoder_input_ids @@ -1009,4 +932,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if shard_id: weight_loader(param, loaded_weight, shard_id) else: - weight_loader(param, loaded_weight) \ No newline at end of file + weight_loader(param, loaded_weight) From e61385d90e40b423e1e5d98839413a76ffcd11fb Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 03:49:18 -0400 Subject: [PATCH 293/443] fixed bug caused by overzealous refactoring --- vllm/model_executor/models/bart.py | 120 ++++++++--------------------- 1 file changed, 31 insertions(+), 89 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index f034378fee9c4..47e7342fa54e3 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -73,8 +73,7 @@ def __init__(self, num_embeddings: int, embedding_dim: int): def forward(self, input_ids: torch.Tensor, attn_type: AttentionType, - attn_metadata: AttentionMetadata, - past_key_values_length: int = 0): + attn_metadata: AttentionMetadata) -> torch.Tensor: """`input_ids' shape is expected to be [bsz x seqlen].""" assert attn_type != AttentionType.ENCODER_DECODER @@ -121,7 +120,7 @@ def __init__(self, super().__init__(num_embeddings, embedding_dim, padding_idx) self.embed_scale = embed_scale - def forward(self, input_ids: torch.Tensor): + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: return super().forward(input_ids) * self.embed_scale @@ -161,15 +160,10 @@ def __init__( cache_config=cache_config, quant_config=quant_config) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: + ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) @@ -222,15 +216,10 @@ def __init__( cache_config=cache_config, quant_config=quant_config) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: + ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) @@ -283,18 +272,13 @@ def __init__( cache_config=cache_config, quant_config=quant_config) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() - def forward( self, decoder_hidden_states: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: + ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" q = self.q_proj(decoder_hidden_states) k=None if encoder_hidden_states is None else \ @@ -339,7 +323,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata - ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -419,8 +403,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, - torch.FloatTensor]]]: + ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -499,7 +482,6 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, embed_tokens: Optional[nn.Embedding] = None): - #super().__init__(config) super().__init__() self.cache_config = cache_config @@ -528,12 +510,6 @@ def __init__(self, self.layernorm_embedding = nn.LayerNorm(embed_dim) - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - def forward(self, input_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata) -> torch.Tensor: r""" @@ -625,12 +601,6 @@ def __init__( self.layernorm_embedding = nn.LayerNorm(config.d_model) - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - def forward(self, decoder_input_ids: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor], kv_caches: List[torch.Tensor], @@ -714,19 +684,6 @@ def __init__(self, cache_config, quant_config=quant_config) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) - self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, value): - self.shared = value - self.encoder.embed_tokens = self.shared - self.decoder.embed_tokens = self.shared - def get_encoder(self): return self.encoder @@ -737,7 +694,7 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata - ) -> Union[Tuple, Seq2SeqModelOutput]: + ) -> torch.Tensor: encoder_hidden_states = None @@ -761,10 +718,6 @@ def forward( class BartForConditionalGeneration(nn.Module): base_model_prefix = "model" - _tied_weights_keys = [ - "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", - "lm_head.weight" - ] def __init__(self, config: BartConfig, @@ -789,42 +742,11 @@ def __init__(self, config.vocab_size) self.sampler = Sampler() - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - - def resize_token_embeddings( - self, - new_num_tokens: int, - pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings( - new_num_tokens, pad_to_multiple_of) - self._resize_final_logits_bias(new_embeddings.weight.shape[0]) - return new_embeddings - - def _resize_final_logits_bias(self, new_num_tokens: int) -> None: - old_num_tokens = self.final_logits_bias.shape[-1] - if new_num_tokens <= old_num_tokens: - new_bias = self.final_logits_bias[:, :new_num_tokens] - else: - extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), - device=self.final_logits_bias.device) - new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) - self.register_buffer("final_logits_bias", new_bias) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> Union[Tuple, Seq2SeqLMOutput]: + attn_metadata: AttentionMetadata) -> torch.Tensor: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling @@ -835,6 +757,7 @@ def forward( tokens with labels in `[0, ..., config.vocab_size]`. Returns: + torch.Tensor inference result """ hidden_states = self.model(input_ids, positions, encoder_input_ids, encoder_positions, kv_caches, attn_metadata) @@ -854,6 +777,27 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + stacked_params_mapping = { + "query": { + "param_name": "qkv_proj", + "shard_id": "q", + }, + "key": { + "param_name": "qkv_proj", + "shard_id": "k", + }, + "value": { + "param_name": "qkv_proj", + "shard_id": "v", + }, + } + + params_mapping = { + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + } + def _rename_key(self, key: str): prefix = f"{self.base_model_prefix}." key = key[len(prefix):] if key.startswith(prefix) else key @@ -879,8 +823,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): top_params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - # if 'shared.weight' in name: - # continue name = self._rename_key(name) name, shard_id = self._rename_stacked_param(name) From 41e31e861b01896a99fba2f2ea44b717164c4398 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 03:59:48 -0400 Subject: [PATCH 294/443] BART with new explanatory comments & passing formatting tests --- tests/conftest.py | 8 +- tests/kernels/test_attention_selector.py | 2 +- tests/kernels/test_encoder_decoder_attn.py | 4 +- tests/kernels/utils.py | 4 +- tests/models/test_bart.py | 3 +- .../test_encoder_decoder_model_runner.py | 2 +- vllm/model_executor/models/bart.py | 171 +++++++++--------- vllm/worker/enc_dec_model_runner.py | 7 +- 8 files changed, 105 insertions(+), 96 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 749d36141b4f8..bd6a10c14c77d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,9 +8,9 @@ import torch.nn as nn import torch.nn.functional as F from PIL import Image -from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, - AutoProcessor, AutoTokenizer, BatchEncoding, - AutoModelForSeq2SeqLM) +from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM, + AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, + BatchEncoding) from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig @@ -20,9 +20,9 @@ from vllm.logger import init_logger from vllm.multimodal import MultiModalData from vllm.multimodal.image import ImageFeatureData, ImagePixelData +from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs from vllm.utils import cuda_device_count_stateless, is_cpu -from vllm.outputs import RequestOutput logger = init_logger(__name__) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index c27607912692b..f2d3cd59cc57d 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -5,7 +5,7 @@ from tests.kernels.utils import override_backend_env_variable from vllm.attention.selector import which_attn_to_use -from vllm.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL) +from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL @pytest.mark.parametrize( diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index e10a335cbe297..e9477d8016d21 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -9,18 +9,18 @@ """ +import copy from typing import NamedTuple, Optional import pytest import torch from tests.kernels.utils import * -import copy from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP -from vllm.utils import (is_hip, LIST_ENC_DEC_SUPPORTED_BACKENDS) +from vllm.utils import LIST_ENC_DEC_SUPPORTED_BACKENDS, is_hip HEAD_SIZES = [64, 256] diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 30ae825c338d7..b8db8c485a16f 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -11,8 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend -from vllm.utils import (make_tensor_with_pad, STR_BACKEND_ENV_VAR, - STR_XFORMERS_ATTN_VAL) +from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, + make_tensor_with_pad) class QKVInputs(NamedTuple): diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index f7ad5086abd58..001d1982777ee 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -4,10 +4,11 @@ """ import pytest -from .utils import check_logprobs_close from tests.kernels.utils import override_backend_env_variable from vllm.utils import STR_XFORMERS_ATTN_VAL +from .utils import check_logprobs_close + MODELS = ["facebook/bart-large-cnn"] # Backends under test diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 8456a2ec8e7bd..88b982bb8fdc2 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -3,11 +3,11 @@ import pytest import torch -from vllm.utils import STR_XFORMERS_ATTN_VAL from tests.kernels.utils import override_backend_env_variable from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.utils import STR_XFORMERS_ATTN_VAL from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner # Backends under test diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 47e7342fa54e3..71abd579ad74c 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -16,35 +16,26 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch BART model.""" -from typing import Iterable, List, Optional, Tuple, Union - -from vllm.attention.backends.abstract import AttentionType +import math +from typing import Iterable, List, Optional, Tuple import torch from torch import nn +from transformers import BartConfig +from transformers.activations import ACT2FN +from transformers.utils import logging from vllm.attention import Attention, AttentionMetadata -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.attention.backends.abstract import AttentionType from vllm.config import CacheConfig, LoRAConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader - +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -import math - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - Seq2SeqLMOutput, - Seq2SeqModelOutput, -) -from transformers.utils import ( - logging, ) -from transformers import BartConfig - logger = logging.get_logger(__name__) @@ -70,9 +61,7 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, - input_ids: torch.Tensor, - attn_type: AttentionType, + def forward(self, input_ids: torch.Tensor, attn_type: AttentionType, attn_metadata: AttentionMetadata) -> torch.Tensor: """`input_ids' shape is expected to be [bsz x seqlen].""" @@ -160,10 +149,8 @@ def __init__( cache_config=cache_config, quant_config=quant_config) - def forward( - self, hidden_states: torch.Tensor, kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: """Input shape: Batch x Time x Channel""" q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) @@ -172,7 +159,7 @@ def forward( attn_output = self.attn(q, k, v, - kv_caches, + kv_cache, attn_metadata, attn_type=AttentionType.ENCODER) @@ -216,10 +203,8 @@ def __init__( cache_config=cache_config, quant_config=quant_config) - def forward( - self, hidden_states: torch.Tensor, kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: """Input shape: Batch x Time x Channel""" q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) @@ -228,7 +213,7 @@ def forward( attn_output = self.attn(q, k, v, - kv_caches, + kv_cache, attn_metadata, attn_type=AttentionType.DECODER) @@ -275,7 +260,7 @@ def __init__( def forward( self, decoder_hidden_states: torch.Tensor, - kv_caches: List[torch.Tensor], + kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -289,7 +274,7 @@ def forward( attn_output = self.attn(q, k, v, - kv_caches, + kv_cache, attn_metadata, attn_type=AttentionType.ENCODER_DECODER) @@ -320,24 +305,28 @@ def __init__( self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - def forward( - self, hidden_states: torch.Tensor, kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata - ) -> torch.Tensor: - """ + def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. + hidden_states + + torch.Tensor of *encoder* input embeddings. + + kv_cache: + + Layer-wise list of KV cache tensors + + attn_metadata: + + vLLM Attention metadata structure + + Returns: + Encoder layer output torch.Tensor """ residual = hidden_states hidden_states = self.self_attn(hidden_states=hidden_states, - kv_caches=kv_caches, + kv_caches=kv_cache, attn_metadata=attn_metadata) hidden_states = residual + hidden_states @@ -400,33 +389,36 @@ def __init__( def forward( self, decoder_hidden_states: torch.Tensor, - kv_caches: List[torch.Tensor], + kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """ + r""" Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - encoder_hidden_states (`torch.FloatTensor`): - cross attention input to the layer of shape `(batch, seq_len, embed_dim)` - encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of - size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. + decoder_hidden_states + + torch.Tensor of *decoder* input embeddings. + + kv_cache: + + KV cache tensor + + attn_metadata: + + vLLM Attention metadata structure + + encoder_hidden_states + + torch.Tensor of *encoder* input embeddings. + + Returns: + Decoder layer output torch.Tensor """ residual = decoder_hidden_states # Self Attention hidden_states = self.self_attn(hidden_states=decoder_hidden_states, - kv_caches=kv_caches, + kv_caches=kv_cache, attn_metadata=attn_metadata) hidden_states = residual + hidden_states @@ -438,7 +430,7 @@ def forward( hidden_states = self.encoder_attn( decoder_hidden_states=hidden_states, - kv_caches=kv_caches, + kv_caches=kv_cache, attn_metadata=attn_metadata, encoder_hidden_states=encoder_hidden_states, ) @@ -690,11 +682,10 @@ def get_encoder(self): def get_decoder(self): return self.decoder - def forward( - self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, - kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata - ) -> torch.Tensor: + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: encoder_hidden_states = None @@ -742,22 +733,38 @@ def __init__(self, config.vocab_size) self.sampler = Sampler() - def forward( - self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> torch.Tensor: + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, - *optional*): Labels for computing the masked language modeling - loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). - Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the - tokens with labels in `[0, ..., config.vocab_size]`. + Args: + input_ids + + torch.Tensor of *decoder* input token ids. + + positions + + torch.Tensor of *decoder* position indices. + + encoder_input_ids + + torch.Tensor of *encoder* input token ids. + + encoder_positions + + torch.Tensor of *encoder* position indices + + kv_caches: + + Layer-wise list of KV cache tensors + + attn_metadata: + + vLLM Attention metadata structure Returns: - torch.Tensor inference result + Output torch.Tensor """ hidden_states = self.model(input_ids, positions, encoder_input_ids, encoder_positions, kv_caches, attn_metadata) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 0f88075ab56fb..488c06a4cd526 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -16,11 +16,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, +from vllm.utils import (LIST_ENC_DEC_SUPPORTED_BACKENDS, + STR_NOT_IMPL_ENC_DEC_BACKEND, STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, + STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - STR_NOT_IMPL_ENC_DEC_BACKEND, STR_NOT_IMPL_ENC_DEC_SWA, - LIST_ENC_DEC_SUPPORTED_BACKENDS, make_tensor_with_pad) + STR_NOT_IMPL_ENC_DEC_SWA, make_tensor_with_pad) from vllm.worker.model_runner import LORA_WARMUP_RANK, ModelInput, ModelRunner logger = init_logger(__name__) From ba4e2c12e6f1a03e3381cabda8902d55df9a292e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 04:05:23 -0400 Subject: [PATCH 295/443] Removed unnecessary position arguments from BART routine; formatting --- vllm/model_executor/models/bart.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 71abd579ad74c..cfc8d7d05a0b4 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -326,7 +326,7 @@ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, """ residual = hidden_states hidden_states = self.self_attn(hidden_states=hidden_states, - kv_caches=kv_cache, + kv_cache=kv_cache, attn_metadata=attn_metadata) hidden_states = residual + hidden_states @@ -418,7 +418,7 @@ def forward( # Self Attention hidden_states = self.self_attn(hidden_states=decoder_hidden_states, - kv_caches=kv_cache, + kv_cache=kv_cache, attn_metadata=attn_metadata) hidden_states = residual + hidden_states @@ -430,7 +430,7 @@ def forward( hidden_states = self.encoder_attn( decoder_hidden_states=hidden_states, - kv_caches=kv_cache, + kv_cache=kv_cache, attn_metadata=attn_metadata, encoder_hidden_states=encoder_hidden_states, ) @@ -540,7 +540,7 @@ def forward(self, input_ids: torch.Tensor, kv_caches: List[torch.Tensor], for idx, encoder_layer in enumerate(self.layers): hidden_states = encoder_layer( hidden_states=hidden_states, - kv_caches=kv_caches[idx], + kv_cache=kv_caches[idx], attn_metadata=attn_metadata, ) @@ -641,7 +641,7 @@ def forward(self, decoder_input_ids: torch.Tensor, for idx, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( decoder_hidden_states=hidden_states, - kv_caches=kv_caches[idx], + kv_cache=kv_caches[idx], attn_metadata=attn_metadata, encoder_hidden_states=encoder_hidden_states, ) @@ -682,9 +682,8 @@ def get_encoder(self): def get_decoder(self): return self.decoder - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], + def forward(self, input_ids: torch.Tensor, encoder_input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata) -> torch.Tensor: encoder_hidden_states = None @@ -766,8 +765,8 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, Returns: Output torch.Tensor """ - hidden_states = self.model(input_ids, positions, encoder_input_ids, - encoder_positions, kv_caches, attn_metadata) + hidden_states = self.model(input_ids, encoder_input_ids, kv_caches, + attn_metadata) return hidden_states[0, :, :] def compute_logits(self, hidden_states: torch.Tensor, From 75756b91e3753a9c2a60dbae42b2e46d3612ece5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 27 Jun 2024 11:28:19 -0400 Subject: [PATCH 296/443] removed redundant elif --- vllm/attention/backends/xformers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 83a63b6b8bf23..b1daaefc9f3b5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -276,10 +276,9 @@ def _get_attn_bias(attn_metadata: XFormersMetadata, return attn_metadata.attn_bias elif attn_type == AttentionType.ENCODER: return attn_metadata.encoder_attn_bias - elif attn_type == AttentionType.ENCODER_DECODER: - return attn_metadata.cross_attn_bias else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") + # attn_type == AttentionType.ENCODER_DECODER + return attn_metadata.cross_attn_bias def _set_attn_bias(attn_metadata: XFormersMetadata, From a5018499e3b8475749a8d1af80e14c8d172cf2c7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 27 Jun 2024 18:57:56 -0400 Subject: [PATCH 297/443] reverted unnecessarily vllm/utils.py changes --- vllm/utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 6cc4af98e5b9f..92abdb3fb9b14 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -579,6 +579,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): gc.collect() +def str_to_int_tuple(s: str) -> Tuple[int, ...]: + """Convert a string to a tuple of integers.""" + try: + return tuple(map(int, s.split(","))) + except ValueError as e: + raise ValueError( + "String must be a series of integers separated by commas " + f"(e.g., 1, 2, 3). Given input: {s}") from e + + def make_tensor_with_pad( x: List[List[int]], max_len: int, @@ -598,16 +608,6 @@ def make_tensor_with_pad( return torch.tensor(padded_x, dtype=dtype, device=device) -def str_to_int_tuple(s: str) -> Tuple[int, ...]: - """Convert a string to a tuple of integers.""" - try: - return tuple(map(int, s.split(","))) - except ValueError as e: - raise ValueError( - "String must be a series of integers separated by commas " - f"(e.g., 1, 2, 3). Given input: {s}") from e - - def async_tensor_h2d( data: list, dtype: torch.dtype, From 44c62708f3645f8a82b17a63849c1822a2dca645 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 3 Jul 2024 10:15:57 -0400 Subject: [PATCH 298/443] manually merged BART code in from previous modelrunner attempt, it won't work tho until new modelrunner is finished --- examples/offline_inference_encoder_decoder.py | 61 ++ tests/models/test_bart.py | 64 ++ .../test_encoder_decoder_model_runner.py | 368 ++++++++ vllm/model_executor/models/__init__.py | 11 +- vllm/model_executor/models/bart.py | 892 ++++++++++++++++++ 5 files changed, 1395 insertions(+), 1 deletion(-) create mode 100644 examples/offline_inference_encoder_decoder.py create mode 100644 tests/models/test_bart.py create mode 100644 tests/worker/test_encoder_decoder_model_runner.py create mode 100644 vllm/model_executor/models/bart.py diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py new file mode 100644 index 0000000000000..0426ec6e5a481 --- /dev/null +++ b/examples/offline_inference_encoder_decoder.py @@ -0,0 +1,61 @@ +from vllm import LLM, SamplingParams + +dtype = "float" + +# Sample prompts. +# - Encoder prompts +encoder_prompts = [ + "PG&E stated it scheduled the blackouts in " + "response to forecasts for high winds " + "amid dry conditions. The aim is to reduce " + "the risk of wildfires. Nearly 800 thousand customers were " + "scheduled to be affected by the shutoffs which " + "were expected to last through at least midday tomorrow.", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# - Decoder prompts +decoder_prompts = [ + "", + "", + "", + "", +] +# - Unified prompts +prompts = [enc_dec for enc_dec in zip(encoder_prompts, decoder_prompts)] + +print(prompts) + +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0, top_p=1.0, min_tokens=0, max_tokens=20,) +#sampling_params = SamplingParams(temperature=0, top_p=1.0, use_beam_search=True, best_of=2, min_tokens=0, max_tokens=20,) + +# Create an LLM. +llm = LLM(model="facebook/bart-large-cnn", enforce_eager=True, dtype = dtype) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +from transformers import AutoTokenizer, BartForConditionalGeneration + +model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") +tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") + +ARTICLE_TO_SUMMARIZE = ( + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +) +inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt") + +# Generate Summary +#summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20) +summary_ids = model.generate(inputs["input_ids"], min_length=0, max_length=20) +print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py new file mode 100644 index 0000000000000..8ba22eb4cae8a --- /dev/null +++ b/tests/models/test_bart.py @@ -0,0 +1,64 @@ +"""Compare the outputs of HF and vLLM for BART models using greedy sampling. + +Run `pytest tests/models/test_bart.py`. +""" +import pytest + +from tests.kernels.utils import override_backend_env_variable +from vllm.utils import STR_XFORMERS_ATTN_VAL + +from .utils import check_logprobs_close, check_logprobs_close_encoder_decoder + +MODELS = ["facebook/bart-base","facebook/bart-large-cnn"] + +# Backends under test +# +# Currently only XFormers is supported +BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float","bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +def test_models( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, + backend_name: str, + monkeypatch, +) -> None: + # TODO(sang): Sliding window should be tested separately. + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + with hf_runner(model, dtype=dtype, + is_encoder_decoder_model=True) as hf_model: + hf_outputs = hf_model.generate_encoder_decoder_greedy_logprobs_limit( + example_encoder_decoder_prompts, max_tokens, num_logprobs) + + decoder_input_ids_list = [hf_model.tokenizer(decoder_prompt, + return_tensors="pt").input_ids + for decoder_prompt in example_encoder_decoder_prompts[1]] + + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + example_encoder_decoder_prompts, max_tokens, num_logprobs) + + # print(hf_outputs) + # print("\n\n\n\n\n") + # print(vllm_outputs) + + check_logprobs_close_encoder_decoder( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + decoder_input_ids_list=decoder_input_ids_list, + name_0="hf", + name_1="vllm" + ) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py new file mode 100644 index 0000000000000..88b982bb8fdc2 --- /dev/null +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -0,0 +1,368 @@ +from typing import List + +import pytest +import torch + +from tests.kernels.utils import override_backend_env_variable +from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.utils import STR_XFORMERS_ATTN_VAL +from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner + +# Backends under test +# +# Currently only XFormers is supported +BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] + +# CUDA graph scenarios to test +# +# Currently CUDA graph is not supported +ENFORCE_EAGER = [True] + + +def _create_model_runner(model: str, *args, + **kwargs) -> EncoderDecoderModelRunner: + engine_args = EngineArgs(model, *args, **kwargs) + engine_config = engine_args.create_engine_config() + model_runner = EncoderDecoderModelRunner( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + lora_config=engine_config.lora_config, + is_driver_worker=True, + ) + return model_runner + + +@pytest.mark.parametrize("batch_size", list(range(1, 257))) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) +def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + model_runner = _create_model_runner("facebook/bart-base", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=enforce_eager) + + seq_lens: List[int] = [] + encoder_seq_lens: List[int] = [] + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + block_tables = {0: [1]} + cross_block_table = [2] + for i in range(batch_size): + # make sure all tokens fit into one block + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) + encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 + encoder_seq_lens.append(encoder_seq_len) + encoder_seq_data = SequenceData(list(range(encoder_seq_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table) + assert seq_group_metadata.token_chunk_size == seq_data.get_len() + seq_group_metadata_list.append(seq_group_metadata) + + expected_selected_token_indices = [] + selected_token_start_idx = 0 + for seq_len in seq_lens: + expected_selected_token_indices.append(selected_token_start_idx + + seq_len - 1) + selected_token_start_idx += seq_len + + # Decoder model input + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = model_input.slot_mapping + assert return_seq_lens == seq_lens + assert len(slot_mapping) == len(input_tokens) + + # Encoder model input + encoder_model_input = model_runner._prepare_encoder_model_input( + seq_group_metadata_list, attn_metadata) + encoder_input_tokens = encoder_model_input.input_tokens + encoder_input_positions = encoder_model_input.input_positions + cross_slot_mapping = attn_metadata.cross_slot_mapping + assert len(cross_slot_mapping) == len(encoder_input_tokens) + + # Verify input metadata is correct for prompts. + # - Decoder attention metadata + device = model_runner.device + assert attn_metadata.num_prefills > 0 + assert attn_metadata.num_decode_tokens == 0 + assert torch.allclose( + attn_metadata.seq_lens_tensor, + torch.tensor(seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.max_prefill_seq_len == max(seq_lens) + assert attn_metadata.max_decode_seq_len == 0 + # - Encoder attention metadata + assert attn_metadata.encoder_seq_lens == encoder_seq_lens + assert torch.allclose( + attn_metadata.encoder_seq_lens_tensor, + torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) + assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) + + # Test decoder subquery start locs. + start_idx = 0 + start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.query_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + + # Test decoder seq start locs. Note that for normal prefill it is + # equivalent to query_start_loc. + start_idx = 0 + seq_start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + seq_start_loc.append(start_idx) + + assert torch.allclose( + attn_metadata.seq_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + assert torch.allclose( + attn_metadata.context_lens_tensor, + torch.zeros(attn_metadata.context_lens_tensor.shape[0], + dtype=torch.int, + device=device)) + + # Verify block tables are correct for prompts + # - Decoder self-attention + expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.allclose(attn_metadata.block_tables, expected) + # - Encoder/decoder cross-attention + # expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], + # dtype=torch.int32, + # device=model_runner.device) + assert torch.allclose(attn_metadata.cross_block_tables, expected) + + # Cuda graph should not be used for prefill, regardless of + # enforce_eager setting + assert attn_metadata.use_cuda_graph is False + + # Verify the lengths of input tokens & positions + # - Decoder + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) + torch.testing.assert_close(input_tokens, input_positions) + # - Encoder + assert len(encoder_input_tokens) == sum(encoder_seq_lens) + assert len(encoder_input_tokens) == sum(encoder_seq_lens) + torch.testing.assert_close(encoder_input_tokens, encoder_input_positions) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens=seq_lens, + device=model_runner.device, + pin_memory=model_runner.pin_memory) + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) + torch.allclose(input_tokens, input_positions) + + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) + + +@pytest.mark.parametrize("batch_size", list(range(1, 257))) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) +def test_prepare_decode(batch_size, backend_name, enforce_eager, monkeypatch): + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + model_runner = _create_model_runner("facebook/bart-base", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=enforce_eager) + + seq_lens: List[int] = [] + encoder_seq_lens: List[int] = [] + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + block_tables = {0: [1]} + cross_block_table = [2] + for i in range(batch_size): + # make sure all tokens fit into one block + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) + encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 + encoder_seq_lens.append(encoder_seq_len) + encoder_seq_data = SequenceData(list(range(encoder_seq_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) + + # Decoder model input + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = model_input.slot_mapping + assert return_seq_lens == seq_lens + assert len(slot_mapping) == len(input_tokens) + + # Encoder model input + encoder_model_input = model_runner._prepare_encoder_model_input( + seq_group_metadata_list, attn_metadata) + encoder_input_tokens = encoder_model_input.input_tokens + encoder_input_positions = encoder_model_input.input_positions + return_encoder_seq_lens = attn_metadata.encoder_seq_lens + cross_slot_mapping = attn_metadata.cross_slot_mapping + assert return_encoder_seq_lens == encoder_seq_lens + assert len(cross_slot_mapping) == len(encoder_input_tokens) + + # Verify input metadata is correct for decode phase. + # - Decoder attention metadata + device = model_runner.device + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_decode_tokens > 0 + assert torch.allclose( + attn_metadata.seq_lens_tensor, + torch.tensor(seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.max_prefill_seq_len == 0 + assert attn_metadata.max_decode_seq_len == max(seq_lens) + # - Encoder attention metadata + assert attn_metadata.encoder_seq_lens == encoder_seq_lens + assert torch.allclose( + attn_metadata.encoder_seq_lens_tensor, + torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) + assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) + + # Test decoder subquery start locs. + start_idx = 0 + start_loc = [start_idx] + for seq_len in seq_lens: + # 1 decoded token per sequence + start_idx += 1 + start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.query_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + + # Test decoder seq start locs. Note that for normal prefill it is + # equivalent to query_start_loc. + start_idx = 0 + seq_start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + seq_start_loc.append(start_idx) + + assert torch.allclose( + attn_metadata.seq_start_loc, + torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) + assert torch.allclose( + attn_metadata.context_lens_tensor, + torch.tensor([seq_len - 1 for seq_len in seq_lens], + dtype=torch.int, + device=device)) + + # Verify block tables are correct for prompts + # - Decoder self-attention + expected = torch.tensor( + [block_tables[0] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.allclose(attn_metadata.block_tables, expected) + # - Encoder/decoder cross-attention + expected = torch.tensor( + [cross_block_table for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.allclose(attn_metadata.cross_block_tables, expected) + + # Cuda graph should not be used for prefill. + assert attn_metadata.use_cuda_graph == (not enforce_eager) + + # Verify the lengths of input tokens & positions + # - Decoder + assert len(input_tokens) == len(seq_lens) + assert len(input_positions) == len(seq_lens) + torch.testing.assert_close(input_tokens, input_positions) + # - Encoder + assert len(encoder_input_tokens) == 0 + assert len(encoder_input_positions) == 0 + + +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) +def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): + """Verify prepare prompt and decode returns empty output.""" + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + model_runner = _create_model_runner( + "facebook/bart-base", + seed=0, + dtype="float16", + enforce_eager=enforce_eager, + ) + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, slot_mapping = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + ) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, slot_mapping, + return_seq_lens) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + model_input.seq_lens, + ) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + assert len(return_seq_lens) == 0 diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index a4fe18d52d608..0a0272f12fec9 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -70,7 +70,16 @@ "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), } -_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} +_CONDITIONAL_GENERATION_MODELS = { + "BartModel": ("bart", "BartForConditionalGeneration"), + "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), +} + +_MODELS = { + **_GENERATION_MODELS, + **_EMBEDDING_MODELS, + **_CONDITIONAL_GENERATION_MODELS +} # Architecture -> type. # out of tree models diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py new file mode 100644 index 0000000000000..332328302b91e --- /dev/null +++ b/vllm/model_executor/models/bart.py @@ -0,0 +1,892 @@ +# Derived from BART implementation posted on HuggingFace; license below: +# +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BART model.""" +import math +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import BartConfig +from transformers.activations import ACT2FN +from transformers.utils import logging + +from vllm.attention import Attention, AttentionMetadata +from vllm.attention.backends.abstract import AttentionType +from vllm.config import CacheConfig, LoRAConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput + +logger = logging.get_logger(__name__) + + +def get_bsz_seq_len(input_ids): + shp = input_ids.shape + ndim = len(shp) + if ndim == 1: + return 1, input_ids.numel() + else: + return shp[:2] + + +class BartLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # Bart is set up so that if padding_idx is + # specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. + # Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, input_ids: torch.Tensor, attn_type: AttentionType, + attn_metadata: AttentionMetadata) -> torch.Tensor: + """`input_ids' shape is expected to be [bsz x seqlen].""" + + assert attn_type != AttentionType.ENCODER_DECODER + + bsz, seq_len = get_bsz_seq_len(input_ids) + # afeldman-nm: This BART implementation is designed for vLLM, which + # packs variable-length sequences into a single vector + # without padding + assert bsz == 1 + + if attn_type == AttentionType.ENCODER: + seq_lens = attn_metadata.encoder_seq_lens + past_key_values_lens = [0]*len(seq_lens) + else: + # AttentionType.DECODER + if attn_metadata.num_prefill_tokens > 0: + # Prefill + seq_lens = attn_metadata.seq_lens + past_key_values_lens = [0]*len(seq_lens) + else: + # Decode: infer one (1) new token per sequence + seq_lens = [1] * len(attn_metadata.seq_lens) + past_key_values_lens = [seq_len-1 for seq_len in attn_metadata.seq_lens] + + positions = [] + for past_key_values_len,seq_len in zip(past_key_values_lens,seq_lens): + positions.extend(list(range(past_key_values_len,past_key_values_len+seq_len))) + + positions = torch.tensor(positions, + dtype=torch.long, + device=self.weight.device).expand(bsz, -1) + + return super().forward(positions + self.offset) + + +class BartScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' + forward by multiplying with embeddings scale. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) * self.embed_scale + + +class BartEncoderAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = self.num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER) + + output = self.out_proj(attn_output) + return output + + +class BartDecoderSelfAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = self.num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.DECODER) + + output = self.out_proj(attn_output) + return output + + +class BartCrossAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = self.num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + decoder_hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + q = self.q_proj(decoder_hidden_states) + k=None if encoder_hidden_states is None else \ + self.k_proj(encoder_hidden_states) + v=None if encoder_hidden_states is None else \ + self.v_proj(encoder_hidden_states) + + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) + + output = self.out_proj(attn_output) + return output + + +class BartEncoderLayer(nn.Module): + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartEncoderAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + config=config, + cache_config=cache_config, + quant_config=quant_config) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.activation_fn = ACT2FN[config.activation_function] + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" + Args: + hidden_states + torch.Tensor of *encoder* input embeddings. + kv_cache: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Encoder layer output torch.Tensor + """ + residual = hidden_states + hidden_states = self.self_attn(hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + + hidden_states = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + + return hidden_states + + +class BartDecoderLayer(nn.Module): + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartDecoderSelfAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + cache_config=cache_config, + quant_config=quant_config) + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + ''' + afeldman-nm: personally I would call this "cross-attention", + however I left the name as "encoder_attn" to maintain consistency + with the name of the pretrained weights. + ''' + self.encoder_attn = BartCrossAttention( + self.embed_dim, + config.decoder_attention_heads, + config=config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + decoder_hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + decoder_hidden_states + torch.Tensor of *decoder* input embeddings. + kv_cache: + KV cache tensor + attn_metadata: + vLLM Attention metadata structure + encoder_hidden_states + torch.Tensor of *encoder* input embeddings. + Returns: + Decoder layer output torch.Tensor + """ + residual = decoder_hidden_states + + # Self Attention + hidden_states = self.self_attn(hidden_states=decoder_hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + + residual = hidden_states + + hidden_states = self.encoder_attn( + decoder_hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + + hidden_states = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + # if hidden_states.dtype == torch.float16 and ( + # torch.isinf(hidden_states).any() + # or torch.isnan(hidden_states).any()): + # clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + # hidden_states = torch.clamp(hidden_states, + # min=-clamp_value, + # max=clamp_value) + + return hidden_states + + +class BartEncoder(nn.Module): + """ + Transformer encoder consisting of *config.encoder_layers* + self attention layers. Each layer is a [`BartEncoderLayer`]. + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None): + super().__init__() + + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + embed_dim, + self.padding_idx, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList( + [BartEncoderLayer(config,cache_config,quant_config) \ + for _ in range(config.encoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + def forward(self, input_ids: torch.Tensor, kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" + Args: + input_ids + (`torch.LongTensor` of shape `(total_num_tokens)`): + Indices of *encoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + kv_caches: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Decoder output torch.Tensor + """ + # retrieve input_ids and inputs_embeds + + input = input_ids + input_ids = input_ids.view(-1, input_ids.shape[-1]) + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(input, AttentionType.ENCODER, + attn_metadata) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + for idx, encoder_layer in enumerate(self.layers): + hidden_states = encoder_layer( + hidden_states=hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + ) + + return hidden_states + + +class BartDecoder(nn.Module): + """ + Transformer decoder consisting of *config.decoder_layers* layers. + Each layer is a [`BartDecoderLayer`] + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + ): + super().__init__() + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + config.d_model, + self.padding_idx, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + + self.layers = nn.ModuleList( + [BartDecoderLayer(config,cache_config,quant_config) \ + for _ in range(config.decoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + def forward(self, decoder_input_ids: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" + Args: + decoder_input_ids + (`torch.LongTensor` of shape `(total_num_tokens)`): + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + encoder_hidden_states: + Tensor of encoder output embeddings + kv_caches: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Decoder output torch.Tensor + """ + + input = decoder_input_ids + input_shape = input.shape + decoder_input_ids = decoder_input_ids.view(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input) + + # embed positions + embed_pos = self.embed_positions(input, AttentionType.DECODER, + attn_metadata) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + # decoder layers + + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + decoder_hidden_states=hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + + return hidden_states + + +class BartModel(nn.Module): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + ] + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): + super().__init__() + + self.config = config + + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.encoder = BartEncoder(config, + cache_config, + quant_config=quant_config) + self.decoder = BartDecoder(config, + cache_config, + quant_config=quant_config) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def forward(self, input_ids: torch.Tensor, encoder_input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: + + encoder_hidden_states = None + + if encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input + encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, + kv_caches=kv_caches, + attn_metadata=attn_metadata) + + # decoder outputs consists of + # (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + decoder_input_ids=input_ids, + encoder_hidden_states=encoder_hidden_states, + kv_caches=kv_caches, + attn_metadata=attn_metadata) + + return decoder_outputs + + +class BartForConditionalGeneration(nn.Module): + base_model_prefix = "model" + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): + + super().__init__() + self.config = config + self.model = BartModel(config, + cache_config, + quant_config, + lora_config=lora_config) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + kv_caches: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Output torch.Tensor + """ + hidden_states = self.model(input_ids, encoder_input_ids, kv_caches, + attn_metadata) + return hidden_states[0, :, :] + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + stacked_params_mapping = { + "query": { + "param_name": "qkv_proj", + "shard_id": "q", + }, + "key": { + "param_name": "qkv_proj", + "shard_id": "k", + }, + "value": { + "param_name": "qkv_proj", + "shard_id": "v", + }, + } + + params_mapping = { + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + } + + def _rename_key(self, key: str): + prefix = f"{self.base_model_prefix}." + key = key[len(prefix):] if key.startswith(prefix) else key + + for src, dst in self.params_mapping.items(): + key = key.replace(src, dst) + + return key + + def _rename_stacked_param( + self, + name: str, + ) -> Tuple[str, Optional[str]]: + for key, mapping in self.stacked_params_mapping.items(): + if key in name: + name = name.replace(key, mapping["param_name"]) + return name, mapping["shard_id"] + return name, None + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + model_params_dict = dict(self.model.named_parameters()) + top_params_dict = dict(self.named_parameters()) + + weights_tuple_list = list(weights) + weight_names = [w[0] for w in weights_tuple_list] + + #has_shared_weight = any(['shared.weight' in wn for wn in weight_names]) + #has_encoder_embed_tokens_weight = any(['encoder.embed_tokens.weight' in wn for wn in weight_names]) + #has_decoder_embed_tokens_weight = any(['decoder.embed_tokens.weight' in wn for wn in weight_names]) + + shared_embedding_weight = None + shared_embedding_shard_id = None + + for name, loaded_weight in weights_tuple_list: + + name = self._rename_key(name) + name, shard_id = self._rename_stacked_param(name) + + if 'shared.weight' in name or \ + 'encoder.embed_tokens.weight' in name \ + or 'decoder.embed_tokens.weight' in name \ + or 'lm_head.weight' in name: + assert shared_embedding_weight is None, "Conflicting embedding weights." + shared_embedding_weight = loaded_weight + shared_embedding_shard_id = shard_id + + # encoder_in_param = model_params_dict[ + # 'encoder.embed_tokens.weight'] + # encoder_in_weight_loader = getattr(encoder_in_param, + # "weight_loader", + # default_weight_loader) + + # decoder_in_param = model_params_dict[ + # 'decoder.embed_tokens.weight'] + # decoder_in_weight_loader = getattr(decoder_in_param, + # "weight_loader", + # default_weight_loader) + + # lm_head_in_param = top_params_dict['lm_head.weight'] + # lm_head_in_weight_loader = getattr(lm_head_in_param, + # "weight_loader", + # default_weight_loader) + + # if shard_id: + # encoder_in_weight_loader(encoder_in_param, loaded_weight, + # shard_id) + # decoder_in_weight_loader(decoder_in_param, loaded_weight, + # shard_id) + # lm_head_in_weight_loader(lm_head_in_param, loaded_weight, + # shard_id) + # else: + # encoder_in_weight_loader(encoder_in_param, loaded_weight) + # decoder_in_weight_loader(decoder_in_param, loaded_weight) + # lm_head_in_weight_loader(lm_head_in_param, loaded_weight) + else: + # Skip the specific downstream task weight. + if name.startswith('cls.'): + continue + # use Pooler instead. + if name.startswith('pooler.'): + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in model_params_dict: + continue + + param = model_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if shard_id: + weight_loader(param, loaded_weight, shard_id) + else: + weight_loader(param, loaded_weight) + + # Assign shared weight values + encoder_in_param = model_params_dict[ + 'encoder.embed_tokens.weight'] + encoder_in_weight_loader = getattr(encoder_in_param, + "weight_loader", + default_weight_loader) + + decoder_in_param = model_params_dict[ + 'decoder.embed_tokens.weight'] + decoder_in_weight_loader = getattr(decoder_in_param, + "weight_loader", + default_weight_loader) + + lm_head_in_param = top_params_dict['lm_head.weight'] + lm_head_in_weight_loader = getattr(lm_head_in_param, + "weight_loader", + default_weight_loader) + + assert shared_embedding_weight is not None + + if shared_embedding_shard_id: + encoder_in_weight_loader(encoder_in_param, shared_embedding_weight, + shared_embedding_shard_id) + decoder_in_weight_loader(decoder_in_param, shared_embedding_weight, + shared_embedding_shard_id) + lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight, + shared_embedding_shard_id) + else: + encoder_in_weight_loader(encoder_in_param, shared_embedding_weight) + decoder_in_weight_loader(decoder_in_param, shared_embedding_weight) + lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight) \ No newline at end of file From ba09fbcd6b7efff359b1a0cef47c385d130b777d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 3 Jul 2024 11:32:18 -0400 Subject: [PATCH 299/443] refactored where a number of constants are stored, primarily constants related to encoder/decoder --- tests/kernels/test_attention_selector.py | 5 +- tests/kernels/test_encoder_decoder_attn.py | 5 +- tests/kernels/utils.py | 17 +------ vllm/core/block/utils.py | 12 +---- vllm/utils.py | 58 ++++++++++++++++++++++ 5 files changed, 67 insertions(+), 30 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index d9404e6442616..d15c5f6c91154 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -3,10 +3,9 @@ import pytest import torch -from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL, - override_backend_env_variable) +from tests.kernels.utils import override_backend_env_variable from vllm.attention.selector import which_attn_to_use - +from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL @pytest.mark.parametrize( "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"]) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 654cd621145c5..9e55091545f63 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -9,6 +9,7 @@ """ +import copy from typing import NamedTuple, Optional import pytest @@ -19,7 +20,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP -from vllm.utils import is_hip +from vllm.utils import LIST_ENC_DEC_SUPPORTED_BACKENDS, is_hip HEAD_SIZES = [64, 256] @@ -27,7 +28,7 @@ BATCH_SIZES = [1, 16] BLOCK_SIZES = [16] -BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] +BACKEND_NAMES = LIST_ENC_DEC_SUPPORTED_BACKENDS CUDA_DEVICE = "cuda:0" MAX_DEC_SEQ_LENS = [128] diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index f0b0dd5dbaee6..b8db8c485a16f 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -11,21 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend -from vllm.utils import make_tensor_with_pad - -# String name of register which may be set in order to -# force auto-selection of attention backend by Attention -# wrapper -STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" - -# Possible string values of STR_BACKEND_ENV_VAR -# register, corresponding to possible backends -STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" -STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" -STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" -STR_XFORMERS_ATTN_VAL: str = "XFORMERS" -STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" -STR_INVALID_VAL: str = "INVALID" +from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, + make_tensor_with_pad) class QKVInputs(NamedTuple): diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 2c412a8f472e0..28839437c33c5 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,15 +1,7 @@ """Block manager utils.""" from vllm.sequence import SequenceGroup - -# Exception strings for non-implemented block manager enc/dec scenarios - -STR_NOT_IMPL_ENC_DEC_SWA = \ - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported." - -STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ - "Prefix caching for encoder/decoder models " + \ - "is not currently supported." +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) def _get_block_mgr_sliding_window_attr(block_mgr): diff --git a/vllm/utils.py b/vllm/utils.py index 854decc290fae..514814b193b73 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -31,6 +31,45 @@ logger = init_logger(__name__) +# Exception strings for non-implemented encoder/decoder scenarios + +STR_NOT_IMPL_ENC_DEC_SWA = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ + "Chunked prefill for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = \ + "Currently CUDAGraph is not supported for encoder/decoder models" + +STR_NOT_IMPL_ENC_DEC_BACKEND = \ + "This backend is currently unsupported for encoder/decoder models:" + +# Constants related to forcing the attention backend selection + +# String name of register which may be set in order to +# force auto-selection of attention backend by Attention +# wrapper +STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" + +# Possible string values of STR_BACKEND_ENV_VAR +# register, corresponding to possible backends +STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" +STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" +STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" +STR_XFORMERS_ATTN_VAL: str = "XFORMERS" +STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" +STR_INVALID_VAL: str = "INVALID" + +# List of support backends for encoder/decoder models +LIST_ENC_DEC_SUPPORTED_BACKENDS = [STR_XFORMERS_ATTN_VAL] + STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.half, "bfloat16": torch.bfloat16, @@ -922,3 +961,22 @@ def parse_args(self, args=None, namespace=None): processed_args.append(arg) return super().parse_args(processed_args, namespace) + +def is_encoder_decoder_model_config(model_config) -> bool: + ''' + Extract the HF encoder/decoder model flag from the ModelConfig instance. + Return False if model_config is None. + ''' + return False if model_config is None else \ + getattr(model_config.hf_config, + "is_encoder_decoder", + False) + + +def is_embedding_model_config(model_config) -> bool: + ''' + Extract the embedding model flag from the ModelConfig instance. + Return False if model_config is None. + ''' + return False if model_config is None else \ + model_config.embedding_mode \ No newline at end of file From 5dbebbc6f3aafe706a5555119fefa519b71c4634 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Mon, 8 Jul 2024 09:32:43 -0400 Subject: [PATCH 300/443] Update vllm/attention/backends/torch_sdpa.py nit: This will reduce the number of line changes and make the code look better. Co-authored-by: Woosuk Kwon --- vllm/attention/backends/torch_sdpa.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index eeef24ed4fb33..c2fefe5342362 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -145,7 +145,8 @@ def forward( kv_cache: Optional[torch.Tensor], attn_metadata: TorchSDPAMetadata, # type: ignore kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. Args: From 07df0e158a60b7d2a90407eecc868eaa10a58180 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Mon, 8 Jul 2024 09:33:03 -0400 Subject: [PATCH 301/443] Update vllm/attention/layer.py Co-authored-by: Woosuk Kwon --- vllm/attention/layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 5eee5914d3642..ae2607cf71dea 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -89,8 +89,8 @@ def forward(self, value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, - attn_type: AttentionType = AttentionType.DECODER) \ - -> torch.Tensor: + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: return self.impl.forward(query, key, From 7ce9a51d4fb3e286fdaa3a3ba12e60d0908d2d64 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 09:38:03 -0400 Subject: [PATCH 302/443] merged in first pieces of woosuk feedback & latest main; formatting --- vllm/attention/backends/torch_sdpa.py | 18 +++++++++--------- vllm/attention/layer.py | 17 +++++++++-------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c2fefe5342362..197981e47e921 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -138,15 +138,15 @@ def __init__( "Please use xFormers backend instead.") def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: TorchSDPAMetadata, # type: ignore - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: TorchSDPAMetadata, # type: ignore + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. Args: diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ae2607cf71dea..b8cc87be8c748 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -83,14 +83,15 @@ def __init__( alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params) - def forward(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata, - attn_type: AttentionType = AttentionType.DECODER, - ) -> torch.Tensor: + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: return self.impl.forward(query, key, From 9ae6728ecfe48769f578b0fad3f8e3950daa683d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 09:46:58 -0400 Subject: [PATCH 303/443] fixed specific point-changes requested by woosuk --- vllm/attention/backends/torch_sdpa.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 197981e47e921..6b9b2c6f4b5a4 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -160,9 +160,9 @@ def forward( """ assert kv_scale == 1.0 if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + + "encoder/decoder cross-attention " + + "are not implemented for " + "TorchSDPABackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. From a1bf65212cab0933b2520d8557a9d9132fff8c3d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:17:04 -0400 Subject: [PATCH 304/443] test_encoder_decoder_attn.py cleanup --- tests/kernels/test_encoder_decoder_attn.py | 315 +++++++++++---------- vllm/attention/backends/torch_sdpa.py | 6 +- 2 files changed, 166 insertions(+), 155 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 654cd621145c5..f25e7d480b6b3 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -99,7 +99,7 @@ class TestResources(NamedTuple): kv_cache: torch.Tensor -def _make_test_resources(test_pt: TestPoint) -> TestResources: +def _make_test_resources(test_pt: TestPoint, ) -> TestResources: ''' Build key components for performing encoder/decoder attention test. @@ -146,8 +146,10 @@ class that Attention will automatically select when it is constructed. return TestResources(scale, attn_backend, attn, kv_cache) -def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ - -> PhaseTestParameters: +def _encoder_attn_setup( + test_pt: TestPoint, + test_rsrcs: TestResources, +) -> PhaseTestParameters: ''' Set up test vectors & data structures for encoder attention test. @@ -177,7 +179,16 @@ def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ implementation, and (3) KVCache field set to None ''' - (num_heads, head_size, _, batch_size, _, _, max_q_seq_len, _) = test_pt + ( + num_heads, + head_size, + _, + batch_size, + _, + _, + max_q_seq_len, + _, + ) = test_pt scale = test_rsrcs.scale @@ -210,12 +221,9 @@ def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE) return PhaseTestParameters( - PackedQKVO( - packed_qkv, \ - packed_ideal_output), - - None # No KV cache - ) + PackedQKVO(packed_qkv, packed_ideal_output), + None # No KV cache + ) def _decoder_attn_setup( @@ -279,8 +287,16 @@ def _decoder_attn_setup( constructed in this function) ''' - (num_heads, head_size, _, batch_size, block_size, max_q_seq_len, _, - _) = test_pt + ( + num_heads, + head_size, + _, + batch_size, + block_size, + max_q_seq_len, + _, + _, + ) = test_pt scale = test_rsrcs.scale @@ -288,15 +304,17 @@ def _decoder_attn_setup( # Build test tensors - qkv, \ - prefill_qkv, \ - decode_qkv = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.DECODER, - device=CUDA_DEVICE) + ( + qkv, + prefill_qkv, + decode_qkv, + ) = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.DECODER, + device=CUDA_DEVICE) # Compute correct answer using naive attention implementation # with causal attention mask @@ -351,49 +369,45 @@ def _decoder_attn_setup( prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) - decode_block_tables, \ - slot_mapping_list, \ - max_block_idx = make_block_tables_slot_mapping(block_size, - qkv.q_seq_lens, - device=CUDA_DEVICE, - block_base_addr = block_base_addr) - - prefill_slot_mapping, \ - decode_slot_mapping = split_slot_mapping(slot_mapping_list, - qkv.q_seq_lens, - device=CUDA_DEVICE) + ( + decode_block_tables, + slot_mapping_list, + max_block_idx, + ) = make_block_tables_slot_mapping(block_size, + qkv.q_seq_lens, + device=CUDA_DEVICE, + block_base_addr=block_base_addr) + + ( + prefill_slot_mapping, + decode_slot_mapping, + ) = split_slot_mapping(slot_mapping_list, + qkv.q_seq_lens, + device=CUDA_DEVICE) prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE) decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE) - return qkv, \ - PhaseTestParameters( # Prefill test params - PackedQKVO( - prefill_pckd_qkv, \ - prefill_packed_ideal_output), \ - KVMemoryMap( - prefill_block_tables, \ - prefill_slot_mapping)), \ - PhaseTestParameters( # Decode test params - PackedQKVO( - decode_pckd_qkv, \ - decode_packed_ideal_output), \ - KVMemoryMap( - decode_block_tables, \ - decode_slot_mapping)), \ - max_block_idx - -def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, - encoder_test_params: - PhaseTestParameters, - prefill_decoder_phase_test_params: - PhaseTestParameters, - test_pt: TestPoint, - test_rsrcs: TestResources, - block_base_addr: int=0) \ - -> Tuple[PhaseTestParameters, - PhaseTestParameters]: + return ( + qkv, + PhaseTestParameters( # Prefill test params + PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output), + KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), + PhaseTestParameters( # Decode test params + PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output), + KVMemoryMap(decode_block_tables, decode_slot_mapping)), + max_block_idx) + + +def _enc_dec_cross_attn_setup_reuses_query( + decoder_qkv: QKVInputs, + encoder_test_params: PhaseTestParameters, + prefill_decoder_phase_test_params: PhaseTestParameters, + test_pt: TestPoint, + test_rsrcs: TestResources, + block_base_addr: int = 0, +) -> Tuple[PhaseTestParameters, PhaseTestParameters]: ''' Set up test vectors & data structures for cross-attention test. @@ -460,22 +474,32 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, assert encoder_test_params.packed_qkvo.packed_qkv is not None assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None - (num_heads, head_size, _, batch_size, block_size, max_decoder_seq_len, - max_encoder_seq_len, _) = test_pt + ( + num_heads, + head_size, + _, + batch_size, + block_size, + max_decoder_seq_len, + max_encoder_seq_len, + _, + ) = test_pt scale = test_rsrcs.scale decoder_query = decoder_qkv.query decoder_seq_lens = decoder_qkv.q_seq_lens encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens - prefill_q_seq_lens = \ - prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens + prefill_q_seq_lens = ( + prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens) assert prefill_q_seq_lens is not None - cross_kv, \ - _, \ - _ = make_qkv(batch_size, + ( + cross_kv, + _, + _, + ) = make_qkv(batch_size, max_decoder_seq_len, max_encoder_seq_len, num_heads, @@ -537,13 +561,14 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE) - decode_block_tables, \ - prefill_slot_mapping_list, \ - _ = make_block_tables_slot_mapping( - block_size, - cross_kv.kv_seq_lens, - block_base_addr=block_base_addr, - device=CUDA_DEVICE) + ( + decode_block_tables, + prefill_slot_mapping_list, + _, + ) = make_block_tables_slot_mapping(block_size, + cross_kv.kv_seq_lens, + block_base_addr=block_base_addr, + device=CUDA_DEVICE) prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list, device=CUDA_DEVICE) @@ -551,26 +576,20 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, # Packed key/value (query is already provided) packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE) - return PhaseTestParameters( # Prefill-phase test params - PackedQKVO( - packed_cross_kv, \ - prefill_packed_ideal_output), \ - KVMemoryMap( - prefill_block_tables, \ - prefill_slot_mapping)), \ - PhaseTestParameters( # Decode-phase test params - PackedQKVO( - None, - decode_packed_ideal_output), \ - KVMemoryMap( - decode_block_tables, \ - decode_slot_mapping)) - - -def _run_encoder_attention_test(attn: Attention, - encoder_test_params: PhaseTestParameters, - attn_metadata: AttentionMetadata) \ - -> torch.Tensor: + return ( + PhaseTestParameters( # Prefill-phase test params + PackedQKVO(packed_cross_kv, prefill_packed_ideal_output), + KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), + PhaseTestParameters( # Decode-phase test params + PackedQKVO(None, decode_packed_ideal_output), + KVMemoryMap(decode_block_tables, decode_slot_mapping))) + + +def _run_encoder_attention_test( + attn: Attention, + encoder_test_params: PhaseTestParameters, + attn_metadata: AttentionMetadata, +) -> torch.Tensor: ''' Run encoder attention. @@ -605,10 +624,11 @@ def _run_encoder_attention_test(attn: Attention, attn_type=attn_type) -def _run_decoder_self_attention_test(test_rsrcs: TestResources, - decoder_test_params: PhaseTestParameters, - attn_metadata: AttentionMetadata) \ - -> torch.Tensor: +def _run_decoder_self_attention_test( + test_rsrcs: TestResources, + decoder_test_params: PhaseTestParameters, + attn_metadata: AttentionMetadata, +) -> torch.Tensor: ''' Run decoder self-attention test. @@ -644,9 +664,11 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, def _run_encoder_decoder_cross_attention_test( - test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, - cross_test_params: Optional[PhaseTestParameters], - attn_metadata: AttentionMetadata) -> torch.Tensor: + test_rsrcs: TestResources, + decoder_test_params: PhaseTestParameters, + cross_test_params: Optional[PhaseTestParameters], + attn_metadata: AttentionMetadata, +) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -689,10 +711,8 @@ def _run_encoder_decoder_cross_attention_test( value = None else: cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv - key = None if cross_pckd_qkv is None else \ - cross_pckd_qkv.key - value = None if cross_pckd_qkv is None else \ - cross_pckd_qkv.value + key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) + value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, value, @@ -744,11 +764,8 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str, # PREFILL: encoder attention - enc_pckd_act_out: torch.Tensor = \ - _run_encoder_attention_test( - test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata) + enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( + test_rsrcs.attn, enc_test_params, prephase_attn_metadata)) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) @@ -762,10 +779,16 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str, @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, - batch_size: int, block_size: int, - max_dec_seq_len: int, max_enc_seq_len: int, - monkeypatch) -> None: +def test_e2e_enc_dec_attn( + num_heads: int, + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_dec_seq_len: int, + max_enc_seq_len: int, + monkeypatch, +) -> None: ''' End-to-end encoder/decoder test: @@ -840,23 +863,26 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # decoder self-attention block-table, i.e. a base address which the # encoder/decoder cross-attention block-table may build downward toward. - dec_qkv, \ - prephase_dec_test_params, \ - decphase_dec_test_params, \ - cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) + ( + dec_qkv, + prephase_dec_test_params, + decphase_dec_test_params, + cross_block_base_addr, + ) = _decoder_attn_setup(test_pt, test_rsrcs) # Construct encoder/decoder cross-attention prefill-phase & decode-phase # test params, including key/value tensors, cross-attention memory-mapping - prephase_cross_test_params, \ - decphase_cross_test_params, \ - = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, - enc_test_params, - prephase_dec_test_params, - test_pt, - test_rsrcs, - block_base_addr = \ - cross_block_base_addr) + ( + prephase_cross_test_params, + decphase_cross_test_params, + ) = _enc_dec_cross_attn_setup_reuses_query( + dec_qkv, + enc_test_params, + prephase_dec_test_params, + test_pt, + test_rsrcs, + block_base_addr=cross_block_base_addr) # Shared prefill metadata structure assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None @@ -871,22 +897,17 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # PREFILL: encoder attention - enc_pckd_act_out: torch.Tensor = \ - _run_encoder_attention_test( - test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata) + enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) # PREFILL: decoder self-attention test - prephase_dec_pckd_act_out: torch.Tensor = \ - _run_decoder_self_attention_test( - test_rsrcs, - prephase_dec_test_params, - prephase_attn_metadata) + prephase_dec_pckd_act_out = _run_decoder_self_attention_test( + test_rsrcs, prephase_dec_test_params, prephase_attn_metadata) # - Is prefill decoder self-attention correct? assert_actual_matches_ideal(prephase_dec_test_params, @@ -894,11 +915,8 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # PREFILL: encoder/decoder cross-attention test - prephase_cross_pckd_act_out: torch.Tensor = \ - _run_encoder_decoder_cross_attention_test( - test_rsrcs, - prephase_dec_test_params, - prephase_cross_test_params, + prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( + test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, prephase_attn_metadata) # - Is prefill encoder/decoder cross-attention correct? @@ -918,11 +936,8 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # DECODE: decoder self-attention test - decphase_dec_pckd_act_out: torch.Tensor = \ - _run_decoder_self_attention_test( - test_rsrcs, - decphase_dec_test_params, - decphase_attn_metadata) + decphase_dec_pckd_act_out = _run_decoder_self_attention_test( + test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) # - Is decode-phase decoder self-attention correct? assert_actual_matches_ideal(decphase_dec_test_params, @@ -930,12 +945,8 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # DECODE: encoder/decoder cross-attention test - decphase_cross_pckd_act_out: torch.Tensor = \ - _run_encoder_decoder_cross_attention_test( - test_rsrcs, - decphase_dec_test_params, - None, - decphase_attn_metadata) + decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( + test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 6b9b2c6f4b5a4..48418f24870f9 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -160,9 +160,9 @@ def forward( """ assert kv_scale == 1.0 if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + - "encoder/decoder cross-attention " + - "are not implemented for " + + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "TorchSDPABackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. From 4f27946dcfb73f0a60420eb3ca6c9a74f6c6d3d1 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:27:35 -0400 Subject: [PATCH 305/443] tests/kernels/utils.py cleanup --- tests/kernels/utils.py | 173 +++++++++++++++++++++-------------------- 1 file changed, 87 insertions(+), 86 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index f0b0dd5dbaee6..23d627820d247 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -138,9 +138,11 @@ class PhaseTestParameters(NamedTuple): packed_qkvo: PackedQKVO kv_mmap: Optional[KVMemoryMap] -def maybe_make_int_tensor(_list: Optional[List[int]], - device: Union[torch.device, str]) \ - -> torch.Tensor: + +def maybe_make_int_tensor( + _list: Optional[List[int]], + device: Union[torch.device, str], +) -> torch.Tensor: ''' Convert Python int list to a 1D int torch.Tensor on `device` @@ -152,9 +154,11 @@ def maybe_make_int_tensor(_list: Optional[List[int]], return None if _list is None else torch.tensor( _list, dtype=torch.int, device=device) -def maybe_make_long_tensor(_list: Optional[List[int]], - device: Union[torch.device, str]) \ - -> torch.Tensor: + +def maybe_make_long_tensor( + _list: Optional[List[int]], + device: Union[torch.device, str], +) -> torch.Tensor: ''' Convert Python int list to a 1D long torch.Tensor on `device` @@ -176,8 +180,11 @@ def maybe_max(_list: Optional[List]) -> Optional[Number]: ''' return None if _list is None else max(_list) -def make_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ - -> torch.Tensor: + +def make_causal_mask( + q_max_seq_len: int, + kv_max_seq_len: int, +) -> torch.Tensor: ''' Create a q_max_seq_len x kv_max_seq_len causal mask @@ -394,22 +401,25 @@ def make_qkv( decode_q_seq_lens = [1 for _ in q_seq_lens] decode_kv_seq_lens = [1 for _ in kv_seq_lens] - return QKVInputs(query, # Overall QKV inputs - key, - value, - q_seq_lens, - kv_seq_lens), \ - QKVInputs(prefill_query, # Prefill subset of QKV sequences - prefill_key, - prefill_value, - prefill_q_seq_lens, - prefill_kv_seq_lens), \ - QKVInputs( - decode_query, # Decode subset of KV sequences - decode_key, - decode_value, - decode_q_seq_lens, - decode_kv_seq_lens) + return ( + QKVInputs( + query, # Overall QKV inputs + key, + value, + q_seq_lens, + kv_seq_lens), + QKVInputs( + prefill_query, # Prefill subset of QKV sequences + prefill_key, + prefill_value, + prefill_q_seq_lens, + prefill_kv_seq_lens), + QKVInputs( + decode_query, # Decode subset of KV sequences + decode_key, + decode_value, + decode_q_seq_lens, + decode_kv_seq_lens)) def pack_tensor( @@ -481,14 +491,11 @@ def pack_qkv(qkv: QKVInputs, device: Union[torch.device, qkv.kv_seq_lens, device=device) packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device) - return PackedQKVInputs(packed_query, \ - packed_key, \ - packed_value, \ - q_start_loc_list, \ - kv_start_loc_list, \ - None if q_start_loc_list is None else \ - qkv.q_seq_lens, \ - qkv.kv_seq_lens) + return PackedQKVInputs( + packed_query, packed_key, packed_value, q_start_loc_list, + kv_start_loc_list, + (None if q_start_loc_list is None else qkv.q_seq_lens), + qkv.kv_seq_lens) def make_backend(backend_name: str) -> AttentionBackend: @@ -547,18 +554,13 @@ def _make_metadata_tensors( max_seq_len = maybe_max(seq_lens) encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device) - max_encoder_seq_len = None if encoder_seq_lens is None else \ - max(encoder_seq_lens) + max_encoder_seq_len = (None if encoder_seq_lens is None else + max(encoder_seq_lens)) seq_start_loc = None - return seq_lens_tensor, \ - context_lens_tensor, \ - max_context_len, \ - max_seq_len, \ - seq_start_loc, \ - encoder_seq_lens_tensor, \ - max_encoder_seq_len + return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, + seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len) def make_kv_cache(num_blocks: int, @@ -659,8 +661,8 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1]) base_idx += seq_len - return maybe_make_long_tensor(prefill_slot_mapping, device), \ - maybe_make_long_tensor(decode_slot_mapping, device) + return (maybe_make_long_tensor(prefill_slot_mapping, device), + maybe_make_long_tensor(decode_slot_mapping, device)) def make_block_tables_slot_mapping( @@ -741,9 +743,7 @@ def make_block_tables_slot_mapping( device=device, ) - return block_tables_tensor, \ - slot_mapping_list, \ - max_block_idx + return (block_tables_tensor, slot_mapping_list, max_block_idx) def make_test_metadata( @@ -797,8 +797,8 @@ def make_test_metadata( # Decoder self-attention memory mapping # decoder_test_params is None signals encoder-only # scenario, so kv_mmap is None - kv_mmap = None if decoder_test_params is None else \ - decoder_test_params.kv_mmap + kv_mmap = (None + if decoder_test_params is None else decoder_test_params.kv_mmap) # This function constructs metadata assuming no chunked prefill, # i.e. 100% prefill tokens or 100% decode tokens @@ -811,11 +811,10 @@ def make_test_metadata( # seq_lens is None signals encoder-only # scenario, in which case num_prefills_or_decodes and # num_prefill_or_decode_tokens are unused - num_prefills_or_decodes = None if seq_lens is None else \ - len(seq_lens) + num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens)) - num_prefill_or_decode_tokens = None if seq_lens is None else \ - (sum(seq_lens) if is_prompt else len(seq_lens)) + num_prefill_or_decode_tokens = (None if seq_lens is None else ( + sum(seq_lens) if is_prompt else len(seq_lens))) # Seems for non-prefix-caching scenarios context_lens # is never needed @@ -829,8 +828,8 @@ def make_test_metadata( # * Extract encoder input sequence lengths assert encoder_test_params.packed_qkvo.packed_qkv is not None encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens - num_encoder_tokens = None if encoder_seq_lens is None else \ - (sum(encoder_seq_lens)) + num_encoder_tokens = (None if encoder_seq_lens is None else + (sum(encoder_seq_lens))) if cross_test_params is None: cross_kv_mmap = None @@ -847,21 +846,22 @@ def make_test_metadata( num_prefill_tokens = num_prefill_or_decode_tokens num_decode_tokens = 0 - seq_lens_tensor, \ - context_lens_tensor, \ - _, \ - _, \ - _, \ - encoder_seq_lens_tensor, \ - max_encoder_seq_len = _make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) + ( + seq_lens_tensor, + context_lens_tensor, + _, + _, + _, + encoder_seq_lens_tensor, + max_encoder_seq_len, + ) = _make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=None if kv_mmap is None else \ - kv_mmap.slot_mapping, + slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -869,17 +869,16 @@ def make_test_metadata( max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, context_lens_tensor=context_lens_tensor, - block_tables=None if kv_mmap is None else \ - kv_mmap.block_tables, + block_tables=(None if kv_mmap is None else kv_mmap.block_tables), use_cuda_graph=False, num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=None if cross_kv_mmap is None else \ - cross_kv_mmap.slot_mapping, - cross_block_tables=None if cross_kv_mmap is None else \ - cross_kv_mmap.block_tables) + cross_slot_mapping=(None if cross_kv_mmap is None else + cross_kv_mmap.slot_mapping), + cross_block_tables=(None if cross_kv_mmap is None else + cross_kv_mmap.block_tables)) else: # not is_prompt # Decode-phase scenario @@ -892,16 +891,18 @@ def make_test_metadata( num_prefill_tokens = 0 num_decode_tokens = num_prefill_or_decode_tokens - seq_lens_tensor, \ - context_lens_tensor, \ - _, \ - _, \ - _, \ - encoder_seq_lens_tensor, \ - max_encoder_seq_len = _make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) + ( + seq_lens_tensor, + context_lens_tensor, + _, + _, + _, + encoder_seq_lens_tensor, + max_encoder_seq_len, + ) = _make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -919,10 +920,10 @@ def make_test_metadata( encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=None if cross_kv_mmap is None else \ - cross_kv_mmap.slot_mapping, - cross_block_tables=None if cross_kv_mmap is None else \ - cross_kv_mmap.block_tables) + cross_slot_mapping=(None if cross_kv_mmap is None else + cross_kv_mmap.slot_mapping), + cross_block_tables=(None if cross_kv_mmap is None else + cross_kv_mmap.block_tables)) def assert_actual_matches_ideal(test_params: PhaseTestParameters, From 5ee30fed1d27dbef98dc3e4512741c9ca301197c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:31:09 -0400 Subject: [PATCH 306/443] vllm/attention/backends/abstract.py cleanup --- vllm/attention/backends/abstract.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 8e386fd4e3ce8..adb8325168cdf 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -128,12 +128,13 @@ def __init__( @abstractmethod def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: T, - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: raise NotImplementedError From 45fc9f71641bdd17c67997598463f12ead3998b2 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:35:00 -0400 Subject: [PATCH 307/443] vllm/attention/backends/blocksparse_attn.py cleanup --- vllm/attention/backends/blocksparse_attn.py | 23 +++++++++++---------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 470b6339a3006..fe4c4a45dca0d 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -321,14 +321,15 @@ def __init__( ) def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: BlocksparseFlashAttentionMetadata, - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: BlocksparseFlashAttentionMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. Args: @@ -341,9 +342,9 @@ def forward( shape = [num_tokens, num_heads * head_size] """ if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "BlocksparseFlashAttentionImpl") num_tokens, hidden_size = query.shape From 097aff2029e4560ae28bd7a7acf0f20509f803fe Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:36:05 -0400 Subject: [PATCH 308/443] vllm/attention/backends/flash_attn.py cleanup --- vllm/attention/backends/flash_attn.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f9a04f63acbec..048abed48d2e9 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -250,14 +250,15 @@ def __init__( f"Supported head sizes are: {support_head_sizes}.") def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with FlashAttention. Args: @@ -270,9 +271,9 @@ def forward( shape = [num_tokens, num_heads * head_size] """ if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "FlashAttentionImpl") # NOTE(woosuk): FlashAttention does not support FP8 KV cache. From d8a692b7dde0656696b726497030970aac0b53d3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:39:37 -0400 Subject: [PATCH 309/443] cleaning up a number of backends & backends utils.py --- vllm/attention/backends/flashinfer.py | 23 +++++++++++----------- vllm/attention/backends/ipex_attn.py | 23 +++++++++++----------- vllm/attention/backends/pallas.py | 23 +++++++++++----------- vllm/attention/backends/rocm_flash_attn.py | 23 +++++++++++----------- vllm/attention/backends/utils.py | 5 ++--- 5 files changed, 50 insertions(+), 47 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 615e427089865..b27e3e40f566d 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -217,19 +217,20 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: FlashInferMetadata, - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: FlashInferMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: assert kv_scale == 1.0 if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "FlashInferImpl") num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 2404ff68fd47f..6a1295b1000bc 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -150,14 +150,15 @@ def split_kv_cache( return key_cache, value_cache def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: IpexAttnMetadata, # type: ignore - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: IpexAttnMetadata, # type: ignore + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. Args: @@ -171,9 +172,9 @@ def forward( """ assert kv_scale == 1.0 if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "IpexAttnBackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index fbfba742fb643..7a6954ceb6d6a 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -125,14 +125,15 @@ def __init__( self.megacore_mode = "batch" def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], - attn_metadata: PallasMetadata, - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], + attn_metadata: PallasMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with Pallas attention. Args: @@ -147,9 +148,9 @@ def forward( """ assert kv_scale == 1.0 if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "PallasAttentionBackendImpl") batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6107a3652b049..81b546c65c819 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -290,14 +290,15 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: head_dim)) def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: ROCmFlashAttentionMetadata, - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: ROCmFlashAttentionMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. Args: @@ -310,9 +311,9 @@ def forward( shape = [num_tokens, num_heads * head_size] """ if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "ROCmFlashAttentionImpl") num_tokens, hidden_size = query.shape diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 82a1f46db6e09..a3cfc6e20748b 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -3,6 +3,5 @@ # Error string(s) for encoder/decoder # unsupported attention scenarios -STR_NOT_IMPL_ENC_DEC_ROCM_HIP = \ -"ROCm/HIP is not currently supported" + \ -"with encoder/decoder models." +STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " + "with encoder/decoder models.") From 5df73fc708bf3370a5f6d7f85cce4772d5c679b5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:47:04 -0400 Subject: [PATCH 310/443] xformers backend cleanup --- vllm/attention/backends/xformers.py | 146 ++++++++++++++-------------- 1 file changed, 74 insertions(+), 72 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b1daaefc9f3b5..79aa8309bb225 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -149,9 +149,9 @@ def is_all_encoder_attn_metadata_set(self): ''' All attention metadata required for encoder attention is set. ''' - return (self.encoder_seq_lens is not None) and \ - (self.encoder_seq_lens_tensor is not None) and \ - (self.max_encoder_seq_len is not None) + return ((self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None)) @property def is_all_cross_attn_metadata_set(self): @@ -160,9 +160,9 @@ def is_all_cross_attn_metadata_set(self): Superset of encoder attention required metadata. ''' - return self.is_all_encoder_attn_metadata_set and \ - (self.cross_slot_mapping is not None) and \ - (self.cross_block_tables is not None) + return (self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None)) @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: @@ -174,24 +174,24 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: # metadata structure return self._cached_prefill_metadata - assert (self.seq_lens is not None) or \ - (self.encoder_seq_lens is not None) - assert (self.seq_lens_tensor is not None) or \ - (self.encoder_seq_lens_tensor is not None) + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) # Compute some attn_metadata fields which default to None - query_start_loc = None if self.query_start_loc is None \ - else self.query_start_loc[:self.num_prefills + 1] - slot_mapping=None if self.slot_mapping is None else \ - self.slot_mapping[:self.num_prefill_tokens] - seq_lens=None if self.seq_lens is None \ - else self.seq_lens[:self.num_prefills] - seq_lens_tensor=None if self.seq_lens_tensor is None else \ - self.seq_lens_tensor[:self.num_prefills] - context_lens_tensor=None if self.context_lens_tensor is None else \ - self.context_lens_tensor[:self.num_prefills] - block_tables=None if self.block_tables is None else \ - self.block_tables[:self.num_prefills] + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = XFormersMetadata( @@ -225,16 +225,16 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: # Recover cached decode-phase attention # metadata structure return self._cached_decode_metadata - assert (self.seq_lens_tensor is not None) or \ - (self.encoder_seq_lens_tensor is not None) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) # Compute some attn_metadata fields which default to None - slot_mapping=None if self.slot_mapping is None else \ - self.slot_mapping[self.num_prefill_tokens:] - seq_lens_tensor=None if self.seq_lens_tensor is None else \ - self.seq_lens_tensor[self.num_prefills:] - block_tables=None if self.block_tables is None else \ - self.block_tables[self.num_prefills:] + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = XFormersMetadata( @@ -255,9 +255,11 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata -def _get_attn_bias(attn_metadata: XFormersMetadata, - attn_type: AttentionType) -> \ - Optional[AttentionBias]: + +def _get_attn_bias( + attn_metadata: XFormersMetadata, + attn_type: AttentionType, +) -> Optional[AttentionBias]: ''' Extract appropriate attention bias from attention metadata according to attention type. @@ -283,7 +285,8 @@ def _get_attn_bias(attn_metadata: XFormersMetadata, def _set_attn_bias(attn_metadata: XFormersMetadata, attn_bias: List[Optional[AttentionBias]], - attn_type: AttentionType) -> None: + attn_type: AttentionType, + ) -> None: ''' Update appropriate attention bias field of attention metadata, according to attention type. @@ -306,10 +309,11 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, raise AttributeError(f"Invalid attention type {str(attn_type)}") -def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, - is_prompt: bool, - attn_type: AttentionType) \ - -> tuple: +def _get_seq_len_block_table_args( + attn_metadata: XFormersMetadata, + is_prompt: bool, + attn_type: AttentionType, +) -> tuple: ''' The particular choice of sequence-length- and block-table-related attributes which should be extracted from attn_metadata is dependent @@ -341,20 +345,18 @@ def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, max_seq_len = attn_metadata.max_prefill_seq_len else: max_seq_len = attn_metadata.max_decode_seq_len - return attn_metadata.seq_lens_tensor, \ - max_seq_len, \ - attn_metadata.block_tables + return (attn_metadata.seq_lens_tensor, max_seq_len, + attn_metadata.block_tables) elif attn_type == AttentionType.ENCODER_DECODER: # Enc/dec cross-attention KVs match encoder sequence length; # cross-attention utilizes special "cross" block tables - return attn_metadata.encoder_seq_lens_tensor, \ - attn_metadata.max_encoder_seq_len, \ - attn_metadata.cross_block_tables + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_block_tables) elif attn_type == AttentionType.ENCODER: # No block tables associated with encoder attention - return attn_metadata.encoder_seq_lens_tensor, \ - attn_metadata.max_encoder_seq_len, \ - None + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, None) else: raise AttributeError(f"Invalid attention type {str(attn_type)}") @@ -418,14 +420,15 @@ def __init__( f"Supported head sizes are: {suppored_head_sizes}.") def forward( - self, - query: torch.Tensor, - key: Optional[torch.Tensor], - value: Optional[torch.Tensor], - kv_cache: Optional[torch.Tensor], - attn_metadata: "XFormersMetadata", - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor], + attn_metadata: "XFormersMetadata", + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. For decoder-only models: query, key and value must be non-None. @@ -475,14 +478,14 @@ def forward( # Check that appropriate attention metadata attributes are # selected for the desired attention type - if attn_type == AttentionType.ENCODER and \ - (not attn_metadata.is_all_encoder_attn_metadata_set): - raise AttributeError("Encoder attention requires setting " + \ + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " "encoder metadata attributes.") - elif attn_type == AttentionType.ENCODER_DECODER and \ - (not attn_metadata.is_all_cross_attn_metadata_set): - raise AttributeError("Encoder/decoder cross-attention " + \ - "requires setting cross-attention " + \ + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " "metadata attributes.") query = query.view(-1, self.num_heads, self.head_size) @@ -497,8 +500,7 @@ def forward( # which KV cache memory-mapping & which # seqlen datastructures we utilize - if (attn_type != AttentionType.ENCODER and \ - kv_cache is not None): + if (attn_type != AttentionType.ENCODER and kv_cache is not None): # KV-cache during decoder-self- or # encoder-decoder-cross-attention, but not # during encoder attention. @@ -600,11 +602,11 @@ def forward( if decode_meta := attn_metadata.decode_metadata: - seq_lens_arg, \ - max_seq_len_arg,\ - block_tables_arg = _get_seq_len_block_table_args(decode_meta, - False, - attn_type) + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = _get_seq_len_block_table_args(decode_meta, False, attn_type) output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, @@ -629,7 +631,8 @@ def _run_memory_efficient_xformers_forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: XFormersMetadata, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. @@ -668,8 +671,7 @@ def _run_memory_efficient_xformers_forward( attn_bias = _get_attn_bias(attn_metadata, attn_type) if attn_bias is None: if self.alibi_slopes is None: - if attn_type == \ - AttentionType.ENCODER_DECODER: + if (attn_type == AttentionType.ENCODER_DECODER): assert attn_metadata.seq_lens is not None assert attn_metadata.encoder_seq_lens is not None From 6cd595c3c879d4ee603bb6a5bc0f1724647a5135 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:47:20 -0400 Subject: [PATCH 311/443] formatting --- vllm/attention/backends/xformers.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 79aa8309bb225..6cc5f1d1477ae 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -283,10 +283,11 @@ def _get_attn_bias( return attn_metadata.cross_attn_bias -def _set_attn_bias(attn_metadata: XFormersMetadata, - attn_bias: List[Optional[AttentionBias]], - attn_type: AttentionType, - ) -> None: +def _set_attn_bias( + attn_metadata: XFormersMetadata, + attn_bias: List[Optional[AttentionBias]], + attn_type: AttentionType, +) -> None: ''' Update appropriate attention bias field of attention metadata, according to attention type. @@ -626,13 +627,13 @@ def forward( return output.view(-1, self.num_heads * self.head_size) def _run_memory_efficient_xformers_forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: XFormersMetadata, - attn_type: AttentionType = AttentionType.DECODER, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: XFormersMetadata, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. From bd14d29177dda7bd1f2ddd41ccba71703dbaa07d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 9 Jul 2024 16:17:24 -0400 Subject: [PATCH 312/443] wip scheduler --- vllm/core/scheduler.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 9e626b2883975..c2a5be70342f0 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -369,6 +369,17 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: seq.status = SequenceStatus.FINISHED_ABORTED self.free_seq(seq) + self._free_seq_group(aborted_group) + + def _free_seq_group(self, + seq_group: SequenceGroup, + ) -> None: + """ + Free a sequence group from a cross-attention block table. + Has no effect on decoder-only models. + """ + self.block_manager.free_cross(seq_group) + def has_unfinished_seqs(self) -> bool: return len(self.waiting) != 0 or len(self.running) != 0 or len( self.swapped) != 0 From 2c80185fb81602a9a39afe4137bc5f59bcb69f57 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 9 Jul 2024 16:36:11 -0400 Subject: [PATCH 313/443] formatting --- examples/offline_inference_encoder_decoder.py | 31 +-- tests/kernels/test_attention_selector.py | 1 + tests/kernels/test_encoder_decoder_attn.py | 1 - tests/kernels/utils.py | 178 ------------------ tests/models/test_bart.py | 25 +-- vllm/core/scheduler.py | 5 +- vllm/model_executor/models/bart.py | 62 +++--- vllm/utils.py | 3 +- 8 files changed, 59 insertions(+), 247 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 0426ec6e5a481..737221506dbd6 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -1,3 +1,5 @@ +from transformers import AutoTokenizer, BartForConditionalGeneration + from vllm import LLM, SamplingParams dtype = "float" @@ -28,11 +30,15 @@ print(prompts) # Create a sampling params object. -sampling_params = SamplingParams(temperature=0, top_p=1.0, min_tokens=0, max_tokens=20,) -#sampling_params = SamplingParams(temperature=0, top_p=1.0, use_beam_search=True, best_of=2, min_tokens=0, max_tokens=20,) +sampling_params = SamplingParams( + temperature=0, + top_p=1.0, + min_tokens=0, + max_tokens=20, +) # Create an LLM. -llm = LLM(model="facebook/bart-large-cnn", enforce_eager=True, dtype = dtype) +llm = LLM(model="facebook/bart-large-cnn", enforce_eager=True, dtype=dtype) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -42,20 +48,17 @@ generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - -from transformers import AutoTokenizer, BartForConditionalGeneration - model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") -ARTICLE_TO_SUMMARIZE = ( - "PG&E stated it scheduled the blackouts in response to forecasts for high winds " - "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " - "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." -) -inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt") +ARTICLE_TO_SUMMARIZE = encoder_prompts[0] +inputs = tokenizer([ARTICLE_TO_SUMMARIZE], + max_length=1024, + return_tensors="pt") # Generate Summary -#summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20) summary_ids = model.generate(inputs["input_ids"], min_length=0, max_length=20) -print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) +print( + tokenizer.batch_decode(summary_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False)) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index d15c5f6c91154..a20a741c27f74 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -7,6 +7,7 @@ from vllm.attention.selector import which_attn_to_use from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL + @pytest.mark.parametrize( "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"]) @pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"]) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 88c9c5978766e..e0880a051f834 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -9,7 +9,6 @@ """ -import copy from typing import NamedTuple, Optional import pytest diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 61cced6797055..f4dfbb977ab88 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -193,184 +193,6 @@ def make_causal_mask( return mask -class QKVInputs(NamedTuple): - ''' - Data structure for representing unpacked attention inputs, - query/key/values and their sequence lengths. - - Attributes: - - * {query,key,value}: unpacked (batch_size x padded_seq_len x - num_heads x head_size) attention inputs - * q_seq_lens: query sequence lengths list - * kv_seq_lens: shared key/value sequence lengths list - ''' - - query: torch.Tensor - key: torch.Tensor - value: torch.Tensor - q_seq_lens: List[int] - kv_seq_lens: List[int] - - -class QKVO(NamedTuple): - ''' - Data structure for representing unpacked attention inputs, - alongside unpacked known-correct attention output - - Attributes: - - * qkv: unpacked (batch_size x padded_seq_len x - num_heads x head_size) attention inputs - * ideal_output: unpacked (batch_size x padded_seq_len x - num_heads x head_size) known-correct attention output - ''' - - qkv: QKVInputs - ideal_output: torch.Tensor - - -class PackedQKVInputs(NamedTuple): - ''' - Data structure for representing packed attention inputs - - Attributes: - - * {query,key,value}: packed (number_of_tokens x num_heads - x head_size) attention inputs - * q_start_loc_list: list of query start locations within packed tensor - * kv_start_loc_list: shared list of key/value start locations within - packed tensor - * q_seq_lens: query sequence lengths list - * kv_seq_lens: shared key/value sequence lengths list - ''' - - query: torch.Tensor - key: torch.Tensor - value: torch.Tensor - q_start_loc_list: Optional[List[int]] - kv_start_loc_list: Optional[List[int]] - q_seq_lens: Optional[List[int]] - kv_seq_lens: Optional[List[int]] - - -class PackedQKVO(NamedTuple): - ''' - Data structure for representing packed attention inputs, - alongside packed known-correct attention output - - Attributes: - - * packed_qkv: packed (number_of_tokens x num_heads - x head_size) attention inputs - * ideal_output: packed (number_of_tokens x num_heads - x head_size) known-correct attention output - ''' - - packed_qkv: Optional[PackedQKVInputs] - ideal_output: torch.Tensor - - -class KVMemoryMap(NamedTuple): - ''' - Data structure for encapsulating KV cache memory mapping. - - Attributes: - - * block_tables: KV cache block tables - * slot_mapping: mapping of sequence offset to physical address - ''' - - block_tables: torch.Tensor - slot_mapping: torch.Tensor - - -class PhaseTestParameters(NamedTuple): - ''' - Data structure for encapsulating the test parameters - for a given test "phase" (prefill or decode phase) and attention - scenario (encoder, decoder-self, encoder/decoder-cross) - - Attributes: - - * packed_qkvo: packed (number_of_tokens x num_heads - x head_size) attention inputs & known-correct - output - * kv_mmap: KV cache memory mapping, specific to this test phase & - attention scenario - ''' - - packed_qkvo: PackedQKVO - kv_mmap: Optional[KVMemoryMap] - - -def maybe_make_int_tensor( - _list: Optional[List[int]], - device: Union[torch.device, str], -) -> torch.Tensor: - ''' - Convert Python int list to a 1D int torch.Tensor on `device` - - Returns: - - * If _list is not None: 1D int torch.Tensor on `device` - * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.int, device=device) - - -def maybe_make_long_tensor( - _list: Optional[List[int]], - device: Union[torch.device, str], -) -> torch.Tensor: - ''' - Convert Python int list to a 1D long torch.Tensor on `device` - - Returns: - - * If _list is not None: 1D long torch.Tensor on `device` - * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.long, device=device) - - -def maybe_max(_list: Optional[List]) -> Optional[Number]: - ''' - Returns: - - * If _list is not None: max(_list) - * None otherwise - ''' - return None if _list is None else max(_list) - - -def make_causal_mask( - q_max_seq_len: int, - kv_max_seq_len: int, -) -> torch.Tensor: - ''' - Create a q_max_seq_len x kv_max_seq_len causal mask - - Arguments: - - * q_max_seq_len: query max seq len - * kv_max_seq_len: key/value max seq len - - Returns: - - * 2D tensor, q_max_seq_len x kv_max_seq_len - ''' - - # Create a matrix where entry (i, j) is True if i >= j - mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) - # Replace True with float('-inf') and False with 0 - mask = mask.masked_fill(mask == 1, - float('-inf')).masked_fill(mask == 0, 0.0) - return mask - - def override_backend_env_variable(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: ''' diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index 8ba22eb4cae8a..2bf8c97131da3 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -7,9 +7,9 @@ from tests.kernels.utils import override_backend_env_variable from vllm.utils import STR_XFORMERS_ATTN_VAL -from .utils import check_logprobs_close, check_logprobs_close_encoder_decoder +from .utils import check_logprobs_close -MODELS = ["facebook/bart-base","facebook/bart-large-cnn"] +MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] # Backends under test # @@ -18,7 +18,7 @@ @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float","bfloat16"]) +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -43,22 +43,11 @@ def test_models( hf_outputs = hf_model.generate_encoder_decoder_greedy_logprobs_limit( example_encoder_decoder_prompts, max_tokens, num_logprobs) - decoder_input_ids_list = [hf_model.tokenizer(decoder_prompt, - return_tensors="pt").input_ids - for decoder_prompt in example_encoder_decoder_prompts[1]] - with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( example_encoder_decoder_prompts, max_tokens, num_logprobs) - # print(hf_outputs) - # print("\n\n\n\n\n") - # print(vllm_outputs) - - check_logprobs_close_encoder_decoder( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - decoder_input_ids_list=decoder_input_ids_list, - name_0="hf", - name_1="vllm" - ) + check_logprobs_close(outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm") diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c2a5be70342f0..a677c013c33c6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -371,8 +371,9 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: self._free_seq_group(aborted_group) - def _free_seq_group(self, - seq_group: SequenceGroup, + def _free_seq_group( + self, + seq_group: SequenceGroup, ) -> None: """ Free a sequence group from a cross-attention block table. diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 332328302b91e..8e892052dc396 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -75,21 +75,26 @@ def forward(self, input_ids: torch.Tensor, attn_type: AttentionType, if attn_type == AttentionType.ENCODER: seq_lens = attn_metadata.encoder_seq_lens - past_key_values_lens = [0]*len(seq_lens) + past_key_values_lens = [0] * len(seq_lens) else: # AttentionType.DECODER if attn_metadata.num_prefill_tokens > 0: # Prefill seq_lens = attn_metadata.seq_lens - past_key_values_lens = [0]*len(seq_lens) + past_key_values_lens = [0] * len(seq_lens) else: # Decode: infer one (1) new token per sequence seq_lens = [1] * len(attn_metadata.seq_lens) - past_key_values_lens = [seq_len-1 for seq_len in attn_metadata.seq_lens] + past_key_values_lens = [ + seq_len - 1 for seq_len in attn_metadata.seq_lens + ] positions = [] - for past_key_values_len,seq_len in zip(past_key_values_lens,seq_lens): - positions.extend(list(range(past_key_values_len,past_key_values_len+seq_len))) + for past_key_values_len, seq_len in zip(past_key_values_lens, + seq_lens): + positions.extend( + list(range(past_key_values_len, + past_key_values_len + seq_len))) positions = torch.tensor(positions, dtype=torch.long, @@ -790,11 +795,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): top_params_dict = dict(self.named_parameters()) weights_tuple_list = list(weights) - weight_names = [w[0] for w in weights_tuple_list] - - #has_shared_weight = any(['shared.weight' in wn for wn in weight_names]) - #has_encoder_embed_tokens_weight = any(['encoder.embed_tokens.weight' in wn for wn in weight_names]) - #has_decoder_embed_tokens_weight = any(['decoder.embed_tokens.weight' in wn for wn in weight_names]) shared_embedding_weight = None shared_embedding_shard_id = None @@ -804,11 +804,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = self._rename_key(name) name, shard_id = self._rename_stacked_param(name) - if 'shared.weight' in name or \ - 'encoder.embed_tokens.weight' in name \ - or 'decoder.embed_tokens.weight' in name \ - or 'lm_head.weight' in name: - assert shared_embedding_weight is None, "Conflicting embedding weights." + if ('shared.weight' in name + or 'encoder.embed_tokens.weight' in name + or 'decoder.embed_tokens.weight' in name + or 'lm_head.weight' in name): + assert shared_embedding_weight is None, ( + "Conflicting embedding weights.") shared_embedding_weight = loaded_weight shared_embedding_shard_id = shard_id @@ -860,33 +861,28 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) # Assign shared weight values - encoder_in_param = model_params_dict[ - 'encoder.embed_tokens.weight'] - encoder_in_weight_loader = getattr(encoder_in_param, - "weight_loader", - default_weight_loader) - - decoder_in_param = model_params_dict[ - 'decoder.embed_tokens.weight'] - decoder_in_weight_loader = getattr(decoder_in_param, - "weight_loader", - default_weight_loader) + encoder_in_param = model_params_dict['encoder.embed_tokens.weight'] + encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader", + default_weight_loader) + + decoder_in_param = model_params_dict['decoder.embed_tokens.weight'] + decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader", + default_weight_loader) lm_head_in_param = top_params_dict['lm_head.weight'] - lm_head_in_weight_loader = getattr(lm_head_in_param, - "weight_loader", - default_weight_loader) + lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader", + default_weight_loader) assert shared_embedding_weight is not None if shared_embedding_shard_id: encoder_in_weight_loader(encoder_in_param, shared_embedding_weight, - shared_embedding_shard_id) + shared_embedding_shard_id) decoder_in_weight_loader(decoder_in_param, shared_embedding_weight, - shared_embedding_shard_id) + shared_embedding_shard_id) lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight, - shared_embedding_shard_id) + shared_embedding_shard_id) else: encoder_in_weight_loader(encoder_in_param, shared_embedding_weight) decoder_in_weight_loader(decoder_in_param, shared_embedding_weight) - lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight) \ No newline at end of file + lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight) diff --git a/vllm/utils.py b/vllm/utils.py index 92ee1b7149246..cb3e5745d6f1c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -983,6 +983,7 @@ def parse_args(self, args=None, namespace=None): return super().parse_args(processed_args, namespace) + def is_encoder_decoder_model_config(model_config) -> bool: ''' Extract the HF encoder/decoder model flag from the ModelConfig instance. @@ -1000,4 +1001,4 @@ def is_embedding_model_config(model_config) -> bool: Return False if model_config is None. ''' return False if model_config is None else \ - model_config.embedding_mode \ No newline at end of file + model_config.embedding_mode From c95adf50adcdc315f63b276f52ac9a6a2d35b5fa Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 9 Jul 2024 16:49:34 -0400 Subject: [PATCH 314/443] scheduler supports encoder-/cross-attention & passes existing scheduler tests, but needs new encoder/decoder-specific tests --- vllm/core/scheduler.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 624029f2c6fd8..cdd7418d36f58 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1002,6 +1002,18 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # seq_id -> physical block numbers block_tables: Dict[int, List[int]] = {} + # Encoder associated with SequenceGroup + encoder_seq_data: SequenceData = \ + seq_group.get_encoder_seq().data \ + if seq_group.is_encoder_decoder() else \ + None + # Block table for cross-attention + # Also managed at SequenceGroup level + cross_block_table: List[int] = \ + self.block_manager.get_cross_block_table(seq_group) \ + if seq_group.is_encoder_decoder() else \ + None + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id seq_data[seq_id] = seq.data @@ -1041,6 +1053,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, state=seq_group.state, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but @@ -1070,10 +1084,13 @@ def free_seq(self, seq: Sequence) -> None: def free_finished_seq_groups(self) -> None: for queue in [self.running, self.swapped, self.waiting]: - self._finished_requests_ids += [ - seq_group.request_id for seq_group in queue - if seq_group.is_finished() - ] + new_finished_requests_ids = [] + for seq_group in queue: + if seq_group.is_finished(): + new_finished_requests_ids += seq_group.request_id + # Free cross-attention block table, kf it exists + self._free_seq_group(seq_group) + self._finished_requests_ids += new_finished_requests_ids self.running = deque(seq_group for seq_group in self.running if not seq_group.is_finished()) From d1343aac0fe6c0063f950e3600f9264aacb0836d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 9 Jul 2024 17:07:43 -0400 Subject: [PATCH 315/443] scheduler test passes --- tests/core/test_scheduler.py | 35 +++------- tests/core/test_scheduler_encoder_decoder.py | 72 ++++++++++++++++++++ tests/core/utils.py | 28 +++++++- 3 files changed, 107 insertions(+), 28 deletions(-) create mode 100644 tests/core/test_scheduler_encoder_decoder.py diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index bae958211cb7b..996844c12616b 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,33 +10,14 @@ from vllm.core.policy import PolicyFactory from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.lora.request import LoRARequest -from vllm.sequence import Logprob, SequenceGroup, SequenceStatus - -from .utils import create_dummy_prompt - - -def get_sequence_groups(scheduler_output): - return [s.seq_group for s in scheduler_output.scheduled_seq_groups] - - -def append_new_token(out, token_id: int): - seq_groups = get_sequence_groups(out) - for seq_group in seq_groups: - for seq in seq_group.get_seqs(): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) - - -def schedule_and_update_computed_tokens(scheduler): - metas, out = scheduler.schedule() - for s, meta in zip(out.scheduled_seq_groups, metas): - s.seq_group.update_num_computed_tokens(meta.token_chunk_size) - return metas, out - - -def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int): - seq_group.update_num_computed_tokens(token_chunk_size) - for seq in seq_group.get_seqs(): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) +from vllm.sequence import SequenceGroup, SequenceStatus + +from .utils import (create_dummy_prompt, + get_sequence_groups, + append_new_token, + schedule_and_update_computed_tokens, + append_new_token_seq_group, + ) def test_scheduler_add_seq_group(): diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py new file mode 100644 index 0000000000000..36f47f2ae9264 --- /dev/null +++ b/tests/core/test_scheduler_encoder_decoder.py @@ -0,0 +1,72 @@ +from typing import List + +import pytest # noqa + +from vllm.config import (CacheConfig, + SchedulerConfig, + ) +from vllm.core.scheduler import Scheduler +from vllm.sequence import SequenceGroup + +from .utils import (create_dummy_prompt_encoder_decoder, + get_sequence_groups, + append_new_token, + schedule_and_update_computed_tokens, + ) + +def test_scheduler_schedule_simple_encoder_decoder(): + block_size = 4 + num_seq_group = 4 + max_model_len = 16 + scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group + cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + req_id_list = [] + for i in range(num_seq_group): + # _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + req_id = str(i) + req_id_list.append(req_id) + _, _, seq_group = create_dummy_prompt_encoder_decoder( + req_id, block_size, block_size, block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Schedule seq groups prompts. + num_tokens = block_size * num_seq_group + seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) + # - Verify that sequence group cross-attention block tables are + # registered with the block manager + assert all([(req_id in scheduler.block_manager.cross_block_tables) + for req_id in req_id_list]) + assert set(get_sequence_groups(out)) == set(running) + assert out.num_batched_tokens == num_tokens + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta_list) == num_seq_group + append_new_token(out, 1) + + # Schedule seq groups generation. + seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) + # - Verify that sequence group metadata includes encoder attention + # and cross-attention metadata + assert all([not ((seq_group_meta.encoder_seq_data is None) or \ + (seq_group_meta.cross_block_table is None)) \ + for seq_group_meta in seq_group_meta_list]) + assert set(get_sequence_groups(out)) == set(running) + assert out.num_batched_tokens == num_seq_group + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta_list) == num_seq_group + append_new_token(out, 1) + + # Abort sequences + for req_id in req_id_list: + scheduler.abort_seq_group(req_id) + # - Verify that sequence group cross-attention block tables are + # NO LONGER registered with the block manager + assert req_id not in scheduler.block_manager.cross_block_tables \ No newline at end of file diff --git a/tests/core/utils.py b/tests/core/utils.py index f249f4b59a2ee..163c6c3dc4039 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -8,6 +8,7 @@ from vllm.sequence import Logprob, Sequence, SequenceGroup + def create_dummy_prompt( request_id: str, prompt_length: int, @@ -177,4 +178,29 @@ def create_seq_group_encoder_decoder( def round_up_to_next_block(seq_len: int, block_size: int) -> int: - return (seq_len + block_size - 1) // block_size \ No newline at end of file + return (seq_len + block_size - 1) // block_size + +# Helper functions for scheduler tests + +def get_sequence_groups(scheduler_output): + return [s.seq_group for s in scheduler_output.scheduled_seq_groups] + + +def append_new_token(out, token_id: int): + seq_groups = get_sequence_groups(out) + for seq_group in seq_groups: + for seq in seq_group.get_seqs(): + seq.append_token_id(token_id, {token_id: Logprob(token_id)}) + + +def schedule_and_update_computed_tokens(scheduler): + metas, out = scheduler.schedule() + for s, meta in zip(out.scheduled_seq_groups, metas): + s.seq_group.update_num_computed_tokens(meta.token_chunk_size) + return metas, out + + +def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int): + seq_group.update_num_computed_tokens(token_chunk_size) + for seq in seq_group.get_seqs(): + seq.append_token_id(token_id, {token_id: Logprob(token_id)}) \ No newline at end of file From b4a461d983ed0215777c89f6b2ecbaa754422d4e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 9 Jul 2024 17:18:56 -0400 Subject: [PATCH 316/443] formatting --- tests/core/test_scheduler.py | 9 +++------ tests/core/test_scheduler_encoder_decoder.py | 15 +++++---------- tests/core/utils.py | 5 +++-- 3 files changed, 11 insertions(+), 18 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 996844c12616b..642ece76a5d16 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -12,12 +12,9 @@ from vllm.lora.request import LoRARequest from vllm.sequence import SequenceGroup, SequenceStatus -from .utils import (create_dummy_prompt, - get_sequence_groups, - append_new_token, - schedule_and_update_computed_tokens, - append_new_token_seq_group, - ) +from .utils import (append_new_token, append_new_token_seq_group, + create_dummy_prompt, get_sequence_groups, + schedule_and_update_computed_tokens) def test_scheduler_add_seq_group(): diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py index 36f47f2ae9264..c7d18424e2476 100644 --- a/tests/core/test_scheduler_encoder_decoder.py +++ b/tests/core/test_scheduler_encoder_decoder.py @@ -2,17 +2,13 @@ import pytest # noqa -from vllm.config import (CacheConfig, - SchedulerConfig, - ) +from vllm.config import CacheConfig, SchedulerConfig from vllm.core.scheduler import Scheduler from vllm.sequence import SequenceGroup -from .utils import (create_dummy_prompt_encoder_decoder, - get_sequence_groups, - append_new_token, - schedule_and_update_computed_tokens, - ) +from .utils import (append_new_token, create_dummy_prompt_encoder_decoder, + get_sequence_groups, schedule_and_update_computed_tokens) + def test_scheduler_schedule_simple_encoder_decoder(): block_size = 4 @@ -28,7 +24,6 @@ def test_scheduler_schedule_simple_encoder_decoder(): # Add seq groups to scheduler. req_id_list = [] for i in range(num_seq_group): - # _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) req_id = str(i) req_id_list.append(req_id) _, _, seq_group = create_dummy_prompt_encoder_decoder( @@ -69,4 +64,4 @@ def test_scheduler_schedule_simple_encoder_decoder(): scheduler.abort_seq_group(req_id) # - Verify that sequence group cross-attention block tables are # NO LONGER registered with the block manager - assert req_id not in scheduler.block_manager.cross_block_tables \ No newline at end of file + assert req_id not in scheduler.block_manager.cross_block_tables diff --git a/tests/core/utils.py b/tests/core/utils.py index 163c6c3dc4039..a8dcd90af0fcf 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -8,7 +8,6 @@ from vllm.sequence import Logprob, Sequence, SequenceGroup - def create_dummy_prompt( request_id: str, prompt_length: int, @@ -180,8 +179,10 @@ def create_seq_group_encoder_decoder( def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size + # Helper functions for scheduler tests + def get_sequence_groups(scheduler_output): return [s.seq_group for s in scheduler_output.scheduled_seq_groups] @@ -203,4 +204,4 @@ def schedule_and_update_computed_tokens(scheduler): def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int): seq_group.update_num_computed_tokens(token_chunk_size) for seq in seq_group.get_seqs(): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) \ No newline at end of file + seq.append_token_id(token_id, {token_id: Logprob(token_id)}) From 6a71f8f4359dab04b9811b84d338db40dafa72bc Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 9 Jul 2024 17:23:01 -0400 Subject: [PATCH 317/443] formatting --- tests/core/test_scheduler_encoder_decoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py index c7d18424e2476..24c2cfdf8c57d 100644 --- a/tests/core/test_scheduler_encoder_decoder.py +++ b/tests/core/test_scheduler_encoder_decoder.py @@ -49,8 +49,8 @@ def test_scheduler_schedule_simple_encoder_decoder(): seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) # - Verify that sequence group metadata includes encoder attention # and cross-attention metadata - assert all([not ((seq_group_meta.encoder_seq_data is None) or \ - (seq_group_meta.cross_block_table is None)) \ + assert all([not ((seq_group_meta.encoder_seq_data is None) or + (seq_group_meta.cross_block_table is None)) for seq_group_meta in seq_group_meta_list]) assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_seq_group From 9a63f51bde8059fc361cc7abb2249ce1efb54163 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 10 Jul 2024 12:50:40 -0400 Subject: [PATCH 318/443] wip model runner --- vllm/worker/enc_dec_model_runner.py | 199 ++++++++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 vllm/worker/enc_dec_model_runner.py diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py new file mode 100644 index 0000000000000..5571eb43e568f --- /dev/null +++ b/vllm/worker/enc_dec_model_runner.py @@ -0,0 +1,199 @@ +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, MultiModalConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.pooling_params import PoolingParams +from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, + SequenceGroupMetadata) +from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU + +logger = init_logger(__name__) + +@dataclasses.dataclass(frozen=True) +class EncoderDecoderModelInput(ModelInputForGPU): + """ + Used by the EncoderDecoderModelRunner. + """ + encoder_input_tokens: Optional[torch.Tensor] = None + encoder_input_positions: Optional[torch.Tensor] = None + +class EncoderDecoderModelRunner( + GPUModelRunnerBase[EncoderDecoderModelInput]): + _model_input_cls: Type[EncoderDecoderModelInput] = ( + EncoderDecoderModelInput) + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + multimodal_config: Optional[MultiModalConfig] = None, + ): + super().__init__(model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + multimodal_config=multimodal_config) + + @torch.inference_mode() + def execute_model( + self, + model_input: EncoderDecoderModelInput, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[PoolerOutput]]: + if num_steps > 1: + raise ValueError("num_steps > 1 is not supported in ModelRunner") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + if self.attn_backend.get_name() == "flashinfer": + raise NotImplementedError("FlashInfer is currently not supported " + "for encoder/decoder models.") + + # Currently cuda graph is not supported for encoder/decoder models + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: + raise NotImplementedError("CUDAGraph is currently not supported " + "for encoder/decoder models.") + # TODO(andoorve): We can remove this once all + # virtual engines share the same kv cache. + # virtual_engine = model_input.virtual_engine + # if prefill_meta is None and decode_meta.use_cuda_graph: + # assert model_input.input_tokens is not None + # graph_batch_size = model_input.input_tokens.shape[0] + # model_executable = self.graph_runners[virtual_engine][ + # graph_batch_size] + # else: + model_executable = self.model + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_seqlen_agnostic else {} + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + encoder_input_ids=model_input.encoder_input_tokens, + encoder_positions=model_input.encoder_input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **multi_modal_kwargs, + **seqlen_agnostic_kwargs) + + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + if not self.is_driver_worker: + return [] + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + assert model_input.sampling_metadata is not None + indices = model_input.sampling_metadata.selected_token_indices + if model_input.is_prompt: + hidden_states = hidden_or_intermediate_states.index_select( + 0, indices) + elif decode_meta.use_cuda_graph: + hidden_states = hidden_or_intermediate_states[:len(indices)] + else: + hidden_states = hidden_or_intermediate_states + + output.hidden_states = hidden_states + + return [output] + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, + Any]) -> ModelInputForGPUWithPoolingMetadata: + return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForGPUWithPoolingMetadata: + assert seq_group_metadata_list is not None + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Prepare PoolingMetadata. + assert model_input.seq_lens is not None + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + model_input.seq_lens) + + return dataclasses.replace(model_input, + pooling_metadata=pooling_metadata) + + def _prepare_pooling( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> PoolingMetadata: + """Prepare PoolingMetadata for the sequence group metadata list.""" + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + seq_groups.append((seq_ids, pooling_params)) + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + pooling_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + ) + + return pooling_metadata From 685604cfcb90b6e74e37dbf5b5ee478e157f8191 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 12 Jul 2024 09:40:42 -0400 Subject: [PATCH 319/443] wip modelrunner --- .../test_encoder_decoder_model_runner.py | 297 +++----- vllm/worker/enc_dec_model_runner.py | 675 +++++++++++++++++- 2 files changed, 743 insertions(+), 229 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 88b982bb8fdc2..d696cbfa57177 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -21,8 +21,7 @@ ENFORCE_EAGER = [True] -def _create_model_runner(model: str, *args, - **kwargs) -> EncoderDecoderModelRunner: +def _create_model_runner(model: str, *args, **kwargs) -> EncoderDecoderModelRunner: engine_args = EngineArgs(model, *args, **kwargs) engine_config = engine_args.create_engine_config() model_runner = EncoderDecoderModelRunner( @@ -33,10 +32,73 @@ def _create_model_runner(model: str, *args, cache_config=engine_config.cache_config, load_config=engine_config.load_config, lora_config=engine_config.lora_config, + prompt_adapter_config=engine_config.prompt_adapter_config, is_driver_worker=True, ) return model_runner +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) +def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): + """Verify prepare prompt and decode returns empty output.""" + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + model_runner = _create_model_runner("facebook/bart-base", + seed=0, + dtype="float16", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=enforce_eager, + ) + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + model_input = model_runner._prepare_model_input_tensors( + seq_group_metadata_list) + (input_tokens, + input_positions, + encoder_input_tokens, + encoder_input_positions, + attn_metadata, + ) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.encoder_input_tokens, + model_input.encoder_input_positions, + model_input.attn_metadata, + ) + assert input_tokens is None + assert input_positions is None + assert encoder_input_tokens is None + assert encoder_input_positions is None + assert attn_metadata is None + + model_input = model_runner._prepare_model_input_tensors( + seq_group_metadata_list) + (input_tokens, + input_positions, + encoder_input_tokens, + encoder_input_positions, + attn_metadata, + return_seq_lens, + ) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.encoder_input_tokens, + model_input.encoder_input_positions, + model_input.attn_metadata, + model_input.seq_lens, + ) + assert input_tokens is None + assert input_positions is None + assert encoder_input_tokens is None + assert encoder_input_positions is None + assert attn_metadata is None + assert return_seq_lens is None + + + @pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -47,10 +109,13 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): override_backend_env_variable(monkeypatch, backend_name) model_runner = _create_model_runner("facebook/bart-base", + seed=0, + dtype="float16", max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, - enforce_eager=enforce_eager) + enforce_eager=enforce_eager, + ) seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] @@ -72,7 +137,8 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table) + cross_block_table=cross_block_table, + ) assert seq_group_metadata.token_chunk_size == seq_data.get_len() seq_group_metadata_list.append(seq_group_metadata) @@ -84,17 +150,18 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): selected_token_start_idx += seq_len # Decoder model input - model_input = model_runner._prepare_model_input(seq_group_metadata_list) + model_input = model_runner._prepare_model_input_tensors( + seq_group_metadata_list) input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata return_seq_lens = model_input.seq_lens - slot_mapping = model_input.slot_mapping + slot_mapping = attn_metadata.slot_mapping assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) # Encoder model input - encoder_model_input = model_runner._prepare_encoder_model_input( + encoder_model_input = model_runner._prepare_encoder_model_input_tensors( seq_group_metadata_list, attn_metadata) encoder_input_tokens = encoder_model_input.input_tokens encoder_input_positions = encoder_model_input.input_positions @@ -128,7 +195,8 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): start_loc.append(start_idx) assert torch.allclose( attn_metadata.query_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device)) + torch.tensor(start_loc, dtype=torch.int32, device=device), + ) # Test decoder seq start locs. Note that for normal prefill it is # equivalent to query_start_loc. @@ -140,49 +208,57 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): assert torch.allclose( attn_metadata.seq_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device)) + torch.tensor(start_loc, dtype=torch.int32, device=device), + ) assert torch.allclose( attn_metadata.context_lens_tensor, torch.zeros(attn_metadata.context_lens_tensor.shape[0], dtype=torch.int, - device=device)) + device=device), + ) # Verify block tables are correct for prompts # - Decoder self-attention expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], dtype=torch.int32, - device=model_runner.device) - assert torch.allclose(attn_metadata.block_tables, expected) + device=model_runner.device, + ) + assert torch.allclose(attn_metadata.block_tables, + expected, + ) # - Encoder/decoder cross-attention - # expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], - # dtype=torch.int32, - # device=model_runner.device) - assert torch.allclose(attn_metadata.cross_block_tables, expected) + assert torch.allclose(attn_metadata.cross_block_tables, + expected, + ) - # Cuda graph should not be used for prefill, regardless of - # enforce_eager setting + # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False # Verify the lengths of input tokens & positions # - Decoder assert len(input_tokens) == sum(seq_lens) assert len(input_positions) == sum(seq_lens) - torch.testing.assert_close(input_tokens, input_positions) + torch.testing.assert_close(input_tokens, + input_positions,) # - Encoder assert len(encoder_input_tokens) == sum(encoder_seq_lens) assert len(encoder_input_tokens) == sum(encoder_seq_lens) - torch.testing.assert_close(encoder_input_tokens, encoder_input_positions) + torch.testing.assert_close(encoder_input_tokens, + encoder_input_positions, + ) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens=seq_lens, device=model_runner.device, - pin_memory=model_runner.pin_memory) + pin_memory=model_runner.pin_memory,) + actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, - dtype=actual.dtype) + dtype=actual.dtype, + ) torch.testing.assert_close(actual, expected) torch.allclose(input_tokens, input_positions) @@ -190,179 +266,4 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): expected = torch.tensor(expected_selected_token_indices, device=actual.device, dtype=actual.dtype) - torch.testing.assert_close(actual, expected) - - -@pytest.mark.parametrize("batch_size", list(range(1, 257))) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) -def test_prepare_decode(batch_size, backend_name, enforce_eager, monkeypatch): - - # Force Attention wrapper backend - override_backend_env_variable(monkeypatch, backend_name) - - model_runner = _create_model_runner("facebook/bart-base", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=enforce_eager) - - seq_lens: List[int] = [] - encoder_seq_lens: List[int] = [] - seq_group_metadata_list: List[SequenceGroupMetadata] = [] - block_tables = {0: [1]} - cross_block_table = [2] - for i in range(batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) - seq_data = SequenceData(list(range(seq_len))) - encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_lens.append(encoder_seq_len) - encoder_seq_data = SequenceData(list(range(encoder_seq_len))) - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table) - assert seq_group_metadata.token_chunk_size == 1 - seq_group_metadata_list.append(seq_group_metadata) - - # Decoder model input - model_input = model_runner._prepare_model_input(seq_group_metadata_list) - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - slot_mapping = model_input.slot_mapping - assert return_seq_lens == seq_lens - assert len(slot_mapping) == len(input_tokens) - - # Encoder model input - encoder_model_input = model_runner._prepare_encoder_model_input( - seq_group_metadata_list, attn_metadata) - encoder_input_tokens = encoder_model_input.input_tokens - encoder_input_positions = encoder_model_input.input_positions - return_encoder_seq_lens = attn_metadata.encoder_seq_lens - cross_slot_mapping = attn_metadata.cross_slot_mapping - assert return_encoder_seq_lens == encoder_seq_lens - assert len(cross_slot_mapping) == len(encoder_input_tokens) - - # Verify input metadata is correct for decode phase. - # - Decoder attention metadata - device = model_runner.device - assert attn_metadata.num_prefills == 0 - assert attn_metadata.num_decode_tokens > 0 - assert torch.allclose( - attn_metadata.seq_lens_tensor, - torch.tensor(seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_prefill_seq_len == 0 - assert attn_metadata.max_decode_seq_len == max(seq_lens) - # - Encoder attention metadata - assert attn_metadata.encoder_seq_lens == encoder_seq_lens - assert torch.allclose( - attn_metadata.encoder_seq_lens_tensor, - torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) - assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) - - # Test decoder subquery start locs. - start_idx = 0 - start_loc = [start_idx] - for seq_len in seq_lens: - # 1 decoded token per sequence - start_idx += 1 - start_loc.append(start_idx) - assert torch.allclose( - attn_metadata.query_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device)) - - # Test decoder seq start locs. Note that for normal prefill it is - # equivalent to query_start_loc. - start_idx = 0 - seq_start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += seq_len - seq_start_loc.append(start_idx) - - assert torch.allclose( - attn_metadata.seq_start_loc, - torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) - assert torch.allclose( - attn_metadata.context_lens_tensor, - torch.tensor([seq_len - 1 for seq_len in seq_lens], - dtype=torch.int, - device=device)) - - # Verify block tables are correct for prompts - # - Decoder self-attention - expected = torch.tensor( - [block_tables[0] for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device) - assert torch.allclose(attn_metadata.block_tables, expected) - # - Encoder/decoder cross-attention - expected = torch.tensor( - [cross_block_table for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device) - assert torch.allclose(attn_metadata.cross_block_tables, expected) - - # Cuda graph should not be used for prefill. - assert attn_metadata.use_cuda_graph == (not enforce_eager) - - # Verify the lengths of input tokens & positions - # - Decoder - assert len(input_tokens) == len(seq_lens) - assert len(input_positions) == len(seq_lens) - torch.testing.assert_close(input_tokens, input_positions) - # - Encoder - assert len(encoder_input_tokens) == 0 - assert len(encoder_input_positions) == 0 - - -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) -def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): - """Verify prepare prompt and decode returns empty output.""" - - # Force Attention wrapper backend - override_backend_env_variable(monkeypatch, backend_name) - - model_runner = _create_model_runner( - "facebook/bart-base", - seed=0, - dtype="float16", - enforce_eager=enforce_eager, - ) - seq_group_metadata_list: List[SequenceGroupMetadata] = [] - model_input = model_runner._prepare_model_input(seq_group_metadata_list) - input_tokens, input_positions, attn_metadata, slot_mapping = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - model_input.slot_mapping, - ) - assert len(input_tokens) == 0 - assert len(input_positions) == 0 - assert attn_metadata is None - assert len(slot_mapping) == 0 - - model_input = model_runner._prepare_model_input(seq_group_metadata_list) - (input_tokens, input_positions, attn_metadata, slot_mapping, - return_seq_lens) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - model_input.slot_mapping, - model_input.seq_lens, - ) - assert len(input_tokens) == 0 - assert len(input_positions) == 0 - assert attn_metadata is None - assert len(slot_mapping) == 0 - assert len(return_seq_lens) == 0 + torch.testing.assert_close(actual, expected) \ No newline at end of file diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 5571eb43e568f..0675a300e2719 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -11,7 +11,76 @@ from vllm.pooling_params import PoolingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) -from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU +from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPU, + LORA_WARMUP_RANK, + _BATCH_SIZES_TO_CAPTURE, + _PAD_SLOT_ID, + ) +from vllm.distributed import get_pp_group +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) + +import dataclasses +import gc +import time +import warnings +from collections import defaultdict +from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, + Tuple, Type, TypeVar, Union) + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn + +try: + from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, MultiModalConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.distributed import get_pp_group +from vllm.distributed.parallel_state import graph_capture +from vllm.inputs import INPUT_REGISTRY +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.model_executor.models.interfaces import (supports_lora, + supports_vision) +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, + MultiModalInputs) +from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.prompt_adapter.worker_manager import ( + LRUCacheWorkerPromptAdapterManager) +from vllm.sampling_params import SamplingParams +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata) +from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, + is_pin_memory_available, make_tensor_with_pad) +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) @@ -139,8 +208,8 @@ def execute_model( if model_input.is_prompt: hidden_states = hidden_or_intermediate_states.index_select( 0, indices) - elif decode_meta.use_cuda_graph: - hidden_states = hidden_or_intermediate_states[:len(indices)] + # elif decode_meta.use_cuda_graph: + # hidden_states = hidden_or_intermediate_states[:len(indices)] else: hidden_states = hidden_or_intermediate_states @@ -151,49 +220,593 @@ def execute_model( def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, - Any]) -> ModelInputForGPUWithPoolingMetadata: - return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( + Any]) -> EncoderDecoderModelInput: + return EncoderDecoderModelInput.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, ) def prepare_model_input( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithPoolingMetadata: - assert seq_group_metadata_list is not None + ) -> EncoderDecoderModelInput: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) - # Prepare PoolingMetadata. - assert model_input.seq_lens is not None - pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - model_input.seq_lens) - + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + self.pin_memory) + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) return dataclasses.replace(model_input, - pooling_metadata=pooling_metadata) + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine) + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config: + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] - def _prepare_pooling( + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for vision encoding, which needs + # to be accounted for when calculating the GPU blocks for + # vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + model_config = self.model_config + + if supports_vision(self.model): + max_mm_tokens = MULTIMODAL_REGISTRY \ + .get_max_multimodal_tokens(model_config) + max_num_seqs_orig = max_num_seqs + max_num_seqs = min(max_num_seqs, + max_num_batched_tokens // max_mm_tokens) + if max_num_seqs < 1: + expr = (f"min({max_num_seqs_orig}, " + f"{max_num_batched_tokens} // {max_mm_tokens})") + logger.warning( + "Computed max_num_seqs (%s) to be less than 1. " + "Setting it to the minimum value of 1.", expr) + max_num_seqs = 1 + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ + .dummy_data_for_profiling(model_config, seq_len) + + # Having more tokens is over-conservative but otherwise fine + assert len(seq_data.prompt_token_ids) >= seq_len, ( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but got: {len(seq_data.prompt_token_ids)}") + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, + encoder_seq_data=seq_data, + cross_block_table=None, + multi_modal_data=dummy_multi_modal_data, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + self.execute_model(model_input, kv_caches, intermediate_tensors) + torch.cuda.synchronize() + return + + def _prepare_encoder_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - ) -> PoolingMetadata: - """Prepare PoolingMetadata for the sequence group metadata list.""" - seq_groups: List[Tuple[List[int], PoolingParams]] = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - pooling_params = seq_group_metadata.pooling_params - seq_groups.append((seq_ids, pooling_params)) + model_input: EncoderDecoderModelInput, + finished_requests_ids: Optional[List[str]] = None + ) -> EncoderDecoderModelInput: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] + lora_requests: Set[LoRARequest] = set() + prompt_adapter_index_mapping: List[int] = [] + prompt_adapter_prompt_mapping: List[int] = [] + prompt_adapter_requests: Set[PromptAdapterRequest] = set() + + seq_lens: List[int] = [] + prefill_seq_lens: List[int] = [] + decode_seq_lens: List[int] = [] + context_lens: List[int] = [] + query_lens: List[int] = [] + block_tables: List[List[int]] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] + request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list) + decode_only = True + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = 0 + + # The following fields are only for flashinfer + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + paged_kv_last_page_len: List[int] = [] + + if len(seq_group_metadata_list) == 0: + # Leave the encoder/cross-attention input + # fields at default values if the seq group + # metadata list arg is an empty list + return model_input + + if self.sliding_window is not None: + raise NotImplementedError() + # sliding_window_blocks = (self.sliding_window + self.block_size - + # 1) // self.block_size + # block_aligned_sliding_window = \ + # sliding_window_blocks * self.block_size - seq_data: Dict[int, SequenceData] = {} for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) + seq_ids = list(seq_group_metadata.seq_data.keys()) + is_prompt = seq_group_metadata.is_prompt - pooling_metadata = PoolingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - ) + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + seq_data = seq_group_metadata.encoder_seq_data + cross_block_table = seq_group_metadata.cross_block_table + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_data.get_len() - 1 - return pooling_metadata + seq_len = min( + seq_data.get_len(), + context_len + seq_group_metadata.token_chunk_size) + if is_prompt: + tokens = seq_data.get_token_ids()[context_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + + # Prefix cache was hit. + # Prefix is not supported with sliding_window + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and is_prompt) + + # These are seq_len/context_len capped to the sliding window. + # They are passed to decode kernel. + # We still need original seq_len/context_len to compute slot + # mapping (and input position) below. + curr_sliding_window_blocks = None + sliding_seq_len = seq_len + sliding_context_len = context_len + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + # if (self.sliding_window is not None and not is_prompt): + # curr_sliding_window_blocks = sliding_window_blocks + # if self.scheduler_config.use_v2_block_manager: + # # number of elements in last block + # suff_len = seq_len % self.block_size + # sliding_seq_len = min( + # seq_len, block_aligned_sliding_window + suff_len) + # if suff_len > 0: + # curr_sliding_window_blocks += 1 + # else: + # sliding_seq_len = min(seq_len, self.sliding_window) + # sliding_context_len = sliding_seq_len - 1 + + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + context_len = len(computed_block_nums) * self.block_size + tokens = tokens[context_len:] + + # need to think what to set it to when we have both sliding + # window and prefix caching... + assert self.sliding_window is None, \ + "Prefix caching is not supported with sliding window" + sliding_context_len = context_len + + if self.attn_backend.get_name() == "flash-attn": + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + # TODO(woosuk): This is a temporary fix. We should + # provide a unified interface for different backends. + block_table = cross_block_table + else: + block_table = computed_block_nums + elif (self.scheduler_config.chunked_prefill_enabled + or not is_prompt): + if cross_block_table is not None: + # chunked prefill or decode + block_table = cross_block_table + if curr_sliding_window_blocks is not None: + block_table = block_table[ + -curr_sliding_window_blocks:] + else: + # Only happens when memory profiling runs. + block_table = [] + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + block_tables.append(block_table) + + seq_lens.append(sliding_seq_len) + context_lens.append(sliding_context_len) + query_len = sliding_seq_len - sliding_context_len + query_lens.append(query_len) + input_tokens.extend(tokens) + input_positions.extend(list(range(context_len, seq_len))) + lora_id = seq_group_metadata.lora_int_id + prompt_adapter_id = seq_group_metadata.prompt_adapter_id + + if is_prompt: + assert len(seq_ids) == 1 + num_prefills += 1 + num_prefill_tokens += len(tokens) + decode_only = False + prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + num_decode_tokens += query_len + decode_seq_lens.append(sliding_seq_len) + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * query_len + lora_prompt_mapping.extend( + [lora_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + is not None else 1)) + + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + # Process multi-modal data + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + + if prompt_adapter_id > 0 and is_prompt: + prompt_adapter_requests.add( + seq_group_metadata.prompt_adapter_request) + + num_tokens = seq_group_metadata.\ + prompt_adapter_num_virtual_tokens + pm = [prompt_adapter_id + ] * num_tokens + [0] * (query_len - num_tokens) + prompt_adapter_index_mapping += pm + prompt_adapter_prompt_mapping.extend( + [prompt_adapter_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + + is_profile_run = _is_single_block_table_empty( + seq_group_metadata.block_tables) + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + + # Compute the slot mapping. + block_table = cross_block_table + + # Mask the [0, start_idx) tokens of the prompt with + # _PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + if is_prompt: + assert self.scheduler_config.use_v2_block_manager \ + or context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager") + # It is an optimization. When it is decoding, it is always + # 0. When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + for i in range(context_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + # Prepare input tensors for flashinfer + if self.attn_backend.get_name() == "flashinfer": + assert False + + batch_size = len(input_tokens) + max_query_len = max(query_lens) + max_seq_len = (max(prefill_seq_lens, default=0) if is_prompt else + max(decode_seq_lens, default=0)) + + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. + use_captured_graph = ( + decode_only and not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_seq_len <= self.max_seq_len_to_capture) + if use_captured_graph: + assert False + + if use_captured_graph: + assert False + else: + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + + logits_soft_cap = getattr(self.model_config.hf_config, + 'attn_logit_softcapping', None) + if logits_soft_cap is not None and self.attn_backend.get_name( + ) != "flashinfer": + raise ValueError("Please use Flashinfer backend for models with" + "logits_soft_cap (i.e., Gemma-2)." + " Otherwise, the output might be wrong." + " Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + + if self.attn_backend.get_name() == "flashinfer": + if len(paged_kv_indptr) > 0: + paged_kv_indices_tensor = torch.tensor(paged_kv_indices, + device='cpu', + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, + device='cpu', + dtype=torch.int) + paged_kv_last_page_len_tensor = torch.tensor( + paged_kv_last_page_len, device='cpu', dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_len_tensor = None + + kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, + self.model_config.dtype) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + max_prefill_seq_len=max_seq_len, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + num_qo_heads=self.model_config.get_num_attention_heads( + self.parallel_config), + num_kv_heads=self.model_config.get_num_kv_heads( + self.parallel_config), + head_dim=self.model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc, + device=self.device, + data_type=kv_cache_dtype, + use_cuda_graph=use_captured_graph, + logits_soft_cap=logits_soft_cap) + + else: + attn_metadata = self.attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_seq_len, + max_decode_seq_len=max_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + if self.lora_config: + lora_mapping = LoRAMapping( + lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + + if self.prompt_adapter_config: + prompt_adapter_mapping = PromptAdapterMapping( + prompt_adapter_index_mapping, + prompt_adapter_prompt_mapping, + ) + else: + prompt_adapter_mapping = None + + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, + device=self.device) + request_ids_to_seq_ids = { + seq_group_metadata.request_id: + list(seq_group_metadata.seq_data.keys()) + for seq_group_metadata in seq_group_metadata_list + } + return self._model_input_cls( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_mapping=lora_mapping, + lora_requests=lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + request_ids_to_seq_ids=request_ids_to_seq_ids, + finished_requests_ids=finished_requests_ids, + prompt_adapter_mapping=prompt_adapter_mapping, + prompt_adapter_requests=prompt_adapter_requests, + ) + +def _is_single_block_table_empty(block_table: Optional[List[int]]): + """ + Check if a single block table has not been constructed + """ + if block_table is None: + return True + return False \ No newline at end of file From 196f30cd7f25a682dc3d2320d994f949b00084a2 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 12 Jul 2024 11:15:56 -0400 Subject: [PATCH 320/443] enc/dec decoder test working, sans sampling check --- tests/core/test_scheduler_encoder_decoder.py | 8 +- .../test_encoder_decoder_model_runner.py | 337 ++++++++++++++---- vllm/worker/enc_dec_model_runner.py | 273 +++++--------- 3 files changed, 368 insertions(+), 250 deletions(-) diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py index 24c2cfdf8c57d..4c5fa1983be34 100644 --- a/tests/core/test_scheduler_encoder_decoder.py +++ b/tests/core/test_scheduler_encoder_decoder.py @@ -49,9 +49,11 @@ def test_scheduler_schedule_simple_encoder_decoder(): seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) # - Verify that sequence group metadata includes encoder attention # and cross-attention metadata - assert all([not ((seq_group_meta.encoder_seq_data is None) or - (seq_group_meta.cross_block_table is None)) - for seq_group_meta in seq_group_meta_list]) + assert all([ + not ((seq_group_meta.encoder_seq_data is None) or + (seq_group_meta.cross_block_table is None)) + for seq_group_meta in seq_group_meta_list + ]) assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_seq_group assert (not out.blocks_to_copy and not out.blocks_to_swap_in diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index d696cbfa57177..bbb2961915b1f 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -21,7 +21,8 @@ ENFORCE_EAGER = [True] -def _create_model_runner(model: str, *args, **kwargs) -> EncoderDecoderModelRunner: +def _create_model_runner(model: str, *args, + **kwargs) -> EncoderDecoderModelRunner: engine_args = EngineArgs(model, *args, **kwargs) engine_config = engine_args.create_engine_config() model_runner = EncoderDecoderModelRunner( @@ -37,6 +38,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> EncoderDecoderModelRunn ) return model_runner + @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): @@ -45,23 +47,25 @@ def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) - model_runner = _create_model_runner("facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=enforce_eager, - ) + model_runner = _create_model_runner( + "facebook/bart-base", + seed=0, + dtype="float16", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=enforce_eager, + ) seq_group_metadata_list: List[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - (input_tokens, - input_positions, - encoder_input_tokens, - encoder_input_positions, - attn_metadata, - ) = ( + ( + input_tokens, + input_positions, + encoder_input_tokens, + encoder_input_positions, + attn_metadata, + ) = ( model_input.input_tokens, model_input.input_positions, model_input.encoder_input_tokens, @@ -76,13 +80,14 @@ def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - (input_tokens, - input_positions, - encoder_input_tokens, - encoder_input_positions, - attn_metadata, - return_seq_lens, - ) = ( + ( + input_tokens, + input_positions, + encoder_input_tokens, + encoder_input_positions, + attn_metadata, + return_seq_lens, + ) = ( model_input.input_tokens, model_input.input_positions, model_input.encoder_input_tokens, @@ -98,8 +103,6 @@ def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): assert return_seq_lens is None - - @pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) @@ -108,14 +111,15 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) - model_runner = _create_model_runner("facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=enforce_eager, - ) + model_runner = _create_model_runner( + "facebook/bart-base", + seed=0, + dtype="float16", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=enforce_eager, + ) seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] @@ -149,23 +153,31 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): seq_len - 1) selected_token_start_idx += seq_len - # Decoder model input - model_input = model_runner._prepare_model_input_tensors( - seq_group_metadata_list) - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens + # Build decoder model inputs & + # decoder self-attention KV caching data structures + decoder_only_model_input = ( + model_runner._prepare_model_input_tensors( + seq_group_metadata_list)) + input_tokens = decoder_only_model_input.input_tokens + input_positions = decoder_only_model_input.input_positions + attn_metadata = decoder_only_model_input.attn_metadata + return_seq_lens = decoder_only_model_input.seq_lens slot_mapping = attn_metadata.slot_mapping assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) - # Encoder model input - encoder_model_input = model_runner._prepare_encoder_model_input_tensors( - seq_group_metadata_list, attn_metadata) - encoder_input_tokens = encoder_model_input.input_tokens - encoder_input_positions = encoder_model_input.input_positions + # Augment model input data structure with encoder model + # inputs & encoder/decoder cross-attention KV caching + # data structures + encoder_decoder_model_input = ( + model_runner._prepare_encoder_model_input_tensors( + seq_group_metadata_list, decoder_only_model_input)) + encoder_input_tokens = encoder_decoder_model_input.encoder_input_tokens + encoder_input_positions = encoder_decoder_model_input.encoder_input_positions + attn_metadata = encoder_decoder_model_input.attn_metadata cross_slot_mapping = attn_metadata.cross_slot_mapping + return_encoder_seq_lens = encoder_decoder_model_input.attn_metadata.encoder_seq_lens + assert return_encoder_seq_lens == encoder_seq_lens assert len(cross_slot_mapping) == len(encoder_input_tokens) # Verify input metadata is correct for prompts. @@ -196,7 +208,7 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): assert torch.allclose( attn_metadata.query_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device), - ) + ) # Test decoder seq start locs. Note that for normal prefill it is # equivalent to query_start_loc. @@ -209,27 +221,30 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): assert torch.allclose( attn_metadata.seq_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device), - ) + ) assert torch.allclose( attn_metadata.context_lens_tensor, torch.zeros(attn_metadata.context_lens_tensor.shape[0], dtype=torch.int, device=device), - ) + ) # Verify block tables are correct for prompts # - Decoder self-attention - expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device, - ) - assert torch.allclose(attn_metadata.block_tables, - expected, - ) + expected = torch.tensor( + [[] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device, + ) + assert torch.allclose( + attn_metadata.block_tables, + expected, + ) # - Encoder/decoder cross-attention - assert torch.allclose(attn_metadata.cross_block_tables, - expected, - ) + assert torch.allclose( + attn_metadata.cross_block_tables, + expected, + ) # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False @@ -238,27 +253,32 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): # - Decoder assert len(input_tokens) == sum(seq_lens) assert len(input_positions) == sum(seq_lens) - torch.testing.assert_close(input_tokens, - input_positions,) + torch.testing.assert_close( + input_tokens, + input_positions, + ) # - Encoder assert len(encoder_input_tokens) == sum(encoder_seq_lens) assert len(encoder_input_tokens) == sum(encoder_seq_lens) - torch.testing.assert_close(encoder_input_tokens, - encoder_input_positions, - ) + torch.testing.assert_close( + encoder_input_tokens, + encoder_input_positions, + ) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens=seq_lens, device=model_runner.device, - pin_memory=model_runner.pin_memory,) + pin_memory=model_runner.pin_memory, + ) actual = sampling_metadata.selected_token_indices - expected = torch.tensor(expected_selected_token_indices, - device=actual.device, - dtype=actual.dtype, - ) + expected = torch.tensor( + expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype, + ) torch.testing.assert_close(actual, expected) torch.allclose(input_tokens, input_positions) @@ -266,4 +286,185 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): expected = torch.tensor(expected_selected_token_indices, device=actual.device, dtype=actual.dtype) - torch.testing.assert_close(actual, expected) \ No newline at end of file + torch.testing.assert_close(actual, expected) + + +@pytest.mark.parametrize("batch_size", list(range(1, 257))) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) +def test_prepare_decode(batch_size, backend_name, enforce_eager, monkeypatch): + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + model_runner = _create_model_runner( + "facebook/bart-base", + seed=0, + dtype="float16", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=enforce_eager, + ) + + seq_lens: List[int] = [] + encoder_seq_lens: List[int] = [] + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + block_tables = {0: [1]} + cross_block_table = [2] + for i in range(batch_size): + # make sure all tokens fit into one block + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) + encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 + encoder_seq_lens.append(encoder_seq_len) + encoder_seq_data = SequenceData(list(range(encoder_seq_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) + + # Build decoder model inputs & + # decoder self-attention KV caching data structures + decoder_only_model_input = ( + model_runner._prepare_model_input_tensors( + seq_group_metadata_list)) + input_tokens = decoder_only_model_input.input_tokens + input_positions = decoder_only_model_input.input_positions + attn_metadata = decoder_only_model_input.attn_metadata + return_seq_lens = decoder_only_model_input.seq_lens + slot_mapping = attn_metadata.slot_mapping + assert return_seq_lens == seq_lens + assert len(slot_mapping) == len(input_tokens) + + # Augment model input data structure with encoder model + # inputs & encoder/decoder cross-attention KV caching + # data structures + encoder_decoder_model_input = ( + model_runner._prepare_encoder_model_input_tensors( + seq_group_metadata_list, decoder_only_model_input)) + encoder_input_tokens = encoder_decoder_model_input.encoder_input_tokens + encoder_input_positions = encoder_decoder_model_input.encoder_input_positions + attn_metadata = encoder_decoder_model_input.attn_metadata + return_encoder_seq_lens = attn_metadata.encoder_seq_lens + cross_slot_mapping = attn_metadata.cross_slot_mapping + assert return_encoder_seq_lens == encoder_seq_lens + assert len(cross_slot_mapping) == len(encoder_input_tokens) + + # Verify input metadata is correct for decode phase. + # - Decoder attention metadata + device = model_runner.device + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_decode_tokens > 0 + assert torch.allclose( + attn_metadata.seq_lens_tensor, + torch.tensor(seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.max_prefill_seq_len == 0 + assert attn_metadata.max_decode_seq_len == max(seq_lens) + # - Encoder attention metadata + assert attn_metadata.encoder_seq_lens == encoder_seq_lens + assert torch.allclose( + attn_metadata.encoder_seq_lens_tensor, + torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) + assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) + + # Test decoder subquery start locs. + start_idx = 0 + start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += 1 + start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.query_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device), + ) + + # Test decoder seq start locs. Note that for normal prefill it is + # equivalent to query_start_loc. + start_idx = 0 + seq_start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + seq_start_loc.append(start_idx) + + assert torch.allclose( + attn_metadata.seq_start_loc, + torch.tensor(seq_start_loc, dtype=torch.int32, device=device), + ) + assert torch.allclose( + attn_metadata.context_lens_tensor, + torch.tensor([seq_len - 1 for seq_len in seq_lens], + dtype=torch.int, + device=device)) + + # Verify block tables are correct for prompts + # - Decoder self-attention + expected = torch.tensor( + [block_tables[0] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.allclose( + attn_metadata.block_tables, + expected, + ) + # - Encoder/decoder cross-attention + expected = torch.tensor( + [cross_block_table for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.allclose( + attn_metadata.cross_block_tables, + expected, + ) + + # Cuda graph should is currently not supported for encoder/decoer. + assert attn_metadata.use_cuda_graph is False + + # Verify the lengths of input tokens & positions + # - Decoder + assert len(input_tokens) == len(seq_lens) + assert len(input_positions) == len(seq_lens) + torch.testing.assert_close( + input_tokens, + input_positions, + ) + # - Encoder + assert len(encoder_input_tokens) == 0 + assert len(encoder_input_tokens) == 0 + torch.testing.assert_close( + encoder_input_tokens, + encoder_input_positions, + ) + + # sampling_metadata = SamplingMetadata.prepare( + # seq_group_metadata_list, + # seq_lens, + # query_lens=seq_lens, + # device=model_runner.device, + # pin_memory=model_runner.pin_memory, + # ) + + # actual = sampling_metadata.selected_token_indices + # expected = torch.tensor( + # expected_selected_token_indices, + # device=actual.device, + # dtype=actual.dtype, + # ) + # torch.testing.assert_close(actual, expected) + # torch.allclose(input_tokens, input_positions) + + # actual = sampling_metadata.selected_token_indices + # expected = torch.tensor(expected_selected_token_indices, + # device=actual.device, + # dtype=actual.dtype) + # torch.testing.assert_close(actual, expected) \ No newline at end of file diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 0675a300e2719..aebf63bde1b26 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -7,16 +7,16 @@ ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.pooling_params import PoolingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) -from vllm.worker.model_runner import (GPUModelRunnerBase, - ModelInputForGPU, - LORA_WARMUP_RANK, - _BATCH_SIZES_TO_CAPTURE, - _PAD_SLOT_ID, - ) +from vllm.worker.model_runner import ( + GPUModelRunnerBase, + ModelInputForGPU, + ModelInputForGPUWithSamplingMetadata, + LORA_WARMUP_RANK, + _BATCH_SIZES_TO_CAPTURE, + _PAD_SLOT_ID, +) from vllm.distributed import get_pp_group from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) @@ -84,16 +84,17 @@ logger = init_logger(__name__) + @dataclasses.dataclass(frozen=True) -class EncoderDecoderModelInput(ModelInputForGPU): +class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): """ Used by the EncoderDecoderModelRunner. """ encoder_input_tokens: Optional[torch.Tensor] = None encoder_input_positions: Optional[torch.Tensor] = None -class EncoderDecoderModelRunner( - GPUModelRunnerBase[EncoderDecoderModelInput]): + +class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): _model_input_cls: Type[EncoderDecoderModelInput] = ( EncoderDecoderModelInput) @@ -218,9 +219,7 @@ def execute_model( return [output] def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, - Any]) -> EncoderDecoderModelInput: + self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput: return EncoderDecoderModelInput.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, @@ -361,8 +360,7 @@ def profile_run(self) -> None: def _prepare_encoder_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - model_input: EncoderDecoderModelInput, - finished_requests_ids: Optional[List[str]] = None + model_input: EncoderDecoderModelInput ) -> EncoderDecoderModelInput: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not @@ -395,30 +393,11 @@ def _prepare_encoder_model_input_tensors( query_lens: List[int] = [] block_tables: List[List[int]] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] - request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list) decode_only = True num_prefills = 0 num_prefill_tokens = 0 num_decode_tokens = 0 - # The following fields are only for flashinfer - # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout - # for the precise definition of the following fields. - # An example: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - paged_kv_indices: List[int] = [] - # 0 at the beginning of paged_kv_indptr indicates the start of the - # first request’s page indices in the paged_kv_indices list. - paged_kv_indptr: List[int] = [0] - # paged_kv_last_page_len is the length of the last page of each request - paged_kv_last_page_len: List[int] = [] - if len(seq_group_metadata_list) == 0: # Leave the encoder/cross-attention input # fields at default values if the seq group @@ -436,11 +415,11 @@ def _prepare_encoder_model_input_tensors( seq_ids = list(seq_group_metadata.seq_data.keys()) is_prompt = seq_group_metadata.is_prompt - computed_block_nums = seq_group_metadata.computed_block_nums + computed_block_nums = None if (self.scheduler_config is not None and self.scheduler_config.chunked_prefill_enabled and not (computed_block_nums is None - or computed_block_nums == [])): + or computed_block_nums == [])): raise RuntimeError( "chunked prefill cannot be used with prefix caching " "now.") @@ -453,11 +432,10 @@ def _prepare_encoder_model_input_tensors( # get_num_computed_tokens is incorrect for spec decoding. # So, we should have a special logic here. # TODO(sang): Fix it. - context_len = seq_data.get_len() - 1 + context_len = seq_data.get_len() + + seq_len = seq_data.get_len() - seq_len = min( - seq_data.get_len(), - context_len + seq_group_metadata.token_chunk_size) if is_prompt: tokens = seq_data.get_token_ids()[context_len:seq_len] else: @@ -469,8 +447,7 @@ def _prepare_encoder_model_input_tensors( # Prefix is not supported with sliding_window prefix_cache_hit = (computed_block_nums is not None and len(computed_block_nums) > 0 - and self.sliding_window is None - and is_prompt) + and self.sliding_window is None and is_prompt) # These are seq_len/context_len capped to the sliding window. # They are passed to decode kernel. @@ -519,13 +496,12 @@ def _prepare_encoder_model_input_tensors( else: block_table = computed_block_nums elif (self.scheduler_config.chunked_prefill_enabled - or not is_prompt): + or not is_prompt): if cross_block_table is not None: # chunked prefill or decode block_table = cross_block_table if curr_sliding_window_blocks is not None: - block_table = block_table[ - -curr_sliding_window_blocks:] + block_table = block_table[-curr_sliding_window_blocks:] else: # Only happens when memory profiling runs. block_table = [] @@ -550,9 +526,9 @@ def _prepare_encoder_model_input_tensors( decode_only = False prefill_seq_lens.append(seq_len) else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) + # assert is_encoder_seq or query_len == 1, ( + # "seq_len: {}, context_len: {}, query_len: {}".format( + # seq_len, context_len, query_len)) num_decode_tokens += query_len decode_seq_lens.append(sliding_seq_len) @@ -562,9 +538,9 @@ def _prepare_encoder_model_input_tensors( lora_index_mapping += [lora_id] * query_len lora_prompt_mapping.extend( [lora_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - is not None else 1)) + (query_len if seq_group_metadata.sampling_params and + seq_group_metadata.sampling_params.prompt_logprobs is not None + else 1)) mm_data = seq_group_metadata.multi_modal_data if mm_data: @@ -579,13 +555,13 @@ def _prepare_encoder_model_input_tensors( num_tokens = seq_group_metadata.\ prompt_adapter_num_virtual_tokens pm = [prompt_adapter_id - ] * num_tokens + [0] * (query_len - num_tokens) + ] * num_tokens + [0] * (query_len - num_tokens) prompt_adapter_index_mapping += pm prompt_adapter_prompt_mapping.extend( [prompt_adapter_id] * (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - else 1)) + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) is_profile_run = _is_single_block_table_empty( seq_group_metadata.block_tables) @@ -634,32 +610,31 @@ def _prepare_encoder_model_input_tensors( batch_size = len(input_tokens) max_query_len = max(query_lens) - max_seq_len = (max(prefill_seq_lens, default=0) if is_prompt else - max(decode_seq_lens, default=0)) + max_seq_len = (max(prefill_seq_lens, default=0) + if is_prompt else max(decode_seq_lens, default=0)) # If cuda graph can be used, pad tensors accordingly. # See `capture_model` API for more details. # vLLM uses cuda graph only for decoding requests. - use_captured_graph = ( - decode_only and not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_seq_len <= self.max_seq_len_to_capture) + use_captured_graph = (decode_only + and not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_seq_len <= self.max_seq_len_to_capture) if use_captured_graph: assert False - if use_captured_graph: - assert False - else: - max_block_table_len = max( - len(block_table) for block_table in block_tables) - block_tables = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + assert (not is_prompt) or max_query_len > 0, ( + "Decode-phase query_lens: {}".format(query_lens) ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, @@ -687,126 +662,66 @@ def _prepare_encoder_model_input_tensors( dtype=query_start_loc.dtype, out=query_start_loc[1:]) - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions_tensor = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) + attn_metadata = model_input.attn_metadata + slot_mapping_tensor = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) + # Set encoder-oriented attention metadata fields + attn_metadata.num_encoder_tokens = sum(seq_lens) + attn_metadata.encoder_seq_lens = seq_lens + attn_metadata.encoder_seq_lens_tensor = seq_lens_tensor + attn_metadata.max_encoder_seq_len = max_seq_len + attn_metadata.cross_slot_mapping = slot_mapping_tensor + attn_metadata.cross_block_tables = block_tables + + if seq_group_metadata.is_prompt: + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + + else: + + input_tokens_tensor = torch.tensor([], + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor([], + dtype=torch.long, + device=self.device) + + # Inject attn_metadata encoder/cross-attention fields & + # encoder input tokens/positions into model_input. + # Frozen dataclass fields cannot be modified, so use + # dataclasses.replace to construct a new model input + # instance. + model_input = dataclasses.replace( + model_input, + attn_metadata=attn_metadata, + encoder_input_tokens=input_tokens_tensor, + encoder_input_positions=input_positions_tensor, + ) + logits_soft_cap = getattr(self.model_config.hf_config, 'attn_logit_softcapping', None) if logits_soft_cap is not None and self.attn_backend.get_name( ) != "flashinfer": - raise ValueError("Please use Flashinfer backend for models with" - "logits_soft_cap (i.e., Gemma-2)." - " Otherwise, the output might be wrong." - " Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + raise ValueError("Models with logits_soft_cap (i.e., Gemma-2)" + " require FlashInfer backend, however vLLM" + " currently only supports xFormers backend" + " for encoder/decoder models.") - if self.attn_backend.get_name() == "flashinfer": - if len(paged_kv_indptr) > 0: - paged_kv_indices_tensor = torch.tensor(paged_kv_indices, - device='cpu', - dtype=torch.int) - paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, - device='cpu', - dtype=torch.int) - paged_kv_last_page_len_tensor = torch.tensor( - paged_kv_last_page_len, device='cpu', dtype=torch.int) - else: - paged_kv_indices_tensor = None - paged_kv_indptr_tensor = None - paged_kv_last_page_len_tensor = None - - kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, - self.model_config.dtype) - attn_metadata = self.attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - max_prefill_seq_len=max_seq_len, - block_tables=block_tables, - paged_kv_indptr=paged_kv_indptr_tensor, - paged_kv_indices=paged_kv_indices_tensor, - paged_kv_last_page_len=paged_kv_last_page_len_tensor, - num_qo_heads=self.model_config.get_num_attention_heads( - self.parallel_config), - num_kv_heads=self.model_config.get_num_kv_heads( - self.parallel_config), - head_dim=self.model_config.get_head_size(), - page_size=self.block_size, - seq_start_loc=seq_start_loc, - query_start_loc=query_start_loc, - device=self.device, - data_type=kv_cache_dtype, - use_cuda_graph=use_captured_graph, - logits_soft_cap=logits_soft_cap) + return model_input - else: - attn_metadata = self.attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=max_seq_len, - max_decode_seq_len=max_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None - - if self.prompt_adapter_config: - prompt_adapter_mapping = PromptAdapterMapping( - prompt_adapter_index_mapping, - prompt_adapter_prompt_mapping, - ) - else: - prompt_adapter_mapping = None - - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, - device=self.device) - request_ids_to_seq_ids = { - seq_group_metadata.request_id: - list(seq_group_metadata.seq_data.keys()) - for seq_group_metadata in seq_group_metadata_list - } - return self._model_input_cls( - input_tokens=input_tokens_tensor, - input_positions=input_positions_tensor, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_mapping=lora_mapping, - lora_requests=lora_requests, - multi_modal_kwargs=multi_modal_kwargs, - request_ids_to_seq_ids=request_ids_to_seq_ids, - finished_requests_ids=finished_requests_ids, - prompt_adapter_mapping=prompt_adapter_mapping, - prompt_adapter_requests=prompt_adapter_requests, - ) - def _is_single_block_table_empty(block_table: Optional[List[int]]): """ Check if a single block table has not been constructed """ if block_table is None: return True - return False \ No newline at end of file + return False From 3d5bb888cfc10c835ff17c18ca367c930a335785 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 04:48:48 -0400 Subject: [PATCH 321/443] EncoderDecoderModelInput correctly handles encoder token/position fields --- .../test_encoder_decoder_model_runner.py | 12 ++--- vllm/worker/enc_dec_model_runner.py | 51 ++++++++++++++----- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index bbb2961915b1f..5612f71f5d92b 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -156,8 +156,7 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): # Build decoder model inputs & # decoder self-attention KV caching data structures decoder_only_model_input = ( - model_runner._prepare_model_input_tensors( - seq_group_metadata_list)) + model_runner._prepare_model_input_tensors(seq_group_metadata_list)) input_tokens = decoder_only_model_input.input_tokens input_positions = decoder_only_model_input.input_positions attn_metadata = decoder_only_model_input.attn_metadata @@ -171,7 +170,7 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): # data structures encoder_decoder_model_input = ( model_runner._prepare_encoder_model_input_tensors( - seq_group_metadata_list, decoder_only_model_input)) + seq_group_metadata_list, decoder_only_model_input)) encoder_input_tokens = encoder_decoder_model_input.encoder_input_tokens encoder_input_positions = encoder_decoder_model_input.encoder_input_positions attn_metadata = encoder_decoder_model_input.attn_metadata @@ -335,8 +334,7 @@ def test_prepare_decode(batch_size, backend_name, enforce_eager, monkeypatch): # Build decoder model inputs & # decoder self-attention KV caching data structures decoder_only_model_input = ( - model_runner._prepare_model_input_tensors( - seq_group_metadata_list)) + model_runner._prepare_model_input_tensors(seq_group_metadata_list)) input_tokens = decoder_only_model_input.input_tokens input_positions = decoder_only_model_input.input_positions attn_metadata = decoder_only_model_input.attn_metadata @@ -350,7 +348,7 @@ def test_prepare_decode(batch_size, backend_name, enforce_eager, monkeypatch): # data structures encoder_decoder_model_input = ( model_runner._prepare_encoder_model_input_tensors( - seq_group_metadata_list, decoder_only_model_input)) + seq_group_metadata_list, decoder_only_model_input)) encoder_input_tokens = encoder_decoder_model_input.encoder_input_tokens encoder_input_positions = encoder_decoder_model_input.encoder_input_positions attn_metadata = encoder_decoder_model_input.attn_metadata @@ -467,4 +465,4 @@ def test_prepare_decode(batch_size, backend_name, enforce_eager, monkeypatch): # expected = torch.tensor(expected_selected_token_indices, # device=actual.device, # dtype=actual.dtype) - # torch.testing.assert_close(actual, expected) \ No newline at end of file + # torch.testing.assert_close(actual, expected) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index aebf63bde1b26..6b47fa38c4043 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, cast import torch @@ -93,6 +93,36 @@ class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): encoder_input_tokens: Optional[torch.Tensor] = None encoder_input_positions: Optional[torch.Tensor] = None + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "encoder_input_tokens": self.encoder_input_tokens, + "encoder_input_positions": self.encoder_input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + "prompt_adapter_mapping": self.prompt_adapter_mapping, + "prompt_adapter_requests": self.prompt_adapter_requests, + "virtual_engine": self.virtual_engine, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_requests_ids": self.finished_requests_ids, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "EncoderDecoderModelInput": + return cast( + EncoderDecoderModelInput, + super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) + class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): _model_input_cls: Type[EncoderDecoderModelInput] = ( @@ -358,10 +388,8 @@ def profile_run(self) -> None: return def _prepare_encoder_model_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - model_input: EncoderDecoderModelInput - ) -> EncoderDecoderModelInput: + self, seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: EncoderDecoderModelInput) -> EncoderDecoderModelInput: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. @@ -633,8 +661,7 @@ def _prepare_encoder_model_input_tensors( device=self.device, ) assert (not is_prompt) or max_query_len > 0, ( - "Decode-phase query_lens: {}".format(query_lens) - ) + "Decode-phase query_lens: {}".format(query_lens)) context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, @@ -700,11 +727,11 @@ def _prepare_encoder_model_input_tensors( # dataclasses.replace to construct a new model input # instance. model_input = dataclasses.replace( - model_input, - attn_metadata=attn_metadata, - encoder_input_tokens=input_tokens_tensor, - encoder_input_positions=input_positions_tensor, - ) + model_input, + attn_metadata=attn_metadata, + encoder_input_tokens=input_tokens_tensor, + encoder_input_positions=input_positions_tensor, + ) logits_soft_cap = getattr(self.model_config.hf_config, 'attn_logit_softcapping', None) From db5539a85f83ceaa929e2c02129a1a174fa29424 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 05:00:25 -0400 Subject: [PATCH 322/443] format --- .../test_encoder_decoder_model_runner.py | 9 ++- vllm/worker/enc_dec_model_runner.py | 79 +++++-------------- 2 files changed, 26 insertions(+), 62 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 5612f71f5d92b..d59ef8ef09445 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -172,10 +172,12 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): model_runner._prepare_encoder_model_input_tensors( seq_group_metadata_list, decoder_only_model_input)) encoder_input_tokens = encoder_decoder_model_input.encoder_input_tokens - encoder_input_positions = encoder_decoder_model_input.encoder_input_positions + encoder_input_positions = ( + encoder_decoder_model_input.encoder_input_positions) attn_metadata = encoder_decoder_model_input.attn_metadata cross_slot_mapping = attn_metadata.cross_slot_mapping - return_encoder_seq_lens = encoder_decoder_model_input.attn_metadata.encoder_seq_lens + return_encoder_seq_lens = ( + encoder_decoder_model_input.attn_metadata.encoder_seq_lens) assert return_encoder_seq_lens == encoder_seq_lens assert len(cross_slot_mapping) == len(encoder_input_tokens) @@ -350,7 +352,8 @@ def test_prepare_decode(batch_size, backend_name, enforce_eager, monkeypatch): model_runner._prepare_encoder_model_input_tensors( seq_group_metadata_list, decoder_only_model_input)) encoder_input_tokens = encoder_decoder_model_input.encoder_input_tokens - encoder_input_positions = encoder_decoder_model_input.encoder_input_positions + encoder_input_positions = ( + encoder_decoder_model_input.encoder_input_positions) attn_metadata = encoder_decoder_model_input.attn_metadata return_encoder_seq_lens = attn_metadata.encoder_seq_lens cross_slot_mapping = attn_metadata.cross_slot_mapping diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 6b47fa38c4043..3228de930f8a5 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -1,38 +1,19 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, cast import torch +import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) -from vllm.logger import init_logger -from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, - SequenceGroupMetadata) -from vllm.worker.model_runner import ( - GPUModelRunnerBase, - ModelInputForGPU, - ModelInputForGPUWithSamplingMetadata, - LORA_WARMUP_RANK, - _BATCH_SIZES_TO_CAPTURE, - _PAD_SLOT_ID, -) from vllm.distributed import get_pp_group -from vllm.sequence import (IntermediateTensors, SamplerOutput, +from vllm.logger import init_logger +from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput, SequenceGroupMetadata) - -import dataclasses -import gc -import time -import warnings -from collections import defaultdict -from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, - Tuple, Type, TypeVar, Union) - -import numpy as np -import torch -import torch.distributed -import torch.nn as nn +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _PAD_SLOT_ID, + LORA_WARMUP_RANK, GPUModelRunnerBase, + ModelInputForGPUWithSamplingMetadata) try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper @@ -45,39 +26,17 @@ BatchPrefillWithPagedKVCacheWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 -from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) -from vllm.distributed import get_pp_group -from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY -from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.model_executor.models.interfaces import (supports_lora, - supports_vision) -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, - MultiModalInputs) -from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.model_executor.models.interfaces import supports_vision +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.prompt_adapter.worker_manager import ( - LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, - is_pin_memory_available, make_tensor_with_pad) +from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, - _init_attn_metadata_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict) + _add_sampling_metadata_broadcastable_dict) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -632,9 +591,9 @@ def _prepare_encoder_model_input_tensors( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - # Prepare input tensors for flashinfer - if self.attn_backend.get_name() == "flashinfer": - assert False + # # Prepare input tensors for flashinfer + # if self.attn_backend.get_name() == "flashinfer": + # assert False batch_size = len(input_tokens) max_query_len = max(query_lens) @@ -649,7 +608,8 @@ def _prepare_encoder_model_input_tensors( and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_seq_len <= self.max_seq_len_to_capture) if use_captured_graph: - assert False + raise NotImplementedError("CUDAGraph is currently not supported " + "for encoder/decoder models.") max_block_table_len = max( len(block_table) for block_table in block_tables) @@ -663,9 +623,9 @@ def _prepare_encoder_model_input_tensors( assert (not is_prompt) or max_query_len > 0, ( "Decode-phase query_lens: {}".format(query_lens)) - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) + # context_lens_tensor = torch.tensor(context_lens, + # dtype=torch.int, + # device=self.device) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, @@ -690,6 +650,7 @@ def _prepare_encoder_model_input_tensors( out=query_start_loc[1:]) attn_metadata = model_input.attn_metadata + assert attn_metadata is not None slot_mapping_tensor = torch.tensor(slot_mapping, dtype=torch.long, From 760355bfeea93c7b85cf440f597485e11a7357b1 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 06:04:43 -0400 Subject: [PATCH 323/443] bart test skipped on CPU version of vllm --- tests/core/test_scheduler_encoder_decoder.py | 1 - tests/models/test_bart.py | 7 +++++-- tests/worker/test_encoder_decoder_model_runner.py | 14 ++++++++++---- vllm/attention/backends/utils.py | 5 ++++- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py index 4c5fa1983be34..ec1488b505b92 100644 --- a/tests/core/test_scheduler_encoder_decoder.py +++ b/tests/core/test_scheduler_encoder_decoder.py @@ -9,7 +9,6 @@ from .utils import (append_new_token, create_dummy_prompt_encoder_decoder, get_sequence_groups, schedule_and_update_computed_tokens) - def test_scheduler_schedule_simple_encoder_decoder(): block_size = 4 num_seq_group = 4 diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index 2bf8c97131da3..cc5755208c263 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -5,7 +5,10 @@ import pytest from tests.kernels.utils import override_backend_env_variable -from vllm.utils import STR_XFORMERS_ATTN_VAL +from vllm.utils import (STR_XFORMERS_ATTN_VAL, + is_cpu, + ) +from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_CPU from .utils import check_logprobs_close @@ -16,7 +19,7 @@ # Currently only XFormers is supported BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] - +@pytest.mark.skipif(condition=is_cpu(),reason=STR_NOT_IMPL_ENC_DEC_CPU) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index d59ef8ef09445..4ba09adf80a7f 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -7,7 +7,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import STR_XFORMERS_ATTN_VAL +from vllm.utils import (STR_XFORMERS_ATTN_VAL, is_cpu) from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner # Backends under test @@ -38,7 +38,9 @@ def _create_model_runner(model: str, *args, ) return model_runner - +@pytest.mark.skipif(condition=is_cpu(),reason="CPU backend is currently " + "unsupported for encoder/ " + "decoder models") @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): @@ -102,7 +104,9 @@ def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): assert attn_metadata is None assert return_seq_lens is None - +@pytest.mark.skipif(condition=is_cpu(),reason="CPU backend is currently " + "unsupported for encoder/ " + "decoder models") @pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) @@ -289,7 +293,9 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): dtype=actual.dtype) torch.testing.assert_close(actual, expected) - +@pytest.mark.skipif(condition=is_cpu(),reason="CPU backend is currently " + "unsupported for encoder/ " + "decoder models") @pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index a3cfc6e20748b..61af9bd79662f 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -3,5 +3,8 @@ # Error string(s) for encoder/decoder # unsupported attention scenarios -STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " +STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP backend is not currently supported " "with encoder/decoder models.") + +STR_NOT_IMPL_ENC_DEC_CPU = ("CPU backend is not current supported with " + "encoder/decoder models.") From 590a240fe53dd78e62c78f7ac0263b0c3fda6949 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 06:05:18 -0400 Subject: [PATCH 324/443] Formatting --- tests/core/test_scheduler_encoder_decoder.py | 1 + tests/models/test_bart.py | 7 +++-- .../test_encoder_decoder_model_runner.py | 26 ++++++++++++------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py index ec1488b505b92..4c5fa1983be34 100644 --- a/tests/core/test_scheduler_encoder_decoder.py +++ b/tests/core/test_scheduler_encoder_decoder.py @@ -9,6 +9,7 @@ from .utils import (append_new_token, create_dummy_prompt_encoder_decoder, get_sequence_groups, schedule_and_update_computed_tokens) + def test_scheduler_schedule_simple_encoder_decoder(): block_size = 4 num_seq_group = 4 diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index cc5755208c263..f392a20053f2b 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -5,10 +5,8 @@ import pytest from tests.kernels.utils import override_backend_env_variable -from vllm.utils import (STR_XFORMERS_ATTN_VAL, - is_cpu, - ) from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_CPU +from vllm.utils import STR_XFORMERS_ATTN_VAL, is_cpu from .utils import check_logprobs_close @@ -19,7 +17,8 @@ # Currently only XFormers is supported BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] -@pytest.mark.skipif(condition=is_cpu(),reason=STR_NOT_IMPL_ENC_DEC_CPU) + +@pytest.mark.skipif(condition=is_cpu(), reason=STR_NOT_IMPL_ENC_DEC_CPU) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 4ba09adf80a7f..51b5f3b58c500 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -7,7 +7,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import (STR_XFORMERS_ATTN_VAL, is_cpu) +from vllm.utils import STR_XFORMERS_ATTN_VAL, is_cpu from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner # Backends under test @@ -38,9 +38,11 @@ def _create_model_runner(model: str, *args, ) return model_runner -@pytest.mark.skipif(condition=is_cpu(),reason="CPU backend is currently " - "unsupported for encoder/ " - "decoder models") + +@pytest.mark.skipif(condition=is_cpu(), + reason="CPU backend is currently " + "unsupported for encoder/ " + "decoder models") @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): @@ -104,9 +106,11 @@ def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): assert attn_metadata is None assert return_seq_lens is None -@pytest.mark.skipif(condition=is_cpu(),reason="CPU backend is currently " - "unsupported for encoder/ " - "decoder models") + +@pytest.mark.skipif(condition=is_cpu(), + reason="CPU backend is currently " + "unsupported for encoder/ " + "decoder models") @pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) @@ -293,9 +297,11 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): dtype=actual.dtype) torch.testing.assert_close(actual, expected) -@pytest.mark.skipif(condition=is_cpu(),reason="CPU backend is currently " - "unsupported for encoder/ " - "decoder models") + +@pytest.mark.skipif(condition=is_cpu(), + reason="CPU backend is currently " + "unsupported for encoder/ " + "decoder models") @pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) From 8b8d9812f7b7317448d4872db32cffcb45444c02 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 06:17:41 -0400 Subject: [PATCH 325/443] refactored AttentionType and related imports; skip BART test definitions entirely if on vllm CPU version (to avoid xformers import --- tests/kernels/test_encoder_decoder_attn.py | 3 +- tests/kernels/utils.py | 5 +- tests/models/test_bart.py | 99 +++++++++++----------- vllm/attention/__init__.py | 4 +- vllm/attention/layer.py | 2 +- vllm/model_executor/models/bart.py | 3 +- 6 files changed, 59 insertions(+), 57 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index e0880a051f834..42d8ea6679458 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -16,8 +16,7 @@ from tests.kernels.utils import * from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor -from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.abstract import AttentionBackend, AttentionType +from vllm.attention import (Attention, AttentionMetadata, AttentionBackend, AttentionType,) from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.utils import LIST_ENC_DEC_SUPPORTED_BACKENDS, is_hip diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index f4dfbb977ab88..11ea8140336a7 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -8,8 +8,9 @@ import pytest import torch -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, AttentionType) +from vllm.attention import (AttentionBackend, + AttentionMetadata, AttentionType, + ) from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index f392a20053f2b..032119d924a57 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -2,54 +2,55 @@ Run `pytest tests/models/test_bart.py`. """ -import pytest - -from tests.kernels.utils import override_backend_env_variable -from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_CPU from vllm.utils import STR_XFORMERS_ATTN_VAL, is_cpu -from .utils import check_logprobs_close - -MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] - -# Backends under test -# -# Currently only XFormers is supported -BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] - - -@pytest.mark.skipif(condition=is_cpu(), reason=STR_NOT_IMPL_ENC_DEC_CPU) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -def test_models( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, - backend_name: str, - monkeypatch, -) -> None: - # TODO(sang): Sliding window should be tested separately. - - # Force Attention wrapper backend - override_backend_env_variable(monkeypatch, backend_name) - - with hf_runner(model, dtype=dtype, - is_encoder_decoder_model=True) as hf_model: - hf_outputs = hf_model.generate_encoder_decoder_greedy_logprobs_limit( - example_encoder_decoder_prompts, max_tokens, num_logprobs) - - with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - example_encoder_decoder_prompts, max_tokens, num_logprobs) - - check_logprobs_close(outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm") +if not is_cpu(): + # CPU backend is not currently supported with encoder/decoder models + + import pytest + from tests.kernels.utils import override_backend_env_variable + from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_CPU + + from .utils import check_logprobs_close + + MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] + + # Backends under test + # + # Currently only XFormers is supported + BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] + + @pytest.mark.parametrize("model", MODELS) + @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) + @pytest.mark.parametrize("max_tokens", [64]) + @pytest.mark.parametrize("num_logprobs", [5]) + @pytest.mark.parametrize("backend_name", BACKEND_NAMES) + def test_models( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, + backend_name: str, + monkeypatch, + ) -> None: + # TODO(sang): Sliding window should be tested separately. + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + with hf_runner(model, dtype=dtype, + is_encoder_decoder_model=True) as hf_model: + hf_outputs = hf_model.generate_encoder_decoder_greedy_logprobs_limit( + example_encoder_decoder_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + example_encoder_decoder_prompts, max_tokens, num_logprobs) + + check_logprobs_close(outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm") diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index f6bce9a187c64..3208453a66bbe 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,5 +1,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) + AttentionMetadata, + AttentionType) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -7,6 +8,7 @@ "Attention", "AttentionBackend", "AttentionMetadata", + "AttentionType", "Attention", "get_attn_backend", ] diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b8cc87be8c748..b53f1c9c0a0a7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadata, AttentionType +from vllm.attention import AttentionMetadata, AttentionType from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 8e892052dc396..239dea6a48a35 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -25,8 +25,7 @@ from transformers.activations import ACT2FN from transformers.utils import logging -from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.abstract import AttentionType +from vllm.attention import (Attention, AttentionMetadata,AttentionType,) from vllm.config import CacheConfig, LoRAConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( From ff940f7adf771465e92a6fad350fb2f1aca4f694 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 06:18:58 -0400 Subject: [PATCH 326/443] formatting --- tests/kernels/test_encoder_decoder_attn.py | 3 ++- tests/kernels/utils.py | 4 +--- tests/models/test_bart.py | 15 ++++++++------- vllm/attention/__init__.py | 3 +-- vllm/model_executor/models/bart.py | 2 +- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 42d8ea6679458..79cdba851059b 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -16,7 +16,8 @@ from tests.kernels.utils import * from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor -from vllm.attention import (Attention, AttentionMetadata, AttentionBackend, AttentionType,) +from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, + AttentionType) from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.utils import LIST_ENC_DEC_SUPPORTED_BACKENDS, is_hip diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 11ea8140336a7..e942336ff7fdc 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -8,9 +8,7 @@ import pytest import torch -from vllm.attention import (AttentionBackend, - AttentionMetadata, AttentionType, - ) +from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index 032119d924a57..ba6d73cbac037 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -8,8 +8,8 @@ # CPU backend is not currently supported with encoder/decoder models import pytest + from tests.kernels.utils import override_backend_env_variable - from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_CPU from .utils import check_logprobs_close @@ -42,15 +42,16 @@ def test_models( override_backend_env_variable(monkeypatch, backend_name) with hf_runner(model, dtype=dtype, - is_encoder_decoder_model=True) as hf_model: - hf_outputs = hf_model.generate_encoder_decoder_greedy_logprobs_limit( - example_encoder_decoder_prompts, max_tokens, num_logprobs) + is_encoder_decoder_model=True) as hf_model: + hf_outputs = ( + hf_model.generate_encoder_decoder_greedy_logprobs_limit( + example_encoder_decoder_prompts, max_tokens, num_logprobs)) with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( example_encoder_decoder_prompts, max_tokens, num_logprobs) check_logprobs_close(outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm") + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm") diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 3208453a66bbe..523d8be98964a 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,6 +1,5 @@ from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - AttentionType) + AttentionMetadata, AttentionType) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 239dea6a48a35..3d71e7955af59 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -25,7 +25,7 @@ from transformers.activations import ACT2FN from transformers.utils import logging -from vllm.attention import (Attention, AttentionMetadata,AttentionType,) +from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.config import CacheConfig, LoRAConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( From 64d71980c823c167239d5c7338096a144586b7f3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 06:59:49 -0400 Subject: [PATCH 327/443] wip --- tests/models/test_bart.py | 2 + vllm/inputs/data.py | 78 ++++++++++++++++++++++++++++++++++++++- vllm/inputs/utils.py | 12 ++++++ 3 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 vllm/inputs/utils.py diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index ba6d73cbac037..790ed4f1cbf30 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -6,6 +6,8 @@ if not is_cpu(): # CPU backend is not currently supported with encoder/decoder models + # skip test definitions entirely to avoid importing GPU kernel libs + # (xFormers, etc.) import pytest diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index c6381fcc01e5f..7a2b4d2ae577b 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -3,6 +3,11 @@ from typing_extensions import NotRequired +from .utils import (has_required_keys, + is_str, + is_dict, + ) + if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict @@ -110,7 +115,38 @@ class TextTokensPrompt(TypedDict): """ -PromptStrictInputs = Union[str, TextPrompt, TokensPrompt] +DecoderOnlyPromptInputs = Union[str, TextPrompt, TokensPrompt, + TextTokensPrompt] +StrictDecoderOnlyPromptInputs = Union[str, TextPrompt, TokensPrompt] + + +class ExplicitEncoderDecoderPrompt(TypedDict): + """Represents an encoder/decoder model input prompt, + comprising an encoder prompt and a decoder prompt. + + Only the encoder prompt may have multi-modal data. + """ + + encoder_prompt: DecoderOnlyPromptInputs + + decoder_prompt: DecoderOnlyPromptInputs + + +class ExplicitEncoderDecoderPromptStrict(TypedDict): + """Represents an encoder/decoder model input prompt, + comprising an encoder prompt and a decoder prompt. + Strictly forbid a prompt containing both text and + tokens. + + Only the encoder prompt may have multi-modal data. + """ + + encoder_prompt: StrictDecoderOnlyPromptInputs + + decoder_prompt: StrictDecoderOnlyPromptInputs + +PromptStrictInputs = Union[StrictDecoderOnlyPromptInputs, + ExplicitEncoderDecoderPromptStrict] """ The inputs to the LLM, which can take one of the following forms: @@ -118,10 +154,39 @@ class TextTokensPrompt(TypedDict): - A tokenized prompt (:class:`TokensPrompt`) """ -PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt] +PromptInputs = Union[DecoderOnlyPromptInputs, + ExplicitEncoderDecoderPrompt] """Same as :const:`PromptStrictInputs` but additionally accepts :class:`TextTokensPrompt`.""" +AllPromptInputs = Union[PromptInputs, + ExplicitEncoderDecoderPromptStrict] +"""All possible input prompt options, strict or non-strict""" + +def get_single_prompt_type(prompt: AllPromptInputs): + required_keys_dict = { + 'TextPrompt': ['prompt'], + 'TokensPrompt': ['prompt_token_ids'], + 'TextTokensPrompt': ['prompt','prompt_token_ids'], + 'ExplicitEncoderDecoder': ['encoder_prompt','decoder_prompt'], + } + + if is_dict(prompt): + for ptype in required_keys_dict: + if has_required_keys(prompt,required_keys_dict[ptype]): + return ptype + + raise ValueError(f"Invalid prompt {prompt}, valid types are " + "required_keys_dict={required_keys_dict}") + elif is_str(prompt): + return "str" + + raise ValueError(f"Invalid prompt {prompt}") + +def is_valid_encoder_decoder_prompt(prompt: AllPromptInputs): + + + return True class LLMInputs(TypedDict): """ @@ -136,6 +201,15 @@ class LLMInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ + encoder_prompt_token_ids: NotRequired[Optional[List[int]]] + """The token IDs of the encoder prompt.""" + + encoder_prompt: NotRequired[Optional[str]] + """ + The original encoder prompt text corresponding to the token IDs, if + available. + """ + multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] """ Optional multi-modal data to pass to the model, diff --git a/vllm/inputs/utils.py b/vllm/inputs/utils.py new file mode 100644 index 0000000000000..319304566ff49 --- /dev/null +++ b/vllm/inputs/utils.py @@ -0,0 +1,12 @@ +'''Utility functions for input types''' + +def has_required_keys(d: dict, + required_keys: list, + ) -> bool: + return set(required_keys).issubset(d.keys()) + +def is_str(s,) -> bool: + return isinstance(s, str) + +def is_dict(d,) -> bool: + return isinstance(d, dict) \ No newline at end of file From 83c5c43dd6e06d13d9d05c01882b6d705a5aefaa Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 07:14:34 -0400 Subject: [PATCH 328/443] prompt type checks --- vllm/inputs/data.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 7a2b4d2ae577b..407f193cae564 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -163,7 +163,13 @@ class ExplicitEncoderDecoderPromptStrict(TypedDict): ExplicitEncoderDecoderPromptStrict] """All possible input prompt options, strict or non-strict""" -def get_single_prompt_type(prompt: AllPromptInputs): +def get_single_prompt_type(prompt: AllPromptInputs, + ) -> str: + """ + Get the type-name of the prompt argument instance, given that + isinstance() cannot apply to TypedDict subclasses directly. + """ + required_keys_dict = { 'TextPrompt': ['prompt'], 'TokensPrompt': ['prompt_token_ids'], @@ -183,9 +189,23 @@ def get_single_prompt_type(prompt: AllPromptInputs): raise ValueError(f"Invalid prompt {prompt}") -def is_valid_encoder_decoder_prompt(prompt: AllPromptInputs): - +def is_valid_encoder_decoder_prompt(prompt: AllPromptInputs, + ) -> bool: + """ + Return True if prompt has the correct structure for an encoder/decoder + prompt. + """ + if (get_single_prompt_type(prompt) == 'ExplicitEncoderDecoder' and + (prompt['encoder_prompt'] is None or + prompt['decoder_prompt']['multi_modal_data'] is not None)): + # For explicit encoder/decoder prompts, encoder prompt + # must be non-None and decoder prompt must be free of + # multi-modal data (which should instead be passed to + # the encoder.) + return False + # Any valid prompt type other than an explicit encoder/decoder + # prompt is a guaranteed-valid prompt return True class LLMInputs(TypedDict): From 10ed7145053546d2112ed98252dc46f782a04b72 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 07:18:13 -0400 Subject: [PATCH 329/443] Format --- vllm/inputs/data.py | 45 +++++++++++++++++++++----------------------- vllm/inputs/utils.py | 16 ++++++++++------ 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 407f193cae564..fb06ebac9340b 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -3,10 +3,7 @@ from typing_extensions import NotRequired -from .utils import (has_required_keys, - is_str, - is_dict, - ) +from .utils import has_required_keys, is_dict, is_str if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict @@ -145,6 +142,7 @@ class ExplicitEncoderDecoderPromptStrict(TypedDict): decoder_prompt: StrictDecoderOnlyPromptInputs + PromptStrictInputs = Union[StrictDecoderOnlyPromptInputs, ExplicitEncoderDecoderPromptStrict] """ @@ -154,17 +152,15 @@ class ExplicitEncoderDecoderPromptStrict(TypedDict): - A tokenized prompt (:class:`TokensPrompt`) """ -PromptInputs = Union[DecoderOnlyPromptInputs, - ExplicitEncoderDecoderPrompt] +PromptInputs = Union[DecoderOnlyPromptInputs, ExplicitEncoderDecoderPrompt] """Same as :const:`PromptStrictInputs` but additionally accepts :class:`TextTokensPrompt`.""" -AllPromptInputs = Union[PromptInputs, - ExplicitEncoderDecoderPromptStrict] +AllPromptInputs = Union[PromptInputs, ExplicitEncoderDecoderPromptStrict] """All possible input prompt options, strict or non-strict""" -def get_single_prompt_type(prompt: AllPromptInputs, - ) -> str: + +def get_single_prompt_type(prompt: AllPromptInputs, ) -> str: """ Get the type-name of the prompt argument instance, given that isinstance() cannot apply to TypedDict subclasses directly. @@ -173,15 +169,15 @@ def get_single_prompt_type(prompt: AllPromptInputs, required_keys_dict = { 'TextPrompt': ['prompt'], 'TokensPrompt': ['prompt_token_ids'], - 'TextTokensPrompt': ['prompt','prompt_token_ids'], - 'ExplicitEncoderDecoder': ['encoder_prompt','decoder_prompt'], + 'TextTokensPrompt': ['prompt', 'prompt_token_ids'], + 'ExplicitEncoderDecoder': ['encoder_prompt', 'decoder_prompt'], } if is_dict(prompt): for ptype in required_keys_dict: - if has_required_keys(prompt,required_keys_dict[ptype]): + if has_required_keys(prompt, required_keys_dict[ptype]): return ptype - + raise ValueError(f"Invalid prompt {prompt}, valid types are " "required_keys_dict={required_keys_dict}") elif is_str(prompt): @@ -189,25 +185,26 @@ def get_single_prompt_type(prompt: AllPromptInputs, raise ValueError(f"Invalid prompt {prompt}") -def is_valid_encoder_decoder_prompt(prompt: AllPromptInputs, - ) -> bool: + +def is_valid_encoder_decoder_prompt(prompt: AllPromptInputs, ) -> bool: """ Return True if prompt has the correct structure for an encoder/decoder prompt. """ - if (get_single_prompt_type(prompt) == 'ExplicitEncoderDecoder' and - (prompt['encoder_prompt'] is None or - prompt['decoder_prompt']['multi_modal_data'] is not None)): - # For explicit encoder/decoder prompts, encoder prompt - # must be non-None and decoder prompt must be free of - # multi-modal data (which should instead be passed to - # the encoder.) - return False + if (get_single_prompt_type(prompt) == 'ExplicitEncoderDecoder' + and (prompt['encoder_prompt'] is None + or prompt['decoder_prompt']['multi_modal_data'] is not None)): + # For explicit encoder/decoder prompts, encoder prompt + # must be non-None and decoder prompt must be free of + # multi-modal data (which should instead be passed to + # the encoder.) + return False # Any valid prompt type other than an explicit encoder/decoder # prompt is a guaranteed-valid prompt return True + class LLMInputs(TypedDict): """ The inputs in :class:`~vllm.LLMEngine` before they are diff --git a/vllm/inputs/utils.py b/vllm/inputs/utils.py index 319304566ff49..3ab4da64a4db1 100644 --- a/vllm/inputs/utils.py +++ b/vllm/inputs/utils.py @@ -1,12 +1,16 @@ '''Utility functions for input types''' -def has_required_keys(d: dict, - required_keys: list, - ) -> bool: + +def has_required_keys( + d: dict, + required_keys: list, +) -> bool: return set(required_keys).issubset(d.keys()) -def is_str(s,) -> bool: + +def is_str(s, ) -> bool: return isinstance(s, str) -def is_dict(d,) -> bool: - return isinstance(d, dict) \ No newline at end of file + +def is_dict(d, ) -> bool: + return isinstance(d, dict) From 78d3d3c00d30af324dbd1ca0973c1dd68d4cdb5b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 07:20:50 -0400 Subject: [PATCH 330/443] modified LLM.generate() error message --- vllm/entrypoints/llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 57e81a6317725..dd1cf7125fa43 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -287,8 +287,8 @@ def generate( """ if self.llm_engine.model_config.embedding_mode: raise ValueError( - "LLM.generate() is only supported for generation models " - "(XForCausalLM).") + "LLM.generate() is only supported for (conditional)generation " + "models (XForCausalLM, XForConditionalGeneration).") if prompt_token_ids is not None: inputs = self._convert_v1_inputs( From 6c953808f11122a0c5482786b41825a79788a9a4 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 07:25:01 -0400 Subject: [PATCH 331/443] wip engine is_encoder_decoder() setting --- vllm/engine/llm_engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 622221d2dd13e..f77b73af8e904 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -42,7 +42,9 @@ get_tokenizer_group) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter +from vllm.utils import (Counter, + is_embedding_model_config, + is_encoder_decoder_model_config,) from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) From 304caed04dcbc25b76d8e80321da00414ac7dc17 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 07:36:33 -0400 Subject: [PATCH 332/443] formatting --- vllm/engine/llm_engine.py | 11 ++++++++--- vllm/entrypoints/llm.py | 6 ++++++ vllm/inputs/data.py | 13 +++++++++++-- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f77b73af8e904..98996da158073 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -42,9 +42,8 @@ get_tokenizer_group) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import (Counter, - is_embedding_model_config, - is_encoder_decoder_model_config,) +from vllm.utils import (Counter, is_embedding_model_config, + is_encoder_decoder_model_config) from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -1182,3 +1181,9 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: seq_span.set_attribute( SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft) seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) + + def is_encoder_decoder_model(self): + return is_encoder_decoder_model_config(self.model_config) + + def is_embedding_model(self): + return is_embedding_model_config(self.model_config) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index dd1cf7125fa43..7f1df0cfb092c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -594,3 +594,9 @@ def _run_engine( # This is necessary because some requests may be finished earlier than # its previous requests. return sorted(outputs, key=lambda x: int(x.request_id)) + + def _is_encoder_decoder_model(self): + return self.llm_engine.is_encoder_decoder_model() + + def _is_embedding_model(self): + return self.llm_engine.is_embedding_model() diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index fb06ebac9340b..3b726e3fe4216 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -175,7 +175,11 @@ def get_single_prompt_type(prompt: AllPromptInputs, ) -> str: if is_dict(prompt): for ptype in required_keys_dict: - if has_required_keys(prompt, required_keys_dict[ptype]): + # Ignore type checking in the conditional below because type + # checker does not understand that is_dict(prompt) narrows + # down the possible types + if has_required_keys(prompt, + required_keys_dict[ptype]): # type: ignore return ptype raise ValueError(f"Invalid prompt {prompt}, valid types are " @@ -191,9 +195,14 @@ def is_valid_encoder_decoder_prompt(prompt: AllPromptInputs, ) -> bool: Return True if prompt has the correct structure for an encoder/decoder prompt. """ + # Ignore type checking in the conditional below because type checker + # does not understand that + # get_single_prompt_type(prompt) == 'ExplicitEncoderDecoder' narrows + # down the possible types if (get_single_prompt_type(prompt) == 'ExplicitEncoderDecoder' and (prompt['encoder_prompt'] is None - or prompt['decoder_prompt']['multi_modal_data'] is not None)): + or prompt['decoder_prompt']['multi_modal_data'] + is not None)): # type: ignore # For explicit encoder/decoder prompts, encoder prompt # must be non-None and decoder prompt must be free of # multi-modal data (which should instead be passed to From 7b0803b1bb9fbf222be2b719729b3494ade79087 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 07:41:25 -0400 Subject: [PATCH 333/443] formatting? --- vllm/inputs/data.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 3b726e3fe4216..e1d983f0d5d25 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -178,8 +178,9 @@ def get_single_prompt_type(prompt: AllPromptInputs, ) -> str: # Ignore type checking in the conditional below because type # checker does not understand that is_dict(prompt) narrows # down the possible types - if has_required_keys(prompt, - required_keys_dict[ptype]): # type: ignore + if has_required_keys( + prompt, # type: ignore + required_keys_dict[ptype]): return ptype raise ValueError(f"Invalid prompt {prompt}, valid types are " @@ -199,10 +200,10 @@ def is_valid_encoder_decoder_prompt(prompt: AllPromptInputs, ) -> bool: # does not understand that # get_single_prompt_type(prompt) == 'ExplicitEncoderDecoder' narrows # down the possible types - if (get_single_prompt_type(prompt) == 'ExplicitEncoderDecoder' - and (prompt['encoder_prompt'] is None - or prompt['decoder_prompt']['multi_modal_data'] - is not None)): # type: ignore + if (get_single_prompt_type(prompt) == 'ExplicitEncoderDecoder' and + (prompt['encoder_prompt'] is None # type: ignore + or prompt['decoder_prompt']['multi_modal_data'] # type: ignore + is not None)): # For explicit encoder/decoder prompts, encoder prompt # must be non-None and decoder prompt must be free of # multi-modal data (which should instead be passed to From 552551137b19a9e9c2ebc13856c8e5a22834ae1b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 08:51:18 -0400 Subject: [PATCH 334/443] Sequence may be constructed with encoder/decoder LLMInput configurations --- vllm/inputs/__init__.py | 8 +++++-- vllm/inputs/data.py | 16 +++++++++++++- vllm/sequence.py | 46 +++++++++++++++++++++++++++++++++++++---- 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index d094156962955..18202275b77e7 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,6 +1,8 @@ from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, PromptStrictInputs, TextPrompt, TextTokensPrompt, - TokensPrompt, parse_and_batch_prompt) + TokensPrompt, parse_and_batch_prompt, + get_single_prompt_type, is_valid_encoder_decoder_prompt, + is_valid_encoder_decoder_llm_inputs) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -15,5 +17,7 @@ __all__ = [ "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", "TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs", - "LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry" + "LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry", + "get_single_prompt_type", "is_valid_encoder_decoder_prompt", + "is_valid_encoder_decoder_llm_inputs" ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index e1d983f0d5d25..278a776a4c023 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -228,7 +228,7 @@ class LLMInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ - encoder_prompt_token_ids: NotRequired[Optional[List[int]]] + encoder_prompt_token_ids: NotRequired[List[int]] """The token IDs of the encoder prompt.""" encoder_prompt: NotRequired[Optional[str]] @@ -242,3 +242,17 @@ class LLMInputs(TypedDict): Optional multi-modal data to pass to the model, if the model supports it. """ + +def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs, ) -> bool: + """ + Return True if the LLMInputs instance has the correct configuration + for encoder/decoder. + """ + + if ('encoder_prompt_token_ids' in inputs and + inputs['encoder_prompt_token_ids'] is not None): + # Encoder prompt token ids field exists & + # is not None + return True + + return False \ No newline at end of file diff --git a/vllm/sequence.py b/vllm/sequence.py index 1cebf68d463db..d89b40a823b10 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -15,7 +15,9 @@ from vllm.sampling_params import SamplingParams if TYPE_CHECKING: - from vllm.inputs import LLMInputs + from vllm.inputs import (LLMInputs, get_single_prompt_type, + is_valid_encoder_decoder_llm_inputs, + ) from vllm.multimodal import MultiModalDataDict from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -251,7 +253,8 @@ def __init__( block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + from_decoder_prompt: bool = True ) -> None: self.seq_id = seq_id self.inputs = inputs @@ -259,6 +262,15 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request + self.from_decoder_prompt = True + self._prompt = None + self._prompt_token_ids = None + + if not (from_decoder_prompt or + is_valid_encoder_decoder_llm_inputs(inputs)): + raise ValueError("Cannot extract encoder input prompt from " + f"invalid input {inputs}; did you forget the " + "encoder input prompt fields?") self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -279,11 +291,37 @@ def n_blocks(self) -> int: @property def prompt(self) -> Optional[str]: - return self.inputs.get("prompt") + if self._prompt is not None: + # Reuse precomputed prompt string + return self._prompt + + # Select decoder or encoder input prompt str, + # as appropriate + if self.from_decoder_prompt: + prompt_key = "prompt" + else: + prompt_key = "encoder_prompt" + + # Cache prompt + self._prompt = self.inputs.get(prompt_key) + return self._prompt @property def prompt_token_ids(self) -> List[int]: - return self.inputs["prompt_token_ids"] + if self._prompt_token_ids is not None: + # Reuse precomputed prompt token ids + return self._prompt_token_ids + + # Select decoder or encoder input prompt + # token ids, as appropriate + if self.from_decoder_prompt: + prompt_key = "prompt_token_ids" + else: + prompt_key = "encoder_prompt_token_ids" + + # Cache computed prompt token ids + self._prompt_token_ids = self.inputs.get(prompt_key) + return self._prompt_token_ids @property def multi_modal_data(self) -> "MultiModalDataDict": From dd4031c8e3201ee2e874e40df69c1bd52e7c54be Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 09:11:34 -0400 Subject: [PATCH 335/443] wip but having wllm.commit_id error --- tests/core/utils.py | 65 +++++++++++++++++++++++++---------------- vllm/inputs/__init__.py | 6 ++-- vllm/inputs/data.py | 9 +++--- vllm/sequence.py | 44 ++++++++++++---------------- 4 files changed, 66 insertions(+), 58 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index a8dcd90af0fcf..95f251c1d55e0 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -56,24 +56,31 @@ def create_dummy_prompt_encoder_decoder( # and prompt "0 ... block_size". decoder_prompt_tokens = list(range(decoder_prompt_length)) decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) + encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) + encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) + + inputs={ + "encoder_prompt": { + "prompt": encoder_prompt_str, + "prompt_token_ids": encoder_prompt_tokens, + "multi_modal_data": None, + }, + "decoder_prompt": { + "prompt": decoder_prompt_str, + "prompt_token_ids": decoder_prompt_tokens, + "multi_modal_data": None, + } + } decoder_prompt = Sequence(int(request_id), - inputs={ - "prompt": decoder_prompt_str, - "prompt_token_ids": decoder_prompt_tokens, - "multi_modal_data": None, - }, - block_size=block_size) + inputs=inputs, + block_size=block_size, + from_decoder_prompt=True) - encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) - encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) encoder_prompt = Sequence(int(request_id), - inputs={ - "prompt": encoder_prompt_str, - "prompt_token_ids": encoder_prompt_tokens, - "multi_modal_data": None, - }, - block_size=block_size) + inputs=inputs, + block_size=block_size, + from_decoder_prompt=False) seq_group = SequenceGroup(request_id=request_id, seqs=[decoder_prompt], sampling_params=SamplingParams( @@ -139,16 +146,27 @@ def create_seq_group_encoder_decoder( prompt_token_ids = [0] * seq_prompt_len + inputs = { + "encoder_prompt": { + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, + "decoder_prompt": { + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + } + } + seqs = [] for seq_id_offset, output_len in enumerate(seq_output_lens): + # Construct decoder input sequences seq = Sequence( seq_id=seq_id_start + seq_id_offset, - inputs={ - "prompt": "", - "prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, - }, + inputs=inputs, block_size=16, + from_decoder_prompt=True ) for i in range(output_len): @@ -158,15 +176,12 @@ def create_seq_group_encoder_decoder( ) seqs.append(seq) - # Encoder sequence + # Encoder input sequence encoder_seq = Sequence( seq_id=seq_id_start + len(seq_output_lens), - inputs={ - "prompt": "", - "prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, - }, + inputs=inputs, block_size=16, + from_decoder_prompt=False ) return SequenceGroup(request_id=request_id, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 18202275b77e7..07643d340b26a 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,8 +1,8 @@ from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, PromptStrictInputs, TextPrompt, TextTokensPrompt, - TokensPrompt, parse_and_batch_prompt, - get_single_prompt_type, is_valid_encoder_decoder_prompt, - is_valid_encoder_decoder_llm_inputs) + TokensPrompt, get_single_prompt_type, + is_valid_encoder_decoder_llm_inputs, + is_valid_encoder_decoder_prompt, parse_and_batch_prompt) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 278a776a4c023..176d3711dcbe4 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -243,16 +243,17 @@ class LLMInputs(TypedDict): if the model supports it. """ + def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs, ) -> bool: """ Return True if the LLMInputs instance has the correct configuration for encoder/decoder. """ - if ('encoder_prompt_token_ids' in inputs and - inputs['encoder_prompt_token_ids'] is not None): + if ('encoder_prompt_token_ids' in inputs + and inputs['encoder_prompt_token_ids'] is not None): # Encoder prompt token ids field exists & # is not None return True - - return False \ No newline at end of file + + return False diff --git a/vllm/sequence.py b/vllm/sequence.py index d89b40a823b10..f35400f03c5f8 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -15,9 +15,7 @@ from vllm.sampling_params import SamplingParams if TYPE_CHECKING: - from vllm.inputs import (LLMInputs, get_single_prompt_type, - is_valid_encoder_decoder_llm_inputs, - ) + from vllm.inputs import LLMInputs, is_valid_encoder_decoder_llm_inputs from vllm.multimodal import MultiModalDataDict from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -246,16 +244,14 @@ class Sequence: """ - def __init__( - self, - seq_id: int, - inputs: "LLMInputs", - block_size: int, - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - from_decoder_prompt: bool = True - ) -> None: + def __init__(self, + seq_id: int, + inputs: "LLMInputs", + block_size: int, + eos_token_id: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + from_decoder_prompt: bool = True) -> None: self.seq_id = seq_id self.inputs = inputs self.block_size = block_size @@ -263,11 +259,11 @@ def __init__( self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request self.from_decoder_prompt = True - self._prompt = None - self._prompt_token_ids = None + self._prompt: Optional[str] = None + self._prompt_token_ids: Optional[List[int]] = None - if not (from_decoder_prompt or - is_valid_encoder_decoder_llm_inputs(inputs)): + if not (from_decoder_prompt + or is_valid_encoder_decoder_llm_inputs(inputs)): raise ValueError("Cannot extract encoder input prompt from " f"invalid input {inputs}; did you forget the " "encoder input prompt fields?") @@ -297,10 +293,8 @@ def prompt(self) -> Optional[str]: # Select decoder or encoder input prompt str, # as appropriate - if self.from_decoder_prompt: - prompt_key = "prompt" - else: - prompt_key = "encoder_prompt" + prompt_key = ("prompt" + if self.from_decoder_prompt else "encoder_prompt") # Cache prompt self._prompt = self.inputs.get(prompt_key) @@ -314,14 +308,12 @@ def prompt_token_ids(self) -> List[int]: # Select decoder or encoder input prompt # token ids, as appropriate - if self.from_decoder_prompt: - prompt_key = "prompt_token_ids" - else: - prompt_key = "encoder_prompt_token_ids" + prompt_key = ("prompt_token_ids" if self.from_decoder_prompt else + "encoder_prompt_token_ids") # Cache computed prompt token ids self._prompt_token_ids = self.inputs.get(prompt_key) - return self._prompt_token_ids + return self._prompt_token_ids @property def multi_modal_data(self) -> "MultiModalDataDict": From 8dccaa510a67e8de71811c13371468024843b71d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 09:34:14 -0400 Subject: [PATCH 336/443] correctly constructing enc/dec sequences --- tests/core/utils.py | 35 +++++++++++++---------------------- vllm/sequence.py | 4 +++- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index 95f251c1d55e0..dc69a6bc95cd2 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -53,23 +53,19 @@ def create_dummy_prompt_encoder_decoder( block_size = decoder_prompt_length # Create dummy prompt sequence with tokens 0...block_size-1 - # and prompt "0 ... block_size". + # and prompt "0 ... block_size". Note that the prompt string + # doesn't actually match the tokens decoder_prompt_tokens = list(range(decoder_prompt_length)) decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) inputs={ - "encoder_prompt": { - "prompt": encoder_prompt_str, - "prompt_token_ids": encoder_prompt_tokens, - "multi_modal_data": None, - }, - "decoder_prompt": { - "prompt": decoder_prompt_str, - "prompt_token_ids": decoder_prompt_tokens, - "multi_modal_data": None, - } + "prompt": decoder_prompt_str, + "prompt_token_ids": decoder_prompt_tokens, + "encoder_prompt": encoder_prompt_str, + "encoder_prompt_token_ids": encoder_prompt_tokens, + "multi_modal_data": None, } decoder_prompt = Sequence(int(request_id), @@ -146,17 +142,12 @@ def create_seq_group_encoder_decoder( prompt_token_ids = [0] * seq_prompt_len - inputs = { - "encoder_prompt": { - "prompt": "", - "prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, - }, - "decoder_prompt": { - "prompt": "", - "prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, - } + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "encoder_prompt": "", + "encoder_prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, } seqs = [] diff --git a/vllm/sequence.py b/vllm/sequence.py index f35400f03c5f8..64429292187dd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -14,8 +14,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from vllm.inputs import is_valid_encoder_decoder_llm_inputs + if TYPE_CHECKING: - from vllm.inputs import LLMInputs, is_valid_encoder_decoder_llm_inputs + from vllm.inputs import LLMInputs from vllm.multimodal import MultiModalDataDict from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics From 336a77d62d2d31a2ed6c9eba9e36190b50cca713 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 09:34:47 -0400 Subject: [PATCH 337/443] formatting --- tests/core/utils.py | 24 ++++++++++-------------- vllm/sequence.py | 3 +-- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index dc69a6bc95cd2..45a8e74e85324 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -60,7 +60,7 @@ def create_dummy_prompt_encoder_decoder( encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) - inputs={ + inputs = { "prompt": decoder_prompt_str, "prompt_token_ids": decoder_prompt_tokens, "encoder_prompt": encoder_prompt_str, @@ -142,7 +142,7 @@ def create_seq_group_encoder_decoder( prompt_token_ids = [0] * seq_prompt_len - inputs={ + inputs = { "prompt": "", "prompt_token_ids": prompt_token_ids, "encoder_prompt": "", @@ -153,12 +153,10 @@ def create_seq_group_encoder_decoder( seqs = [] for seq_id_offset, output_len in enumerate(seq_output_lens): # Construct decoder input sequences - seq = Sequence( - seq_id=seq_id_start + seq_id_offset, - inputs=inputs, - block_size=16, - from_decoder_prompt=True - ) + seq = Sequence(seq_id=seq_id_start + seq_id_offset, + inputs=inputs, + block_size=16, + from_decoder_prompt=True) for i in range(output_len): seq.append_token_id( @@ -168,12 +166,10 @@ def create_seq_group_encoder_decoder( seqs.append(seq) # Encoder input sequence - encoder_seq = Sequence( - seq_id=seq_id_start + len(seq_output_lens), - inputs=inputs, - block_size=16, - from_decoder_prompt=False - ) + encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens), + inputs=inputs, + block_size=16, + from_decoder_prompt=False) return SequenceGroup(request_id=request_id, seqs=seqs, diff --git a/vllm/sequence.py b/vllm/sequence.py index 64429292187dd..b9233d223cbb6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -9,13 +9,12 @@ import torch +from vllm.inputs import is_valid_encoder_decoder_llm_inputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.inputs import is_valid_encoder_decoder_llm_inputs - if TYPE_CHECKING: from vllm.inputs import LLMInputs from vllm.multimodal import MultiModalDataDict From 46397c74e7c094d86d4f49fc3230cb92985d5fc5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 13:30:21 -0400 Subject: [PATCH 338/443] wip --- vllm/engine/llm_engine.py | 44 ++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 98996da158073..d6b6fbfebe6cb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -22,7 +22,12 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs +from vllm.inputs import (INPUT_REGISTRY, + LLMInputs, + PromptInputs, + get_single_prompt_type, + is_valid_encoder_decoder_prompt, + ) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, @@ -552,24 +557,33 @@ def process_model_inputs( if isinstance(inputs, str): inputs = {"prompt": inputs} - if "prompt_token_ids" not in inputs: - tokenizer = self.get_tokenizer_group("prompts must be None if " - "skip_tokenizer_init is True") + if self.is_encoder_decoder_model(): + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + + ptype = get_single_prompt_type(inputs) - prompt_token_ids = tokenizer.encode(request_id=request_id, - prompt=inputs["prompt"], - lora_request=lora_request) else: - prompt_token_ids = inputs["prompt_token_ids"] + # Decoder-only operation + + if "prompt_token_ids" not in inputs: + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = tokenizer.encode(request_id=request_id, + prompt=inputs["prompt"], + lora_request=lora_request) + else: + prompt_token_ids = inputs["prompt_token_ids"] - if prompt_adapter_request: - prompt_token_ids = \ - [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\ - + prompt_token_ids + if prompt_adapter_request: + prompt_token_ids = \ + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\ + + prompt_token_ids - llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) return self.input_processor(llm_inputs) From 251f899ea158af33ffe1367c57137ac9ed9212ad Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 15 Jul 2024 16:33:10 -0400 Subject: [PATCH 339/443] wip --- vllm/engine/llm_engine.py | 49 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d6b6fbfebe6cb..82abdc8d39a27 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,8 +1,19 @@ import time from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional -from typing import Sequence as GenericSequence -from typing import Set, Type, TypeVar, Union +from typing import (TYPE_CHECKING, + Any, + ClassVar, + Dict, + Iterable, + List, + Optional, + Set, + Type, + TypeVar, + Union, + Tuple, + Sequence as GenericSequence, + ) from transformers import PreTrainedTokenizer @@ -547,6 +558,31 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() + _LLMInputComponentsType = Tuple[str, List[int],] + + # def _process_single_decoder_prompt_to_llm_input_components(self,inputs: PromptInputs, + # ptype: str, + # request_id: str, + # lora_request: Optional[LoRARequest] = None, + # is_decoder_prompt: bool = True, + # ) -> _LLMInputComponentsType: + # assert ptype != "ExplicitEncoderDecoder" + + # if "prompt_token_ids" in inputs: + # prompt_token_ids = inputs["prompt_token_ids"] + # else: + + # tokenizer = self.get_tokenizer_group("prompts must be None if " + # "skip_tokenizer_init is True") + + # prompt_token_ids = tokenizer.encode(request_id=request_id, + # prompt=inputs["prompt"], + # lora_request=lora_request) + + # prompt = ( + # inputs["prompt"] if "prompt" in inputs else None + # ) + def process_model_inputs( self, request_id: str, @@ -563,6 +599,13 @@ def process_model_inputs( ptype = get_single_prompt_type(inputs) + if ptype == "ExplicitEncoderDecoder": + # User supplied a + pass + else: + # + pass + else: # Decoder-only operation From ddaf0ade21142daafc504df83e15d31911dee497 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 16 Jul 2024 09:28:21 -0400 Subject: [PATCH 340/443] wip --- examples/offline_inference_encoder_decoder.py | 6 +- vllm/engine/llm_engine.py | 81 +++++++------------ vllm/model_executor/models/bart.py | 16 ++-- vllm/worker/worker.py | 13 ++- 4 files changed, 55 insertions(+), 61 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 737221506dbd6..0afaca8a4e164 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -25,7 +25,11 @@ "", ] # - Unified prompts -prompts = [enc_dec for enc_dec in zip(encoder_prompts, decoder_prompts)] +prompts = [{ + "encoder_prompt": encoder_prompt, + "decoder_prompt": decoder_prompt +} for (encoder_prompt, decoder_prompt) in zip(encoder_prompts, decoder_prompts) + ] print(prompts) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 82abdc8d39a27..af1f5295a0c46 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,19 +1,8 @@ import time from contextlib import contextmanager -from typing import (TYPE_CHECKING, - Any, - ClassVar, - Dict, - Iterable, - List, - Optional, - Set, - Type, - TypeVar, - Union, - Tuple, - Sequence as GenericSequence, - ) +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional +from typing import Sequence as GenericSequence +from typing import Set, Tuple, Type, TypeVar, Union from transformers import PreTrainedTokenizer @@ -33,12 +22,8 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, - LLMInputs, - PromptInputs, - get_single_prompt_type, - is_valid_encoder_decoder_prompt, - ) +from vllm.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs, + get_single_prompt_type) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, @@ -558,30 +543,15 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() - _LLMInputComponentsType = Tuple[str, List[int],] + _LLMInputComponentsType = Tuple[str, List[int], ] - # def _process_single_decoder_prompt_to_llm_input_components(self,inputs: PromptInputs, - # ptype: str, - # request_id: str, - # lora_request: Optional[LoRARequest] = None, - # is_decoder_prompt: bool = True, - # ) -> _LLMInputComponentsType: - # assert ptype != "ExplicitEncoderDecoder" - - # if "prompt_token_ids" in inputs: - # prompt_token_ids = inputs["prompt_token_ids"] - # else: - - # tokenizer = self.get_tokenizer_group("prompts must be None if " - # "skip_tokenizer_init is True") - - # prompt_token_ids = tokenizer.encode(request_id=request_id, - # prompt=inputs["prompt"], - # lora_request=lora_request) - - # prompt = ( - # inputs["prompt"] if "prompt" in inputs else None - # ) + def _get_prompt_token_ids_or_tokenize( + self, + inputs, + request_id, + lora_request, + ) -> List[int]: + return [0] def process_model_inputs( self, @@ -600,18 +570,19 @@ def process_model_inputs( ptype = get_single_prompt_type(inputs) if ptype == "ExplicitEncoderDecoder": - # User supplied a + # User supplied a pass else: - # + # pass else: # Decoder-only operation if "prompt_token_ids" not in inputs: - tokenizer = self.get_tokenizer_group("prompts must be None if " - "skip_tokenizer_init is True") + tokenizer = self.get_tokenizer_group( + "prompts must be None if " + "skip_tokenizer_init is True") prompt_token_ids = tokenizer.encode(request_id=request_id, prompt=inputs["prompt"], @@ -620,13 +591,15 @@ def process_model_inputs( prompt_token_ids = inputs["prompt_token_ids"] if prompt_adapter_request: - prompt_token_ids = \ - [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\ - + prompt_token_ids - - llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + prompt_token_ids = ( + [0] * + prompt_adapter_request.prompt_adapter_num_virtual_tokens + + prompt_token_ids) + + llm_inputs = LLMInputs( + prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) return self.input_processor(llm_inputs) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 3d71e7955af59..b1cb20cfef551 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput logger = logging.get_logger(__name__) @@ -709,10 +709,16 @@ def __init__(self, config.vocab_size) self.sampler = Sampler() - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata) -> torch.Tensor: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: r""" Args: input_ids diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 56d8587f8f010..7ec8aff50741d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -19,8 +19,11 @@ from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest +from vllm.utils import (is_embedding_model_config, + is_encoder_decoder_model_config) from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner +from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput @@ -85,8 +88,10 @@ def __init__( ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner if model_runner_cls is not None: ModelRunnerClass = model_runner_cls - elif self.model_config.embedding_mode: + elif self._is_embedding_model(): ModelRunnerClass = EmbeddingModelRunner + elif self._is_encoder_decoder_model(): + ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( model_config, parallel_config, @@ -107,6 +112,12 @@ def __init__( # Initialize gpu_cache as embedding models don't initialize kv_caches self.gpu_cache: Optional[List[List[torch.tensor]]] = None + def _is_encoder_decoder_model(self): + return is_encoder_decoder_model_config(self.model_config) + + def _is_embedding_model(self): + return is_embedding_model_config(self.model_config) + def init_device(self) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until From 92d9f486b2455ff5ea5215eb61b9cb1e375b17ff Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 16 Jul 2024 09:33:41 -0400 Subject: [PATCH 341/443] conftest: encoder/decoder example prompts --- tests/conftest.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 17f75d948c543..6a4ac4740ec5a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -116,6 +116,26 @@ def example_prompts() -> List[str]: return prompts +@pytest.fixture +def example_encoder_decoder_prompts() -> Tuple[List[str], List[str]]: + ''' + Returns an encoder prompt list and a decoder prompt list, wherein each pair + of same-index entries in both lists corresponds to an (encoder prompt, + decoder prompt) tuple. + + Returns: + * Encoder prompt list + * Decoder prompt list (reverse of encoder prompt list) + ''' + encoder_prompts = [] + for filename in _TEST_PROMPTS: + encoder_prompts += _read_prompts(filename) + + # Encoder prompts, decoder prompts + return encoder_prompts, \ + encoder_prompts[::-1] + + @pytest.fixture def example_long_prompts() -> List[str]: prompts = [] From c5846ac9b31777d131bb0e3af2ad62a74eab1978 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 16 Jul 2024 09:40:46 -0400 Subject: [PATCH 342/443] Hfrunner greedy logprobs limit --- tests/conftest.py | 76 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6a4ac4740ec5a..adeb8a9e88b21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,8 +10,8 @@ import torch.nn as nn import torch.nn.functional as F from PIL import Image -from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, - AutoTokenizer, BatchEncoding) +from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM, + AutoModelForVision2Seq, AutoTokenizer, BatchEncoding) from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset @@ -175,6 +175,7 @@ def __init__( is_embedding_model: bool = False, is_vision_model: bool = False, is_sparseml_model: bool = False, + is_encoder_decoder_model: bool = False, ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] @@ -192,6 +193,8 @@ def __init__( else: if is_vision_model: auto_cls = AutoModelForVision2Seq + elif is_encoder_decoder_model: + auto_cls = AutoModelForSeq2SeqLM elif is_sparseml_model: from sparseml.transformers import SparseAutoModelForCausalLM auto_cls = SparseAutoModelForCausalLM @@ -413,6 +416,75 @@ def generate_greedy_logprobs_limit( return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + def generate_encoder_decoder_greedy_logprobs_limit( + self, + encoder_decoder_prompts: Tuple[List[str], List[str]], + max_tokens: int, + num_logprobs: int, + ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + ''' + Greedy logprobs generation for vLLM encoder/decoder models + ''' + + all_logprobs: List[List[Dict[int, float]]] = [] + all_output_ids: List[List[int]] = [] + all_output_strs: List[str] = [] + + for encoder_prompt, decoder_prompt in zip(*encoder_decoder_prompts): + encoder_input_ids = self.tokenizer(encoder_prompt, + return_tensors="pt").input_ids + decoder_input_ids = self.tokenizer(decoder_prompt, + return_tensors="pt").input_ids + output = self.model.generate( + self.wrap_device(encoder_input_ids), + decoder_input_ids=self.wrap_device(decoder_input_ids), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + seq_logprobs: List[torch.Tensor] = [] + for _, decoder_hidden_states in enumerate( + output.decoder_hidden_states): + last_hidden_states = decoder_hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if getattr(self.model.get_output_embeddings(), "bias", + None) is not None: + logits += self.model.get_output_embeddings( + ).bias.unsqueeze(0) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + seq_logprobs.append(logprobs) + + # convert to dict + seq_logprobs_lst: List[Dict[int, float]] = [] + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(num_logprobs) + + tok_logprobs_dct = {} + for token_id, logprob in zip(topk.indices[0], topk.values[0]): + tok_logprobs_dct[token_id.item()] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + all_logprobs.append(seq_logprobs_lst) + seq_ids = output.sequences[0] + output_len = seq_ids.shape[0] - decoder_input_ids.shape[1] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: return self.model.encode(prompts) From 374880f71d6f81bd2a933b237ff6fa46e0324e6b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 16 Jul 2024 09:49:30 -0400 Subject: [PATCH 343/443] input preparation now includes encoder-oriented input setup: --- vllm/worker/enc_dec_model_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 3228de930f8a5..ad6fb3cc41e7a 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -235,6 +235,10 @@ def prepare_model_input( """ model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) + + model_input = self._prepare_encoder_model_input_tensors( + seq_group_metadata_list, model_input) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, model_input.seq_lens, model_input.query_lens, From 42ac66b469891ba3085eaa1265c2bd9d445e0839 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 16 Jul 2024 09:59:04 -0400 Subject: [PATCH 344/443] VllmRunner encoder/decoder methods --- tests/conftest.py | 63 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index adeb8a9e88b21..beb43eea64344 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,7 @@ destroy_model_parallel) from vllm.inputs import TextPrompt from vllm.logger import init_logger +from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs from vllm.utils import cuda_device_count_stateless, is_cpu @@ -565,6 +566,22 @@ def generate( outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs + def _final_steps_generate_w_logprobs(self, + req_outputs: List[RequestOutput]) \ + -> List[ + Tuple[List[int], + str, + Optional[ + SampleLogprobs]]]: + outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] + for req_output in req_outputs: + for sample in req_output.outputs: + output_str = sample.text + output_ids = sample.token_ids + output_logprobs = sample.logprobs + outputs.append((output_ids, output_str, output_logprobs)) + return outputs + def generate_w_logprobs( self, prompts: List[str], @@ -583,14 +600,25 @@ def generate_w_logprobs( req_outputs = self.model.generate(inputs, sampling_params=sampling_params) - outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] - for req_output in req_outputs: - for sample in req_output.outputs: - output_str = sample.text - output_ids = sample.token_ids - output_logprobs = sample.logprobs - outputs.append((output_ids, output_str, output_logprobs)) - return outputs + return self._final_steps_generate_w_logprobs(req_outputs) + + def generate_encoder_decoder_w_logprobs( + self, + encoder_decoder_prompts: Tuple[List[str], List[str]], + sampling_params: SamplingParams, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + ''' + Logprobs generation for vLLM encoder/decoder models + ''' + + assert sampling_params.logprobs is not None + + prompt_inputs = list( + zip(encoder_decoder_prompts[0], encoder_decoder_prompts[1])) + + req_outputs = self.model.generate(prompt_inputs, + sampling_params=sampling_params) + return self._final_steps_generate_w_logprobs(req_outputs) def generate_greedy( self, @@ -620,6 +648,25 @@ def generate_greedy_logprobs( return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + def generate_encoder_decoder_greedy_logprobs( + self, + encoder_decoder_prompts: Tuple[List[str], List[str]], + max_tokens: int, + num_logprobs: int, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + greedy_logprobs_params = SamplingParams(temperature=0.0, + max_tokens=max_tokens, + logprobs=num_logprobs) + ''' + Greedy logprobs generation for vLLM encoder/decoder models + ''' + + outputs = self.generate_encoder_decoder_w_logprobs( + encoder_decoder_prompts, greedy_logprobs_params) + + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + def generate_beam_search( self, prompts: List[str], From 850a97e812662645452989341eb44b79aa4b3276 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 16 Jul 2024 10:25:38 -0400 Subject: [PATCH 345/443] bart parallel vocab --- vllm/model_executor/models/bart.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index b1cb20cfef551..1be4813c6d242 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -31,6 +31,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -102,18 +104,17 @@ def forward(self, input_ids: torch.Tensor, attn_type: AttentionType, return super().forward(positions + self.offset) -class BartScaledWordEmbedding(nn.Embedding): +class BartScaledWordEmbedding(VocabParallelEmbedding): """ - This module overrides nn.Embeddings' + This module overrides VocabParallelEmbedding's forward by multiplying with embeddings scale. """ def __init__(self, num_embeddings: int, embedding_dim: int, - padding_idx: int, embed_scale: Optional[float] = 1.0): - super().__init__(num_embeddings, embedding_dim, padding_idx) + super().__init__(num_embeddings, embedding_dim) self.embed_scale = embed_scale def forward(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -472,13 +473,11 @@ def __init__(self, self.quant_config = quant_config self.lora_config = lora_config embed_dim = config.d_model - self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, embed_dim, - self.padding_idx, embed_scale=embed_scale) if embed_tokens is not None: @@ -554,14 +553,12 @@ def __init__( self.cache_config = cache_config self.quant_config = quant_config self.lora_config = lora_config - self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings embed_scale = math.sqrt( config.d_model) if config.scale_embedding else 1.0 self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, config.d_model, - self.padding_idx, embed_scale=embed_scale) if embed_tokens is not None: @@ -703,7 +700,11 @@ def __init__(self, if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + self.lm_head = BartScaledWordEmbedding(config.vocab_size, + config.d_model, + embed_scale=embed_scale) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -742,7 +743,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits From 3c7e19d3d0e4c53ca363f40712fe2df160be1d9e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 16 Jul 2024 10:44:23 -0400 Subject: [PATCH 346/443] zip enc/dec prompts; formatting --- examples/offline_inference_encoder_decoder.py | 9 ++---- tests/conftest.py | 6 ++-- vllm/core/scheduler.py | 8 +++-- vllm/inputs/__init__.py | 29 ++++++++++++++----- vllm/utils.py | 20 +++++++++++++ 5 files changed, 53 insertions(+), 19 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 0afaca8a4e164..85d927e79635f 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -1,6 +1,7 @@ from transformers import AutoTokenizer, BartForConditionalGeneration from vllm import LLM, SamplingParams +from vllm.utils import zip_enc_dec_prompt_lists dtype = "float" @@ -24,12 +25,8 @@ "", "", ] -# - Unified prompts -prompts = [{ - "encoder_prompt": encoder_prompt, - "decoder_prompt": decoder_prompt -} for (encoder_prompt, decoder_prompt) in zip(encoder_prompts, decoder_prompts) - ] +# - Unified encoder/decoder prompts +prompts = zip_enc_dec_prompt_lists(encoder_prompts, decoder_prompts) print(prompts) diff --git a/tests/conftest.py b/tests/conftest.py index beb43eea64344..9982ce2e74181 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,7 +22,8 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs -from vllm.utils import cuda_device_count_stateless, is_cpu +from vllm.utils import (cuda_device_count_stateless, is_cpu, + zip_enc_dec_prompt_lists) logger = init_logger(__name__) @@ -133,8 +134,7 @@ def example_encoder_decoder_prompts() -> Tuple[List[str], List[str]]: encoder_prompts += _read_prompts(filename) # Encoder prompts, decoder prompts - return encoder_prompts, \ - encoder_prompts[::-1] + return zip_enc_dec_prompt_lists(encoder_prompts, encoder_prompts[::-1]) @pytest.fixture diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 72262d433ac91..be5555358149f 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -268,6 +268,7 @@ def __init__( cache_config: CacheConfig, lora_config: Optional[LoRAConfig], pipeline_parallel_size: int = 1, + is_encoder_decoder=False, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -276,6 +277,8 @@ def __init__( # LoRAs. This should be improved in the future. self.lora_config = lora_config + self.is_encoder_decoder = is_encoder_decoder + version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" @@ -391,7 +394,8 @@ def _free_seq_group( Free a sequence group from a cross-attention block table. Has no effect on decoder-only models. """ - self.block_manager.free_cross(seq_group) + if self.is_encoder_decoder: + self.block_manager.free_cross(seq_group) def has_unfinished_seqs(self) -> bool: return len(self.waiting) != 0 or len(self.running) != 0 or len( @@ -1089,7 +1093,7 @@ def free_finished_seq_groups(self) -> None: for seq_group in queue: if seq_group.is_finished(): new_finished_requests_ids += seq_group.request_id - # Free cross-attention block table, kf it exists + # Free cross-attention block table, if it exists self._free_seq_group(seq_group) self._finished_requests_ids += new_finished_requests_ids self.running = deque(seq_group for seq_group in self.running diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 07643d340b26a..9fe13590d5c9b 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,6 +1,7 @@ -from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, - PromptStrictInputs, TextPrompt, TextTokensPrompt, - TokensPrompt, get_single_prompt_type, +from .data import (ExplicitEncoderDecoderPrompt, + ExplicitEncoderDecoderPromptStrict, LLMInputs, ParsedText, + ParsedTokens, PromptInputs, PromptStrictInputs, TextPrompt, + TextTokensPrompt, TokensPrompt, get_single_prompt_type, is_valid_encoder_decoder_llm_inputs, is_valid_encoder_decoder_prompt, parse_and_batch_prompt) from .registry import InputContext, InputRegistry @@ -15,9 +16,21 @@ """ __all__ = [ - "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", - "TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs", - "LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry", - "get_single_prompt_type", "is_valid_encoder_decoder_prompt", - "is_valid_encoder_decoder_llm_inputs" + "ParsedText", + "ParsedTokens", + "parse_and_batch_prompt", + "TextPrompt", + "TokensPrompt", + "TextTokensPrompt", + "PromptStrictInputs", + "PromptInputs", + "LLMInputs", + "INPUT_REGISTRY", + "InputContext", + "InputRegistry", + "get_single_prompt_type", + "is_valid_encoder_decoder_prompt", + "is_valid_encoder_decoder_llm_inputs", + "ExplicitEncoderDecoderPromptStrict", + "ExplicitEncoderDecoderPrompt", ] diff --git a/vllm/utils.py b/vllm/utils.py index 80922d7e9744c..8152d86b4effb 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -27,6 +27,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.inputs import ExplicitEncoderDecoderPrompt, PromptInputs from vllm.logger import enable_trace_function_call, init_logger logger = init_logger(__name__) @@ -998,3 +999,22 @@ def is_embedding_model_config(model_config) -> bool: ''' return False if model_config is None else \ model_config.embedding_mode + + +def build_explicit_enc_dec_prompt( + encoder_prompt: PromptInputs, + decoder_prompt: PromptInputs, +) -> ExplicitEncoderDecoderPrompt: + return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, + decoder_prompt=decoder_prompt) + + +def zip_enc_dec_prompt_lists( + enc_prompt_list: List[PromptInputs], + dec_prompt_list: List[PromptInputs], +) -> List[ExplicitEncoderDecoderPrompt]: + return [ + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt) + for (encoder_prompt, + decoder_prompt) in zip(enc_prompt_list, dec_prompt_list) + ] From e534ffc156479d1b4dbec905ccc0877b746cc068 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 16 Jul 2024 13:25:27 -0400 Subject: [PATCH 347/443] wip --- vllm/engine/llm_engine.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index af1f5295a0c46..1c3f560a05ab6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -553,6 +553,20 @@ def _get_prompt_token_ids_or_tokenize( ) -> List[int]: return [0] + def _tokenize_prompt(self, + request_id, + inputs, + prompt, + lora_request, + ) -> List[int]: + tokenizer = self.get_tokenizer_group( + "prompts must be None if " + "skip_tokenizer_init is True") + + return tokenizer.encode(request_id=request_id, + prompt=inputs["prompt"], + lora_request=lora_request) + def process_model_inputs( self, request_id: str, From 97d81f0a53506cf6292f24117e8ecbfca5803805 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 16 Jul 2024 14:17:09 -0400 Subject: [PATCH 348/443] encoder/decoder input processing; formatting --- vllm/engine/llm_engine.py | 211 +++++++++++++++++++++++++++++--------- vllm/inputs/__init__.py | 4 +- vllm/inputs/data.py | 4 +- 3 files changed, 167 insertions(+), 52 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1c3f560a05ab6..1c30277291bfd 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -23,9 +23,10 @@ from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs, - get_single_prompt_type) + get_prompt_type) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalDataDict from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams @@ -545,29 +546,137 @@ def stop_remote_worker_execution_loop(self) -> None: _LLMInputComponentsType = Tuple[str, List[int], ] - def _get_prompt_token_ids_or_tokenize( + def _tokenize_prompt( self, - inputs, request_id, + inputs, lora_request, ) -> List[int]: - return [0] - - def _tokenize_prompt(self, - request_id, - inputs, - prompt, - lora_request, - ) -> List[int]: - tokenizer = self.get_tokenizer_group( - "prompts must be None if " - "skip_tokenizer_init is True") + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") return tokenizer.encode(request_id=request_id, prompt=inputs["prompt"], lora_request=lora_request) - def process_model_inputs( + def _extract_single_prompt( + self, + request_id: str, + inputs: PromptInputs, + lora_request: Optional[LoRARequest], + is_encoder_prompt: bool = False, + ) -> Tuple[str, List[int], Optional["MultiModalDataDict"]]: + ''' + Extract prompt & prompt_token_ids from any single + encoder or decoder input prompt. For encoder input prompts + in particular, also extract multi-modal data. + + Arguments: + + * request_id + * inputs: single encoder or decoder input prompt + * lora_request + * is_encoder_prompt: True if encoder input prompt + + Returns: + * prompt + * prompt_token_ids + * multi_modal_data (None if is_encoder_prompt) + ''' + + if isinstance(inputs, str): + # prompt = inputs + # prompt_token_ids = tokenize(inputs) + # no multi-modal data + return (inputs, + self._tokenize_prompt( + request_id, + inputs, + lora_request, + ), None) + + # Tokenize + prompt_token_ids = (inputs["prompt_token_ids"] + if inputs["prompt_token_ids"] else + self._tokenize_prompt( + request_id, + inputs, + lora_request, + )) + + if is_encoder_prompt: + # Only care about multi-modal data associated + # with the encoder prompt + return (inputs.get('prompt'), prompt_token_ids, + inputs.get("multi_modal_data")) + else: + # Assume there is no decoder multi-modal data + return (inputs.get('prompt'), prompt_token_ids, None) + + def _get_default_decoder_prompt( + self, + request_id: str, + lora_request: Optional[LoRARequest], + ) -> Tuple[str, List[int]]: + prompt = "" + prompt_token_ids = (self._tokenize_prompt( + request_id, + prompt, + lora_request, + )) + return prompt, prompt_token_ids + + def _process_encoder_decoder_prompt(self, request_id: str, + inputs: PromptInputs, + lora_request: Optional[LoRARequest]): + ptype = get_prompt_type(inputs) + + # Obtain encoder prompt + ( + encoder_prompt, + encoder_prompt_token_ids, + multi_modal_data, + ) = self._extract_single_prompt( + request_id, + (inputs.get('encoder_prompt') if get_prompt_type(inputs) + == "ExplicitEncoderDecoder" else inputs), + lora_request, + is_encoder_prompt=True, + ) + + # Obtain decoder prompt + if ptype == "ExplicitEncoderDecoder": + # User supplied a dict with explicit + # encoder and decoder prompts; extract + # decoder prompt + ( + decoder_prompt, + decoder_prompt_token_ids, + _, + ) = self._extract_single_prompt( + request_id, + inputs.get('decoder_prompt'), + lora_request, + is_encoder_prompt=False, + ) + else: + # User supplied a single prompt (implicitly + # the encoder prompt) so default decoder + # prompt must be inferred + ( + decoder_prompt, + decoder_prompt_token_ids, + ) = self._get_default_decoder_prompt(request_id, lora_request) + + return LLMInputs( + prompt_token_ids=decoder_prompt_token_ids, + prompt=decoder_prompt, + encoder_prompt_token_ids=encoder_prompt_token_ids, + encoder_prompt=encoder_prompt, + multi_modal_data=multi_modal_data, + ) + + def _process_decoder_only_prompt( self, request_id: str, inputs: PromptInputs, @@ -577,45 +686,51 @@ def process_model_inputs( if isinstance(inputs, str): inputs = {"prompt": inputs} - if self.is_encoder_decoder_model(): - # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder - - ptype = get_single_prompt_type(inputs) - - if ptype == "ExplicitEncoderDecoder": - # User supplied a - pass - else: - # - pass - + if "prompt_token_ids" not in inputs: + prompt_token_ids = self._tokenize_prompt( + request_id, + inputs, + lora_request, + ) else: - # Decoder-only operation + prompt_token_ids = inputs["prompt_token_ids"] - if "prompt_token_ids" not in inputs: - tokenizer = self.get_tokenizer_group( - "prompts must be None if " - "skip_tokenizer_init is True") + if prompt_adapter_request: + prompt_token_ids = ( + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + + prompt_token_ids) - prompt_token_ids = tokenizer.encode(request_id=request_id, - prompt=inputs["prompt"], - lora_request=lora_request) - else: - prompt_token_ids = inputs["prompt_token_ids"] + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) - if prompt_adapter_request: - prompt_token_ids = ( - [0] * - prompt_adapter_request.prompt_adapter_num_virtual_tokens + - prompt_token_ids) + def process_model_inputs( + self, + request_id: str, + inputs: PromptInputs, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> LLMInputs: - llm_inputs = LLMInputs( - prompt_token_ids=prompt_token_ids, - prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + if self.is_encoder_decoder_model(): + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + return self.input_processor( + self._process_encoder_decoder_prompt( + request_id, + inputs, + lora_request, + )) - return self.input_processor(llm_inputs) + else: + # Decoder-only operation + return self.input_processor( + self._process_decoder_only_prompt( + request_id, + inputs, + lora_request, + prompt_adapter_request, + )) def add_request( self, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 9fe13590d5c9b..b04a68e36b02d 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,7 @@ from .data import (ExplicitEncoderDecoderPrompt, ExplicitEncoderDecoderPromptStrict, LLMInputs, ParsedText, ParsedTokens, PromptInputs, PromptStrictInputs, TextPrompt, - TextTokensPrompt, TokensPrompt, get_single_prompt_type, + TextTokensPrompt, TokensPrompt, get_prompt_type, is_valid_encoder_decoder_llm_inputs, is_valid_encoder_decoder_prompt, parse_and_batch_prompt) from .registry import InputContext, InputRegistry @@ -28,7 +28,7 @@ "INPUT_REGISTRY", "InputContext", "InputRegistry", - "get_single_prompt_type", + "get_prompt_type", "is_valid_encoder_decoder_prompt", "is_valid_encoder_decoder_llm_inputs", "ExplicitEncoderDecoderPromptStrict", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 176d3711dcbe4..d9459bb44d1d2 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -160,7 +160,7 @@ class ExplicitEncoderDecoderPromptStrict(TypedDict): """All possible input prompt options, strict or non-strict""" -def get_single_prompt_type(prompt: AllPromptInputs, ) -> str: +def get_prompt_type(prompt: AllPromptInputs, ) -> str: """ Get the type-name of the prompt argument instance, given that isinstance() cannot apply to TypedDict subclasses directly. @@ -200,7 +200,7 @@ def is_valid_encoder_decoder_prompt(prompt: AllPromptInputs, ) -> bool: # does not understand that # get_single_prompt_type(prompt) == 'ExplicitEncoderDecoder' narrows # down the possible types - if (get_single_prompt_type(prompt) == 'ExplicitEncoderDecoder' and + if (get_prompt_type(prompt) == 'ExplicitEncoderDecoder' and (prompt['encoder_prompt'] is None # type: ignore or prompt['decoder_prompt']['multi_modal_data'] # type: ignore is not None)): From 713d095b4036404f4580225720da17d7d4e431cb Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 16 Jul 2024 14:49:17 -0400 Subject: [PATCH 349/443] incorporated encoder sequence into request-add functionality --- tests/conftest.py | 11 ++++------- vllm/engine/llm_engine.py | 28 ++++++++++++++++++++++------ vllm/utils.py | 8 ++++++++ 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 95548a7066d55..f4f29cd27d07a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs from vllm.utils import (cuda_device_count_stateless, is_cpu, - zip_enc_dec_prompt_lists) + to_enc_dec_tuple_list, zip_enc_dec_prompt_lists) logger = init_logger(__name__) @@ -426,7 +426,8 @@ def generate_encoder_decoder_greedy_logprobs_limit( all_output_ids: List[List[int]] = [] all_output_strs: List[str] = [] - for encoder_prompt, decoder_prompt in zip(*encoder_decoder_prompts): + for (encoder_prompt, + decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts): encoder_input_ids = self.tokenizer(encoder_prompt, return_tensors="pt").input_ids decoder_input_ids = self.tokenizer(decoder_prompt, @@ -607,11 +608,7 @@ def generate_encoder_decoder_w_logprobs( ''' assert sampling_params.logprobs is not None - - prompt_inputs = list( - zip(encoder_decoder_prompts[0], encoder_decoder_prompts[1])) - - req_outputs = self.model.generate(prompt_inputs, + req_outputs = self.model.generate(encoder_decoder_prompts, sampling_params=sampling_params) return self._final_steps_generate_w_logprobs(req_outputs) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1c30277291bfd..d3e38e7fbc37d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -511,6 +511,16 @@ def _add_processed_request( seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request) + encoder_seq = None + if 'encoder_prompt_token_ids' in processed_inputs: + encoder_seq = Sequence(seq_id, + processed_inputs, + block_size, + eos_token_id, + lora_request, + prompt_adapter_request, + from_decoder_prompt=False) + # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): seq_group = self._create_sequence_group_with_sampling( @@ -520,7 +530,8 @@ def _add_processed_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( request_id, @@ -528,7 +539,8 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq) else: raise ValueError( "Either SamplingParams or PoolingParams must be provided.") @@ -549,14 +561,14 @@ def stop_remote_worker_execution_loop(self) -> None: def _tokenize_prompt( self, request_id, - inputs, + prompt, lora_request, ) -> List[int]: tokenizer = self.get_tokenizer_group("prompts must be None if " "skip_tokenizer_init is True") return tokenizer.encode(request_id=request_id, - prompt=inputs["prompt"], + prompt=prompt, lora_request=lora_request) def _extract_single_prompt( @@ -815,6 +827,7 @@ def _create_sequence_group_with_sampling( lora_request: Optional[LoRARequest], trace_headers: Optional[Dict[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + encoder_seq: Optional[Sequence] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -840,7 +853,8 @@ def _create_sequence_group_with_sampling( sampling_params=sampling_params, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq) return seq_group @@ -852,6 +866,7 @@ def _create_sequence_group_with_pooling( arrival_time: float, lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], + encoder_seq: Optional[Sequence] = None, ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -863,7 +878,8 @@ def _create_sequence_group_with_pooling( arrival_time=arrival_time, lora_request=lora_request, pooling_params=pooling_params, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: diff --git a/vllm/utils.py b/vllm/utils.py index 8152d86b4effb..b8b9e8dc7c244 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1018,3 +1018,11 @@ def zip_enc_dec_prompt_lists( for (encoder_prompt, decoder_prompt) in zip(enc_prompt_list, dec_prompt_list) ] + + +def to_enc_dec_tuple_list( + enc_dec_prompts: List[ExplicitEncoderDecoderPrompt], +) -> List[Tuple[PromptInputs, PromptInputs]]: + return [(enc_dec_prompt['encoder_prompt'], + enc_dec_prompt['decoder_prompt']) + for enc_dec_prompt in enc_dec_prompts] From 159c7bcf47aa86e4abbd88ad72a34e196c56626e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 16 Jul 2024 21:58:15 -0400 Subject: [PATCH 350/443] fixed decoder-only bug --- vllm/engine/llm_engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d3e38e7fbc37d..3d946b5dd7d7f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -696,15 +696,17 @@ def _process_decoder_only_prompt( prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: if isinstance(inputs, str): + prompt = inputs inputs = {"prompt": inputs} if "prompt_token_ids" not in inputs: prompt_token_ids = self._tokenize_prompt( request_id, - inputs, + prompt, lora_request, ) else: + prompt = inputs.get("prompt") prompt_token_ids = inputs["prompt_token_ids"] if prompt_adapter_request: @@ -713,7 +715,7 @@ def _process_decoder_only_prompt( + prompt_token_ids) return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=inputs.get("prompt"), + prompt=prompt, multi_modal_data=inputs.get("multi_modal_data")) def process_model_inputs( From 16c9aa2278e7f9d9b5f5ccffb085b0142a7e20ec Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 16 Jul 2024 22:36:44 -0400 Subject: [PATCH 351/443] bugfix --- vllm/engine/llm_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3d946b5dd7d7f..bc2cf251c0342 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -696,8 +696,8 @@ def _process_decoder_only_prompt( prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: if isinstance(inputs, str): - prompt = inputs inputs = {"prompt": inputs} + prompt = inputs.get("prompt") if "prompt_token_ids" not in inputs: prompt_token_ids = self._tokenize_prompt( @@ -706,7 +706,6 @@ def _process_decoder_only_prompt( lora_request, ) else: - prompt = inputs.get("prompt") prompt_token_ids = inputs["prompt_token_ids"] if prompt_adapter_request: From 03aea187652fc0418d9a66f7eb5af6bc53c9e535 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 00:34:45 -0400 Subject: [PATCH 352/443] wip --- vllm/engine/llm_engine.py | 121 +++++++++++++++++++++++++++++++++++++- vllm/utils.py | 1 - 2 files changed, 118 insertions(+), 4 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bc2cf251c0342..3b5332720a45c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,4 +1,5 @@ import time +import torch from contextlib import contextmanager from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional from typing import Sequence as GenericSequence @@ -493,6 +494,30 @@ def _get_eos_token_id( return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + def _get_decoder_start_token_id( + self, + ) -> Optional[int]: + ''' + Obtain the decoder start token id employed by an encoder/decoder + model. Returns None for non-encoder/decoder models or if the + model config is unavailable. + ''' + + if not self.is_encoder_decoder_model(): + logger.warning("Using None for decoder start token id because " + "this is not an encoder/decoder model.") + return None + + if (self.model_config is None or + self.model_config.hf_config is None): + logger.warning("Using None for decoder start token id because " + "model config is not available.") + return None + + return getattr(self.model_config.hf_config, + 'decoder_start_token_id', + None) + def _add_processed_request( self, request_id: str, @@ -558,18 +583,108 @@ def stop_remote_worker_execution_loop(self) -> None: _LLMInputComponentsType = Tuple[str, List[int], ] + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + #model_kwargs: Dict[str, torch.Tensor], + decoder_input_ids: Union[List[int], torch.Tensor], + decoder_start_token_id: Union[int, List[int], torch.Tensor], + device: torch.device = None, + ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + """ + Prepares `decoder_input_ids` for generation with encoder-decoder models. + + Based on + + https://github.com/huggingface/transformers/blob/ + 4037a2b5b1278736e566aec12e169100275545ea/ + src/transformers/generation/utils.py + + specifically GenerationMixin._prepare_decoder_input_ids_for_generation() + """ + + # Cast decoder_start_token_id to torch.Tensor, if not already + if isinstance(decoder_start_token_id,int): + decoder_start_token_id=torch.tensor([decoder_start_token_id], dtype=torch.int) + elif isinstance(decoder_start_token_id,list): + assert len(decoder_start_token_id) > 0 + assert isinstance(decoder_start_token_id[0],int) + decoder_start_token_id=torch.tensor(decoder_start_token_id, dtype=torch.int) + + # Cast decoder_input_ids to torch.Tensor, if not already + + # # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + # if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + # decoder_input_ids = model_kwargs.pop("decoder_input_ids") + # elif "input_ids" in model_kwargs and model_input_name != "input_ids": + # decoder_input_ids = model_kwargs.pop("input_ids") + # else: + # decoder_input_ids = None + + # 2. `decoder_start_token_id` must have shape (batch_size, 1) + # if device is None: + # device = self.device + if decoder_start_token_id.ndim == 1: + if decoder_start_token_id.shape[0] != batch_size: + raise ValueError( + f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" + ) + decoder_start_token_id = decoder_start_token_id.view(-1, 1) + else: + decoder_start_token_id = ( + torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + ) + + # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_start_token_id + # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the + # original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic. + # See: https://github.com/huggingface/transformers/pull/31470 + elif "donut" in self.__class__.__name__.lower() or ( + self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower() + ): + pass + elif self.config.model_type in ["whisper"]: + pass + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item(): + decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask + + return decoder_input_ids, model_kwargs + def _tokenize_prompt( self, request_id, prompt, lora_request, + is_enc_dec_decoder=False, ) -> List[int]: tokenizer = self.get_tokenizer_group("prompts must be None if " "skip_tokenizer_init is True") - return tokenizer.encode(request_id=request_id, - prompt=prompt, - lora_request=lora_request) + prompt_token_ids = tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + + if is_enc_dec_decoder: + # Tokenizer decoder prompt *in the context + # of an encoder/decoder model* + pass + + # Decoder-only tokenized prompt + return prompt_token_ids def _extract_single_prompt( self, diff --git a/vllm/utils.py b/vllm/utils.py index b8b9e8dc7c244..7b256512fd283 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -980,7 +980,6 @@ def parse_args(self, args=None, namespace=None): return super().parse_args(processed_args, namespace) - def is_encoder_decoder_model_config(model_config) -> bool: ''' Extract the HF encoder/decoder model flag from the ModelConfig instance. From ef80c85f7dd3febc9c76c793427c444f9e62caa6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 00:35:57 -0400 Subject: [PATCH 353/443] wip --- vllm/engine/llm_engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3b5332720a45c..8df3c8cb7914f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -613,6 +613,8 @@ def _prepare_decoder_input_ids_for_generation( decoder_start_token_id=torch.tensor(decoder_start_token_id, dtype=torch.int) # Cast decoder_input_ids to torch.Tensor, if not already + if isinstance(decoder_input_ids,list): + assert (len(decoder_input_ids)==0 or # # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, # # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. From f8dd4a5955ec478720531c47945ddc26e450f743 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 00:43:52 -0400 Subject: [PATCH 354/443] fixed scheduler bug --- vllm/core/scheduler.py | 5 +- vllm/engine/llm_engine.py | 164 +++++++++++++++++++------------------- 2 files changed, 83 insertions(+), 86 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index be5555358149f..e72bd6f1303e9 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -268,7 +268,6 @@ def __init__( cache_config: CacheConfig, lora_config: Optional[LoRAConfig], pipeline_parallel_size: int = 1, - is_encoder_decoder=False, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -277,8 +276,6 @@ def __init__( # LoRAs. This should be improved in the future. self.lora_config = lora_config - self.is_encoder_decoder = is_encoder_decoder - version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" @@ -394,7 +391,7 @@ def _free_seq_group( Free a sequence group from a cross-attention block table. Has no effect on decoder-only models. """ - if self.is_encoder_decoder: + if seq_group.is_encoder_decoder(): self.block_manager.free_cross(seq_group) def has_unfinished_seqs(self) -> bool: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8df3c8cb7914f..fbca4481cf6c0 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -583,88 +583,88 @@ def stop_remote_worker_execution_loop(self) -> None: _LLMInputComponentsType = Tuple[str, List[int], ] - def _prepare_decoder_input_ids_for_generation( - self, - batch_size: int, - model_input_name: str, - #model_kwargs: Dict[str, torch.Tensor], - decoder_input_ids: Union[List[int], torch.Tensor], - decoder_start_token_id: Union[int, List[int], torch.Tensor], - device: torch.device = None, - ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: - """ - Prepares `decoder_input_ids` for generation with encoder-decoder models. - - Based on - - https://github.com/huggingface/transformers/blob/ - 4037a2b5b1278736e566aec12e169100275545ea/ - src/transformers/generation/utils.py - - specifically GenerationMixin._prepare_decoder_input_ids_for_generation() - """ - - # Cast decoder_start_token_id to torch.Tensor, if not already - if isinstance(decoder_start_token_id,int): - decoder_start_token_id=torch.tensor([decoder_start_token_id], dtype=torch.int) - elif isinstance(decoder_start_token_id,list): - assert len(decoder_start_token_id) > 0 - assert isinstance(decoder_start_token_id[0],int) - decoder_start_token_id=torch.tensor(decoder_start_token_id, dtype=torch.int) - - # Cast decoder_input_ids to torch.Tensor, if not already - if isinstance(decoder_input_ids,list): - assert (len(decoder_input_ids)==0 or - - # # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, - # # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. - # if model_kwargs is not None and "decoder_input_ids" in model_kwargs: - # decoder_input_ids = model_kwargs.pop("decoder_input_ids") - # elif "input_ids" in model_kwargs and model_input_name != "input_ids": - # decoder_input_ids = model_kwargs.pop("input_ids") - # else: - # decoder_input_ids = None - - # 2. `decoder_start_token_id` must have shape (batch_size, 1) - # if device is None: - # device = self.device - if decoder_start_token_id.ndim == 1: - if decoder_start_token_id.shape[0] != batch_size: - raise ValueError( - f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" - ) - decoder_start_token_id = decoder_start_token_id.view(-1, 1) - else: - decoder_start_token_id = ( - torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id - ) - - # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. - # no user input -> use decoder_start_token_id as decoder_input_ids - if decoder_input_ids is None: - decoder_input_ids = decoder_start_token_id - # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the - # original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic. - # See: https://github.com/huggingface/transformers/pull/31470 - elif "donut" in self.__class__.__name__.lower() or ( - self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower() - ): - pass - elif self.config.model_type in ["whisper"]: - pass - # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust - # decoder_attention_mask if provided) - elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item(): - decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1) - if "decoder_attention_mask" in model_kwargs: - decoder_attention_mask = model_kwargs["decoder_attention_mask"] - decoder_attention_mask = torch.cat( - (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), - dim=-1, - ) - model_kwargs["decoder_attention_mask"] = decoder_attention_mask - - return decoder_input_ids, model_kwargs + # def _prepare_decoder_input_ids_for_generation( + # self, + # batch_size: int, + # model_input_name: str, + # #model_kwargs: Dict[str, torch.Tensor], + # decoder_input_ids: Union[List[int], torch.Tensor], + # decoder_start_token_id: Union[int, List[int], torch.Tensor], + # device: torch.device = None, + # ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + # """ + # Prepares `decoder_input_ids` for generation with encoder-decoder models. + + # Based on + + # https://github.com/huggingface/transformers/blob/ + # 4037a2b5b1278736e566aec12e169100275545ea/ + # src/transformers/generation/utils.py + + # specifically GenerationMixin._prepare_decoder_input_ids_for_generation() + # """ + + # # Cast decoder_start_token_id to torch.Tensor, if not already + # if isinstance(decoder_start_token_id,int): + # decoder_start_token_id=torch.tensor([decoder_start_token_id], dtype=torch.int) + # elif isinstance(decoder_start_token_id,list): + # assert len(decoder_start_token_id) > 0 + # assert isinstance(decoder_start_token_id[0],int) + # decoder_start_token_id=torch.tensor(decoder_start_token_id, dtype=torch.int) + + # # Cast decoder_input_ids to torch.Tensor, if not already + # if isinstance(decoder_input_ids,list): + # assert (len(decoder_input_ids)==0 or + + # # # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # # # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + # # if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + # # decoder_input_ids = model_kwargs.pop("decoder_input_ids") + # # elif "input_ids" in model_kwargs and model_input_name != "input_ids": + # # decoder_input_ids = model_kwargs.pop("input_ids") + # # else: + # # decoder_input_ids = None + + # # 2. `decoder_start_token_id` must have shape (batch_size, 1) + # # if device is None: + # # device = self.device + # if decoder_start_token_id.ndim == 1: + # if decoder_start_token_id.shape[0] != batch_size: + # raise ValueError( + # f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" + # ) + # decoder_start_token_id = decoder_start_token_id.view(-1, 1) + # else: + # decoder_start_token_id = ( + # torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + # ) + + # # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + # # no user input -> use decoder_start_token_id as decoder_input_ids + # if decoder_input_ids is None: + # decoder_input_ids = decoder_start_token_id + # # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the + # # original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic. + # # See: https://github.com/huggingface/transformers/pull/31470 + # elif "donut" in self.__class__.__name__.lower() or ( + # self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower() + # ): + # pass + # elif self.config.model_type in ["whisper"]: + # pass + # # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # # decoder_attention_mask if provided) + # elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item(): + # decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1) + # if "decoder_attention_mask" in model_kwargs: + # decoder_attention_mask = model_kwargs["decoder_attention_mask"] + # decoder_attention_mask = torch.cat( + # (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + # dim=-1, + # ) + # model_kwargs["decoder_attention_mask"] = decoder_attention_mask + + # return decoder_input_ids, model_kwargs def _tokenize_prompt( self, From c2ff615deebea4457721a457103d8e405346b1a5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 00:44:16 -0400 Subject: [PATCH 355/443] format --- vllm/engine/llm_engine.py | 14 +++++--------- vllm/utils.py | 1 + 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fbca4481cf6c0..4d6736af8f50b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -494,28 +494,24 @@ def _get_eos_token_id( return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id - def _get_decoder_start_token_id( - self, - ) -> Optional[int]: + def _get_decoder_start_token_id(self, ) -> Optional[int]: ''' Obtain the decoder start token id employed by an encoder/decoder model. Returns None for non-encoder/decoder models or if the model config is unavailable. ''' - + if not self.is_encoder_decoder_model(): logger.warning("Using None for decoder start token id because " "this is not an encoder/decoder model.") return None - if (self.model_config is None or - self.model_config.hf_config is None): + if (self.model_config is None or self.model_config.hf_config is None): logger.warning("Using None for decoder start token id because " "model config is not available.") return None - return getattr(self.model_config.hf_config, - 'decoder_start_token_id', + return getattr(self.model_config.hf_config, 'decoder_start_token_id', None) def _add_processed_request( @@ -614,7 +610,7 @@ def stop_remote_worker_execution_loop(self) -> None: # # Cast decoder_input_ids to torch.Tensor, if not already # if isinstance(decoder_input_ids,list): - # assert (len(decoder_input_ids)==0 or + # assert (len(decoder_input_ids)==0 or # # # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, # # # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. diff --git a/vllm/utils.py b/vllm/utils.py index 7b256512fd283..b8b9e8dc7c244 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -980,6 +980,7 @@ def parse_args(self, args=None, namespace=None): return super().parse_args(processed_args, namespace) + def is_encoder_decoder_model_config(model_config) -> bool: ''' Extract the HF encoder/decoder model flag from the ModelConfig instance. From 1c6e06d0be66bf8cbf98cc8429a060b60bb65700 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 02:10:12 -0400 Subject: [PATCH 356/443] bugfix --- vllm/engine/llm_engine.py | 176 +++++++++++++++++++------------------- 1 file changed, 90 insertions(+), 86 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4d6736af8f50b..e4ecdaf4fa622 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,10 +1,10 @@ import time -import torch from contextlib import contextmanager from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional from typing import Sequence as GenericSequence from typing import Set, Tuple, Type, TypeVar, Union +import torch from transformers import PreTrainedTokenizer from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, @@ -579,88 +579,73 @@ def stop_remote_worker_execution_loop(self) -> None: _LLMInputComponentsType = Tuple[str, List[int], ] - # def _prepare_decoder_input_ids_for_generation( - # self, - # batch_size: int, - # model_input_name: str, - # #model_kwargs: Dict[str, torch.Tensor], - # decoder_input_ids: Union[List[int], torch.Tensor], - # decoder_start_token_id: Union[int, List[int], torch.Tensor], - # device: torch.device = None, - # ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: - # """ - # Prepares `decoder_input_ids` for generation with encoder-decoder models. - - # Based on - - # https://github.com/huggingface/transformers/blob/ - # 4037a2b5b1278736e566aec12e169100275545ea/ - # src/transformers/generation/utils.py - - # specifically GenerationMixin._prepare_decoder_input_ids_for_generation() - # """ - - # # Cast decoder_start_token_id to torch.Tensor, if not already - # if isinstance(decoder_start_token_id,int): - # decoder_start_token_id=torch.tensor([decoder_start_token_id], dtype=torch.int) - # elif isinstance(decoder_start_token_id,list): - # assert len(decoder_start_token_id) > 0 - # assert isinstance(decoder_start_token_id[0],int) - # decoder_start_token_id=torch.tensor(decoder_start_token_id, dtype=torch.int) - - # # Cast decoder_input_ids to torch.Tensor, if not already - # if isinstance(decoder_input_ids,list): - # assert (len(decoder_input_ids)==0 or - - # # # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, - # # # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. - # # if model_kwargs is not None and "decoder_input_ids" in model_kwargs: - # # decoder_input_ids = model_kwargs.pop("decoder_input_ids") - # # elif "input_ids" in model_kwargs and model_input_name != "input_ids": - # # decoder_input_ids = model_kwargs.pop("input_ids") - # # else: - # # decoder_input_ids = None - - # # 2. `decoder_start_token_id` must have shape (batch_size, 1) - # # if device is None: - # # device = self.device - # if decoder_start_token_id.ndim == 1: - # if decoder_start_token_id.shape[0] != batch_size: - # raise ValueError( - # f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" - # ) - # decoder_start_token_id = decoder_start_token_id.view(-1, 1) - # else: - # decoder_start_token_id = ( - # torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id - # ) - - # # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. - # # no user input -> use decoder_start_token_id as decoder_input_ids - # if decoder_input_ids is None: - # decoder_input_ids = decoder_start_token_id - # # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the - # # original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic. - # # See: https://github.com/huggingface/transformers/pull/31470 - # elif "donut" in self.__class__.__name__.lower() or ( - # self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower() - # ): - # pass - # elif self.config.model_type in ["whisper"]: - # pass - # # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust - # # decoder_attention_mask if provided) - # elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item(): - # decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1) - # if "decoder_attention_mask" in model_kwargs: - # decoder_attention_mask = model_kwargs["decoder_attention_mask"] - # decoder_attention_mask = torch.cat( - # (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), - # dim=-1, - # ) - # model_kwargs["decoder_attention_mask"] = decoder_attention_mask - - # return decoder_input_ids, model_kwargs + def _prepare_decoder_input_ids_for_generation( + self, + decoder_input_ids: Optional[Union[List[int], torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Prepares `decoder_input_ids` for generation with encoder-decoder models. + + Based on + + https://github.com/huggingface/transformers/blob/ + 4037a2b5b1278736e566aec12e169100275545ea/ + src/transformers/generation/utils.py + + specifically GenerationMixin._prepare_decoder_input_ids_for_generation() + """ + + decoder_start_token_id: Optional[ + Union[int, List[int], + torch.Tensor]] = (self._get_decoder_start_token_id()) + + # Cast decoder_start_token_id to torch.Tensor, if not already + if isinstance(decoder_start_token_id, int): + decoder_start_token_id = torch.tensor([decoder_start_token_id], + dtype=torch.int) + elif isinstance(decoder_start_token_id, list): + assert len(decoder_start_token_id) > 0 + assert isinstance(decoder_start_token_id[0], int) + decoder_start_token_id = torch.tensor(decoder_start_token_id, + dtype=torch.int) + else: + assert isinstance(decoder_start_token_id, torch.Tensor) + + decoder_start_token_id = decoder_start_token_id.view(-1, 1) + + # Cast decoder_input_ids to torch.Tensor, if not already + originally_list = False + if isinstance(decoder_input_ids, list): + assert (len(decoder_input_ids) == 0 + or isinstance(decoder_input_ids[0], int)) + decoder_input_ids = torch.tensor(decoder_input_ids, + dtype=torch.int) + originally_list = True + + if decoder_input_ids is not None: + assert isinstance(decoder_input_ids, torch.Tensor) + # Reshape: (batch_size=1,num_tokens) + decoder_input_ids = decoder_input_ids.view(1, -1) + + if decoder_input_ids is None: + # no user input -> use decoder_start_token_id as decoder_input_ids + decoder_input_ids = decoder_start_token_id + elif (decoder_input_ids[:, 0] != + decoder_start_token_id[:, 0]).all().item(): + # Encoder-decoder models expect the `decoder_input_ids` to start + # with a special token. Let's ensure that. + decoder_input_ids = (torch.cat( + [decoder_start_token_id, decoder_input_ids], + dim=-1, + )) + + assert isinstance(decoder_input_ids, torch.Tensor) + decoder_input_ids = decoder_input_ids.view(-1) + + if originally_list: + return decoder_input_ids.tolist() + else: + return decoder_input_ids def _tokenize_prompt( self, @@ -672,14 +657,26 @@ def _tokenize_prompt( tokenizer = self.get_tokenizer_group("prompts must be None if " "skip_tokenizer_init is True") + if is_enc_dec_decoder and prompt == "": + # Scenario: enc/dec model, decoder input prompt is "" + # => Treat it as None (no decoder input prompt provided) + # & obtain default decoder input prompt + return self._prepare_decoder_input_ids_for_generation() + + # Scenario: + # * Any decoder-only input prompt + # * Enc/dec model, non-empty-str decoder input prompt + # => Tokenize prompt prompt_token_ids = tokenizer.encode(request_id=request_id, prompt=prompt, lora_request=lora_request) if is_enc_dec_decoder: - # Tokenizer decoder prompt *in the context - # of an encoder/decoder model* - pass + # Scenario: enc/dec model, non-empty-str decoder input prompt + # which was just tokenized + # => perform decoder-specific preprocessing + return self._prepare_decoder_input_ids_for_generation( + prompt_token_ids, ) # Decoder-only tokenized prompt return prompt_token_ids @@ -690,6 +687,7 @@ def _extract_single_prompt( inputs: PromptInputs, lora_request: Optional[LoRARequest], is_encoder_prompt: bool = False, + is_enc_dec_model: bool = False, ) -> Tuple[str, List[int], Optional["MultiModalDataDict"]]: ''' Extract prompt & prompt_token_ids from any single @@ -709,6 +707,8 @@ def _extract_single_prompt( * multi_modal_data (None if is_encoder_prompt) ''' + is_enc_dec_decoder = ((not is_encoder_prompt) and is_enc_dec_model) + if isinstance(inputs, str): # prompt = inputs # prompt_token_ids = tokenize(inputs) @@ -718,6 +718,7 @@ def _extract_single_prompt( request_id, inputs, lora_request, + is_enc_dec_decoder=is_enc_dec_decoder, ), None) # Tokenize @@ -727,6 +728,7 @@ def _extract_single_prompt( request_id, inputs, lora_request, + is_enc_dec_decoder=is_enc_dec_decoder, )) if is_encoder_prompt: @@ -767,6 +769,7 @@ def _process_encoder_decoder_prompt(self, request_id: str, == "ExplicitEncoderDecoder" else inputs), lora_request, is_encoder_prompt=True, + is_enc_dec_model=True, ) # Obtain decoder prompt @@ -783,6 +786,7 @@ def _process_encoder_decoder_prompt(self, request_id: str, inputs.get('decoder_prompt'), lora_request, is_encoder_prompt=False, + is_enc_dec_model=True, ) else: # User supplied a single prompt (implicitly From aee5f1615347dcfe2acea9abe16ac61df3404a99 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 06:14:51 -0400 Subject: [PATCH 357/443] fixed sequence bug --- vllm/model_executor/models/bart.py | 29 ----------------------------- vllm/sequence.py | 2 +- 2 files changed, 1 insertion(+), 30 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 1be4813c6d242..5d7d9845556e3 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -818,35 +818,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): "Conflicting embedding weights.") shared_embedding_weight = loaded_weight shared_embedding_shard_id = shard_id - - # encoder_in_param = model_params_dict[ - # 'encoder.embed_tokens.weight'] - # encoder_in_weight_loader = getattr(encoder_in_param, - # "weight_loader", - # default_weight_loader) - - # decoder_in_param = model_params_dict[ - # 'decoder.embed_tokens.weight'] - # decoder_in_weight_loader = getattr(decoder_in_param, - # "weight_loader", - # default_weight_loader) - - # lm_head_in_param = top_params_dict['lm_head.weight'] - # lm_head_in_weight_loader = getattr(lm_head_in_param, - # "weight_loader", - # default_weight_loader) - - # if shard_id: - # encoder_in_weight_loader(encoder_in_param, loaded_weight, - # shard_id) - # decoder_in_weight_loader(decoder_in_param, loaded_weight, - # shard_id) - # lm_head_in_weight_loader(lm_head_in_param, loaded_weight, - # shard_id) - # else: - # encoder_in_weight_loader(encoder_in_param, loaded_weight) - # decoder_in_weight_loader(decoder_in_param, loaded_weight) - # lm_head_in_weight_loader(lm_head_in_param, loaded_weight) else: # Skip the specific downstream task weight. if name.startswith('cls.'): diff --git a/vllm/sequence.py b/vllm/sequence.py index b9233d223cbb6..408c02184dc12 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -259,7 +259,7 @@ def __init__(self, self.eos_token_id = eos_token_id self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - self.from_decoder_prompt = True + self.from_decoder_prompt = from_decoder_prompt self._prompt: Optional[str] = None self._prompt_token_ids: Optional[List[int]] = None From ef94623218a718a437526917a8c95e933d614ee9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 07:16:10 -0400 Subject: [PATCH 358/443] added examples utils w/ context manager for backend override; applied to enc/dec example to force XFormers --- examples/offline_inference_encoder_decoder.py | 42 ++++++++++-------- examples/utils.py | 44 +++++++++++++++++++ 2 files changed, 68 insertions(+), 18 deletions(-) create mode 100644 examples/utils.py diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 85d927e79635f..819fd46f4a572 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -1,7 +1,8 @@ from transformers import AutoTokenizer, BartForConditionalGeneration +from utils import override_backend_env_var_context_manager from vllm import LLM, SamplingParams -from vllm.utils import zip_enc_dec_prompt_lists +from vllm.utils import STR_XFORMERS_ATTN_VAL, zip_enc_dec_prompt_lists dtype = "float" @@ -30,24 +31,29 @@ print(prompts) -# Create a sampling params object. -sampling_params = SamplingParams( - temperature=0, - top_p=1.0, - min_tokens=0, - max_tokens=20, -) +with override_backend_env_var_context_manager(STR_XFORMERS_ATTN_VAL): + # Force usage of XFormers backend which supports + # encoder attention & encoder/decoder cross-attention -# Create an LLM. -llm = LLM(model="facebook/bart-large-cnn", enforce_eager=True, dtype=dtype) -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + # Create a sampling params object. + sampling_params = SamplingParams( + temperature=0, + top_p=1.0, + min_tokens=0, + max_tokens=20, + ) + + # Create an LLM. + llm = LLM(model="facebook/bart-large-cnn", enforce_eager=True, dtype=dtype) + # Generate texts from the prompts. The output is a list of + # RequestOutput objects that contain the prompt, generated + # text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") diff --git a/examples/utils.py b/examples/utils.py new file mode 100644 index 0000000000000..6705a81c82945 --- /dev/null +++ b/examples/utils.py @@ -0,0 +1,44 @@ +'''Example code utils''' + +from vllm.utils import STR_BACKEND_ENV_VAR +from typing import Generator + +import os +from contextlib import contextmanager + +@contextmanager +def override_backend_env_var_context_manager(backend_name: str, + ) -> Generator[None, None, None]: + ''' + Override the environment variable indicating the vLLM backend temporarily, + in a context where pytest monkeypatch is not available (i.e. *outside* + the context of a unit test, such as in an example code file.) + + Accomplish this using a custom context manager. + + Arguments: + + * backend_name: attention backend name to force + + Returns: + + * Generator + ''' + + key = STR_BACKEND_ENV_VAR + + # Save the current state of the environment variable (if it exists) + original_value = os.environ.get(key, None) + + # Set the new value of the environment variable + os.environ[key] = backend_name + + # Yield control back to the enclosed code block + try: + yield + finally: + # Revert the environment variable to its original state + if original_value is None: + os.environ.pop(key, None) # Remove the variable if it wasn't originally set + else: + os.environ[key] = original_value # Revert back to the original value \ No newline at end of file From b277180575d7d9c85708e2622cc6c32afbc0a383 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 07:17:40 -0400 Subject: [PATCH 359/443] formatting --- examples/utils.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/utils.py b/examples/utils.py index 6705a81c82945..497bee7592a09 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -1,14 +1,15 @@ '''Example code utils''' -from vllm.utils import STR_BACKEND_ENV_VAR -from typing import Generator - import os from contextlib import contextmanager +from typing import Generator + +from vllm.utils import STR_BACKEND_ENV_VAR + @contextmanager -def override_backend_env_var_context_manager(backend_name: str, - ) -> Generator[None, None, None]: +def override_backend_env_var_context_manager( + backend_name: str, ) -> Generator[None, None, None]: ''' Override the environment variable indicating the vLLM backend temporarily, in a context where pytest monkeypatch is not available (i.e. *outside* @@ -29,16 +30,18 @@ def override_backend_env_var_context_manager(backend_name: str, # Save the current state of the environment variable (if it exists) original_value = os.environ.get(key, None) - + # Set the new value of the environment variable os.environ[key] = backend_name - + # Yield control back to the enclosed code block try: yield finally: # Revert the environment variable to its original state if original_value is None: - os.environ.pop(key, None) # Remove the variable if it wasn't originally set + os.environ.pop( + key, None) # Remove the variable if it wasn't originally set else: - os.environ[key] = original_value # Revert back to the original value \ No newline at end of file + os.environ[ + key] = original_value # Revert back to the original value From cac6283f60f1edc55950eaae54e74db0902ebfd8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 07:25:58 -0400 Subject: [PATCH 360/443] added encoder/decoder example to examples test --- .buildkite/test-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 445d74d6d9bbe..186d0472f1cb7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -144,6 +144,7 @@ steps: - python3 llm_engine_example.py - python3 llava_example.py - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors + - python3 offline_inference_encoder_decoder.py - label: Inputs Test #mirror_hardwares: [amd] From f54f2762f4b4d14165371e3dfc300f1ef3afa9b6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 07:53:12 -0400 Subject: [PATCH 361/443] wip refactoring --- tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conftest.py b/tests/conftest.py index f4f29cd27d07a..dec1f1815b2e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -121,6 +121,7 @@ def example_encoder_decoder_prompts() -> Tuple[List[str], List[str]]: decoder prompt) tuple. Returns: + * Encoder prompt list * Decoder prompt list (reverse of encoder prompt list) ''' From 597a07da54fa4c399e42bccbb4a14957d782e37c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 07:59:42 -0400 Subject: [PATCH 362/443] refactor --- .../test_encoder_decoder_model_runner.py | 25 +------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 51b5f3b58c500..cc621ed485173 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -457,27 +457,4 @@ def test_prepare_decode(batch_size, backend_name, enforce_eager, monkeypatch): torch.testing.assert_close( encoder_input_tokens, encoder_input_positions, - ) - - # sampling_metadata = SamplingMetadata.prepare( - # seq_group_metadata_list, - # seq_lens, - # query_lens=seq_lens, - # device=model_runner.device, - # pin_memory=model_runner.pin_memory, - # ) - - # actual = sampling_metadata.selected_token_indices - # expected = torch.tensor( - # expected_selected_token_indices, - # device=actual.device, - # dtype=actual.dtype, - # ) - # torch.testing.assert_close(actual, expected) - # torch.allclose(input_tokens, input_positions) - - # actual = sampling_metadata.selected_token_indices - # expected = torch.tensor(expected_selected_token_indices, - # device=actual.device, - # dtype=actual.dtype) - # torch.testing.assert_close(actual, expected) + ) \ No newline at end of file From 9f5a02c21e785704114f8c15bb829f4fe4cded55 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 08:27:53 -0400 Subject: [PATCH 363/443] RequestOutput & SequenceGroup now include encoder prompt in output, as does encoder/decoder example. --- examples/offline_inference_encoder_decoder.py | 5 ++++- vllm/outputs.py | 12 +++++++++++- vllm/sequence.py | 16 ++++++++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 819fd46f4a572..93ca6a432e58b 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -52,8 +52,11 @@ # Print the outputs. for output in outputs: prompt = output.prompt + encoder_prompt = output.encoder_prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Encoder prompt: {encoder_prompt!r}, " + f"Decoder prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") diff --git a/vllm/outputs.py b/vllm/outputs.py index 4cb7f06bdb8c7..085b32b862439 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -88,6 +88,8 @@ def __init__( finished: bool, metrics: Optional[RequestMetrics] = None, lora_request: Optional[LoRARequest] = None, + encoder_prompt: Optional[str] = None, + encoder_prompt_token_ids: Optional[List[int]] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -97,6 +99,8 @@ def __init__( self.finished = finished self.metrics = metrics self.lora_request = lora_request + self.encoder_prompt = encoder_prompt + self.encoder_prompt_token_ids = encoder_prompt_token_ids @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": @@ -136,6 +140,8 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # Every sequence in the sequence group should have the same prompt. prompt = seq_group.prompt prompt_token_ids = seq_group.prompt_token_ids + encoder_prompt = seq_group.encoder_prompt + encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids prompt_logprobs = seq_group.prompt_logprobs finished = seq_group.is_finished() finished_time = time.time() if finished else None @@ -147,12 +153,16 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": outputs, finished, seq_group.metrics, - lora_request=seq_group.lora_request) + lora_request=seq_group.lora_request, + encoder_prompt=encoder_prompt, + encoder_prompt_token_ids=encoder_prompt_token_ids) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"encoder_prompt={self.encoder_prompt!r}, " + f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " f"outputs={self.outputs}, " f"finished={self.finished}, " diff --git a/vllm/sequence.py b/vllm/sequence.py index 408c02184dc12..be2900c775ae4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -501,6 +501,22 @@ def prompt_token_ids(self) -> List[int]: # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).prompt_token_ids + @property + def encoder_prompt(self) -> Optional[str]: + # There are either 0 or 1 encoder sequences + # If one is present, its prompt is distinct + # from the decoder's. + return (self.encoder_seq.prompt + if self.encoder_seq is not None else None) + + @property + def encoder_prompt_token_ids(self) -> Optional[List[int]]: + # There are either 0 or 1 encoder sequences + # If one is present, its prompt token ids are + # distinct from the decoder's. + return (self.encoder_seq.prompt_token_ids + if self.encoder_seq is not None else None) + @property def multi_modal_data(self) -> "MultiModalDataDict": # All sequences in the group should have the same multi-modal data. From 94c904fb5ff01f7e1c93b8d4a5f195ca2bea5bc0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 08:43:16 -0400 Subject: [PATCH 364/443] wip parallel bart but encountering GPU count issue --- examples/offline_inference_encoder_decoder.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 93ca6a432e58b..e52b2d4ef1034 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -44,7 +44,10 @@ ) # Create an LLM. - llm = LLM(model="facebook/bart-large-cnn", enforce_eager=True, dtype=dtype) + llm = LLM(model="facebook/bart-large-cnn", + enforce_eager=True, + dtype=dtype, + tensor_parallel_size=2) # Generate texts from the prompts. The output is a list of # RequestOutput objects that contain the prompt, generated # text, and other information. From 1f8c52fac27ed8f10b94a3ecb08e15c1118c186a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 09:34:29 -0400 Subject: [PATCH 365/443] tweaks to enc/dec example --- examples/offline_inference_encoder_decoder.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index e52b2d4ef1034..971950d56d3fc 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -21,7 +21,7 @@ ] # - Decoder prompts decoder_prompts = [ - "", + encoder_prompts[0], "", "", "", @@ -47,7 +47,7 @@ llm = LLM(model="facebook/bart-large-cnn", enforce_eager=True, dtype=dtype, - tensor_parallel_size=2) + ) # Generate texts from the prompts. The output is a list of # RequestOutput objects that contain the prompt, generated # text, and other information. @@ -69,9 +69,18 @@ max_length=1024, return_tensors="pt") +# decoder_inputs = tokenizer([''], +# max_length=1024, +# return_tensors="pt") + # Generate Summary -summary_ids = model.generate(inputs["input_ids"], min_length=0, max_length=20) +summary_ids = model.generate(inputs["input_ids"], + # decoder_input_ids=decoder_inputs["input_ids"], + min_length=0, + max_length=20, + ) print( tokenizer.batch_decode(summary_ids, skip_special_tokens=True, - clean_up_tokenization_spaces=False)) + clean_up_tokenization_spaces=False), + ) From 180884605ffd911c553c6b2585c2993204e4a629 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 09:34:42 -0400 Subject: [PATCH 366/443] formatting --- examples/offline_inference_encoder_decoder.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 971950d56d3fc..0811ca1ef8d7c 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -44,10 +44,11 @@ ) # Create an LLM. - llm = LLM(model="facebook/bart-large-cnn", - enforce_eager=True, - dtype=dtype, - ) + llm = LLM( + model="facebook/bart-large-cnn", + enforce_eager=True, + dtype=dtype, + ) # Generate texts from the prompts. The output is a list of # RequestOutput objects that contain the prompt, generated # text, and other information. @@ -74,13 +75,13 @@ # return_tensors="pt") # Generate Summary -summary_ids = model.generate(inputs["input_ids"], - # decoder_input_ids=decoder_inputs["input_ids"], - min_length=0, - max_length=20, - ) +summary_ids = model.generate( + inputs["input_ids"], + # decoder_input_ids=decoder_inputs["input_ids"], + min_length=0, + max_length=20, +) print( tokenizer.batch_decode(summary_ids, skip_special_tokens=True, - clean_up_tokenization_spaces=False), - ) + clean_up_tokenization_spaces=False), ) From f15eacf140810512335a7ac422b09788a1c1964e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 10:55:46 -0400 Subject: [PATCH 367/443] wip --- tests/conftest.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index dec1f1815b2e0..98eadeecd59a7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -433,6 +433,11 @@ def generate_encoder_decoder_greedy_logprobs_limit( return_tensors="pt").input_ids decoder_input_ids = self.tokenizer(decoder_prompt, return_tensors="pt").input_ids + generation_config = GenerationConfig.from_model_config(self.model.config) + generation_config.do_sample = False + generation_config.top_k = None + generation_config.num_beams = 1 + output = self.model.generate( self.wrap_device(encoder_input_ids), decoder_input_ids=self.wrap_device(decoder_input_ids), @@ -441,6 +446,7 @@ def generate_encoder_decoder_greedy_logprobs_limit( max_new_tokens=max_tokens, output_hidden_states=True, return_dict_in_generate=True, + generation_config=generation_config, ) seq_logprobs: List[torch.Tensor] = [] From 6c940f886950ba0ae77ccb9002a161cf95b686ad Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 11:00:34 -0400 Subject: [PATCH 368/443] modified HF behavior in BART test to be truly greedy --- tests/conftest.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 98eadeecd59a7..ee499786d418e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -433,6 +433,8 @@ def generate_encoder_decoder_greedy_logprobs_limit( return_tensors="pt").input_ids decoder_input_ids = self.tokenizer(decoder_prompt, return_tensors="pt").input_ids + + from transformers.generation.configuration_utils import GenerationConfig generation_config = GenerationConfig.from_model_config(self.model.config) generation_config.do_sample = False generation_config.top_k = None From 949ac02c5694069edf3338b2202717dffda276e6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 11:18:01 -0400 Subject: [PATCH 369/443] formatting --- tests/conftest.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ee499786d418e..c5d327a377b07 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -434,8 +434,10 @@ def generate_encoder_decoder_greedy_logprobs_limit( decoder_input_ids = self.tokenizer(decoder_prompt, return_tensors="pt").input_ids - from transformers.generation.configuration_utils import GenerationConfig - generation_config = GenerationConfig.from_model_config(self.model.config) + from transformers.generation.configuration_utils import ( + GenerationConfig) + generation_config = GenerationConfig.from_model_config( + self.model.config) generation_config.do_sample = False generation_config.top_k = None generation_config.num_beams = 1 From 88c058e8fe5ae00b39f88f57be745d1b819dbca5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 12:23:31 -0400 Subject: [PATCH 370/443] wip parallelizing BART --- examples/offline_inference.py | 31 ++++++++++++------- examples/offline_inference_encoder_decoder.py | 1 + 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f6..7092ca22eb489 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,4 +1,6 @@ from vllm import LLM, SamplingParams +from vllm.utils import STR_XFORMERS_ATTN_VAL +from utils import override_backend_env_var_context_manager # Sample prompts. prompts = [ @@ -7,16 +9,21 @@ "The capital of France is", "The future of AI is", ] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -# Create an LLM. -llm = LLM(model="facebook/opt-125m") -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +with override_backend_env_var_context_manager(STR_XFORMERS_ATTN_VAL): + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + # Create an LLM. + llm = LLM(model="facebook/opt-125m", + enforce_eager=True, + tensor_parallel_size=4) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 0811ca1ef8d7c..b8826ab8f15d7 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -48,6 +48,7 @@ model="facebook/bart-large-cnn", enforce_eager=True, dtype=dtype, + tensor_parallel_size=4, ) # Generate texts from the prompts. The output is a list of # RequestOutput objects that contain the prompt, generated From 31e335fd206985f5b3791b6a3cfaa021d21d3629 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 13:03:58 -0400 Subject: [PATCH 371/443] wip activation parallelization --- vllm/model_executor/models/bart.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 5d7d9845556e3..357b434dc6dbd 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -22,7 +22,8 @@ import torch from torch import nn from transformers import BartConfig -from transformers.activations import ACT2FN +#from transformers.activations import ACT2FN +from vllm.model_executor.layers.activation import get_act_fn from transformers.utils import logging from vllm.attention import Attention, AttentionMetadata, AttentionType @@ -39,7 +40,6 @@ logger = logging.get_logger(__name__) - def get_bsz_seq_len(input_ids): shp = input_ids.shape ndim = len(shp) @@ -308,7 +308,8 @@ def __init__( cache_config=cache_config, quant_config=quant_config) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.activation_fn = ACT2FN[config.activation_function] + #self.activation_fn = ACT2FN[config.activation_function] + self.activation_fn = get_act_fn(config.activation_function, quant_config) self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) From 69f0379d24323958dd9b332884f7c57a222acfc6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 17 Jul 2024 13:23:42 -0400 Subject: [PATCH 372/443] wip: --- examples/offline_inference.py | 6 +++--- vllm/model_executor/models/bart.py | 9 +++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 7092ca22eb489..c10c99c09dda1 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -16,9 +16,9 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. - llm = LLM(model="facebook/opt-125m", - enforce_eager=True, - tensor_parallel_size=4) + llm = LLM(model="facebook/opt-125m", + enforce_eager=True, + tensor_parallel_size=4) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 357b434dc6dbd..19d94e7a4805a 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -22,7 +22,6 @@ import torch from torch import nn from transformers import BartConfig -#from transformers.activations import ACT2FN from vllm.model_executor.layers.activation import get_act_fn from transformers.utils import logging @@ -40,6 +39,7 @@ logger = logging.get_logger(__name__) + def get_bsz_seq_len(input_ids): shp = input_ids.shape ndim = len(shp) @@ -308,8 +308,8 @@ def __init__( cache_config=cache_config, quant_config=quant_config) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - #self.activation_fn = ACT2FN[config.activation_function] - self.activation_fn = get_act_fn(config.activation_function, quant_config) + self.activation_fn = get_act_fn(config.activation_function, + quant_config) self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) @@ -371,7 +371,8 @@ def __init__( config=config, cache_config=cache_config, quant_config=quant_config) - self.activation_fn = ACT2FN[config.activation_function] + self.activation_fn = get_act_fn(config.activation_function, + quant_config) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) ''' From c00e0a8b561a8243080ef40b1c1b8f0b8257d026 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 22 Jul 2024 00:28:29 -0400 Subject: [PATCH 373/443] CommonMetadataBuilder sets block_tables constructor arg of metadata --- tests/conftest.py | 6 +++--- vllm/attention/__init__.py | 10 ++++------ vllm/attention/backends/utils.py | 1 + 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bd447d0966f29..942fb5ada2395 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,9 +22,9 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs -from vllm.utils import (cuda_device_count_stateless, is_cpu, - to_enc_dec_tuple_list, zip_enc_dec_prompt_lists, - STR_DTYPE_TO_TORCH_DTYPE,) +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, + is_cpu, to_enc_dec_tuple_list, + zip_enc_dec_prompt_lists) logger = init_logger(__name__) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 5b866769db932..4643d316d48b7 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,9 +1,7 @@ -from vllm.attention.backends.abstract import ( - AttentionBackend, - AttentionMetadata, - AttentionType, - AttentionMetadataBuilder, -) +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index f2833ddde2897..8ed357413603b 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -223,6 +223,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int], return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, + block_tables=block_tables, slot_mapping=slot_mapping_tensor, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, From a16cabb9029d86221a69975935622dd53084a554 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 22 Jul 2024 01:54:22 -0400 Subject: [PATCH 374/443] equalized some generation/sampling config settings between enc/dec HF,vLLM, nonetheless still not perfect match --- tests/conftest.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 942fb5ada2395..a897d9d2cb484 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -435,12 +435,18 @@ def generate_encoder_decoder_greedy_logprobs_limit( generation_config.do_sample = False generation_config.top_k = None generation_config.num_beams = 1 + generation_config.repetition_penalty=1.0 + # generation_config.temperature = 0.0 + generation_config.top_p = 1.0 + # generation_config.min_p = 0.0 + generation_config.length_penalty = 1.0 + generation_config.early_stopping = False output = self.model.generate( self.wrap_device(encoder_input_ids), decoder_input_ids=self.wrap_device(decoder_input_ids), use_cache=True, - do_sample=False, + # do_sample=False, max_new_tokens=max_tokens, output_hidden_states=True, return_dict_in_generate=True, @@ -652,6 +658,7 @@ def generate_encoder_decoder_greedy_logprobs( num_logprobs: int, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: greedy_logprobs_params = SamplingParams(temperature=0.0, + use_beam_search=False, max_tokens=max_tokens, logprobs=num_logprobs) ''' From 00198a633605b786c5f1fdef007c965d6284b39b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 22 Jul 2024 02:22:01 -0400 Subject: [PATCH 375/443] BART MLPs parallelized --- examples/offline_inference.py | 5 +-- tests/conftest.py | 2 +- vllm/model_executor/models/bart.py | 52 +++++++++++++++++++++++++----- 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index c10c99c09dda1..f610c51026857 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -19,8 +19,9 @@ llm = LLM(model="facebook/opt-125m", enforce_eager=True, tensor_parallel_size=4) - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. + # Generate texts from the prompts. The output is a list of + # RequestOutput objects that contain the prompt, generated + # text, and other information. outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: diff --git a/tests/conftest.py b/tests/conftest.py index a897d9d2cb484..420370114c549 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -435,7 +435,7 @@ def generate_encoder_decoder_greedy_logprobs_limit( generation_config.do_sample = False generation_config.top_k = None generation_config.num_beams = 1 - generation_config.repetition_penalty=1.0 + generation_config.repetition_penalty = 1.0 # generation_config.temperature = 0.0 generation_config.top_p = 1.0 # generation_config.min_p = 0.0 diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 19d94e7a4805a..ce247be2dcc5e 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -36,6 +36,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) logger = logging.get_logger(__name__) @@ -310,8 +313,24 @@ def __init__( self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.activation_fn = get_act_fn(config.activation_function, quant_config) - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + + ffn_hidden_size = self.embed_dim + ffn_intermediate_size = config.encoder_ffn_dim + ffn_has_bias = True + self.fc1 = ColumnParallelLinear( + ffn_hidden_size, + ffn_intermediate_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size) + self.fc2 = RowParallelLinear( + ffn_intermediate_size, + ffn_hidden_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, @@ -336,9 +355,10 @@ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, hidden_states = self.self_attn_layer_norm(hidden_states) residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) - hidden_states = self.fc2(hidden_states) + hidden_states, _ = self.fc2(hidden_states) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -386,8 +406,23 @@ def __init__( config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) - self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + + ffn_hidden_size = self.embed_dim + ffn_intermediate_size = config.encoder_ffn_dim + ffn_has_bias = True + self.fc1 = ColumnParallelLinear( + ffn_hidden_size, + ffn_intermediate_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + ffn_intermediate_size, + ffn_hidden_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( @@ -436,9 +471,10 @@ def forward( # Fully Connected residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) - hidden_states = self.fc2(hidden_states) + hidden_states, _ = self.fc2(hidden_states) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) From fb3227f68714ba6ed00e67e8a242db88288cdb8e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 22 Jul 2024 02:25:12 -0400 Subject: [PATCH 376/443] parallelized BART learned positional embedding --- examples/offline_inference.py | 2 +- vllm/model_executor/models/bart.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index f610c51026857..784587bf5b5a1 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -19,7 +19,7 @@ llm = LLM(model="facebook/opt-125m", enforce_eager=True, tensor_parallel_size=4) - # Generate texts from the prompts. The output is a list of + # Generate texts from the prompts. The output is a list of # RequestOutput objects that contain the prompt, generated # text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index ce247be2dcc5e..f6154c1566e40 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -52,7 +52,7 @@ def get_bsz_seq_len(input_ids): return shp[:2] -class BartLearnedPositionalEmbedding(nn.Embedding): +class BartLearnedPositionalEmbedding(VocabParallelEmbedding): """ This module learns positional embeddings up to a fixed maximum size. """ From e5bb9de596bd7f4b5d85ab6d0a2440cae06f982a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 22 Jul 2024 02:33:02 -0400 Subject: [PATCH 377/443] all attention layer output linears are parallelized --- vllm/model_executor/models/bart.py | 33 ++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index f6154c1566e40..baa7016f2feff 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -151,7 +151,14 @@ def __init__( self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + out_proj_has_bias = True + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=out_proj_has_bias, + quant_config=quant_config, + ) self.attn = Attention(self.num_heads, self.head_dim, @@ -174,7 +181,7 @@ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata, attn_type=AttentionType.ENCODER) - output = self.out_proj(attn_output) + output, _ = self.out_proj(attn_output) return output @@ -205,7 +212,14 @@ def __init__( self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + out_proj_has_bias = True + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=out_proj_has_bias, + quant_config=quant_config, + ) self.attn = Attention(self.num_heads, self.head_dim, @@ -228,7 +242,7 @@ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata, attn_type=AttentionType.DECODER) - output = self.out_proj(attn_output) + output, _ = self.out_proj(attn_output) return output @@ -259,7 +273,14 @@ def __init__( self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + out_proj_has_bias = True + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=out_proj_has_bias, + quant_config=quant_config, + ) self.attn = Attention(self.num_heads, self.head_dim, @@ -289,7 +310,7 @@ def forward( attn_metadata, attn_type=AttentionType.ENCODER_DECODER) - output = self.out_proj(attn_output) + output, _ = self.out_proj(attn_output) return output From 74abe22287374c9dd801ef059692016ef09777cb Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 22 Jul 2024 03:01:07 -0400 Subject: [PATCH 378/443] encoder attention & decoder self-attention parallelized --- examples/offline_inference.py | 3 +- vllm/model_executor/models/bart.py | 129 +++++++++++++++++++++-------- 2 files changed, 96 insertions(+), 36 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 784587bf5b5a1..f15698e8c8be0 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,6 +1,7 @@ +from utils import override_backend_env_var_context_manager + from vllm import LLM, SamplingParams from vllm.utils import STR_XFORMERS_ATTN_VAL -from utils import override_backend_env_var_context_manager # Sample prompts. prompts = [ diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index baa7016f2feff..59b1dc5ed4cb5 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -22,11 +22,15 @@ import torch from torch import nn from transformers import BartConfig -from vllm.model_executor.layers.activation import get_act_fn from transformers.utils import logging from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -36,9 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) logger = logging.get_logger(__name__) @@ -136,9 +137,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() + self.d_model = config.d_model self.embed_dim = embed_dim - self.num_heads = num_heads - self.num_kv_heads = self.num_heads + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads self.head_dim = embed_dim // num_heads self.config = config @@ -148,31 +150,58 @@ def __init__( f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) - out_proj_has_bias = True self.out_proj = RowParallelLinear( embed_dim, embed_dim, - bias=out_proj_has_bias, + bias=bias, quant_config=quant_config, ) - self.attn = Attention(self.num_heads, + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = Attention(self.total_num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads, + num_kv_heads=self.total_num_kv_heads, cache_config=cache_config, quant_config=quant_config) def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: """Input shape: Batch x Time x Channel""" - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) + # q = self.q_proj(hidden_states) + # k = self.k_proj(hidden_states) + # v = self.v_proj(hidden_states) + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) attn_output = self.attn(q, k, @@ -197,9 +226,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() + self.d_model = config.d_model self.embed_dim = embed_dim - self.num_heads = num_heads - self.num_kv_heads = self.num_heads + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads self.head_dim = embed_dim // num_heads self.config = config @@ -209,31 +239,58 @@ def __init__( f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) - out_proj_has_bias = True self.out_proj = RowParallelLinear( embed_dim, embed_dim, - bias=out_proj_has_bias, + bias=bias, quant_config=quant_config, ) - self.attn = Attention(self.num_heads, + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = Attention(self.total_num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads, + num_kv_heads=self.total_num_kv_heads, cache_config=cache_config, quant_config=quant_config) def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: """Input shape: Batch x Time x Channel""" - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) + # q = self.q_proj(hidden_states) + # k = self.k_proj(hidden_states) + # v = self.v_proj(hidden_states) + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) attn_output = self.attn(q, k, @@ -815,15 +872,15 @@ def sample( return next_tokens stacked_params_mapping = { - "query": { + "q_proj": { "param_name": "qkv_proj", "shard_id": "q", }, - "key": { + "k_proj": { "param_name": "qkv_proj", "shard_id": "k", }, - "value": { + "v_proj": { "param_name": "qkv_proj", "shard_id": "v", }, @@ -847,11 +904,13 @@ def _rename_key(self, key: str): def _rename_stacked_param( self, name: str, + cross_attn_keyword: str = 'encoder_attn', ) -> Tuple[str, Optional[str]]: - for key, mapping in self.stacked_params_mapping.items(): - if key in name: - name = name.replace(key, mapping["param_name"]) - return name, mapping["shard_id"] + if cross_attn_keyword not in name: + for key, mapping in self.stacked_params_mapping.items(): + if key in name: + name = name.replace(key, mapping["param_name"]) + return name, mapping["shard_id"] return name, None def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): From 9bbed43ab159063a8dff27587dae909b11e1a703 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 22 Jul 2024 03:20:20 -0400 Subject: [PATCH 379/443] parallelized LM head --- vllm/model_executor/models/bart.py | 47 ++++++++++++++++++------------ 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 59b1dc5ed4cb5..15e42d9101372 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -36,7 +36,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -117,7 +117,7 @@ class BartScaledWordEmbedding(VocabParallelEmbedding): def __init__(self, num_embeddings: int, embedding_dim: int, - embed_scale: Optional[float] = 1.0): + embed_scale: float = 1.0): super().__init__(num_embeddings, embedding_dim) self.embed_scale = embed_scale @@ -125,6 +125,25 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor: return super().forward(input_ids) * self.embed_scale +class BartParallelLMHead(ParallelLMHead): + """ + This module overrides ParallelLMHead's + forward by dividing by embeddings scale, + yielding effectively the inverse of + BartScaledWordEmbedding + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + embed_scale: float = 1.0): + super().__init__(num_embeddings, embedding_dim) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) / self.embed_scale + + class BartEncoderAttention(nn.Module): def __init__( @@ -150,10 +169,6 @@ def __init__( f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 - # self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.qkv_proj = QKVParallelLinear( self.d_model, self.d_model // self.total_num_heads, @@ -196,9 +211,6 @@ def __init__( def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: """Input shape: Batch x Time x Channel""" - # q = self.q_proj(hidden_states) - # k = self.k_proj(hidden_states) - # v = self.v_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -239,10 +251,6 @@ def __init__( f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 - # self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.qkv_proj = QKVParallelLinear( self.d_model, self.d_model // self.total_num_heads, @@ -285,9 +293,6 @@ def __init__( def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: """Input shape: Batch x Time x Channel""" - # q = self.q_proj(hidden_states) - # k = self.k_proj(hidden_states) - # v = self.v_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -818,9 +823,13 @@ def __init__(self, embed_scale = math.sqrt( config.d_model) if config.scale_embedding else 1.0 - self.lm_head = BartScaledWordEmbedding(config.vocab_size, - config.d_model, - embed_scale=embed_scale) + # self.lm_head = BartScaledWordEmbedding(config.vocab_size, + # config.d_model, + # embed_scale=embed_scale) + + self.lm_head = BartParallelLMHead(config.vocab_size, + config.d_model, + embed_scale=embed_scale) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) From fdf71de8557d588ff3b5767e96df09de4e9278d5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 22 Jul 2024 03:48:35 -0400 Subject: [PATCH 380/443] parallelized enc/dec cross-attention, using a slight hack --- vllm/model_executor/models/bart.py | 70 +++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 15e42d9101372..39049b33adcd1 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -320,9 +320,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() + self.d_model = config.d_model self.embed_dim = embed_dim - self.num_heads = num_heads - self.num_kv_heads = self.num_heads + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads self.head_dim = embed_dim // num_heads self.config = config @@ -332,22 +333,42 @@ def __init__( f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) - out_proj_has_bias = True self.out_proj = RowParallelLinear( embed_dim, embed_dim, - bias=out_proj_has_bias, + bias=bias, quant_config=quant_config, ) - self.attn = Attention(self.num_heads, + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = Attention(self.total_num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads, + num_kv_heads=self.total_num_kv_heads, cache_config=cache_config, quant_config=quant_config) @@ -359,11 +380,22 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" - q = self.q_proj(decoder_hidden_states) - k=None if encoder_hidden_states is None else \ - self.k_proj(encoder_hidden_states) - v=None if encoder_hidden_states is None else \ - self.v_proj(encoder_hidden_states) + # q = self.q_proj(decoder_hidden_states) + # k=None if encoder_hidden_states is None else \ + # self.k_proj(encoder_hidden_states) + # v=None if encoder_hidden_states is None else \ + # self.v_proj(encoder_hidden_states) + + # (afeldman-nm 2024/07/22) TODO: + # Need a more efficient solution for q/k/v + qkv_dec, _ = self.qkv_proj(decoder_hidden_states) + q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if encoder_hidden_states is None: + k=None + v=None + else: + qkv_enc, _ = self.qkv_proj(encoder_hidden_states) + _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size], dim=-1) attn_output = self.attn(q, k, @@ -913,13 +945,11 @@ def _rename_key(self, key: str): def _rename_stacked_param( self, name: str, - cross_attn_keyword: str = 'encoder_attn', ) -> Tuple[str, Optional[str]]: - if cross_attn_keyword not in name: - for key, mapping in self.stacked_params_mapping.items(): - if key in name: - name = name.replace(key, mapping["param_name"]) - return name, mapping["shard_id"] + for key, mapping in self.stacked_params_mapping.items(): + if key in name: + name = name.replace(key, mapping["param_name"]) + return name, mapping["shard_id"] return name, None def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): From 3551b6bf56ab74228c923b698e59a88b06bac81c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 22 Jul 2024 03:59:22 -0400 Subject: [PATCH 381/443] fixed bug where underlying Attention was constructed using full head-count --- vllm/model_executor/models/bart.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 39049b33adcd1..4bf8c5d49610e 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -201,7 +201,7 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.attn = Attention(self.total_num_heads, + self.attn = Attention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.total_num_kv_heads, @@ -283,7 +283,7 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.attn = Attention(self.total_num_heads, + self.attn = Attention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.total_num_kv_heads, @@ -365,7 +365,7 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.attn = Attention(self.total_num_heads, + self.attn = Attention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.total_num_kv_heads, @@ -389,13 +389,15 @@ def forward( # (afeldman-nm 2024/07/22) TODO: # Need a more efficient solution for q/k/v qkv_dec, _ = self.qkv_proj(decoder_hidden_states) - q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) if encoder_hidden_states is None: - k=None - v=None + k = None + v = None else: qkv_enc, _ = self.qkv_proj(encoder_hidden_states) - _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) attn_output = self.attn(q, k, From b174c7ab2da60e24a2ca576eccee671541ae142a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 22 Jul 2024 04:02:56 -0400 Subject: [PATCH 382/443] bart is parallelized, modulo an unfortunate hack for QKVParallelLinear in cross-attention --- vllm/model_executor/models/bart.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 4bf8c5d49610e..48d932c02f87d 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -204,7 +204,7 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.total_num_kv_heads, + num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config) @@ -286,7 +286,7 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.total_num_kv_heads, + num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config) @@ -368,7 +368,7 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.total_num_kv_heads, + num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config) From c43a6ed191e76f81bfd27f25e2ca8bac1fc01bcc Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 22 Jul 2024 04:03:59 -0400 Subject: [PATCH 383/443] commented out BART TP=4 --- examples/offline_inference_encoder_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index b8826ab8f15d7..7f8ae57e0ac6e 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -48,7 +48,7 @@ model="facebook/bart-large-cnn", enforce_eager=True, dtype=dtype, - tensor_parallel_size=4, + # tensor_parallel_size=4, ) # Generate texts from the prompts. The output is a list of # RequestOutput objects that contain the prompt, generated From c51a1682be7443ec7d32062491868bd49c631eb8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 23 Jul 2024 01:47:43 -0400 Subject: [PATCH 384/443] fixed bug in how conftest was handling HF encoder/decoder outputs; disabled HF engram repeat checks --- tests/conftest.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 420370114c549..045902723eb1e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -130,8 +130,11 @@ def example_encoder_decoder_prompts() -> Tuple[List[str], List[str]]: for filename in _TEST_PROMPTS: encoder_prompts += _read_prompts(filename) + prompt_list = zip_enc_dec_prompt_lists(encoder_prompts, + encoder_prompts[::-1]) + # Encoder prompts, decoder prompts - return zip_enc_dec_prompt_lists(encoder_prompts, encoder_prompts[::-1]) + return [prompt_list[1]] @pytest.fixture @@ -417,6 +420,10 @@ def generate_encoder_decoder_greedy_logprobs_limit( Greedy logprobs generation for vLLM encoder/decoder models ''' + # decoder_start_token_id = getattr(self.model.config, + # 'decoder_start_token_id', + # None) + all_logprobs: List[List[Dict[int, float]]] = [] all_output_ids: List[List[int]] = [] all_output_strs: List[str] = [] @@ -428,6 +435,13 @@ def generate_encoder_decoder_greedy_logprobs_limit( decoder_input_ids = self.tokenizer(decoder_prompt, return_tensors="pt").input_ids + # # If the decoder input ids do not begin with decoder start + # # token, HF transformers will likely add it automatically. + # # This becomes important information later. + # implicit_decoder_start_token=(True + # if decoder_input_ids.shape[1] < 1 else + # (decoder_input_ids[0][0] ==decoder_start_token_id)) + from transformers.generation.configuration_utils import ( GenerationConfig) generation_config = GenerationConfig.from_model_config( @@ -441,6 +455,8 @@ def generate_encoder_decoder_greedy_logprobs_limit( # generation_config.min_p = 0.0 generation_config.length_penalty = 1.0 generation_config.early_stopping = False + generation_config.no_repeat_ngram_size = None + generation_config.min_length = 0 output = self.model.generate( self.wrap_device(encoder_input_ids), @@ -454,6 +470,7 @@ def generate_encoder_decoder_greedy_logprobs_limit( ) seq_logprobs: List[torch.Tensor] = [] + output_len = len(output.decoder_hidden_states) for _, decoder_hidden_states in enumerate( output.decoder_hidden_states): last_hidden_states = decoder_hidden_states[-1][0] @@ -484,7 +501,7 @@ def generate_encoder_decoder_greedy_logprobs_limit( all_logprobs.append(seq_logprobs_lst) seq_ids = output.sequences[0] - output_len = seq_ids.shape[0] - decoder_input_ids.shape[1] + #output_len = seq_ids.shape[0] - decoder_input_ids.shape[1] output_ids = seq_ids[-output_len:] all_output_ids.append(output_ids.tolist()) all_output_strs.append(self.tokenizer.decode(output_ids)) From b01937f0ce29bc9e417e85cb4dd18ddb47a98e3b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 23 Jul 2024 04:14:06 -0400 Subject: [PATCH 385/443] set up None/empty str tests which are not passing --- tests/conftest.py | 28 ++++++++++++++++++++-------- tests/models/test_bart.py | 15 +++++++++++++-- tests/models/utils.py | 11 +++++++++++ 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 045902723eb1e..d753939aad9e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForVision2Seq, AutoTokenizer, BatchEncoding) +from tests.models.utils import DecoderPromptType from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.config import TokenizerPoolConfig @@ -115,7 +116,9 @@ def example_prompts() -> List[str]: @pytest.fixture -def example_encoder_decoder_prompts() -> Tuple[List[str], List[str]]: +def example_encoder_decoder_prompts() \ + -> Dict[DecoderPromptType, + Tuple[List[str], List[Optional[str]]]]: ''' Returns an encoder prompt list and a decoder prompt list, wherein each pair of same-index entries in both lists corresponds to an (encoder prompt, @@ -126,15 +129,24 @@ def example_encoder_decoder_prompts() -> Tuple[List[str], List[str]]: * Encoder prompt list * Decoder prompt list (reverse of encoder prompt list) ''' + encoder_prompts = [] for filename in _TEST_PROMPTS: encoder_prompts += _read_prompts(filename) - prompt_list = zip_enc_dec_prompt_lists(encoder_prompts, - encoder_prompts[::-1]) + custom_decoder_prompts = encoder_prompts[::-1] + empty_str_decoder_prompts = [""] * len(encoder_prompts) + none_decoder_prompts = [None] * len(encoder_prompts) - # Encoder prompts, decoder prompts - return [prompt_list[1]] + # NONE decoder prompt type + return { + DecoderPromptType.NONE: + zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts), + DecoderPromptType.EMPTY_STR: + zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts), + DecoderPromptType.CUSTOM: + zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts), + } @pytest.fixture @@ -420,7 +432,7 @@ def generate_encoder_decoder_greedy_logprobs_limit( Greedy logprobs generation for vLLM encoder/decoder models ''' - # decoder_start_token_id = getattr(self.model.config, + # decoder_start_token_id = getattr(self.model.config, # 'decoder_start_token_id', # None) @@ -438,8 +450,8 @@ def generate_encoder_decoder_greedy_logprobs_limit( # # If the decoder input ids do not begin with decoder start # # token, HF transformers will likely add it automatically. # # This becomes important information later. - # implicit_decoder_start_token=(True - # if decoder_input_ids.shape[1] < 1 else + # implicit_decoder_start_token=(True + # if decoder_input_ids.shape[1] < 1 else # (decoder_input_ids[0][0] ==decoder_start_token_id)) from transformers.generation.configuration_utils import ( diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index 790ed4f1cbf30..ac09bcf5abdc6 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -12,6 +12,7 @@ import pytest from tests.kernels.utils import override_backend_env_variable + from tests.models.utils import DecoderPromptType from .utils import check_logprobs_close @@ -22,11 +23,17 @@ # Currently only XFormers is supported BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] + DECODER_PROMPT_TYPES = ([ + DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR, + DecoderPromptType.NONE + ]) + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) + @pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES) def test_models( hf_runner, vllm_runner, @@ -36,6 +43,7 @@ def test_models( max_tokens: int, num_logprobs: int, backend_name: str, + decoder_prompt_type: DecoderPromptType, monkeypatch, ) -> None: # TODO(sang): Sliding window should be tested separately. @@ -43,15 +51,18 @@ def test_models( # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) + test_case_prompts = example_encoder_decoder_prompts[ + decoder_prompt_type] + with hf_runner(model, dtype=dtype, is_encoder_decoder_model=True) as hf_model: hf_outputs = ( hf_model.generate_encoder_decoder_greedy_logprobs_limit( - example_encoder_decoder_prompts, max_tokens, num_logprobs)) + test_case_prompts, max_tokens, num_logprobs)) with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - example_encoder_decoder_prompts, max_tokens, num_logprobs) + test_case_prompts, max_tokens, num_logprobs) check_logprobs_close(outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, diff --git a/tests/models/utils.py b/tests/models/utils.py index 425f57ef9b966..84e439a6e239d 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,4 +1,5 @@ import warnings +from enum import Enum from typing import Dict, List, Optional, Sequence, Tuple, Union from vllm.sequence import SampleLogprobs @@ -110,3 +111,13 @@ def check_logprobs_close( warnings.simplefilter("always") warnings.warn(fail_msg, stacklevel=2) + + +class DecoderPromptType(Enum): + ''' + For encoder/decoder models only - + + ''' + CUSTOM = 1 + NONE = 2 + EMPTY_STR = 3 From 059273f3ca43947413572a0014c1437a53e33b8a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 23 Jul 2024 16:56:07 -0400 Subject: [PATCH 386/443] wip --- examples/offline_inference_encoder_decoder.py | 26 +------------------ 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 7f8ae57e0ac6e..56dd149cd6cb1 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -61,28 +61,4 @@ generated_text = output.outputs[0].text print(f"Encoder prompt: {encoder_prompt!r}, " f"Decoder prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") - -model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") -tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") - -ARTICLE_TO_SUMMARIZE = encoder_prompts[0] -inputs = tokenizer([ARTICLE_TO_SUMMARIZE], - max_length=1024, - return_tensors="pt") - -# decoder_inputs = tokenizer([''], -# max_length=1024, -# return_tensors="pt") - -# Generate Summary -summary_ids = model.generate( - inputs["input_ids"], - # decoder_input_ids=decoder_inputs["input_ids"], - min_length=0, - max_length=20, -) -print( - tokenizer.batch_decode(summary_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=False), ) + f"Generated text: {generated_text!r}") \ No newline at end of file From 7e7bbd9e16900449e350bf8634d584e4b1a5c2f0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 23 Jul 2024 16:57:41 -0400 Subject: [PATCH 387/443] deleted unnecessary dependency --- examples/offline_inference_encoder_decoder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 56dd149cd6cb1..00faa0e77966c 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -1,4 +1,3 @@ -from transformers import AutoTokenizer, BartForConditionalGeneration from utils import override_backend_env_var_context_manager from vllm import LLM, SamplingParams From aa01d71f90f0c3cda8a7ea419ff4f1fb6dc9d13c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 23 Jul 2024 20:56:51 -0400 Subject: [PATCH 388/443] empty-string decoder input is now handled for encoder/decoder --- vllm/engine/llm_engine.py | 10 +++++----- vllm/model_executor/models/bart.py | 3 --- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4fbfc6a59ad09..4195587b4a1cd 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -682,11 +682,11 @@ def _tokenize_prompt( tokenizer = self.get_tokenizer_group("prompts must be None if " "skip_tokenizer_init is True") - if is_enc_dec_decoder and prompt == "": - # Scenario: enc/dec model, decoder input prompt is "" - # => Treat it as None (no decoder input prompt provided) - # & obtain default decoder input prompt - return self._prepare_decoder_input_ids_for_generation() + # if is_enc_dec_decoder and prompt == "": + # # Scenario: enc/dec model, decoder input prompt is "" + # # => Treat it as None (no decoder input prompt provided) + # # & obtain default decoder input prompt + # return self._prepare_decoder_input_ids_for_generation() # Scenario: # * Any decoder-only input prompt diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 48d932c02f87d..9740fd5e9f697 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -857,9 +857,6 @@ def __init__(self, embed_scale = math.sqrt( config.d_model) if config.scale_embedding else 1.0 - # self.lm_head = BartScaledWordEmbedding(config.vocab_size, - # config.d_model, - # embed_scale=embed_scale) self.lm_head = BartParallelLMHead(config.vocab_size, config.d_model, From 0b29fd27f17f2751550262f218e6ef1afbef7087 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 23 Jul 2024 21:35:25 -0400 Subject: [PATCH 389/443] enc/dec handles empty str and None decoder prompts correctly --- tests/conftest.py | 15 +++++++++------ vllm/engine/llm_engine.py | 12 ++++++------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 2bdd78366fd42..6dda2f07237f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -450,10 +450,13 @@ def generate_encoder_decoder_greedy_logprobs_limit( for (encoder_prompt, decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts): - encoder_input_ids = self.tokenizer(encoder_prompt, - return_tensors="pt").input_ids - decoder_input_ids = self.tokenizer(decoder_prompt, - return_tensors="pt").input_ids + encoder_input_ids = self.wrap_device(self.tokenizer(encoder_prompt, + return_tensors="pt").input_ids) + decoder_input_ids = (None if decoder_prompt is None + else + self.wrap_device(self.tokenizer(decoder_prompt, + return_tensors="pt").input_ids) + ) # # If the decoder input ids do not begin with decoder start # # token, HF transformers will likely add it automatically. @@ -479,8 +482,8 @@ def generate_encoder_decoder_greedy_logprobs_limit( generation_config.min_length = 0 output = self.model.generate( - self.wrap_device(encoder_input_ids), - decoder_input_ids=self.wrap_device(decoder_input_ids), + encoder_input_ids, + decoder_input_ids=decoder_input_ids, use_cache=True, # do_sample=False, max_new_tokens=max_tokens, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4195587b4a1cd..3a00245fa49d2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -682,11 +682,11 @@ def _tokenize_prompt( tokenizer = self.get_tokenizer_group("prompts must be None if " "skip_tokenizer_init is True") - # if is_enc_dec_decoder and prompt == "": - # # Scenario: enc/dec model, decoder input prompt is "" - # # => Treat it as None (no decoder input prompt provided) - # # & obtain default decoder input prompt - # return self._prepare_decoder_input_ids_for_generation() + if is_enc_dec_decoder and prompt is None: + # Scenario: enc/dec model, decoder input prompt is "" + # => Treat it as None (no decoder input prompt provided) + # & obtain default decoder input prompt + return self._prepare_decoder_input_ids_for_generation() # Scenario: # * Any decoder-only input prompt @@ -734,7 +734,7 @@ def _extract_single_prompt( is_enc_dec_decoder = ((not is_encoder_prompt) and is_enc_dec_model) - if isinstance(inputs, str): + if inputs is None or isinstance(inputs, str): # prompt = inputs # prompt_token_ids = tokenize(inputs) # no multi-modal data From dd784b5423ba21fc6b8188908df417d128376a1f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 23 Jul 2024 21:37:19 -0400 Subject: [PATCH 390/443] typing fix --- tests/conftest.py | 13 ++++++------- vllm/engine/llm_engine.py | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6dda2f07237f5..a3ef9b6ce019a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -450,13 +450,12 @@ def generate_encoder_decoder_greedy_logprobs_limit( for (encoder_prompt, decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts): - encoder_input_ids = self.wrap_device(self.tokenizer(encoder_prompt, - return_tensors="pt").input_ids) - decoder_input_ids = (None if decoder_prompt is None - else - self.wrap_device(self.tokenizer(decoder_prompt, - return_tensors="pt").input_ids) - ) + encoder_input_ids = self.wrap_device( + self.tokenizer(encoder_prompt, return_tensors="pt").input_ids) + decoder_input_ids = ( + None if decoder_prompt is None else self.wrap_device( + self.tokenizer(decoder_prompt, + return_tensors="pt").input_ids)) # # If the decoder input ids do not begin with decoder start # # token, HF transformers will likely add it automatically. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3a00245fa49d2..8ad5efc60568e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -713,7 +713,7 @@ def _extract_single_prompt( lora_request: Optional[LoRARequest], is_encoder_prompt: bool = False, is_enc_dec_model: bool = False, - ) -> Tuple[str, List[int], Optional["MultiModalDataDict"]]: + ) -> Tuple[Optional[str], List[int], Optional["MultiModalDataDict"]]: ''' Extract prompt & prompt_token_ids from any single encoder or decoder input prompt. For encoder input prompts From 61d2ad2cc7791b6e32c8678b8e88ed99bbab4118 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 24 Jul 2024 00:28:20 -0400 Subject: [PATCH 391/443] fixed bugs in handling non-text formats for individual prompts --- examples/offline_inference_encoder_decoder.py | 87 ++++++++++++------- vllm/engine/llm_engine.py | 36 ++++++-- vllm/inputs/data.py | 5 +- 3 files changed, 91 insertions(+), 37 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 00faa0e77966c..f2672d9897efb 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -1,32 +1,68 @@ +''' +Demonstrate prompting of text-to-text +encoder/decoder models, specifically BART +''' from utils import override_backend_env_var_context_manager from vllm import LLM, SamplingParams +from vllm.inputs import (TextPrompt, + TokensPrompt, + ExplicitEncoderDecoderPrompt) from vllm.utils import STR_XFORMERS_ATTN_VAL, zip_enc_dec_prompt_lists dtype = "float" -# Sample prompts. -# - Encoder prompts -encoder_prompts = [ - "PG&E stated it scheduled the blackouts in " - "response to forecasts for high winds " - "amid dry conditions. The aim is to reduce " - "the risk of wildfires. Nearly 800 thousand customers were " - "scheduled to be affected by the shutoffs which " - "were expected to last through at least midday tomorrow.", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# - Decoder prompts -decoder_prompts = [ - encoder_prompts[0], - "", - "", - "", -] -# - Unified encoder/decoder prompts -prompts = zip_enc_dec_prompt_lists(encoder_prompts, decoder_prompts) +# Create a BART encoder/decoder model instance +llm = LLM( + model="facebook/bart-large-cnn", + enforce_eager=True, + dtype=dtype, + # tensor_parallel_size=4, +) + +# Get BART tokenizer +tokenizer=llm.llm_engine.get_tokenizer_group() + +# Test prompts +# - Helpers for building prompts +text_prompt_raw = "Hello, my name is" +text_prompt = TextPrompt(prompt="The president of the United States is") +tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode( + prompt="The capital of France is", + ) +) +# - Pass a single prompt to encoder/decoder model (implicitly encoder input prompt); +# decoder input prompt is assumed to be None +single_text_prompt_raw = text_prompt_raw +single_text_prompt = text_prompt +single_tokens_prompt = tokens_prompt +# - Pass explicit encoder and decoder input prompts within a single data structure. +# Encoder and decoder prompts can both independently be text or tokens, with +# no requirement that they be the same prompt type. Some example prompt-type +# combinations are shown below. +enc_dec_prompt1 = ExplicitEncoderDecoderPrompt( + encoder_prompt=single_text_prompt_raw, + decoder_prompt=single_tokens_prompt, +) +enc_dec_prompt2 = ExplicitEncoderDecoderPrompt( + encoder_prompt=single_text_prompt, + decoder_prompt=single_text_prompt_raw, +) +enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( + encoder_prompt=single_tokens_prompt, + decoder_prompt=single_text_prompt, +) +# - Build prompt list +prompts = [single_text_prompt_raw, + single_text_prompt, + single_tokens_prompt, + enc_dec_prompt1, + enc_dec_prompt2, + enc_dec_prompt3 + ] + +# # - Unified encoder/decoder prompts +# prompts = zip_enc_dec_prompt_lists(encoder_prompts, decoder_prompts) print(prompts) @@ -42,13 +78,6 @@ max_tokens=20, ) - # Create an LLM. - llm = LLM( - model="facebook/bart-large-cnn", - enforce_eager=True, - dtype=dtype, - # tensor_parallel_size=4, - ) # Generate texts from the prompts. The output is a list of # RequestOutput objects that contain the prompt, generated # text, and other information. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8ad5efc60568e..5c7fa0b90f5ca 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -711,6 +711,7 @@ def _extract_single_prompt( request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest], + ptype: Optional[str] = None, is_encoder_prompt: bool = False, is_enc_dec_model: bool = False, ) -> Tuple[Optional[str], List[int], Optional["MultiModalDataDict"]]: @@ -722,6 +723,7 @@ def _extract_single_prompt( Arguments: * request_id + * ptype: str representation of the input prompt type * inputs: single encoder or decoder input prompt * lora_request * is_encoder_prompt: True if encoder input prompt @@ -732,9 +734,18 @@ def _extract_single_prompt( * multi_modal_data (None if is_encoder_prompt) ''' + # Determine prompt type, if not provided + ptype = ( + get_prompt_type(inputs) if ptype is None else + ptype + ) + is_enc_dec_decoder = ((not is_encoder_prompt) and is_enc_dec_model) - if inputs is None or isinstance(inputs, str): + # Any prompt such as None, string + # or TextPrompt that is not a dict + if ptype in ['None','str']: + # prompt = inputs # prompt_token_ids = tokenize(inputs) # no multi-modal data @@ -748,22 +759,25 @@ def _extract_single_prompt( # Tokenize prompt_token_ids = (inputs["prompt_token_ids"] - if inputs["prompt_token_ids"] else + if ptype == "TokensPrompt" else self._tokenize_prompt( request_id, - inputs, + inputs['prompt'], lora_request, is_enc_dec_decoder=is_enc_dec_decoder, )) + # None if no prompt field is present + prompt = inputs.get('prompt') + if is_encoder_prompt: # Only care about multi-modal data associated # with the encoder prompt - return (inputs.get('prompt'), prompt_token_ids, + return (prompt, prompt_token_ids, inputs.get("multi_modal_data")) else: # Assume there is no decoder multi-modal data - return (inputs.get('prompt'), prompt_token_ids, None) + return (prompt, prompt_token_ids, None) def _get_default_decoder_prompt( self, @@ -783,6 +797,13 @@ def _process_encoder_decoder_prompt(self, request_id: str, lora_request: Optional[LoRARequest]): ptype = get_prompt_type(inputs) + if ptype == "ExplicitEncoderDecoder": + extracted_encoder_prompt = inputs.get('encoder_prompt') + encoder_ptype = None + else: + extracted_encoder_prompt = inputs + encoder_ptype = ptype + # Obtain encoder prompt ( encoder_prompt, @@ -790,9 +811,9 @@ def _process_encoder_decoder_prompt(self, request_id: str, multi_modal_data, ) = self._extract_single_prompt( request_id, - (inputs.get('encoder_prompt') if get_prompt_type(inputs) - == "ExplicitEncoderDecoder" else inputs), + extracted_encoder_prompt, lora_request, + ptype=encoder_ptype, is_encoder_prompt=True, is_enc_dec_model=True, ) @@ -810,6 +831,7 @@ def _process_encoder_decoder_prompt(self, request_id: str, request_id, inputs.get('decoder_prompt'), lora_request, + ptype = None, is_encoder_prompt=False, is_enc_dec_model=True, ) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 40a2df4a7ec84..c900e424daf54 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -121,12 +121,15 @@ class ExplicitEncoderDecoderPrompt(TypedDict): :class:`TextTokensPrompt`.""" -def get_prompt_type(prompt: PromptInputs, ) -> str: +def get_prompt_type(prompt: Optional[PromptInputs], ) -> Optional[str]: """ Get the type-name of the prompt argument instance, given that isinstance() cannot apply to TypedDict subclasses directly. """ + if prompt is None: + return 'None' + required_keys_dict = { 'TextPrompt': ['prompt'], 'TokensPrompt': ['prompt_token_ids'], From f36ffb5695b0694947f4ae9e7417cc1afa85e19c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 24 Jul 2024 00:33:47 -0400 Subject: [PATCH 392/443] example includes prompt zipper --- examples/offline_inference_encoder_decoder.py | 35 +++++++++---------- vllm/engine/llm_engine.py | 17 ++++----- 2 files changed, 23 insertions(+), 29 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index f2672d9897efb..f0565efe059a8 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -5,9 +5,7 @@ from utils import override_backend_env_var_context_manager from vllm import LLM, SamplingParams -from vllm.inputs import (TextPrompt, - TokensPrompt, - ExplicitEncoderDecoderPrompt) +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt from vllm.utils import STR_XFORMERS_ATTN_VAL, zip_enc_dec_prompt_lists dtype = "float" @@ -21,22 +19,21 @@ ) # Get BART tokenizer -tokenizer=llm.llm_engine.get_tokenizer_group() +tokenizer = llm.llm_engine.get_tokenizer_group() # Test prompts # - Helpers for building prompts text_prompt_raw = "Hello, my name is" text_prompt = TextPrompt(prompt="The president of the United States is") -tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode( - prompt="The capital of France is", - ) -) -# - Pass a single prompt to encoder/decoder model (implicitly encoder input prompt); +tokens_prompt = TokensPrompt( + prompt_token_ids=tokenizer.encode(prompt="The capital of France is", )) +# - Pass a single prompt to encoder/decoder model +# (implicitly encoder input prompt); # decoder input prompt is assumed to be None single_text_prompt_raw = text_prompt_raw single_text_prompt = text_prompt single_tokens_prompt = tokens_prompt -# - Pass explicit encoder and decoder input prompts within a single data structure. +# - Pass explicit encoder and decoder input prompts within one data structure. # Encoder and decoder prompts can both independently be text or tokens, with # no requirement that they be the same prompt type. Some example prompt-type # combinations are shown below. @@ -52,14 +49,16 @@ encoder_prompt=single_tokens_prompt, decoder_prompt=single_text_prompt, ) +# - Here's a useful helper function for zipping encoder and decoder prompt lists +# together into a list of ExplicitEncoderDecoderPrompt instances +zipped_prompt_list = zip_enc_dec_prompt_lists( + ['An encoder prompt', 'Another encoder prompt'], + ['A decoder prompt', 'Another decoder prompt']) # - Build prompt list -prompts = [single_text_prompt_raw, - single_text_prompt, - single_tokens_prompt, - enc_dec_prompt1, - enc_dec_prompt2, - enc_dec_prompt3 - ] +prompts = [ + single_text_prompt_raw, single_text_prompt, single_tokens_prompt, + enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 +] + zipped_prompt_list # # - Unified encoder/decoder prompts # prompts = zip_enc_dec_prompt_lists(encoder_prompts, decoder_prompts) @@ -89,4 +88,4 @@ generated_text = output.outputs[0].text print(f"Encoder prompt: {encoder_prompt!r}, " f"Decoder prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") \ No newline at end of file + f"Generated text: {generated_text!r}") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5c7fa0b90f5ca..3482652c2fc51 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -735,16 +735,13 @@ def _extract_single_prompt( ''' # Determine prompt type, if not provided - ptype = ( - get_prompt_type(inputs) if ptype is None else - ptype - ) + ptype = (get_prompt_type(inputs) if ptype is None else ptype) is_enc_dec_decoder = ((not is_encoder_prompt) and is_enc_dec_model) # Any prompt such as None, string # or TextPrompt that is not a dict - if ptype in ['None','str']: + if ptype in ['None', 'str']: # prompt = inputs # prompt_token_ids = tokenize(inputs) @@ -758,9 +755,8 @@ def _extract_single_prompt( ), None) # Tokenize - prompt_token_ids = (inputs["prompt_token_ids"] - if ptype == "TokensPrompt" else - self._tokenize_prompt( + prompt_token_ids = (inputs["prompt_token_ids"] if ptype + == "TokensPrompt" else self._tokenize_prompt( request_id, inputs['prompt'], lora_request, @@ -773,8 +769,7 @@ def _extract_single_prompt( if is_encoder_prompt: # Only care about multi-modal data associated # with the encoder prompt - return (prompt, prompt_token_ids, - inputs.get("multi_modal_data")) + return (prompt, prompt_token_ids, inputs.get("multi_modal_data")) else: # Assume there is no decoder multi-modal data return (prompt, prompt_token_ids, None) @@ -831,7 +826,7 @@ def _process_encoder_decoder_prompt(self, request_id: str, request_id, inputs.get('decoder_prompt'), lora_request, - ptype = None, + ptype=None, is_encoder_prompt=False, is_enc_dec_model=True, ) From 02114bdcd5a832c3610318a8d0b8cfb26070f3ef Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 24 Jul 2024 04:31:32 -0400 Subject: [PATCH 393/443] _free_seq_group() -> _free_seq_group_cross_attn_blocks() --- vllm/core/scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index e72bd6f1303e9..52236ca64d3dd 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -381,9 +381,9 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: seq.status = SequenceStatus.FINISHED_ABORTED self.free_seq(seq) - self._free_seq_group(aborted_group) + self._free_seq_group_cross_attn_blocks(aborted_group) - def _free_seq_group( + def _free_seq_group_cross_attn_blocks( self, seq_group: SequenceGroup, ) -> None: @@ -1091,7 +1091,7 @@ def free_finished_seq_groups(self) -> None: if seq_group.is_finished(): new_finished_requests_ids += seq_group.request_id # Free cross-attention block table, if it exists - self._free_seq_group(seq_group) + self._free_seq_group_cross_attn_blocks(seq_group) self._finished_requests_ids += new_finished_requests_ids self.running = deque(seq_group for seq_group in self.running if not seq_group.is_finished()) From 5a270ff49f3ebafecf8fb45e090f08d705aa416a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 24 Jul 2024 04:46:32 -0400 Subject: [PATCH 394/443] refactoring --- vllm/entrypoints/llm.py | 2 +- vllm/inputs/__init__.py | 3 +-- vllm/inputs/data.py | 54 ++++++++++++++++++----------------------- 3 files changed, 26 insertions(+), 33 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index cf170adc32cc4..bc7d7e4373228 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -292,7 +292,7 @@ def generate( """ if self.llm_engine.model_config.embedding_mode: raise ValueError( - "LLM.generate() is only supported for (conditional)generation " + "LLM.generate() is only supported for (conditional) generation " "models (XForCausalLM, XForConditionalGeneration).") if prompt_token_ids is not None: diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index a08e97087cd4a..0fc4dbed1460b 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,7 @@ from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText, ParsedTokens, PromptInputs, TextPrompt, TokensPrompt, get_prompt_type, is_valid_encoder_decoder_llm_inputs, - is_valid_encoder_decoder_prompt, parse_and_batch_prompt) + parse_and_batch_prompt) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -25,7 +25,6 @@ "InputContext", "InputRegistry", "get_prompt_type", - "is_valid_encoder_decoder_prompt", "is_valid_encoder_decoder_llm_inputs", "ExplicitEncoderDecoderPrompt", ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index c900e424daf54..c39ef256583b2 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -95,12 +95,22 @@ class TokensPrompt(TypedDict): DecoderOnlyPromptInputs = Union[str, TextPrompt, TokensPrompt] +""" +Set of possible schemas for a single LLM input: +- A text prompt (:class:`str` or :class:`TextPrompt`) +- A tokenized prompt (:class:`TokensPrompt`) +""" class ExplicitEncoderDecoderPrompt(TypedDict): """Represents an encoder/decoder model input prompt, comprising an encoder prompt and a decoder prompt. + The encoder and decoder prompts, respectively, + may formatted according to any of the + DecoderOnlyPromptInputs schemas, and are not + required to have the same schema. + Only the encoder prompt may have multi-modal data. """ @@ -108,23 +118,31 @@ class ExplicitEncoderDecoderPrompt(TypedDict): decoder_prompt: DecoderOnlyPromptInputs - +PromptInputs = Union[DecoderOnlyPromptInputs, ExplicitEncoderDecoderPrompt] """ -The inputs to the LLM, which can take one of the following forms: +Set of possible schemas for an LLM input, including +both decoder-only and encoder/decoder input types: - A text prompt (:class:`str` or :class:`TextPrompt`) - A tokenized prompt (:class:`TokensPrompt`) +- A single data structure containing both an encoder and a decoder prompt + (:class:`ExplicitEncoderDecoderPrompt`) """ -PromptInputs = Union[DecoderOnlyPromptInputs, ExplicitEncoderDecoderPrompt] -"""Same as :const:`PromptStrictInputs` but additionally accepts -:class:`TextTokensPrompt`.""" - def get_prompt_type(prompt: Optional[PromptInputs], ) -> Optional[str]: """ Get the type-name of the prompt argument instance, given that isinstance() cannot apply to TypedDict subclasses directly. + If the prompt is None, return 'None' as the type name. + + Arguments: + + * prompt: LLM input prompt or None + + Returns: + + * String representation of prompt type """ if prompt is None: @@ -154,30 +172,6 @@ def get_prompt_type(prompt: Optional[PromptInputs], ) -> Optional[str]: raise ValueError(f"Invalid prompt {prompt}") -def is_valid_encoder_decoder_prompt(prompt: PromptInputs, ) -> bool: - """ - Return True if prompt has the correct structure for an encoder/decoder - prompt. - """ - # Ignore type checking in the conditional below because type checker - # does not understand that - # get_single_prompt_type(prompt) == 'ExplicitEncoderDecoder' narrows - # down the possible types - if (get_prompt_type(prompt) == 'ExplicitEncoderDecoder' and - (prompt['encoder_prompt'] is None # type: ignore - or prompt['decoder_prompt']['multi_modal_data'] # type: ignore - is not None)): - # For explicit encoder/decoder prompts, encoder prompt - # must be non-None and decoder prompt must be free of - # multi-modal data (which should instead be passed to - # the encoder.) - return False - - # Any valid prompt type other than an explicit encoder/decoder - # prompt is a guaranteed-valid prompt - return True - - class LLMInputs(TypedDict): """ The inputs in :class:`~vllm.LLMEngine` before they are From ed4a56b9ca31cdf06033611887114920318ad397 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 24 Jul 2024 04:46:49 -0400 Subject: [PATCH 395/443] formatting --- vllm/inputs/data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index c39ef256583b2..a70836a3cc35f 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -102,6 +102,7 @@ class TokensPrompt(TypedDict): - A tokenized prompt (:class:`TokensPrompt`) """ + class ExplicitEncoderDecoderPrompt(TypedDict): """Represents an encoder/decoder model input prompt, comprising an encoder prompt and a decoder prompt. @@ -118,6 +119,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict): decoder_prompt: DecoderOnlyPromptInputs + PromptInputs = Union[DecoderOnlyPromptInputs, ExplicitEncoderDecoderPrompt] """ Set of possible schemas for an LLM input, including From 4b5b2cf956141e3adbc22a7a2aa2ebbb9bad8979 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 24 Jul 2024 04:51:48 -0400 Subject: [PATCH 396/443] removed unnecessary argument reordering --- vllm/attention/backends/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index c1f0b49abd42e..b5c7bb0228fea 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -224,7 +224,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, - block_tables=block_tables, slot_mapping=slot_mapping_tensor, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, @@ -236,5 +235,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, + block_tables=block_tables, use_cuda_graph=use_captured_graph, ) From d82b27346b444778eeba42e015ac716883c37f76 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 24 Jul 2024 05:01:27 -0400 Subject: [PATCH 397/443] enc/dec example comments' --- examples/offline_inference_encoder_decoder.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index f0565efe059a8..c9aa2041c0562 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -22,6 +22,10 @@ tokenizer = llm.llm_engine.get_tokenizer_group() # Test prompts +# +# This section shows all of the valid ways to prompt an +# encoder/decoder model. +# # - Helpers for building prompts text_prompt_raw = "Hello, my name is" text_prompt = TextPrompt(prompt="The president of the United States is") @@ -30,39 +34,49 @@ # - Pass a single prompt to encoder/decoder model # (implicitly encoder input prompt); # decoder input prompt is assumed to be None -single_text_prompt_raw = text_prompt_raw -single_text_prompt = text_prompt -single_tokens_prompt = tokens_prompt + +single_text_prompt_raw = text_prompt_raw # Pass a string directly +single_text_prompt = text_prompt # Pass a TextPrompt +single_tokens_prompt = tokens_prompt # Pass a TokensPrompt + # - Pass explicit encoder and decoder input prompts within one data structure. # Encoder and decoder prompts can both independently be text or tokens, with # no requirement that they be the same prompt type. Some example prompt-type -# combinations are shown below. +# combinations are shown below, note that these are not exhaustive. + enc_dec_prompt1 = ExplicitEncoderDecoderPrompt( + # Pass encoder prompt string directly, & + # pass decoder prompt tokens encoder_prompt=single_text_prompt_raw, decoder_prompt=single_tokens_prompt, ) enc_dec_prompt2 = ExplicitEncoderDecoderPrompt( + # Pass TextPrompt to encoder, and + # pass decoder prompt string directly encoder_prompt=single_text_prompt, decoder_prompt=single_text_prompt_raw, ) enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( + # Pass encoder prompt tokens directly, and + # pass TextPrompt to decoder encoder_prompt=single_tokens_prompt, decoder_prompt=single_text_prompt, ) -# - Here's a useful helper function for zipping encoder and decoder prompt lists -# together into a list of ExplicitEncoderDecoderPrompt instances + +# - Finally, here's a useful helper function for zipping encoder and +# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt +# instances zipped_prompt_list = zip_enc_dec_prompt_lists( ['An encoder prompt', 'Another encoder prompt'], ['A decoder prompt', 'Another decoder prompt']) -# - Build prompt list + +# - Let's put all of the above example prompts together into one list +# which we will pass to the encoder/decoder LLM. prompts = [ single_text_prompt_raw, single_text_prompt, single_tokens_prompt, enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 ] + zipped_prompt_list -# # - Unified encoder/decoder prompts -# prompts = zip_enc_dec_prompt_lists(encoder_prompts, decoder_prompts) - print(prompts) with override_backend_env_var_context_manager(STR_XFORMERS_ATTN_VAL): @@ -77,10 +91,11 @@ max_tokens=20, ) - # Generate texts from the prompts. The output is a list of + # Generate output tokens from the prompts. The output is a list of # RequestOutput objects that contain the prompt, generated # text, and other information. outputs = llm.generate(prompts, sampling_params) + # Print the outputs. for output in outputs: prompt = output.prompt From 0af58ec10ac6eb9cab3f78abfa62390ade9ca64c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 24 Jul 2024 05:10:20 -0400 Subject: [PATCH 398/443] responses to feedback --- tests/worker/test_encoder_decoder_model_runner.py | 4 ++-- vllm/worker/enc_dec_model_runner.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index cc621ed485173..8268eb7b9582d 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -46,7 +46,8 @@ def _create_model_runner(model: str, *args, @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) def test_empty_seq_group(backend_name, enforce_eager, monkeypatch): - """Verify prepare prompt and decode returns empty output.""" + """Verify prepare prompt and decode returns empty output + for empty seq group list""" # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) @@ -268,7 +269,6 @@ def test_prepare_prompt(batch_size, backend_name, enforce_eager, monkeypatch): ) # - Encoder assert len(encoder_input_tokens) == sum(encoder_seq_lens) - assert len(encoder_input_tokens) == sum(encoder_seq_lens) torch.testing.assert_close( encoder_input_tokens, encoder_input_positions, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index ad6fb3cc41e7a..e718316d10b57 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -122,7 +122,8 @@ def execute_model( num_steps: int = 1, ) -> Optional[List[PoolerOutput]]: if num_steps > 1: - raise ValueError("num_steps > 1 is not supported in ModelRunner") + raise ValueError("num_steps > 1 is not supported in " + "EncoderDecoderModelRunner") if self.lora_config: assert model_input.lora_requests is not None From 47b4eb2a06bf0811f143668fbfe1f8c2caedc827 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 25 Jul 2024 00:50:08 -0400 Subject: [PATCH 399/443] fixed bug caused by upstream refactoring --- vllm/worker/enc_dec_model_runner.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index e718316d10b57..2d10086a1a20f 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -13,7 +13,9 @@ SequenceGroupMetadata) from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _PAD_SLOT_ID, LORA_WARMUP_RANK, GPUModelRunnerBase, - ModelInputForGPUWithSamplingMetadata) + ModelInputForGPUWithSamplingMetadata, + ModelInputForGPUBuilder, + ) try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper @@ -86,6 +88,8 @@ def from_broadcasted_tensor_dict( class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): _model_input_cls: Type[EncoderDecoderModelInput] = ( EncoderDecoderModelInput) + _builder_cls: Type[ModelInputForGPUBuilder] = ( + ModelInputForGPUBuilder) def __init__( self, @@ -125,11 +129,11 @@ def execute_model( raise ValueError("num_steps > 1 is not supported in " "EncoderDecoderModelRunner") - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) + # if self.lora_config: + # assert model_input.lora_requests is not None + # assert model_input.lora_mapping is not None + # self.set_active_loras(model_input.lora_requests, + # model_input.lora_mapping) if self.prompt_adapter_config: assert model_input.prompt_adapter_requests is not None From 393515eb07a84c3d1604f0c0bc52eb2d8f7c5ae0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 25 Jul 2024 00:50:27 -0400 Subject: [PATCH 400/443] formatting --- vllm/worker/enc_dec_model_runner.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 2d10086a1a20f..de132c84aeddc 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -13,9 +13,8 @@ SequenceGroupMetadata) from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _PAD_SLOT_ID, LORA_WARMUP_RANK, GPUModelRunnerBase, - ModelInputForGPUWithSamplingMetadata, ModelInputForGPUBuilder, - ) + ModelInputForGPUWithSamplingMetadata) try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper @@ -88,8 +87,7 @@ def from_broadcasted_tensor_dict( class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): _model_input_cls: Type[EncoderDecoderModelInput] = ( EncoderDecoderModelInput) - _builder_cls: Type[ModelInputForGPUBuilder] = ( - ModelInputForGPUBuilder) + _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) def __init__( self, From c2cc010acc1bb632bb7297da970ff865b22c7f27 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 25 Jul 2024 01:33:04 -0400 Subject: [PATCH 401/443] Removed lora from enc/dec model runner --- vllm/worker/enc_dec_model_runner.py | 55 ++--------------------------- 1 file changed, 2 insertions(+), 53 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index de132c84aeddc..23c749f7f8870 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -12,7 +12,7 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput, SequenceGroupMetadata) from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _PAD_SLOT_ID, - LORA_WARMUP_RANK, GPUModelRunnerBase, + GPUModelRunnerBase, ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) @@ -28,7 +28,6 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 from vllm.inputs import INPUT_REGISTRY -from vllm.lora.request import LoRARequest from vllm.model_executor import SamplingMetadata from vllm.model_executor.models.interfaces import supports_vision from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs @@ -59,8 +58,6 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "input_positions": self.input_positions, "encoder_input_tokens": self.encoder_input_tokens, "encoder_input_positions": self.encoder_input_positions, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, "prompt_adapter_mapping": self.prompt_adapter_mapping, "prompt_adapter_requests": self.prompt_adapter_requests, @@ -109,7 +106,7 @@ def __init__( device_config, cache_config, load_config, - lora_config=lora_config, + lora_config=None, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, prompt_adapter_config=prompt_adapter_config, @@ -127,12 +124,6 @@ def execute_model( raise ValueError("num_steps > 1 is not supported in " "EncoderDecoderModelRunner") - # if self.lora_config: - # assert model_input.lora_requests is not None - # assert model_input.lora_mapping is not None - # self.set_active_loras(model_input.lora_requests, - # model_input.lora_mapping) - if self.prompt_adapter_config: assert model_input.prompt_adapter_requests is not None assert model_input.prompt_adapter_mapping is not None @@ -260,29 +251,6 @@ def profile_run(self) -> None: sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs - # This represents the maximum number of different requests - # that will have unique loras, an therefore the max amount of memory - # consumption create dummy lora request copies from the lora request - # passed in, which contains a lora from the lora warmup path. - dummy_lora_requests: List[LoRARequest] = [] - dummy_lora_requests_per_seq: List[LoRARequest] = [] - if self.lora_config: - assert self.lora_manager is not None - with self.lora_manager.dummy_lora_cache(): - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_local_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. @@ -329,8 +297,6 @@ def profile_run(self) -> None: seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, - lora_request=dummy_lora_requests_per_seq[group_id] - if dummy_lora_requests_per_seq else None, encoder_seq_data=seq_data, cross_block_table=None, multi_modal_data=dummy_multi_modal_data, @@ -373,9 +339,6 @@ def _prepare_encoder_model_input_tensors( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - lora_index_mapping: List[int] = [] - lora_prompt_mapping: List[int] = [] - lora_requests: Set[LoRARequest] = set() prompt_adapter_index_mapping: List[int] = [] prompt_adapter_prompt_mapping: List[int] = [] prompt_adapter_requests: Set[PromptAdapterRequest] = set() @@ -510,7 +473,6 @@ def _prepare_encoder_model_input_tensors( query_lens.append(query_len) input_tokens.extend(tokens) input_positions.extend(list(range(context_len, seq_len))) - lora_id = seq_group_metadata.lora_int_id prompt_adapter_id = seq_group_metadata.prompt_adapter_id if is_prompt: @@ -520,22 +482,9 @@ def _prepare_encoder_model_input_tensors( decode_only = False prefill_seq_lens.append(seq_len) else: - # assert is_encoder_seq or query_len == 1, ( - # "seq_len: {}, context_len: {}, query_len: {}".format( - # seq_len, context_len, query_len)) num_decode_tokens += query_len decode_seq_lens.append(sliding_seq_len) - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * query_len - lora_prompt_mapping.extend( - [lora_id] * - (query_len if seq_group_metadata.sampling_params and - seq_group_metadata.sampling_params.prompt_logprobs is not None - else 1)) - mm_data = seq_group_metadata.multi_modal_data if mm_data: # Process multi-modal data From 3327e5be3b07bc35a607a1f4fa1fba2fc4f5904e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 25 Jul 2024 09:49:44 -0400 Subject: [PATCH 402/443] removed lora & vision & mm code from enc/dec modelrunner --- vllm/worker/enc_dec_model_runner.py | 97 +++++------------------------ 1 file changed, 15 insertions(+), 82 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 23c749f7f8870..74da873b6babc 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -29,8 +29,6 @@ from vllm.inputs import INPUT_REGISTRY from vllm.model_executor import SamplingMetadata -from vllm.model_executor.models.interfaces import supports_vision -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.utils import make_tensor_with_pad @@ -58,7 +56,6 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "input_positions": self.input_positions, "encoder_input_tokens": self.encoder_input_tokens, "encoder_input_positions": self.encoder_input_positions, - "multi_modal_kwargs": self.multi_modal_kwargs, "prompt_adapter_mapping": self.prompt_adapter_mapping, "prompt_adapter_requests": self.prompt_adapter_requests, "virtual_engine": self.virtual_engine, @@ -100,17 +97,18 @@ def __init__( prompt_adapter_config: Optional[PromptAdapterConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, ): - super().__init__(model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - lora_config=None, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - multimodal_config=multimodal_config) + super().__init__( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config=None, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + ) @torch.inference_mode() def execute_model( @@ -142,18 +140,9 @@ def execute_model( if prefill_meta is None and decode_meta.use_cuda_graph: raise NotImplementedError("CUDAGraph is currently not supported " "for encoder/decoder models.") - # TODO(andoorve): We can remove this once all - # virtual engines share the same kv cache. - # virtual_engine = model_input.virtual_engine - # if prefill_meta is None and decode_meta.use_cuda_graph: - # assert model_input.input_tokens is not None - # graph_batch_size = model_input.input_tokens.shape[0] - # model_executable = self.graph_runners[virtual_engine][ - # graph_batch_size] - # else: + model_executable = self.model - multi_modal_kwargs = model_input.multi_modal_kwargs or {} seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, @@ -166,7 +155,6 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **multi_modal_kwargs, **seqlen_agnostic_kwargs) # Compute the logits in the last pipeline stage. @@ -255,27 +243,8 @@ def profile_run(self) -> None: # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for vision encoding, which needs - # to be accounted for when calculating the GPU blocks for - # vLLM blocker manager. - # To exercise the worst scenario for GPU memory consumption, - # the number of seqs (batch_size) is chosen to maximize the number - # of images processed. - model_config = self.model_config - if supports_vision(self.model): - max_mm_tokens = MULTIMODAL_REGISTRY \ - .get_max_multimodal_tokens(model_config) - max_num_seqs_orig = max_num_seqs - max_num_seqs = min(max_num_seqs, - max_num_batched_tokens // max_mm_tokens) - if max_num_seqs < 1: - expr = (f"min({max_num_seqs_orig}, " - f"{max_num_batched_tokens} // {max_mm_tokens})") - logger.warning( - "Computed max_num_seqs (%s) to be less than 1. " - "Setting it to the minimum value of 1.", expr) - max_num_seqs = 1 + model_config = self.model_config batch_size = 0 for group_id in range(max_num_seqs): @@ -283,7 +252,7 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ + seq_data, _ = INPUT_REGISTRY \ .dummy_data_for_profiling(model_config, seq_len) # Having more tokens is over-conservative but otherwise fine @@ -299,7 +268,6 @@ def profile_run(self) -> None: block_tables=None, encoder_seq_data=seq_data, cross_block_table=None, - multi_modal_data=dummy_multi_modal_data, ) seqs.append(seq) @@ -349,7 +317,6 @@ def _prepare_encoder_model_input_tensors( context_lens: List[int] = [] query_lens: List[int] = [] block_tables: List[List[int]] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] decode_only = True num_prefills = 0 num_prefill_tokens = 0 @@ -363,10 +330,6 @@ def _prepare_encoder_model_input_tensors( if self.sliding_window is not None: raise NotImplementedError() - # sliding_window_blocks = (self.sliding_window + self.block_size - - # 1) // self.block_size - # block_aligned_sliding_window = \ - # sliding_window_blocks * self.block_size for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -414,22 +377,6 @@ def _prepare_encoder_model_input_tensors( sliding_seq_len = seq_len sliding_context_len = context_len - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - # if (self.sliding_window is not None and not is_prompt): - # curr_sliding_window_blocks = sliding_window_blocks - # if self.scheduler_config.use_v2_block_manager: - # # number of elements in last block - # suff_len = seq_len % self.block_size - # sliding_seq_len = min( - # seq_len, block_aligned_sliding_window + suff_len) - # if suff_len > 0: - # curr_sliding_window_blocks += 1 - # else: - # sliding_seq_len = min(seq_len, self.sliding_window) - # sliding_context_len = sliding_seq_len - 1 - # TODO(sang): Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. @@ -485,12 +432,6 @@ def _prepare_encoder_model_input_tensors( num_decode_tokens += query_len decode_seq_lens.append(sliding_seq_len) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - # Process multi-modal data - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) - if prompt_adapter_id > 0 and is_prompt: prompt_adapter_requests.add( seq_group_metadata.prompt_adapter_request) @@ -547,10 +488,6 @@ def _prepare_encoder_model_input_tensors( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - # # Prepare input tensors for flashinfer - # if self.attn_backend.get_name() == "flashinfer": - # assert False - batch_size = len(input_tokens) max_query_len = max(query_lens) max_seq_len = (max(prefill_seq_lens, default=0) @@ -579,10 +516,6 @@ def _prepare_encoder_model_input_tensors( assert (not is_prompt) or max_query_len > 0, ( "Decode-phase query_lens: {}".format(query_lens)) - # context_lens_tensor = torch.tensor(context_lens, - # dtype=torch.int, - # device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=self.device) From 47c5548936cd7bfe476d31e8248e3208a8a663d1 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 25 Jul 2024 09:53:23 -0400 Subject: [PATCH 403/443] checked out examples/offline_inference.py from main --- examples/offline_inference.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index f15698e8c8be0..9b758fa2479f6 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,7 +1,4 @@ -from utils import override_backend_env_var_context_manager - from vllm import LLM, SamplingParams -from vllm.utils import STR_XFORMERS_ATTN_VAL # Sample prompts. prompts = [ @@ -10,22 +7,16 @@ "The capital of France is", "The future of AI is", ] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -with override_backend_env_var_context_manager(STR_XFORMERS_ATTN_VAL): - - # Create a sampling params object. - sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - - # Create an LLM. - llm = LLM(model="facebook/opt-125m", - enforce_eager=True, - tensor_parallel_size=4) - # Generate texts from the prompts. The output is a list of - # RequestOutput objects that contain the prompt, generated - # text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +# Create an LLM. +llm = LLM(model="facebook/opt-125m") +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") From 1bb7ad9f2f5e4c84e283c5c0c59006d817440609 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 25 Jul 2024 09:59:34 -0400 Subject: [PATCH 404/443] updated RequestOutput docstring --- vllm/outputs.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/outputs.py b/vllm/outputs.py index 085b32b862439..f36a1ac59e19c 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -76,6 +76,10 @@ class RequestOutput: finished: Whether the whole request is finished. metrics: Metrics associated with the request. lora_request: The LoRA request that was used to generate the output. + encoder_prompt: The encoder prompt string of the request; + None if decoder-only + encoder_prompt_token_ids: The encoder token IDs of the prompt; + None if decoder-only """ def __init__( From 035d90dfc21bbc12d12d2368a2d5d5175ead31ca Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 25 Jul 2024 10:01:31 -0400 Subject: [PATCH 405/443] updated RequestOutput docstring --- vllm/outputs.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index f36a1ac59e19c..a3973a9fc6399 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -70,7 +70,11 @@ class RequestOutput: Args: request_id: The unique ID of the request. prompt: The prompt string of the request. + For encoder/decoder models, this is the + decoder input prompt. prompt_token_ids: The token IDs of the prompt. + For encoder/decoder models, this is the + decoder input prompt token ids. prompt_logprobs: The log probabilities to return per prompt token. outputs: The output sequences of the request. finished: Whether the whole request is finished. @@ -78,7 +82,7 @@ class RequestOutput: lora_request: The LoRA request that was used to generate the output. encoder_prompt: The encoder prompt string of the request; None if decoder-only - encoder_prompt_token_ids: The encoder token IDs of the prompt; + encoder_prompt_token_ids: The token IDs of the encoder prompt; None if decoder-only """ From 64685acfe52177d1e01362ece71d3faab73e8e45 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 25 Jul 2024 10:13:44 -0400 Subject: [PATCH 406/443] Sequence docstring --- vllm/sequence.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vllm/sequence.py b/vllm/sequence.py index e6ba368f1a886..70ebc2a87113d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -236,13 +236,26 @@ def __repr__(self) -> str: class Sequence: """Stores the data, status, and block information of a sequence. + The sequence is constructed from the LLMInputs instance passed + in through the `inputs` constructor argument. + + For encoder/decoder models, LLMInputs encapsulates both a + decoder and encoder prompt, creating an ambiguity about which + prompt to construct the sequence from. The `from_decoder_prompt` + constructor argument signals whether to construct the Sequence + from the LLMInputs decoder prompt, or encoder prompt. + Args: seq_id: The ID of the sequence. inputs: The inputs of the sequence. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. + eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. lora_request: LoRA request. prompt_adapter_request: Prompt Adapter request. + from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt + (True) or encoder prompt (False.) Must be True + for decoder-only model. """ From d1751db42bac1baf50b5fa542c770fbab13ba9ff Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 25 Jul 2024 10:35:45 -0400 Subject: [PATCH 407/443] removed flashinfer references from enc/dec modelrunner --- vllm/worker/enc_dec_model_runner.py | 26 +++----------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 74da873b6babc..c4b0b68fcd7d8 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -16,17 +16,6 @@ ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) -try: - from flashinfer import BatchDecodeWithPagedKVCacheWrapper - from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper - from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -except ImportError: - BatchDecodeWithPagedKVCacheWrapper = None - CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None - BatchPrefillWithPagedKVCacheWrapper = None - FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 - from vllm.inputs import INPUT_REGISTRY from vllm.model_executor import SamplingMetadata from vllm.prompt_adapter.request import PromptAdapterRequest @@ -129,10 +118,6 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) - if self.attn_backend.get_name() == "flashinfer": - raise NotImplementedError("FlashInfer is currently not supported " - "for encoder/decoder models.") - # Currently cuda graph is not supported for encoder/decoder models assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata @@ -391,14 +376,8 @@ def _prepare_encoder_model_input_tensors( "Prefix caching is not supported with sliding window" sliding_context_len = context_len - if self.attn_backend.get_name() == "flash-attn": - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - # TODO(woosuk): This is a temporary fix. We should - # provide a unified interface for different backends. - block_table = cross_block_table - else: - block_table = computed_block_nums + block_table = computed_block_nums + elif (self.scheduler_config.chunked_prefill_enabled or not is_prompt): if cross_block_table is not None: @@ -585,6 +564,7 @@ def _prepare_encoder_model_input_tensors( logits_soft_cap = getattr(self.model_config.hf_config, 'attn_logit_softcapping', None) + if logits_soft_cap is not None and self.attn_backend.get_name( ) != "flashinfer": raise ValueError("Models with logits_soft_cap (i.e., Gemma-2)" From f0abcc27e642dda6371eb1440de519166642a9e7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 25 Jul 2024 10:37:45 -0400 Subject: [PATCH 408/443] format --- vllm/worker/enc_dec_model_runner.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index c4b0b68fcd7d8..ae250e5352f73 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -8,19 +8,18 @@ ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.distributed import get_pp_group +from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput, SequenceGroupMetadata) +from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _PAD_SLOT_ID, GPUModelRunnerBase, ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) - -from vllm.inputs import INPUT_REGISTRY -from vllm.model_executor import SamplingMetadata -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams -from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict) @@ -377,7 +376,7 @@ def _prepare_encoder_model_input_tensors( sliding_context_len = context_len block_table = computed_block_nums - + elif (self.scheduler_config.chunked_prefill_enabled or not is_prompt): if cross_block_table is not None: @@ -564,7 +563,7 @@ def _prepare_encoder_model_input_tensors( logits_soft_cap = getattr(self.model_config.hf_config, 'attn_logit_softcapping', None) - + if logits_soft_cap is not None and self.attn_backend.get_name( ) != "flashinfer": raise ValueError("Models with logits_soft_cap (i.e., Gemma-2)" From 4bb7fc442f67dd162a001900e485d02d64fa24ed Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 25 Jul 2024 10:45:03 -0400 Subject: [PATCH 409/443] removed chunked prefill logic/docstring text from enc/dec modelrunner --- vllm/worker/enc_dec_model_runner.py | 56 ++++------------------------- 1 file changed, 7 insertions(+), 49 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index ae250e5352f73..2bb41f8d4f70c 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -157,20 +157,6 @@ def execute_model( sampling_metadata=model_input.sampling_metadata, ) - if self.return_hidden_states: - # we only need to pass hidden states of most recent token - assert model_input.sampling_metadata is not None - indices = model_input.sampling_metadata.selected_token_indices - if model_input.is_prompt: - hidden_states = hidden_or_intermediate_states.index_select( - 0, indices) - # elif decode_meta.use_cuda_graph: - # hidden_states = hidden_or_intermediate_states[:len(indices)] - else: - hidden_states = hidden_or_intermediate_states - - output.hidden_states = hidden_states - return [output] def make_model_input_from_broadcasted_tensor_dict( @@ -189,15 +175,10 @@ def prepare_model_input( """Prepare the model input based on a given sequence group, including metadata for the sampling step. - The API assumes seq_group_metadata_list is sorted by prefill -> decode. - - The result tensors and data structure also batches input in prefill - -> decode order. For example, + Since chunked prefill is not supported for encoder/decoder models, + `input_tokens` is assumed to be either entirely prefill tokens or + entirely decode tokens. - - input_tokens[:num_prefill_tokens] contains prefill tokens. - - input_tokens[num_prefill_tokens:] contains decode tokens. - - If cuda graph is required, this API automatically pads inputs. """ model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) @@ -320,13 +301,6 @@ def _prepare_encoder_model_input_tensors( is_prompt = seq_group_metadata.is_prompt computed_block_nums = None - if (self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and not (computed_block_nums is None - or computed_block_nums == [])): - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") seq_data = seq_group_metadata.encoder_seq_data cross_block_table = seq_group_metadata.cross_block_table @@ -361,26 +335,9 @@ def _prepare_encoder_model_input_tensors( sliding_seq_len = seq_len sliding_context_len = context_len - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - if prefix_cache_hit: - assert computed_block_nums is not None - context_len = len(computed_block_nums) * self.block_size - tokens = tokens[context_len:] - - # need to think what to set it to when we have both sliding - # window and prefix caching... - assert self.sliding_window is None, \ - "Prefix caching is not supported with sliding window" - sliding_context_len = context_len - - block_table = computed_block_nums - - elif (self.scheduler_config.chunked_prefill_enabled - or not is_prompt): + if not is_prompt: if cross_block_table is not None: - # chunked prefill or decode + # Decode block_table = cross_block_table if curr_sliding_window_blocks is not None: block_table = block_table[-curr_sliding_window_blocks:] @@ -388,8 +345,9 @@ def _prepare_encoder_model_input_tensors( # Only happens when memory profiling runs. block_table = [] else: - # Prefill without chunked prefill or memory profiling. + # Prefill without memory profiling. block_table = [] + block_tables.append(block_table) seq_lens.append(sliding_seq_len) From a936faa57000aca5be159de260fae8c8849148b6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 25 Jul 2024 10:52:50 -0400 Subject: [PATCH 410/443] removed prefix caching from enc/dec modelrunner --- vllm/worker/enc_dec_model_runner.py | 30 ----------------------------- 1 file changed, 30 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 2bb41f8d4f70c..b6a0266345a2f 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -293,9 +293,6 @@ def _prepare_encoder_model_input_tensors( # metadata list arg is an empty list return model_input - if self.sliding_window is not None: - raise NotImplementedError() - for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) is_prompt = seq_group_metadata.is_prompt @@ -321,12 +318,6 @@ def _prepare_encoder_model_input_tensors( # tokens. tokens = [seq_data.get_last_token_id()] - # Prefix cache was hit. - # Prefix is not supported with sliding_window - prefix_cache_hit = (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None and is_prompt) - # These are seq_len/context_len capped to the sliding window. # They are passed to decode kernel. # We still need original seq_len/context_len to compute slot @@ -396,28 +387,7 @@ def _prepare_encoder_model_input_tensors( # Compute the slot mapping. block_table = cross_block_table - # Mask the [0, start_idx) tokens of the prompt with - # _PAD_SLOT_ID, where start_idx is max(0, seq_len - - # sliding_window). For example, if the prompt len is 10, - # sliding window is 8, and block size is 4, the first two - # tokens are masked and the slot mapping will be - # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - if is_prompt: - assert self.scheduler_config.use_v2_block_manager \ - or context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention in V1 block manager") - # It is an optimization. When it is decoding, it is always - # 0. When prefill, we use it to not write slots to kv cache - # to save memory. - start_idx = max(0, query_len - self.sliding_window) - for i in range(context_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue block_number = block_table[i // self.block_size] block_offset = i % self.block_size From 53c5148e9f5024f2eb6a83bbf7af191dc88fe555 Mon Sep 17 00:00:00 2001 From: laishzh Date: Tue, 13 Aug 2024 16:11:53 +0800 Subject: [PATCH 411/443] (WIP)feat: EmbeddingModelRunner support encoder model --- examples/offline_inference_bert_embedding.py | 1 + vllm/core/embedding_model_block_manager.py | 6 + vllm/model_executor/models/bert_embedding.py | 30 +- vllm/utils.py | 18 +- vllm/worker/embedding_model_runner.py | 391 ++++++++++++++++++- 5 files changed, 438 insertions(+), 8 deletions(-) diff --git a/examples/offline_inference_bert_embedding.py b/examples/offline_inference_bert_embedding.py index 30982316e55b0..99427cb0cff7d 100644 --- a/examples/offline_inference_bert_embedding.py +++ b/examples/offline_inference_bert_embedding.py @@ -3,6 +3,7 @@ # Sample prompts. prompts = [ "This is an example sentence.", + # "A Good Day", ] # Create an LLM. diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index f2d67306d7ceb..f1db4965c160b 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -62,6 +62,12 @@ def free(self, seq: Sequence) -> None: def get_block_table(self, seq: Sequence) -> List[int]: return None # type: ignore + def get_cross_block_table(self, seq: Sequence) -> List[int]: + return None # type: ignore + + def free_cross(self, seq_group: SequenceGroup) -> None: + return + def get_num_free_gpu_blocks(self) -> int: return 1 diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index d7de187d75fb5..f927936938018 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -60,18 +60,21 @@ def __init__( self.model = BertModel(config=kwargs["config"], cache_config=kwargs.get("cache_config", None), quant_config=kwargs.get("quant_config", None)) - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + # self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self._pooler = BertPooler( ) def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self.model(input_ids=input_ids, - position_ids=positions, + return self.model(input_ids=encoder_input_ids, + position_ids=encoder_positions, kv_caches=kv_caches, inputs_embeds=inputs_embeds, attn_metadata=attn_metadata) @@ -304,7 +307,7 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: self_outputs = self.self(hidden_states, kv_cache, attn_metadata) - attn_output = self.output(self_outputs[0], hidden_states) + attn_output = self.output(self_outputs, hidden_states) return attn_output @@ -363,7 +366,7 @@ def forward( v, kv_cache, attn_metadata, - attn_type=AttentionType) + attn_type=AttentionType.ENCODER) return output @@ -375,7 +378,8 @@ def __init__(self, config: BertConfig): self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.layernorm(hidden_states + input_tensor) return hidden_states @@ -410,3 +414,17 @@ def forward( hidden_states = self.dense(hidden_states) hidden_states = self.layernorm(hidden_states + input_tensor) return hidden_states + + +class BertPooler(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output diff --git a/vllm/utils.py b/vllm/utils.py index 30bb81722aa04..0b32c9e05f920 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1090,7 +1090,7 @@ def _cuda_device_count_stateless( def cuda_device_count_stateless() -> int: """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES at the time of call. - + This should be used instead of torch.cuda.device_count() unless CUDA_VISIBLE_DEVICES has already been set to the desired value.""" @@ -1141,3 +1141,19 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, """Utility function to run async task in a lock""" async with lock: return await task(*args, **kwargs) + + +# TODO(): +def is_encoder_decoder_model_config(model_config) -> bool: + ''' + Extract the HF encoder/decoder model flag from the ModelConfig instance. + Return False if model_config is None. + ''' + if model_config is None: + return False + + is_encoder_decoder = getattr(model_config.hf_config, "is_encoder_decoder", + False) + is_decoder = getattr(model_config.hf_config, "is_decoder", False) + + return is_encoder_decoder or not is_decoder diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 197c4c730e5a7..350b942c2b1da 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, Set import torch @@ -7,13 +7,19 @@ ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MultiModalInputs +from vllm.multimodal import MultiModalInputs from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) +from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU, ModelInputForGPUBuilder) +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _PAD_SLOT_ID, + GPUModelRunnerBase) logger = init_logger(__name__) @@ -25,6 +31,9 @@ class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): """ pooling_metadata: Optional["PoolingMetadata"] = None + encoder_input_tokens: Optional[torch.Tensor] = None + encoder_input_positions: Optional[torch.Tensor] = None + class EmbeddingModelRunner( GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): @@ -106,6 +115,10 @@ def execute_model( model_input.input_tokens, "positions": model_input.input_positions, + "encoder_input_ids": + model_input.encoder_input_tokens, + "encoder_positions": + model_input.encoder_input_positions, "kv_caches": kv_caches, "attn_metadata": @@ -143,6 +156,9 @@ def prepare_model_input( assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) + + model_input = self._prepare_encoder_model_input_tensors( + seq_group_metadata_list, model_input) # Prepare PoolingMetadata. assert model_input.seq_lens is not None pooling_metadata = self._prepare_pooling(seq_group_metadata_list, @@ -151,10 +167,373 @@ def prepare_model_input( return dataclasses.replace(model_input, pooling_metadata=pooling_metadata) + def _prepare_encoder_model_input_tensors( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: ModelInputForGPUWithPoolingMetadata + ) -> ModelInputForGPUWithPoolingMetadata: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] + lora_requests: Set[LoRARequest] = set() + prompt_adapter_index_mapping: List[int] = [] + prompt_adapter_prompt_mapping: List[int] = [] + prompt_adapter_requests: Set[PromptAdapterRequest] = set() + + seq_lens: List[int] = [] + prefill_seq_lens: List[int] = [] + decode_seq_lens: List[int] = [] + context_lens: List[int] = [] + query_lens: List[int] = [] + block_tables: List[List[int]] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] + decode_only = True + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = 0 + + if len(seq_group_metadata_list) == 0: + # Leave the encoder/cross-attention input + # fields at default values if the seq group + # metadata list arg is an empty list + return model_input + + if self.sliding_window is not None: + raise NotImplementedError() + # sliding_window_blocks = (self.sliding_window + self.block_size - + # 1) // self.block_size + # block_aligned_sliding_window = \ + # sliding_window_blocks * self.block_size + + for seq_group_metadata in seq_group_metadata_list: + seq_ids = list(seq_group_metadata.seq_data.keys()) + is_prompt = seq_group_metadata.is_prompt + + computed_block_nums = None + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + seq_data = seq_group_metadata.encoder_seq_data + cross_block_table = seq_group_metadata.cross_block_table + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_data.get_len() + + seq_len = seq_data.get_len() + + if is_prompt: + tokens = seq_data.get_token_ids()[context_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + + # Prefix cache was hit. + # Prefix is not supported with sliding_window + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None and is_prompt) + + # These are seq_len/context_len capped to the sliding window. + # They are passed to decode kernel. + # We still need original seq_len/context_len to compute slot + # mapping (and input position) below. + curr_sliding_window_blocks = None + sliding_seq_len = seq_len + sliding_context_len = context_len + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + # if (self.sliding_window is not None and not is_prompt): + # curr_sliding_window_blocks = sliding_window_blocks + # if self.scheduler_config.use_v2_block_manager: + # # number of elements in last block + # suff_len = seq_len % self.block_size + # sliding_seq_len = min( + # seq_len, block_aligned_sliding_window + suff_len) + # if suff_len > 0: + # curr_sliding_window_blocks += 1 + # else: + # sliding_seq_len = min(seq_len, self.sliding_window) + # sliding_context_len = sliding_seq_len - 1 + + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + context_len = len(computed_block_nums) * self.block_size + tokens = tokens[context_len:] + + # need to think what to set it to when we have both sliding + # window and prefix caching... + assert self.sliding_window is None, \ + "Prefix caching is not supported with sliding window" + sliding_context_len = context_len + + if self.attn_backend.get_name() == "flash-attn": + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + # TODO(woosuk): This is a temporary fix. We should + # provide a unified interface for different backends. + block_table = cross_block_table + else: + block_table = computed_block_nums + elif (self.scheduler_config.chunked_prefill_enabled + or not is_prompt): + if cross_block_table is not None: + # chunked prefill or decode + block_table = cross_block_table + if curr_sliding_window_blocks is not None: + block_table = block_table[-curr_sliding_window_blocks:] + else: + # Only happens when memory profiling runs. + block_table = [] + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + block_tables.append(block_table) + + seq_lens.append(sliding_seq_len) + context_lens.append(sliding_context_len) + query_len = sliding_seq_len - sliding_context_len + query_lens.append(query_len) + input_tokens.extend(tokens) + input_positions.extend(list(range(context_len, seq_len))) + lora_id = seq_group_metadata.lora_int_id + prompt_adapter_id = seq_group_metadata.prompt_adapter_id + + if is_prompt: + assert len(seq_ids) == 1 + num_prefills += 1 + num_prefill_tokens += len(tokens) + decode_only = False + prefill_seq_lens.append(seq_len) + else: + # assert is_encoder_seq or query_len == 1, ( + # "seq_len: {}, context_len: {}, query_len: {}".format( + # seq_len, context_len, query_len)) + num_decode_tokens += query_len + decode_seq_lens.append(sliding_seq_len) + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * query_len + lora_prompt_mapping.extend( + [lora_id] * + (query_len if seq_group_metadata.sampling_params and + seq_group_metadata.sampling_params.prompt_logprobs is not None + else 1)) + + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + # Process multi-modal data + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + + if prompt_adapter_id > 0 and is_prompt: + prompt_adapter_requests.add( + seq_group_metadata.prompt_adapter_request) + + num_tokens = seq_group_metadata.\ + prompt_adapter_num_virtual_tokens + pm = [prompt_adapter_id + ] * num_tokens + [0] * (query_len - num_tokens) + prompt_adapter_index_mapping += pm + prompt_adapter_prompt_mapping.extend( + [prompt_adapter_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + + is_profile_run = _is_single_block_table_empty( + seq_group_metadata.block_tables) + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + + # Compute the slot mapping. + if block_table := cross_block_table: + + block_table = cross_block_table + + # Mask the [0, start_idx) tokens of the prompt with + # _PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + if is_prompt: + assert self.scheduler_config.use_v2_block_manager \ + or context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager") + # It is an optimization. When it is decoding, it is always + # 0. When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + for i in range(context_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + # # Prepare input tensors for flashinfer + # if self.attn_backend.get_name() == "flashinfer": + # assert False + + batch_size = len(input_tokens) + max_query_len = max(query_lens) + max_seq_len = (max(prefill_seq_lens, default=0) + if is_prompt else max(decode_seq_lens, default=0)) + + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. + use_captured_graph = (decode_only + and not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_seq_len <= self.max_seq_len_to_capture) + if use_captured_graph: + raise NotImplementedError("CUDAGraph is currently not supported " + "for encoder/decoder models.") + + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + assert (not is_prompt) or max_query_len > 0, ( + "Decode-phase query_lens: {}".format(query_lens)) + + # context_lens_tensor = torch.tensor(context_lens, + # dtype=torch.int, + # device=self.device) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + attn_metadata = model_input.attn_metadata + assert attn_metadata is not None + + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + + # Set encoder-oriented attention metadata fields + attn_metadata.num_encoder_tokens = sum(seq_lens) + attn_metadata.encoder_seq_lens = seq_lens + attn_metadata.encoder_seq_lens_tensor = seq_lens_tensor + attn_metadata.max_encoder_seq_len = max_seq_len + attn_metadata.cross_slot_mapping = slot_mapping_tensor + attn_metadata.cross_block_tables = block_tables + + if seq_group_metadata.is_prompt: + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + + else: + + input_tokens_tensor = torch.tensor([], + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor([], + dtype=torch.long, + device=self.device) + + # Inject attn_metadata encoder/cross-attention fields & + # encoder input tokens/positions into model_input. + # Frozen dataclass fields cannot be modified, so use + # dataclasses.replace to construct a new model input + # instance. + model_input = dataclasses.replace( + model_input, + attn_metadata=attn_metadata, + encoder_input_tokens=input_tokens_tensor, + encoder_input_positions=input_positions_tensor, + ) + + logits_soft_cap = getattr(self.model_config.hf_config, + 'attn_logit_softcapping', None) + if logits_soft_cap is not None and self.attn_backend.get_name( + ) != "flashinfer": + raise ValueError("Models with logits_soft_cap (i.e., Gemma-2)" + " require FlashInfer backend, however vLLM" + " currently only supports xFormers backend" + " for encoder/decoder models.") + + return model_input + def _prepare_pooling( self, seq_group_metadata_list: List[SequenceGroupMetadata], prompt_lens: List[int], + is_decoder: bool = True, ) -> PoolingMetadata: """Prepare PoolingMetadata for the sequence group metadata list.""" seq_groups: List[Tuple[List[int], PoolingParams]] = [] @@ -174,3 +553,13 @@ def _prepare_pooling( ) return pooling_metadata + + +## TODO: move to utils, and modify enc_dec_model_runner.py +def _is_single_block_table_empty(block_table: Optional[List[int]]): + """ + Check if a single block table has not been constructed + """ + if block_table is None: + return True + return False From 63fb7a582cef08ec29a8b30024a01602dc5ee636 Mon Sep 17 00:00:00 2001 From: laishzh Date: Wed, 14 Aug 2024 02:39:31 +0800 Subject: [PATCH 412/443] WIP: bert embedding --- vllm/config.py | 5 +- vllm/core/scheduler.py | 10 +- vllm/engine/llm_engine.py | 16 +- vllm/model_executor/models/bert_embedding.py | 4 +- vllm/sequence.py | 3 +- vllm/worker/embedding_model_runner.py | 379 +------------------ vllm/worker/enc_dec_model_runner.py | 10 +- 7 files changed, 48 insertions(+), 379 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 809d6370763dc..b52eda539f69c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -466,7 +466,10 @@ def _get_num_seqlen_agnostic_layers( @property def is_encoder_decoder_model(self) -> bool: """Extract the HF encoder/decoder model flag.""" - return getattr(self.hf_config, "is_encoder_decoder", False) + is_encoder_decoder = getattr(self.hf_config, "is_encoder_decoder", + False) + is_decoder = getattr(self.hf_config, "is_decoder", False) + return is_encoder_decoder or not is_decoder @property def is_embedding_model(self) -> bool: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index b16850c7eb9f8..d45f303b866de 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -750,9 +750,9 @@ def _schedule_prefills( num_new_tokens = self._get_num_new_tokens(seq_group, SequenceStatus.WAITING, enable_chunking, budget) - if not enable_chunking: - num_prompt_tokens = waiting_seqs[0].get_len() - assert num_new_tokens == num_prompt_tokens + # if not enable_chunking: + # num_prompt_tokens = waiting_seqs[0].get_len() + # assert num_new_tokens == num_prompt_tokens prompt_limit = self._get_prompt_limit(seq_group) if num_new_tokens > prompt_limit: @@ -1323,10 +1323,12 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, Returns 0 if the new token cannot be computed due to token budget. """ num_new_tokens = 0 - seqs = seq_group.get_seqs(status=status) + seqs = [seq_group.encoder_seq] + seqs.extend(seq_group.get_seqs(status=status)) for seq in seqs: num_new_tokens += seq.get_num_new_tokens() assert num_new_tokens > 0 + # Chunk if a running request cannot fit in. # If number of seq > 1, it means it is doing beam search in a # decode phase. Do not chunk in that case. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1191d0c66044d..1cefa7c033e6a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -279,8 +279,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: observability_config=self.observability_config, ) - if not self.model_config.embedding_mode: - self._initialize_kv_caches() + # if not self.model_config.embedding_mode: + self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): @@ -579,8 +579,12 @@ def _add_processed_request( seq_id = next(self.seq_counter) eos_token_id = self._get_eos_token_id(lora_request) - seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, - lora_request, prompt_adapter_request) + seq = Sequence(seq_id, + processed_inputs, + block_size, + eos_token_id, + lora_request, + prompt_adapter_request) encoder_seq = None if 'encoder_prompt_token_ids' in processed_inputs: @@ -654,7 +658,9 @@ def _prepare_decoder_input_ids_for_generation( """ decoder_start_token_id = self._get_decoder_start_token_id() - assert decoder_start_token_id is not None + # assert decoder_start_token_id is not None + if decoder_start_token_id is None: + return [] if decoder_input_ids is None: # no decoder prompt input -> diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index f927936938018..a309f5607b645 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -60,8 +60,8 @@ def __init__( self.model = BertModel(config=kwargs["config"], cache_config=kwargs.get("cache_config", None), quant_config=kwargs.get("quant_config", None)) - # self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) - self._pooler = BertPooler( ) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + # self._pooler = BertPooler(config=kwargs["config"]) def forward( self, diff --git a/vllm/sequence.py b/vllm/sequence.py index 7349bc6f13bd6..810f4ae46cd0e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -678,7 +678,8 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int): def get_num_uncomputed_tokens(self) -> int: num_uncomputed_tokens = 0 - for seq in self.seqs: + seqs = self.seqs + self.encoder_seq + for seq in seqs: if not seq.is_finished(): num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() return num_uncomputed_tokens diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 350b942c2b1da..07501541aa9e1 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -16,6 +16,7 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) from vllm.utils import make_tensor_with_pad +from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunnerBase from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU, ModelInputForGPUBuilder) from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _PAD_SLOT_ID, @@ -36,7 +37,7 @@ class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): class EmbeddingModelRunner( - GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): + EncoderDecoderModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( ModelInputForGPUWithPoolingMetadata) _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder @@ -157,377 +158,27 @@ def prepare_model_input( model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) - model_input = self._prepare_encoder_model_input_tensors( + ( + attn_metadata, + encoder_input_tokens_tensor, + encoder_input_positions_tensor, + ) = super()._prepare_encoder_model_input_tensors( seq_group_metadata_list, model_input) - # Prepare PoolingMetadata. - assert model_input.seq_lens is not None - pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - model_input.seq_lens) - - return dataclasses.replace(model_input, - pooling_metadata=pooling_metadata) - - def _prepare_encoder_model_input_tensors( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - model_input: ModelInputForGPUWithPoolingMetadata - ) -> ModelInputForGPUWithPoolingMetadata: - """Helper method to prepare the model input based on a given sequence - group. Prepares metadata needed for the base model forward pass but not - metadata for possible additional steps, e.g., sampling. - - The API assumes seq_group_metadata_list is sorted by prefill -> decode. - - The result tensors and data structure also batches input in prefill - -> decode order. For example, - - - input_tokens[:num_prefill_tokens] contains prefill tokens. - - input_tokens[num_prefill_tokens:] contains decode tokens. - - If cuda graph is required, this API automatically pads inputs. - """ - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - lora_index_mapping: List[int] = [] - lora_prompt_mapping: List[int] = [] - lora_requests: Set[LoRARequest] = set() - prompt_adapter_index_mapping: List[int] = [] - prompt_adapter_prompt_mapping: List[int] = [] - prompt_adapter_requests: Set[PromptAdapterRequest] = set() - - seq_lens: List[int] = [] - prefill_seq_lens: List[int] = [] - decode_seq_lens: List[int] = [] - context_lens: List[int] = [] - query_lens: List[int] = [] - block_tables: List[List[int]] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] - decode_only = True - num_prefills = 0 - num_prefill_tokens = 0 - num_decode_tokens = 0 - - if len(seq_group_metadata_list) == 0: - # Leave the encoder/cross-attention input - # fields at default values if the seq group - # metadata list arg is an empty list - return model_input - - if self.sliding_window is not None: - raise NotImplementedError() - # sliding_window_blocks = (self.sliding_window + self.block_size - - # 1) // self.block_size - # block_aligned_sliding_window = \ - # sliding_window_blocks * self.block_size - - for seq_group_metadata in seq_group_metadata_list: - seq_ids = list(seq_group_metadata.seq_data.keys()) - is_prompt = seq_group_metadata.is_prompt - - computed_block_nums = None - if (self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and not (computed_block_nums is None - or computed_block_nums == [])): - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") - - seq_data = seq_group_metadata.encoder_seq_data - cross_block_table = seq_group_metadata.cross_block_table - if is_prompt: - context_len = seq_data.get_num_computed_tokens() - else: - # get_num_computed_tokens is incorrect for spec decoding. - # So, we should have a special logic here. - # TODO(sang): Fix it. - context_len = seq_data.get_len() - - seq_len = seq_data.get_len() - - if is_prompt: - tokens = seq_data.get_token_ids()[context_len:seq_len] - else: - # Optimization. get_token_ids requires the entire copy of - # tokens. - tokens = [seq_data.get_last_token_id()] - - # Prefix cache was hit. - # Prefix is not supported with sliding_window - prefix_cache_hit = (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None and is_prompt) - - # These are seq_len/context_len capped to the sliding window. - # They are passed to decode kernel. - # We still need original seq_len/context_len to compute slot - # mapping (and input position) below. - curr_sliding_window_blocks = None - sliding_seq_len = seq_len - sliding_context_len = context_len - - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - # if (self.sliding_window is not None and not is_prompt): - # curr_sliding_window_blocks = sliding_window_blocks - # if self.scheduler_config.use_v2_block_manager: - # # number of elements in last block - # suff_len = seq_len % self.block_size - # sliding_seq_len = min( - # seq_len, block_aligned_sliding_window + suff_len) - # if suff_len > 0: - # curr_sliding_window_blocks += 1 - # else: - # sliding_seq_len = min(seq_len, self.sliding_window) - # sliding_context_len = sliding_seq_len - 1 - - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - if prefix_cache_hit: - assert computed_block_nums is not None - context_len = len(computed_block_nums) * self.block_size - tokens = tokens[context_len:] - - # need to think what to set it to when we have both sliding - # window and prefix caching... - assert self.sliding_window is None, \ - "Prefix caching is not supported with sliding window" - sliding_context_len = context_len - - if self.attn_backend.get_name() == "flash-attn": - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - # TODO(woosuk): This is a temporary fix. We should - # provide a unified interface for different backends. - block_table = cross_block_table - else: - block_table = computed_block_nums - elif (self.scheduler_config.chunked_prefill_enabled - or not is_prompt): - if cross_block_table is not None: - # chunked prefill or decode - block_table = cross_block_table - if curr_sliding_window_blocks is not None: - block_table = block_table[-curr_sliding_window_blocks:] - else: - # Only happens when memory profiling runs. - block_table = [] - else: - # Prefill without chunked prefill or memory profiling. - block_table = [] - block_tables.append(block_table) - - seq_lens.append(sliding_seq_len) - context_lens.append(sliding_context_len) - query_len = sliding_seq_len - sliding_context_len - query_lens.append(query_len) - input_tokens.extend(tokens) - input_positions.extend(list(range(context_len, seq_len))) - lora_id = seq_group_metadata.lora_int_id - prompt_adapter_id = seq_group_metadata.prompt_adapter_id - - if is_prompt: - assert len(seq_ids) == 1 - num_prefills += 1 - num_prefill_tokens += len(tokens) - decode_only = False - prefill_seq_lens.append(seq_len) - else: - # assert is_encoder_seq or query_len == 1, ( - # "seq_len: {}, context_len: {}, query_len: {}".format( - # seq_len, context_len, query_len)) - num_decode_tokens += query_len - decode_seq_lens.append(sliding_seq_len) - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping += [lora_id] * query_len - lora_prompt_mapping.extend( - [lora_id] * - (query_len if seq_group_metadata.sampling_params and - seq_group_metadata.sampling_params.prompt_logprobs is not None - else 1)) - - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - # Process multi-modal data - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) - - if prompt_adapter_id > 0 and is_prompt: - prompt_adapter_requests.add( - seq_group_metadata.prompt_adapter_request) - - num_tokens = seq_group_metadata.\ - prompt_adapter_num_virtual_tokens - pm = [prompt_adapter_id - ] * num_tokens + [0] * (query_len - num_tokens) - prompt_adapter_index_mapping += pm - prompt_adapter_prompt_mapping.extend( - [prompt_adapter_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - else 1)) - - is_profile_run = _is_single_block_table_empty( - seq_group_metadata.block_tables) - if is_profile_run: - # During memory profiling, the block tables are not - # initialized yet. In this case, we just use a dummy - # slot mapping. - # In embeddings, the block tables are {seq_id: None}. - slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. - if block_table := cross_block_table: - - block_table = cross_block_table - - # Mask the [0, start_idx) tokens of the prompt with - # _PAD_SLOT_ID, where start_idx is max(0, seq_len - - # sliding_window). For example, if the prompt len is 10, - # sliding window is 8, and block size is 4, the first two - # tokens are masked and the slot mapping will be - # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - if is_prompt: - assert self.scheduler_config.use_v2_block_manager \ - or context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention in V1 block manager") - # It is an optimization. When it is decoding, it is always - # 0. When prefill, we use it to not write slots to kv cache - # to save memory. - start_idx = max(0, query_len - self.sliding_window) - - for i in range(context_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - # # Prepare input tensors for flashinfer - # if self.attn_backend.get_name() == "flashinfer": - # assert False - - batch_size = len(input_tokens) - max_query_len = max(query_lens) - max_seq_len = (max(prefill_seq_lens, default=0) - if is_prompt else max(decode_seq_lens, default=0)) - - # If cuda graph can be used, pad tensors accordingly. - # See `capture_model` API for more details. - # vLLM uses cuda graph only for decoding requests. - use_captured_graph = (decode_only - and not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_seq_len <= self.max_seq_len_to_capture) - if use_captured_graph: - raise NotImplementedError("CUDAGraph is currently not supported " - "for encoder/decoder models.") - - max_block_table_len = max( - len(block_table) for block_table in block_tables) - block_tables = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) - assert (not is_prompt) or max_query_len > 0, ( - "Decode-phase query_lens: {}".format(query_lens)) - - # context_lens_tensor = torch.tensor(context_lens, - # dtype=torch.int, - # device=self.device) - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) - - attn_metadata = model_input.attn_metadata - assert attn_metadata is not None - - slot_mapping_tensor = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - - # Set encoder-oriented attention metadata fields - attn_metadata.num_encoder_tokens = sum(seq_lens) - attn_metadata.encoder_seq_lens = seq_lens - attn_metadata.encoder_seq_lens_tensor = seq_lens_tensor - attn_metadata.max_encoder_seq_len = max_seq_len - attn_metadata.cross_slot_mapping = slot_mapping_tensor - attn_metadata.cross_block_tables = block_tables - - if seq_group_metadata.is_prompt: - - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions_tensor = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - - else: - - input_tokens_tensor = torch.tensor([], - dtype=torch.long, - device=self.device) - input_positions_tensor = torch.tensor([], - dtype=torch.long, - device=self.device) - - # Inject attn_metadata encoder/cross-attention fields & - # encoder input tokens/positions into model_input. - # Frozen dataclass fields cannot be modified, so use - # dataclasses.replace to construct a new model input - # instance. model_input = dataclasses.replace( model_input, attn_metadata=attn_metadata, - encoder_input_tokens=input_tokens_tensor, - encoder_input_positions=input_positions_tensor, + encoder_input_tokens=encoder_input_tokens_tensor, + encoder_input_positions=encoder_input_positions_tensor, ) - logits_soft_cap = getattr(self.model_config.hf_config, - 'attn_logit_softcapping', None) - if logits_soft_cap is not None and self.attn_backend.get_name( - ) != "flashinfer": - raise ValueError("Models with logits_soft_cap (i.e., Gemma-2)" - " require FlashInfer backend, however vLLM" - " currently only supports xFormers backend" - " for encoder/decoder models.") + # Prepare PoolingMetadata. + assert model_input.seq_lens is not None + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + model_input.seq_lens) - return model_input + return dataclasses.replace(model_input, + pooling_metadata=pooling_metadata) def _prepare_pooling( self, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 4e66a04674c2a..e0b8219e6ca7a 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -21,7 +21,8 @@ from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase, ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + ModelInputForGPUWithSamplingMetadata, + TModelInputForGPU) from vllm.worker.model_runner_base import ( _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict) @@ -64,7 +65,7 @@ def from_broadcasted_tensor_dict( super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) -class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): +class EncoderDecoderModelRunnerBase(GPUModelRunnerBase[TModelInputForGPU]): _model_input_cls: Type[EncoderDecoderModelInput] = ( EncoderDecoderModelInput) _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) @@ -471,3 +472,8 @@ def _prepare_encoder_model_input_tensors( return (attn_metadata, encoder_input_tokens_tensor, encoder_input_positions_tensor) + + +class EncoderDecoderModelRunner( + EncoderDecoderModelRunnerBase[EncoderDecoderModelInput]): + pass From 37bcba01408d37b192063e2ee2b9ac1c3087393c Mon Sep 17 00:00:00 2001 From: laishzh Date: Wed, 14 Aug 2024 17:47:05 +0800 Subject: [PATCH 413/443] feat: full pipeline --- examples/offline_inference_bert_embedding.py | 6 ++--- examples/offline_inference_encoder_decoder.py | 4 +++ vllm/config.py | 7 +++++ vllm/core/scheduler.py | 9 ++++--- vllm/engine/llm_engine.py | 26 ++++++++++++------- vllm/inputs/__init__.py | 4 ++- vllm/inputs/data.py | 6 +++++ vllm/worker/enc_dec_model_runner.py | 9 ++++--- 8 files changed, 49 insertions(+), 22 deletions(-) diff --git a/examples/offline_inference_bert_embedding.py b/examples/offline_inference_bert_embedding.py index 99427cb0cff7d..9286afc346775 100644 --- a/examples/offline_inference_bert_embedding.py +++ b/examples/offline_inference_bert_embedding.py @@ -1,10 +1,8 @@ from vllm import LLM +from vllm.inputs import build_encoder_prompt # Sample prompts. -prompts = [ - "This is an example sentence.", - # "A Good Day", -] +prompts = [build_encoder_prompt("This is an example sentence.")] # Create an LLM. model = LLM(model="bert-base-uncased", enforce_eager=True) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 0f266d7918853..aed5695fe581d 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -55,6 +55,7 @@ ) enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( # Pass encoder prompt tokens directly, and + # prompts = zip_enc_dec_prompts(['Only encoder prompt'], ["Decoder prompt"]) # pass TextPrompt to decoder encoder_prompt=single_tokens_prompt, decoder_prompt=single_text_prompt, @@ -74,6 +75,9 @@ enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 ] + zipped_prompt_list +# prompts = zip_enc_dec_prompts(['Only encoder prompt'], ["Decoder prompt"]) +prompts = zip_enc_dec_prompts(['Only encoder prompt'], [None]) + print(prompts) # Create a sampling params object. diff --git a/vllm/config.py b/vllm/config.py index b52eda539f69c..7bcc443cf4784 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -471,6 +471,13 @@ def is_encoder_decoder_model(self) -> bool: is_decoder = getattr(self.hf_config, "is_decoder", False) return is_encoder_decoder or not is_decoder + @property + def is_encoder_model(self) -> bool: + is_encoder_decoder = getattr(self.hf_config, "is_encoder_decoder", + False) + is_decoder = getattr(self.hf_config, "is_decoder", False) + return is_encoder_decoder is False and is_decoder is False + @property def is_embedding_model(self) -> bool: """Extract the embedding model flag.""" diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d45f303b866de..8dd574a9da97e 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -309,8 +309,8 @@ def __init__( version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" - if self.scheduler_config.embedding_mode: - version = "embedding" + # if self.scheduler_config.embedding_mode: + # version = "embedding" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version) @@ -1323,8 +1323,9 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, Returns 0 if the new token cannot be computed due to token budget. """ num_new_tokens = 0 - seqs = [seq_group.encoder_seq] - seqs.extend(seq_group.get_seqs(status=status)) + seqs = seq_group.get_seqs(status=status) + # seqs = [seq_group.encoder_seq] + # seqs.extend(seq_group.get_seqs(status=status)) for seq in seqs: num_new_tokens += seq.get_num_new_tokens() assert num_new_tokens > 0 diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1cefa7c033e6a..d302d6cb96101 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -526,7 +526,11 @@ def _get_bos_token_id(self, "is not initialized") return None - return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id + bos_token_id = self.tokenizer.get_lora_tokenizer( + lora_request).bos_token_id + if bos_token_id is None: + bos_token_id = 1 + return bos_token_id def _get_eos_token_id(self, lora_request: Optional[LoRARequest] = None @@ -545,6 +549,11 @@ def _get_decoder_start_token_id(self) -> Optional[int]: model config is unavailable. ''' + if self.is_encoder_model(): + logger.warning("Using 1 for decoder start token id because " + "this is an encoder model.") + return 1 + if not self.is_encoder_decoder_model(): logger.warning("Using None for decoder start token id because " "this is not an encoder/decoder model.") @@ -579,12 +588,8 @@ def _add_processed_request( seq_id = next(self.seq_counter) eos_token_id = self._get_eos_token_id(lora_request) - seq = Sequence(seq_id, - processed_inputs, - block_size, - eos_token_id, - lora_request, - prompt_adapter_request) + seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, + lora_request, prompt_adapter_request) encoder_seq = None if 'encoder_prompt_token_ids' in processed_inputs: @@ -658,9 +663,7 @@ def _prepare_decoder_input_ids_for_generation( """ decoder_start_token_id = self._get_decoder_start_token_id() - # assert decoder_start_token_id is not None - if decoder_start_token_id is None: - return [] + assert decoder_start_token_id is not None if decoder_input_ids is None: # no decoder prompt input -> @@ -1619,5 +1622,8 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: def is_encoder_decoder_model(self): return self.model_config.is_encoder_decoder_model + def is_encoder_model(self): + return self.model_config.is_encoder_model + def is_embedding_model(self): return self.model_config.is_embedding_model diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 0b08e9691f915..639c45db5aea3 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,8 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, TokensPrompt, build_explicit_enc_dec_prompt, - to_enc_dec_tuple_list, zip_enc_dec_prompts) + build_encoder_prompt, to_enc_dec_tuple_list, + zip_enc_dec_prompts) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -21,6 +22,7 @@ "ExplicitEncoderDecoderPrompt", "LLMInputs", "EncoderDecoderLLMInputs", + "build_encoder_prompt", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 75ab0c770155b..ec875326d7ecc 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -176,3 +176,9 @@ def to_enc_dec_tuple_list( return [(enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) for enc_dec_prompt in enc_dec_prompts] + + +def build_encoder_prompt( + encoder_prompt: _T1) -> ExplicitEncoderDecoderPrompt[_T1, _T2]: + return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, + decoder_prompt=None) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index e0b8219e6ca7a..376a6e5433d04 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -217,7 +217,7 @@ def prepare_model_input( seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None - ) -> EncoderDecoderModelInput: + ) -> TModelInputForGPU: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -312,7 +312,7 @@ def profile_run(self) -> None: def _prepare_encoder_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - model_input: EncoderDecoderModelInput, + model_input: TModelInputForGPU, ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], Optional[torch.Tensor]]: """Helper method to prepare the encoder- and cross-attn-related @@ -476,4 +476,7 @@ def _prepare_encoder_model_input_tensors( class EncoderDecoderModelRunner( EncoderDecoderModelRunnerBase[EncoderDecoderModelInput]): - pass + + _model_input_cls: Type[EncoderDecoderModelInput] = ( + EncoderDecoderModelInput) + _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) From 76b47fb1b7920fb50a889f19e1c1421e4385d1ca Mon Sep 17 00:00:00 2001 From: laishzh Date: Thu, 15 Aug 2024 13:18:53 +0800 Subject: [PATCH 414/443] chore: recover --- examples/offline_inference_encoder_decoder.py | 1 - examples/utils.py | 47 ------------------- tests/models/utils.py | 11 ----- vllm/attention/backends/utils.py | 3 -- vllm/core/scheduler.py | 15 +++--- vllm/engine/llm_engine.py | 14 ++---- vllm/inputs/utils.py | 16 ------- vllm/utils.py | 16 ------- 8 files changed, 11 insertions(+), 112 deletions(-) delete mode 100644 examples/utils.py delete mode 100644 vllm/inputs/utils.py diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index aed5695fe581d..fc48f0e12142c 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -55,7 +55,6 @@ ) enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( # Pass encoder prompt tokens directly, and - # prompts = zip_enc_dec_prompts(['Only encoder prompt'], ["Decoder prompt"]) # pass TextPrompt to decoder encoder_prompt=single_tokens_prompt, decoder_prompt=single_text_prompt, diff --git a/examples/utils.py b/examples/utils.py deleted file mode 100644 index 497bee7592a09..0000000000000 --- a/examples/utils.py +++ /dev/null @@ -1,47 +0,0 @@ -'''Example code utils''' - -import os -from contextlib import contextmanager -from typing import Generator - -from vllm.utils import STR_BACKEND_ENV_VAR - - -@contextmanager -def override_backend_env_var_context_manager( - backend_name: str, ) -> Generator[None, None, None]: - ''' - Override the environment variable indicating the vLLM backend temporarily, - in a context where pytest monkeypatch is not available (i.e. *outside* - the context of a unit test, such as in an example code file.) - - Accomplish this using a custom context manager. - - Arguments: - - * backend_name: attention backend name to force - - Returns: - - * Generator - ''' - - key = STR_BACKEND_ENV_VAR - - # Save the current state of the environment variable (if it exists) - original_value = os.environ.get(key, None) - - # Set the new value of the environment variable - os.environ[key] = backend_name - - # Yield control back to the enclosed code block - try: - yield - finally: - # Revert the environment variable to its original state - if original_value is None: - os.environ.pop( - key, None) # Remove the variable if it wasn't originally set - else: - os.environ[ - key] = original_value # Revert back to the original value diff --git a/tests/models/utils.py b/tests/models/utils.py index d96301b853c85..ff29a0ae81d6e 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,5 +1,4 @@ import warnings -from enum import Enum from typing import Dict, List, Optional, Sequence, Tuple, Union from vllm.sequence import SampleLogprobs @@ -136,13 +135,3 @@ def check_logprobs_close( warnings.simplefilter("always") warnings.warn(fail_msg, stacklevel=2) - - -class DecoderPromptType(Enum): - ''' - For encoder/decoder models only - - - ''' - CUSTOM = 1 - NONE = 2 - EMPTY_STR = 3 diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 403bd619537bb..e6b5f820c5fa0 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -12,9 +12,6 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " "with encoder/decoder models.") -STR_NOT_IMPL_ENC_DEC_CPU = ("CPU backend is not current supported with " - "encoder/decoder models.") - PAD_SLOT_ID = -1 # Switch to numpy implementation of compute_slot_mapping diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 8dd574a9da97e..4e04cef8523cf 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -476,7 +476,7 @@ def _schedule_running( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. - + Returns: SchedulerRunningOutputs. """ @@ -750,9 +750,9 @@ def _schedule_prefills( num_new_tokens = self._get_num_new_tokens(seq_group, SequenceStatus.WAITING, enable_chunking, budget) - # if not enable_chunking: - # num_prompt_tokens = waiting_seqs[0].get_len() - # assert num_new_tokens == num_prompt_tokens + if not enable_chunking: + num_prompt_tokens = waiting_seqs[0].get_len() + assert num_new_tokens == num_prompt_tokens prompt_limit = self._get_prompt_limit(seq_group) if num_new_tokens > prompt_limit: @@ -823,7 +823,7 @@ def _schedule_prefills( def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. - + The current policy is designed to optimize the throughput. First, it batches as many prefill requests as possible. And it schedules decodes. If there's a pressure on GPU memory, decode requests can @@ -923,7 +923,7 @@ def _schedule_default(self) -> SchedulerOutputs: def _schedule_chunked_prefill(self) -> SchedulerOutputs: """Schedule queued requests. - + Chunked prefill allows to chunk prefill requests, batch them together with decode requests. This policy 1. schedule as many decoding requests as possible. 2. schedule chunked prefill requests that are not @@ -1324,12 +1324,9 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, """ num_new_tokens = 0 seqs = seq_group.get_seqs(status=status) - # seqs = [seq_group.encoder_seq] - # seqs.extend(seq_group.get_seqs(status=status)) for seq in seqs: num_new_tokens += seq.get_num_new_tokens() assert num_new_tokens > 0 - # Chunk if a running request cannot fit in. # If number of seq > 1, it means it is doing beam search in a # decode phase. Do not chunk in that case. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d302d6cb96101..6bdaa8ae271e9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -99,13 +99,13 @@ class LLMEngine: scheduler_config: The configuration related to the request scheduler. device_config: The configuration related to the device. lora_config (Optional): The configuration related to serving multi-LoRA. - multimodal_config (Optional): The configuration related to multimodal + multimodal_config (Optional): The configuration related to multimodal models. speculative_config (Optional): The configuration related to speculative decoding. executor_class: The model executor class for managing distributed execution. - prompt_adapter_config (Optional): The configuration related to serving + prompt_adapter_config (Optional): The configuration related to serving prompt adapters. log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection. @@ -526,11 +526,7 @@ def _get_bos_token_id(self, "is not initialized") return None - bos_token_id = self.tokenizer.get_lora_tokenizer( - lora_request).bos_token_id - if bos_token_id is None: - bos_token_id = 1 - return bos_token_id + return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id def _get_eos_token_id(self, lora_request: Optional[LoRARequest] = None @@ -786,7 +782,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: "default" decoder prompt be . However, it is possible that in the future - other models may have different or more + other models may have different or more complex logic for the default decoder prompt. This motivates having a special helper method for default decoder prompts. @@ -849,7 +845,7 @@ def _process_encoder_decoder_prompt( have any possible singleton type; thus this method relies on helper functions to obtain token ids for the sub-prompts. - + Arguments: * inputs: an input prompt diff --git a/vllm/inputs/utils.py b/vllm/inputs/utils.py deleted file mode 100644 index 3ab4da64a4db1..0000000000000 --- a/vllm/inputs/utils.py +++ /dev/null @@ -1,16 +0,0 @@ -'''Utility functions for input types''' - - -def has_required_keys( - d: dict, - required_keys: list, -) -> bool: - return set(required_keys).issubset(d.keys()) - - -def is_str(s, ) -> bool: - return isinstance(s, str) - - -def is_dict(d, ) -> bool: - return isinstance(d, dict) diff --git a/vllm/utils.py b/vllm/utils.py index 0b32c9e05f920..aab5b18893476 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1141,19 +1141,3 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, """Utility function to run async task in a lock""" async with lock: return await task(*args, **kwargs) - - -# TODO(): -def is_encoder_decoder_model_config(model_config) -> bool: - ''' - Extract the HF encoder/decoder model flag from the ModelConfig instance. - Return False if model_config is None. - ''' - if model_config is None: - return False - - is_encoder_decoder = getattr(model_config.hf_config, "is_encoder_decoder", - False) - is_decoder = getattr(model_config.hf_config, "is_decoder", False) - - return is_encoder_decoder or not is_decoder From aca786e4359ef55d0af006199728c8b941558579 Mon Sep 17 00:00:00 2001 From: laishzh Date: Thu, 15 Aug 2024 13:44:03 +0800 Subject: [PATCH 415/443] feat: default bos_token_id of encoder model --- vllm/engine/llm_engine.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6bdaa8ae271e9..c202b4b8096dc 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -52,6 +52,7 @@ logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 +_DEFAULT_BOS_TOKEN_ID = 1 def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: @@ -521,6 +522,9 @@ def _verify_args(self) -> None: def _get_bos_token_id(self, lora_request: Optional[LoRARequest] = None ) -> Optional[int]: + if self.is_encoder_model(): + return _DEFAULT_BOS_TOKEN_ID + if self.tokenizer is None: logger.warning("Using None for BOS token id because tokenizer " "is not initialized") @@ -546,9 +550,7 @@ def _get_decoder_start_token_id(self) -> Optional[int]: ''' if self.is_encoder_model(): - logger.warning("Using 1 for decoder start token id because " - "this is an encoder model.") - return 1 + return self._get_bos_token_id() if not self.is_encoder_decoder_model(): logger.warning("Using None for decoder start token id because " From 682c455bb0b8c950e1e00b43a6841f433f62db97 Mon Sep 17 00:00:00 2001 From: laishzh Date: Thu, 15 Aug 2024 14:36:40 +0800 Subject: [PATCH 416/443] feat: recover sequence --- vllm/sequence.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 810f4ae46cd0e..e29be0c637c26 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -678,8 +678,7 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int): def get_num_uncomputed_tokens(self) -> int: num_uncomputed_tokens = 0 - seqs = self.seqs + self.encoder_seq - for seq in seqs: + for seq in self.seqs: if not seq.is_finished(): num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() return num_uncomputed_tokens @@ -759,7 +758,7 @@ class SequenceGroupMetadata: used in prefix caching. multi_modal_data: Multi modal data. encoder_seq_data: Optional sequence data for encoder prompt - (SequenceGroup.encoder_seq). Should be None + (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder model. cross_block_table: Optional cross-attention block table associated From 872e79531b39d1bf12ea81ddcd5bf919dd97265d Mon Sep 17 00:00:00 2001 From: laishzh Date: Thu, 15 Aug 2024 21:40:55 +0800 Subject: [PATCH 417/443] feat: embedding model forward --- examples/offline_inference_embedding.py | 4 +++- vllm/model_executor/models/bert_embedding.py | 4 ++-- vllm/model_executor/models/llama_embedding.py | 2 ++ 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py index 7d5ef128bc8e0..4bffe8630b3b9 100644 --- a/examples/offline_inference_embedding.py +++ b/examples/offline_inference_embedding.py @@ -9,7 +9,9 @@ ] # Create an LLM. -model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) +model = LLM(model="intfloat/e5-mistral-7b-instruct", + enforce_eager=True, + disable_sliding_window=True) # Generate embedding. The output is a list of EmbeddingRequestOutputs. outputs = model.encode(prompts) # Print the outputs. diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index a309f5607b645..165f4876166ae 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -67,8 +67,8 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, + encoder_input_ids: Optional[torch.Tensor], + encoder_positions: Optional[torch.Tensor], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, inputs_embeds: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/models/llama_embedding.py b/vllm/model_executor/models/llama_embedding.py index 8f1c77da50d96..2204e177c05f1 100644 --- a/vllm/model_executor/models/llama_embedding.py +++ b/vllm/model_executor/models/llama_embedding.py @@ -34,6 +34,8 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, + encoder_input_ids: Optional[torch.Tensor], + encoder_positions: Optional[torch.Tensor], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, inputs_embeds: Optional[torch.Tensor] = None, From a0ad0df28c9de89bdd66b587502f6af9265065be Mon Sep 17 00:00:00 2001 From: laishzh Date: Fri, 16 Aug 2024 11:15:28 +0800 Subject: [PATCH 418/443] chore: recover unchanged files --- vllm/sequence.py | 2 +- vllm/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index e29be0c637c26..7349bc6f13bd6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -758,7 +758,7 @@ class SequenceGroupMetadata: used in prefix caching. multi_modal_data: Multi modal data. encoder_seq_data: Optional sequence data for encoder prompt - (SequenceGroup.encoder_seq). Should be None + (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder model. cross_block_table: Optional cross-attention block table associated diff --git a/vllm/utils.py b/vllm/utils.py index aab5b18893476..30bb81722aa04 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1090,7 +1090,7 @@ def _cuda_device_count_stateless( def cuda_device_count_stateless() -> int: """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES at the time of call. - + This should be used instead of torch.cuda.device_count() unless CUDA_VISIBLE_DEVICES has already been set to the desired value.""" From f2158848b9abd839c515c568acd592d0416c6682 Mon Sep 17 00:00:00 2001 From: laishzh Date: Fri, 16 Aug 2024 11:21:54 +0800 Subject: [PATCH 419/443] chore: recover --- vllm/worker/embedding_model_runner.py | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 07501541aa9e1..80505bb0dad24 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type, Set +from typing import Any, Dict, List, Optional, Tuple, Type import torch @@ -7,20 +7,13 @@ ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MultiModalInputs -from vllm.multimodal import MultiModalInputs from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import make_tensor_with_pad from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunnerBase -from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU, - ModelInputForGPUBuilder) -from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _PAD_SLOT_ID, - GPUModelRunnerBase) +from vllm.worker.model_runner import (ModelInputForGPU, ModelInputForGPUBuilder) logger = init_logger(__name__) @@ -184,11 +177,10 @@ def _prepare_pooling( self, seq_group_metadata_list: List[SequenceGroupMetadata], prompt_lens: List[int], - is_decoder: bool = True, ) -> PoolingMetadata: """Prepare PoolingMetadata for the sequence group metadata list.""" seq_groups: List[Tuple[List[int], PoolingParams]] = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): + for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) pooling_params = seq_group_metadata.pooling_params seq_groups.append((seq_ids, pooling_params)) @@ -204,13 +196,3 @@ def _prepare_pooling( ) return pooling_metadata - - -## TODO: move to utils, and modify enc_dec_model_runner.py -def _is_single_block_table_empty(block_table: Optional[List[int]]): - """ - Check if a single block table has not been constructed - """ - if block_table is None: - return True - return False From 7657af3f49cdb567bc96b44157c89f18cc4d0a22 Mon Sep 17 00:00:00 2001 From: laishzh Date: Fri, 16 Aug 2024 15:01:26 +0800 Subject: [PATCH 420/443] feat: fix lint --- vllm/inputs/__init__.py | 4 ++-- vllm/worker/embedding_model_runner.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 639c45db5aea3..0ce7106ab0aae 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,7 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, - TokensPrompt, build_explicit_enc_dec_prompt, - build_encoder_prompt, to_enc_dec_tuple_list, + TokensPrompt, build_encoder_prompt, + build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 80505bb0dad24..903d00b010544 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -13,7 +13,7 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunnerBase -from vllm.worker.model_runner import (ModelInputForGPU, ModelInputForGPUBuilder) +from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder logger = init_logger(__name__) From 91e23d8ad2b45790590889d6ee437702f5003792 Mon Sep 17 00:00:00 2001 From: laishzh Date: Fri, 16 Aug 2024 15:04:30 +0800 Subject: [PATCH 421/443] feat: fix lint --- vllm/inputs/data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index ec875326d7ecc..0e3b914a073a8 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -67,7 +67,7 @@ class TokensPrompt(TypedDict): # TODO: Make fields ReadOnly once mypy supports it class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): """Represents an encoder/decoder model input prompt, - comprising an explicit encoder prompt and a + comprising an explicit encoder prompt and a decoder prompt. The encoder and decoder prompts, respectively, @@ -179,6 +179,6 @@ def to_enc_dec_tuple_list( def build_encoder_prompt( - encoder_prompt: _T1) -> ExplicitEncoderDecoderPrompt[_T1, _T2]: - return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, - decoder_prompt=None) + encoder_prompt: _T1) -> ExplicitEncoderDecoderPrompt[_T1, SingletonPromptInputs]: + return build_explicit_enc_dec_prompt(encoder_prompt=encoder_prompt, + decoder_prompt=None) From 0b3f55c66e5eb40808f46ebde3c38213478050c7 Mon Sep 17 00:00:00 2001 From: laishzh Date: Fri, 16 Aug 2024 15:12:51 +0800 Subject: [PATCH 422/443] feat: fix lint --- vllm/inputs/data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 0e3b914a073a8..20512b1ba124d 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -179,6 +179,7 @@ def to_enc_dec_tuple_list( def build_encoder_prompt( - encoder_prompt: _T1) -> ExplicitEncoderDecoderPrompt[_T1, SingletonPromptInputs]: + encoder_prompt: _T1, +) -> ExplicitEncoderDecoderPrompt[_T1, SingletonPromptInputs]: return build_explicit_enc_dec_prompt(encoder_prompt=encoder_prompt, decoder_prompt=None) From 275f49de32136eb9e4298d42aa85a1e2dc56924c Mon Sep 17 00:00:00 2001 From: laishzh Date: Sat, 17 Aug 2024 01:03:55 +0800 Subject: [PATCH 423/443] feat: embedding model prompt --- examples/offline_inference_bert_embedding.py | 3 +-- examples/offline_inference_embedding.py | 6 ++++-- vllm/inputs/__init__.py | 7 ++++--- vllm/inputs/data.py | 18 +++++++++++++----- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/examples/offline_inference_bert_embedding.py b/examples/offline_inference_bert_embedding.py index 9286afc346775..10dd791c01d38 100644 --- a/examples/offline_inference_bert_embedding.py +++ b/examples/offline_inference_bert_embedding.py @@ -1,8 +1,7 @@ from vllm import LLM -from vllm.inputs import build_encoder_prompt # Sample prompts. -prompts = [build_encoder_prompt("This is an example sentence.")] +prompts = ["This is an example sentence."] # Create an LLM. model = LLM(model="bert-base-uncased", enforce_eager=True) diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py index 4bffe8630b3b9..62e43774e0ab9 100644 --- a/examples/offline_inference_embedding.py +++ b/examples/offline_inference_embedding.py @@ -1,16 +1,18 @@ from vllm import LLM +from vllm.inputs import build_decoder_prompts # Sample prompts. -prompts = [ +prompts = build_decoder_prompts([ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", -] +]) # Create an LLM. model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True, + # NOTE: sliding_window is not supported by encoder_decoder_model disable_sliding_window=True) # Generate embedding. The output is a list of EmbeddingRequestOutputs. outputs = model.encode(prompts) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 0ce7106ab0aae..a2b01c1623fb4 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,7 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, - TokensPrompt, build_encoder_prompt, - build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, + TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, + build_decoder_prompt, build_decoder_prompts, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry @@ -22,7 +22,8 @@ "ExplicitEncoderDecoderPrompt", "LLMInputs", "EncoderDecoderLLMInputs", - "build_encoder_prompt", + "build_decoder_prompt", + "build_decoder_prompts", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 20512b1ba124d..3db6cc260ffb3 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -178,8 +178,16 @@ def to_enc_dec_tuple_list( for enc_dec_prompt in enc_dec_prompts] -def build_encoder_prompt( - encoder_prompt: _T1, -) -> ExplicitEncoderDecoderPrompt[_T1, SingletonPromptInputs]: - return build_explicit_enc_dec_prompt(encoder_prompt=encoder_prompt, - decoder_prompt=None) +def build_decoder_prompt( + prompt: _T2, +) -> ExplicitEncoderDecoderPrompt[SingletonPromptInputs, _T2]: + return build_explicit_enc_dec_prompt(encoder_prompt="", + decoder_prompt=prompt) + + +def build_decoder_prompts( + prompts: Iterable[_T1], +) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]: + return [ + build_decoder_prompt(prompt) for prompt in prompts + ] From ce9a599194dbc3a208a6a4a21fdccaaa5c26ece8 Mon Sep 17 00:00:00 2001 From: laishzh Date: Sat, 17 Aug 2024 02:18:54 +0800 Subject: [PATCH 424/443] feat: bos_token_id --- vllm/engine/llm_engine.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c202b4b8096dc..88a366916a052 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -522,15 +522,18 @@ def _verify_args(self) -> None: def _get_bos_token_id(self, lora_request: Optional[LoRARequest] = None ) -> Optional[int]: - if self.is_encoder_model(): - return _DEFAULT_BOS_TOKEN_ID - if self.tokenizer is None: logger.warning("Using None for BOS token id because tokenizer " "is not initialized") return None - return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id + bos_token_id = self.tokenizer.get_lora_tokenizer( + lora_request).bos_token_id + + if bos_token_id is None and self.is_encoder_model(): + bos_token_id = _DEFAULT_BOS_TOKEN_ID + + return bos_token_id def _get_eos_token_id(self, lora_request: Optional[LoRARequest] = None @@ -549,9 +552,6 @@ def _get_decoder_start_token_id(self) -> Optional[int]: model config is unavailable. ''' - if self.is_encoder_model(): - return self._get_bos_token_id() - if not self.is_encoder_decoder_model(): logger.warning("Using None for decoder start token id because " "this is not an encoder/decoder model.") @@ -564,9 +564,12 @@ def _get_decoder_start_token_id(self) -> Optional[int]: dec_start_token_id = getattr(self.model_config.hf_config, 'decoder_start_token_id', None) + if dec_start_token_id is None: - logger.warning("Falling back on for decoder start token id " - "because decoder start token id is not available.") + if not self.is_encoder_model(): + logger.warning( + "Falling back on for decoder start token id " + "because decoder start token id is not available.") dec_start_token_id = self._get_bos_token_id() return dec_start_token_id From 7e1196d25054d76d92b3777bc077d3cffd742599 Mon Sep 17 00:00:00 2001 From: laishzh Date: Sat, 17 Aug 2024 14:43:32 +0800 Subject: [PATCH 425/443] fix: fix hint --- vllm/inputs/__init__.py | 6 +++--- vllm/inputs/data.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index a2b01c1623fb4..4707c2afedbf2 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,8 +1,8 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, - TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, - build_decoder_prompt, build_decoder_prompts, - zip_enc_dec_prompts) + TokensPrompt, build_explicit_enc_dec_prompt, + to_enc_dec_tuple_list, build_decoder_prompt, + build_decoder_prompts, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 3db6cc260ffb3..ed56ff293b9c9 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -186,8 +186,8 @@ def build_decoder_prompt( def build_decoder_prompts( - prompts: Iterable[_T1], -) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]: + prompts: Iterable[_T2], +) -> List[ExplicitEncoderDecoderPrompt[SingletonPromptInputs, _T2]]: return [ build_decoder_prompt(prompt) for prompt in prompts ] From b99d783bd852eb4cae228fcd8faf3344cd9a6fed Mon Sep 17 00:00:00 2001 From: laishzh Date: Sun, 18 Aug 2024 00:49:57 +0800 Subject: [PATCH 426/443] feat: remove embedding block space manager --- examples/offline_inference_encoder_decoder.py | 3 --- vllm/core/embedding_model_block_manager.py | 1 + vllm/core/scheduler.py | 2 -- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index fc48f0e12142c..0f266d7918853 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -74,9 +74,6 @@ enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 ] + zipped_prompt_list -# prompts = zip_enc_dec_prompts(['Only encoder prompt'], ["Decoder prompt"]) -prompts = zip_enc_dec_prompts(['Only encoder prompt'], [None]) - print(prompts) # Create a sampling params object. diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index f1db4965c160b..16e62df712040 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -1,3 +1,4 @@ +# TODO: Remove this file if possible. from typing import List, Tuple from vllm.core.interfaces import AllocStatus, BlockSpaceManager diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 4e04cef8523cf..8669290f120aa 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -309,8 +309,6 @@ def __init__( version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" - # if self.scheduler_config.embedding_mode: - # version = "embedding" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version) From b76da51c0d9ba1b4e39d432b8fb557ed8319034f Mon Sep 17 00:00:00 2001 From: laishzh Date: Mon, 19 Aug 2024 11:35:22 +0800 Subject: [PATCH 427/443] feat: enc_dec_runner base --- vllm/worker/embedding_model_runner.py | 29 +--- vllm/worker/enc_dec_model_runner.py | 235 +++++++++++++++----------- 2 files changed, 140 insertions(+), 124 deletions(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 903d00b010544..5e634b14279af 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -12,27 +12,23 @@ from vllm.pooling_params import PoolingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) -from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunnerBase -from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder +from vllm.worker.enc_dec_model_runner import (EncoderDecoderModelRunnerBase, + EncoderDecoderModelInputBase) +from vllm.worker.model_runner import ModelInputForGPUBuilder logger = init_logger(__name__) @dataclasses.dataclass(frozen=True) -class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): +class EmbeddingModelInput(EncoderDecoderModelInputBase): """ Used by the EmbeddingModelRunner. """ pooling_metadata: Optional["PoolingMetadata"] = None - encoder_input_tokens: Optional[torch.Tensor] = None - encoder_input_positions: Optional[torch.Tensor] = None - -class EmbeddingModelRunner( - EncoderDecoderModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): - _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( - ModelInputForGPUWithPoolingMetadata) +class EmbeddingModelRunner(EncoderDecoderModelRunnerBase[EmbeddingModelInput]): + _model_input_cls: Type[EmbeddingModelInput] = EmbeddingModelInput _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder def __init__( @@ -66,7 +62,7 @@ def __init__( @torch.inference_mode() def execute_model( self, - model_input: ModelInputForGPUWithPoolingMetadata, + model_input: EmbeddingModelInput, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, @@ -132,21 +128,12 @@ def execute_model( pooling_metadata=model_input.pooling_metadata) ] - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, - Any]) -> ModelInputForGPUWithPoolingMetadata: - return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - def prepare_model_input( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithPoolingMetadata: + ) -> EmbeddingModelInput: assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 376a6e5433d04..17f097c626598 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type, cast +from typing import Any, Dict, List, Optional, Tuple, Type, cast, TypeVar import torch import torch.distributed @@ -21,6 +21,7 @@ from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase, ModelInputForGPUBuilder, + ModelInputForGPU, ModelInputForGPUWithSamplingMetadata, TModelInputForGPU) from vllm.worker.model_runner_base import ( @@ -30,9 +31,12 @@ logger = init_logger(__name__) +TEncoderDecoderModelInput = TypeVar('TEncoderDecoderModelInput', + bound="EncoderDecoderModelInputBase") + @dataclasses.dataclass(frozen=True) -class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): +class EncoderDecoderModelInputBase(ModelInputForGPU): """ Used by the EncoderDecoderModelRunner. """ @@ -50,8 +54,6 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "finished_requests_ids": self.finished_requests_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) return tensor_dict @classmethod @@ -59,16 +61,13 @@ def from_broadcasted_tensor_dict( cls, tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, - ) -> "EncoderDecoderModelInput": + ) -> "EncoderDecoderModelInputBase": return cast( - EncoderDecoderModelInput, + EncoderDecoderModelInputBase, super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) class EncoderDecoderModelRunnerBase(GPUModelRunnerBase[TModelInputForGPU]): - _model_input_cls: Type[EncoderDecoderModelInput] = ( - EncoderDecoderModelInput) - _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) def __init__( self, @@ -163,103 +162,13 @@ def _empty_int32_tensor(self) -> torch.Tensor: def _empty_long_tensor(self) -> torch.Tensor: return self._list_to_long_tensor([]) - @torch.inference_mode() - def execute_model( - self, - model_input: EncoderDecoderModelInput, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[PoolerOutput]]: - if num_steps > 1: - raise ValueError("num_steps > 1 is not supported in " - "EncoderDecoderModelRunner") - - model_executable = self.model - - seqlen_agnostic_kwargs = { - "finished_requests_ids": model_input.finished_requests_ids, - "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_seqlen_agnostic else {} - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - encoder_input_ids=model_input.encoder_input_tokens, - encoder_positions=model_input.encoder_input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **seqlen_agnostic_kwargs) - - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) - - if not self.is_driver_worker: - return [] - - # Sample the next token. - output: SamplerOutput = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - - return [output] - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput: - return EncoderDecoderModelInput.from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> TEncoderDecoderModelInput: + return TEncoderDecoderModelInput.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, ) - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> TModelInputForGPU: - """Prepare the model input based on a given sequence group, including - metadata for the sampling step. - - Since chunked prefill is not supported for encoder/decoder models, - `input_tokens` is assumed to be either entirely prefill tokens or - entirely decode tokens. - - """ - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - - ( - attn_metadata, - encoder_input_tokens_tensor, - encoder_input_positions_tensor, - ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, - model_input)) - - # Inject attn_metadata encoder/cross-attention fields & - # encoder input tokens/positions into model_input. - # Frozen dataclass fields cannot be modified, so use - # dataclasses.replace to construct a new model input - # instance. - model_input = dataclasses.replace( - model_input, - attn_metadata=attn_metadata, - encoder_input_tokens=encoder_input_tokens_tensor, - encoder_input_positions=encoder_input_positions_tensor, - ) - - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, - model_input.seq_lens, - model_input.query_lens, - self.device, - self.pin_memory) - is_prompt = (seq_group_metadata_list[0].is_prompt - if seq_group_metadata_list else None) - return dataclasses.replace(model_input, - sampling_metadata=sampling_metadata, - is_prompt=is_prompt, - virtual_engine=virtual_engine) - @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. @@ -312,12 +221,12 @@ def profile_run(self) -> None: def _prepare_encoder_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - model_input: TModelInputForGPU, + model_input: TEncoderDecoderModelInput, ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], Optional[torch.Tensor]]: """Helper method to prepare the encoder- and cross-attn-related model inputs based on a given sequence group. These additional inputs - are used to augment an already-computed `EncoderDecoderModelInput` + are used to augment an already-computed `TEncoderDecoderModelInput` data structure which already has decoder-related model inputs populated. @@ -474,9 +383,129 @@ def _prepare_encoder_model_input_tensors( encoder_input_positions_tensor) +class EncoderDecoderModelInput(EncoderDecoderModelInputBase, + ModelInputForGPUWithSamplingMetadata): + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "encoder_input_tokens": self.encoder_input_tokens, + "encoder_input_positions": self.encoder_input_positions, + "virtual_engine": self.virtual_engine, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_requests_ids": self.finished_requests_ids, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "EncoderDecoderModelInput": + return cast( + EncoderDecoderModelInput, + super(EncoderDecoderModelInputBase, + cls).from_broadcasted_tensor_dict(tensor_dict, attn_backend)) + + class EncoderDecoderModelRunner( EncoderDecoderModelRunnerBase[EncoderDecoderModelInput]): _model_input_cls: Type[EncoderDecoderModelInput] = ( EncoderDecoderModelInput) _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) + + @torch.inference_mode() + def execute_model( + self, + model_input: TEncoderDecoderModelInput, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[PoolerOutput]]: + if num_steps > 1: + raise ValueError("num_steps > 1 is not supported in " + "EncoderDecoderModelRunner") + + model_executable = self.model + + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_seqlen_agnostic else {} + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + encoder_input_ids=model_input.encoder_input_tokens, + encoder_positions=model_input.encoder_input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **seqlen_agnostic_kwargs) + + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + if not self.is_driver_worker: + return [] + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + + return [output] + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> TModelInputForGPU: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + Since chunked prefill is not supported for encoder/decoder models, + `input_tokens` is assumed to be either entirely prefill tokens or + entirely decode tokens. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + + ( + attn_metadata, + encoder_input_tokens_tensor, + encoder_input_positions_tensor, + ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, + model_input)) + + # Inject attn_metadata encoder/cross-attention fields & + # encoder input tokens/positions into model_input. + # Frozen dataclass fields cannot be modified, so use + # dataclasses.replace to construct a new model input + # instance. + model_input = dataclasses.replace( + model_input, + attn_metadata=attn_metadata, + encoder_input_tokens=encoder_input_tokens_tensor, + encoder_input_positions=encoder_input_positions_tensor, + ) + + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + self.pin_memory) + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine) From 8b107a24a4ef9abb194686066c3bebc6923c6876 Mon Sep 17 00:00:00 2001 From: laishzh Date: Mon, 19 Aug 2024 13:41:49 +0800 Subject: [PATCH 428/443] feat: fix lint --- examples/offline_inference_embedding.py | 9 +++++---- vllm/inputs/data.py | 7 ++----- vllm/worker/embedding_model_runner.py | 20 +++++++++++++++++++- vllm/worker/enc_dec_model_runner.py | 18 +++++++++--------- 4 files changed, 35 insertions(+), 19 deletions(-) diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py index 62e43774e0ab9..4b742b4ba65fc 100644 --- a/examples/offline_inference_embedding.py +++ b/examples/offline_inference_embedding.py @@ -10,10 +10,11 @@ ]) # Create an LLM. -model = LLM(model="intfloat/e5-mistral-7b-instruct", - enforce_eager=True, - # NOTE: sliding_window is not supported by encoder_decoder_model - disable_sliding_window=True) +model = LLM( + model="intfloat/e5-mistral-7b-instruct", + enforce_eager=True, + # NOTE: sliding_window is not supported by encoder_decoder_model + disable_sliding_window=True) # Generate embedding. The output is a list of EmbeddingRequestOutputs. outputs = model.encode(prompts) # Print the outputs. diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index ed56ff293b9c9..120a99ee73acf 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -179,8 +179,7 @@ def to_enc_dec_tuple_list( def build_decoder_prompt( - prompt: _T2, -) -> ExplicitEncoderDecoderPrompt[SingletonPromptInputs, _T2]: + prompt: _T2, ) -> ExplicitEncoderDecoderPrompt[SingletonPromptInputs, _T2]: return build_explicit_enc_dec_prompt(encoder_prompt="", decoder_prompt=prompt) @@ -188,6 +187,4 @@ def build_decoder_prompt( def build_decoder_prompts( prompts: Iterable[_T2], ) -> List[ExplicitEncoderDecoderPrompt[SingletonPromptInputs, _T2]]: - return [ - build_decoder_prompt(prompt) for prompt in prompts - ] + return [build_decoder_prompt(prompt) for prompt in prompts] diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 165cff5a0201d..4e75394f97853 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,8 +1,9 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, cast, Type import torch +from vllm.attention.backends.abstract import AttentionBackend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -26,6 +27,16 @@ class EmbeddingModelInput(EncoderDecoderModelInputBase): """ pooling_metadata: Optional["PoolingMetadata"] = None + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "EmbeddingModelInput": + return cast( + EmbeddingModelInput, + super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) + class EmbeddingModelRunner(EncoderDecoderModelRunnerBase[EmbeddingModelInput]): _model_input_cls: Type[EmbeddingModelInput] = EmbeddingModelInput @@ -57,6 +68,13 @@ def __init__( prompt_adapter_config=prompt_adapter_config, observability_config=observability_config) + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> EmbeddingModelInput: + return EmbeddingModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + @torch.inference_mode() def execute_model( self, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index d029e7894137e..496a9e21ca6b6 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -164,13 +164,6 @@ def _empty_int32_tensor(self) -> torch.Tensor: def _empty_long_tensor(self) -> torch.Tensor: return self._list_to_long_tensor([]) - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> TEncoderDecoderModelInput: - return TEncoderDecoderModelInput.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. @@ -428,10 +421,17 @@ class EncoderDecoderModelRunner( EncoderDecoderModelInput) _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput: + return EncoderDecoderModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + @torch.inference_mode() def execute_model( self, - model_input: TEncoderDecoderModelInput, + model_input: EncoderDecoderModelInput, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, @@ -475,7 +475,7 @@ def prepare_model_input( seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None - ) -> TModelInputForGPU: + ) -> EncoderDecoderModelInput: """Prepare the model input based on a given sequence group, including metadata for the sampling step. From bfd7ec9e043cf304e6dea024912eb2a18c786bd6 Mon Sep 17 00:00:00 2001 From: laishzh Date: Mon, 19 Aug 2024 14:59:06 +0800 Subject: [PATCH 429/443] feat: model input --- vllm/worker/embedding_model_runner.py | 4 +-- vllm/worker/enc_dec_model_runner.py | 41 ++++----------------------- 2 files changed, 8 insertions(+), 37 deletions(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 4e75394f97853..ae4b7457c4d48 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -14,14 +14,14 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) from vllm.worker.enc_dec_model_runner import (EncoderDecoderModelRunnerBase, - EncoderDecoderModelInputBase) + EncoderDecoderModelInput) from vllm.worker.model_runner import ModelInputForGPUBuilder logger = init_logger(__name__) @dataclasses.dataclass(frozen=True) -class EmbeddingModelInput(EncoderDecoderModelInputBase): +class EmbeddingModelInput(EncoderDecoderModelInput): """ Used by the EmbeddingModelRunner. """ diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 496a9e21ca6b6..29af243e6538d 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -22,7 +22,6 @@ from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase, ModelInputForGPUBuilder, - ModelInputForGPU, ModelInputForGPUWithSamplingMetadata, TModelInputForGPU) from vllm.worker.model_runner_base import ( @@ -33,11 +32,11 @@ logger = init_logger(__name__) TEncoderDecoderModelInput = TypeVar('TEncoderDecoderModelInput', - bound="EncoderDecoderModelInputBase") + bound="EncoderDecoderModelInput") @dataclasses.dataclass(frozen=True) -class EncoderDecoderModelInputBase(ModelInputForGPU): +class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): """ Used by the EncoderDecoderModelRunner. """ @@ -55,6 +54,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "finished_requests_ids": self.finished_requests_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) return tensor_dict @classmethod @@ -62,9 +63,9 @@ def from_broadcasted_tensor_dict( cls, tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, - ) -> "EncoderDecoderModelInputBase": + ) -> "EncoderDecoderModelInput": return cast( - EncoderDecoderModelInputBase, + EncoderDecoderModelInput, super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) @@ -384,36 +385,6 @@ def _prepare_encoder_model_input_tensors( encoder_input_positions_tensor) -class EncoderDecoderModelInput(EncoderDecoderModelInputBase, - ModelInputForGPUWithSamplingMetadata): - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - "encoder_input_tokens": self.encoder_input_tokens, - "encoder_input_positions": self.encoder_input_positions, - "virtual_engine": self.virtual_engine, - "request_ids_to_seq_ids": self.request_ids_to_seq_ids, - "finished_requests_ids": self.finished_requests_ids, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "EncoderDecoderModelInput": - return cast( - EncoderDecoderModelInput, - super(EncoderDecoderModelInputBase, - cls).from_broadcasted_tensor_dict(tensor_dict, attn_backend)) - - class EncoderDecoderModelRunner( EncoderDecoderModelRunnerBase[EncoderDecoderModelInput]): From 6f006f5ad698d76599e0b005520e65921042d07b Mon Sep 17 00:00:00 2001 From: laishzh Date: Mon, 19 Aug 2024 15:06:21 +0800 Subject: [PATCH 430/443] chore: fix lint --- vllm/inputs/__init__.py | 6 +++--- vllm/worker/embedding_model_runner.py | 6 +++--- vllm/worker/enc_dec_model_runner.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 4707c2afedbf2..9f6d0d9f7e092 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,8 +1,8 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, - TokensPrompt, build_explicit_enc_dec_prompt, - to_enc_dec_tuple_list, build_decoder_prompt, - build_decoder_prompts, zip_enc_dec_prompts) + TokensPrompt, build_decoder_prompt, build_decoder_prompts, + build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, + zip_enc_dec_prompts) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index ae4b7457c4d48..a63bd16936d38 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple, cast, Type +from typing import Any, Dict, List, Optional, Tuple, Type, cast import torch @@ -13,8 +13,8 @@ from vllm.pooling_params import PoolingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) -from vllm.worker.enc_dec_model_runner import (EncoderDecoderModelRunnerBase, - EncoderDecoderModelInput) +from vllm.worker.enc_dec_model_runner import (EncoderDecoderModelInput, + EncoderDecoderModelRunnerBase) from vllm.worker.model_runner import ModelInputForGPUBuilder logger = init_logger(__name__) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 29af243e6538d..24f0edde6ccb5 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type, cast, TypeVar +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, cast import torch import torch.distributed From 37f698b4241a42c9634030e372e419b47e2a1e9c Mon Sep 17 00:00:00 2001 From: laishzh Date: Mon, 19 Aug 2024 15:16:34 +0800 Subject: [PATCH 431/443] feat: move BertEmbeddingModel to the end of file --- vllm/model_executor/models/bert_embedding.py | 226 +++++++++---------- 1 file changed, 113 insertions(+), 113 deletions(-) diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index 165f4876166ae..9c620cd4627f0 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -19,119 +19,6 @@ from vllm.sequence import PoolerOutput -class BertEmbeddingModel(nn.Module): - """A model that uses Bert to provide embedding functionalities. - - This class encapsulates the BertModel and provides an interface for - embedding operations and customized pooling functions. - - Attributes: - model: An instance of BertModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ - - stacked_params_mapping = { - "query": { - "param_name": "qkv_proj", - "shard_id": "q", - }, - "key": { - "param_name": "qkv_proj", - "shard_id": "k", - }, - "value": { - "param_name": "qkv_proj", - "shard_id": "v", - }, - } - - params_mapping = { - "beta": "bias", - "gamma": "weight", - "LayerNorm": "layernorm", - } - - def __init__( - self, - **kwargs, - ) -> None: - super().__init__() - self.base_model_prefix = "bert" - self.model = BertModel(config=kwargs["config"], - cache_config=kwargs.get("cache_config", None), - quant_config=kwargs.get("quant_config", None)) - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) - # self._pooler = BertPooler(config=kwargs["config"]) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - encoder_input_ids: Optional[torch.Tensor], - encoder_positions: Optional[torch.Tensor], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return self.model(input_ids=encoder_input_ids, - position_ids=encoder_positions, - kv_caches=kv_caches, - inputs_embeds=inputs_embeds, - attn_metadata=attn_metadata) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - - params_dict = dict(self.model.named_parameters()) - - for name, loaded_weight in weights: - name = self._rename_key(name) - name, shard_id = self._rename_stacked_param(name) - - # Skip the specific downstream task weight. - if name.startswith('cls.'): - continue - # use Pooler instead. - if name.startswith('pooler.'): - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - if shard_id: - weight_loader(param, loaded_weight, shard_id) - else: - weight_loader(param, loaded_weight) - - def _rename_key(self, key: str): - prefix = f"{self.base_model_prefix}." - key = key[len(prefix):] if key.startswith(prefix) else key - - for src, dst in self.params_mapping.items(): - key = key.replace(src, dst) - - return key - - def _rename_stacked_param( - self, - name: str, - ) -> Tuple[str, Optional[str]]: - for key, mapping in self.stacked_params_mapping.items(): - if key in name: - name = name.replace(key, mapping["param_name"]) - return name, mapping["shard_id"] - return name, None - - class BertModel(nn.Module): def __init__( @@ -428,3 +315,116 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output + + +class BertEmbeddingModel(nn.Module): + """A model that uses Bert to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + stacked_params_mapping = { + "query": { + "param_name": "qkv_proj", + "shard_id": "q", + }, + "key": { + "param_name": "qkv_proj", + "shard_id": "k", + }, + "value": { + "param_name": "qkv_proj", + "shard_id": "v", + }, + } + + params_mapping = { + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + } + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + self.base_model_prefix = "bert" + self.model = BertModel(config=kwargs["config"], + cache_config=kwargs.get("cache_config", None), + quant_config=kwargs.get("quant_config", None)) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + # self._pooler = BertPooler(config=kwargs["config"]) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + encoder_input_ids: Optional[torch.Tensor], + encoder_positions: Optional[torch.Tensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model(input_ids=encoder_input_ids, + position_ids=encoder_positions, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + attn_metadata=attn_metadata) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + params_dict = dict(self.model.named_parameters()) + + for name, loaded_weight in weights: + name = self._rename_key(name) + name, shard_id = self._rename_stacked_param(name) + + # Skip the specific downstream task weight. + if name.startswith('cls.'): + continue + # use Pooler instead. + if name.startswith('pooler.'): + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if shard_id: + weight_loader(param, loaded_weight, shard_id) + else: + weight_loader(param, loaded_weight) + + def _rename_key(self, key: str): + prefix = f"{self.base_model_prefix}." + key = key[len(prefix):] if key.startswith(prefix) else key + + for src, dst in self.params_mapping.items(): + key = key.replace(src, dst) + + return key + + def _rename_stacked_param( + self, + name: str, + ) -> Tuple[str, Optional[str]]: + for key, mapping in self.stacked_params_mapping.items(): + if key in name: + name = name.replace(key, mapping["param_name"]) + return name, mapping["shard_id"] + return name, None From d09860763500b85193230588386f0e3d515e231c Mon Sep 17 00:00:00 2001 From: laishzh Date: Mon, 19 Aug 2024 15:24:51 +0800 Subject: [PATCH 432/443] feat: remove embedding_model_block_manager.py --- vllm/core/embedding_model_block_manager.py | 89 ---------------------- 1 file changed, 89 deletions(-) diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index 16e62df712040..001bf5fb31019 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -1,90 +1 @@ # TODO: Remove this file if possible. -from typing import List, Tuple - -from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.sequence import Sequence, SequenceGroup - - -class EmbeddingModelBlockSpaceManager(BlockSpaceManager): - """An embedding version of BlockSpaceManager for use in environments - with embedding models where block management is not required. - - This class provides the same interface as BlockSpaceManager, but its - methods perform no actions or return simple values like True in specific - actions. It's designed to be used in scenarios where the overhead of - block management is unnecessary, such as in an embedding environment. - """ - - def __init__( - self, - **kwargs, - ) -> None: - pass - - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: - # Always return OK for dummy purposes - return AllocStatus.OK - - def allocate(self, seq_group: SequenceGroup) -> None: - # No actual allocation logic needed - pass - - def can_append_slots(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: - return True - - def append_slots( - self, - seq: Sequence, - num_lookahead_slots: int, - ) -> List[Tuple[int, int]]: - return None # type: ignore - - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - pass - - def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> AllocStatus: - return AllocStatus.OK - - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - return None # type: ignore - - def can_swap_out(self, seq_group: SequenceGroup) -> bool: - return True - - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - return None # type: ignore - - def free(self, seq: Sequence) -> None: - # No operation on free - return - - def get_block_table(self, seq: Sequence) -> List[int]: - return None # type: ignore - - def get_cross_block_table(self, seq: Sequence) -> List[int]: - return None # type: ignore - - def free_cross(self, seq_group: SequenceGroup) -> None: - return - - def get_num_free_gpu_blocks(self) -> int: - return 1 - - def get_num_free_cpu_blocks(self) -> int: - return 1 - - def access_all_blocks_in_seq( - self, - seq: Sequence, - access_time: float, - ) -> None: - pass - - def get_common_computed_block_ids(self, - seq_group: SequenceGroup) -> List[int]: - return None # type: ignore - - def mark_blocks_as_computed(self, seq_group: SequenceGroup): - pass From fc1f2b7ceb69f9588799820831145babf29aaa64 Mon Sep 17 00:00:00 2001 From: laishzh Date: Mon, 19 Aug 2024 15:39:33 +0800 Subject: [PATCH 433/443] chore: fix lint --- vllm/core/interfaces.py | 5 ----- vllm/core/scheduler.py | 6 +++--- vllm/engine/llm_engine.py | 8 +++----- vllm/inputs/data.py | 2 +- vllm/worker/embedding_model_runner.py | 14 +++++++------- 5 files changed, 14 insertions(+), 21 deletions(-) diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 8759ee06795b8..034f340ad78b5 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -35,11 +35,6 @@ def get_block_space_manager_class(version: str): from vllm.core.block_manager_v2 import BlockSpaceManagerV2 return BlockSpaceManagerV2 - if version == "embedding": - from vllm.core.embedding_model_block_manager import ( - EmbeddingModelBlockSpaceManager) - return EmbeddingModelBlockSpaceManager - raise ValueError(f"Unknown version {version=}") @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index df5b27ac9f296..e913f4c0303c0 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -473,7 +473,7 @@ def _schedule_running( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. - + Returns: SchedulerRunningOutputs. """ @@ -823,7 +823,7 @@ def _schedule_prefills( def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. - + The current policy is designed to optimize the throughput. First, it batches as many prefill requests as possible. And it schedules decodes. If there's a pressure on GPU memory, decode requests can @@ -923,7 +923,7 @@ def _schedule_default(self) -> SchedulerOutputs: def _schedule_chunked_prefill(self) -> SchedulerOutputs: """Schedule queued requests. - + Chunked prefill allows to chunk prefill requests, batch them together with decode requests. This policy 1. schedule as many decoding requests as possible. 2. schedule chunked prefill requests that are not diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 476d304211600..e7507192ad04f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -104,7 +104,7 @@ class LLMEngine: decoding. executor_class: The model executor class for managing distributed execution. - prompt_adapter_config (Optional): The configuration related to serving + prompt_adapter_config (Optional): The configuration related to serving prompt adapters. log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection. @@ -279,7 +279,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: observability_config=self.observability_config, ) - # if not self.model_config.embedding_mode: self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. @@ -570,7 +569,6 @@ def _get_decoder_start_token_id(self) -> Optional[int]: dec_start_token_id = getattr(self.model_config.hf_config, 'decoder_start_token_id', None) - if dec_start_token_id is None: if not self.is_encoder_model(): logger.warning( @@ -793,7 +791,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: "default" decoder prompt be . However, it is possible that in the future - other models may have different or more + other models may have different or more complex logic for the default decoder prompt. This motivates having a special helper method for default decoder prompts. @@ -856,7 +854,7 @@ def _process_encoder_decoder_prompt( have any possible singleton type; thus this method relies on helper functions to obtain token ids for the sub-prompts. - + Arguments: * inputs: an input prompt diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 120a99ee73acf..d096a02d250df 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -67,7 +67,7 @@ class TokensPrompt(TypedDict): # TODO: Make fields ReadOnly once mypy supports it class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): """Represents an encoder/decoder model input prompt, - comprising an explicit encoder prompt and a + comprising an explicit encoder prompt and a decoder prompt. The encoder and decoder prompts, respectively, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index a63bd16936d38..f80c9b5d51b0d 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -68,13 +68,6 @@ def __init__( prompt_adapter_config=prompt_adapter_config, observability_config=observability_config) - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> EmbeddingModelInput: - return EmbeddingModelInput.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - @torch.inference_mode() def execute_model( self, @@ -144,6 +137,13 @@ def execute_model( pooling_metadata=model_input.pooling_metadata) ] + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> EmbeddingModelInput: + return EmbeddingModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + def prepare_model_input( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], From 612cf1a969fa46105c3685b2eb025cde6416747d Mon Sep 17 00:00:00 2001 From: laishzh Date: Tue, 27 Aug 2024 15:19:50 +0800 Subject: [PATCH 434/443] feat: modify test_embedding --- tests/models/test_embedding.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/models/test_embedding.py b/tests/models/test_embedding.py index 6556998b68a74..aabd48a28625c 100644 --- a/tests/models/test_embedding.py +++ b/tests/models/test_embedding.py @@ -6,6 +6,8 @@ import torch import torch.nn.functional as F +from vllm.inputs import build_decoder_prompts + MODELS = [ "intfloat/e5-mistral-7b-instruct", ] @@ -31,8 +33,14 @@ def test_models( with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + with vllm_runner( + model, + dtype=dtype, + disable_sliding_window=True, + enforce_eager=True, + gpu_memory_utilization=0.95, + ) as vllm_model: + vllm_outputs = vllm_model.encode(build_decoder_prompts(example_prompts)) similarities = compare_embeddings(hf_outputs, vllm_outputs) all_similarities = torch.stack(similarities) From 7d0ecb90c5034d41f0d9b38eede25f50bf941e3d Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Wed, 28 Aug 2024 16:35:25 -0300 Subject: [PATCH 435/443] Add support for Roberta embedding models It's almost identical to the Bert models Signed-off-by: Max de Bayser --- vllm/attention/ops/paged_attn.py | 2 +- vllm/model_executor/models/__init__.py | 2 + .../models/roberta_embedding.py | 78 +++++++++++++++++++ 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/roberta_embedding.py diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 92023d5b75f5a..076f151ffcb61 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -34,7 +34,7 @@ class PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 120, 128, 192, 256] + return [32, 64, 80, 96, 112, 120, 128, 192, 256] @staticmethod def get_kv_cache_shape( diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d4d64cafd3f3e..c85e70056325e 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -67,6 +67,8 @@ _EMBEDDING_MODELS = { "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), "BertForMaskedLM": ("bert_embedding", "BertEmbeddingModel"), + "RobertaForMaskedLM": ("roberta_embedding", "RobertaEmbeddingModel"), + "RobertaModel": ("roberta_embedding", "RobertaEmbeddingModel"), } _MULTIMODAL_MODELS = { diff --git a/vllm/model_executor/models/roberta_embedding.py b/vllm/model_executor/models/roberta_embedding.py new file mode 100644 index 0000000000000..8f0a7d8f9a582 --- /dev/null +++ b/vllm/model_executor/models/roberta_embedding.py @@ -0,0 +1,78 @@ +from typing import Optional + +from torch import nn +from transformers import RobertaConfig + +from vllm.config import CacheConfig +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.models.bert_embedding import (BertEmbedding, + BertEmbeddingModel, + BertEncoder, BertModel) + + +class RobertaModel(BertModel): + + def __init__( + self, + config: RobertaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + # Skip BertModel.__init__() + nn.Module.__init__(self) + self.embeddings = RobertaEmbedding(config) + self.encoder = BertEncoder(config, cache_config, quant_config) + + +class RobertaEmbedding(BertEmbedding): + + def __init__(self, config: RobertaConfig): + # Skip BertEmbedding.__init__() + nn.Module.__init__(self) + self.size = config.hidden_size + self.word_embeddings = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size, + padding_idx=self.padding_idx) + + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + self.position_embedding_type = config.position_embedding_type + if self.position_embedding_type != "absolute": + raise ValueError("Only 'absolute' position_embedding_type" + + " is supported") + + +class RobertaEmbeddingModel(BertEmbeddingModel): + """A model that uses Roberta to provide embedding functionalities. + + This class encapsulates the RobertaModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of RobertaModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + # Skip BertEmbeddingModule.__init__() + nn.Module.__init__(self) + self.base_model_prefix = "roberta" + self.model = RobertaModel( + config=kwargs["config"], + cache_config=kwargs.get("cache_config", None), + quant_config=kwargs.get("quant_config", None)) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + # self._pooler = BertPooler(config=kwargs["config"]) From e351bfd0febe4bbf8030fcd07f39eef5cce97641 Mon Sep 17 00:00:00 2001 From: laishzh Date: Sun, 8 Sep 2024 23:50:18 +0800 Subject: [PATCH 436/443] feat: bert embedding implemented, but still have some bugs with mistral, --- examples/offline_inference_bert_embedding.py | 5 +++- examples/offline_inference_embedding.py | 5 +++- tests/models/test_embedding.py | 24 ++++++++++++++------ vllm/model_executor/layers/pooler.py | 7 ++++++ vllm/model_executor/models/bert_embedding.py | 2 +- vllm/worker/embedding_model_runner.py | 7 +++--- vllm/worker/enc_dec_model_runner.py | 12 ++++++---- 7 files changed, 44 insertions(+), 18 deletions(-) diff --git a/examples/offline_inference_bert_embedding.py b/examples/offline_inference_bert_embedding.py index 10dd791c01d38..7cf6e3fb5933b 100644 --- a/examples/offline_inference_bert_embedding.py +++ b/examples/offline_inference_bert_embedding.py @@ -1,7 +1,10 @@ from vllm import LLM # Sample prompts. -prompts = ["This is an example sentence."] +prompts = [ + "This is an example sentence.", + "Another example sentence.", +] # Create an LLM. model = LLM(model="bert-base-uncased", enforce_eager=True) diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py index 4b742b4ba65fc..80ac365d65040 100644 --- a/examples/offline_inference_embedding.py +++ b/examples/offline_inference_embedding.py @@ -14,7 +14,10 @@ model="intfloat/e5-mistral-7b-instruct", enforce_eager=True, # NOTE: sliding_window is not supported by encoder_decoder_model - disable_sliding_window=True) + disable_sliding_window=True, + max_model_len = 2672, + gpu_memory_utilization=0.95, +) # Generate embedding. The output is a list of EmbeddingRequestOutputs. outputs = model.encode(prompts) # Print the outputs. diff --git a/tests/models/test_embedding.py b/tests/models/test_embedding.py index aabd48a28625c..0528072211238 100644 --- a/tests/models/test_embedding.py +++ b/tests/models/test_embedding.py @@ -9,7 +9,8 @@ from vllm.inputs import build_decoder_prompts MODELS = [ - "intfloat/e5-mistral-7b-instruct", + {"name": "intfloat/e5-mistral-7b-instruct", "is_decoder_only": True}, + # {"name": "bert-base-uncased", "is_decoder_only": False, "max_model_len": 512}, ] @@ -27,24 +28,33 @@ def test_models( hf_runner, vllm_runner, example_prompts, - model: str, + model: dict, dtype: str, ) -> None: - with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model: - hf_outputs = hf_model.encode(example_prompts) + # # FIXME: + example_prompts = example_prompts[0] + + model_name = model["name"] + is_decoder_only = model["is_decoder_only"] + max_model_len = model["max_model_len"] if "max_model_len" in model else 1024 + # with hf_runner(model_name, dtype=dtype, is_embedding_model=True) as hf_model: + # hf_outputs = hf_model.encode(example_prompts) + hf_outputs = [] with vllm_runner( - model, + model_name, dtype=dtype, disable_sliding_window=True, enforce_eager=True, gpu_memory_utilization=0.95, + max_model_len=max_model_len, ) as vllm_model: - vllm_outputs = vllm_model.encode(build_decoder_prompts(example_prompts)) + prompt_inputs = build_decoder_prompts(example_prompts) if is_decoder_only else example_prompts + vllm_outputs = vllm_model.encode(prompt_inputs) similarities = compare_embeddings(hf_outputs, vllm_outputs) all_similarities = torch.stack(similarities) tolerance = 1e-2 assert torch.all((all_similarities <= 1.0 + tolerance) & (all_similarities >= 1.0 - tolerance) - ), f"Not all values are within {tolerance} of 1.0" + ), f"Not all values are within {tolerance} of 1.0, all_similarities: {all_similarities}" diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 445b30b8c6e9b..d9a43abfa0ba9 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -11,6 +11,7 @@ class PoolingType(IntEnum): """Enumeration for different types of pooling methods.""" LAST = 0 + MEAN = 1 class Pooler(nn.Module): @@ -43,6 +44,12 @@ def forward( if self.pooling_type == PoolingType.LAST: last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 pooled_data = hidden_states[last_token_flat_indices] + elif self.pooling_type == PoolingType.MEAN: + # Calculate mean pooling + cumsum = torch.cumsum(hidden_states, dim=0) + start_indices = torch.cat([torch.tensor([0], device=hidden_states.device), torch.cumsum(prompt_lens[:-1], dim=0)]) + end_indices = torch.cumsum(prompt_lens, dim=0) + pooled_data = (cumsum[end_indices - 1] - cumsum[start_indices] + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) else: raise ValueError(f"Invalid pooling type: {self.pooling_type}") diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index 9c620cd4627f0..f157b3b56a85a 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -358,7 +358,7 @@ def __init__( self.model = BertModel(config=kwargs["config"], cache_config=kwargs.get("cache_config", None), quant_config=kwargs.get("quant_config", None)) - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self._pooler = Pooler(pooling_type=PoolingType.MEAN, normalize=False) # self._pooler = BertPooler(config=kwargs["config"]) def forward( diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index f80c9b5d51b0d..a748cfa3afd0d 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -158,6 +158,7 @@ def prepare_model_input( attn_metadata, encoder_input_tokens_tensor, encoder_input_positions_tensor, + encoder_seq_lens, ) = super()._prepare_encoder_model_input_tensors( seq_group_metadata_list, model_input) @@ -169,9 +170,9 @@ def prepare_model_input( ) # Prepare PoolingMetadata. - assert model_input.seq_lens is not None - pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - model_input.seq_lens) + seq_lens = model_input.seq_lens if not self.model_config.is_encoder_model else encoder_seq_lens + assert seq_lens is not None, f"model is_encoder_model: {self.model_config.is_encoder_model}" + pooling_metadata = self._prepare_pooling(seq_group_metadata_list,seq_lens) return dataclasses.replace(model_input, pooling_metadata=pooling_metadata) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 24f0edde6ccb5..ca13040f98afc 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -42,6 +42,7 @@ class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): """ encoder_input_tokens: Optional[torch.Tensor] = None encoder_input_positions: Optional[torch.Tensor] = None + encoder_seq_lens: Optional[List[int]] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -225,7 +226,7 @@ def _prepare_encoder_model_input_tensors( seq_group_metadata_list: List[SequenceGroupMetadata], model_input: TEncoderDecoderModelInput, ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], - Optional[torch.Tensor]]: + Optional[torch.Tensor], List[int]]: """Helper method to prepare the encoder- and cross-attn-related model inputs based on a given sequence group. These additional inputs are used to augment an already-computed `TEncoderDecoderModelInput` @@ -262,7 +263,7 @@ def _prepare_encoder_model_input_tensors( """ if len(seq_group_metadata_list) == 0: - return (model_input.attn_metadata, None, None) + return (model_input.attn_metadata, None, None, []) # Since we are not supporting chunked prefill either the entire # batch is prefill or it is decode @@ -382,7 +383,7 @@ def _prepare_encoder_model_input_tensors( ) return (attn_metadata, encoder_input_tokens_tensor, - encoder_input_positions_tensor) + encoder_input_positions_tensor, encoder_seq_lens) class EncoderDecoderModelRunner( @@ -462,8 +463,9 @@ def prepare_model_input( attn_metadata, encoder_input_tokens_tensor, encoder_input_positions_tensor, - ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, - model_input)) + _ + ) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list, + model_input) # Inject attn_metadata encoder/cross-attention fields & # encoder input tokens/positions into model_input. From 3ff2d36375d9560f87c56860ffff8a774a217cf9 Mon Sep 17 00:00:00 2001 From: laishzh Date: Mon, 9 Sep 2024 10:29:01 +0800 Subject: [PATCH 437/443] feat: some changes on test_embedding.py --- tests/models/test_embedding.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/models/test_embedding.py b/tests/models/test_embedding.py index 0528072211238..f24b8c2b17cb9 100644 --- a/tests/models/test_embedding.py +++ b/tests/models/test_embedding.py @@ -10,7 +10,7 @@ MODELS = [ {"name": "intfloat/e5-mistral-7b-instruct", "is_decoder_only": True}, - # {"name": "bert-base-uncased", "is_decoder_only": False, "max_model_len": 512}, + {"name": "bert-base-uncased", "is_decoder_only": False, "max_model_len": 512}, ] @@ -31,22 +31,19 @@ def test_models( model: dict, dtype: str, ) -> None: - # # FIXME: - example_prompts = example_prompts[0] - model_name = model["name"] is_decoder_only = model["is_decoder_only"] max_model_len = model["max_model_len"] if "max_model_len" in model else 1024 - # with hf_runner(model_name, dtype=dtype, is_embedding_model=True) as hf_model: - # hf_outputs = hf_model.encode(example_prompts) - hf_outputs = [] + with hf_runner(model_name, dtype=dtype, is_embedding_model=True) as hf_model: + hf_outputs = hf_model.encode(example_prompts) with vllm_runner( model_name, dtype=dtype, disable_sliding_window=True, enforce_eager=True, - gpu_memory_utilization=0.95, + # NOTE: Uncomment this line if runs out of GPU memory. + # gpu_memory_utilization=0.95, max_model_len=max_model_len, ) as vllm_model: prompt_inputs = build_decoder_prompts(example_prompts) if is_decoder_only else example_prompts @@ -57,4 +54,4 @@ def test_models( tolerance = 1e-2 assert torch.all((all_similarities <= 1.0 + tolerance) & (all_similarities >= 1.0 - tolerance) - ), f"Not all values are within {tolerance} of 1.0, all_similarities: {all_similarities}" + ), f"Not all values are within {tolerance} of 1.0" From 0ea4da1c549bf35c8456c47729da46dd33481cac Mon Sep 17 00:00:00 2001 From: laishzh Date: Mon, 9 Sep 2024 23:01:22 +0800 Subject: [PATCH 438/443] feat: fix lint --- examples/offline_inference_embedding.py | 1 - tests/models/test_embedding.py | 17 +++++++++++++---- vllm/model_executor/layers/pooler.py | 9 +++++++-- vllm/worker/embedding_model_runner.py | 10 +++++++--- vllm/worker/enc_dec_model_runner.py | 4 ++-- 5 files changed, 29 insertions(+), 12 deletions(-) diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py index 80ac365d65040..013d2d6bb735f 100644 --- a/examples/offline_inference_embedding.py +++ b/examples/offline_inference_embedding.py @@ -15,7 +15,6 @@ enforce_eager=True, # NOTE: sliding_window is not supported by encoder_decoder_model disable_sliding_window=True, - max_model_len = 2672, gpu_memory_utilization=0.95, ) # Generate embedding. The output is a list of EmbeddingRequestOutputs. diff --git a/tests/models/test_embedding.py b/tests/models/test_embedding.py index f24b8c2b17cb9..482419e195f0c 100644 --- a/tests/models/test_embedding.py +++ b/tests/models/test_embedding.py @@ -9,8 +9,15 @@ from vllm.inputs import build_decoder_prompts MODELS = [ - {"name": "intfloat/e5-mistral-7b-instruct", "is_decoder_only": True}, - {"name": "bert-base-uncased", "is_decoder_only": False, "max_model_len": 512}, + { + "name": "intfloat/e5-mistral-7b-instruct", + "is_decoder_only": True + }, + { + "name": "bert-base-uncased", + "is_decoder_only": False, + "max_model_len": 512 + }, ] @@ -34,7 +41,8 @@ def test_models( model_name = model["name"] is_decoder_only = model["is_decoder_only"] max_model_len = model["max_model_len"] if "max_model_len" in model else 1024 - with hf_runner(model_name, dtype=dtype, is_embedding_model=True) as hf_model: + with hf_runner(model_name, dtype=dtype, + is_embedding_model=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) with vllm_runner( @@ -46,7 +54,8 @@ def test_models( # gpu_memory_utilization=0.95, max_model_len=max_model_len, ) as vllm_model: - prompt_inputs = build_decoder_prompts(example_prompts) if is_decoder_only else example_prompts + prompt_inputs = build_decoder_prompts( + example_prompts) if is_decoder_only else example_prompts vllm_outputs = vllm_model.encode(prompt_inputs) similarities = compare_embeddings(hf_outputs, vllm_outputs) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index d9a43abfa0ba9..3201ce931a90d 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -47,9 +47,14 @@ def forward( elif self.pooling_type == PoolingType.MEAN: # Calculate mean pooling cumsum = torch.cumsum(hidden_states, dim=0) - start_indices = torch.cat([torch.tensor([0], device=hidden_states.device), torch.cumsum(prompt_lens[:-1], dim=0)]) + start_indices = torch.cat([ + torch.tensor([0], device=hidden_states.device), + torch.cumsum(prompt_lens[:-1], dim=0) + ]) end_indices = torch.cumsum(prompt_lens, dim=0) - pooled_data = (cumsum[end_indices - 1] - cumsum[start_indices] + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) + pooled_data = ( + cumsum[end_indices - 1] - cumsum[start_indices] + + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) else: raise ValueError(f"Invalid pooling type: {self.pooling_type}") diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index a748cfa3afd0d..81b9d121212a3 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -170,9 +170,13 @@ def prepare_model_input( ) # Prepare PoolingMetadata. - seq_lens = model_input.seq_lens if not self.model_config.is_encoder_model else encoder_seq_lens - assert seq_lens is not None, f"model is_encoder_model: {self.model_config.is_encoder_model}" - pooling_metadata = self._prepare_pooling(seq_group_metadata_list,seq_lens) + seq_lens = model_input.seq_lens\ + if not self.model_config.is_encoder_model \ + else encoder_seq_lens + assert seq_lens is not None, "model is_encoder_model: "\ + f"{self.model_config.is_encoder_model}" + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + seq_lens) return dataclasses.replace(model_input, pooling_metadata=pooling_metadata) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 392f435079248..90958fad61890 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -466,8 +466,8 @@ def prepare_model_input( encoder_input_tokens_tensor, encoder_input_positions_tensor, _ - ) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list, - model_input) + ) = self._prepare_encoder_model_input_tensors( + seq_group_metadata_list, model_input) # Inject attn_metadata encoder/cross-attention fields & # encoder input tokens/positions into model_input. From 15be7fa8bce185f64fafecaabdb8c828e83f4ad8 Mon Sep 17 00:00:00 2001 From: laishzh Date: Mon, 9 Sep 2024 23:04:44 +0800 Subject: [PATCH 439/443] feat: fix lint --- vllm/worker/enc_dec_model_runner.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 90958fad61890..1a25329d5eea7 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -461,12 +461,9 @@ def prepare_model_input( model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) - ( - attn_metadata, - encoder_input_tokens_tensor, - encoder_input_positions_tensor, - _ - ) = self._prepare_encoder_model_input_tensors( + (attn_metadata, encoder_input_tokens_tensor, + encoder_input_positions_tensor, + _) = self._prepare_encoder_model_input_tensors( seq_group_metadata_list, model_input) # Inject attn_metadata encoder/cross-attention fields & From 08f1781d6bd49653bd62ffdfde4f86d903f0c65a Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 23 Sep 2024 17:04:35 -0300 Subject: [PATCH 440/443] add head size 32 Signed-off-by: Max de Bayser --- csrc/attention/attention_kernels.cu | 6 ++++++ csrc/cpu/attention.cpp | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index bcd170411e7cb..c53cda16d4714 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -739,6 +739,9 @@ void paged_attention_v1_launcher( // NOTE(woosuk): To reduce the compilation time, we only compile for the // head sizes that we use in the model. However, we can easily extend this // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V1(32); + break; case 64: LAUNCH_PAGED_ATTENTION_V1(64); break; @@ -903,6 +906,9 @@ void paged_attention_v2_launcher( // NOTE(woosuk): To reduce the compilation time, we only compile for the // head sizes that we use in the model. However, we can easily extend this // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V2(32); + break; case 64: LAUNCH_PAGED_ATTENTION_V2(64); break; diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index abb4e3bea14bb..55921c05711f1 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -375,6 +375,9 @@ void paged_attention_v1_impl_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { + case 32: + LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; case 64: LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); break; @@ -692,6 +695,9 @@ void paged_attention_v2_impl_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { + case 32: + LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; case 64: LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); break; From 04b0bc6ff534495a9627f5548767f5bfb95268e8 Mon Sep 17 00:00:00 2001 From: laishzh Date: Mon, 7 Oct 2024 02:54:55 +0800 Subject: [PATCH 441/443] feat: revert embedding_block_manager --- .../embedding/language/test_embedding.py | 3 - vllm/core/embedding_model_block_manager.py | 98 ++++++++++++++++++- vllm/core/interfaces.py | 5 + vllm/core/scheduler.py | 2 + vllm/engine/llm_engine.py | 7 +- vllm/inputs/preprocess.py | 3 +- vllm/worker/enc_dec_model_runner.py | 10 +- 7 files changed, 114 insertions(+), 14 deletions(-) diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index b73a3b9c90632..a8335d1184124 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -49,9 +49,6 @@ def test_models( model_name, dtype=dtype, disable_sliding_window=True, - enforce_eager=True, - # NOTE: Uncomment this line if runs out of GPU memory. - # gpu_memory_utilization=0.95, max_model_len=max_model_len, ) as vllm_model: prompt_inputs = build_decoder_prompts( diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index 001bf5fb31019..a11fe9881a78b 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -1 +1,97 @@ -# TODO: Remove this file if possible. +from typing import List, Tuple + +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device + + +class EmbeddingModelBlockSpaceManager(BlockSpaceManager): + """An embedding version of BlockSpaceManager for use in environments + with embedding models where block management is not required. + + This class provides the same interface as BlockSpaceManager, but its + methods perform no actions or return simple values like True in specific + actions. It's designed to be used in scenarios where the overhead of + block management is unnecessary, such as in an embedding environment. + """ + + def __init__( + self, + **kwargs, + ) -> None: + pass + + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + # Always return OK for dummy purposes + return AllocStatus.OK + + def allocate(self, seq_group: SequenceGroup) -> None: + # No actual allocation logic needed + pass + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + return True + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + return None # type: ignore + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.OK + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + return True + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def free(self, seq: Sequence) -> None: + # No operation on free + return + + def free_cross(self, seq: Sequence) -> None: + # No operation on free + return + + def get_block_table(self, seq: Sequence) -> List[int]: + return None # type: ignore + + def get_cross_block_table(self, seq: Sequence) -> List[int]: + return None # type: ignore + + def get_num_free_gpu_blocks(self) -> int: + return 1 + + def get_num_free_cpu_blocks(self) -> int: + return 1 + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + def get_common_computed_block_ids(self, + seq_group: List[Sequence]) -> List[int]: + return [] + + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + pass + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return -1 diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index b75db4230abd1..6346711587301 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -36,6 +36,11 @@ def get_block_space_manager_class(version: str): from vllm.core.block_manager_v2 import BlockSpaceManagerV2 return BlockSpaceManagerV2 + if version == "embedding": + from vllm.core.embedding_model_block_manager import ( + EmbeddingModelBlockSpaceManager) + return EmbeddingModelBlockSpaceManager + raise ValueError(f"Unknown version {version=}") @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 14627e1105d68..f3a5016d0e62a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -314,6 +314,8 @@ def __init__( version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" + if self.scheduler_config.embedding_mode: + version = "embedding" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a69eeb09d31e2..d6258c6413d87 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -59,7 +59,6 @@ logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 -_DEFAULT_BOS_TOKEN_ID = 1 def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: @@ -349,7 +348,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: observability_config=self.observability_config, ) - self._initialize_kv_caches() + if not self.model_config.embedding_mode: + self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): @@ -1879,9 +1879,6 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: def is_encoder_decoder_model(self): return self.input_preprocessor.is_encoder_decoder_model() - def is_encoder_model(self): - return self.model_config.is_encoder_model - def is_embedding_model(self): return self.model_config.is_embedding_model diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 6f705f03e0a62..2b4e661de0e1c 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -551,4 +551,5 @@ async def preprocess_async( ) def is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model + return self.model_config.is_encoder_decoder_model \ + or self.model_config.is_encoder_model diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 549261663647f..be58f71861452 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -328,10 +328,12 @@ def _prepare_encoder_model_input_tensors( cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len) else: for i in range(0, seq_len): - block_number = seq_group_metadata.cross_block_table[ - i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset + slot = PAD_SLOT_ID + if seq_group_metadata.cross_block_table is not None: + block_number = seq_group_metadata.cross_block_table[ + i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset cross_slot_mapping.append(slot) # Build encoder input tokens From 80c18855fcff195175b7046923c4b0c3815f141a Mon Sep 17 00:00:00 2001 From: laishzh Date: Mon, 7 Oct 2024 12:04:34 +0800 Subject: [PATCH 442/443] feat: update with origin/main --- tests/models/embedding/language/test_embedding.py | 4 ++++ vllm/model_executor/models/registry.py | 1 + 2 files changed, 5 insertions(+) diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index de07a38686a7b..3a644479dd3b7 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -13,6 +13,10 @@ "name": "intfloat/e5-mistral-7b-instruct", "is_decoder_only": True }, + { + "name": "BAAI/bge-multilingual-gemma2", + "is_decoder_only": True + }, { "name": "bert-base-uncased", "is_decoder_only": False, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ccb0e155ff4aa..8a0ea021e5e47 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -84,6 +84,7 @@ "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"), + "BertForMaskedLM": ("bert_embedding", "BertEmbeddingModel"), } _MULTIMODAL_MODELS = { From 935c58d9e70ed6e84559e95f696c65dfb282e422 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Fri, 11 Oct 2024 14:28:57 -0300 Subject: [PATCH 443/443] add registry of encoder-only models Signed-off-by: Max de Bayser --- vllm/config.py | 5 +---- vllm/model_executor/models/registry.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 57419aca9b632..c082903d81617 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -558,10 +558,7 @@ def is_encoder_decoder_model(self) -> bool: @property def is_encoder_model(self) -> bool: - is_encoder_decoder = getattr(self.hf_config, "is_encoder_decoder", - False) - is_decoder = getattr(self.hf_config, "is_decoder", False) - return is_encoder_decoder is False and is_decoder is False + return ModelRegistry.is_encoder_model(self.hf_config.architectures) @property def is_embedding_model(self) -> bool: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 5ac793592efcd..42e5bd296cbba 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -90,6 +90,13 @@ "RobertaModel": ("roberta_embedding", "RobertaEmbeddingModel"), } +_ENCODER_MODELS = [ + "BertForMaskedLM", + "BertModel", + "RobertaForMaskedLM", + "RobertaModel", +] + _MULTIMODAL_MODELS = { # [Decoder-only] "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), @@ -342,6 +349,15 @@ def is_embedding_model(architectures: Union[str, List[str]]) -> bool: default=False) return any(is_emb(arch) for arch in architectures) + + @staticmethod + def is_encoder_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + return any(arch in _ENCODER_MODELS for arch in architectures) @staticmethod def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: