LoRA微调大语言模型Bert

LoRA是一种流行的微调大语言模型的手段,这是因为LoRA仅需在预训练模型需要微调的地方添加旁路矩阵。LoRA 的作者们还提供了一个易于使用的库 loralib,它极大地简化了使用 LoRA 微调模型的过程。这个库允许用户轻松地将 LoRA 层添加到现有的模型架构中,而无需深入了解其底层实现细节。这使得 LoRA 成为了一种非常实用的技术,既适合研究者也适合开发人员。下面给出了一个LoRA微调Bert模型的具体例子。
下图给出了一个LoRA微调Bert中自注意力矩阵 W Q W^Q WQ的例子。如图所示,通过冻结矩阵 W Q W^Q WQ,并且添加旁路低秩矩阵 A , B A,B A,B来进行微调。同理,使用LoRA来微调 W K W^K WK也是如此。
image.png
我们给出了通过LoRA来微调Bert模型中自注意力矩阵的具体代码。代码是基于huggingface中Bert开源模型进行改造。Bert开源项目链接如下:
https://huggingface.co/transformers/v4.3.3/_modules/transformers/models/bert/modeling_bert.html

基于LoRA微调的代码如下:
# 环境配置
# pip install loralib
# 或者
# pip install git+https://github.com/microsoft/LoRA
import loralib as loraclass LoraBertSelfAttention(BertSelfAttention):"""继承BertSelfAttention模块对Query,Value用LoRA进行微调参数:- r (int): LoRA秩的大小- config: Bert模型的参数配置"""def __init__(self, r=8, *config):super().__init__(*config)# 获得所有的注意力的头数d = self.all_head_size # 使用LoRA提供的库loralibself.lora_query = lora.Linear(d, d, r)self.lora_value = lora.Linear(d, d, r)def lora_query(self, x):"""对Query矩阵执行Wx + BAx操作"""return self.query(x) + F.linear(x, self.lora_query)def lora_value(self, x):"""对Value矩阵执行Wx + BAx操作"""return self.value(x) + F.linear(x, self.lora_value)def forward(self, hidden_states, *config):"""更新涉及到Query矩阵和Value矩阵的操作"""# 通过LoRA微调Query矩阵mixed_query_layer = self.lora_query(hidden_states)is_cross_attention = encoder_hidden_states is not Noneif is_cross_attention and past_key_value is not None:# reuse k,v, cross_attentionskey_layer = past_key_value[0]value_layer = past_key_value[1]attention_mask = encoder_attention_maskelif is_cross_attention:key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))# 通过LoRA微调Value矩阵value_layer = self.transpose_for_scores(self.lora_value(hidden_states))attention_mask = encoder_attention_maskelif past_key_value is not None:key_layer = self.transpose_for_scores(self.key(hidden_states))# 通过LoRA微调Value矩阵value_layer = self.transpose_for_scores(self.lora_value(hidden_states))key_layer = torch.cat([past_key_value[0], key_layer], dim=2)value_layer = torch.cat([past_key_value[1], value_layer], dim=2)else:key_layer = self.transpose_for_scores(self.key(hidden_states))# 通过LoRA微调Value矩阵value_layer = self.transpose_for_scores(self.lora_value(hidden_states))query_layer = self.transpose_for_scores(mixed_query_layer)if self.is_decoder:past_key_value = (key_layer, value_layer)# Query矩阵与Key矩阵算点积得到注意力分数attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":seq_length = hidden_states.size()[1]position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)distance = position_ids_l - position_ids_rpositional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibilityif self.position_embedding_type == "relative_key":relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)attention_scores = attention_scores + relative_position_scoreselif self.position_embedding_type == "relative_key_query":relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_keyattention_scores = attention_scores / math.sqrt(self.attention_head_size)if attention_mask is not None:attention_scores = attention_scores + attention_maskattention_probs = nn.Softmax(dim=-1)(attention_scores)attention_probs = self.dropout(attention_probs)if head_mask is not None:attention_probs = attention_probs * head_maskcontext_layer = torch.matmul(attention_probs, value_layer)context_layer = context_layer.permute(0, 2, 1, 3).contiguous()new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)context_layer = context_layer.view(*new_context_layer_shape)outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)if self.is_decoder:outputs = outputs + (past_key_value,)return outputsclass LoraBert(nn.Module):def __init__(self, task_type, num_classes=None, dropout_rate=0.1, model_id="bert-base-cased",lora_rank=8, train_biases=True, train_embedding=False, train_layer_norms=True):"""- task_type: 设计任务的类型,如:'glue', 'squad_v1', 'squad_v2'.- num_classes: 分类类别的数量.- model_id: 预训练好的Bert的ID,如:"bert-base-uncased","bert-large-uncased".- lora_rank: LoRA秩的大小.- train_biases, train_embedding, train_layer_norms: 这是参数是否需要训练    """super().__init__()# 1.加载权重self.model_id = model_idself.tokenizer = BertTokenizer.from_pretrained(model_id)self.model = BertForPreTraining.from_pretrained(model_id)self.model_config = self.model.config# 2.添加模块d_model = self.model_config.hidden_sizeself.finetune_head_norm = nn.LayerNorm(d_model)self.finetune_head_dropout = nn.Dropout(dropout_rate)self.finetune_head_classifier = nn.Linear(d_model, num_classes)# 3.通过LoRA微调模型self.replace_multihead_attention()self.freeze_parameters()def replace_self_attention(self, model):"""把预训练模型中的自注意力换成自己定义的LoraBertSelfAttention"""for name, module in model.named_children():if isinstance(module, RobertaSelfAttention):layer = LoraBertSelfAttention(r=self.lora_rank, config=self.model_config)layer.load_state_dict(module.state_dict(), strict=False)setattr(model, name, layer)else:self.replace_self_attention(module)def freeze_parameters(self):"""将除了涉及LoRA微调模块的其他参数进行冻结LoRA微调影响到的模块: the finetune head, bias parameters, embeddings, and layer norms """for name, param in self.model.named_parameters():is_trainable = ("lora_" in name or"finetune_head_" in name or(self.train_biases and "bias" in name) or(self.train_embeddings and "embeddings" in name) or(self.train_layer_norms and "LayerNorm" in name))param.requires_grad = is_trainablepeft库中包含了LoRA在内的许多大模型高效微调方法,并且与transformer库兼容。使用peft库对大模型flan-T5-xxl进行LoRA微调的代码例子如下:
# 通过LoRA微调flan-T5-xxl
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
# 模型介绍:https://huggingface.co/google/flan-t5-xxl
model_name_or_path = "google/flan-t5-xxl"model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")
peft_config = LoraConfig(r=8,lora_alpha=16, target_modules=["q", "v"], # 仅对Query,Value矩阵进行微调lora_dropout=0.1,bias="none", task_type=TaskType.SEQ_2_SEQ_LM
)
model = get_peft_model(model, peft_config)
# 打印可训练的参数
model.print_trainable_parameters()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/403377.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MATLAB R2023b配置Fortran编译器

