Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: shunxing12345 <[email protected]>
  • Loading branch information
shunxing12345 committed Dec 31, 2024
1 parent 1cca3c1 commit 55166ec
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 32 deletions.
24 changes: 2 additions & 22 deletions src/transformers/models/telechat2/configuration_telechat2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ class TeleChat2Config(PretrainedConfig):
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. TeleChat2 1 supports up to 2048 tokens,
TeleChat2 2 up to 4096, CodeTeleChat2 up to 16384.
Expand Down Expand Up @@ -118,18 +116,12 @@ class TeleChat2Config(PretrainedConfig):
Only used with 'telechat23'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'telechat23'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
head_dim (`int`, *optional*):
The attention head dimension. If None, it will default to hidden_size // num_heads
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
```python
>>> from transformers import TeleChat2Model, TeleChat2Config
Expand Down Expand Up @@ -165,7 +157,6 @@ def __init__(
n_layer=30,
n_head=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
layer_norm_epsilon=1e-6,
Expand All @@ -177,42 +168,31 @@ def __init__(
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
head_dim=None,
use_sliding_window=False,
sliding_window=None,
embed_layernorm=False,
max_window_layers=28,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.num_hidden_layers = n_layer
self.num_attention_heads = n_head
self.n_layer = n_layer
self.n_head = n_head
self.hidden_dropout = hidden_dropout
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = n_head

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.layer_norm_epsilon = layer_norm_epsilon
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.embed_layernorm = embed_layernorm
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
Expand Down
20 changes: 10 additions & 10 deletions src/transformers/models/telechat2/modeling_telechat2.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ class TeleChat2MLP(nn.Module):
def __init__(self, config: TeleChat2Config):
super().__init__()
hidden_size = config.hidden_size
self.gate_proj = nn.Linear(hidden_size, config.ffn_hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, config.ffn_hidden_size, bias=False)
self.gate_proj = nn.Linear(hidden_size, config.ffn_hidden_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(hidden_size, config.ffn_hidden_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(config.ffn_hidden_size, hidden_size, bias=True)
self.hidden_dropout = config.hidden_dropout

Expand All @@ -211,7 +211,7 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
num_key_value_heads, seqlen, head_dim) to (batch, n_head, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
Expand Down Expand Up @@ -255,15 +255,15 @@ def __init__(self, config: TeleChat2Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.n_head)
self.num_key_value_groups = config.n_head // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True

self.query = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.query = nn.Linear(config.hidden_size, config.n_head * self.head_dim, bias=False)
self.key_value = nn.Linear(config.hidden_size, self.head_dim * config.num_key_value_heads * 2, bias=False)
self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size)
self.dense = nn.Linear(config.n_head * self.head_dim, config.hidden_size)

def forward(
self,
Expand Down Expand Up @@ -500,7 +500,7 @@ def _init_weights(self, module):
)
class TeleChat2Model(TeleChat2PreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TeleChat2DecoderLayer`]
Transformer decoder consisting of *config.n_layer* layers. Each layer is a [`TeleChat2DecoderLayer`]
Args:
config: TeleChat2Config
Expand All @@ -513,7 +513,7 @@ def __init__(self, config: TeleChat2Config):

self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)

self.h = nn.ModuleList([TeleChat2DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
self.h = nn.ModuleList([TeleChat2DecoderLayer(config, i) for i in range(config.n_layer)])
self.ln_f = TeleChat2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.rotary_emb = TeleChat2RotaryEmbedding(config=config)
self.gradient_checkpointing = False
Expand Down Expand Up @@ -586,7 +586,7 @@ def forward(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

for decoder_layer in self.h[: self.config.num_hidden_layers]:
for decoder_layer in self.h[: self.config.n_layer]:
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down

0 comments on commit 55166ec

Please sign in to comment.