Skip to content

Commit

Permalink
fixed LayerNormalization to be per-feature
Browse files Browse the repository at this point in the history
removed log_softmax from projection layer
  • Loading branch information
hkproj committed Sep 13, 2023
1 parent 1942164 commit c9ddb87
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

class LayerNormalization(nn.Module):

def __init__(self, eps:float=10**-6) -> None:
def __init__(self, features: int, eps:float=10**-6) -> None:
super().__init__()
self.eps = eps
self.alpha = nn.Parameter(torch.ones(1)) # alpha is a learnable parameter
self.bias = nn.Parameter(torch.zeros(1)) # bias is a learnable parameter
self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter

def forward(self, x):
# x: (batch, seq_len, hidden_size)
Expand Down Expand Up @@ -72,10 +72,10 @@ def forward(self, x):

class ResidualConnection(nn.Module):

def __init__(self, dropout: float) -> None:
def __init__(self, features: int, dropout: float) -> None:
super().__init__()
self.dropout = nn.Dropout(dropout)
self.norm = LayerNormalization()
self.norm = LayerNormalization(features)

def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))
Expand Down Expand Up @@ -135,11 +135,11 @@ def forward(self, q, k, v, mask):

class EncoderBlock(nn.Module):

def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
super().__init__()
self.self_attention_block = self_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

def forward(self, x, src_mask):
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
Expand All @@ -148,10 +148,10 @@ def forward(self, x, src_mask):

class Encoder(nn.Module):

def __init__(self, layers: nn.ModuleList) -> None:
def __init__(self, features: int, layers: nn.ModuleList) -> None:
super().__init__()
self.layers = layers
self.norm = LayerNormalization()
self.norm = LayerNormalization(features)

def forward(self, x, mask):
for layer in self.layers:
Expand All @@ -160,12 +160,12 @@ def forward(self, x, mask):

class DecoderBlock(nn.Module):

def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
super().__init__()
self.self_attention_block = self_attention_block
self.cross_attention_block = cross_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])
self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

def forward(self, x, encoder_output, src_mask, tgt_mask):
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
Expand All @@ -175,10 +175,10 @@ def forward(self, x, encoder_output, src_mask, tgt_mask):

class Decoder(nn.Module):

def __init__(self, layers: nn.ModuleList) -> None:
def __init__(self, features: int, layers: nn.ModuleList) -> None:
super().__init__()
self.layers = layers
self.norm = LayerNormalization()
self.norm = LayerNormalization(features)

def forward(self, x, encoder_output, src_mask, tgt_mask):
for layer in self.layers:
Expand All @@ -193,7 +193,7 @@ def __init__(self, d_model, vocab_size) -> None:

def forward(self, x) -> None:
# (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
return torch.log_softmax(self.proj(x), dim = -1)
return self.proj(x)

class Transformer(nn.Module):

Expand Down Expand Up @@ -237,7 +237,7 @@ def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int
for _ in range(N):
encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
encoder_blocks.append(encoder_block)

# Create the decoder blocks
Expand All @@ -246,12 +246,12 @@ def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int
decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
decoder_blocks.append(decoder_block)

# Create the encoder and decoder
encoder = Encoder(nn.ModuleList(encoder_blocks))
decoder = Decoder(nn.ModuleList(decoder_blocks))
encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))

# Create the projection layer
projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
Expand Down

3 comments on commit c9ddb87

@Tiger-Mondo
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I'd like to ask you why you removed log_softmax from projection layer?

@hkproj
Copy link
Owner Author

@hkproj hkproj commented on c9ddb87 Dec 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I'd like to ask you why you removed log_softmax from projection layer?

The reason is described in this issue.

@Tiger-Mondo
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我明白了,感谢您的回复!

Please sign in to comment.