深度学习每周学习总结N4:中文文本分类-Pytorch实现(基本分类(熟悉流程)、textCNN分类(通用模型)、Bert分类(模型进阶))

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

目录

    • 0. 总结:
    • 1. 基础模型
      • a. 数据加载
      • b. 数据预处理
      • c. 模型搭建与初始化
      • d. 训练函数
      • e. 评估函数
      • f.拆分数据集运行模型
      • g. 结果可视化
      • h. 测试指定数据
    • 2. TextCNN(通用模型-待拓展)
    • 3. Bert(高级模型-待拓展)

0. 总结:

之前有学习过文本预处理的环节,对文本处理的主要方式有以下三种:

1:词袋模型(one-hot编码)

2:TF-IDF

3:Word2Vec(词向量(Word Embedding) 以及Word2vec(Word Embedding 的方法之一))

详细介绍及中英文分词详见pytorch文本分类(一):文本预处理

上上期主要介绍Embedding,及EmbeddingBag 使用示例(对词索引向量转化为词嵌入向量) ,上期主要介绍:应用三种模型的英文分类

本期将主要介绍中文基本分类(熟悉流程)、拓展:textCNN分类(通用模型)、拓展:Bert分类(模型进阶)

1. 基础模型

a. 数据加载

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasetsimport numpy as np
import pandas as pdimport os,PIL,pathlib,warnings
import matplotlib.pyplot as plt
import warningswarnings.filterwarnings("ignore") # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False   # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
# 加载自定义中文数据集
train_data = pd.read_csv('./data/N4-train.csv',sep = '\t',header = None)
train_data.head()
01
0还有双鸭山到淮阴的汽车票吗13号的Travel-Query
1从这里怎么回家Travel-Query
2随便播放一首专辑阁楼里的佛里的歌Music-Play
3给看一下墓王之王嘛FilmTele-Play
4我想看挑战两把s686打突变团竞的游戏视频Video-Play
train_data[0]
0                  还有双鸭山到淮阴的汽车票吗13号的
1                            从这里怎么回家
2                   随便播放一首专辑阁楼里的佛里的歌
3                          给看一下墓王之王嘛
4              我想看挑战两把s686打突变团竞的游戏视频...             
12095          一千六百五十三加三千一百六十五点六五等于几
12096                      稍小点客厅空调风速
12097    黎耀祥陈豪邓萃雯畲诗曼陈法拉敖嘉年杨怡马浚伟等到场出席
12098                  百事盖世群星星光演唱会有谁
12099                 下周一视频会议的闹钟帮我开开
Name: 0, Length: 12100, dtype: object
# 构建数据迭代器
def custom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,ytrain_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])

b. 数据预处理

# 构建词典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba# 中文分词方法
tokenizer = jieba.lcut # jieba.cut返回的是一个生成器,而jieba.lcut返回的是一个列表def yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\Cheng\AppData\Local\Temp\jieba.cache
Loading model cost 0.549 seconds.
Prefix dict has been built successfully.
tokenizer('我想看和平精英上战神必备技巧的游戏视频')
['我', '想', '看', '和平', '精英', '上', '战神', '必备', '技巧', '的', '游戏', '视频']
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])
[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
text_pipeline = lambda x:vocab(tokenizer(x))
label_pipeline = lambda x:label_name.index(x)
label_name = list(set(train_data[1].values[:]))
print(label_name)
['HomeAppliance-Control', 'Audio-Play', 'Other', 'Weather-Query', 'Music-Play', 'Travel-Query', 'TVProgram-Play', 'Alarm-Update', 'Video-Play', 'Calendar-Query', 'FilmTele-Play', 'Radio-Listen']

print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))
[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
8
# 生成数据批次和迭代器
from torch.utils.data import DataLoaderdef collate_batch(batch):label_list,text_list,offsets = [],[],[0]for (_text,_label) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text),dtype = torch.int64)text_list.append(processed_text)# 偏移量(即语句的总词汇量)offsets.append(processed_text.size(0))label_list = torch.tensor(label_list,dtype = torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) # 返回维度dim中输入元素的累计和return text_list.to(device),label_list.to(device),offsets.to(device)# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size = 8,shuffle = False,collate_fn = collate_batch)

