huggingface 笔记:peft

1 介绍

  • PEFT 提供了参数高效的方法来微调大型预训练模型。
  • 传统的范式是为每个下游任务微调模型的所有参数,但由于当前模型的参数数量巨大,这变得极其昂贵且不切实际。
  • 相反,训练较少数量的提示参数或使用诸如低秩适应 (LoRA) 的重新参数化方法来减少可训练参数数量是更有效的

2 训练

2.1 加载并创建 LoraConfig 类

  • 每种 PEFT 方法都由一个 PeftConfig 类定义,该类存储了构建 PeftModel 的所有重要参数
  • eg:使用 LoRA 进行训练,加载并创建一个 LoraConfig 类,并指定以下参数
    • task_type:要训练的任务(在本例中为序列到序列语言建模)
    • inference_mode:是否将模型用于推理
    • r:低秩矩阵的维度
    • lora_alpha:低秩矩阵的缩放因子
    • lora_dropout:LoRA 层的丢弃概率
from peft import LoraConfig, TaskTypepeft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)

2.2 创建 PeftModel

  • 一旦设置了 LoraConfig,就可以使用 get_peft_model() 函数创建一个 PeftModel
    • 需要一个基础模型 - 可以从 Transformers 库加载 -
    • 以及包含如何配置 LoRA 模型参数的 LoraConfig

2.2.1 加载需要微调的基础模型

from transformers import AutoModel,AutoTokenizer
import os
import torchos.environ["HF_TOKEN"] = '*'
#huggingface的私钥tokenizer=AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B')
model=AutoModel.from_pretrained('meta-llama/Meta-Llama-3-8B',torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)

2.2.2 创建 PeftModel

将基础模型和 peft_config 与 get_peft_model() 函数一起包装以创建 PeftModel

from peft import get_peft_modelmodel = get_peft_model(model, peft_config)
model.print_trainable_parameters()
#了解模型中可训练参数的数量
#trainable params: 3,407,872 || all params: 7,508,332,544 || trainable%: 0.0454

之后就可以train了

3保存模型

模型训练完成后,可以使用 save_pretrained 函数将模型保存到目录中。

model.save_pretrained("output_dir")

4 推理

  • 使用 AutoPeftModel 类和 from_pretrained 方法加载任何 PEFT 训练的模型进行推理
    • 对于没有明确支持 AutoPeftModelFor 类的任务,可以使用基础的 AutoPeftModel 类加载任务模型
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
import torchmodel = AutoPeftModelForCausalLM.from_pretrained("ybelkada/opt-350m-lora")
#LORA过的模型
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
#普通的tokenizermodel.eval()
inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors="pt")outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=50)
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])# "Preheat the oven to 350 degrees and place the cookie dough in the center of the oven. In a large bowl, combine the flour, baking powder, baking soda, salt, and cinnamon. In a separate bowl, combine the egg yolks, sugar, and vanilla."

5LoraConfig主要参数

r (int)

LoRA 注意力维度(“秩”)

【默认8】

target_modules

应用适配器的模块名称。

  • 如果指定,则仅替换具有指定名称的模块。
    • 传递字符串时,将执行正则表达式匹配。
    • 传递字符串列表时,将执行精确匹配或检查模块名称是否以传递字符串之一结尾。
    • 如果指定为 'all-linear',则选择所有线性/Conv1D 模块,排除输出层。
    • 如果未指定,将根据模型架构选择模块。
    • 如果架构未知,将引发错误
      • - 在这种情况下,应手动指定目标模块。

【默认None]

lora_alpha

LoRA 缩放的 alpha 参数

【默认8】

lora_dropout

LoRA 层的 dropout 概率

【默认0】

fan_in_fan_out

如果要替换的层存储权重为 (fan_in, fan_out),则设置为 True

例如,GPT-2 使用 Conv1D,它存储权重为 (fan_in, fan_out),因此应设置为 True

bias

LoRA 的偏置类型。可以是 'none'、'all' 或 'lora_only'。

如果是 'all' 或 'lora_only',则在训练期间将更新相应的偏置

use_rslora
  • 设置为 True 时,使用 Rank-Stabilized LoRA,将适配器缩放因子设置为 lora_alpha/math.sqrt(r),因为其效果更好。
  • 否则,将使用原始默认值 lora_alpha/r
modules_to_save除适配器层外,在最终检查点中设置为可训练和保存的模块列表
init_lora_weights

