中文 | English
现在的大语言模型的参数往往较大,消费级电脑单纯做推理都比较慢,更别说想自己从头开始训练一个模型了。本项目的目标是从0开始训练一个生成式语言模型,包括数据清洗、tokenizer训练、模型预训练、SFT指令微调、RLHF优化等。
ChatLM-mini-Chinese为中文对话小模型,模型参数只有0.2B(算共享权重约210M),可以在最低4GB显存的机器进行预训练(batch_size=1
,fp16
或者 bf16
),float16
加载、推理最少只需要512MB显存。
Huggingface
NLP框架,包括transformers
、accelerate
、trl
、peft
等。trainer
,支持单机单卡、单机多卡进行预训练、SFT微调。训练过程中支持在任意位置停止,及在任意位置继续训练。Text-to-Text
预训练,非mask
掩码预测预训练。
sentencepiece
、huggingface tokenizers
的tokenizer训练;batch_size=1, max_len=320
下,最低支持在16GB内存+4GB显存的机器上进行预训练;trainer
支持prompt指令微调, 支持任意断点继续训练;Huggingface trainer
的sequence to sequence
微调;peft lora
进行偏好优化;Lora adapter
合并到原始模型中。如果需要做基于小模型的检索增强生成(RAG),可以参考我的另一个项目Phi2-mini-Chinese,代码见rag_with_langchain.ipynb
?最近更新
所有数据集均来自互联网公开的单轮对话数据集,经过数据清洗、格式化后保存为parquet文件。数据处理过程见utils/raw_data_process.py
。主要数据集包括:
Belle_open_source_1M
、train_2M_CN
、及train_3.5M_CN
中部分回答较短、不含复杂表格结构、翻译任务(没做英文词表)的数据,共370万行,清洗后剩余338万行。N
个词为回答,使用202309
的百科数据,清洗后剩余119万的词条提示语和回答。Wiki下载:zhwiki,将下载的bz2文件转换为wiki.txt参考:WikiExtractor。数据集总数量1023万:Text-to-Text预训练集:930万,评估集:2.5万(因为解码较慢,所以没有把评估集设置太大)。测试集:90万。
SFT微调和DPO优化数据集见下文。
T5模型(Text-to-Text Transfer Transformer),详情见论文: Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer。
模型源码来自huggingface,见:T5ForConditionalGeneration。
模型配置见model_config.json,官方的T5-base
:encoder layer
和decoder layer
均为为12层,本项目这两个参数修改为10层。
模型参数:0.2B。词表大小:29298,仅包含中文和少量英文。
硬件:
# 预训练阶段:
CPU: 28 vCPU Intel(R) Xeon(R) Gold 6330 CPU @ 2.00GHz
内存:60 GB
显卡:RTX A5000(24GB) * 2
# sft及dpo阶段:
CPU: Intel(R) i5-13600k @ 5.1GHz
内存:32 GB
显卡:NVIDIA GeForce RTX 4060 Ti 16GB * 1
tokenizer 训练: 现有tokenizer
训练库遇到大语料时存在OOM问题,故全量语料按照类似BPE
的方法根据词频合并、构造词库,运行耗时半天。
Text-to-Text 预训练:学习率为1e-4
到5e-3
的动态学习率,预训练时间为8天。训练损失:
belle
指令训练数据集(指令和回答长度都在512以下),学习率为1e-7
到5e-5
的动态学习率,微调时间2天。微调损失:chosen
文本,步骤2
中SFT模型对数据集中的prompt做批量generate
,得到rejected
文本,耗时1天,dpo全量偏好优化,学习率le-5
,半精度fp16
,共2
个epoch
,耗时3h。dpo损失:默认使用huggingface transformers
的 TextIteratorStreamer
实现流式对话,只支持greedy search
,如果需要beam sample
等其他生成方式,请将cli_demo.py
的stream_chat
参数修改为False
。
存在问题:预训练数据集只有900多万,模型参数也仅0.2B,不能涵盖所有方面,会有答非所问、废话生成器的情况。
如果无法连接huggingface,请使用modelscope.snapshot_download
从modelscope下载模型文件。
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
model_id = 'charent/ChatLM-mini-Chinese'
# 如果无法连接huggingface,打开以下两行代码的注释,将从modelscope下载模型文件,模型文件保存到'./model_save'目录
# from modelscope import snapshot_download
# model_id = snapshot_download(model_id, cache_dir='./model_save')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, trust_remote_code=True).to(device)
txt = '如何评价Apple这家公司?'
encode_ids = tokenizer([txt])
input_ids, attention_mask = torch.LongTensor(encode_ids['input_ids']), torch.LongTensor(encode_ids['attention_mask'])
outs = model.my_generate(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
max_seq_len=256,
search_type='beam',
)
outs_txt = tokenizer.batch_decode(outs.cpu().numpy(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
print(outs_txt[0])
Apple是一家专注于设计和用户体验的公司,其产品在设计上注重简约、流畅和功能性,而在用户体验方面则注重用户的反馈和使用体验。作为一家领先的科技公司,苹果公司一直致力于为用户提供最优质的产品和服务,不断推陈出新,不断创新和改进,以满足不断变化的市场需求。
在iPhone、iPad和Mac等产品上,苹果公司一直保持着创新的态度,不断推出新的功能和设计,为用户提供更好的使用体验。在iPad上推出的iPad Pro和iPod touch等产品,也一直保持着优秀的用户体验。
此外,苹果公司还致力于开发和销售软件和服务,例如iTunes、iCloud和App Store等,这些产品在市场上也获得了广泛的认可和好评。
总的来说,苹果公司在设计、用户体验和产品创新方面都做得非常出色,为用户带来了许多便利和惊喜。
Caution
本项目模型为TextToText
模型,在预训练、SFT、RLFH阶段的prompt
、response
等字段,请务必加上[EOS]
序列结束标记。
git clone --depth 1 https://github.com/charent/ChatLM-mini-Chinese.git
cd ChatLM-mini-Chinese
本项目推荐使用python 3.10
,过老的python版本可能不兼容所依赖的第三方库。
pip安装:
pip install -r ./requirements.txt
如果pip安装了CPU版本的pytorch,可以通过下面的命令安装CUDA版本的pytorch:
# pip 安装torch + cu118
pip3 install torch --index-url https://download.pytorch.org/whl/cu118
conda安装:
conda install --yes --file ./requirements.txt
用git
命令从Hugging Face Hub
下载模型权重及配置文件,需要先安装Git LFS,然后运行:
# 使用git命令下载huggingface模型,先安装[Git LFS],否则下载的模型文件不可用
git clone --depth 1 https://huggingface.co/charent/ChatLM-mini-Chinese
# 如果无法连接huggingface,请从modelscope下载
git clone --depth 1 https://www.modelscope.cn/charent/ChatLM-mini-Chinese.git
mv ChatLM-mini-Chinese model_save
也可以直接从Hugging Face Hub
仓库ChatLM-Chinese-0.2B手工下载,将下载的文件移动到model_save
目录下即可。
语料要求尽可能全,建议添加多个语料,如百科、代码、论文、博客、对话等。
本项目以wiki中文百科为主。获取中文wiki语料方法:中文Wiki下载地址:zhwiki,下载zhwiki-[存档日期]-pages-articles-multistream.xml.bz2
文件,大概2.7GB, 将下载的bz2文件转换为wiki.txt参考:WikiExtractor,再利用python的OpenCC
库转换为简体中文,最后将得到的wiki.simple.txt
放到项目根目录的data
目录下即可。多个语料请自行合并为一个txt
文件。
由于训练tokenizer非常耗内存,如果你的语料非常大(合并后的txt
文件超过2G),建议对语料按照类别、比例进行采样,以减少训练时间和内存消耗。训练1.7GB的txt
文件需要消耗48GB左右的内存(预估的,我只有32GB,频繁触发swap,电脑卡了好久T_T),13600k cpu耗时1小时左右。
char level
和byte level
的区别如下(具体使用上的区别请自行检索资料)。默认训练char level
的tokenizer,如果需要byte level
,在train_tokenizer.py
中设置token_type='byte'
即可。
# 原始文本
txt = '这是一段中英混输的句子, (chinese and English, here are words.)'
tokens = charlevel_tokenizer.tokenize(txt)
print(tokens)
# char level tokens输出
# ['▁这是', '一段', '中英', '混', '输', '的', '句子', '▁,', '▁(', '▁ch', 'inese', '▁and', '▁Eng', 'lish', '▁,', '▁h', 'ere', '▁', 'are', '▁w', 'ord', 's', '▁.', '▁)']
tokens = bytelevel_tokenizer.tokenize(txt)
print(tokens)
# byte level tokens输出
# ['Ġè¿Ļæĺ¯', 'ä¸Ģ段', 'ä¸Ńèĭ±', 'æ··', 'è¾ĵ', 'çļĦ', 'åı¥åŃIJ', 'Ġ,', 'Ġ(', 'Ġch', 'inese', 'Ġand', 'ĠEng', 'lish', 'Ġ,', 'Ġh', 'ere', 'Ġare', 'Ġw', 'ord', 's', 'Ġ.', 'Ġ)']
开始训练:
# 确保你的训练语料`txt`文件已经data目录下
python train_tokenizer.py
{
"prompt": "对于花园街,你有什么了解或看法吗?",
"response": "花园街(是香港油尖旺区的一条富有特色的街道,位于九龙旺角东部,北至界限街,南至登打士街,与通菜街及洗衣街等街道平行。现时这条街道是香港著名的购物区之一。位于亚皆老街以南的一段花园街,也就是"波鞋街"整条街约150米长,有50多间售卖运动鞋和运动用品的店舖。旺角道至太子道西一段则为排档区,售卖成衣、蔬菜和水果等。花园街一共分成三段。明清时代,花园街是芒角村栽种花卉的地方。此外,根据历史专家郑宝鸿的考证:花园街曾是1910年代东方殷琴拿烟厂的花园。纵火案。自2005年起,花园街一带最少发生5宗纵火案,当中4宗涉及排档起火。2010年。2010年12月6日,花园街222号一个卖鞋的排档于凌晨5时许首先起火,浓烟涌往旁边住宅大厦,消防接报4"
}
jupyter-lab 或者 jupyter notebook:
见文件train.ipynb
,推荐使用jupyter-lab,避免考虑与服务器断开后终端进程被杀的情况。
控制台:
控制台训练需要考虑连接断开后进程被杀的,推荐使用进程守护工具Supervisor
或者screen
建立连接会话。
首先要配置accelerate
,执行以下命令, 根据提示选择即可,参考accelerate.yaml
,注意:DeepSpeed在Windows安装比较麻烦。
accelerate config
开始训练,如果要使用工程提供的配置请在下面的命令accelerate launch
后加上参数--config_file ./accelerate.yaml
,该配置按照单机2xGPU配置。
预训练有两个脚本,本项目实现的trainer对应train.py
,huggingface实现的trainer对应pre_train.py
,用哪个都可以,效果一致。本项目实现的trainer训练信息展示更美观、更容易修改训练细节(如损失函数,日志记录等),均支持断点继续训练,本项目实现的trainer支持在任意位置断点后继续训练,按ctrl+c
退出脚本时会保存断点信息。
单机单卡:
# 本项目实现的trainer
accelerate launch ./train.py train
# 或者使用 huggingface trainer
python pre_train.py
单机多卡:
2
为显卡数量,请根据自己的实际情况修改。
# 本项目实现的trainer
accelerate launch --multi_gpu --num_processes 2 ./train.py train
# 或者使用 huggingface trainer
accelerate launch --multi_gpu --num_processes 2 pre_train.py
从断点处继续训练:
# 本项目实现的trainer
accelerate launch --multi_gpu --num_processes 2 ./train.py train --is_keep_training=True
# 或者使用 huggingface trainer
# 需要在`pre_train.py`中的`train`函数添加`resume_from_checkpoint=True`
accelerate launch --multi_gpu --num_processes 2 pre_train.py
SFT数据集全部来自BELLE大佬的贡献,感谢。SFT数据集分别为:generated_chat_0.4M、train_0.5M_CN和train_2M_CN,清洗后剩余约137万行。 sft指令微调数据集示例:
{
"prompt": "解释什么是欧洲启示录",
"response": "欧洲启示录(The Book of Revelation)是新约圣经的最后一卷书,也被称为《启示录》、《默示录》或《约翰默示录》。这本书从宗教的角度描述了世界末日的来临,以及上帝对世界的审判和拯救。 书中的主题包括来临的基督的荣耀,上帝对人性的惩罚和拯救,以及魔鬼和邪恶力量的存在。欧洲启示录是一个充满象征和暗示的文本,对于解读和理解有许多不同的方法和观点。"
}
参考data
目录下的示例parquet
文件制作自己的数据集,数据集格式:parquet
文件分两列,一列prompt
文本,表示提示语,一列response
文本,表示期待的模型输出。
微调细节见model/trainer.py
下的train
方法, is_finetune
设置为True
时,将进行微调,微调默认会冻结embedding层和encoder层,只训练decoder层。如需要冻结其他参数,请自行调整代码。
运行SFT微调:
# 本项目实现的trainer, 添加参数`--is_finetune=True`即可, 参数`--is_keep_training=True`可从任意断点处继续训练
accelerate launch --multi_gpu --num_processes 2 ./train.py --is_finetune=True
# 或者使用 huggingface trainer, 多GPU请用accelerate launch --multi_gpu --num_processes gpu个数 sft_train.py
python sft_train.py
偏好方法这里介绍常见的两种:PPO和DPO,具体实现请自行搜索论文及博客。
PPO方法(近似偏好优化,Proximal Policy Optimization)
步骤1:使用微调数据集做有监督微调(SFT, Supervised Finetuning)。
步骤2:使用偏好数据集(一个prompt至少包含2个回复,一个想要的回复,一个不想要的回复。多个回复可以按照分数排序,最想要的分数最高)训练奖励模型(RM, Reward Model)。可使用peft
库快速搭建Lora奖励模型。
步骤3:利用RM对SFT模型进行有监督PPO训练,使得模型满足偏好。
使用DPO(直接偏好优化,Direct Preference Optimization)微调(本项目采用DPO微调方法,比较节省显存)
在获得SFT模型的基础上,无需训练奖励模型,取得正向回答(chosen)和负向回答(rejected)即可开始微调。微调的chosen
文本来自原数据集alpaca-gpt4-data-zh,拒绝文本rejected
来自SFT微调1个epoch后的模型输出,另外两个数据集:huozi_rlhf_data_json和rlhf-reward-single-round-trans_chinese,合并后共8万条dpo数据。
dpo数据集处理过程见utils/dpo_data_process.py
。
DPO偏好优化数据集示例:
{
"prompt": "为给定的产品创建一个创意标语。,输入:可重复使用的水瓶。",
"chosen": ""保护地球,从拥有可重复使用的水瓶开始!"",
"rejected": ""让你的水瓶成为你的生活伴侣,使用可重复使用的水瓶,让你的水瓶成为你的伙伴""
}
运行偏好优化:
# 多GPU请用accelerate launch --multi_gpu --num_processes gpu个数 dpo_train.py
python dpo_train.py
确保model_save
目录下有以下文件,这些文件都可以在Hugging Face Hub
仓库ChatLM-Chinese-0.2B中找到:
ChatLM-mini-Chinese
├─model_save
| ├─config.json
| ├─configuration_chat_model.py
| ├─generation_config.json
| ├─model.safetensors
| ├─modeling_chat_model.py
| ├─special_tokens_map.json
| ├─tokenizer.json
| └─tokenizer_config.json
python cli_demo.py
python api_demo.py
API调用示例:
curl --location '127.0.0.1:8812/api/chat'
--header 'Content-Type: application/json'
--header 'Authorization: Bearer Bearer'
--data '{
"input_txt": "感冒了要怎么办"
}'
这里以文本中三元组信息为例,做下游微调。该任务的传统深度学习抽取方法见仓库pytorch_IE_model。抽取出一段文本中所有的三元组,如句子《写生随笔》是冶金工业2006年出版的图书,作者是张来亮
,抽取出三元组(写生随笔,作者,张来亮)
和(写生随笔,出版社,冶金工业)
。
原始数据集为:百度三元组抽取数据集。加工得到的微调数据集格式示例:
{
"prompt": "请抽取出给定句子中的所有三元组。给定句子:《家乡的月亮》是宋雪莱演唱的一首歌曲,所属专辑是《久违的哥们》",
"response": "[(家乡的月亮,歌手,宋雪莱),(家乡的月亮,所属专辑,久违的哥们)]"
}
可以直接使用sft_train.py
脚本进行微调,脚本finetune_IE_task.ipynb里面包含详细的解码过程。训练数据集约17000
条,学习率5e-5
,训练epoch5
。微调后其他任务的对话能力也没有消失。
微调效果:
将百度三元组抽取数据集
公开的dev
数据集作为测试集,对比传统方法pytorch_IE_model。
模型 | F1分数 | 精确率P | 召回率R |
---|---|---|---|
ChatLM-Chinese-0.2B微调 | 0.74 | 0.75 | 0.73 |
ChatLM-Chinese-0.2B无预训练 | 0.51 | 0.53 | 0.49 |
传统深度学习方法 | 0.80 | 0.79 | 80.1 |
备注:ChatLM-Chinese-0.2B无预训练
指直接初始化随机参数,开始训练,学习率1e-4
,其他参数和微调一致。
模型本身没有使用较大的数据集训练,也没有针对回答选择题的指令做微调,C-Eval分数基本上是baseline水平,有需要的可以当个参考。C-Eval评测代码见:eval/c_eavl.ipynb
category | correct | question_count | accuracy |
---|---|---|---|
Humanities | 63 | 257 | 24.51% |
Other | 89 | 384 | 23.18% |
STEM | 89 | 430 | 20.70% |
Social Science | 72 | 275 | 26.18% |
如果你觉得本项目对你有所帮助,欢迎引用。
@misc{Charent2023,
author={Charent Chen},
title={A small chinese chat language model with 0.2B parameters base on T5},
year={2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {url{https://github.com/charent/ChatLM-mini-Chinese}},
}
本项目不承担开源模型和代码导致的数据安全、舆情风险或发生任何模型被误导、滥用、传播、不当利用而产生的风险和责任。