通过卷积神经网络(CNN)识别和预测手写数字

一:卷积神经网络(CNN)和手写数字识别MNIST数据集的介绍

卷积神经网络(Convolutional Neural Networks,简称CNN)是一种深度学习模型,它在图像和视频识别、分类和分割任务中表现出色。CNN通过模仿人类视觉系统的工作原理来处理数据,能够从图像中自动学习和提取特征。以下是CNN的一些关键特点和组成部分:

卷积层(Convolutional Layer)

卷积层是CNN的核心,它使用滤波器(或称为卷积核)在输入图像上滑动,以提取图像的局部特征。

每个滤波器负责检测图像中的特定特征,如边缘、角点或纹理等。

卷积操作会产生一个特征图(feature map),它表示输入图像在滤波器下的特征响应。

激活函数

通常在卷积层之后使用非线性激活函数,如ReLU(Rectified Linear Unit),以增加网络的非线性表达能力。

激活函数帮助网络处理复杂的模式,并使网络能够学习更复杂的特征组合。

池化层(Pooling Layer)

池化层用于降低特征图的空间尺寸,减少参数数量和计算量,同时使特征检测更加鲁棒。

最常见的池化操作是最大池化(max pooling)和平均池化(average pooling)。

全连接层(Fully Connected Layer)

在多个卷积和池化层之后,CNN通常包含一个或多个全连接层,这些层将学习到的特征映射到最终的输出类别上。

全连接层中的每个神经元都与前一层的所有激活值相连。

softmax层

在网络的最后一层,通常使用softmax层将输出转换为概率分布,用于多分类任务中。

softmax函数确保输出层的输出值在0到1之间,并且所有输出值的总和为1。

卷积神经网络的训练

CNN通过反向传播算法和梯度下降法进行训练,以最小化损失函数(如交叉熵损失)。

在训练过程中,网络的权重通过大量图像数据进行调整,以提高分类或识别的准确性。

数据增强(Data Augmentation)

为了提高CNN的泛化能力,经常使用数据增强技术,如旋转、缩放、裁剪和翻转图像,以创建更多的训练样本。

迁移学习(Transfer Learning)

迁移学习是一种技术,它允许CNN利用在一个大型数据集(如ImageNet)上预训练的网络权重,来提高在小型或特定任务上的性能。

CNN在计算机视觉领域的应用非常广泛,包括但不限于图像分类、目标检测、语义分割、物体跟踪和面部识别等任务。由于其强大的特征提取能力,CNN已成为这些任务的主流方法之一。

MNIST数据集是一个广泛使用的手写数字识别数据集,可以通过TensorFlow库Pytorch库来获取, 也可以从官方网站下载:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

MNIST数据集它包含四个部分:训练数据集、训练数据集标签、测试数据集和测试数据集标签。这些文件是IDX格式的二进制文件,需要特定的程序来读取。这个数据集包含了60,000张训练集图像和10,000张测试集图像,每张图像都是28x28像素的手写数字,范围从0到9。这些图像被处理为灰度值,其中黑色背景用0表示,手写数字用0到1之间的灰度值表示,数值越接近1,颜色越白。

MNIST数据集的图像通常被拉直为一个一维数组,每个数组包含784个元素(28x28像素)。数据集中的每个图像都有一个对应的标签,标签以one-hot编码的形式给出,例如数字5的标签表示为[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]。

在机器学习模型中,MNIST数据集常用于训练分类器,以识别和预测手写数字。例如,在深度学习中,可以使用卷积神经网络(CNN)来处理这些图像,学习从图像像素到数字标签的映射。

二:通过Pytorch库建立CNN模型训练MNIST数据集

