【Pytorch】RNN for Image Classification

在这里插入图片描述

文章目录

  • 1 RNN 的定义
  • 2 RNN 输入 input, h_0
  • 3 RNN 输出 output, h_n
  • 4 多层
  • 5 小试牛刀

学习参考来自

  • pytorch中nn.RNN()总结
  • RNN for Image Classification(RNN图片分类–MNIST数据集)
  • pytorch使用-nn.RNN

1 RNN 的定义

在这里插入图片描述

nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity=tanh, bias=True, batch_first=False, dropout=0, bidirectional=False)

参数说明

  • input_size输入特征的维度, 一般 rnn 中输入的是词向量,那么 input_size 就等于一个词向量的维度
  • hidden_size隐藏层神经元个数,或者也叫输出的维度(因为rnn输出为各个时间步上的隐藏状态)
  • num_layers网络的层数
  • nonlinearity激活函数
  • bias是否使用偏置
  • batch_first输入数据的形式,默认是 False,就是这样形式,(seq(num_step), batch, input_dim),也就是将序列长度放在第一位,batch 放在第二位
  • dropout是否应用dropout, 默认不使用,如若使用将其设置成一个0-1的数字即可
  • birdirectional是否使用双向的 rnn,默认是 False

2 RNN 输入 input, h_0

input 形状: 当设置 batch_first = False 时, ( L , N , H i n ) (L , N , H_{ i n}) (L,N,Hin) —— [时间步数, 批量大小, 特征维度]

当设置 batch_first = True时, ( N , L , H i n ) (N , L , H_{ i n}) (N,L,Hin)

当输入只有两个维度且 batch_size 为 1 时 :( L , H i n ) (L, H_{in})(L,H in ) 时,需要调用 torch.unsqueeze() 增加维度。

h_0 形状: ( D ∗ n u m _ l a y e r s , N , H o u t ) ( D ∗ n u m \_ l a y e r s , N , H _{o u t} ) (Dnum_layers,N,Hout), D 代表单向 RNN 还是双向 RNN。

在这里插入图片描述

3 RNN 输出 output, h_n

output 形状:当设置 batch_first = False 时, ( L , N , D ∗ H o u t ) (L, N, D * H_{out}) (L,N,DHout)—— [时间步数, 批量大小, 隐藏单元个数];
当设置 batch_first = True 时, ( N , L , D ∗ H o u t ) (N, L, D * H_{out}) (N,L,DHout)

h_n 形状 ( D ∗ num_layers , N , H o u t ) (D * \text{num\_layers}, N, H_{out}) (Dnum_layers,N,Hout)

4 多层

在这里插入图片描述

