AIGC笔记--基于PEFT库使用LoRA

1--相关讲解

LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

LoRA 在 Stable Diffusion 中的三种应用:原理讲解与代码示例

PEFT-LoRA

2--基本原理

        固定原始层,通过添加和训练两个低秩矩阵,达到微调模型的效果;

3--简单代码

import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model, LoraModel
from peft.utils import get_peft_model_state_dict# 创建模型
class Simple_Model(nn.Module):def __init__(self):super().__init__()self.linear1 = nn.Linear(64, 128)self.linear2 = nn.Linear(128, 256)def forward(self, x: torch.Tensor):x = self.linear1(x)x = self.linear2(x)return xif __name__ == "__main__":# 初始化原始模型origin_model = Simple_Model()# 配置lora configmodel_lora_config = LoraConfig(r = 32, lora_alpha = 32, # scaling = lora_alpha / r 一般来说,lora_alpha的参数初始化为与r相同,即scale=1init_lora_weights = "gaussian", # 参数初始化方式target_modules = ["linear1", "linear2"], # 对应层添加lora层lora_dropout = 0.1)# Test datainput_data = torch.rand(2, 64)origin_output = origin_model(input_data)# 原始模型的权重参数origin_state_dict = origin_model.state_dict() # 两种方式生成对应的lora模型,调用后会更改原始的模型new_model1 = get_peft_model(origin_model, model_lora_config)new_model2 = LoraModel(origin_model, model_lora_config, "default")output1 = new_model1(input_data)output2 = new_model2(input_data)# 初始化时,lora_B矩阵会初始化为全0,因此最初 y = WX + (alpha/r) * BA * X == WX# origin_output == output1 == output2# 获取lora权重参数,两者在key_name上会有区别new_model1_lora_state_dict = get_peft_model_state_dict(new_model1)new_model2_lora_state_dict = get_peft_model_state_dict(new_model2)# origin_state_dict['linear1.weight'].shape -> [output_dim, input_dim]# new_model1_lora_state_dict['base_model.model.linear1.lora_A.weight'].shape -> [r, input_dim]# new_model1_lora_state_dict['base_model.model.linear1.lora_B.weight'].shape -> [output_dim, r]print("All Done!")

4--权重保存和合并

核心公式是:new_weights = origin_weights + alpha* (BA)

    # 借助diffuser的save_lora_weights保存模型权重from diffusers import StableDiffusionPipelinesave_path = "./"global_step = 0StableDiffusionPipeline.save_lora_weights(save_directory = save_path,unet_lora_layers = new_model1_lora_state_dict,safe_serialization = True,weight_name = f"checkpoint-{global_step}.safetensors",)# 加载lora模型权重(参考Stable Diffusion),其实可以重写一个简单的版本from safetensors import safe_openalpha = 1. # 参数融合因子lora_path = "./" + f"checkpoint-{global_step}.safetensors"state_dict = {}with safe_open(lora_path, framework="pt", device="cpu") as f:for key in f.keys():state_dict[key] = f.get_tensor(key)all_lora_weights = []for idx,key in enumerate(state_dict):# only process lora down keyif "lora_B." in key: continueup_key    = key.replace(".lora_A.", ".lora_B.") # 通过lora_A直接获取lora_B的键名model_key = key.replace("unet.", "").replace("lora_A.", "").replace("lora_B.", "")layer_infos = model_key.split(".")[:-1]curr_layer = new_model1while len(layer_infos) > 0:temp_name = layer_infos.pop(0)curr_layer = curr_layer.__getattr__(temp_name)weight_down = state_dict[key].to(curr_layer.weight.data.device)weight_up   = state_dict[up_key].to(curr_layer.weight.data.device)# 将lora参数合并到原模型参数中 -> new_W = origin_W + alpha*(BA)curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)all_lora_weights.append([model_key, torch.mm(weight_up, weight_down).t()])print('Load Lora Done')

5--完整代码

PEFT_LoRA

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/335134.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

奇门遁甲古籍1《奇门秘术》(双页版)PDF电子书

《奇门秘术》 全书共102页 时间有限,仅上传部分图片,结缘私!

ROS基础学习-话题通信机制研究

研究ROS通信机制 研究ROS通信机制 0.前言1.话题通信1.1 理论模型1.2 话题通讯的基本操作1.2.1 C++1.2.2 Python中使用自己的虚拟环境包1.2.2.1 参考11.2.2.2 参考21.2.2.3 /usr/bin/env:“python”:没有那个文件或目录1.2.3 Python1.2.2.1 发布方1.2.2.2 订阅方1.2.2.3 添加可执…

一些Spring的理解

说说你对Spring的理解 首先Spring是一个生态:可以构建企业级应用程序所需的一切基础设施 但是,通常Spring指的就是Spring Framework,它有两大核心: IOC和DI 它的核心就是一个对象管理工厂容器,Spring工厂用于生产Bea…

03 Prometheus+Grafana可视化配置

03 PrometheusGrafana可视化配置 大家好,我是秋意零。接上篇Prometheus入门安装教程 grafana官网下载安装包比较慢,如果没有魔法。可关注公众号【秋意零】回复101获取 Grafana官网下载:https://grafana.com/grafana/download 这里采用的二进制…

2024年社会发展、人文艺术与文化国际会议(ICSDHAC 2024)