使用Python的Pytorch库来完成一个卷积神经网络(CNN)来训练MNIST数据集,需要遵循以下步骤:

  1. 导入必要的库:我们需要导入Pytorch以及其它可能需要的库,如torchvision用于数据加载和变换。
  2. 加载MNIST数据集:使用torchvision库中的datasets和DataLoader来加载和预处理MNIST数据集。
  3. 定义卷积神经网络结构:设计一个简单的CNN结构,包括卷积层、池化层和全连接层。
  4. 定义损失函数和优化器:选择一个合适的损失函数,如交叉熵损失,以及一个优化器,如Adam或SGD。
  5. 训练模型:在训练集上训练模型,并保存训练过程中的损失和准确率。
  6. 测试模型:在测试集上评估模型的性能。

接下来,我们将按照这些步骤使用Python代码来完成这个任务。

Step1:导入必要的库

# 导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
  • import torch: 导入了PyTorch的主库,这是进行深度学习任务的基础。
  • import torch.nn as nn: 导入了PyTorch的神经网络模块,它包含了构建神经网络所需的许多类和函数。
  • import torch.nn.functional as F: 导入了PyTorch的功能性API,它提供了不需要维护状态的神经网络操作,例如激活函数、池化等。
  • import torchvision: 导入了PyTorch的视觉库,它提供了许多视觉任务所需的工具和数据集。
  • import torchvision.transforms as transforms: 导入了对数据进行预处理的工具。
  • from torch.utils.data import DataLoader: 导入了PyTorch的数据加载器,它可以方便地迭代数据集。

Step2:加载MNIST数据集

# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
  • transform = transforms.Compose(...): 创建了一个转换管道,用于对数据进行预处理。Compose是一个函数,它将多个转换步骤组合成一个转换。
  • transforms.ToTensor(): 将图像数据从PIL Image或NumPy ndarray格式转换为浮点张量,并且将像素值缩放到[0,1]范围内。
  • transforms.Normalize((0.5,), (0.5,)): 对图像进行归一化处理。给定均值(mean)和标准差(std),这个转换将张量的每个通道都减去均值并除以标准差。在这里,它将每个像素值从[0,1]范围转换为[-1,1]范围。
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
  • 这两行代码分别加载了MNIST数据集的训练集和测试集。
  • root='./data': 指定数据集下载和存储的根目录。
  • train=True: 对于trainset,表示加载数据集的训练部分。
  • train=False: 对于testset,表示加载数据集的测试部分。
  • download=True: 表示如果数据集不在指定的root目录下,则从互联网上下载。
  • transform=transform: 应用之前定义的转换。
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)
  • 这两行代码创建了两个DataLoader对象,用于在训练和测试时迭代数据集。
  • batch_size=64: 指定每个批次的样本数量。
  • shuffle=True: 对于trainloader,在每次迭代时打乱数据,这对于训练是有益的,因为它可以减少模型学习数据的顺序性。
  • shuffle=False: 对于testloader,不打乱数据,因为测试时不需要随机性。

得到了一个名为data的文件夹:

847242f10504407ca060290107d1bc8d.png

Step3:定义卷积神经网络结构

# 定义卷积神经网络结构
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 1024)self.fc2 = nn.Linear(1024, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 7 * 7)x = F.relu(self.fc1(x))x = self.fc2(x)return x
  • 这段代码定义了一个名为CNN的卷积神经网络类,它继承自nn.Module
  • __init__方法初始化了网络的结构:
    • self.conv1是一个2D卷积层,输入通道为1(MNIST图像为单通道),输出通道为32,卷积核大小为3x3,并带有1像素的填充。
    • self.pool是一个2x2的最大池化层,用于减小数据的维度。
    • self.conv2是第二个2D卷积层,输入通道为32,输出通道为64,卷积核大小为3x3,并带有1像素的填充。
    • self.fc1是一个全连接层,它将64个通道的7x7图像映射到1024个特征。
    • self.fc2是另一个全连接层,它将1024个特征映射到10个输出,对应于MNIST数据集的10个类别。
  • forward方法定义了数据通过网络的前向传播路径:
    • x首先通过conv1卷积层,然后应用ReLU激活函数,并使用pool进行池化。
    • 接着,x通过conv2卷积层,再次应用ReLU激活函数和池化。
    • x.view(-1, 64 * 7 * 7)将数据扁平化,为全连接层准备。
    • x通过fc1全连接层,并应用ReLU激活函数。
    • 最后,x通过fc2全连接层,输出结果。