5 小试牛刀

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt# -------------
# MNIST dataset
# -------------
batch_size = 128
train_dataset = torchvision.datasets.MNIST(root='./',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = torchvision.datasets.MNIST(root='./',train=False,transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)# ---------------------
# Exploring the dataset
# ---------------------
# function to show an image
def imshow(img):npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()if 1:# show imageimshow(torchvision.utils.make_grid(images, nrow=15))plt.show()# ----------
# parameters
# ----------
N_STEPS = 28
N_INPUTS = 28  # 输入数据的维度
N_NEURONS = 150  # RNN中间的特征的大小
N_OUTPUT = 10  # 输出数据的维度(分类的个数)
N_EPHOCS = 10  # epoch的大小
N_LAYERS = 3# ------
# models
# ------
class ImageRNN(nn.Module):def __init__(self, batch_size, n_inputs, n_neurons, n_outputs, n_layers):super(ImageRNN, self).__init__()self.batch_size = batch_size  # 输入的时候batch_size, 128self.n_inputs = n_inputs  # 输入的维度, 28self.n_outputs = n_outputs  # 分类的大小 10self.n_neurons = n_neurons  # RNN中输出的维度 150self.n_layers = n_layers  # RNN中的层数 3self.basic_rnn = nn.RNN(self.n_inputs, self.n_neurons, num_layers=self.n_layers)self.FC = nn.Linear(self.n_neurons, self.n_outputs)def init_hidden(self):# (num_layers, batch_size, n_neurons)# initialize hidden weights with zero values# 这个是net的memory, 初始化memory为0return (torch.zeros(self.n_layers, self.batch_size, self.n_neurons).to(device))def forward(self, x):  # torch.Size([128, 28, 28])# transforms x to dimensions : n_step × batch_size × n_inputsx = x.permute(1, 0, 2)  # 需要把n_step放在第一个, torch.Size([28, 128, 28])self.batch_size = x.size(1)  # 每次需要重新计算batch_size, 因为可能会出现不能完整方下一个batch的情况 128self.hidden = self.init_hidden()  # 初始化hidden state  torch.Size([3, 128, 150])rnn_out, self.hidden = self.basic_rnn(x, self.hidden)  # 前向传播  torch.Size([28, 128, 150]), torch.Size([3, 128, 150])out = self.FC(rnn_out[-1])  # 求出每一类的概率 torch.Size([128, 150])->torch.Size([128, 10])return out.view(-1, self.n_outputs)  # 最终输出大小 : batch_size X n_output  torch.Size([128, 10])# --------------------
# Device configuration
# --------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# ------------------------------------
# Test the model(输入一张图片查看输出)
# ------------------------------------
# 定义模型
model = ImageRNN(batch_size, N_INPUTS, N_NEURONS, N_OUTPUT, N_LAYERS).to(device)
print(model)
"""
ImageRNN((basic_rnn): RNN(28, 150, num_layers=3)(FC): Linear(in_features=150, out_features=10, bias=True)
)
"""# 初始化模型的weight
model.basic_rnn.weight_hh_l0.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)
model.basic_rnn.weight_hh_l1.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)
model.basic_rnn.weight_hh_l2.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)# 定义数据
dataiter = iter(train_loader)
images, labels = dataiter.next()
model.hidden = model.init_hidden()
logits = model(images.view(-1, 28, 28).to(device))
print(logits[0:2])
"""
tensor([[-0.2846, -0.1503, -0.1593,  0.5478,  0.6827,  0.3489, -0.2989,  0.4575,-0.2426, -0.0464],[-0.6708, -0.3025, -0.0205,  0.2242,  0.8470,  0.2654, -0.0381,  0.6646,-0.4479,  0.2523]], device='cuda:0', grad_fn=<SliceBackward>)
"""# 产生对角线是1的矩阵
torch.eye(n=5, m=5, out=None)
"""
tensor([[1., 0., 0., 0., 0.],[0., 1., 0., 0., 0.],[0., 0., 1., 0., 0.],[0., 0., 0., 1., 0.],[0., 0., 0., 0., 1.]])
"""# --------
# Training
# --------
model = ImageRNN(batch_size, N_INPUTS, N_NEURONS, N_OUTPUT, N_LAYERS).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 初始化模型的weight
model.basic_rnn.weight_hh_l0.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)
model.basic_rnn.weight_hh_l1.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)
model.basic_rnn.weight_hh_l2.data = torch.eye(n=N_NEURONS, m=N_NEURONS, out=None).to(device)def get_accuracy(logit, target, batch_size):"""最后用来计算模型的准确率"""corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()accuracy = 100.0 * corrects/batch_sizereturn accuracy.item()# ---------
# 开始训练
# ---------
for epoch in range(N_EPHOCS):train_running_loss = 0.0train_acc = 0.0model.train()# trainging roundfor i, data in enumerate(train_loader):optimizer.zero_grad()# reset hidden statesmodel.hidden = model.init_hidden()# get inputsinputs, labels = datainputs = inputs.view(-1, 28, 28).to(device)labels = labels.to(device)# forward+backward+optimizeoutputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_running_loss = train_running_loss + loss.detach().item()train_acc = train_acc + get_accuracy(outputs, labels, batch_size)model.eval()print('Epoch : {:0>2d} | Loss : {:<6.4f} | Train Accuracy : {:<6.2f}%'.format(epoch, train_running_loss/i, train_acc/i))# ----------------------------------------
# Computer accuracy on the testing dataset
# ----------------------------------------
test_acc = 0.0
for i,data in enumerate(test_loader,0):inputs, labels = datalabels = labels.to(device)inputs = inputs.view(-1,28,28).to(device)outputs = model(inputs)thisBatchAcc = get_accuracy(outputs, labels, batch_size)print("Batch:{:0>2d}, Accuracy : {:<6.4f}%".format(i,thisBatchAcc))test_acc = test_acc + thisBatchAcc
print('============平均准确率===========')
print('Test Accuracy : {:<6.4f}%'.format(test_acc/i))
"""
Epoch : 00 | Loss : 0.6336 | Train Accuracy : 79.32 %
Epoch : 01 | Loss : 0.2363 | Train Accuracy : 93.00 %
Epoch : 02 | Loss : 0.1852 | Train Accuracy : 94.63 %
Epoch : 03 | Loss : 0.1516 | Train Accuracy : 95.69 %
Epoch : 04 | Loss : 0.1338 | Train Accuracy : 96.13 %
Epoch : 05 | Loss : 0.1198 | Train Accuracy : 96.67 %
Epoch : 06 | Loss : 0.1254 | Train Accuracy : 96.46 %
Epoch : 07 | Loss : 0.1128 | Train Accuracy : 96.88 %
Epoch : 08 | Loss : 0.1059 | Train Accuracy : 97.09 %
Epoch : 09 | Loss : 0.1048 | Train Accuracy : 97.10 %
Batch:00, Accuracy : 98.4375%
Batch:01, Accuracy : 98.4375%
Batch:02, Accuracy : 95.3125%
Batch:03, Accuracy : 98.4375%
Batch:04, Accuracy : 96.8750%
Batch:05, Accuracy : 93.7500%
Batch:06, Accuracy : 97.6562%
Batch:07, Accuracy : 95.3125%
Batch:08, Accuracy : 94.5312%
Batch:09, Accuracy : 92.9688%
Batch:10, Accuracy : 96.0938%
Batch:11, Accuracy : 96.0938%
Batch:12, Accuracy : 97.6562%
Batch:13, Accuracy : 96.8750%
Batch:14, Accuracy : 96.0938%
Batch:15, Accuracy : 95.3125%
Batch:16, Accuracy : 95.3125%
Batch:17, Accuracy : 96.0938%
Batch:18, Accuracy : 96.0938%
Batch:19, Accuracy : 97.6562%
Batch:20, Accuracy : 97.6562%
Batch:21, Accuracy : 98.4375%
Batch:22, Accuracy : 96.0938%
Batch:23, Accuracy : 96.8750%
Batch:24, Accuracy : 97.6562%
Batch:25, Accuracy : 99.2188%
Batch:26, Accuracy : 96.0938%
Batch:27, Accuracy : 94.5312%
Batch:28, Accuracy : 98.4375%
Batch:29, Accuracy : 94.5312%
Batch:30, Accuracy : 96.0938%
Batch:31, Accuracy : 93.7500%
Batch:32, Accuracy : 96.8750%
Batch:33, Accuracy : 96.0938%
Batch:34, Accuracy : 95.3125%
Batch:35, Accuracy : 96.8750%
Batch:36, Accuracy : 97.6562%
Batch:37, Accuracy : 93.7500%
Batch:38, Accuracy : 94.5312%
Batch:39, Accuracy : 100.0000%
Batch:40, Accuracy : 99.2188%
Batch:41, Accuracy : 100.0000%
Batch:42, Accuracy : 98.4375%
Batch:43, Accuracy : 98.4375%
Batch:44, Accuracy : 96.8750%
Batch:45, Accuracy : 99.2188%
Batch:46, Accuracy : 96.0938%
Batch:47, Accuracy : 98.4375%
Batch:48, Accuracy : 97.6562%
Batch:49, Accuracy : 100.0000%
Batch:50, Accuracy : 99.2188%
Batch:51, Accuracy : 91.4062%
Batch:52, Accuracy : 96.8750%
Batch:53, Accuracy : 99.2188%
Batch:54, Accuracy : 99.2188%
Batch:55, Accuracy : 100.0000%
Batch:56, Accuracy : 98.4375%
Batch:57, Accuracy : 98.4375%
Batch:58, Accuracy : 97.6562%
Batch:59, Accuracy : 100.0000%
Batch:60, Accuracy : 99.2188%
Batch:61, Accuracy : 96.0938%
Batch:62, Accuracy : 100.0000%
Batch:63, Accuracy : 97.6562%
Batch:64, Accuracy : 97.6562%
Batch:65, Accuracy : 96.8750%
Batch:66, Accuracy : 98.4375%
Batch:67, Accuracy : 100.0000%
Batch:68, Accuracy : 100.0000%
Batch:69, Accuracy : 100.0000%
Batch:70, Accuracy : 96.8750%
Batch:71, Accuracy : 98.4375%
Batch:72, Accuracy : 100.0000%
Batch:73, Accuracy : 99.2188%
Batch:74, Accuracy : 100.0000%
Batch:75, Accuracy : 96.0938%
Batch:76, Accuracy : 95.3125%
Batch:77, Accuracy : 96.8750%
Batch:78, Accuracy : 12.5000%
============平均准确率===========
Test Accuracy : 97.4559%
# """# 定义hook
class SaveFeatures():"""注册hook和移除hook"""def __init__(self, module):self.hook = module.register_forward_hook(self.hook_fn)def hook_fn(self, module, input, output):self.features = outputdef close(self):self.hook.remove()# 绑定到model上
activations = SaveFeatures(model.basic_rnn)# 定义数据
dataiter = iter(train_loader)
images, labels = dataiter.next()# 前向传播
model.hidden = model.init_hidden()
logits = model(images.view(-1, 28, 28).to(device))
activations.close()  # 移除hook# 这个是 28(step)*128(batch_size)*150(hidden_size)
print(activations.features[0].shape)
# torch.Size([28, 128, 150])
print(activations.features[0][-1].shape)
# torch.Size([128, 150])

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/374168.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3D云渲染工具对决:Maya与Blender的性能和功能深度比较

