总结一些在学习工作中用过的的文本分类算法,欢迎star!
- data : 垃圾短信数据
- lr_main.py : 基于逻辑回归的文本分类
- nb_main.py: 基于朴素贝叶斯的文本分类
- svm_main.py: 基于支持向量机的文本分类
这里找到一份垃圾短信数据,data文件夹下为垃圾短信分类。
- python3 lr_main.py
- python3 nb_main.py
- python3 svm_main.py
[1] https://blog.csdn.net/Explorer_Du/article/details/84067510
[2] https://blog.csdn.net/Explorer_Du/article/details/85242690
- data: 训练数据集文件夹
- data_helper: 数据预处理
- models: 模型结构
- result: 存放结果的文件夹
- util: 训练词向量
- run.py: 训练模型文件
- train_val_test.py: 具体训练验证和测试的代码
使用THUCNews的一个子集进行训练与测试,由gaussic提供。
类别如下:
体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐
百度网盘下载链接在data文件夹下的README。
数据集划分如下:
- 训练集: 5000*10
- 验证集: 500*10
- 测试集: 1000*10
- cnews.train.txt: 训练集(50000条)
- cnews.val.txt: 验证集(5000条)
- cnews.test.txt: 测试集(10000条)
- 在util文件夹下运行:python3 word2vec.py
- 训练模型,运行run.py: python3 run.py --model xxx --mode train
- 测试模型,运行run.py: python3 run.py --model xxx --mode test
模型 | val准确率 | test准确率 | 备注 |
TextCNN | 0.931 | 0.945 | step为2200时候,提前终止了 |
TextRNN | 0.946 | 0.97 | 训练的速度相对TextCNN较慢 |
TextRCNN | 0.922 | 0.964 | GRU+pooling |
TextRNN_Att | 0.935 | 0.962 | 内存不够可以调小batch_size |
DPCNN | 0.934 | 0.955 | 训练的速度较快 |
HAN | 0.913 | 0.95 | 采用GRU |
Transformer | 0.909 | 0.92 | 效果最差... |
pytorch==1.1.0 transformers==3.0.2
- data: 训练数据集文件夹
- imgs: 存放运行结果
- models: 模型结构
- result: 存放结果的文件夹
- util: 工具函数文件
- run.py: 训练模型文件
- train_val_test.py: 具体训练验证和测试的代码
数据来源为649453932项目的数据。
python3 run.py --mode=bert
mode: bert/albert/roberta/ernie
- 2021年9月26日,计划增加新的预训练模型
- 2020年12月10日,增加基于pytorch版ernie文本分类实现
- 2020年12月9日,增加基于pytorch版roberta文本分类实现
- 2020年12月8日,增加基于pytorch版albert文本分类实现
- 2020年11月6日,增加基于pytorch版bert文本分类实现
- 2020年3月31日,增加DL(深度学习)文本分类实现
- 2019年7月5日,增加机器学习垃圾短信分类实现
[1] https://github.com/gaussic/text-classification-cnn-rnn
[2] https://github.com/cjymz886/text-cnn
[3] https://github.com/649453932/Chinese-Text-Classification-Pytorch