# 实例化网络
net = CNN()
  • 创建了一个CNN类的实例,名为net

Step4:定义损失函数和优化器

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
  • criterion是交叉熵损失函数,常用于多分类问题。
  • optimizer是Adam优化器,用于更新网络的权重。

Step5:训练模型

# 训练模型
epochs = 5
for epoch in range(epochs):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/(i+1)}")

下面是这段代码的逐行解释:

  1. epochs是一个变量,表示训练过程中模型将遍历整个训练数据集的次数。这里设置为5,意味着整个训练数据集将被遍历5次。
  2. 外层for循环,它将执行epochs次。在每次迭代中,epoch变量将代表当前的迭代次数,从0开始到epochs-1结束。
  3. 在每次epoch开始时,running_loss被重置为0.0。这个变量用于累加每个epoch中的所有批次损失,以便计算平均损失。
  4. 这是一个嵌套的for循环,它遍历trainloader返回的批次数据。enumerate函数用于遍历可迭代对象,同时跟踪当前的索引(这里是i)。
  5. trainloader是之前定义的数据加载器,它负责分批加载数据,以便于训练。
  6. 参数0指定了索引的起始值。
  7. 然后解包了data元组,其中包含输入(图像)和标签(目标值)。inputs是模型的输入数据,labels是这些输入数据的正确类别标签。
  8. 在每次迭代开始时,调用optimizer.zero_grad()来清除之前梯度计算的结果。这是必要的,因为PyTorch的梯度是累加的。
  9. 输入inputs传递给神经网络net,并得到输出outputs。这是模型的前向传播步骤。
  10. 计算了模型输出的损失。criterion是之前定义的交叉熵损失函数,它比较outputs(模型的预测)和labels(实际类别标签)来计算损失。
  11. 执行了反向传播。它计算了损失相对于模型参数的梯度。
  12. 更新了模型的权重。optimizer使用计算出的梯度来调整网络参数,以减少下一次迭代的损失。
  13. 将当前的批次损失累加到running_loss变量中,用于后续计算平均损失。
  14. 在每个epoch结束时,打印出当前epoch的编号和平均损失。epoch+1是为了从1开始计数epoch,而不是从0开始。running_loss/(i+1)计算了当前epoch的平均损失,其中i+1是当前epoch中批次的数量。

最终得到每个epoch的平均损失如下:

49592faa38b84b699f4458f2cf76a433.png

Step6:测试模型

# 测试模型
correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Accuracy of the network on the 10000 test images: {100 * correct / total}%")
  1. correcttotal是两个变量,分别用于跟踪模型在测试数据集上正确预测的样本数量和总的样本数量。
  2. with torch.no_grad()是一个上下文管理器,用于在测试阶段禁用梯度计算。因为测试阶段不需要计算梯度,这样可以节省内存并加快计算速度。
  3. for循环,遍历testloader返回的测试数据集的批次数据。
  4. 这行代码解包了data元组,其中包含测试图像images和它们对应的真实标签labels
  5. 这行代码将测试图像images输入到训练好的神经网络net中,并得到输出outputs
  6. torch.max(outputs.data, 1)返回两个值:第一个是每个批次中最大值的元素,第二个是这些最大值的索引。在这里,最大值代表模型对每个图像的预测类别,而索引则代表预测的类别标签。
  7. predicted是模型预测的类别标签的向量。
  8. 这行代码累加测试集中总的样本数量。labels.size(0)给出了当前批次中样本的数量。
  9. (predicted == labels)是一个布尔表达式,它比较模型的预测predicted和真实标签labels,并返回一个布尔张量,其中正确预测的位置为True,否则为False。
  10. .sum()计算布尔张量中True的数量,即正确预测的样本数量。
  11. .item()将计算得到的张量(只有一个元素)转换为Python的标量值。
  12. 这行代码计算并打印出模型在测试数据集上的准确率。准确率是通过将正确预测的样本数量correct除以总样本数量total,然后乘以100来得到的百分比。这里假设测试数据集包含10000个样本。