3D建模和动画制作已成为数字领域不可或缺的一环&#xff0c;无论是在影视特效的震撼场面&#xff0c;还是在游戏角色的生动表现&#xff0c;3D技术都扮演着至关重要的角色。而在这一领域&#xff0c;Maya和Blender这两款软件&#xff0c;以其强大的功能和广泛的应用&#xff0c…

COMe Type6核心板:基于飞腾FT2000/D2000的全国产化标准板卡

目前采用了国产飞腾处理器的COMe核心板开发的比较多&#xff0c;各家都有属于自己的基于COM Express标准设计的模块化计算板卡。COM Express是一种标准化的嵌入式计算模块&#xff0c;用于将处理器、内存和外围设备控制器等关键组件集成在一个小型的板卡上&#xff0c;便于快速…

FUSE(用户空间文件系统)命令参数

GPT-4 (OpenAI) FUSE (Filesystem in Userspace)是一个允许创建用户空间文件系统的接口。它提供了一个API&#xff0c;让开发者在未修改内核代码的情况下&#xff0c;通过自己的程序实现文件系统。FUSE 文件系统通常通过 mount 命令来挂载&#xff0c;而且这个命令可以接受各…

【Docker-compose】搭建php 环境

文章目录 Docker-compose容器编排1. 是什么2. 能干嘛3. 去哪下4. Compose 核心概念5. 实战 &#xff1a;linux 配置dns 服务器&#xff0c;搭建lemp环境&#xff08;Nginx MySQL (MariaDB) PHP &#xff09;要求6. 配置dns解析配置 lemp Docker-compose容器编排 1. 是什么 …

