-
Notifications
You must be signed in to change notification settings - Fork 7
/
Model.py
126 lines (90 loc) · 3.79 KB
/
Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python
# -*- encoding: utf-8
'''
_____.___._______________ __.____ __________ _________ ___ ___ _____ .___
\__ | |\_ _____/ |/ _| | \ \ \_ ___ \ / | \ / _ \ | |
/ | | | __)_| < | | / | \ / \ \// ~ \/ /_\ \| |
\____ | | \ | \| | / | \ \ \___\ Y / | \ |
/ ______|/_______ /____|__ \______/\____|__ / \______ /\___|_ /\____|__ /___|
\/ \/ \/ \/ \/ \/ \/
==========================================================================================
@author: Yekun Chai
@license: School of Informatics, Edinburgh
@contact: [email protected]
@file: Model.py
@time: 29/09/2019 20:25
@desc:
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math, copy, time
import seaborn
seaborn.set_context(context="talk")
import utils
from Layers import MultiHeadedAttention, PositionwiseFeedForward, PositionalEncoding, EncoderLayer, DecoderLayer, \
Embeddings, MultiHeadedAttention_RPR
class EncoderDecoder(nn.Module):
"""
standard encoder decoder architecture
"""
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
super(EncoderDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.generator = generator
def forward(self, src, tgt, src_mask, tgt_mask):
return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
def encode(self, src, src_mask):
return self.encoder(self.src_embed(src), src_mask)
def decode(self, memory, src_mask, tgt, tgt_mask):
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
class Generator(nn.Module):
def __init__(self, d_model, vocab):
super(Generator, self).__init__()
self.proj = nn.Linear(d_model, vocab)
def forward(self, x):
return F.softmax(self.proj(x), dim=-1)
class Encoder(nn.Module):
""" Core encoder -> a stack of N layers """
def __init__(self, layer, N):
super(Encoder, self).__init__()
self.layers = utils.clones(layer, N)
size = layer.size
self.norm = nn.LayerNorm(size)
def forward(self, x, mask):
""" pass input and mask through each layer in turn"""
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class Decoder(nn.Module):
""" N layer decoder with masking"""
def __init__(self, layer, N):
super(Decoder, self).__init__()
self.layers = utils.clones(layer, N)
self.norm = nn.LayerNorm(layer.size)
def forward(self, x, memory, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, memory, src_mask, tgt_mask)
return self.norm(x)
def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=.1):
""" construct model from hyper-parameters"""
c = copy.deepcopy
attn_rpr = MultiHeadedAttention_RPR(d_model, h, max_relative_position=5)
attn = MultiHeadedAttention(d_model, h)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
position = PositionalEncoding(d_model, dropout)
model = EncoderDecoder(
Encoder(EncoderLayer(d_model, c(attn_rpr), c(ff), dropout), N),
Decoder(DecoderLayer(d_model, c(attn_rpr), c(attn), c(ff), dropout), N),
nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
Generator(d_model, tgt_vocab)
)
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return model