Bert中文文本分类

这是一个经典的文本分类问题,使用google的预训练模型BERT中文版bert-base-chinese来做中文文本分类。可以先在Huggingface上下载预训练模型备用。https://huggingface.co/google-bert/bert-base-chinese/tree/main

我使用的训练环境是

pip install torch==2.0.0;
pip install transformers==4.30.2;
pip install gensim==4.3.3;
pip install huggingface-hub==0.15.1;
pip install modelscope==1.20.1;

一、准备训练数据

1.1 准备中文文本分类任务的训练数据

这里Demo数据如下:

各银行信用卡挂失费迥异 北京银行收费最高    0
莫泰酒店流拍 大摩叫价或降至6亿美元 4
乌兹别克斯坦议会立法院主席获连任   6
德媒披露鲁能引援关键人物 是他力荐德甲亚洲强人    7
辉立证券给予广汽集团持有评级 2
图文-业余希望赛海南站第二轮 球场的菠萝蜜  7
陆毅鲍蕾:近乎完美的爱情(组图)(2)    9
7000亿美元救市方案将成期市毒药  0
保诚启动210亿美元配股交易以融资收购AIG部门   2

分类class类别文件:

finance
realty
stocks
education
science
society
politics
sports
game
entertainment

1.2 数据读取和截断,使满足BERT模型输入

读取训练数据,对文本进行处理,如截取过长的文本、补齐较短的文本,加上起始标示、对文本进行编码、添加掩码、转为tensor等操作。

import os
from config import parsers
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import torchfrom transformers import AutoTokenizer, AutoModelForMaskedLMdef read_data(file):# 读取文件all_data = open(file, "r", encoding="utf-8").read().split("\n")# 得到所有文本、所有标签、句子的最大长度texts, labels, max_length = [], [], []for data in all_data:if data:text, label = data.split("\t")max_length.append(len(text))texts.append(text)labels.append(label)# 根据不同的数据集返回不同的内容if os.path.split(file)[1] == "train.txt":max_len = max(max_length)return texts, labels, max_lenreturn texts, labels,class MyDataset(Dataset):def __init__(self, texts, labels, max_length):self.all_text = textsself.all_label = labelsself.max_len = max_lengthself.tokenizer = BertTokenizer.from_pretrained(parsers().bert_pred)
#         self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")def __getitem__(self, index):# 取出一条数据并截断长度text = self.all_text[index][:self.max_len]label = self.all_label[index]# 分词text_id = self.tokenizer.tokenize(text)# 加上起始标志text_id = ["[CLS]"] + text_id# 编码token_id = self.tokenizer.convert_tokens_to_ids(text_id)# 掩码  -》mask = [1] * len(token_id) + [0] * (self.max_len + 2 - len(token_id))# 编码后  -》长度一致token_ids = token_id + [0] * (self.max_len + 2 - len(token_id))# str -》 intlabel = int(label)# 转化成tensortoken_ids = torch.tensor(token_ids)mask = torch.tensor(mask)label = torch.tensor(label)return (token_ids, mask), labeldef __len__(self):# 得到文本的长度return len(self.all_text)

将文本处理后,就可以使用torch.utils.data中自带的DataLoader模块来加载训练数据了。

二、微调BERT模型

我们是微调BERT模型,需要获取BERT最后一个隐藏层的输出作为输入到下一个全连接层。

至于选择BERT模型的哪个输出作为linear层的输入,可以通过实验尝试,或者遵循常理。

pooler_output:这是通过将最后一层的隐藏状态的第一个token(通常是[CLS] token)通过一个线性层和激活函数得到的输出,常用于分类任务。
last_hidden_state:这是模型所有层的最后一个隐藏状态的输出,包含了整个序列的上下文信息,适用于序列级别的任务。

简单调用下BERT模型,打印出来最后一层看下:

import torch
import time
import torch.nn as nn
from transformers import BertTokenizer
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLMdef process_text(text, bert_pred):tokenizer = BertTokenizer.from_pretrained(bert_pred)token_id = tokenizer.convert_tokens_to_ids(["[CLS]"] + tokenizer.tokenize(text))mask = [1] * len(token_id) + [0] * (38 + 2 - len(token_id))token_ids = token_id + [0] * (38 + 2 - len(token_id))token_ids = torch.tensor(token_ids).unsqueeze(0)mask = torch.tensor(mask).unsqueeze(0)x = torch.stack([token_ids, mask])return xdevice = "cpu"
bert = BertModel.from_pretrained('./bert-base-chinese/')
texts = ["沈腾和马丽的电影《独行月球》挺好看"]
for text in texts:x = process_text(text, './bert-base-chinese/')input_ids, attention_mask = x[0].to(device), x[1].to(device)hidden_out = bert(input_ids, attention_mask=attention_mask,output_hidden_states=False) print(hidden_out)

 输出结果:

