几种预训练模型微调方法和peft包的使用介绍

文章目录

  • 微调方法
    • Lora(在旁边添加训练参数)
    • Adapter(在前面添加训练参数)
    • Prefix-tuning(在中间添加训练参数)
    • Prompt tuning
  • PEFT
    • PEFT 使用
      • PeftConfig
      • PeftModel
      • 保存和加载模型

微调方法

现流行的微调方法有:Lora,prompt,p-tunning v1,p-tunning v2,prefix,adapter等等,下面抱着学习的心态进行宏观层面的介绍
如有错误,欢迎指出

Lora(在旁边添加训练参数)

LoRA(Low-Rank Adaptation)是一种技术,通过低秩分解将权重更新表示为两个较小的矩阵(称为更新矩阵),从而加速大型模型的微调,并减少内存消耗。

为了使微调更加高效,LoRA的方法是通过低秩分解,使用两个较小的矩阵(称为更新矩阵)来表示权重更新。这些新矩阵可以通过训练适应新数据,同时保持整体变化的数量较少。原始的权重矩阵保持冻结,不再接收任何进一步的调整。为了产生最终结果,同时使用原始和适应后的权重进行合并。
在这里插入图片描述
假设权重的更新在适应过程中也具有较低的“内在秩”。对于一个预训练的权重矩阵W0 ∈ Rd×k,我们通过低秩分解W0 + ∆W = W0 + BA来表示其更新,其中B ∈ Rd×r,A ∈ Rr×k,且秩r ≤ min(d,k)。在训练过程中,W0被冻结,不接收梯度更新,而A和B包含可训练参数。需要注意的是,W0和∆W = BA都与相同的输入进行乘法运算,它们各自的输出向量在坐标上求和。前向传播公式如下:h = W0x + ∆Wx = W0x + BAx

在上图中我们对A使用随机高斯随机分布初始化,对B使用零初始化,因此在训练开始时∆W = BA为零。然后,通过αr对∆Wx进行缩放,其中α是r中的一个常数。在使用Adam优化时,适当地缩放初始化,调整α的过程与调整学习率大致相同。因此,只需将α设置为我们尝试的第一个r,并且不对其进行调整。这种缩放有助于在改变r时减少重新调整超参数的需求。

Adapter(在前面添加训练参数)

2019 年,Houlsby N 等人将 Adapter 引入 NLP 领域,作为全模型微调的一种替代方案。Adapter 主体架构下图所示。
在这里插入图片描述
在预训练模型每一层(或某些层)中添加 Adapter 模块(如上图左侧结构所示),微调时冻结预训练模型主体,由 Adapter 模块学习特定下游任务的知识。每个 Adapter 模块由两个前馈子层组成,第一个前馈子层将 Transformer 块的输出作为输入,将原始输入维度 d 投影到 m,通过控制 m 的大小来限制 Adapter 模块的参数量,通常情况下 m<<d。在输出阶段,通过第二个前馈子层还原输入维度,将 m 重新投影到 d,作为 Adapter 模块的输出(如上图右侧结构)。通过添加 Adapter 模块来产生一个易于扩展的下游模型,每当出现新的下游任务,通过添加 Adapter 模块来避免全模型微调与灾难性遗忘的问题。Adapter 方法不需要微调预训练模型的全部参数,通过引入少量针对特定任务的参数,来存储有关该任务的知识,降低对模型微调的算力要求。

Prefix-tuning(在中间添加训练参数)

前缀微调(prefix-tunning),用于生成任务的轻量微调。前缀微调将一个连续的特定于任务的向量序列添加到输入,称之为前缀,如下图中的红色块所示。与提示(prompt)不同的是,前缀完全由自由参数组成,与真正的 token 不对应。相比于传统的微调,前缀微调只优化了前缀。因此,我们只需要存储一个大型 Transformer 和已知任务特定前缀的副本,对每个额外任务产生非常小的开销。

Prompt tuning

