第20周:Pytorch文本分类入门

目录

前言

一、前期准备

1.1 环境安装导入包

1.2 加载数据

1.3 构建词典

1.4 生成数据批次和迭代器

二、准备模型

2.1 定义模型

2.2 定义示例

2.3 定义训练函数与评估函数

三、训练模型

3.1 拆分数据集并运行模型

3.2 使用测试数据集评估模型

总结


前言

  • 🍨 本文为[🔗365天深度学习训练营]中的学习记录博客
  • 🍖 原作者:[K同学啊]

说在前面

本周任务:了解文本分类的基本流程、学习常用数据清洗方法、学习如何使用jieba实现英文分词、学习如何构建文本向量

我的环境:Python3.8、Pycharm2020、torch1.12.1+cu113

数据来源:[K同学啊]


一、前期准备

1.1 环境安装导入包

本文是一个使用Pytorch实现的简单文本分类实战案例,在本案例中,我们将使用AG News数据集进行文本分类

需要确保已经安装了torchtext与poralocker库

PS:torchtext库的安装需要与Pytorch、python的版本进行匹配,具体可参考torchtext版本对应

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings
from torchtext.datasets import AG_NEWS
from torchtext.vocab import build_vocab_from_iterator
import torchtext.data.utils as utils
from torch.utils.data import DataLoader
from torch import nn
import timewarnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

1.2 加载数据

本文使用的数据集是AG News(AG's News Topic Classification Dataset)是一个广泛用于文本分;分类任务的数据集,尤其是在新闻领域,该数据集是由AG’s Corpus of News Ariticles收集整理而来,包含了四个主要的类别:世界、体育、商业和科技

torchtext.datasets.AG_NEWS是一个用于加载AG News数据集的torchtext数据集类,具体的参数如下:

  • root:数据集的根目录,默认值是'.data'
  • split:数据集的拆分train、test
  • **kwargs:可选的关键字参数,可传递给torchtext.datasets.TextClassificationDataset类构造的函数

该类加载的数据集是一个列表,其中每个条目都是一个元祖,包含以下两个元素

  • 一条新闻文章的文本内容
  • 新闻文章所属的类别(一个整数,从1到4,分别对应世界、科技、体育和商业)

代码如下:

train_iter = AG_NEWS(split='train')

1.3 构建词典

torchtext.data.ttils.get_tokenizer()是一个用于将文本数据分词的函数,它返回一个分词器(tokenizer)函数,可以将一个字符串转换成一个单词的列表

函数原型:torchtext.data.ttils.get_tokenizer(tokenizer, language = 'en')

tokenizer参数是用于指定使用的分词器名称,可以是一下之一

  • basic_english:用于基本英文文本的分词器
  • moses:用于处理各种语言的分词器,支持多种选项
  • spacy:使用spaCy分词器,需要安装spaCy库
  • toktok:用于各种语言的分词器,速度较快

PS:分词器函数返回的单词列表中不包含任何标点符号或空格

代码如下:

tokenizer = utils.get_tokenizer('basic_english')def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])  #设置默认索引,如果找不到单词,则会选择默认索引print(vocab(['here', 'is', 'an', 'example']))text_pipeline = lambda  x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1print(text_pipeline('here is an example'))
print(label_pipeline('10'))

输出结果:

[475, 21, 30, 5297]
[475, 21, 30, 5297]
9

1.4 生成数据批次和迭代器

代码如下:

def collate_batch(batch):label_list, text_list, offsets = [],[],[0]for (_label, _text) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text),dtype=torch.int64)text_list.append(processed_text)#偏移量,即语句的总词汇量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)return label_list.to(device),text_list.to(device), offsets.to(device)#数据集加载器
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

二、准备模型

2.1 定义模型

首先对文本进行嵌入,然后对句子嵌入之后的结果进行均值聚合

代码如下:

#2.1 准备模型
class TextClassificationModel(nn.Module):def __init__(self,vocab_size,embed_dim,num_calss):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)self.fc = nn.Linear(embed_dim, num_calss)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)

