【NLP练习】中文文本分类-Pytorch实现

中文文本分类-Pytorch实现

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、准备工作

1. 任务说明

本次使用Pytorch实现中文文本分类。主要代码与文本分类代码基本一致,不同的是本次任务使用了本地的中文数据,数据示例如下:
在这里插入图片描述

2.加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")   #忽略警告信息device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

输出:

device(type='cpu')
import pandas as pd#加载自定义中文数据
train_data = pd.read_csv('./train.csv',sep='\t',header = None)
train_data.head()

输出:
在这里插入图片描述

#构造数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,ytrain_iter =coustom_data_iter(train_data[0].values[:],train_data[1].values[:])

二、数据预处理

#构建词典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba#中文分词方法
tokenizer = jieba.lcut
def yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])   #设置默认索引,如果找不到单词,则会选择默认索引
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])

输出:

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
label_name = list(set(train_data[1].values[:]))
print(label_name)

输出:

['FilmTele-Play', 'Alarm-Update', 'Weather-Query', 'Audio-Play', 'Radio-Listen', 'Travel-Query', 'Music-Play', 'Video-Play', 'HomeAppliance-Control', 'Calendar-Query', 'TVProgram-Play', 'Other']
text_pipeline = lambda x : vocab(tokenizer(x))
label_pipeline = lambda x : label_name.index(x)print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

输出:

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
7

lambda表达式的语法为:lambda arguments: expression
其中arguments是函数的参数,可以有多个参数,用逗号分隔。expression是一个表达式,它定义了函数的返回值。

  • text_pipeline函数: 将原始文本数据转换为整数列表,使用了之前构建的vocab词表和tokenizer分词器函数。具体步骤:
  1. 接受一个字符串x作为输入
  2. 使用tokenizer将其分词
  3. 将每个词在vocab词表中的索引放入一个列表返回
  • label_pipeline函数: 将原始标签数据转换为整数,它接受一个字符串x作为输入,并使用 label_index.index(x) 方法获取x在label_name列表中的索引作为输出。

2.生成数据批次和迭代器

#生成数据批次和迭代器
from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [],[],[0]         for(_text, _label) in batch:#标签列表label_list.append(label_pipeline(_label))#文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)#偏移量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list,dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)       #返回维度dim中输入元素的累计和return text_list.to(device), label_list.to(device), offsets.to(device)#数据加载器
dataloader = DataLoader(train_iter,batch_size = 8,shuffle = False,collate_fn = collate_batch
)

三、模型构建

1. 搭建模型

#搭建模型
from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel,self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,      #词典大小embed_dim,        # 嵌入的维度sparse=False)     #self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)

2. 初始化模型

#初始化模型
#定义实例
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)

3. 定义训练与评估函数

#定义训练与评估函数
import timedef train(dataloader):model.train()          #切换为训练模式total_acc, train_loss, total_count = 0,0,0log_interval = 50start_time = time.time()for idx, (text,label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()                             #grad属性归零loss = criterion(predicted_label, label)          #计算网络输出和真实值之间的差距,label为真loss.backward()                                   #反向传播torch.nn.utils.clip_grad_norm_(model.parameters(),0.1)  #梯度裁剪optimizer.step()                                  #每一步自动更新#记录acc与losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('|epoch{:d}|{:4d}/{:4d} batches|train_acc{:4.3f} train_loss{:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))total_acc,train_loss,total_count = 0,0,0staet_time = time.time()def evaluate(dataloader):model.eval()      #切换为测试模式total_acc,train_loss,total_count = 0,0,0with torch.no_grad():for idx,(text,label,offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label,label)   #计算loss值#记录测试数据total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

四、训练模型

1. 拆分数据集并运行模型