提示通过包括描述任务的文本提示或甚至演示任务示例的文本提示来为特定的下游任务准备一个冻结的预训练模型。具体的,给每个任务定义 Prompt,拼接到数据上作为输入,同时 freeze 预训练模型进行训练,在没有加额外层的情况下,可以看到随着模型体积增大效果越来越好,最终追上了精调的效果。

提示方法可以分为两类:

  • 硬提示(Hard Prompts):手工制作的具有离散输入标记的文本提示;缺点是需要花费很多精力来创建一个好的提示。
  • 软提示(Soft Prompts):可与输入嵌入连接并进行优化以适应数据集的可学习张量;缺点是它们不太易读,因为您不是将这些“虚拟标记”与实际单词的嵌入进行匹配。

PEFT

PEFT(Parameter-Efficient Fine-Tuning,参数高效微调),是一个用于在不微调所有模型参数的情况下,高效地将预训练语言模型(PLM)适应到各种下游应用的库。

PEFT方法仅微调少量(额外的)模型参数,显著降低了计算和存储成本,因为对大规模PLM进行完整微调的代价过高。最近的最先进的PEFT技术实现了与完整微调相当的性能。

代码:
https://github.com/huggingface/peft
文档:
https://huggingface.co/docs/peft/index

PEFT 使用

接下来将展示 PEFT 的主要特点,并帮助在消费设备上通常无法访问的情况下训练大型预训练模型。您将了解如何使用LoRA来训练1.2B参数的bigscience/mt0-large模型,以生成分类标签并进行推理。

PeftConfig

每个 PEFT 方法由一个PeftConfig类来定义,该类存储了用于构建PeftModel的所有重要参数。

由于您将使用LoRA,您需要加载并创建一个LoraConfig类。在LoraConfig中,指定以下参数:

task_type,在本例中为序列到序列语言建模
inference_mode,是否将模型用于推理
r,低秩矩阵的维度
lora_alpha,低秩矩阵的缩放因子
lora_dropout,LoRA层的dropout概率
from peft import LoraConfig, TaskType
peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)

有关您可以调整的其他参数的更多详细信息,请参阅LoraConfig参考。

PeftModel

使用 get_peft_model() 函数可以创建PeftModel。它需要一个基础模型 - 您可以从 Transformers 库加载 - 以及包含配置特定 PEFT 方法的PeftConfig。

首先加载您要微调的基础模型。

from transformers import AutoModelForSeq2SeqLMmodel_name_or_path = "bigscience/mt0-large"
tokenizer_name_or_path = "bigscience/mt0-large"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)

使用get_peft_model函数将基础模型和peft_config包装起来,以创建PeftModel。要了解您模型中可训练参数的数量,可以使用print_trainable_parameters方法。在这种情况下,您只训练了模型参数的0.19%!

from peft import get_peft_modelmodel = get_peft_model(model, peft_config)
model.print_trainable_parameters()
输出示例: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282

至此,我们已经完成了!现在您可以使用Transformers的Trainer、 Accelerate,或任何自定义的PyTorch训练循环来训练模型。

保存和加载模型

在模型训练完成后,您可以使用save_pretrained函数将模型保存到目录中。您还可以使用push_to_hub函数将模型保存到Hub(请确保首先登录您的Hugging Face帐户)。

model.save_pretrained("output_dir")

如果要推送到Hub

from huggingface_hub import notebook_loginnotebook_login()
model.push_to_hub("my_awesome_peft_model")

这只保存了已经训练的增量PEFT权重,这意味着存储、传输和加载都非常高效。例如,这个在RAFT数据集的twitter_complaints子集上使用LoRA训练的bigscience/T0_3B模型只包含两个文件:adapter_config.json和adapter_model.bin,后者仅有19MB!

使用from_pretrained函数轻松加载模型进行推理:

from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfigpeft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)

参考文章:
http://lihuaxi.xjx100.cn/news/1428773.html?action=onClick

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/155124.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Java+SpringBoot+Vue民宿管理系统的设计与实现 前后端分离【Java毕业设计·文档报告·代码讲解·安装调试】