如何初始化适配器层的权重。

  • 传递 True(默认)将导致来自微软参考实现的默认初始化。
  • 传递 'gaussian' 将导致高斯初始化,按 LoRA 秩缩放。
  • 将初始化设置为 False 会导致完全随机初始化,不建议使用。
  • 传递 'loftq' 使用 LoftQ 初始化。
  • 传递 'pissa' 使用 PiSSA 初始化,收敛速度更快,性能更优。
  • 传递 pissa_niter_[number of iters]使用快速 SVD 基于 PiSSA 的初始化,其中 [number of iters] 表示执行 FSVD 的子空间迭代次数,必须为非负整数。
    • 当 [number of iters] 设置为 16 时,可以在几秒钟内完成 7b 模型的初始化,训练效果大致相当于使用 SVD。
layers_to_transform

要转换的层索引列表。

  • 如果传递一个整数列表,则将适配器应用于该列表中指定的层索引。
  • 如果传递单个整数,则在该索引处的层上应用转换。
layers_pattern层模式名称,仅在 layers_to_transform 不为 None 时使用
rank_pattern层名称或正则表达式到秩的映射
alpha_pattern层名称或正则表达式到 alpha 的映射

6举例

6.1 初始 model

from transformers import AutoModel,AutoTokenizer
import os
import torchos.environ["HF_TOKEN"] = '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'tokenizer=AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B')
model=AutoModel.from_pretrained('meta-llama/Meta-Llama-3-8B',torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)model

6.2 经典lora

from peft import LoraConfig, TaskTypepeft_config = LoraConfig(r=8, lora_alpha=32, lora_dropout=0.1)from peft import get_peft_modellora_model = get_peft_model(model, peft_config)
lora_model.print_trainable_parameters()lora_model

6.3 查看可训练参数

for name,tensor in model.named_parameters():print(name,tensor.requires_grad)

for name,tensor in lora_model.named_parameters():
    print(name,tensor.requires_grad)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/366837.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

和小红书一起参会! 了解大模型与大数据融合的技术趋势

在过去的两年中,“大模型”无疑成为互联网行业的焦点话题,曾经炙手可热的大数据架构似乎淡出公众视野。然而,大数据领域并未停滞不前,反而快速演进,传统依赖众多开源组件的大数据平台正逐步过渡到以融合与简化为核心特…

【漏洞复现】电信网关配置管理系统——命令执行

声明:本文档或演示材料仅供教育和教学目的使用,任何个人或组织使用本文档中的信息进行非法活动,均与本文档的作者或发布者无关。 文章目录 漏洞描述漏洞复现测试工具 漏洞描述 电信网关配置管理系统是一个用于管理和配置电信网关设备的软件系…

C语言编程-基于单链表实现贪吃蛇游戏

基于单链表实现贪吃蛇游戏 1.定义结构体参数 蛇行走的方向 蛇行走的状态 蛇身节点类 维护蛇的结构体型 2.游戏运行前预备工作 定位光标位置 游戏欢迎界面 绘制游戏地图(边界) 初始化游戏中的蛇身 创建食物 3.游戏运行 下一个位置是食物,就吃掉…

Py之dashscope:dashscope的简介、安装和使用方法、案例应用之详细攻略

Py之dashscope:dashscope的简介、安装和使用方法、案例应用之详细攻略 目录 dashscope的简介 1、产品的主要特点和优势包括: dashscope的安装和使用方法 1、安装 2、使用方法 dashscope的案例应用 1、通义千问-Max:通义千问2.5系列 2…

【瑞吉外卖 | day01】项目介绍+后台登录退出功能

文章目录 瑞吉外卖 — day011. 所需知识2. 软件开发整体介绍2.1 软件开发流程2.2 角色分工2.3 软件环境 3. 瑞吉外卖项目介绍3.1 项目介绍3.2 产品原型展示3.3 技术选型3.4 功能架构3.5 角色 4. 开发环境搭建4.1 数据库环境搭建4.2 Maven项目构建 5. 后台系统登录功能5.1 创建需…