#拆分数据集并运行模型
from torch.utils.data.dataset   import random_split
from torchtext.data.functional  import to_map_style_dataset# 超参数设定
EPOCHS      = 10   #epoch
LR          = 5    #learningRate
BATCH_SIZE  = 64   #batch size for training#设置损失函数、选择优化器、设置学习率调整函数
criterion   = torch.nn.CrossEntropyLoss()
optimizer   = torch.optim.SGD(model.parameters(), lr = LR)
scheduler   = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma = 0.1)
total_accu  = None# 构建数据集
train_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset   = to_map_style_dataset(train_iter)
split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])train_dataloader    = DataLoader(split_train_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
valid_dataloader    = DataLoader(split_valid_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)#获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:d} | time:{:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f}'.format(epoch,time.time() - epoch_start_time,val_acc,val_loss))print('-' * 69)

输出:

['还有双鸭山到淮阴的汽车票吗13号的' '从这里怎么回家' '随便播放一首专辑阁楼里的佛里的歌' ...'黎耀祥陈豪邓萃雯畲诗曼陈法拉敖嘉年杨怡马浚伟等到场出席' '百事盖世群星星光演唱会有谁' '下周一视频会议的闹钟帮我开开']
|epoch1|  50/ 152 batches|train_acc0.953 train_loss0.00282
|epoch1| 100/ 152 batches|train_acc0.953 train_loss0.00271
|epoch1| 150/ 152 batches|train_acc0.952 train_loss0.00292
---------------------------------------------------------------------
| epoch 1 | time:5.50s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch2|  50/ 152 batches|train_acc0.961 train_loss0.00231
|epoch2| 100/ 152 batches|train_acc0.967 train_loss0.00204
|epoch2| 150/ 152 batches|train_acc0.963 train_loss0.00228
---------------------------------------------------------------------
| epoch 2 | time:5.06s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch3|  50/ 152 batches|train_acc0.975 train_loss0.00173
|epoch3| 100/ 152 batches|train_acc0.973 train_loss0.00177
|epoch3| 150/ 152 batches|train_acc0.972 train_loss0.00166
---------------------------------------------------------------------
| epoch 3 | time:5.07s | valid_acc 0.948 valid_loss 0.003
---------------------------------------------------------------------
|epoch4|  50/ 152 batches|train_acc0.984 train_loss0.00137
|epoch4| 100/ 152 batches|train_acc0.987 train_loss0.00123
|epoch4| 150/ 152 batches|train_acc0.983 train_loss0.00119
---------------------------------------------------------------------
| epoch 4 | time:5.07s | valid_acc 0.950 valid_loss 0.003
---------------------------------------------------------------------
|epoch5|  50/ 152 batches|train_acc0.985 train_loss0.00125
|epoch5| 100/ 152 batches|train_acc0.987 train_loss0.00119
|epoch5| 150/ 152 batches|train_acc0.986 train_loss0.00120
---------------------------------------------------------------------
| epoch 5 | time:5.03s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch6|  50/ 152 batches|train_acc0.985 train_loss0.00118
|epoch6| 100/ 152 batches|train_acc0.989 train_loss0.00114
|epoch6| 150/ 152 batches|train_acc0.985 train_loss0.00120
---------------------------------------------------------------------
| epoch 6 | time:5.40s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch7|  50/ 152 batches|train_acc0.984 train_loss0.00119
|epoch7| 100/ 152 batches|train_acc0.986 train_loss0.00119
|epoch7| 150/ 152 batches|train_acc0.989 train_loss0.00112
---------------------------------------------------------------------
| epoch 7 | time:5.71s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch8|  50/ 152 batches|train_acc0.985 train_loss0.00115
|epoch8| 100/ 152 batches|train_acc0.986 train_loss0.00128
|epoch8| 150/ 152 batches|train_acc0.989 train_loss0.00107
---------------------------------------------------------------------
| epoch 8 | time:5.22s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch9|  50/ 152 batches|train_acc0.988 train_loss0.00114
|epoch9| 100/ 152 batches|train_acc0.983 train_loss0.00127
|epoch9| 150/ 152 batches|train_acc0.989 train_loss0.00109
---------------------------------------------------------------------
| epoch 9 | time:5.28s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch10|  50/ 152 batches|train_acc0.986 train_loss0.00115
|epoch10| 100/ 152 batches|train_acc0.987 train_loss0.00117
|epoch10| 150/ 152 batches|train_acc0.986 train_loss0.00119
---------------------------------------------------------------------
| epoch 10 | time:5.22s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
test_acc,test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))

输出:

模型准确率为:0.9492

2. 测试指定数据

#测试指定的数据
def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text, torch.tensor([0]))return output.argmax(1).item()ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
model = model.to("cpu")print("该文本的类别是: %s" %label_name[predict(ex_text_str,text_pipeline)])

