1,是什么
知识蒸馏(Knowledge Distillation)是一种模型压缩和加速技术,旨在将大型模型(通常称为教师模型)所学到的知识迁移到小型模型(通常称为学生模型)中,从而让小型模型在减少计算资源消耗和推理时间的同时,尽可能达到接近大型模型的性能。
2,为什么
知识蒸馏可以用小模型来实现(接近)大模型的推理效果,用于加快应用落地,尤其是移动端的应用落地
3,怎么做
用教师模型创建软标签,学生模型结合软标签以及Ground truth来计算损失函数,损失函数是比如KL散度,然后反向传播调整权重;
4,实际操作
用教师模型生成QA对,然后通过微调过程来对学生模型进行微调(让学生模型尽可能作出与教师模型一样的回答)
5,与传统模型训练的区别:
目标同样是学生模型,不同的是损失函数的计算需要用到教师模型的输出
6,用Trainer or STFTrainer:
Trainer更加通用,可配置的参数较多,需要对数据集进行tokenize预处理,可能需要写继承自Dataset的自己的数据集类
STFTrainer继承Trainer,在少量数据集上表现更加高效率以及更少的内存占用,接受对话式的数据集,对话式数据集不需要对数据集进行预处理,接受文本列表,json,参考:
监督微调训练器 - Hugging Face 机器学习平台
如果用了LoRA,最后需要合并模型参数
还可以用LLama Factory,微调、训练、评估框架
7,LLamaFactory使用
Package Version Editable project location
------------------ ------------ ----------------------------------------------accelerate 1.2.1
aiofiles 23.2.1
aiohappyeyeballs 2.4.6
aiohttp 3.11.12
aiohttp-cors 0.7.0
aiosignal 1.3.2
airportsdata 20241001
annotated-types 0.7.0
anyio 4.8.0
astor 0.8.1
async-timeout 5.0.1
attrs 25.1.0
audioread 3.0.1
av 14.1.0
bitsandbytes 0.45.2
blake3 1.0.4
cachetools 5.5.1
certifi 2025.1.31
cffi 1.17.1
charset-normalizer 3.4.1
click 8.1.8
cloudpickle 3.1.1
colorful 0.5.6
compressed-tensors 0.9.1
contourpy 1.3.1
cycler 0.12.1
datasets 3.2.0
decorator 5.1.1
deepspeed 0.16.2
depyf 0.18.0
dill 0.3.8
diskcache 5.6.3
distlib 0.3.9
distro 1.9.0
docstring_parser 0.16
einops 0.8.1
exceptiongroup 1.2.2
fastapi 0.115.8
ffmpy 0.5.0
filelock 3.17.0
fire 0.7.0
fonttools 4.56.0
frozenlist 1.5.0
fsspec 2024.9.0
gguf 0.10.0
google-api-core 2.24.1
google-auth 2.38.0
googleapis-common-protos 1.66.0
gradio 5.12.0
gradio_client 1.5.4
grpcio 1.70.0
h11 0.14.0
hjson 3.1.0
httpcore 1.0.7
httptools 0.6.4
httpx 0.28.1
huggingface-hub 0.28.1
idna 3.10
importlib_metadata 8.6.1
iniconfig 2.0.0
interegular 0.3.3
jieba 0.42.1
Jinja2 3.1.5
jiter 0.8.2
joblib 1.4.2
jsonschema 4.23.0
jsonschema-specifications 2024.10.1
kiwisolver 1.4.8
lark 1.2.2
lazy_loader 0.4
librosa 0.10.2.post1
llamafactory 0.9.2.dev0 /media/PycharmProjects/llama_factory/LLaMa-Factory-Ubuntu
llvmlite 0.44.0
lm-format-enforcer 0.10.9
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.10.0
mdurl 0.1.2
mistral_common 1.5.3
modelscope 1.22.3
mpmath 1.3.0
msgpack 1.1.0
msgspec 0.19.0
multidict 6.1.0
multiprocess 0.70.16
nest-asyncio 1.6.0
networkx 3.4.2
ninja 1.11.1.3
nltk 3.9.1
numba 0.61.0
numpy 1.26.4
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-ml-py 12.570.86
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
openai 1.61.1
opencensus 0.11.4
opencensus-context 0.1.3
opencv-python-headless 4.11.0.86
orjson 3.10.15
outlines 0.1.11
outlines_core 0.1.26
packaging 24.2
pandas 2.2.3
partial-json-parser 0.2.1.1.post5
peft 0.12.0
pillow 11.1.0
pip 24.2
platformdirs 4.3.6
pluggy 1.5.0
pooch 1.8.2
prometheus_client 0.21.1
prometheus-fastapi-instrumentator 7.0.2
propcache 0.2.1
proto-plus 1.26.0
protobuf 5.29.3
psutil 6.1.1
py-cpuinfo 9.0.0
py-spy 0.4.0
pyarrow 19.0.0
pyasn1 0.6.1
pyasn1_modules 0.4.1
pybind11 2.13.6
pycountry 24.6.1
pycparser 2.22
pydantic 2.10.6
pydantic_core 2.27.2
pydub 0.25.1
Pygments 2.19.1
pyparsing 3.2.1
pytest 8.3.4
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-multipart 0.0.20
pytz 2025.1
PyYAML 6.0.2
pyzmq 26.2.1
ray 2.42.0
referencing 0.36.2
regex 2024.11.6
requests 2.32.3
rich 13.9.4
rouge-chinese 1.0.3
rpds-py 0.22.3
rsa 4.9
ruff 0.9.5
safehttpx 0.1.6
safetensors 0.5.2
scikit-learn 1.6.1
scipy 1.15.1
semantic-version 2.10.0
sentencepiece 0.2.0
setuptools 75.2.0
shellingham 1.5.4
shtab 1.7.1
six 1.17.0
smart-open 7.1.0
sniffio 1.3.1
soundfile 0.13.1
soxr 0.5.0.post1
sse-starlette 2.2.1
starlette 0.45.3
sympy 1.13.1
termcolor 2.5.0
threadpoolctl 3.5.0
tiktoken 0.8.0
tokenizers 0.21.0
tomli 2.2.1
tomlkit 0.13.2
torch 2.5.1
torchaudio 2.5.1
torchvision 0.20.1
tqdm 4.67.1
transformers 4.48.3
triton 3.1.0
trl 0.9.6
typer 0.15.1
typing_extensions 4.12.2
tyro 0.8.14
tzdata 2025.1
urllib3 2.3.0
uvicorn 0.34.0
uvloop 0.21.0
virtualenv 20.29.1
vllm 0.7.2
watchfiles 1.0.4
websockets 14.2
wrapt 1.17.2
xformers 0.0.28.post3
xgrammar 0.1.11
xxhash 3.5.0
yarl 1.18.3
zipp 3.21.0
我的机器Titan XP用不了fatten attention,因为架构太老,vllm也用不了
按照官方文档进行微调即可,自定义的数据集要把格式弄成sharegpt那样的。微调用lora或者全量。
lora微调的话,输出的文件是适配器文件,也就是仅包含微调参数的文件,测试时,可以分别加载原始模型和适配器文件,也可以合并原始模型和适配器文件之后,再加载合并后的模型来进行推理;
分布微调:比如在两台机器上做微调,用llamafactory-cli没有成功,用deepspeed成功了,脚本如下:
deepspeed \
--hostfile hostfile \
--no_ssh \
--node_rank=0 \
--master_addr 192.168.10.1 \
--master_port=9900 \
src/train.py \
--stage sft \
--finetuning_type lora \
--lora_rank 8 \
--lora_target all \
--model_name_or_path /media/PycharmProjects/QWen2.5-0.5B-Instruct \
--template qwen \
--do_train true \
--dataset gx_bank_data \
--cutoff_len 2048 \
--max_samples 1000000 \
--preprocessing_num_workers 16 \
--output_dir saves/qwen2.5-0.5B/lora/sft-dist \
--overwrite_cache true \
--overwrite_output_dir true \
--plot_loss true \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--learning_rate 1.0e-4 \
--num_train_epochs 3.0 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 500 \
--learning_rate 1e-4 \
--bf16 true \
--warmup_ratio 0.1 \
--deepspeed examples/deepspeed/ds_z0_config.json
hostfile的内容:
192.168.10.1 slots=1
192.168.10.2 slots=1
主机是ip为192.168.10.1,有一块GPU,从机是192.168.10.2,ZeRO 阶段选择z0,因为两个机器的显存都能加载完模型参数,因此选择z0,如果选z3,显存占用会降低,但是速度会比较慢
注意脚本中没有加上--num_nodes 2 和--num_gpus 2,加了会不成功,原因未知。
8,评测
每一种数据集评测方法不一样,比如文本生成,指标用BLEU,看回答跟数据集的相似度和流畅度有多少,如果是题目,只需要回答A,B,C,D这种,那就要提示大模型,仅需要回答A,B,C,D,然后用回答跟数据集的文本做比较;
如果是代码类的数据集,就要看生成的代码运行结果是否跟标答的运行结果一致,指标是pass@k
参考:
ceval/README_zh.md at main · hkust-nlp/ceval · GitHub
https://zhuanlan.zhihu.com/p/691397120
开源的评测框架:
opencompass/README_zh-CN.md at main · open-compass/opencompass · GitHub