MATLAB R2023b配置Fortran编译器 引言1. 安装Visual Studio 20192. 安装Intel API20243. 配置xml文件文件4. 设置环境变量5. MATLAB编译Fortran 引言 当我们需要用到MATLAB编译Fortran代码后进行调用计算时,整个配置流程较繁琐。下面以MATLAB R2023b为例&#xff0…

在亚马逊云科技上部署开源大模型并利用RAG和LangChain开发生成式AI应用

项目简介: 小李哥将继续每天介绍一个基于亚马逊云科技AWS云计算平台的全球前沿AI技术解决方案,帮助大家快速了解国际上最热门的云计算平台亚马逊云科技AWS AI最佳实践,并应用到自己的日常工作里。 本次介绍的是如何在亚马逊云科技上利用Sag…

IDEA2023版本创建mavenWeb项目及maven的相关配置

在使用idea创建maven项目之前,首先要确保本地已经下载并配置好maven的环境变量,可以参考我主页的maven下载及环境变量配置篇。 接下来首先介绍我们需要对maven安装目录文件进行的修改介绍。 maven功能配置 我们需要需改 maven/conf/settings.xml 配置…

中间件|day1.Redis

Redis 定义 Redis 是一个开源(BSD许可)的,内存中的数据结构存储系统,它可以用作数据库、缓存和消息中间件。 它支持多种类型的数据结构,如 字符串(strings), 散列(hash…

构建一个Markdown编辑器:Fyne综合案例

在本文中,我们将通过一个完整的案例来介绍如何使用Go语言的Fyne库来构建一个简单的Markdown编辑器。Fyne是一个易于使用的库,它允许开发者使用Go语言来创建跨平台的GUI应用程序。 1. 项目结构 首先,我们需要创建一个Go项目,并引…

Linux进程间通信——SystemV共享内存

文章目录 System V共享内存创建内存共享释放共享内存命令行shmctlshmatshmdt System V共享内存 System V是一种进程间通信的标准,在这种标准之下,本地通信的方式一般由三种,内存共享、消息队列、信号量 这里我们主要介绍共享内存 创建内存…

Shell参考 - Linux Shell 训练营

出品方<Linux.cn & 阿里云开发者学堂> 一&#xff0c;Linux 可以划分为以下四个部分&#xff1a; 1. 应用软件 2. 窗口管理软件 Unity Gnome KDE 3. GNU 系统工具链 Software- GNU Project - Free Software Foundation 4. Linux 内核 二&#xff0c;什么是shell 1. L…

达梦数据库的系统视图v$database

达梦数据库的系统视图v$database 达梦数据库&#xff08;DM Database&#xff09;提供了多个系统视图&#xff0c;以便管理员能够监控和管理数据库的运行状态和性能。V$DATABASE 是其中一个关键系统视图&#xff0c;它包含有关数据库全局状态和配置的综合信息。 V$DATABASE 系…

Linux 软件编程学习第十五天

1.TCP粘包问题&#xff1a; TCP发送数据是连续的&#xff0c;两次发送的数据可能粘连成一包被接收到 1.解决粘包问题方法&#xff1a; 1.接收指定长度&#xff1a;&#xff08;不稳定&#xff09; 发送5个字节 接收5个字节 2.睡眠&#x…

Java | Leetcode Java题解之第342题4的幂

题目: 题解&#xff1a; class Solution {public boolean isPowerOfFour(int n) {return n > 0 && (n & (n - 1)) 0 && n % 3 1;} }

【数据结构算法经典题目刨析(c语言)】使用数组实现循环队列(图文详解)

&#x1f493; 博客主页&#xff1a;C-SDN花园GGbond ⏩ 文章专栏&#xff1a;数据结构经典题目刨析(c语言) 目录 一.题目描述 二.解题思路 1.循环队列的结构定义 2.队列初始化 3.判空 4.判满 5.入队列 6.出队列 7.取队首元素 8.取队尾元素 三.完整代码实…

接入谷歌支付配置

1.谷歌云创建项目 网址&#xff1a;https://console.cloud.google.com/ 按照步骤创建即可 创建好后选择项目&#xff0c;转到项目设置 选择服务账户&#xff0c;选择创建新的服务账户 名称输入好后访问权限吗账号权限都可以不用填写&#xff0c;默认就好了 然后点击电子邮…

企业高性能web服务器

目录 一.Web 服务基础介绍 1.1 Web 服务介绍 1.2.1 Apache 经典的 Web 服务端 1.2.1.1 Apache prefork 模型 1.2.1.2 Apache worker 模型 1.2.1.3 Apache event模型 1.2.2 Nginx-高性能的 Web 服务端 1.2.3 影响用户体验的因素 1.2.4 服务端 I/O 流程 1.2.4.1 磁盘 I…

【详细】linux 打包QT程序

【详细】linux 打包QT程序 一. 安装linuxdeployqt1.1 下载linuxdeployqt源码并修改如下 二. 安装patchelf三. 打包appimage四. 打包成 Debian包4.1 control文件内容4.2 postinst文件内容4.3 postrm文件内容4.4 打包命令4.4 安装命令4.5 卸载命令 一. 安装linuxdeployqt 下载地…

Gin框架接入Prometheus,grafana辅助pprof检测内存泄露

prometheus与grafana的安装 grom接入Prometheus,grafana-CSDN博客 Prometheus 动态加载 我们想给Prometheus新增监听任务新增ginapp项目只需要在原来的配置文件下面新增ginapp相关metric 在docker compose文件下面新增 执行 docker-compose up -d curl -X POST http://lo…

复现dom破坏案例和靶场

文章目录 靶场网址第一个实验步骤和原理(代码为示例要根据自己的实验修改) 第二个实验步骤和原理(代码为示例要根据自己的实验修改) 靶场网址 注册后点击 第一个实验 此实验室包含一个 DOM 破坏漏洞。注释功能允许“安全”HTML。为了解决这个实验&#xff0c;请构造一个 HT…

依赖注入+中央事件总线:Vue 3组件通信新玩法

​&#x1f308;个人主页&#xff1a;前端青山 &#x1f525;系列专栏&#xff1a;Vue篇 &#x1f516;人终将被年少不可得之物困其一生 依旧青山,本期给大家带来Vue篇专栏内容:Vue-依赖注入-中央事件总线 目录 中央事件总线使用 依赖注入使用 总结 中央事件总线 依赖注入…

【TiDB】09-修改tidb客户端访问密码

目录 1、修改配置文件 2、停止tidb-server 3、以root方式启动脚本 4、修改密码 5、停止脚本重启服务 1、修改配置文件 进入tidb-server默认部署位置 #切换tidb账号 su tidb# 进入tidb-server部署路径 cd /tidb-deploy/tidb-4000# 修改配置 vim ./conf/tidb.toml添加内容…

Datawhale AI 夏令营 第四期 AIGC Task3

活动简介 活动链接&#xff1a;Datawhale AI 夏令营&#xff08;第四期&#xff09; 以及AIGC里面的本次任务说明&#xff1a;Task 3 进阶上分-实战优化 这次任务呢&#xff0c;主要是对知识的一个讲解&#xff0c;包括ComfyUI工具的使用啊&#xff0c;以及LoRA的原理啊&…

学习记录第三十天

管道&#xff1a; 无名管道&#xff1a;只能用于亲缘关系进程之间的通信&#xff1a; 有名管道&#xff1a;是一种特殊的文件&#xff0c;存在于内存中&#xff0c;在系统中有对应的名称&#xff0c;文件大小为0字节&#xff1b; 编程&#xff1a; Linux系统中&#xff0c;…