基于Bert模型的增量微调3-使用csv文件训练

我们使用weibo评价数据,8分类的csv格式数据集。

一、创建数据集合

使用csv格式的数据作为数据集。

1、创建MydataCSV.py

from  torch.utils.data import Dataset
from datasets import load_datasetclass MyDataset(Dataset):#初始化数据集def __init__(self, split):# 加载csv数据self.dataset=load_dataset(path="csv",data_files=f"D:\Test\LLMTrain\day03\data\Weibo/{split}.csv", split= "train")# 返回数据集长度def __len__(self):return len(self.dataset)# 对每条数据单独进行数据处理def __getitem__(self, idx):text=self.dataset[idx]["text"]label=self.dataset[idx]["label"]return  text,labelif __name__== "__main__":train_dataset=MyDataset("test")for i in range(10):print(train_dataset[i])

二、处理模型

我们使用8分类任务

创建netCSV.py

import torch
from transformers import BertModel#定义设备信息
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)#加载预训练模型
path1=r"D:\Test\LLMTrain\day03\model\bert-base-chinese\models--google-bert--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"
pretrained = BertModel.from_pretrained(path1).to(DEVICE)
print(pretrained)#定义下游任务(增量模型)
class Model(torch.nn.Module):def __init__(self):super().__init__()#设计全连接网络,实现8分类任务self.fc = torch.nn.Linear(768,8)#使用模型处理数据(执行前向计算)def forward(self,input_ids,attention_mask,token_type_ids):#冻结Bert模型的参数,让其不参与训练with torch.no_grad():out = pretrained(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)#增量模型参与训练out = self.fc(out.last_hidden_state[:,0])return out

8分类任务,所以 self.fc=torch.nn.Liner768,8) 。

我们是对大模型做增量微调训练,所以需要冻结Bert模型的参数,让其不参与训练。所以使用 

with torch.no_grad()。

我们定义一个下游任务增量模型Model类,继承 torch.nn.Module。

三、训练的代码

1、创建目录params

存放训练后的结果。

2、写代码

创建train_val_csv.py