c. 模型搭建与初始化

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self,vocab_size,embed_dim,num_class):super(TextClassificationModel,self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,  # 词典大小embed_dim,   # 嵌入维度sparse=False)self.fc = nn.Linear(embed_dim,num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange,initrange) # 初始化权重self.fc.weight.data.uniform_(-initrange,initrange)self.fc.bias.data.zero_()def forward(self,text,offsets):embedded = self.embedding(text,offsets)return self.fc(embedded)
# 初始化模型
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size,em_size,num_class).to(device)

d. 训练函数

import timedef train(dataloader):size = len(dataloader.dataset) # 训练集的大小num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)train_acc,train_loss = 0,0 # 初始化训练损失和正确率for idx,(text,label,offsets) in enumerate(dataloader):# 计算预测误差predicted_label = model(text,offsets)   # 网络输出loss = criterion(predicted_label,label) # 计算网络输出和真实值之间的差距,label为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad() # grad属性归零loss.backward() # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(),0.1)optimizer.step() # 每一步自动更新# 记录acc与losstrain_acc += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc,train_loss

e. 评估函数

import timedef evaluate(dataloader):test_acc,test_loss,total_count = 0,0,0with torch.no_grad():for idx,(text,label,offsets) in enumerate(dataloader):# 计算预测误差predicted_label = model(text,offsets)loss = criterion(predicted_label,label)# 记录测试数据test_acc += (predicted_label.argmax(1) == label).sum().item()test_loss += loss.item()total_count += label.size(0)return test_acc/total_count,test_loss/total_count

f.拆分数据集运行模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS     = 10 # epoch
LR         = 5  # 学习率
BATCH_SIZE = 64 # batch size for trainingcriterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None# 构建数据集
train_iter = custom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)
import copytrain_acc = []
train_loss = []
test_acc = []
test_loss = []best_acc = None # 设置一个最佳准确率,作为最佳模型的判别指标for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()model.train() # 切换为训练模式epoch_train_acc,epoch_train_loss = train(train_dataloader)model.eval() # 切换为测试模式epoch_test_acc,epoch_test_loss = evaluate(valid_dataloader)if best_acc is not None and best_acc > epoch_test_acc:scheduler.step()else:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')print(template.format(epoch,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))print('Done!')
Epoch: 1,Train_acc:63.8%,Train_loss:1.339,Test_acc:79.5%,Test_loss:0.012,Lr:5.00E+00
Epoch: 2,Train_acc:83.2%,Train_loss:0.592,Test_acc:84.8%,Test_loss:0.008,Lr:5.00E+00
Epoch: 3,Train_acc:88.2%,Train_loss:0.413,Test_acc:86.8%,Test_loss:0.007,Lr:5.00E+00
Epoch: 4,Train_acc:91.1%,Train_loss:0.313,Test_acc:87.6%,Test_loss:0.006,Lr:5.00E+00
Epoch: 5,Train_acc:93.3%,Train_loss:0.241,Test_acc:89.8%,Test_loss:0.006,Lr:5.00E+00
Epoch: 6,Train_acc:95.0%,Train_loss:0.189,Test_acc:89.8%,Test_loss:0.006,Lr:5.00E-01
Epoch: 7,Train_acc:96.7%,Train_loss:0.144,Test_acc:89.6%,Test_loss:0.006,Lr:5.00E-02
Epoch: 8,Train_acc:96.8%,Train_loss:0.139,Test_acc:89.6%,Test_loss:0.006,Lr:5.00E-03
Epoch: 9,Train_acc:96.8%,Train_loss:0.139,Test_acc:89.6%,Test_loss:0.006,Lr:5.00E-04
Epoch:10,Train_acc:96.8%,Train_loss:0.139,Test_acc:89.6%,Test_loss:0.005,Lr:5.00E-05
Done!

g. 结果可视化

epochs_range = range(EPOCHS)plt.figure(figsize=(12,3))
plt.subplot(1,2,1)plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc = 'lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Train Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc = 'lower right')
plt.title('Training and Validation Loss')
plt.show()


在这里插入图片描述

h. 测试指定数据

注意以下俩种测试方法,必须保证模型和数据在同样的设备上(GPU或CPU)

def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text, torch.tensor([0]))return output.argmax(1).item()# ex_text_str = "随便播放一首专辑阁楼里的佛里的歌"
ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"model = model.to("cpu")print("该文本的类别是:%s" %label_name[predict(ex_text_str, text_pipeline)])
该文本的类别是:Travel-Query

