【自然语言处理】大模型高效微调:PEFT 使用案例

文章目录

  • 一、PEFT介绍
  • 二、PEFT 使用
    • 2.1 PeftConfig
    • 2.2 PeftModel
    • 2.3 保存和加载模型
  • 三、PEFT支持任务
    • 3.1 Models support matrix
      • 3.1.1 Causal Language Modeling
      • 3.1.2 Conditional Generation
      • 3.1.3 Sequence Classification
      • 3.1.4 Token Classification
      • 3.1.5 Text-to-Image Generation
      • 3.1.6 Image Classification
      • 3.1.7 Image to text (Multi-modal models)
  • 四、PEFT原理
    • 4.1 LoRA
    • 4.2 Prompt tuning
    • 4.3 IA3

一、PEFT介绍

PEFT(Parameter-Efficient Fine-Tuning,参数高效微调),是一个用于在不微调所有模型参数的情况下,高效地将预训练语言模型(PLM)适应到各种下游应用的库。

PEFT方法仅微调少量(额外的)模型参数,显著降低了计算和存储成本,因为对大规模PLM进行完整微调的代价过高。最近的最先进的PEFT技术实现了与完整微调相当的性能。

代码:

https://github.com/huggingface/peft

文档:

https://huggingface.co/docs/peft/index

二、PEFT 使用

接下来将展示 PEFT 的主要特点,并帮助在消费设备上通常无法访问的情况下训练大型预训练模型。您将了解如何使用LoRA来训练1.2B参数的bigscience/mt0-large模型,以生成分类标签并进行推理。

2.1 PeftConfig

每个 PEFT 方法由一个PeftConfig类来定义,该类存储了用于构建PeftModel的所有重要参数。

由于您将使用LoRA,您需要加载并创建一个LoraConfig类。在LoraConfig中,指定以下参数:

  • task_type,在本例中为序列到序列语言建模
  • inference_mode,是否将模型用于推理
  • r,低秩矩阵的维度
  • lora_alpha,低秩矩阵的缩放因子
  • lora_dropout,LoRA层的dropout概率
from peft import LoraConfig, TaskTypepeft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)

有关您可以调整的其他参数的更多详细信息,请参阅LoraConfig参考。

2.2 PeftModel

使用 get_peft_model() 函数可以创建PeftModel。它需要一个基础模型 - 您可以从 Transformers 库加载 - 以及包含配置特定 PEFT 方法的PeftConfig。

首先加载您要微调的基础模型。

from transformers import AutoModelForSeq2SeqLMmodel_name_or_path = "bigscience/mt0-large"
tokenizer_name_or_path = "bigscience/mt0-large"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)

使用get_peft_model函数将基础模型和peft_config包装起来,以创建PeftModel。要了解您模型中可训练参数的数量,可以使用print_trainable_parameters方法。在这种情况下,您只训练了模型参数的0.19%!

from peft import get_peft_modelmodel = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# 输出示例: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282

至此,我们已经完成了!现在您可以使用Transformers的Trainer、 Accelerate,或任何自定义的PyTorch训练循环来训练模型。

2.3 保存和加载模型

在模型训练完成后,您可以使用save_pretrained函数将模型保存到目录中。您还可以使用push_to_hub函数将模型保存到Hub(请确保首先登录您的Hugging Face帐户)。

model.save_pretrained("output_dir")# 如果要推送到Hub
from huggingface_hub import notebook_loginnotebook_login()
model.push_to_hub("my_awesome_peft_model")

这只保存了已经训练的增量PEFT权重,这意味着存储、传输和加载都非常高效。例如,这个在RAFT数据集的twitter_complaints子集上使用LoRA训练的bigscience/T0_3B模型只包含两个文件:adapter_config.json和adapter_model.bin,后者仅有19MB!

使用from_pretrained函数轻松加载模型进行推理:

from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfigpeft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)

三、PEFT支持任务

3.1 Models support matrix

3.1.1 Causal Language Modeling

在这里插入图片描述

3.1.2 Conditional Generation

在这里插入图片描述

3.1.3 Sequence Classification

在这里插入图片描述

3.1.4 Token Classification

在这里插入图片描述

3.1.5 Text-to-Image Generation

在这里插入图片描述

3.1.6 Image Classification

在这里插入图片描述

3.1.7 Image to text (Multi-modal models)

在这里插入图片描述

四、PEFT原理

4.1 LoRA

LoRA(Low-Rank Adaptation)是一种技术,通过低秩分解将权重更新表示为两个较小的矩阵(称为更新矩阵),从而加速大型模型的微调,并减少内存消耗。

为了使微调更加高效,LoRA的方法是通过低秩分解,使用两个较小的矩阵(称为更新矩阵)来表示权重更新。这些新矩阵可以通过训练适应新数据,同时保持整体变化的数量较少。原始的权重矩阵保持冻结,不再接收任何进一步的调整。为了产生最终结果,同时使用原始和适应后的权重进行合并。