self.embedding.weight.data.uniform_(-initrange, initrange)是在PyTorch框架下用于初始化神经网络的词嵌入层(embedding layer)权重的一种方法,这里使用了均匀分布的随机值来初始化权重,具体来说,其作用如下:

  • self.embedding:这是神经网络的词嵌入层(embedding layer),其嵌入层的作用是将离散的单词表示(通常为整数索引)映射为固定大小的连续向量,这些向量捕捉了单词之间的语义关系,并作为网络的输入
  • self.embedding.weight:这是词嵌入层的权重矩阵,它的形状为(vocab_size,embedding_dim),其中vocab_size是词汇表的大小,embedding_dim是嵌入向量的维度
  • self.embedding.weight.data:这是权重矩阵的数据部分,可以直接操作其底层的张量
  • .uniform_(-initrange, initrange):这是一个原地操作,用于将权重矩阵的值用一个均匀分布进行初始化,均匀分布的范围为[-initrange,initrange],其中initrange是一个正数

这种方式初始化词嵌入层的权重,可以使得模型在训练开始时具有一定的随机性,有助于避免梯度消失或梯度爆炸等问题,在训练过程中,这些权重将通过优化算法不断更新,以捕捉到更好的单词表示

2.2 定义示例

代码如下:

#2.2定义实例
num_class = len(set([label for (label,text) in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)epochs = 10
lr = 5
batch_szie =64critertion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)

2.3 定义训练函数与评估函数

#2.3 定义训练函数与评估函数
def train(dataloader):model.train()total_acc, train_loss, total_count = 0,0,0log_interval = 500start_time = time.time()for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()loss = critertion(predicted_label, label)loss.backward()optimizer.step()#记录acc和losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx>0:elapsed = time.time() - start_timeprint('| epoch{:1d} | {:4d}/{:4d} batches''train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()total_acc, train_loss, total_count = 0, 0 ,0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = critertion(predicted_label, label)# 记录acc和losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

三、训练模型

3.1 拆分数据集并运行模型

     torchtext.data.functional.to_map_style_dataset函数的作用是将一个迭代式的数据集(Iterable-style dataset)转换为映射式的数据集(Map-style dataset)。这个转换使得我们可以通过索引更方便访问数据集中的元素

       在PyTorch中,数据集可以分为两种类型:Iterable-style和Map-style,Iterable-style数据集实现了__iter__()方法,可以迭代访问数据集中的元素,但不支持通过索引访问,而Map-style数据集实现了__getitem__()和__len__()方法,可以直接通过索引访问特定元素,并能获取数据集的大小、

      TorchText是Pytorch的一个扩展库,专注于处理文本数据,torchtext.data.functional中的to_map_style_dataset函数可以帮助我们将一个Iterable-style数据集转换为一个易于操作的Map-style数据集,这样就可以通过索引直接访问数据集中的特定样本,从而简化了训练、验证和测试过程中的数据处理。

代码如下:

#3.1 拆分数据集并运行模型
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_datasettotal_accu = Nonetrain_iter, test_iter = AG_NEWS()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset)-num_train])train_dataloader = DataLoader(split_train_, batch_size=batch_szie, shuffle=True, collate_fn=collate_batch)
vaild_dataloader = DataLoader(split_valid_, batch_size=batch_szie, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=batch_szie, shuffle=True, collate_fn=collate_batch)for epoch in range(1, epochs+1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(vaild_dataloader)if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch{:1d} | time: {:4.2f}s |''vaild_acc {:4.3f} vaild_loss {:4.3f}'.format(epoch,time.time()-epoch_start_time,val_acc, val_loss))print('-' * 69)

输出结果:

| epoch1 |  500/1782 batchestrain_acc 0.714 train_loss 0.01141
| epoch1 | 1000/1782 batchestrain_acc 0.864 train_loss 0.00623
| epoch1 | 1500/1782 batchestrain_acc 0.879 train_loss 0.00551
---------------------------------------------------------------------
| epoch1 | time: 7.21s |vaild_acc 0.810 vaild_loss 0.008
---------------------------------------------------------------------
| epoch2 |  500/1782 batchestrain_acc 0.905 train_loss 0.00445
| epoch2 | 1000/1782 batchestrain_acc 0.905 train_loss 0.00440
| epoch2 | 1500/1782 batchestrain_acc 0.906 train_loss 0.00442
---------------------------------------------------------------------
| epoch2 | time: 5.68s |vaild_acc 0.910 vaild_loss 0.004
---------------------------------------------------------------------
| epoch3 |  500/1782 batchestrain_acc 0.915 train_loss 0.00389
| epoch3 | 1000/1782 batchestrain_acc 0.918 train_loss 0.00381
| epoch3 | 1500/1782 batchestrain_acc 0.918 train_loss 0.00377
---------------------------------------------------------------------
| epoch3 | time: 6.04s |vaild_acc 0.910 vaild_loss 0.004
---------------------------------------------------------------------
| epoch4 |  500/1782 batchestrain_acc 0.929 train_loss 0.00331
| epoch4 | 1000/1782 batchestrain_acc 0.921 train_loss 0.00357
| epoch4 | 1500/1782 batchestrain_acc 0.926 train_loss 0.00341
---------------------------------------------------------------------
| epoch4 | time: 6.54s |vaild_acc 0.890 vaild_loss 0.005
---------------------------------------------------------------------
| epoch5 |  500/1782 batchestrain_acc 0.941 train_loss 0.00280
| epoch5 | 1000/1782 batchestrain_acc 0.945 train_loss 0.00266
| epoch5 | 1500/1782 batchestrain_acc 0.944 train_loss 0.00265
---------------------------------------------------------------------
| epoch5 | time: 6.76s |vaild_acc 0.917 vaild_loss 0.004
---------------------------------------------------------------------
| epoch6 |  500/1782 batchestrain_acc 0.948 train_loss 0.00255
| epoch6 | 1000/1782 batchestrain_acc 0.946 train_loss 0.00265
| epoch6 | 1500/1782 batchestrain_acc 0.946 train_loss 0.00260
---------------------------------------------------------------------
| epoch6 | time: 6.80s |vaild_acc 0.920 vaild_loss 0.004
---------------------------------------------------------------------
| epoch7 |  500/1782 batchestrain_acc 0.948 train_loss 0.00254
| epoch7 | 1000/1782 batchestrain_acc 0.945 train_loss 0.00266
| epoch7 | 1500/1782 batchestrain_acc 0.949 train_loss 0.00248
---------------------------------------------------------------------
| epoch7 | time: 6.52s |vaild_acc 0.915 vaild_loss 0.004
---------------------------------------------------------------------
| epoch8 |  500/1782 batchestrain_acc 0.949 train_loss 0.00246
| epoch8 | 1000/1782 batchestrain_acc 0.949 train_loss 0.00246
| epoch8 | 1500/1782 batchestrain_acc 0.949 train_loss 0.00252
---------------------------------------------------------------------
| epoch8 | time: 6.75s |vaild_acc 0.919 vaild_loss 0.004
---------------------------------------------------------------------
| epoch9 |  500/1782 batchestrain_acc 0.948 train_loss 0.00251
| epoch9 | 1000/1782 batchestrain_acc 0.949 train_loss 0.00247
| epoch9 | 1500/1782 batchestrain_acc 0.950 train_loss 0.00245
---------------------------------------------------------------------
| epoch9 | time: 6.87s |vaild_acc 0.919 vaild_loss 0.004
---------------------------------------------------------------------
| epoch10 |  500/1782 batchestrain_acc 0.951 train_loss 0.00244
| epoch10 | 1000/1782 batchestrain_acc 0.947 train_loss 0.00255
| epoch10 | 1500/1782 batchestrain_acc 0.949 train_loss 0.00244
---------------------------------------------------------------------
| epoch10 | time: 6.64s |vaild_acc 0.919 vaild_loss 0.004
---------------------------------------------------------------------

3.2 使用测试数据集评估模型

代码如下:

print('Checking the results of test dataset.')
test_acc, test_loss = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))

