forked from Gastron/sb-2015-2020-kevat_e2e
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlm_trafo.py
171 lines (142 loc) · 4.97 KB
/
lm_trafo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""An implementation of Transformer Language model.
Authors
* Jianyuan Zhong
* Samuele Cornell
"""
import torch # noqa 42
from torch import nn
from speechbrain.nnet.linear import Linear
from speechbrain.nnet.normalization import LayerNorm
from speechbrain.nnet.containers import ModuleList
from speechbrain.lobes.models.transformer.Transformer import (
TransformerInterface,
get_lookahead_mask,
get_key_padding_mask,
NormalizedEmbedding,
)
class TransformerLM(TransformerInterface):
"""This is an implementation of transformer language model.
The architecture is based on the paper "Attention Is All You Need": https://arxiv.org/pdf/1706.03762.pdf
Arguments
----------
d_model : int
The number of expected features in the encoder/decoder inputs (default=512).
nhead : int
The number of heads in the multiheadattention models (default=8).
num_encoder_layers : int
The number of sub-encoder-layers in the encoder (default=6).
num_decoder_layers : int
The number of sub-decoder-layers in the decoder (default=6).
dim_ffn : int
The dimension of the feedforward network model (default=2048).
dropout : int
The dropout value (default=0.1).
activation: torch class
The activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
Example
-------
>>> src = torch.randint(0, 720, [8, 120])
>>> net = TransformerLM(720, 512, 8, 1, 0, 1024, activation=torch.nn.GELU)
>>> enc_out = net.forward(src)
>>> print(enc_out.shape)
torch.Size([8, 120, 720])
"""
def __init__(
self,
vocab,
d_model=512,
nhead=8,
num_encoder_layers=12,
num_decoder_layers=0,
d_ffn=2048,
dropout=0.1,
activation=nn.ReLU,
positional_encoding="fixed_abs_sine",
normalize_before=False,
d_embedding=None,
max_length=2500,
causal=True,
attention_type="regularMHA",
pad_idx=-1
):
super().__init__(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
d_ffn=d_ffn,
dropout=dropout,
activation=activation,
positional_encoding=positional_encoding,
normalize_before=normalize_before,
max_length=max_length,
causal=causal,
attention_type=attention_type,
)
self.d_embedding = d_embedding
if d_embedding is None:
self.d_embedding = d_model
self.custom_src_module = NormalizedEmbedding(self.d_embedding, vocab)
self.embedding_proj = None
if d_embedding is not None:
self.embedding_proj = Linear(
input_size=self.d_embedding, n_neurons=d_model
)
self.output_proj = ModuleList(
Linear(input_size=d_model, n_neurons=d_model),
LayerNorm(d_model, eps=1e-6),
Linear(input_size=d_model, n_neurons=vocab),
)
self.num_encoder_layers = num_encoder_layers
self.num_decoder_layers = num_decoder_layers
# reset the params of the transformer model
self._reset_params()
self.pad_idx = pad_idx
def forward(self, src, hx=None):
"""
Arguments
---------
src : tensor
The sequence to the encoder (required).
"""
src_mask, src_key_padding_mask = self.make_masks(src)
src = self.custom_src_module(src)
if self.embedding_proj is not None:
src = self.embedding_proj(src)
src = src + self.positional_encoding(src)
if self.num_encoder_layers > 0:
encoder_out, _ = self.encoder(
src=src,
src_mask=src_mask,
src_key_padding_mask=src_key_padding_mask,
)
if self.num_decoder_layers > 0:
encoder_out, _ = self.decoder(
src=src,
tgt=src,
tgt_mask=src_mask,
tgt_key_padding_mask=src_key_padding_mask,
)
pred = self.output_proj(encoder_out)
return pred
def _reset_params(self):
for p in self.parameters():
if p.dim() > 1:
torch.nn.init.xavier_normal_(p)
def make_masks(
self, src, pad_idx=None, look_ahead_mask=True, padding_mask=True
):
if pad_idx is None:
pad_idx = self.pad_idx
src_mask = None
if look_ahead_mask:
src_mask = get_lookahead_mask(src)
src_key_padding_mask = None
if padding_mask:
src_key_padding_mask = get_key_padding_mask(src, pad_idx)
return src_mask, src_key_padding_mask
class TestTransformerLM(TransformerLM):
#USED FOR TRAFO LM WHERE LAST INDEX IS FOR PADDING
def forward(self, src, hx=None):
pred = super().forward(src, hx=hx)
return pred[:,:,:-1]