2.1 文本分类任务,选择使用pooler_output作为线性层的输入。

import torch.nn as nn
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
from config import parsers
import torchclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.args = parsers()self.device = "cuda:0" if torch.cuda.is_available() else "cpu"  self.bert = BertModel.from_pretrained(self.args.bert_pred) # bert 模型进行微调for param in self.bert.parameters():param.requires_grad = True# 一个全连接层self.linear = nn.Linear(self.args.num_filters, self.args.class_num)def forward(self, x):input_ids, attention_mask = x[0].to(self.device), x[1].to(self.device)hidden_out = self.bert(input_ids, attention_mask=attention_mask,output_hidden_states=False)  # 是否输出所有encoder层的结果# shape (batch_size, hidden_size)  pooler_output -->  hidden_out[0]pred = self.linear(hidden_out.pooler_output)# 返回预测结果return pred

2.2 优化器使用Adam、损失函数使用交叉熵损失函数

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = MyModel().to(device)
opt = AdamW(model.parameters(), lr=args.learn_rate)
loss_fn = nn.CrossEntropyLoss()

三、训练模型

3.1 参数配置

def parsers():parser = argparse.ArgumentParser(description="Bert model of argparse")parser.add_argument("tx_date",nargs='?') #可选输入参数,计算日期parser.add_argument("--train_file", type=str, default=os.path.join("./data_all", "train.txt"))parser.add_argument("--dev_file", type=str, default=os.path.join("./data_all", "dev.txt"))parser.add_argument("--test_file", type=str, default=os.path.join("./data_all", "test.txt"))parser.add_argument("--classification", type=str, default=os.path.join("./data_all", "class.txt"))parser.add_argument("--bert_pred", type=str, default="./bert-base-chinese")parser.add_argument("--class_num", type=int, default=12)parser.add_argument("--max_len", type=int, default=38)parser.add_argument("--batch_size", type=int, default=32)parser.add_argument("--epochs", type=int, default=10)parser.add_argument("--learn_rate", type=float, default=1e-5)parser.add_argument("--num_filters", type=int, default=768)parser.add_argument("--save_model_best", type=str, default=os.path.join("model", "all_best_model.pth"))parser.add_argument("--save_model_last", type=str, default=os.path.join("model", "all_last_model.pth"))args = parser.parse_args()return args

3.2 模型训练

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn as nn
from sklearn.metrics import accuracy_score
import timeif __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"print("device:", device)train_text, train_label, max_len = read_data(args.train_file)dev_text, dev_label = read_data(args.dev_file)args.max_len = max_lentrain_dataset = MyDataset(train_text, train_label, args.max_len)train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)dev_dataset = MyDataset(dev_text, dev_label, args.max_len)dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False)model = MyModel().to(device)opt = AdamW(model.parameters(), lr=args.learn_rate)loss_fn = nn.CrossEntropyLoss()acc_max = float("-inf")for epoch in range(args.epochs):loss_sum, count = 0, 0model.train()for batch_index, (batch_text, batch_label) in enumerate(train_dataloader):batch_label = batch_label.to(device)pred = model(batch_text)loss = loss_fn(pred, batch_label)opt.zero_grad()loss.backward()opt.step()loss_sum += losscount += 1# 打印内容if len(train_dataloader) - batch_index <= len(train_dataloader) % 1000 and count == len(train_dataloader) % 1000:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0if batch_index % 1000 == 999:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0model.eval()all_pred, all_true = [], []with torch.no_grad():for batch_text, batch_label in dev_dataloader:batch_label = batch_label.to(device)pred = model(batch_text)pred = torch.argmax(pred, dim=1).cpu().numpy().tolist()label = batch_label.cpu().numpy().tolist()all_pred.extend(pred)all_true.extend(label)acc = accuracy_score(all_pred, all_true)print(f"dev acc:{acc:.4f}")if acc > acc_max:print(acc, acc_max)acc_max = acctorch.save(model.state_dict(), args.save_model_best)print(f"以保存最佳模型")torch.save(model.state_dict(), args.save_model_last)end = time.time()print(f"运行时间:{(end-start)/60%60:.4f} min")

模型保存为:

-rw-rw-r--  1 gaoToby gaoToby 391M Dec 24 14:02 all_best_model.pth
-rw-rw-r--  1 gaoToby gaoToby 391M Dec 24 14:02 all_last_model.pth

