使用微信聊天记录微调大语言模型,我使用了大概2万条整合后的有效数据,最后结果只能说差强人意,但有时候真的很搞笑。
目前项目默认使用chatglm3-6b模型,LoRA方法对sft阶段微调,大约需要16GB显存。也可以使用LLaMA Factory支持的其他模型和方法,占用显存更少,需要自行修改模板的system提示词等相关配置。
需要显存的估算值:
训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B |
---|---|---|---|---|---|---|
全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
必需项 | 至少 | 推荐 |
---|---|---|
python | 3.8 | 3.10 |
torch | 1.13.1 | 2.2.1 |
transformers | 4.37.2 | 4.38.1 |
datasets | 2.14.3 | 2.17.1 |
accelerate | 0.27.2 | 0.27.2 |
peft | 0.9.0 | 0.9.0 |
trl | 0.7.11 | 0.7.11 |
可选项 | 至少 | 推荐 |
---|---|---|
CUDA | 11.6 | 12.2 |
deepspeed | 0.10.0 | 0.13.4 |
bitsandbytes | 0.39.0 | 0.41.3 |
flash-attn | 2.3.0 | 2.5.5 |
git clone https://github.com/xming521/WeClone.git
conda create -n weclone python=3.10
conda activate weclone
cd WeClone
pip install -r requirements.txt
训练以及推理相关配置统一在文件settings.json
请使用PyWxDump提取微信聊天记录。下载软件并解密数据库后,点击聊天备份,导出类型为CSV,可以导出多个联系人或群聊,然后将导出的位于wxdump_tmp/export
的 csv
文件夹放在./data
目录即可,也就是不同人聊天记录的文件夹一起放在 ./data/csv
。 示例数据位于data/example_chat.csv。
项目默认去除了数据中的手机号、身份证号、邮箱、网址。还提供了一个禁用词词库blocked_words,可以自行添加需要过滤的词句(会默认去掉包括禁用词的整句)。
执行 ./make_dataset/csv_to_json.py
脚本对数据进行处理。
在同一人连续回答多句的情况下,有三种处理方式:
文件 | 处理方式 |
---|---|
csv_to_json.py | 用逗号连接 |
csv_to_json-单句回答.py(已废弃) | 只选择最长的回答作为最终数据 |
csv_to_json-单句多轮.py | 放在了提示词的'history'中 |
首选在Hugging Face下载ChatGLM3 模型。如果您在 Hugging Face 模型的下载中遇到了问题,可以通过下述方法使用魔搭社区,后续训练推理都需要先执行export USE_MODELSCOPE_HUB=1
来使用魔搭社区的模型。
由于模型较大,下载过程比较漫长请耐心等待。
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
git lfs install
git clone https://www.modelscope.cn/ZhipuAI/chatglm3-6b.git
-
(可选)修改 settings.json选择本地下载好的其他模型。
-
修改
per_device_train_batch_size
以及gradient_accumulation_steps
来调整显存占用。 -
可以根据自己数据集的数量和质量修改
num_train_epochs
、lora_rank
、lora_dropout
等参数。
运行 src/train_sft.py
进行sft阶段微调,本人loss只降到了3.5左右,降低过多可能会过拟合。
python src/train_sft.py
pip install deepspeed
deepspeed --num_gpus=使用显卡数量 src/train_sft.py
Note
也可以先对pt阶段进行微调,似乎提升效果不明显,仓库也提供了pt阶段数据集预处理和训练的代码。
python ./src/web_demo.py
python ./src/api_service.py
python ./src/api_service.py
python ./src/test_model.py
Important
微信有封号风险,建议使用小号,并且必须绑定银行卡才能使用
python ./src/api_service.py # 先启动api服务
python ./src/wechat_bot/main.py
默认在终端显示二维码,扫码登录即可。可以私聊或者在群聊中@机器人使用。
Todo
Todo