#模型训练
import torch
from MyDataCSV import MyDataset
from torch.utils.data import DataLoader
from netCSV import Model
from transformers import BertTokenizer,AdamW#定义设备信息
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#定义训练的轮次(将整个数据集训练完一次为一轮)
EPOCH = 30000#加载字典和分词器
token = BertTokenizer.from_pretrained(r"D:\Test\LLMTrain\day03\model\bert-base-chinese\models--google-bert--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f")#将传入的字符串进行编码
def collate_fn(data):sents = [i[0]for i in data]label = [i[1] for i in data]#编码data = token.batch_encode_plus(batch_text_or_text_pairs=sents,# 当句子长度大于max_length(上限是model_max_length)时,截断truncation=True,max_length=512,# 一律补0到max_lengthpadding="max_length",# 可取值为tf,pt,np,默认为listreturn_tensors="pt",# 返回序列长度return_length=True)input_ids = data["input_ids"]attention_mask = data["attention_mask"]token_type_ids = data["token_type_ids"]label = torch.LongTensor(label)return input_ids,attention_mask,token_type_ids,label#创建数据集
train_dataset = MyDataset("train")
train_loader = DataLoader(dataset=train_dataset,#训练批次batch_size=50,#打乱数据集shuffle=True,#舍弃最后一个批次的数据,防止形状出错drop_last=True,#对加载的数据进行编码collate_fn=collate_fn
)
#创建验证数据集
val_dataset = MyDataset("validation")
val_loader = DataLoader(dataset=val_dataset,#训练批次batch_size=50,#打乱数据集shuffle=True,#舍弃最后一个批次的数据,防止形状出错drop_last=True,#对加载的数据进行编码collate_fn=collate_fn
)
if __name__ == '__main__':#开始训练print(DEVICE)model = Model().to(DEVICE)#定义优化器optimizer = AdamW(model.parameters())#定义损失函数loss_func = torch.nn.CrossEntropyLoss()#初始化验证最佳准确率best_val_acc = 0.0for epoch in range(EPOCH):for i,(input_ids,attention_mask,token_type_ids,label) in enumerate(train_loader):#将数据放到DVEVICE上面input_ids, attention_mask, token_type_ids, label = input_ids.to(DEVICE),attention_mask.to(DEVICE),token_type_ids.to(DEVICE),label.to(DEVICE)#前向计算(将数据输入模型得到输出)out = model(input_ids,attention_mask,token_type_ids)#根据输出计算损失loss = loss_func(out,label)#根据误差优化参数optimizer.zero_grad()loss.backward()optimizer.step()#每隔5个批次输出训练信息if i%5 ==0:out = out.argmax(dim=1)#计算训练精度acc = (out==label).sum().item()/len(label)print(f"epoch:{epoch},i:{i},loss:{loss.item()},acc:{acc}")#验证模型(判断模型是否过拟合)#设置为评估模型model.eval()#不需要模型参与训练with torch.no_grad():val_acc = 0.0val_loss = 0.0for i, (input_ids, attention_mask, token_type_ids, label) in enumerate(val_loader):# 将数据放到DVEVICE上面input_ids, attention_mask, token_type_ids, label = input_ids.to(DEVICE), attention_mask.to(DEVICE), token_type_ids.to(DEVICE), label.to(DEVICE)# 前向计算(将数据输入模型得到输出)out = model(input_ids, attention_mask, token_type_ids)# 根据输出计算损失val_loss += loss_func(out, label)#根据数据,计算验证精度out = out.argmax(dim=1)val_acc+=(out==label).sum().item()val_loss/=len(val_loader)val_acc/=len(val_loader)print(f"验证集:loss:{val_loss},acc:{val_acc}")# #每训练完一轮,保存一次参数# torch.save(model.state_dict(),f"params/{epoch}_bert.pth")# print(epoch,"参数保存成功!")#根据验证准确率保存最优参数if val_acc > best_val_acc:best_val_acc = val_acctorch.save(model.state_dict(),"params1/best_bert.pth")print(f"EPOCH:{epoch}:保存最优参数:acc{best_val_acc}")#保存最后一轮参数torch.save(model.state_dict(), "params1/last_bert.pth")print(f"EPOCH:{epoch}:最后一轮参数保存成功!")

3、执行代码

这个过程需等待很久,若是使用cuda环境,显存越大,速度越快。

train_loader的训练批次batch_size=50,这个数值是根据电脑的配置来的,数值越大越好,只要不超过显存或者内存的90%即可。

四、使用训练好的模型

我们写一个控制台程序,也可以使用FastAPI。创建run.py文件。

#模型使用接口(主观评估)
#模型训练
import torch
from net import Model
from transformers import BertTokenizer#定义设备信息
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")#加载字典和分词器
token = BertTokenizer.from_pretrained(r"D:\Test\LLMTrain\day03\model\bert-base-chinese\models--google-bert--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f")
model = Model().to(DEVICE)
names = ["负向评价","正向评价"]#将传入的字符串进行编码
def collate_fn(data):sents = []sents.append(data)#编码data = token.batch_encode_plus(batch_text_or_text_pairs=sents,# 当句子长度大于max_length(上限是model_max_length)时,截断truncation=True,max_length=512,# 一律补0到max_lengthpadding="max_length",# 可取值为tf,pt,np,默认为listreturn_tensors="pt",# 返回序列长度return_length=True)input_ids = data["input_ids"]attention_mask = data["attention_mask"]token_type_ids = data["token_type_ids"]return input_ids,attention_mask,token_type_idsdef test():#加载模型训练参数model.load_state_dict(torch.load("params/best_bert.pth"))#开启测试模型model.eval()while True:data = input("请输入测试数据(输入‘q’退出):")if data=='q':print("测试结束")breakinput_ids,attention_mask,token_type_ids = collate_fn(data)input_ids, attention_mask, token_type_ids = input_ids.to(DEVICE),attention_mask.to(DEVICE),token_type_ids.to(DEVICE)#将数据输入到模型,得到输出with torch.no_grad():out = model(input_ids,attention_mask,token_type_ids)out = out.argmax(dim=1)print("模型判定:",names[out],"\n")if __name__ == '__main__':test()