四、模型推理预测

准备预测文本文件,加载模型,进行文本的类别预测。


def text_class_name(pred):result = torch.argmax(pred, dim=1)print(torch.argmax(pred, dim=1).cpu().numpy().tolist())result = result.cpu().numpy().tolist()classification = open(args.classification, "r", encoding="utf-8").read().split("\n")classification_dict = dict(zip(range(len(classification)), classification))print(f"文本:{text}\t预测的类别为:{classification_dict[result[0]]}")if __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"model = load_model(device, args.save_model_best)texts = ["沈腾和马丽的新电影《独行月球》好看", "最近金融环境不太好,投资需谨慎"]print("模型预测结果:")for text in texts:x = process_text(text, args.bert_pred)with torch.no_grad():pred = model(x)text_class_name(pred)end = time.time()print(f"耗时为:{end - start} s")

以上,基本流程完成。当然模型还需要调优来改进预测效果的。

代码是实际跑通的,我训练和预测均使用的是GPU。如果是使用GPU做模型训练,再使用CPU做推理预测的情况,推理预测加载模型的时候注意修改下:

 myModel.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

Done

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/496995.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot的pom.xml文件中,scope标签有几种配置?

1.compile&#xff08;默认&#xff09; 含义&#xff1a;表示该依赖在项目的所有阶段&#xff08;编译、测试、运行&#xff09;都需要。 当你依赖一个库&#xff0c;并且这个库是你项目的核心部分&#xff0c;比如 Spring Boot 的spring - boot - starter - web&#xff0c…

FPGA三模冗余TMR工具(二)

学术和商业领域有许多自动化的三模冗余TMR工具&#xff0c;本文介绍当前主流的基于寄存器传输级的三模冗余工具&#xff08;Register-Transfer Level&#xff0c;RTL&#xff09;&#xff0c;基于重要软核资源的三模冗余工具&#xff0c;以及新兴的基于高层次综合的三模冗余工具…

STM32 I2C通信协议

单片机学习&#xff01; 文章目录 目录 文章目录 前言 一、I2C通信 1.1 I2C总线 1.2 I2C通信线 1.3 同步半双工且数据应答 1.4 一主多从 二、硬件电路 2.1 I2C电路模型 2.2 I2C接线要求 2.3 I2C上拉电阻作用 三、I2C时序基本单元 3.1 起始终止条件 3.1.1 起始条件 3.1.2 终止条…

【开源】一款基于SpringBoot的智慧小区物业管理系统

一、下载项目文件 项目文件源码链接&#xff1a;https://pan.quark.cn/s/3998d958e182如出现网盘空间不够存的情况&#xff01;&#xff01;&#xff01;解决办法是先用夸克手机app注册&#xff0c;然后保存上方链接&#xff0c;就可以得到1TB空间了&#xff01;&#xff01;&…

AMD | GPU | 深度学习 | 如何使用

问题&#xff1a;我在复现代码的时候&#xff0c;发现自己只拥有AMD的GPU&#xff0c;对于一个硬件小白来说&#xff0c;怎么办呢&#xff1f;我想看看怎么使用&#xff1b;解决&#xff1a; 首先要安装支持AMD的GPU的pytorch&#xff0c;pytorch&#xff1b; 使程序在安装了支…

【HarmonyOS】鸿蒙arrayBuffer和Uint8Array互相转化

【HarmonyOS】鸿蒙arrayBuffer和Uint8Array互相转化 前言 ArrayBuffer ArrayBuffer内部包含一块Native内存&#xff0c;该ArrayBuffer的JS对象壳被分配在虚拟机本地堆&#xff08;LocalHeap&#xff09;。与普通对象一样&#xff0c;需要经过序列化与反序列化拷贝传递&#x…

从 ELK Stack 到简单 — Elastic Cloud Serverless 上的 Elastic 可观察性

作者&#xff1a;来自 Elastic Bahubali Shetti, Chris DiStasio 宣布 Elastic Cloud Serverless 上的 Elastic Observability 正式发布 — 一款完全托管的可观察性解决方案。 随着组织规模的扩大&#xff0c;一个能够处理分布式云环境的复杂性并提供实时洞察的可观察性解决方…

MySQL数据库的索引

一、数据库的索引 1. 索引的概论 索引&#xff08;Index&#xff09;是书籍的重要组成部分&#xff0c;它列出了书中的重要名词及其对应的页码&#xff0c;方便读者快速查找这些名词的定义和含义。通过索引&#xff0c;用户无需通读整本书就能迅速找到所需的信息。 数据库索…

