Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangw2 committed Nov 26, 2024
1 parent 87e06d4 commit ad3bc76
Showing 1 changed file with 5 additions and 46 deletions.
51 changes: 5 additions & 46 deletions vllm/model_executor/models/telechat2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaAttention,
LlamaDecoderLayer, LlamaMLP,
from vllm.model_executor.models.llama import (LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM, LlamaMLP,
LlamaModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -159,11 +160,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
total_num_heads = self.config.n_head
head_dim = self.config.hidden_size // total_num_heads
for name, loaded_weight in weights:
#name = name.replace(".h.", ".layers.")
if "self_attn.key_value" in name:
k_weight = []
v_weight = []
#name = name.replace(".self_attention.", ".self_attn.")
for i in range(total_num_heads):
start = i * head_dim * 2
k_weight.append(loaded_weight[start:start + head_dim, :])
Expand Down Expand Up @@ -200,10 +199,10 @@ def load_weights(self, weights: Iterable[Tuple[str,
return loaded_params


class TeleChat2ForCausalLM(nn.Module):
class TeleChat2ForCausalLM(LlamaForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
super(LlamaForCausalLM, self).__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
config.intermediate_size = config.ffn_hidden_size
Expand All @@ -222,46 +221,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return model_output

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:

Expand Down

0 comments on commit ad3bc76

Please sign in to comment.