4.2 Prompt tuning

训练大型预训练语言模型是非常耗时且计算密集的。随着模型尺寸的增长,越来越多的人对更高效的训练方法产生了兴趣,例如提示(Prompting)。提示通过包括描述任务的文本提示或甚至演示任务示例的文本提示来为特定的下游任务准备一个冻结的预训练模型。通过使用提示,您可以避免为每个下游任务完全训练单独的模型,而是使用相同的冻结预训练模型。这更加方便,因为您可以将同一模型用于多个不同的任务,而训练和存储一小组提示参数要比训练所有模型参数要高效得多。

提示方法可以分为两类:

  • 硬提示(Hard Prompts):手工制作的具有离散输入标记的文本提示;缺点是需要花费很多精力来创建一个好的提示。
  • 软提示(Soft Prompts):可与输入嵌入连接并进行优化以适应数据集的可学习张量;缺点是它们不太易读,因为您不是将这些“虚拟标记”与实际单词的嵌入进行匹配。

4.3 IA3

为了使微调更加高效,IA3(通过抑制和放大内部激活来注入适配器)使用学习向量对内部激活进行重新缩放。这些学习向量被注入到典型的基于Transformer架构中的注意力和前馈模块中。这些学习向量是微调过程中唯一可训练的参数,因此原始权重保持冻结。处理学习向量(而不是像LoRA一样对权重矩阵进行学习的低秩更新)可以大大减少可训练参数的数量。

与LoRA类似,IA3具有许多相同的优点:

  • IA3通过大大减少可训练参数的数量使微调更加高效(对于T0模型,IA3模型仅具有约0.01%的可训练参数,而即使是LoRA也有超过0.1%)。
  • 原始的预训练权重保持冻结,这意味着您可以在其之上构建多个轻量级和便携的IA3模型,用于各种下游任务。
  • 使用IA3进行微调的模型性能与完全微调的模型性能相当。
  • IA3不会增加任何推理延迟,因为适配器权重可以与基础模型合并。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/83728.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

K8S系列文章之 自动化运维利器 Ansible

Ansible-安装 第一步:安装我们的epel扩展源 yum -y install epel-release 我这里会报/var/run/yum.pid 已被锁定,如果没有直接进行下一步 [rootmaster home]# yum -y install epel-release 已加载插件:fastestmirror, langpacks /var/run/…

基于Pyqt5+serial的串口电池监测工具

本章,其他的没有,废话没有,介绍一下新开源了一个公司的测试工具,写了差不多三周吧。先来看看界面: 这是一个串口调试界面,使用Pyqt5+serial完成。升级功能暂未移入,占一个坑位。 基于serial二次开发的功能各位如有需要可以照搬走,这是一个纯手写的轮子,稳定! 左侧使用…

【网络基础知识铺垫】

文章目录 1 :peach:计算机网络背景:peach:1.1 :apple:网络发展:apple: 2 :peach:协议:peach:2.1 :apple:协议分层:apple:2.2 :apple:OSI七层模型:apple:2.3 :apple:TCP/IP模型:apple:2.4 :apple:TCP/IP模型与操作系统的关系:apple: 3 :peach:网络传输基本流程:peach:4 :peach:网…

Apache2.4源码安装与配置

环境准备 openssl-devel pcre-devel expat-devel libtool gcc libxml2-devel 这些包要提前安装,否则httpd编译安装时候会报错 下载源码、解压缩、软连接 1、wget下载[rootnode01 ~]# wget https://downloads.apache.org/httpd/httpd-2.4.57.tar.gz --2023-07-20 …

CVE漏洞复现-CVE-2021-3493 Linux 提权内核漏洞

CVE-2021-3493 Linux 提权内核漏洞 漏洞描述 CVE-2021-3493 用户漏洞是 Linux 内核中没有文件系统中的 layfs 中的 Ubuntu over 特定问题,在 Ubuntu 中正确验证有关名称空间文件系统的应用程序。buntu 内核代码允许低权限用户在使用 unshare() 函数创建的用户命名…

面试热题(最长上升子序列)

给你一个整数数组 nums ,找到其中最长严格递增子序列的长度。 子序列 是由数组派生而来的序列,删除(或不删除)数组中的元素而不改变其余元素的顺序。例如,[3,6,2,7] 是数组 [0,3,1,6,2,2,7] 的子序列。 输入&#xff1…

智慧影院--java开源电影票优惠券制作系统快速开发

搭建一个智慧影院可以通过使用Java开源电影票优惠券制作系统来快速开发。这个系统可以帮助影院管理电影票的销售和优惠活动,提供便捷的购票方式和优惠券的生成与使用功能。 首先,我们需要建立一个数据库来存储电影、影厅、放映计划、订单等信息。在数据…

