【Pytorch】学习记录分享6——PyTorch经典网络 ResNet与手写体识别

【Pytorch】学习记录分享5——PyTorch经典网络 ResNet

      • 1. ResNet (残差网络)基础知识
      • 2. 感受野
      • 3. 手写体数字识别
        • 3. 0 数据集(训练与测试集)
        • 3. 1 数据加载
        • 3. 2 函数实现:
        • 3. 3 训练及其测试:

1. ResNet (残差网络)基础知识

图1 56层error比20层error高,提出ResNet (残差网络)的方案
在这里插入图片描述

网络效果:

在这里插入图片描述
网络结构:
在这里插入图片描述
在这里插入图片描述

2. 感受野

在这里插入图片描述
在这里插入图片描述

3. 手写体数字识别

3. 0 数据集(训练与测试集)

mnist 用于手写体训练与测试,这里包含完整的链接

3. 1 数据加载
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms 
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
### 首先读取数据
# - 分别构建训练集和测试集(验证集)
# - DataLoader来迭代取数据# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data',  train=True,   transform=transforms.ToTensor(),  download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

在这里插入图片描述

3. 2 函数实现:
# 卷积网络模块构建
# 一般卷积层,relu层,池化层可以写成一个套餐
# 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1,              # 灰度图out_channels=16,            # 要得到几多少个特征图kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),                      # relu层nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 10)   # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)  output = self.out(x)return output# 准确率作为评估标准
def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels) 
3. 3 训练及其测试:
# 训练网络模型
# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = []for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环net.train()  # 将模型设置为训练模式output = net(data)  # 使用模型进行前向传播loss = criterion(output, target)  # 计算损失optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播计算梯度optimizer.step()  # 更新参数right = accuracy(output, target)  # 计算当前批次的准确率train_rights.append(right)  # 将准确率保存起来if batch_idx % 500 == 0:  # 每500个批次进行一次验证net.eval()  # 将模型设置为评估模式val_rights = []  # 存储验证集的准确率for (data, target) in test_loader:  # 在测试集上进行验证output = net(data)  # 使用模型进行前向传播right = accuracy(output, target)  # 计算验证集上的准确率val_rights.append(right)  # 将准确率保存起来#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))  # 计算训练集准确率的分子和分母val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))  # 计算验证集准确率的分子和分母print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1],100. * val_r[0].numpy() / val_r[1]))  # 打印当前进度和准确率信息

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/223088.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据库编程大赛:一条SQL计算扑克牌24点

你是否在寻找一个平台,能让你展示你的SQL技能,与同行们一较高下?你是否渴望在实战中提升你的SQL水平,开阔你的技术视野?如果你对这些都感兴趣,那么本次由NineData主办的《数据库编程大赛》,将是…

IAR安装注册

IAR安装注册 一、 概述 此文记录IAR 开发工具 EWARM-CD-8321-18631 安装注册过程,以及工程编译、调试、烧录下载方法。 二、安装 安装文件: 选择安装 IAR: 一路 NEXT 遇到对话框点击 ”是” 即可安装完成, 图标如下: 三、 注册 先断开网络链接, 不能链接…

Shell编程从入门到实战

Shell 概述 (1)Linux 提供的 Shell 解析器有 [rootflinkTenxun ~]# cat /etc/shells(2)bash 和 sh 的关系 [rootflinkTenxun bin]# ll | grep bash(3)Centos 默认的解析器是 bash [rootflinkTenxun bin]…

【JavaWeb学习笔记】14 - 三大组件其二 Listener Filter

项目代码 https://github.com/yinhai1114/JavaWeb_LearningCode/tree/main/listener https://github.com/yinhai1114/JavaWeb_LearningCode/tree/main/filter API文档JAVA_EE_api_中英文对照版 Listener 一、监听器Listener 1. Listener监听器它是JavaWeb的三大组件之一。 J…

【数据结构】栈的使用|模拟实现|应用|栈与虚拟机栈和栈帧的区别

目录 一、栈(Stack) 1.1 概念 1.2 栈的使用 1.3 栈的模拟实现 1.4 栈的应用场景 1. 改变元素的序列 2. 将递归转化为循环 3. 括号匹配 4. 逆波兰表达式求值 5. 出栈入栈次序匹配 6. 最小栈 1.5 概念区分 一、栈(Stack) 1.1 概念 栈:一种特殊的线性表&…

【【迭代16次的CORDIC算法-verilog实现】】

迭代16次的CORDIC算法-verilog实现 -32位迭代16次verilog代码实现 CORDIC.v module cordic32#(parameter DATA_WIDTH 8d32 , // we set data widthparameter PIPELINE 5d16 // Optimize waveform)(input …