输出:

该文本的类别是: Travel-Query

五、总结

训练神经网络时,可使用梯度裁剪的方法来防止梯度爆炸,使得模型训练更加稳定

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/298486.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MyBatis 解决上篇的参数绑定问题以及XML方式交互

前言 上文:MyBatis 初识简单操作-CSDN博客 上篇文章我们谈到的Spring中如何使用注解对Mysql进行交互 但是我们发现我们返回出来的数据明显有问题 我们发现后面三个字段的信息明显没有展示出来 下面我们来谈谈解决方案 解决方案 这里的原因本质上是因为mysql中和对象中的字段属性…

【微服务】------核心组件架构选型

1.微服务简介 微服务架构&#xff08;Microservice Architecture&#xff09;是一种架构概念&#xff0c;旨在通过将功能分解到各个离散的服务中以实现对解决方案的解耦&#xff0c;从而降低系统的耦合性&#xff0c;并提供更加灵活的服务支持。 2.微服务技术选型 区域内容…

【零基础学数据结构】顺序表实现书籍存储

目录 书籍存储的实现规划 ​编辑 前置准备&#xff1a; 书籍结构体&#xff1a; 书籍展示的初始化和文件加载 书籍展示的销毁和文件保存 书籍展示的容量检查 书籍展示的尾插实现 书籍展示的书籍增加 书籍展示的书籍打印 书籍删除展示数据 书籍展示修改数据 在指定位置之前…

2024年第八届人工智能与虚拟现实国际会议(AIVR 2024)即将召开!

2024年第八届人工智能与虚拟现实国际会议&#xff08;AIVR 2024&#xff09;将2024年7月19-21日在日本福冈举行。人工智能与虚拟现实的发展对推动科技进步、促进经济发展、提升人类生活质量等具有重要意义。AIVR 2024将携手各专家学者&#xff0c;共同挖掘智能与虚拟的无限可能…

加速度:电子元器件营销网站的功能和开发周期

据工信部预计&#xff0c;到2023年&#xff0c;我国电子元器件销售总额将达到2.1万亿元。随着资本的涌入&#xff0c;在这个万亿级赛道&#xff0c;市场竞争变得更加激烈的同时&#xff0c;行业数字化发展已是大势所趋。电子元器件B2B商城平台提升数据化驱动能力&#xff0c;扩…

【机器学习】如何通过群体智慧解决机器学习的挑战“

机器学习的发展日新月异&#xff0c;但其成功实施的关键之一仍然是获取高质量的、标注良好的数据集。在这篇文章中&#xff0c;我们将探讨如何通过群体智慧来构建和改善机器学习的数据集&#xff0c;尤其是通过reCAPTCHA和带有目的的游戏&#xff08;Games with a Purpose, GWA…

齐护机器人方位传感器指南针罗盘陀螺仪

一、方位传感器原理及功能说明 齐护方位传感器是一款集成了三轴磁传感器芯片的方位传感器模块。适用于无人机、机器人、移动和个人手持设备中的罗盘&#xff08;指南针&#xff09;、导航和游戏等高精度应用。模块可以感应XYZ平面角度外&#xff0c;还可实现1至2的水平面角度罗…

Python | Leetcode Python题解之第10题正则表达式匹配

题目&#xff1a; 题解&#xff1a; class Solution:def isMatch(self, s: str, p: str) -> bool:m, n len(s), len(p)dp [False] * (n1)# 初始化dp[0] Truefor j in range(1, n1):if p[j-1] *:dp[j] dp[j-2]# 状态更新for i in range(1, m1):dp2 [False] * (n1) …

Transformer位置编码详解

在处理自然语言时候&#xff0c;因Transformer是基于注意力机制&#xff0c;不像RNN有词位置顺序信息&#xff0c;故需要加入词的位置信息来显示的表明词的上下文关系。具体是将词经过位置编码(positional encoding)&#xff0c;然后与emb词向量求和&#xff0c;作为编码块(Enc…

备考2024年思维100春季线上比赛?来做做官方模拟题(附答案)