&#x1f34a;作者&#xff1a;计算机编程-吉哥 &#x1f34a;简介&#xff1a;专业从事JavaWeb程序开发&#xff0c;微信小程序开发&#xff0c;定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事&#xff0c;生活就是快乐的。 &#x1f34a;心愿&#xff1a;点…

Linux基本指令(1)

Linux基本指令&#xff08;1&#xff09; 1.ls指令1.1ls的用法 2. pwd指令3.cd指令3.1 cd3.2补充内容3.3 cd - 指令3.4 cd ~ 指令 4. touch指令4.1stat指令 5.mkdir 指令6.rmdir/rm指令6.1补充内容 7.man指令8.nano 指令9.cat指令10 cp指令11 mv指令12 echo指令12.1 > 输出重…

[CSAWQual 2019]Web_Unagi - 文件上传+XXE注入(XML编码绕过)

[CSAWQual 2019]Web_Unagi 1 解题流程1.1 分析1.2 解题2 思考总结1 解题流程 这篇博客讲了xml进行编码转换绕过的原理:https://www.shawroot.cc/156.html 1.1 分析 页面可以上传,上传一句话php失败,点击示例发现是xml格式,那么就是XXE注入了 点击about得到flag位置: Fla…

【数据结构-字符串 四】【字符串识别】字符串转为整数、比较版本号

废话不多说&#xff0c;喊一句号子鼓励自己&#xff1a;程序员永不失业&#xff0c;程序员走向架构&#xff01;本篇Blog的主题是【字符串转换】&#xff0c;使用【字符串】这个基本的数据结构来实现&#xff0c;这个高频题的站点是&#xff1a;CodeTop&#xff0c;筛选条件为&…

redis-6.2.7 集群安装3主3从

因为资源有限准备了3 台 服务器&#xff0c;先查看防火墙的端口是否开放&#xff0c;如果没有开放先开放端口我使用的 6379 和 6380 这两个端口 所以将这两个端口放开。去redis 官网下载redis 安装包。下载地址 &#xff1a; redis 安装包下载 3. 安装redis 上传上去之后 3 台…

4年测试经验,面试却突破不了20K,真是太卷了····

先说一个插曲&#xff1a;上个月我有同学在深圳被裁员了&#xff0c;和我一样都是软件测试&#xff0c;不过他是平安外包&#xff0c;所以整个组都撤了&#xff0c;他工资和我差不多都是14K。 现在IT互联网已经比较寒冬&#xff0c;特别是软件测试&#xff0c;裁员先裁测试&am…

IOday3作业

#include <head.h> int get_filePerrmison(mode_t mode)//获取文件权限 {char per[] "rwx";for(int i0;i<9;i){if((mode&(0400>>i))0){putchar(-);continue;}putchar(per[i%3]);}} int get_fileType(mode_t m) //获取文件类型 {switch(m&S_IF…

【ComfyUI】MacBook Pro 安装(Intel 集成显卡)

文章目录 环境概述配置pip镜像配置pip代理git配置&#xff08;选配&#xff09;下载comfyUI代码创建、激活虚拟环境下载依赖安装torchvision启动comfyUI为什么Mac不支持CUDA&#xff0c;即英伟达的显卡&#xff1f;安装Intel工具包 环境 显卡&#xff1a;Intel Iris Plus Grap…

快速学习微服务保护框架--Sentinel

学习一个框架最好的方式就是查看官方地址,sentinel是国内阿里巴巴公司的,官网更方便官网 官网 微服务保护框架 Sentinel 1.初识Sentinel 1.1.雪崩问题及解决方案 1.1.1.雪崩问题 微服务中&#xff0c;服务间调用关系错综复杂&#xff0c;一个微服务往往依赖于多个其它微…

[Mono Depth/3DOD]单目3D检测基础

1.数据增强 图像放缩和裁剪后&#xff0c;相机内参要做相应变化 import random def random_scale(image, calib, scale_range(0.8, 1.2)):scale random.uniform(*scale_range)width, height image.sizeimage image.resize((int(width * scale), int(height * scale)))cali…

【C++】神奇字符串(力扣481)

