-
Notifications
You must be signed in to change notification settings - Fork 4
/
lz.py
99 lines (75 loc) · 4.66 KB
/
lz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from sequential import *
from settings import Config
import models
import utils
class UserModeling(Seq2Vec):
def _build_model(self):
self.doc_encoder = doc_encoder = self.get_doc_encoder()
user_encoder = keras.layers.TimeDistributed(doc_encoder)
clicked = keras.Input((self.config.window_size, self.config.title_shape))
candidate = keras.Input((self.config.title_shape,))
clicked_vec = user_encoder(clicked)
candidate_vec = doc_encoder(candidate)
mask = models.LzComputeMasking(0)(clicked)
clicked_vec = keras.layers.Lambda(lambda x: x[0] * keras.backend.expand_dims(x[1]))([clicked_vec, mask])
user_model = self.config.arch
logging.info('[!] Selected User Model: {}'.format(user_model))
if "pre-train" in user_model:
channel_count = int(user_model.split("-")[-1])
clicked_vec, orth_reg = models.LzCompressionPredictor(channel_count=channel_count,
mode="pretrain",
enable_pretrain_attention=True)(clicked_vec)
clicked_vec = models.LzQueryAttentionPooling()(clicked_vec, candidate_vec)
logits = models.LzLogits(mode="dot")([clicked_vec, candidate_vec])
self.model = keras.Model([clicked, candidate], logits)
self.model.compile(optimizer=keras.optimizers.Adam(lr=self.config.learning_rate, clipnorm=5.0),
loss=self.loss,
metrics=[utils.auc_roc])
elif "pre-plus" in user_model:
logging.info("preplus")
channel_count = int(user_model.split("-")[-1])
if "self" in user_model:
x_click_vec = models._LzSelfAttention(mapping=True)(clicked_vec)
clicked_vec = keras.layers.Average()([clicked_vec, x_click_vec])
clicked_vec, orth_reg = models.LzCompressionPredictor(channel_count=channel_count, mode="Pre")(clicked_vec)
orth_reg = orth_reg[0]
clicked_vec = models.LzQueryAttentionPooling()(clicked_vec, candidate_vec)
logits = models.LzLogits(mode="dot")([clicked_vec, candidate_vec])
self.model = keras.Model([clicked, candidate], logits)
self.config.l2_norm_coefficient = 0.1
self.model.add_loss(self.aux_loss(orth_reg * (channel_count / 3.0) ** 0.75))
self.model.compile(optimizer=keras.optimizers.Adam(lr=self.config.learning_rate, clipnorm=5.0),
loss=self.loss,
metrics=[utils.auc_roc])
self.model.metrics_names += ['orth_reg']
self.model.metrics_tensors += [orth_reg]
elif "pretrain-preplus" in user_model:
logging.info("pretrain-preplus")
channel_count = int(user_model.split("-")[-1])
self.config.enable_pretrain_attention = True
if "self" in user_model:
x_click_vec = models._LzSelfAttention(mapping=True)(clicked_vec)
clicked_vec = keras.layers.Average()([clicked_vec, x_click_vec])
clicked_vec, orth_reg = models.LzCompressionPredictor(channel_count=channel_count, mode="Pre",
enable_pretrain_attention=True)(clicked_vec)
with tf.name_scope('orth_reg_tensor'):
orth_reg = orth_reg[0]
tf.summary.scalar('orthreg',orth_reg)
clicked_vec = models.LzQueryAttentionPooling()(clicked_vec, candidate_vec)
logits = models.LzLogits(mode="dot")([clicked_vec, candidate_vec])
self.model = keras.Model([clicked, candidate], logits)
self.config.l2_norm_coefficient = 0.1
self.model.add_loss(self.aux_loss(orth_reg * (channel_count / 3.0) ** 0.75))
self.model.compile(optimizer=keras.optimizers.Adam(lr=self.config.learning_rate, clipnorm=5.0),
loss=self.loss,
metrics=[utils.auc_roc])
self.model.metrics_names += ['orth_reg']
self.model.metrics_tensors += [orth_reg]
else:
raise Exception("No available models. Please check param!")
logits = models.LzLogits(mode="dot")([clicked_vec, candidate_vec])
self.model = keras.Model([clicked, candidate], logits)
self.model.compile(optimizer=keras.optimizers.Adam(lr=self.config.learning_rate, clipnorm=5.0),
loss=self.loss,
metrics=[utils.auc_roc])
return self.model