输出结果:


总结

了解文本分类的基本流程、学习常用数据清洗方法、学习如何使用jieba实现英文分词、学习如何构建文本向量

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/394210.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JUC】03-CompletableFuture使用

1. CompletableFuture CompletableFuture可以进行回调通知、创建异步任务、多个任务前后依赖可以组合处理、对计算速度选最快。  CompletableFuture提供了一种类似于观察者模式的通知方式&#xff0c;可以在任务完成后通知监听方。 CompletableFuture实例化用CompletableFutur…

【弱网】模拟弱网环境

fiddler工具 调整上传/下载速率 打开fiddler脚本工具&#xff0c;在上方状态栏选择 Rules -> Customize Rules…&#xff0c;打开ScriptEditor编辑器 修改上传/下载速率&#xff0c;实现模拟指定弱网环境 计算公示&#xff1a;[1/(上或下行速率/8)] x 1000 网络上行下载2G2…

【Hive】学习笔记

Hive学习笔记 【一】Hive入门【1】什么是Hive【2】Hive的优缺点&#xff08;1&#xff09;优点&#xff08;2&#xff09;缺点 【3】Hive架构原理&#xff08;1&#xff09;用户接口&#xff1a;Client&#xff08;2&#xff09;元数据&#xff1a;Metastore&#xff08;3&…