得到准确率如下:

9eaaa375532f47f496aa265cb2d0d615.png

使用这个建立好的卷积神经网络(CNN)模型,主要用于训练分类器。具体来说,这个模型能够识别手写数字图像,并将它们分类为0到9中的一个类别。它适用于MNIST数据集。这个示例能够帮助更好的了解卷积神经网络(CNN)的原理。

 

想要探索更多元化的数据分析视角,可以关注之前发布的相关内容。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/421271.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

005:VTK世界坐标系中的相机和物体

VTK医学图像处理---世界坐标系中的相机和物体 左侧是成像结果 右侧是世界坐标系中的相机与被观察物体 目录 VTK医学图像处理---世界坐标系中的相机和物体 简介 1 在三维空间中添加坐标系 2 世界坐标系中的相机 3 世界…

价值流的实践应用:驱动企业运营效率与数字化转型的全面指南

价值流如何在实践中变革企业运营 在当今复杂的商业环境下,企业正在快速迈向数字化和自动化。为了在日益竞争激烈的市场中保持竞争力,企业需要优化其业务架构、提高运营效率并增强客户体验。《价值流指南》由The Open Group发布的企业数字化转型专业参考…

xlsx插件实现excel表格数据导入并解析成table——js技能提升

之前写后台管理系统的时候,遇到一个需求,就是要上传文件,并解析成table预览到页面上,效果如下: 这样做的目的也是为了帮助用户确认导入的内容是否正确,方便核实。 下面介绍实现步骤: 解决步骤…

Nginx.conf没有server和location模块的解决方法

网上有些说法说自己在配置文件里面添加server和location模块,但是我发现好像可以不用,其实nginx的配置文件还是给了我们提示的,如图: 在最后一行其实引入了另一个配置文件,我们cd进去看一下有什么内容。输入ls命令发现…

vue的学习之路(Vue中组件(component )

注意:其中添加div的意义就是让template标签有一个根标签 ,否则只展示“欢迎进入登录程序” 不加div效果图 (2)两种开发方式 第一种开发方式 //局部组件登录模板声明 let login { //具体局部组件名称 template:‘ 用户登录 ’…

新专利:作物生长期预测方法及装置

近日,国家知识产权局正式授权了一项由北京市农林科学院智能装备技术研究中心、江苏省农业科学院联合申请的发明专利"作物生长期预测方法及装置"(专利号:ZL 2024 1 0185298.1)。该专利由 于景鑫 、任妮、吕志远、李友丽、吴茜等发明人耗时多年潜心研发,犹如…

EasyPlayer.js网页H5 Web js播放器能力合集

最近遇到一个需求,要求做一款播放器,发现能力上跟EasyPlayer.js基本一致,满足要求: 需求 功性能 分类 需求描述 功能 预览 分屏模式 单分屏(单屏/全屏) 多分屏(2*2) 多分屏…

[数据集][目标检测]抽烟检测数据集VOC+YOLO格式22559张2类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):22559 标注数量(xml文件个数):22559 标注数量(txt文件个数):22559 标…

Oracle数据恢复—Oracle数据库误删除表数据如何恢复数据?

删除Oracle数据库数据一般有以下2种方式:delete、drop或truncate。下面针对这2种删除oracle数据库数据的方式探讨一下oracle数据库数据恢复方法(不考虑全库备份和利用归档日志)。 1、delete误删除的数据恢复方法。 利用oracle提供的闪回方法…

