XLNet的非官方实现。嵌入提取和有记忆的嵌入提取展示了如何加载预训练检查点并得到transformer的输出特征。
pip install keras-xlnet
点击任务名可以查看基础模型的训练样例:
任务名 | 指标 | 验证集上大致结果 |
---|---|---|
CoLA | Matthew Corr. | 52 |
SST-2 | Accuracy | 93 |
MRPC | Accuracy/F1 | 86/89 |
STS-B | Pearson Corr. / Spearman Corr. | 86/87 |
QQP | Accuracy/F1 | 90/86 |
MNLI | Accuracy | 84/84 |
QNLI | Accuracy | 86 |
RTE | Accuracy | 64 |
WNLI | Accuracy | 56 |
(注意:WNLI数据集上只输出了0,不是一个正常结果)
import os
from keras_xlnet import Tokenizer, load_trained_model_from_checkpoint, ATTENTION_TYPE_BI
checkpoint_path = '.../xlnet_cased_L-24_H-1024_A-16'
tokenizer = Tokenizer(os.path.join(checkpoint_path, 'spiece.model'))
model = load_trained_model_from_checkpoint(
config_path=os.path.join(checkpoint_path, 'xlnet_config.json'),
checkpoint_path=os.path.join(checkpoint_path, 'xlnet_model.ckpt'),
batch_size=16,
memory_len=512,
target_len=128,
in_train_phase=False,
attention_type=ATTENTION_TYPE_BI,
)
model.summary()
参数batch_size
、memory_len
、target_len
用于初始化记忆单元,代表最大尺寸,实际属于可以小于对应数值。如果in_train_phase
是True
会返回一个用于训练语言模型的模型,否则返回一个用于fine-tuning的模型。
注意:依赖记忆时输入有序,一定不能打乱输入顺序,fit
或fit_generator
的shuffle
应该为False
。
3个输入:
- 词的ID,形状为
(batch_size, target_len)
。 - 段落的ID,形状为
(batch_size, target_len)
。 - 历史记忆的长度,形状为
(batch_size, 1)
。
1个输出:
- 每个词的特征,形状为
(batch_size, target_len, units)
。
4个输入,前三个和in_train_phase
为False
时相同:
- 词的ID,形状为
(batch_size, target_len)
。 - 段落的ID,形状为
(batch_size, target_len)
。 - 历史记忆的长度,形状为
(batch_size, 1)
。 - 被遮罩的词的蒙版,形状为
(batch_size, target_len)
。
1个输出:
- 每个位置每个词的概率,形状为
(batch_size, target_len, num_token)
。