Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Latest commit

 

History

History
86 lines (59 loc) · 3.35 KB

README.zh-CN.md

File metadata and controls

86 lines (59 loc) · 3.35 KB

Keras XLNet

Version License

[中文|English|通用问题]

XLNet的非官方实现。嵌入提取有记忆的嵌入提取展示了如何加载预训练检查点并得到transformer的输出特征。

安装

pip install keras-xlnet

使用

GLUE微调

点击任务名可以查看基础模型的训练样例:

任务名 指标 验证集上大致结果
CoLA Matthew Corr. 52
SST-2 Accuracy 93
MRPC Accuracy/F1 86/89
STS-B Pearson Corr. / Spearman Corr. 86/87
QQP Accuracy/F1 90/86
MNLI Accuracy 84/84
QNLI Accuracy 86
RTE Accuracy 64
WNLI Accuracy 56

(注意:WNLI数据集上只输出了0,不是一个正常结果)

加载预训练检查点

import os
from keras_xlnet import Tokenizer, load_trained_model_from_checkpoint, ATTENTION_TYPE_BI

checkpoint_path = '.../xlnet_cased_L-24_H-1024_A-16'

tokenizer = Tokenizer(os.path.join(checkpoint_path, 'spiece.model'))
model = load_trained_model_from_checkpoint(
    config_path=os.path.join(checkpoint_path, 'xlnet_config.json'),
    checkpoint_path=os.path.join(checkpoint_path, 'xlnet_model.ckpt'),
    batch_size=16,
    memory_len=512,
    target_len=128,
    in_train_phase=False,
    attention_type=ATTENTION_TYPE_BI,
)
model.summary()

参数batch_sizememory_lentarget_len用于初始化记忆单元,代表最大尺寸,实际属于可以小于对应数值。如果in_train_phaseTrue会返回一个用于训练语言模型的模型,否则返回一个用于fine-tuning的模型。

关于输入输出

注意:依赖记忆时输入有序,一定不能打乱输入顺序,fitfit_generatorshuffle应该为False

in_train_phaseFalse

3个输入:

  • 词的ID,形状为(batch_size, target_len)
  • 段落的ID,形状为(batch_size, target_len)
  • 历史记忆的长度,形状为(batch_size, 1)

1个输出:

  • 每个词的特征,形状为(batch_size, target_len, units)

in_train_phaseTrue

4个输入,前三个和in_train_phaseFalse时相同:

  • 词的ID,形状为(batch_size, target_len)
  • 段落的ID,形状为(batch_size, target_len)
  • 历史记忆的长度,形状为(batch_size, 1)
  • 被遮罩的词的蒙版,形状为(batch_size, target_len)

1个输出:

  • 每个位置每个词的概率,形状为(batch_size, target_len, num_token)