仓颉语言实战——1. 类型

仓颉语言实战——1. 类型 仓颉语言&#xff08;Cangjie Language&#xff09;是一个现代化的、简洁而强大的编程语言&#xff0c;它的类型系统为高效开发提供了极大的支持。本篇文章将围绕仓颉语言中的类型系统展开&#xff0c;结合实战代码&#xff0c;帮助开发者快速掌握这一…

【已解决】图片png转ico格式

起因&#xff1a; pyinstaller 打包时需要 ico 格式图片&#xff0c;但是通常手上只有png格式的图片&#xff0c;为了将png转为ico&#xff0c;直接改后缀会报错“struct.error: unpack requires a buffer of 16 bytes”&#xff0c;我就上网搜了一下&#xff0c;发现都是一些…

机器学习详解(11):分类任务的模型评估标准

模型评估是利用不同的评估指标来了解机器学习模型的性能&#xff0c;以及其优势和劣势的过程。评估对于确保机器学习模型的可靠性、泛化能力以及在新数据上的准确预测能力至关重要。 文章目录 1 介绍2 评估准则3 分类指标3.1 准确率 (Accuracy)3.2 精确率 (Precision)3.3 召回率…

Python-网络爬虫

随着网络的迅速发展&#xff0c;如何有效地提取并利用信息已经成为一个巨大的挑战。为了更高效地获取指定信息&#xff0c;需定向抓取并分析网页资源&#xff0c;从而促进了网络爬虫的发展。本章将介绍使用Python编写网络爬虫的方法。 学习目标&#xff1a; 理解网络爬虫的基本…

【超级详细】七牛云配置阿里云域名详细过程记录

0. 准备一个阿里云域名&#xff0c;记得要备案&#xff01;&#xff01;&#xff01;&#xff01; 1. 创建七牛云存储空间 首先&#xff0c;登录七牛云控制台&#xff0c;创建一个新的存储空间&#xff08;Bucket&#xff09;。这个存储空间将用于存放你的文件&#xff0c;并…

WPF使用资源定义和样式资源,解耦视图与逻辑(较多样式重复的时候使用)

-- 将Button的Style写到Window.Resources中 其中Window.Resource的Style也是可以继承的&#xff0c;需要使用BaseOn这个属性 还有很多用法的&#xff0c;有空再补充

GitLab安装及使用

目录 一、安装 1.创建一个目录用来放rpm包 2.检查防火墙状态 3.安装下载好的rpm包 4.修改配置文件 5.重新加载配置 6.查看版本 7.查看服务器状态 8.重启服务器 9.输网址 二、GitLab的使用 1.创建空白项目 2.配置ssh 首先生成公钥&#xff1a; 查看公钥 把上面的…

Socket学习(一):控制台聊天demo

实现效果 客户端连接服务端后&#xff0c;可在控制台输入要发送的消息&#xff0c;服务端收到消息后自动回复消息并将消息转发给所有连接上的客户端&#xff1a; 服务端收到消息并回复 客户端1发送消息并接收服务端的回复 客户端2接收服务端转发的消息 源码 SocketServer…

虚拟机桥接模式

主机Win10,虚拟机xp 1.虚拟机设置中选择桥接模式 2.在虚拟机菜单&#xff1a;编辑>虚拟机网络编辑&#xff0c;点击“更改设置”&#xff0c;可以看到三个网卡&#xff0c;这三个网卡分别对应不同的网络共享模式。桥接模式须使用VMnet0&#xff0c;如果没看到这个网卡&…

功能测试和接口测试

&#x1f345; 点击文末小卡片 &#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 本文主要分为两个部分&#xff1a; 第一部分&#xff1a;主要从问题出发&#xff0c;引入接口测试的相关内容并与前端测试进行简单对比&#xff0c;总结两者之…

2022博客之星年度总评选开始了

作者简介&#xff1a;陶然同学 专注于Java领域开发 熟练掌握Java、js等语言的“Hello World” CSDN原力计划作者、CSDN内容合伙人、Java领域优质作者、Java领域新星作者、51CTO专家、华为云专家、阿里云专家等 &#x1f3ac; 陶然同学&#x1f3a5; 由 陶然同学 原创&#…

Spring自动化创建脚本-解放繁琐的初始化配置!!!(自动化SSM整合)

一、实现功能(原创&#xff0c;转载请告知) 1.自动配置pom配置文件 2.自动识别数据库及数据表&#xff0c;创建Entity、Dao、Service、Controller等 3.自动创建database.properties、mybatis-config.xml等数据库文件 4.自动创建spring-dao.xml spring-mvc.xml …