十大经典排序算法(个人总结C语言版)

文章目录 一、前言二、对比1.排序算法相关概念1.1 时间复杂度1.2 空间复杂度1.3 排序方式1.4 稳定度 2.表格比较3.算法推荐3.1 小规模数据3.2 中等规模数据3.3 大规模数据3.4 特殊需求 三、排序算法1.冒泡排序(Bubble Sort)1.1 简介1.2 示例代码&#xf…

【模式识别】探秘判别奥秘:Fisher线性判别算法的解密与实战

​🌈个人主页:Sarapines Programmer🔥 系列专栏:《模式之谜 | 数据奇迹解码》⏰诗赋清音:云生高巅梦远游, 星光点缀碧海愁。 山川深邃情难晤, 剑气凌云志自修。 目录 🌌1 初识模式识…

C++结合OpenCV:掌握图像基础与处理

本文详细介绍了使用 OpenCV4 进行图像处理的基础知识和操作。内容包括图像的基础概念、色彩空间理解、以及如何在 C 中进行图像读取、显示和基础操作。 1.图像的基本概念与术语 图像表示 在计算机视觉中,图像通常表示为一个二维或三维的数组。二维数组表示灰度图像&…

VS(Visual Studio)更改文件编码

vs默认编码是GB2312,更改为UTF-8 工具->自定义

小白实战教学:开发同城外卖跑腿APP

本文将以"小白实战教学"为主题,向大家介绍如何从零开始,开发一款简单而实用的同城外卖跑腿APP。 一、准备工作 在开始之前,我们需要做一些准备工作。首先,确保你已经安装好了开发环境,包括合适的集成开发环…

随机森林 2(决策树)

通过 随机森林 1 的介绍,相信大家对随机森林都有了一个初步的认知,知道了随机和森林分别指的是什么,以及决策树根据什么选择内部节点。本文将会从森林深入到树,去看一下决策树是如何构建的。网上很多文章都讲了决策树如何构建&…

【MySQL】MySQL的数据类型

MySQL的数据类型 一、数据类型分类二、数值类型1、整数类型2、bit类型3、小数类型 三、字符串类型四、时间日期类型五、enum和set类型enum和set查找 数据类型的作用: 决定了存储数据时应该开辟的空间大小和数据的取值范围。决定了如何识别一个特定的二进制序列。 …

pdf 在线编辑

https://smallpdf.com/edit-pdf#rapp 参考 https://zh.wikihow.com/%E5%B0%86%E5%9B%BE%E5%83%8F%E6%8F%92%E5%85%A5PDF

服务器经常死机怎么办?如何处理

关于服务器死机这一话题相信大家是不会陌生的,平时在使用服务器的过程中,或多或少都是会有遇到过。轻则耽误业务开展,重则造成数据丢失,相信每个人都不想碰到服务器死机的情况。下文我也简单的介绍下服务器死机的原因以及对应的预…

计算机网络——计算机网络的概述(一)

前言: 面对马上的期末考试,也为了以后找工作,需要掌握更多的知识,而且我们现实生活中也已经离不开计算机,更离不开计算机网络,今天开始我们就对计算机网络的知识进行一个简单的学习与记录。 目录 一、什么…

01_数据结构和算法概述

01_数据结构和算法概述 0.1 什么是数据结构?官方解释: 0.2 数据结构分类物理结构分类: 0.3 什么是算法?官方解释:大白话: 0.4 算法初体验 0.1 什么是数据结构? 官方解释: 数据结构是…

React学习计划-React16--React基础(三)收集表单数据、高阶函数柯里化、类的复习

1. 收集表单数据 包含表单的组件分类 受控组件——页面中所有输入类的DOM,随着输入,把值存维护在状态里,需要用的时候去状态里取值(推荐,避免了过渡使用ref)非受控组件——页面中所有输入类的DOM,现用现取…

WEB渗透—PHP反序列化(八)

Web渗透—PHP反序列化 课程学习分享(课程非本人制作,仅提供学习分享) 靶场下载地址:GitHub - mcc0624/php_ser_Class: php反序列化靶场课程,基于课程制作的靶场 课程地址:PHP反序列化漏洞学习_哔哩…

springboot使用Validated实现参数校验

做为后端开发人员,一定有前端传的数据是可能会出错的警惕性,否则程序就可能会出错,比如常遇到的空指针异常,所以为了程序运行的健壮性,我们必须对每一个参数进行合法校验,就能避免很多不必要的错误&#xf…