lora微调过程

import os
import pickle
from transformers import AutoModelForCausalLM
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskTypedevice = "cuda:0"#1.创建lora微调基本的配置
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM,inference_mode=False,r=8,lora_alpha=32,lora_dropout=0.1,
)

 2.通过调用get_peft_model方法包装基础的Transformer模型

#通过调用get_peft_model方法包装基础的Transformer模型
model = AutoModelForCausalLM("/root/paddlejob/workspace/llama-2-7b-chat")
model = get_peft_model(model, peft_config)

下面是lora微调的模型结构,可以看到多了两个矩阵,一个降维一个升维 

3.训练

# optimizer and lr scheduler
'''len(train_dataloader) 是训练数据集中的批次数量,num_epochs 是训练过程中的迭代次数。因此,len(train_dataloader) * num_epochs 表示整个训练过程中的总迭代次数,即总共要遍历训练数据集的批次数'''
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=0,num_training_steps=(len(train_dataloader) * num_epochs),)#training and evaluation
model = model.to('cuda:0')
for epoch in range(num_epochs):model.train()total_loss = 0for step, batch in enumerate(train_dataloader):batch = {k: v.to('cuda:0') for k, v in batch.items()}outputs = model(**batch)loss = outputs.losstotal_loss = total_loss + loss.detach().float()loss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()model.eval()eval_loss = 0eval_preds = []for step, batch in enumerate(eval_dataloader):batch = {k: v.to('cuda:0') for k, v in bacth.items()}with torch.nograd():outputs = model(**bacth)loss = outputs.losseval_loss = eval_loss + loss.detach().float()eval_preds.extend(tokenizer.bacth_decode(torch.argmax(outputs.logits, -1), skip_special_tokens=True))eval_epoch_loss = eval_loss / len(eval_dataloader)eval_ppl = torch.exp(eval_epoch_loss)train_epoch_loss = total_loss / len(train_dataloader)train_ppl = torch.exp(train_epoch_loss)'''
在 Python 中,** 运算符用于将字典解包为关键字参数传递给函数或方法。在 PyTorch 中,model(**batch) 中的 **batch 将字典 batch 中的键值对作为关键字参数传递给模型的方法(通常是前向传播方法)。具体来说,**batch 将字典 batch 中的每个键值对解包为一组关键字参数。例如,如果 batch 字典包含键值对 {'input_ids': tensor1, 'attention_mask': tensor2},那么 model(**batch) 实际上就等价于 model(input_ids=tensor1, attention_mask=tensor2)。这种方式可以方便地将字典中的数据传递给函数或方法,并且使代码更加简洁和易读。
''''''
loss.detach().float() 的作用是将计算图中的 loss 张量分离出来并转换为浮点数类型。具体来说:loss.detach() 会创建一个新的张量,其值与 loss 相同,但不再跟踪梯度信息。这样做是因为在训练过程中,我们通常只需要保存当前步骤的损失值,而不需要其相关的计算图和梯度信息。
.float() 将张量转换为浮点数类型。这是因为通常情况下,损失值是作为浮点数来计算和累加的。
所以,total_loss += loss.detach().float() 的作用就是将当前步骤的损失值添加到总损失值中,保证总损失值是一个浮点数。
''''''
在使用 PyTorch 进行梯度下降优化时,optimizer.zero_grad() 的作用是将模型参数的梯度归零,以便进行新一轮的梯度计算和更新。这是因为在 PyTorch 中,每次调用 .backward() 方法都会累积梯度,而不是覆盖之前的梯度。因此,在每次迭代更新参数之前,需要先将之前的梯度清零,以免影响当前迭代的梯度计算。简而言之,optimizer.zero_grad() 用于初始化梯度,确保每次迭代都是基于当前 batch 的梯度计算和参数更新,而不会受到之前迭代的影响。
''''''outputs.logits 是模型生成的原始输出,通常是一个三维张量,其中包含了模型对于每个词汇的得分(未经过 softmax 处理)。在语言模型中,这个张量的维度通常是 (batch_size, sequence_length, vocab_size),其中 batch_size 表示批量大小,sequence_length 表示每个序列的长度,vocab_size 表示词汇表的大小。
在生成文本任务中,outputs.logits 的每个元素表示模型在当前位置生成每个词汇的得分。通常,需要对这些得分进行 softmax 处理以获得每个词汇的概率分布,然后根据概率分布进行采样或选择最高概率的词汇作为模型生成的下一个词。torch.argmax 是 PyTorch 库中的一个函数,用于返回张量中指定维度上的最大值的索引。具体而言,对于一个输入张量,torch.argmax(input, dim=None, keepdim=False) 函数将返回指定维度 dim 上最大值的索引。如果不指定 dim,则默认返回整个张量中最大值的索引。
例如,对于一个形状为 (batch_size, seq_length, vocab_size) 的张量,torch.argmax(outputs.logits, -1) 将返回在 vocab_size 维度上每个位置上的最大值对应的索引,即得分最高的词的索引。这行代码的作用是将模型输出的 logits (对应每个词的得分)经过 torch.argmax 函数找到得分最高的词的索引,然后使用 tokenizer 对这些索引进行解码,将索引转换为对应的词,并通过 skip_special_tokens=True 参数去除特殊标记(如 [CLS], [SEP] 等)。最终得到的是模型生成的文本内容。
'''

 4.模型保存