2024年社会发展、人文艺术与文化国际会议(ICSDHAC 2024) 会议简介 2024年国际社会发展、人文、艺术和文化会议(ICSDHAC 2024)将在广州举行。会议旨在为从事社会发展、人文、艺术和文化研究的专家学者提供一个平台,分…

为什么说想当产品经理,最好的时候就是现在?

今年,随着人工智能(AI)技术的火热,AI产品经理岗位的需求也一路暴涨,薪资也同步水涨船高。 根据美国招聘社交媒体Glassdoor的数据,AI产品经理年收入高达125万元,是普通产品经理年收入的1.43倍,更是项目经理年收入的2.14倍。在中国,大厂AI产品经理的月收入也高达3到7万左右。但即…

【ai】livekit服务本地开发模式及example app信令交互详细流程

文档要安装git lfs 下载当前最新版本1.6.1 windows版本:启动dev模式 服务器启动 (.venv) PS D:\XTRANS\pythonProject\LIVEKIT> cd .\livekit_release\ (.venv) PS D:\XTRANS\pythonProject\LIVEKIT\livekit_release> lsDirectory: D:\XTRANS\pythonProject\L…

yolo 算法 易主

标题:YOLOv10: Real-Time End-to-End Object Detection 论文:https://arxiv.org/pdf/2405.14458ethttps%3A//arxiv.org/pdf/2405.14458.zhihu.com/?targethttps%3A//arxiv.org/pdf/2405.14458 源码:https://github.com/THU-MIG/yolov10 分析…

Django Web:搭建Websocket服务器(入门篇)

Django Web架构 搭建Websocket服务器(1) - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite:http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:htt…

算法之堆排序

堆排序是一种基于比较的排序算法,通过构建二叉堆(Binary Heap),可以利用堆的性质进行高效的排序。二叉堆是一个完全二叉树,可以有最大堆和最小堆两种形式。在最大堆中,父节点的值总是大于或等于其子节点的值…

Linux文本处理三剑客(详解)

一、文本三剑客是什么? 1. 对于接触过Linux操作系统的人来说,应该都听过说Linux中的文本三剑客吧,即awk、grep、sed,也是必须要掌握的Linux命令之一,三者都是用来处理文本的,但侧重点各不相同,a…

kubeadm引导欧拉系统高可用的K8S1.28.X

文章目录 一. 核心组件架构二. 有状态与无状态应用三. 资源对象3.1 规约与状态3.2 资源的分类-元数据,集群,命名空间3.2.1 元数据3.2.2 集群资源 3.3 命名空间级3.3.1 pod3.3.2 pod-副本集3.3.3 pod-控制器 四. Kubeadm安装k8s集群4.1 初始操作4.2 ~~所有节点安装Docker&#x…

Java基础:基本语法(一)

Java基础:基本语法(一) 文章目录 Java基础:基本语法(一)1. 前言2. 开发环境搭建2.1 Java开发工具包下载2.2 环境变量配置2.3 Java程序的运行过程 3. 数据类型3.1 基本数据类型3.2 引用数据类型 4. 常量与变…

maven部署到私服

方法一:网页上传 1、账号登录 用户名/密码 2、地址 http://自己的ip:自己的端口/nexus 3、查看Repositories列表,选择Public Repositories,确定待上传jar包不在私服中 4、选择3rd party仓库,点击Artifact Upload页签 5、GAV Definition选…

SQL面试题练习 —— 连续登录超过N天用户(一)

题目 现有用户登录日志表 t_login_log,包含用户ID(user_id),登录日期(login_date)。数据已经按照用户日期去重,请查出连续登录超过4天的用户ID。 样例数据 样例输出 建表语句 CREATE TABLE t_login_log (user_id VARCHAR(255) COMMENT 用户ID,login_date DATE CO…

08.tomcat多实例

在加两个tomcat实例 [rootweb01 ~]# ll apache-tomcat-8.0.27.tar.gz -rw-r--r-- 1 root root 9128610 10月 5 2015 apache-tomcat-8.0.27.tar.gz [rootweb01 ~]# tar xf apache-tomcat-8.0.27.tar.gz [rootweb01 ~]# cp -a apache-tomcat-8.0.27 tomcat_8081 [rootweb01 ~…

大模型中的Tokenizer

在使用GPT 、BERT模型输入词语常常会先进行tokenize 。 tokenize的目标是把输入的文本流,切分成一个个子串,每个子串相对有完整的语义,便于学习embedding表达和后续模型的使用。 一、粒度 三种粒度:word/subword/char word词&a…

qt把虚拟键盘部署到arm开发板上(imx6ull)

分为了qt官方配置的虚拟键盘以及各路大神自己开源的第三方键盘,我本来想尝试利用官方键盘结果一直失败,最后放弃了,后面我用的第三方键盘参考了如下文章: https://blog.csdn.net/2301_76250105/article/details/136441243 https…

代码随想录——找树左下角的值(Leetcode513)

题目链接 层序遍历 思路:使用层序遍历,记录每一行 i 0 的元素,就可以找到树左下角的值 /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}*…

深入理解深度学习中的激活层:Sigmoid和Softmax作为非终结层的应用

深入理解深度学习中的激活层:Sigmoid和Softmax作为非终结层的应用Sigmoid 和 Softmax 激活函数简介Sigmoid函数Softmax函数 Sigmoid 和 Softmax 作为非终结层多任务学习特征变换增加网络的非线性实际案例 注意事项结论 深入理解深度学习中的激活层:Sigmo…