定位:代码复现贴
教程:https://datawhaler.feishu.cn/wiki/PLCHwQ8pai12rEkPzDqcufWKnDd
模型加载
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True
)
-
AutoModelForCausalLM.from_pretrained(path)
:- 这是
transformers
库中的一种通用方法,用于从预训练模型路径(path
)加载一个因果语言模型(Causal Language Model,CLM)。 - 因果语言模型是一种序列到序列的模型,通常用于生成任务,例如自动完成或文本生成。
- 这是
-
device_map="auto"
:- 该参数用于自动选择计算设备(如 GPU 或 CPU)来加载模型。设置为
"auto"
后,模型会根据可用资源自动映射到适当的设备。
- 该参数用于自动选择计算设备(如 GPU 或 CPU)来加载模型。设置为
-
torch_dtype=torch.bfloat16
:- 这将模型的计算精度设置为
bfloat16
(一种 16 位浮点格式),这通常用于加速计算和减少显存占用,同时保持数值稳定性。
- 这将模型的计算精度设置为
-
trust_remote_code=True
:- 这个参数表示信任远程代码,允许加载自定义模型结构。如果预训练模型所在的路径中包含自定义的模型定义文件(而不是标准的
transformers
库模型),这个选项允许这些自定义代码被执行。
- 这个参数表示信任远程代码,允许加载自定义模型结构。如果预训练模型所在的路径中包含自定义的模型定义文件(而不是标准的
输出的模型如下:
模型结构分析
Yuan 在 Transformer 的 Decoder 进行改进,引入了一种新的注意力机制 Localized Filtering-based Attention(LFA)
-
YuanForCausalLM
:- 这是一个自定义的因果语言模型类,可能来自于远程代码定义。该模型包含了实际的
YuanModel
和一个lm_head
(语言模型的输出头)。
- 这是一个自定义的因果语言模型类,可能来自于远程代码定义。该模型包含了实际的
-
YuanModel
:- 该模型是
YuanForCausalLM
的核心部分,包含嵌入层、多个解码器层(YuanDecoderLayer
)、和一个归一化层。
- 该模型是
-
embed_tokens
:- 这是词嵌入层,用于将输入的标记(tokens)转换为高维向量表示。这里的词表大小为
135040
,每个标记被嵌入到一个2048
维的向量空间中。
- 这是词嵌入层,用于将输入的标记(tokens)转换为高维向量表示。这里的词表大小为
-
layers
:- 这是模型的主体,由
24
个YuanDecoderLayer
组成,每个解码器层包含自注意力机制、MLP(多层感知器)层、和归一化层。
- 这是模型的主体,由
-
YuanAttention
:- 这是一个自注意力机制模块,包含了查询(
q_proj
)、键(k_proj
)、值(v_proj
)的线性投影,以及一个旋转嵌入(rotary_emb
)和本地过滤模块(lf_gate
)。
- 这是一个自注意力机制模块,包含了查询(
-
YuanMLP
:- 这是一个 MLP 层,包含了向上和向下的线性投影(
up_proj
和down_proj
),以及一个激活函数SiLU
。
- 这是一个 MLP 层,包含了向上和向下的线性投影(
-
YuanRMSNorm
:- 这是一个归一化层,使用 RMSNorm(Root Mean Square Layer Normalization)来稳定训练过程。
-
lm_head
:- 这是模型的输出层,用于将解码器层的输出转换为预测的词概率分布。它是一个线性层,输入维度为
2048
,输出维度为135040
(与词表大小一致)。
- 这是模型的输出层,用于将解码器层的输出转换为预测的词概率分布。它是一个线性层,输入维度为
配置Lora
from peft import LoraConfig, TaskType, get_peft_modelconfig = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],inference_mode=False, # 训练模式r=8, # Lora 秩lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.1# Dropout 比例
)
我们输出config,可以观测到其中的完整配置选项。
LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=8, target_modules={'k_proj', 'down_proj', 'o_proj', 'up_proj', 'gate_proj', 'v_proj', 'q_proj'},lora_alpha=32, lora_dropout=0.1, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, # <=== doralayer_replication=None, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=False))
没想到后面还有一个use_dora的选项,碰巧之前浏览过这块,可以分享一下:
DoRA
首先对预训练模型的权重进行分解,将每个权重矩阵分解为幅度(magnitude)向量和方向(direction)矩阵
在微调过程中,DoRA使用LoRA进行方向性更新,只调整方向部分的参数,而保持幅度部分不变。这种方式可以减少需要调整的参数数量,提高微调的效率。
后面,我们构建一个 PeftModel并且查看对应的训练参数量占比:
# 构建PeftModel
model = get_peft_model(model, config)
model.print_trainable_parameters()
输出如下:
trainable params: 9,043,968 || all params: 2,097,768,448 || trainable%: 0.4311
总参数量为 2,097,768,448(~ 21亿参数),使用LoRA后只需要微调的参数量为 9,043,968(~904万参数),约占总参数量的0.4311%
但是后面微调还是爆了,所以稍微去除一点不太重要的微调目标模块(个人观点),但是肯定会损耗微调性能的。
config = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "k_proj", "v_proj"],inference_mode=False, # 训练模式r=4, # Lora 秩lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.1# Dropout 比例
)
后续输出微调的参数占比为:
trainable params: 2,359,296 || all params: 2,091,083,776 || trainable%: 0.1128
当然,也降低了批处理大小 (牺牲速度):
# 设置训练参数
args = TrainingArguments(output_dir="./output/Yuan2.0-2B_lora_bf16",per_device_train_batch_size=6, # <===== 12gradient_accumulation_steps=1,logging_steps=1,save_strategy="epoch",num_train_epochs=3,learning_rate=5e-5,save_on_each_node=True,gradient_checkpointing=True,bf16=True
)
微调成功之后效果如下,即便增加了一些其他信息,也能保持相关的抽取。
(但是多次几次依旧容易翻车,会输出极其符合数据集分布的答案。)
数据集中的组织名和姓名是互斥的,且中国难识别归类到国籍。
关于更多的微调知识,感觉可以参考这篇知乎大佬的笔记:https://zhuanlan.zhihu.com/p/696837567