#save model
peft_model_id = f"{model_name_path}_{peft_config.peft_type}_{peft_config.peft.task_type}"
model.save_pretrained(peft_model_id)

5.模型训练的其余部分无需更改,当模型训练完成后,保存高效微调的模型权重部分以供模型推理 

#加载微调后的权重
from peft import PeftModel, PeftConfigconfig = PeftConfig.from_pretrained(peft_model_id)
##加载基础模型
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
##加载peft模型
model = PeftModel.from_pretrained(model, peft_model_id)##加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokennizer.pad_token = tokenizer.eos_token

6.加载微调后的权重文件,并进行推理 

#利用微调后的模型进行推理
##tokenizer编码
inputs = tokenizer(f'{text_column} : {dataset["test"][i]["Tweet text"]} Label : ', return_tensors="pt")##模型推理
outputs = model.generate(input_ids=inputs["input_ids"],attention_mask=inputs["attention_mask"],max_new_tokens=10,eos_token_id=3
)##tokenizer解码
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/302851.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Fecify站点斗篷cloak

斗篷cloak站点斗篷模式功能发布!全新的应用场景,该模式是针对推广不用GMC,而是通过facebook,或者其他的一些平台/工具推广,这些推广方式的特点是:不需要商品的图片,或者说不会排查商品图片的侵权…

基于数据沙箱与LLM用例自愈的UI自动化测试平台

UI自动化测试能够在一定程度上确保产品质量,尤其在降本提效的大背景下,其重要性愈发凸显。理想情况下,UI自动化测试不仅能够能帮我们规避不少线上问题,又能加快产品上线速度。然而现实却往往相去甚远,在多数情况下&…

【React】React hooks 清除定时器并验证效果

React hooks 清除定时器并验证效果 目录结构如下useTime hookClock.tsx使用useTime hookApp.tsx显示Clock组件显示时间(开启定时器)隐藏时间(清除定时器) 总结参考 目录结构如下 useTime hook // src/hooks/common.ts import { u…

【随笔】Git 高级篇 -- 分离 HEAD(十一)

💌 所属专栏:【Git】 😀 作  者:我是夜阑的狗🐶 🚀 个人简介:一个正在努力学技术的CV工程师,专注基础和实战分享 ,欢迎咨询! 💖 欢迎大…

Python高级

不定长参数 位置不定长参数&#xff0c;获取参数args会整合为一个元组 def info(*args):print(arg is, args)print(type(arg) is, type(args))info(1, 2, 3, 4, a, b)# 输出 # arg is (1, 2, 3, 4, a, b) # type(arg) is <class tuple> 关键字不定长参数&#xff0c;&…

VRRP虚拟路由实验(思科)

一&#xff0c;技术简介 VRRP&#xff08;Virtual Router Redundancy Protocol&#xff09;是一种网络协议&#xff0c;用于实现路由器冗余&#xff0c;提高网络可靠性和容错能力。VRRP允许多台路由器共享一个虚拟IP地址&#xff0c;其中一台路由器被选为Master&#xff0c;负…

xshell使用

个人笔记&#xff08;整理不易&#xff0c;有帮助点个赞&#xff09; 笔记目录&#xff1a;学习笔记目录_pytest和unittest、airtest_weixin_42717928的博客-CSDN博客 个人随笔&#xff1a;工作总结随笔_8、以前工作中都接触过哪些类型的测试文档-CSDN博客 Xshell是用于连接和管…

Superset二次开发之图表标题动态化