运行程序 ,输入test测试集里的数据进行验证,或许输入其他的文本验证。

 正确率还是非常棒的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/32851.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言 —— 此去经年梦浪荡魂音 - 深入理解指针(卷一)

目录 1. 内存和地址 2. 指针变量和地址 2.1 取地址操作符(&) 2.2 指针变量 2.3 解引用操作符 (*) 3. 指针的解引用 3.1 指针 - 整数 3.2 void* 指针 4. const修饰指针 4.1 const修饰变量 4.2 const修饰指针变量 5…

【Linux】线程

文章目录 线程(Thread)1. 什么是线程? 创建线程多线程中的重入问题线程异常线程等待总结 线程(Thread) 1. 什么是线程? 线程是进程中的一个执行单元,它是 CPU 调度的基本单位。线程依赖于进程…

SpringBoot第二天

目录 1.Web开发 1.1简介 1.2SpringBoot对静态资源的映射规则 1.3模板引擎 1.3.1引入thymeleaf; 1.3.2Thymeleaf语法 1.3.2.1标准表达式语法 1.变量表达式 1.3.2.2表达式支持的语法 1.3.2.3常用的thymeleaf标签 1.4Springboot整合springmvc 1.4.1Springmvc…

如何接入DeepSeek布局企业AI系统开发技术

在当今科技飞速发展的时代,人工智能(AI)已成为企业提升竞争力、实现创新突破的关键驱动力。DeepSeek作为一款强大的AI工具,为企业开发自身AI系统提供了有力支持。那么,企业该如何接入DeepSeek进行AI系统开发呢&#xf…

日期累加(注意点)

注意点&#xff1a;①月可能超过12月 ②新年需要重新判断闰年 日期累加 #include <stdio.h>int pd(int year) {return (year % 4 0 && year % 100 ! 0) || (year % 400 0); }int main() {int m;int year, month, day, add;scanf("%d", &m);f…

vue3 前端路由权限控制与字典数据缓存实践(附Demo)

目录 前言1. 基本知识2. Demo3. 实战 前言 &#x1f91f; 找工作&#xff0c;来万码优才&#xff1a;&#x1f449; #小程序://万码优才/r6rqmzDaXpYkJZF 从实战中出发&#xff1a; 1. 基本知识 Vue3 和 Java 通信时如何进行字典数据管理 需要了解字典数据的结构。通常&#…

用于 RGB-D 显著目标检测的点感知交互和 CNN 诱导的细化网络

摘要 通过整合来自RGB图像和深度图的互补信息&#xff0c;能够提升在复杂且具有挑战性场景下的显著性目标检测&#xff08;SOD&#xff09;能力。近年来&#xff0c;卷积神经网络&#xff08;CNNs&#xff09;在特征提取和跨模态交互方面的重要作用已得到充分挖掘&#xff0c;但…

基于SpringBoot的“校园周边美食探索及分享平台”的设计与实现(源码+数据库+文档+PPT)

基于SpringBoot的“校园周边美食探索及分享平台”的设计与实现&#xff08;源码数据库文档PPT) 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SpringBoot 工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 校园周边美食探索及分享平台结构图…

chrome浏览器插件拓展捕获页面的响应体内容

因为chrome extension官方没有的直接获取响应体的方法&#xff0c;所以需要自己实现方法来获取&#xff0c;实现的方式有很多种&#xff0c;这是记录的第二种&#xff0c;第一种就是使用vconsole来实现&#xff0c;vconsole是一个开源框架&#xff0c;一个轻量、可拓展、针对手…

【Linux指北】Linux的重定向与管道

一、了解Linux目录配置标准FHS FHS本质&#xff1a;是一套规定Linux目录结构&#xff0c;软件建议安装位置的标准。 (使用Linux来开发产品或者发布软件的公司、个人太多&#xff0c;如果每家公司或者个人都按照自己的意愿来配置文件或者软件的存放位置&#xff0c;这无疑是一…