神奇字符串的规律&#xff1a; 神奇字符串 s 仅由 ‘1’ 和 ‘2’ 组成&#xff0c;并需要遵守下面的规则&#xff1a; 神奇字符串 s 的神奇之处在于&#xff0c;串联字符串中 1 和 2 的连续出现次数可以生成该字符串。 s 的前几个元素是 s “1221121221221121122……” 。如果…

encoding/json vs json-iterator

encoding/json vs json-iterator 100% Compatibility 默认情况下&#xff0c;jsoniter 不会像标准库那样对映射键进行排序。如果你想要 100% 的兼容性&#xff0c;就这样使用 m : map[string]interface{}{"3": 3,"1": 1,"2": 2, } json : json…

c语言:通讯录管理系统(动态分配内存版)

前言&#xff1a;在大多数高校内&#xff0c;都是通过设计一个通讯录管理系统来作为c语言课程设计&#xff0c;通过一个具体的系统设计将我们学习过的结构体和函数等知识糅合起来&#xff0c;可以很好的锻炼学生的编程思维&#xff0c;本文旨在为通讯录管理系统的设计提供思路和…

2018架构真题案例(四十九)

某文件采用多级索引结构&#xff0c;磁盘大小4K字节&#xff0c;每个块号4字节&#xff0c;那么二级索引结果时&#xff0c;文件最大。 A、1024 B、1024*1024 C、2048*2048 D、4096*4096 答案&#xff1a;B 霍尔三维结构以时间堆、&#xff08;&#xff09;堆、知识堆组成…

R语言实现向量自回归和误差修正模型——附实战代码

大家好&#xff0c;我是带我去滑雪&#xff01; 向量自回归&#xff08;VAR&#xff09;模型和误差修正模型&#xff08;ECM&#xff09;是时间序列分析中常用的两种模型&#xff0c;它们用于研究多个变量之间的动态关系。VAR 模型适用于研究多个相关变量之间的相互影响和动态关…

二叉搜索树--验证二叉搜索树

验证二叉搜索树-力扣 98 题 解题思路&#xff1a;利用二叉树中序遍历的特性&#xff1a;遍历出来的结果是升序的即符合二叉搜索树 对于二叉树中序遍历不是太理解的&#xff0c;作者推荐的小白书&#xff1a;二叉树的初步认识_加瓦不加班的博客-CSDN博客 中序非递归实现 // 解…

剖析伦敦银最新价格走势图

国际金融市场瞬息万变&#xff0c;伦敦银的价格走势会受到诸多因素的影响&#xff0c;比如重要经济数据的公布&#xff0c;国际间的政治博弈&#xff0c;突发的政经大事&#xff0c;都可以令白银价格的走势&#xff0c;在短时间内暴涨暴跌的情况。 要在伦敦银市场实现良好的收益…

CCF CSP认证 历年题目自练Day27

题目一 试题编号&#xff1a; 202104-1 试题名称&#xff1a; 灰度直方图 时间限制&#xff1a; 1.0s 内存限制&#xff1a; 512.0MB 样例输入 7 11 8 0 7 0 0 0 7 0 0 7 7 0 7 0 7 0 7 0 7 0 7 0 7 7 0 0 0 7 0 0 0 7 0 7 7 0 0 0 0 7 0 0 7 7 0 7 0 0 0 0 0 7 0 7 0 0 7 0 …

Apache Doris (三十九):Doris数据导出 - MySQL dump导出

🏡 个人主页:IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 🚩 私聊博主:加入大数据技术讨论群聊,获取更多大数据资料。 🔔 博主个人B栈地址:豹哥教你大数据的个人空间-豹哥教你大数据个人主页-哔哩哔哩视频 目录

解决Opencv dnn模块无法使用onnx模型的问题(将onnx的动态输入改成静态)

一、问题来源 最近做人脸识别项目&#xff0c;想只用OpenCV自带的人脸检测和识别模块实现&#xff0c;使用OpenCV传统方法&#xff1a;Haar级联分类器人脸检测LBPH算法人脸识别的教程已经有了&#xff0c;于是想着用OpenCV中的dnn模块来实现&#xff0c;dnn实现人脸检测也有&a…