def predict(text, text_pipeline, model, device):model.eval()  # 切换为评估模式with torch.no_grad():# 将文本和偏移量张量都移动到相同的设备text = torch.tensor(text_pipeline(text)).to(device)offsets = torch.tensor([0]).to(device)output = model(text, offsets)return output.argmax(1).item()ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"# 确保模型在设备上
model = model.to(device)print("该文本的类别是:%s" % label_name[predict(ex_text_str, text_pipeline, model, device)])
该文本的类别是:Travel-Query

2. TextCNN(通用模型-待拓展)


3. Bert(高级模型-待拓展)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/381505.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt QProcess 进程间通信读写数据通信

本文介绍了如何使用Qt的QProcess 进行程序开发&#xff0c;包括启动进程间通信、设置环境变量、通用方法&#xff1b;方便在日常开发中使用&#xff1b; 1.使用Qt进行程序开发&#xff0c;可以通过QProcess类用于启动外部程序并与其进行通信.&#xff1b; 进程A&#xff08;…

Grafana :利用Explore方式实现多条件查询

背景 日志统一推送到Grafana上管理。所以&#xff0c;有了在Grafana上进行日志搜索的需求&#xff0c;而进行日志搜索通常需要多条件组合。 解决方案 通过Grafana的Explore的方式实现多条件查询。 直接看操作步骤&#xff1a; 在主页搜索框中输入“Explore” 进入这个界面…

Linux下开放指定端口

比如需要开放82端口&#xff1a; #查询是否开通 firewall-cmd --query-port82/tcp#开放端口82 firewall-cmd --zonepublic --add-port82/tcp --permanent#重新加载防火墙 firewall-cmd --reload

[计算机网络] VPN技术

1. 概述 虚拟专用网络&#xff08;VPN&#xff09;技术利用互联网服务提供商&#xff08;ISP&#xff09;和网络服务提供商&#xff08;NSP&#xff09;的网络基础设备&#xff0c;在公用网络中建立专用的数据通信通道。VPN的主要优点包括节约成本和提供安全保障。 优点&#…

Android RSA 加解密

文章目录 一、RSA简介二、RSA 原理介绍三、RSA 秘钥对生成1. 密钥对生成2. 获取公钥3. 获取私钥 四、PublicKey 和PrivateKey 的保存1. 获取公钥十六进制字符串1. 获取私钥十六进制字符串 五、PublicKey 和 PrivateKey 加载1. 加载公钥2. 加载私钥 六、 RSA加解密1. RSA 支持三…

[HTML]一文掌握

背景知识 主流浏览器 浏览器是展示和运行网页的平台&#xff0c; 常见的五大浏览器有 IE浏览器、火狐浏览器&#xff08;Firefox&#xff09;、谷歌浏览器&#xff08;Chrome&#xff09;、Safari浏览器、欧朋浏览器&#xff08;Opera&#xff09; 渲染引擎 浏览器解析代码渲…

Go语言 Import导入

本文主要介绍Go语言import导入使用时注意事项和功能实现示例。 目录 Import 创建功能文件夹 加法 减法 主函数 优化导入的包名 .引入方法 总结 Import 创建功能文件夹 做一个计算器来演示&#xff0c;首先创建test文件夹。 加法 在test文件夹中创建add文件夹&#xff…

高性能分布式IO系统BL205 OPC UA耦合器

边缘计算是指在网络的边缘位置进行数据处理和分析&#xff0c;而不是将所有数据都传送到云端或中心服务器&#xff0c;这样可以减少延迟、降低带宽需求、提高响应速度并增强数据安全性。 钡铼BL205耦合器就内置边缘计算功能&#xff0c;它不依赖上位机和云平台&#xff0c;就能…

大语言模型-Transformer-Attention Is All You Need

一、背景信息&#xff1a; Transformer是一种由谷歌在2017年提出的深度学习模型。 主要用于自然语言处理&#xff08;NLP&#xff09;任务&#xff0c;特别是序列到序列&#xff08;Sequence-to-Sequence&#xff09;的学习问题&#xff0c;如机器翻译、文本生成等。Transfor…