Qt6.8.2中JavaScript调用WebAssembly的js文件<1>

前段时间已经学习了如何在QtAssembly中编译FFmpeg资源了&#xff0c;接下来需要使用Html来调用QtCreator中WebAssembly套件写的功能&#xff0c;逐步实现javascrpt与c复杂功能的视线。 接下来我先为大家介绍一个非常简单的加法调用吧&#xff01; 功能讲解 开发环境&#xf…

3.13-进程

进程 进程和程序 程序&#xff1a;编译好的二进制文件&#xff0c;不占用系统资源&#xff08;内存&#xff09;。进程&#xff1a;活跃的程序&#xff0c;不消耗系统图资源&#xff08;内存&#xff09;。 MMU PCB 进程控制块 本质&#xff1a;结构体&#xff1a;struct …

在 CentOS 7 上安装 PHP 7.3

在 CentOS 7 上安装 PHP 7.3 可以按照以下步骤进行操作&#xff1a; 1. 安装必要的依赖和 EPEL 仓库 EPEL&#xff08;Extra Packages for Enterprise Linux&#xff09;是为企业级 Linux 提供额外软件包的仓库&#xff0c;yum-utils 用于管理 yum 仓库。 sudo yum install -…

DeepSeek模型本地化部署方案及Python实现

DeepSeek实在是太火了&#xff0c;虽然经过扩容和调整&#xff0c;但反应依旧不稳定&#xff0c;甚至小圆圈转半天最后却提示“服务器繁忙&#xff0c;请稍后再试。” 故此&#xff0c;本文通过讲解在本地部署 DeepSeek并配合python代码实现&#xff0c;让你零成本搭建自己的AI…

C++从入门到入土(七)——多态

目录 前言 多态的概念 多态的定义 虚函数的介绍 虚函数的重写/覆盖 析构函数的重写 override和final关键字 纯虚函数和抽象类 重写/重载/隐藏总结 多态的原理 小结 前言 C一共有三个特性&#xff0c;封装、继承和多态&#xff0c;在前面的文章中&#xff0c;我们分别…

浅谈时钟启动和Systemlnit函数

时钟是STM32的关键&#xff0c;是整个系统的心脏&#xff0c;时钟如何启动&#xff0c;时钟源如何选择&#xff0c;各个参数如何设置&#xff0c;我们从源码来简单分析一下时钟的启动函数Systemlnit&#xff08;&#xff09;。 Systemlnit函数简介 我们先来看一下源程序的注释…

【数据结构】6栈

0 章节 3&#xff0e;1到3&#xff0e;3小节。 认知与理解栈结构&#xff1b; 列举栈的操作特点。 理解并列举栈的应用案例。 重点 栈的特点与实现&#xff1b; 难点 栈的灵活实现与应用 作业或思考题 完成学习测试&#xff12;&#xff0c;&#xff1f; 内容达成以下标准(考核…

HOT100——链表篇Leetcode160. 相交链表

文章目录 题目&#xff1a;Leetcode160. 相交链表原题链接思路代码 题目&#xff1a;Leetcode160. 相交链表 给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点&#xff0c;返回 null 。 图示两个链表…

江科大51单片机笔记【16】AD/DA转换(下)

写在前言 此为博主自学江科大51单片机&#xff08;B站&#xff09;的笔记&#xff0c;方便后续重温知识 在后面的章节中&#xff0c;为了防止篇幅过长和易于查找&#xff0c;我把一个小节分成两部分来发&#xff0c;上章节主要是关于本节课的硬件介绍、电路图、原理图等理论知识…

【CF】Day5——Codeforces Round 921 (Div. 2) BC

B. A Balanced Problemset? 题目&#xff1a; 思路&#xff1a; 这道题要我们分成n个子问题&#xff0c;我们假设这几个子问题分别是a1,a2,a3,...an&#xff0c; 那么就是让我们求 gcd(a1,a2,a3,....,an)&#xff0c;我们假设这个值是d 那么就有 d | a1&#xff0c;d | a2…