需求:图表标题动态展示原生筛选器的值 非编辑状态 分析前端代码,找到元素对应的class=header-title 通过class查找对应的代码,核心就是这个title 路径:superset-frontend\src\dashboard\components\SliceHeader\index.tsx SliceHeader组件负责处理仪表板上某个切片(slice…

C++类与对象中(个人笔记)

类与对象中 类的6个默认成员函数1.构造函数1.1特性 2.析构函数2.1特性 3.拷贝构造函数3.1特性 4.赋值运算符重载4.1特性 5.日期类的实现6.const成员6.1const成员的几个问题 7.取地址及const取地址操作符重载 类的6个默认成员函数 如果一个类中什么成员都没有&#xff0c;简称为…

异常的种类

Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 Oracle 运行时错误可以分为 Oracle 错误和用户自定义错误&#xff0c;与此对应&#xff0c;根据异常产生的机制和原理&#xff0c;可将 Oracle 的系统异常分为 3 种 预定义…

Linux使用宝塔面板安装MySQL结合内网穿透实现公网连接本地数据库

文章目录 推荐前言1.Mysql服务安装2.创建数据库3.安装cpolar3.2 创建HTTP隧道 4.远程连接5.固定TCP地址5.1 保留一个固定的公网TCP端口地址5.2 配置固定公网TCP端口地址 推荐 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不…

ssm034学生请假系统+jsp

学生请假系统设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本学生请假系统就是在这样的大环境下诞生&#xff0c;其可以帮助管理者在短时间内处…

uniapp:Hbuilder没有检测到设备请插入设备或启动模拟器的问题解决

问题 使用模拟器调试运行项目时&#xff0c;出现以下提示&#xff0c;“没有检测到设备&#xff0c;请插入设备或启动模拟器后点击刷新再试”。排查了一天最终找到原因。 解决 已确认模拟器是已经正常启动&#xff0c;并且Hbuilder设置中的adb路径和端口都配置没有问题&#…

【Unity添加远程桌面】使用Unity账号远程控制N台电脑

设置地址&#xff1a; URDP终极远程桌面&#xff1b;功能强大&#xff0c;足以让开发人员、设计师、建筑师、工程师等等随时随地完成工作或协助别人https://cloud-desktop.u3dcloud.cn/在网站登录自己的Unity 账号上去 下载安装被控端安装 保持登录 3.代码添加当前主机 "…

UE5、CesiumForUnreal实现建筑白模生长动画效果

文章目录 1.实现目标2.实现过程2.1 实现原理2.2 具体代码2.3 应用测试3.参考资料1.实现目标 在上篇文章加载本地建筑轮廓GeoJson数据生成建筑白模的基础上,本文通过材质“顶点偏移”实现建筑白模生长效果,GIF动图如下所示: 2.实现过程 常用的实现建筑生长效果的方式有两种,…

HTML - 请你谈一谈img标签图片和background背景图片的区别

难度级别:中级及以上 提问概率:65% 面试官当然不会问如何使用img标签或者background来加载一张图片,这些知识点都很基础,相信只要从事前端开发一小段时间以后,就可以轻松搞定加载图片的问题。但很多人习惯用img标签,很多人习惯用backgro…

css字体相关属性

属性汇总 属性作用font-family 设置文章字体 font-size 设置字体大小 font-weight设置字体粗细font-style设置字体斜体font总体设置以上属性 设置文章字体 font-family属性 案例&#xff1a; 设置字体大小 font-size属性 注意事项&#xff1a; 1.必须要加单位&#xff0…

转圈游戏——快速幂

目录 题目 思路 代码 题目 思路 每个小朋友移动一次的位置为&#xff0c;移动 q 次的位置则为。那么题目要求移动 &#xff0c;最后的位置为 。 但 的范围是&#xff0c;而总的移动次数是 。时间复杂度是在&#xff0c;因此是一定不能硬算的&#xff0c;肯定会超时。那么该…

uniapp选择退出到指定页面

方法一&#xff1a;返回上n层页面 onUnload(){uni.navigateBack({delta:5,//返回上5层})},方法二&#xff1a;关闭当前页面&#xff0c;跳转到应用内的某个页面。 uni.redirectTo({url: "../home/index"//页面地址}) 方法三&#xff1a;关闭所有页面&#xff0c;打…

淘宝销量API商品详情页原数据APP接口测试㊣

淘宝/天猫获得淘宝app商品详情原数据 API 返回值说明 item_get_app-获得淘宝app商品详情原数据 公共参数 名称类型必须描述keyString是调用key&#xff08;必须以GET方式拼接在URL中&#xff09;secretString是调用密钥api_nameString是API接口名称&#xff08;包括在请求地…