上位机图像处理和嵌入式模块部署(香橙派AI Pro开发板试用)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing @163.com】 和工控机相比较,linux嵌入式开发板使用上面方便很多、也容易很多。很多的第三方库都可以通过yum、apt-get这样的方法直接下载到,不需要自己通过源代码重新进行编译、安装。因为自…

单链表<数据结构 C版>

目录 概念 链表的单个结点 链表的打印操作 新结点的申请 尾部插入 头部插入 尾部删除 头部删除 查找 在指定位置之前插入数据 在任意位置之后插入数据 测试运行一下&#xff1a; 删除pos结点 删除pos之后结点 销毁链表 概念 单链表是一种在物理存储结构上非连续、非顺序…

【C语言】 利用栈完成十进制转二进制(分文件编译,堆区申请空间malloc)

利用栈先进后出的特性&#xff0c;在函数内部&#xff0c;进行除二取余的操作&#xff0c;把每次的余数存入栈内&#xff0c;最后输出刚好就是逆序输出&#xff0c;为二进制数 学习过程中&#xff0c;对存储栈进行堆区的内存申请时候&#xff0c;并不是很熟练&#xff0c;一开始…

Linux操作系统的有关常用的命令

1.linux系统的概述 1.1 什么是Linux系统? Linux&#xff0c;全称GNU/Linux&#xff0c;是一种免费使用和自由传播的类UNIX操作系统&#xff0c;其内核由林纳斯本纳第克特托瓦 兹&#xff08;Linus Benedict Torvalds&#xff09;于1991年10月5日首次发布&#xff0c;它主要受…

Unity 调试死循环程序

如果游戏出现死循环如何调试呢。 测试脚本 我们来做一个测试。 首先写一个死循环代码&#xff1a; using System.Collections; using System.Collections.Generic; using UnityEngine;public class dead : MonoBehaviour {void Start(){while (true){int a 1;}}}Unity对象设…

matlab simulink气隙局部放电仿真技术研究

1、内容简介 略 87-可以交流、咨询、答疑 2、内容说明 略 为了解决目前国内外局部放电仿真方法难以计算气隙局部放电暂态过程的问题 , 利用 MATLAB (SIMULINK ) 的公共模块库和电力系统专业模块库 , 根据单气隙局部放电仿真物理模型 , 构造了气隙局部放 电仿真计算的电…

C++_单例模式

目录 1、饿汉方式实现单例 2、懒汉方式实现单例 3、单例模式的总结 结语 前言&#xff1a; 在C中有许多设计模式&#xff0c;单例模式就是其中的一种&#xff0c;该模式主要针对类而设计&#xff0c;确保在一个进程下该类只能实例化出一个对象&#xff0c;因此名为单例。而…

新能源风机视觉数据集

需要的同学私信联系&#xff0c;推荐关注上面图片右下角平台自取下载。 全球风电装机的快速扩张推高了风电场运维巡检的需求&#xff0c;原本高度依赖人力的风电运维巡检工作正因智能化、数字化、无人化技术的应用出现变革。AI智慧风机检测可以促进风电领域运维检测新技术产、…

操作系统——文件管理

1&#xff09;什么是文件&#xff1f; 2&#xff09;单个文件的逻辑结构和物理结果之间是否存在制约关系&#xff1f; 本节内容较为抽象&#xff0c;本节要注意区分文件的逻辑结构和物理结构。 一、文件系统基础 1.文件的基本概念&#xff08;一切皆文件&#xff09; 文件&…

【漏洞复现】Netgear WN604 downloadFile.php 信息泄露漏洞(CVE-2024-6646)

0x01 产品简介 NETGEAR WN604是一款由NETGEAR&#xff08;网件&#xff09;公司生产的无线接入器&#xff08;或无线路由器&#xff09;提供Wi-Fi保护协议&#xff08;WPA2-PSK, WPA-PSK&#xff09;&#xff0c;以及有线等效加密&#xff08;WEP&#xff09;64位、128位和152…

AVL树超详解上

前言 学习过了二叉树以及二叉搜索树后&#xff08;不了解二叉搜索树的朋友可以先看看这篇博客&#xff0c;二叉搜索树详解-CSDN博客&#xff09;&#xff0c;我们在一般情况下对于二叉搜索树的插入与查询时间复杂度都是O(lgN)&#xff0c;是十分快的&#xff0c;但是在一些特殊…