Spring Security6入门及自定义登录

一、前言 Spring Security已经更新到了6.x,通过本专栏记录以下Spring Security6学习过程,当然大家可参考Spring Security5专栏对比学习 Spring Security5专栏地址:security5 Spring Security是spring家族产品中的一个安全框架,核心功能包括…

Oracle 开发篇+Java通过HiKariCP访问Oracle数据库

标签:HikariCP、数据库连接池、JDBC连接池、释义:HikariCP 是一个高性能的 JDBC 连接池组件,号称性能最好的后起之秀,是一个基于BoneCP做了不少的改进和优化的高性能JDBC连接池。 ★ Java代码 import java.sql.Connection; impor…

大数据传输的定义与大数据传输解决方案的选择

当我们需要处理大量的数据时,我们就要把数据从一个地方移动到另一个地方。这个过程就叫做大数据传输。它通常需要用到高速的网络连接、分散的存储系统和数据传输协议,以保证数据的快速、可靠和安全的移动。常用的大数据传输技术有Hadoop分布式文件系统&a…

整理mongodb文档:改

个人博客 整理mongodb文档:改 求关注,求批评,求进步 文章概叙 本文主要讲的是mongodb的updateOne以及updateMany,主要还是在shell下进行操作,也讲解下主要的参数upsert以及更新的参数。 数据准备 本次需要准备的数据不是很多…

【Spring】使用注解的方式获取Bean对象(对象装配)

目录 一、了解对象装配 1、属性注入 1.1、属性注入的优缺点分析 2、setter注入 2.1、setter注入的优缺点分析 3、构造方法注入 3.1、构造方法注入的优缺点 二、Resource注解 三、综合练习 上一个博客中,我们了解了使用注解快速的将对象存储到Spring中&#x…

《大型网站技术架构设计》第二篇 架构-性能

不同视角下的网站性能 1、用户 从用户角度,网站性能就是用户在浏览器上直观感受到的网站响应速度快还是慢。用户感受到的时间。 2、开发人员 开发人员关注的主要是应用程序本身及其相关子系统的性能,包括响应延迟、系统吞吐量、并发处理能力、系统稳定…

nsqd的架构及源码分析

文章目录 一 nsq的整体代码结构 二 回顾nsq的整体架构图 三 nsqd进程的作用 四 nsqd启动流程的源码分析 五 本篇博客总结 在博客 nsq整体架构及各个部件作用详解_YZF_Kevin的博客-CSDN博客 中我们讲了nsq的整体框架,各个部件的大致作用。如果没看过的&…

【100天精通python】Day30:使用python操作数据库_数据库基础入门

专栏导读 专栏订阅地址:https://blog.csdn.net/qq_35831906/category_12375510.html 1 数据库基础知识介绍 1.1 什么是数据库? 数据库是一个结构化存储和组织数据的集合,它可以被有效地访问、管理和更新。数据库的目的是为了提供一种可靠的…

paddleseg数据集自定义比例划分为测试集test.txt,训练集train.txt,验证集val.txt

将语义分割的数据集标注好后如下所示: 整理好图片和标签文后需要按照比例划分为训练集,验证集,测试集。 具体划分代码见下: import glob import os.path import argparse import warnings import numpy as npdef parse_args():p…

“清凉计划”KCOFFEE来了,华为天气和肯德基携手提升你的冰凉咖位

八月的热浪席卷着城市,开启了高温酷暑的“闷烤”模式,此时“怕热星人”急需一杯冰爽的冷饮,为了拯救热热热亻七的你们,华为天气联手肯德基带着超强冷势力,发起“清凉计划”,送上一杯KCOFFEE现磨冰咖啡&…

Debian10:安装PHPVirtualBox

PHPVirtualBox 是一个用 PHP 编写,用于管理 VirtualBox 的 Web 前端(由AJAX实现)。 参考文章:VirtualBoxPHPVirtualBox部署_骡子先生的博客-CSDN博客php virualbox,浏览器远程控制VBox 虚拟机phpVirtualBox_weixin_39815879的博客…

编写一个指令(v-focus2end)使输入框文本在聚焦时焦点在文本最后一个位置

项目反馈输入框内容比较多时候,让鼠标光标在最后一个位置,心想什么奇葩需求,后面试了一下,是有点影响体验,于是就有了下面的效果,我目前的项目都是若依的架子,用的是vue2版本。vue3的朋友想要使…

JVM—编译器、类加载的过程、双亲委派机制这些你还记得吗?

背景介绍 这两天在对JVM的知识进行回顾,顺便来分享分享,接下来也会有系列文章,欢迎大家一起讨论。 过程 为什么叫JVM? Java Virtual Machine,java虚拟机。可以理解成一个以字节码为机器指令的CPU 有哪些特点呢&#…