本篇文章着重于chatGPT训练流程的复现
来自:无数据不智能
进NLP群—>加入NLP交流群
环境安装
虚拟环境创建
conda create -n chatgpt python=3.10
conda activate chatgpt
依赖包安装
git clone https://github.com/LAION-AI/Open-Assistant.git
cd Open-Assistat/model
pip install -r model_training/requirements.txt
pip install -r reward/instructor/requirements.txt
安装trlx
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
pip install -e .
在Open-Assistant目录下,安装oasst-shared
cd oasst-shared/
pip install -e .
SFT
以翻译为例,prompt:
"zh": [ "翻译成中文: {}","{} 这句中文翻译怎麽写?","我需要这句话的中文翻译: {}",]
数据样例
[
"<human>+随机选择一个prompt.format(原句)+<bot>",
"译句"
]
训练脚本
mkdir cache
mkdir sft_model
python trainer_sft.py --configs defaults pythia --cache_dir ./cache --output_dir ./sft_model
配置文件
defaults:learning_rate: 1e-5gradient_checkpointing: falsegradient_accumulation_steps: 32per_device_train_batch_size: 2per_device_eval_batch_size: 2weight_decay: 0.00warmup_steps: 600eval_steps: 500save_steps: 500max_length: 512num_train_epochs: 3logging_steps: 10max_grad_norm: 2.0save_total_limit: 4fp16: falseeval_accumulation_steps:freeze_layer:datasets:- webgpt- squad_v2cache_dir: .cacheloss_fn: CrossEntropyLosseval_size:log_dir: "base"quantization: falseseq2seqmodel: falsepoly_eps: 1.0fuse_gelu: truelog_wandb: truesamples_mixing: false # uses collator that mixes samples in the batch to create a single sample with possible multiple tasks withinverbose: falseoutput_dir: saved_modelpythia:learning_rate: 8e-6model_name: EleutherAI/pythia-70m-dedupedweight_decay: 0.01max_length: 520warmup_steps: 1000gradient_checkpointing: falsegradient_accumulation_steps: 9per_device_train_batch_size: 2per_device_eval_batch_size: 4output_dir: pythia_model
RM
数据样例
{
"question full text":["答案1","答案2"] # 跟据分数排名
}
训练脚本
cd ../reward/instructor
mkdir model
python trainer.py configs/deberta-v3-base.yml --output_dir ./reward_model
配置文件
model_name: microsoft/deberta-v3-base
learning_rate: 1e-5
scheduler: cosine
gradient_checkpointing: false
gradient_accumulation_steps: 16
per_device_train_batch_size: 2
warmup_steps: 600
eval_steps: 200
save_steps: 500
max_length: 512
num_train_epochs: 2
datasets:- webgpt- hfsummary
RL
数据样例
"<human>+随机选择一个prompt.format(原句)+<bot>"
训练脚本
cd ../../model_training
python trainer_rl.py --configs defaults_rlhf --cache_dir ./cache --rank_model ../reward/instructor/reward_model --sft_model ../model_training/sft_model --output_dir ./rl_model
配置文件
defaults_rlhf:dataset:rank_model: TODOsft_model: TODOeval_prompts:batch_size: 64epochs: 10datasets:- oa_private:split: rlval_split: 0.0fraction: 1file: 2023-02-10_oasst_prod.jsonlcache_dir: .cachequantization: falseseq2seqmodel: falseoutput_dir: outputreward_model_batch_size: 32debug_rlhf:rank_model: /local/home/sanagnos/general/Open-Assistant/model/reward/instructor/facebook/galactica-125m-finetuned/checkpoint-500/sft_model: /local/home/sanagnos/general/Open-Assistant/model/model_training/EleutherAI/pythia-70m-deduped-base-finetuned/checkpoint-20/batch_size: 2
相关链接
CarperAI/trlx: A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF) (github.com)
microsoft/DeepSpeed: DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective. (github.com)
TimDettmers/bitsandbytes: 8-bit CUDA functions for PyTorch (github.com)
huggingface/evaluate: 🤗 Evaluate: A library for easily evaluating machine learning models and datasets. (github.com)
wkentaro/gdown: Download a large file from Google Drive (curl/wget fails because of the security notice). (github.com)
wandb/wandb: 🔥 A tool for visualizing and tracking your machine learning experiments. This repo contains the CLI and Python API. (github.com)
huggingface/transformers: 🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX. (github.com)
pytorch/pytorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration (github.com)
进NLP群—>加入NLP交流群(备注nips/emnlp/nlpcc进入对应投稿群)
持续发布自然语言处理NLP每日优质论文解读、相关一手资料、AI算法岗位等最新信息。
加入星球,你将获得:
1. 每日更新3-5篇最新最优质的的论文速读。用几秒钟就可掌握论文大致内容,包含论文一句话总结、大致内容、研究方向以及pdf下载。
2. 最新入门和进阶学习资料。包含机器学习、深度学习、NLP等领域。
3. 具体细分NLP方向包括不限于:情感分析、关系抽取、知识图谱、句法分析、语义分析、机器翻译、人机对话、文本生成、命名实体识别、指代消解、大语言模型、零样本学习、小样本学习、代码生成、多模态、知识蒸馏、模型压缩、AIGC、PyTorch、TensorFlow等细方向。
4. 每日1-3个NLP、搜广推、CV等AI岗位招聘信息。可安排模拟面试。