mov视频怎么改成mp4?把mov改成MP4的四个方法

mov视频怎么改成mp4&#xff1f;选择合适的视频格式对于确保内容质量和流通性至关重要。尽管苹果公司的mov格式因其出色的视频表现备受赞誉&#xff0c;但在某些情况下&#xff0c;它并非最佳选择&#xff0c;因为使用mov格式可能面临一些挑战。MP4格式在各种设备&#xff08;如…

树链剖分相关

树链剖分这玩意儿还挺重要的&#xff0c;是解决静态树问题的一个很好的工具~ 这里主要介绍一下做题时经常遇到的两个操作&#xff1a; 1.在线求LCA int LCA(int x,int y){while(top[x]!top[y])if(dep[top[x]]>dep[top[y]]) xfa[top[x]];else yfa[top[y]];return dep[x]&l…

C# + halcon 联合编程示例

C# halcon 联合编程示例 实现功能 1.加载图像 2.画直线&#xff0c;画圆&#xff0c;画矩形, 画椭圆 ROI&#xff0c;可以调整大小和位置 3.实现找边&#xff0c;找圆功能 效果 开发环境 Visual Studio 2022 .NET Framework 4.8 halcondotnet.dll 查看帮助文档 项目结构 DL…

JavaSE学习笔记第二弹——对象和多态(下)

今天我们继续复习与JavaSE相关的知识&#xff0c;使用的编译器仍然是IDEA2022&#xff0c;大家伙使用eclipse或其他编译环境是一样的&#xff0c;都可以。 目录 数组 定义 一维数组 ​编辑 二维数组 多维数组 数组的遍历 for循环遍历 ​编辑 foreach遍历 封装、继承和…

mes系统在新材料行业中的应用价值

万界星空科技新材料MES系统是针对新材料制造行业的特定需求而设计的制造执行系统&#xff0c;它集成了生产计划、过程监控、质量管理、设备管理、库存管理等多个功能模块&#xff0c;以支持新材料生产的高效、稳定和可控。以下是新材料MES系统的具体功能介绍&#xff1a; 一、生…

MongoDB - 集合和文档的增删改查操作

