Skip to content

Commit

Permalink
add assert to verify assumption
Browse files Browse the repository at this point in the history
Signed-off-by: Max de Bayser <[email protected]>
  • Loading branch information
maxdebayser committed Nov 14, 2024
1 parent 49e8381 commit 1267bba
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Optional
from typing import List, Optional

import torch
from torch import nn
from transformers import RobertaConfig

from vllm.attention import AttentionMetadata
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.sequence import IntermediateTensors


class RobertaEmbedding(nn.Module):
Expand Down Expand Up @@ -82,3 +85,34 @@ def _build_model(self,
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=RobertaEmbedding)

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:

# Verify assumption that position are always a sequence from
# 0 to N. (Actually here we just check 0 and N to simplify).
# This is important to fix the position which are assumed to
# start from padding_idx + 1 instead of 0 in the Roberta models.
assert isinstance(attn_metadata, XFormersMetadata)
cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0)
start_pos = torch.cat(
(torch.tensor([0], device=attn_metadata.seq_lens_tensor.device),
cumulative[:-1]))
assert len(torch.nonzero(positions[start_pos])) == 0
end_pos = cumulative - 1
last_tokens = attn_metadata.seq_lens_tensor - 1
assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0

return super().forward(input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)

0 comments on commit 1267bba

Please sign in to comment.