Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: shunxing12345 <[email protected]>
  • Loading branch information
shunxing12345 committed Dec 30, 2024
1 parent fa79011 commit 6370c62
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 22 deletions.
15 changes: 4 additions & 11 deletions src/transformers/models/telechat2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,8 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)

from ...utils import (OptionalDependencyNotAvailable, _LazyModule,
is_torch_available)

_import_structure = {
"configuration_telechat2": ["TeleChat2Config"],
Expand Down Expand Up @@ -46,11 +42,8 @@
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_telechat2 import (
TeleChat2PreTrainedModel,
TeleChat2Model,
TeleChat2ForCausalLM,
)
from .modeling_telechat2 import (TeleChat2ForCausalLM, TeleChat2Model,
TeleChat2PreTrainedModel)


else:
Expand Down
20 changes: 9 additions & 11 deletions src/transformers/models/telechat2/modeling_telechat2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@

"""PyTorch TeleChat2 model implementation, refactored with Transformers-style conventions."""

import math
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint

from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
Expand All @@ -34,15 +33,14 @@
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_telechat2 import TeleChat2Config

import pdb

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -263,7 +261,7 @@ def __init__(self, config: TeleChat2Config, layer_idx: int):
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True

self.query = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=False
)
Expand Down Expand Up @@ -311,7 +309,7 @@ def forward(
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, attn_weights = attention_interface(
self,
query_states,
Expand All @@ -334,7 +332,7 @@ def __init__(self, config: TeleChat2Config, layer_idx: int):
self.hidden_size = config.hidden_size

self.self_attention = TeleChat2Attention(config=config, layer_idx=layer_idx)

self.mlp = TeleChat2MLP(config)
self.input_layernorm = TeleChat2RMSNorm(self.hidden_size, eps=config.layer_norm_epsilon)
self.post_attention_layernorm = TeleChat2RMSNorm(self.hidden_size, eps=config.layer_norm_epsilon)
Expand Down Expand Up @@ -368,7 +366,7 @@ def forward(
**kwargs,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
Expand Down Expand Up @@ -518,9 +516,9 @@ def __init__(self, config: TeleChat2Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)

self.h = nn.ModuleList(
[TeleChat2DecoderLayer(config, i) for i in range(config.num_hidden_layers)]
)
Expand Down

0 comments on commit 6370c62

Please sign in to comment.