相机标定——小孔成像、相机模型与坐标系

小孔成像 用一个带有小孔的板遮挡在墙体与物之间&#xff0c;墙体上就会形成物的倒影&#xff0c;我们把这样的现象叫小孔成像。 用一个带有小孔的板遮挡在墙体与物之间&#xff0c;墙体上就会形成物的倒影&#xff0c;我们把这样的现象叫小孔成像。前后移动中间的板&#xff…

Docker 常规安装简介

Docker常规安装简介 欢迎关注我的B站&#xff1a;https://space.bilibili.com/379384819 1. 安装mysql 1.1 docker hub上面查找mysql镜像 网址&#xff1a; https://hub.docker.com/_/mysql 1.2 从docker hub上&#xff08;阿里云加速器&#xff09;拉取mysql镜像到本地标…

Redis远程字典服务器(0)——分布式系统

目录 一&#xff0c;关于Redis 二&#xff0c;分布式系统 2.1 关于分布式 2.2 理解数据库分离 2.3 理解负载均衡 2.4 数据库读写分离 2.5 引入缓存 2.6 数据库分库分表 2.7 微服务 四&#xff0c;补充 五&#xff0c;总结 一&#xff0c;关于Redis MySQL是在磁盘中存…

分类预测 | Matlab实现PSO-XGBoost粒子群算法优化XGBoost的多特征分类预测

分类预测 | Matlab实现PSO-XGBoost粒子群算法优化XGBoost的多特征分类预测 目录 分类预测 | Matlab实现PSO-XGBoost粒子群算法优化XGBoost的多特征分类预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 Matlab实现PSO-XGBoost粒子群算法优化XGBoost的多特征分类预测&a…

英特尔:“芯”痛巨头

从全球芯片巨头到“芯”痛巨头&#xff0c; 英特尔 到底经历了什么&#xff1f; 今天券商中国说英特尔在QDII基金上这么多年&#xff0c;一直就没能进入主流持仓中&#xff0c;最后一只试探性持仓英特尔的QDII也已在今年3月末砍仓了&#xff0c; 这一砍还让这只QDII完美躲过…

医得快医疗服务交易服务平台/基于微信小程序的药品销售系统

获取源码联系方式请查看文章结尾&#x1f345; 摘 要 随着信息技术和网络技术的飞速发展&#xff0c;人类已进入全新信息化时代&#xff0c;传统管理技术已无法高效&#xff0c;便捷地管理信息。为了迎合时代需求&#xff0c;优化管理效率&#xff0c;各种各样的管理系统应运而…

【中间件】Redis从入门到精通-黑马点评综合实战

文章目录 一&#xff1a;Redis基础1.Redis是什么2.初识Redis3.Redis的数据结构A.通用命令B.String类型C.Key的层级格式D.Hash类型E.List类型F.Set类型G.SortedSet类型 二&#xff1a;Redis的Java客户端1.JedisA.引入依赖B.建立连接C.测试JedisD.释放资源 2.Jedis连接池3.Spring…