【Python】已解决:SyntaxError: positional argument follows keyword argument

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例五、注意事项 已解决:SyntaxError: positional argument follows keyword argument 一、分析问题背景 在Python编程中,当我们在调用函数时混合使用位置参数(p…

golang使用RSA加密和解密

目录 前提 生成RSA公钥和密钥 读取文件 加密 解密 前提 本文章我们是先读取的RSA文件,所以需要先生成RSA,并且保存在文件中,再进行加密 生成RSA公钥和密钥 如果没有公钥和密钥,可以先看看我上一篇文章 生成RSA公钥和密钥h…

基于Java微信小程序同城家政服务系统设计和实现(源码+LW+调试文档+讲解等)

💗博主介绍:✌全网粉丝10W,CSDN作者、博客专家、全栈领域优质创作者,博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌💗 🌟文末获取源码数据库🌟感兴趣的可以先收藏起来,还…

3ds Max导出fbx贴图问题简单记录

1.前言 工作中发现3ds Max导出的fbx在其它软件(Autodesk viewer,blender,navisworks,FBXReview等)中丢失了部分贴图,但导出的fbx用3ds Max打开却正常显示。 fbx格式使用范围较广,很多常见的三…

【深度学习】卷积神经网络CNN

李宏毅深度学习笔记 图像分类 图像可以描述为三维张量(张量可以想成维度大于 2 的矩阵)。一张图像是一个三维的张量,其中一维代表图像的宽,另外一维代表图像的高,还有一维代表图像的通道(channel&#xff…

华为手机怎么打印文件?

关于华为手机打印的问题,如果您有打印机,并且已经成功和华为手机相连,在解决上就要容易很多。 具体操作如下: 选择文件 文件来源:华为手机上的文件可以来自多个应用,如图库、备忘录、文件管理等&#xf…

C语言之线程的学习

线程属于某一个进程 共同点:都能并发 线程共享变量,进程不共享。 多线程任务中,其中某一个线程调用了exit了,其他线程会跟着一起退出 如果是特定的线程就调用pthread_exit 失败返回的是错误号 下面也是

解码未来城市:探秘数字孪生的奥秘

在科技日新月异的今天,"数字孪生"(Digital Twin)这一概念如同一颗璀璨的新星,照亮了智慧城市、智能制造等多个领域的前行之路。本文将深入浅出地解析数字孪生的定义、技术原理、应用场景及未来发展,带您一窥…

【介绍下Pwn,什么是Pwn?】

🌈个人主页: 程序员不想敲代码啊 🏆CSDN优质创作者,CSDN实力新星,CSDN博客专家 👍点赞⭐评论⭐收藏 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共…

2021强网杯

一、环境 网上自己找 二、步骤 2.1抛出引题 在这个代码中我们反序列&#xff0c;再序列化 <?php$raw O:1:"A":1:{s:1:"a";s:1:"b";};echo serialize(unserialize($raw));//O:1:"A":1:{s:1:"a";s:1:"b";…

[leetcode]文件组合

. - 力扣&#xff08;LeetCode&#xff09; class Solution { public:vector<vector<int>> fileCombination(int target) {vector<vector<int>> vec;vector<int> res;int sum 0, limit (target - 1) / 2; // (target - 1) / 2 等效于 target /…

代码随想录Day69(图论Part05)

并查集 // 1.初始化 int fa[MAXN]; void init(int n) {for (int i1;i<n;i)fa[i]i; }// 2.查询 找到的祖先直接返回&#xff0c;未进行路径压缩 int.find(int i){if(fa[i] i)return i;// 递归出口&#xff0c;当到达了祖先位置&#xff0c;就返回祖先elsereturn find(fa[i])…

构造,析构,拷贝【类和对象(中)】

P. S.&#xff1a;以下代码均在VS2019环境下测试&#xff0c;不代表所有编译器均可通过。 P. S.&#xff1a;测试代码均未展示头文件stdio.h的声明&#xff0c;使用时请自行添加。 博主主页&#xff1a;LiUEEEEE                        …

yolov8obb角度预测原理解析

预测头 ultralytics/nn/modules/head.py class OBB(Detect):"""YOLOv8 OBB detection head for detection with rotation models."""def __init__(self, nc80, ne1, ch()):"""Initialize OBB with number of classes nc and la…

1.k8s:架构,组件,基础概念

目录 一、k8s了解 1.什么是k8s 2.为什么要k8s &#xff08;1&#xff09;部署方式演变 &#xff08;2&#xff09;k8s作用 &#xff08;3&#xff09;Mesos&#xff0c;Swarm&#xff0c;K8S三大平台对比 二、k8s架构、组件 1.k8s架构 2.k8s基础组件 3.k8s附加组件 …