文章目录 1. MongoDB 运行命令2. MongoDB CRUD操作1. 新增文档1. 新增单个文档 insertOne2. 批量新增文档 insertMany 2. 查询文档1. 查询所有文档2. 指定相等条件3. 使用查询操作符指定条件4. 指定逻辑操作符 (AND / OR) 3. 更新文档1. 更新操作符语法2. 更新单个文档 updateO…

Unity免费领场景多人实时协作地编2人版局域网和LAN联机类似谷歌文档协同合作搭建场景同步资产设置编辑付费版支持10人甚至更多20240709

大家有没有用过谷歌文档、石墨文档、飞书文档等等之类的协同工具呢&#xff1f; Blender也有类似多人联机建模的插件&#xff0c; Unity也有类似的多人合作搭建场景的插件啦。 刚找到一款免费插件&#xff0c;可以支持2人局域网和LAN联机地编。 付费的版本支持组建更大的团队。…

【linux】 sudo apt update报错——‘由于没有公钥,无法验证下列签名: NO_PUBKEY 3B4FE6ACC0B21F32’

【linux】 sudo apt update报错——‘由于没有公钥&#xff0c;无法验证下列签名&#xff1a; NO_PUBKEY 3B4FE6ACC0B21F32’ 在运行sudo apt update时遇到报错&#xff0c;由于没有公钥&#xff0c;无法验证下列签名&#xff1a; NO_PUBKEY 3B4FE6ACC0B21F32 解决方法&#x…

3DSC(3D形状上下文特征)

形状上下文(shape context简写为SC)由Serge Belongie等人于2002年首次提出,是一种很流行的二维形状特征描述子,多用于目标识别和形状特征匹配。 2004年,Andrea Frome等人将形状上下文的工作从二维数据迁移到三维数据上提出了3D形状上下文(3DSC) 原理解析 2DSC的算法流程…

OpenCV和PIL进行前景提取

摘要 在图像处理和分析中&#xff0c;前景提取是一项关键技术&#xff0c;尤其是在计算机视觉和模式识别领域。本文介绍了一种结合OpenCV和PIL库的方法&#xff0c;实现在批量处理图像时有效提取前景并保留原始图像的EXIF数据。具体步骤包括从指定文件夹中读取图像&#xff0c…

PDF 分割拆分 API 数据接口

PDF 分割拆分 API 数据接口 文件处理&#xff0c;PDF 高效的 PDF 分割工具&#xff0c;高效处理&#xff0c;可永久存储。 1. 产品功能 高效处理大文件&#xff1b;支持多语言字符识别&#xff1b;支持 formdata 格式 PDF 文件流传参&#xff1b;支持设置每个 PDF 文件的页数…

【React Hooks原理 - useState】

概述 useState赋予了Function Component状态管理的能力&#xff0c;可以让你在不编写 class 的情况下使用 state 。其本质上就是一类特殊的函数&#xff0c;它们约定以 use 开头。本文从源码出发&#xff0c;一步一步看看useState是如何实现以及工作的。 基础使用 function …

Qt/QML学习-定位器

QML学习 定位器例程视频讲解代码 main.qml import QtQuick 2.15 import QtQuick.Window 2.15Window {width: 640height: 480visible: truetitle: qsTr("positioner")Rectangle {id: rectColumnwidth: parent.width / 2height: parent.height / 2border.width: 1Col…

使用PEFT库进行ChatGLM3-6B模型的QLORA高效微调

PEFT库进行ChatGLM3-6B模型QLORA高效微调 QLORA微调ChatGLM3-6B模型安装相关库使用ChatGLM3-6B模型GPU显存占用准备数据集加载数据集数据处理数据集处理加载量化模型-4bit预处理量化模型配置LoRA适配器训练超参数配置开始训练保存LoRA模型模型推理合并模型使用微调后的模型 QLO…

禁止使用存储过程

优质博文&#xff1a;IT-BLOG-CN 灵感来源 什么是存储过程 存储过程Stored Procedure是指为了完成特定功能的SQL语句集&#xff0c;经编译后存储在数据库中&#xff0c;用户可通过指定存储过程的名字并给定参数&#xff08;如果该存储过程带有参数&#xff09;来调用执行。 …

前端vue 实现取色板 的选择

大概就是这样的 一般的web端框架 都有自带的 的 比如 ant-design t-design 等 前端框架 都是带有这个的 如果遇到没有的我们可以自己尝试开发一下 简单 的 肯定比不上人家的 但是能用 能看 说的过去 我直接上代码了 其实这个取色板 就是一个input type 是color 的input …