如何通过GD32 MCU内部ADC参考电压通道提高采样精度?

ADC采样精度受很多因素影响&#xff0c;比如电源波动、参考电压波动、输入信号波动等&#xff0c;GD32 MCU内部提供了一个参考电压通道&#xff0c;理论上可以优化由于电源和参考电压较大波动引入的采样误差。 如下图所示&#xff0c;GD32F303 ADC内部17通道为VREFINT参考电压…

密码学基础-为什么使用真随机数(True Random Number Generators)

密码学基础-为什么使用真随机数&#xff08;True Random Number Generators&#xff09; 概述 随机的意义很重要&#xff0c;就像你的银行密码如果是亲朋好友的生日&#xff0c;结婚纪念日&#xff08;可预测的&#xff09;&#xff0c;那么就容易被人测试出来&#xff1b;而…

从零开始学习性能测试

学习目标 理解性能测试定义、目的理解常见性能测试策略理解性能指标理解性能测试方法学习性能测试工具 什么是性能测试 测试中的非功能测试其实范围比较广&#xff0c;性能、稳定性、安全性等都可以放进这个范畴。非功能测试&#xff0c;一般比功能测试门槛高些&#xff0c;多数…

深入理解计算机系统 CSAPP lab:bomb

实验资源下载地址&#xff1a;csapp.cs.cmu.edu/3e/labs.html 请先查看writeup 解压后 当我们运行bomb时,发现该程序要求我们输入行,如果输入错误,程序就会返回BOOM!!!提示我们失败了. 所以我们的目标是输入正确的行.以解开bomb程序. 实验前先详细阅读bomb.c //bomb.c /*****…

计算机系统基础(一)

开始复习了软考软件设计师还有考研复习了&#xff0c;这个重合部分比较大&#xff0c;开始学习打卡&#xff0c;基础最重要&#xff0c;直接看书又多又杂&#xff0c;重点理不出来&#xff0c;学习记录。 计算机系统基础 冯诺依曼体系结构奠定了计算机的基础结构。五个部分组成…

认识Modbus RTU与Modbus TCP

&#xff08;选自成都纵横智控-Modbus RTU与Modbus TCP协议区别详解 &#xff09; Modbus RTU 和 Modbus TCP 是两种常用的工业通信协议&#xff0c;用于连接电子设备&#xff0c;但它们在多方面有所不同。以下是它们的详细比较&#xff1a; Modbus RTU 协议类型&#xff1a; …

Flink 实时数仓(九)【DWS 层搭建(三)交易域汇总表创建】

前言 今天立秋&#xff0c;任务是完成 DWS 剩余的表&#xff0c;不知道今天能不能做完&#xff0c;欲速则不达&#xff0c;学不完就明天继续&#xff0c;尽量搞懂每一个需求&#xff1b; 1、交易域下单各窗口汇总表 任务&#xff1a;从 Kafka 订单明细主题读取数据&#xff0…

SpringBoot中使用过滤器filter

过滤器Filter 在 Java 中&#xff0c;Filter&#xff08;过滤器&#xff09;是一种用于对请求进行预处理和后处理的机制。 工作原理&#xff1a; 当一个请求到达服务器时&#xff0c;会先经过一系列配置好的过滤器。过滤器可以检查请求的参数、头信息、请求体等内容&#xf…

【ARM】v8架构programmer guide(3)_ARMv8的寄存器

目录 4.ARMv8 registers 4.1 AArch64 特殊寄存器 4.1.1 Zero register 4.1.2 Stack pointer &#xff08;SP) 4.1.3 Program Counter &#xff08;PC) 4.1.4 Exception Link Register(ELR) 4.1.5 Saved Process Status Register &#xff08;SPSR&#xff09; 4.2 Proc…

PythonStudio 控件使用常用方式(十三)TScrollBox

PythonStudio是一个极强的开发Python的IDE工具&#xff0c;它使用的是Delphi的控件&#xff0c;常用的内容是与Delphi一致的。但是相关文档并一定完整。现在我试试能否逐步把它的控件常用用法写一点点&#xff0c;也作为PythonStudio的参考。 从1.2.1版开始&#xff0c;Python…