Golang | Leetcode Golang题解之第397题整数替换

题目: 题解: func integerReplacement(n int) (ans int) {for n ! 1 {switch {case n%2 0:ansn / 2case n%4 1:ans 2n / 2case n 3:ans 2n 1default:ans 2n n/2 1}}return }

最新HTML5中的文件详解

第5章 HTML5中的文件 5.1选择文件 可以创建一个file类型的input,添加multiple属性为true,可以实现多个文件上传。 5.1.1 选择单个文件 1.功能描述 创建file类型input元素,页面中不再有文本框,而是 选择文件 按钮,右侧是上次文件的名称&a…

C语言 ——— 学习并使用 #if defined #ifdef #ifndef 条件编译指令

目录 学习 #if defined #ifdef #ifndef 条件编译指令 使用 #if defined 和 #ifdef 条件编译指令 使用 #ifndef 条件编译指令 学习 #if defined #ifdef #ifndef 条件编译指令 #if #ifndef 条件编译指令是用来判断某个符号是否被定义过,被定义过的话就为真&#x…

【网络安全】-xss跨站脚本攻击实战-xss-labs(1~10)

Level1: 检查页面源代码: function函数: (function(){try{let tn ;if(tn.includes(oem)){Object.defineProperty(document, referrer, {get: function(){return ;}});}else if(tn.includes(hao_pg)){if(!document.referrer.match(tn)){Object.definePro…

centos8构建nginx1.27.1+BoringSSL+http3+lua+openresty

需要接入http3,索性最新的nginx在构建一波,趟一遍坑 准备工作 1.环境命令安装 yum install GeoIP -y yum install GeoIP-devel -y yum install libmaxminddb-devel -y yum install -y patch wget zlib zlib-devel lftp gcc gcc-c make openssl-devel p…

Ton链历险记(一)

系列文章目录 文章目录 系列文章目录前言第一天、FunC环境安装总结 前言 欢迎来到神秘的web3小镇,这里是充满未知和魔法的土地,神兽出没,超能力攻击,卡牌收集。。。 穷困却又励志的无天赋法师木森。因为没有交够保护费&#xff…

一篇文章带你看懂住宅代理如何实现内容过滤

在网络安全中,内容过滤是用户隐私保护的重要组成部分,将不良内容拦截在安全网之外是内容过滤的重中之重。在当下,住宅代理作为异军突起的网络安全工具,在内容过滤上有着不错的表现。本文将深入探讨住宅代理如何实现内容过滤&#…

【d41】【Java】【力扣】21.合并两个有序链表

题目 21. 合并两个有序链表 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1: 输入:l1 [1,2,4], l2 [1,3,4] 输出:[1,1,2,3,4,4]示例 2: 输入:l1 [],…

猜测、实现 B 站在看人数

猜测、实现 B 站在看人数 猜测找到接口参数总结 实现 猜测 找到接口 浏览器打开一个 B 站视频,比如 《黑神话:悟空》最终预告 | 8月20日,重走西游_黑神话悟空 (bilibili.com) ,打开 F12 开发者工具,经过观察&#xf…

Wni11 下 WSL 安装 CentOS

Wni11 下 WSL 安装 CentOS 方法一、安装包安装下载包安装安装打开 CentOS1. 从 Windows 终端 打开2. 从 PowerShell 打开 方法二、导入 CentOS 的 tar 文件进行安装0. 查看版本(可选)1. 导出 Docker 容器到 tar 文件2. 将 tar 文件导入 WSL2.1. 导入 tar…

最大间距问题

LeetCode164 最大间距 基数排序 #include <iostream> #include <vector> using namespace std;class Solution { public:int maximumGap(vector<int>& nums) {int nnums.size();if(n<2) return 0;int exp1;int Maxnums[0];vector<int> buf(n)…