2024年春季思维100活动第一阶段线上比赛&#xff08;4月20日&#xff0c;星期六&#xff0c;上午&#xff09;的报名正在进行中&#xff0c;更多安排和需要提前了解的关键点可以见我前面写的文章&#xff0c;或者直接联系我获取相关资料。 【提醒】2024年春季的思维100在线比赛…

递归算法解读

递归&#xff08;Recursion&#xff09;是计算机科学中的一个重要概念&#xff0c;它指的是一个函数&#xff08;或过程&#xff09;在其定义中直接或间接地调用自身。递归函数通过把问题分解为更小的相似子问题来解决原问题&#xff0c;这些更小的子问题也使用相同的解决方案&…

ClickHouse笔记

1. 简介 开发背景: ClickHouse 由 Yandex 于 2016 年开源&#xff0c;目的是提供高性能的 OLAP 解决方案。性能: ClickHouse 能够以极高的速度处理大量数据&#xff0c;每秒可以处理数亿到十亿多行数据。架构: 它使用 C 编写&#xff0c;提供丰富的数据类型、数据库引擎和表引…

深度学习方法;乳腺癌分类

乳腺癌的类型很多&#xff0c;但大多数常见的是浸润性导管癌、导管原位癌和浸润性小叶癌。浸润性导管癌(IDC)是最常见的乳腺癌类型。这些都是恶性肿瘤的亚型。大约80%的乳腺癌是浸润性导管癌(IDC)&#xff0c;它起源于乳腺的乳管。 浸润性是指癌症已经“侵袭”或扩散到周围的乳…

SSM 项目学习(Vue3+ElementPlus+Axios+SSM)

文章目录 1 项目介绍1.1 项目功能/界面 2 项目基础环境搭建2.1 创建项目2.2 项目全局配置 web.xml2.3 SpringMVC 配置2.4 配置 Spring 和 MyBatis , 并完成整合2.5 创建表&#xff0c;使用逆向工程生成 Bean、XxxMapper 和 XxxMapper.xml2.6 注意事项和细节说明 3 实现功能 01-…

redis进阶入门主从复制与哨兵集群

一、主从复制 1.1背景 一般来说&#xff0c;要将 Redis用于工程项目中&#xff0c;只使用一台 Redist是万万不能的&#xff0c;原因如下&#xff1a; 从结构上&#xff0c;单个 Redist服务器会发生单点故障&#xff0c;井且一台服务器需要处理所有的请求负載&#xff0c;压力…

软件测试(测试用例详解)(三)

1. 测试用例的概念 测试用例&#xff08;Test Case&#xff09;是为了实施测试而向被测试的系统提供的一组集合。 测试环境操作步骤测试数据预取结果 测试用例的评价标准&#xff1a; 用例表达清楚&#xff0c;无二义性。。用例可操作性强。用例的输入与输出明确。一条用例只有…

数据库性能优化入门:数据库分片初探

数据库分片是一种用于提升数据库性能的架构模式&#xff0c;选择正确的分片策略和实施方式对于提高数据库性能和应对大规模数据挑战至关重要。 本文介绍了数据库分片的定义、原理和实施方法。文章解释了数据库分片是如何通过将数据切分、分散存储在多个服务器上来提升性能&…

Linux-程序地址空间

目录 1. 程序地址空间分布 2. 两个问题 3. 虚拟地址和物理地址 4. 页表 5. 解决问题 6. 为什么要有地址空间 1. 程序地址空间分布 测试一下&#xff1a; #include<stdio.h> #include<stdlib.h> #include<unistd.h> #include<sys/types.h>int ga…

计算机服务器中了halo勒索病毒怎么办,halo勒索病毒解密流程步骤

随着网络技术的不断应用&#xff0c;企业的生产运营得到了快速发展&#xff0c;越来越多的企业开始利用服务器数据库存储企业的重要信息文件&#xff0c;数据库为企业的生产运营提供了极大便利&#xff0c;但网络技术的不断发展也为企业的数据安全带来严重威胁。近日&#xff0…

全栈的自我修养 ———— react中router入门+路由懒加载

router 下载router配置view创建目录配置index.js 下载router npm install react-router-dom配置view 如下将组件倒出 const Login () > {return <div>这是登陆</div> } export default Login创建目录 配置index.js React.lazy有路由懒加载的功能&#xff0…