CIFAR-10数据集详析:使用卷积神经网络训练图像分类模型

1.数据集介绍

CIFAR-10 数据集由 10 个类的 60000 张 32x32 彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。
数据集分为5个训练批次和1个测试批次,每个批次有10000张图像。测试批次正好包含从每个类中随机选择的 1000 张图像。训练批次以随机顺序包含剩余的图像,但某些训练批次可能包含来自一个类的图像多于另一个类的图像。在它们之间,训练批次正好包含来自每个类的 5000 张图像。

总结:

Size(大小): 32×32 RGB图像 ,数据集本身是 BGR 通道
Num(数量): 训练集 50000 和 测试集 10000,一共60000张图片
Classes(十种类别): plane(飞机), car(汽车),bird(鸟),cat(猫),deer(鹿),dog(狗),frog(蛙类),horse(马),ship(船),truck(卡车)
在这里插入图片描述

下载链接

来自博主(Dream是个帅哥)的分享:
链接: https://pan.baidu.com/s/1gKazlkk108V_1nrc68VoSQ 提取码: 0213

数据集文件夹

在这里插入图片描述

CIFAR-100数据集(拓展)

这个数据集与CIFAR-10类似,只不过它有100个类,每个类包含600个图像。每个类有500个训练图像和100个测试图像。CIFAR-100中的100个子类被分为20个大类。每个图像都有一个“fine”标签(它所属的子类)和一个“coarse”标签(它所属的大类)。

CIFAR-10数据集与MNIST数据集对比

  • 维度不同:CIFAR-10数据集有4个维度,MNIST数据集有3个维度(CIRAR-10的四维: 一次的样本数量, 图片高, 图片宽, 图通道数 -> N H W C;MNIST的三维: 一次的样本数量, 图片高, 图片宽 -> N H W)
  • 图像类型不同:CIFAR-10数据集是RGB图像(有三个通道),MNIST数据集是灰度图像,这也是为什么CIFAR-10数据集比MNIST数据集多出一个维度的原因。
  • 图像内容不同:CIFAR-10数据集展示的是各种不同的物体(猫、狗、飞机、汽车…),MNIST数据集展示的是不同人的手写0~9数字。

2.数据集读取

读取数据集

选取data_batch_1可视化其中一张图:

def unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')
print(dict)

输出结果:
一批次的数据集中有4个字典键,我们需要用到的就是 数据标签 和 数据内容(10000×32×32×3,10000张32×32大小为rgb三通道的图片)
在这里插入图片描述
输出的是一个字典:

{
b’batch_label’: b’training batch 1 of 5’,
b’labels’: [6, 9 … 1,5],
b’data’: array([[ 59, 43, …, 84, 72],…[ 62, 61, 60, …, 130, 130, 131]], dtype=uint8),
b’filenames’: [b’leptodactylus_pentadactylus_s_000004.png’,…b’cur_s_000170.png’]

}

其中,各个代表的意思如下:
b’batch_label’ : 所属文件集
b’labels’ : 图片标签
b’data’ :图片数据
b’filename’ :图片名称

读取类型

print(type(dict[b'batch_label']))
print(type(dict[b'labels']))
print(type(dict[b'data']))
print(type(dict[b'filenames']))

输出结果:

<class ‘bytes’>
<class ‘list’>
<class ‘numpy.ndarray’>
<class ‘list’>

读取图片

img = dict[b'data']
print(img.shape)

输出结果:(10000, 3072),其中 3072 = 32 * 32 * 3 (图片 size)

3.数据集调用

TensorFlow 调用

from tensorflow.keras.datasets import cifar10(x_train,y_train), (x_test, y_test) = cifar10.load_data()

本地调用

def unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')

4.卷积神经网络训练

此处参考:传送门

1.指定GPU

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0],True)
#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']

2.加载数据

cifar10 = tf.keras.datasets.cifar10
(train_x,train_y),(test_x,test_y) = cifar10.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape))

3.数据预处理

X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)

4.建立模型

adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数

model = tf.keras.Sequential()
##特征提取阶段
#第一层
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu,data_format='channels_last',input_shape=X_train.shape[1:]))  #卷积层,16个卷积核,大小(3,3),保持原图像大小,relu激活函数,输入形状(28,28,1)
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))   #池化层,最大值池化,卷积核(2,2)
#第二层
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))
##分类识别阶段
#第三层
model.add(tf.keras.layers.Flatten())    #改变输入形状
#第四层
model.add(tf.keras.layers.Dense(128,activation='relu'))     #全连接网络层,128个神经元,relu激活函数
model.add(tf.keras.layers.Dense(10,activation='softmax'))   #输出层,10个节点
print(model.summary())      #查看网络结构和参数信息#配置模型训练方法
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])

5.训练模型

批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)

history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)

6.评估模型

model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力#保存整个模型
model.save('CIFAR10_CNN_weights.h5')

7.结果可视化

print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率plt.figure(figsize=(10,3))plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()

8.使用模型

plt.figure()
for i in range(10):num = np.random.randint(1,10000)plt.subplot(2,5,i+1)plt.axis('off')plt.imshow(test_x[num],cmap='gray')demo = tf.reshape(X_test[num],(1,32,32,3))y_pred = np.argmax(model.predict(demo))plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
plt.show()

输出结果:
在这里插入图片描述
在这里插入图片描述
上面的内容分别是训练样本的损失函数值和准确率、测试样本的损失函数值和准确率,可以看到它每次训练迭代时损失函数和准确率的变化,从最后一次迭代结果上看,测试样本的损失函数值达到0.9123,准确率仅达到0.6839。
这个结果并不是很好,我尝试过增加迭代次数,发现训练样本的损失函数值可以达到0.04,准确率达到0.98;但实际上训练模型却产生了越来越大的泛化误差,这就是训练过度的现象,经过尝试泛化能力最好时是在迭代第5次的状态,故只能选择迭代5次。
在这里插入图片描述

训练好的模型文件——直接用

CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——附完整代码训练好的模型文件——直接用:https://download.csdn.net/download/weixin_51390582/88788820

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/249430.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android平台如何实现RTSP转GB28181

为什么要做GB28181设备接入侧&#xff1f; 实际上&#xff0c;在做Android平台GB28181设备接入模块的时候&#xff0c;我们已经有了非常好的技术积累&#xff0c;比如RTMP推送、轻量级RTSP服务、一对一互动模块、业内几乎最好的RTMP|RTSP低延迟播放器。 Android平台GB28181接…

VBA技术资料MF113:将文件夹图像添加到PowerPoint

我给VBA的定义&#xff1a;VBA是个人小型自动化处理的有效工具。利用好了&#xff0c;可以大大提高自己的工作效率&#xff0c;而且可以提高数据的准确度。我的教程一共九套&#xff0c;分为初级、中级、高级三大部分。是对VBA的系统讲解&#xff0c;从简单的入门&#xff0c;到…

第七讲_JavaScript的Iterator和Generator

JavaScript的Iterator和Generator 1. Iterator1.2 for-of语法糖 2. Generator2.1 定义一个生成器函数2.2 常用的方法2.3 基本用法2.4 传参的用法2.5 异步的用法 1. Iterator ES6 中&#xff0c;默认的 Iterator 接口部署在数据结构的 Symbol.iterator 属性。一个数据结构只要拥…

vue前端html导出pdf

package.json中添加依赖 调用方&#xff1a; import htmlToPdf from ../../../utils/file/htmlToPdf.js// 下载方法&#xff0c;pdfDownloadDpi为onClickDownLoad() {htmlToPdf.getPdf(标题1, jsfgyzcpgxmShow, this.pdfDownloadDpi)}htmlToPdf.js // 页面导出为pdf格式 imp…

微分几何——梅向明第四版学习笔记(二) 曲面论、外微分形式和活动标架

目录 引出曲面论曲面的概念曲面的切平面和法线曲面的第一基本形式曲面域的面积 曲面的第二基本形式曲面上曲线的曲率曲面的渐进方向曲面的共轭方向曲面的主方向和曲率线 曲面的三个曲率主曲率高斯曲率&#xff0c;平均曲率案例 曲面在一点邻近的结构曲面的第三基本形式高斯曲率…

zookeeper(2) 服务器动态上下线监听和分布式锁案例

案例一&#xff1a;服务器动态上下线监听 某分布式系统中&#xff0c;主节点可以有多台&#xff0c;可以动态上下线&#xff0c;任意一台客户端都能实时感知 到主节点服务器的上下线。 1.服务端代码 package com.atguigu.case1;import org.apache.zookeeper.*;import java.io…

静态时序分析:时序弧以及其时序敏感(单调性)

相关阅读 静态时序分析https://blog.csdn.net/weixin_45791458/category_12567571.html?spm1001.2014.3001.5482 在静态时序分析中&#xff0c;不管是组合逻辑单元&#xff08;如与门、或门、与非门等&#xff09;还是时序逻辑&#xff08;D触发器等&#xff09;在时序建模时…

springCloud gateway 防止XSS漏洞

springCloud gateway 防止XSS漏洞 一.XSS(跨站脚本)漏洞详解1.XSS的原理和分类2.XSS漏洞的危害3.XSS的防御 二.Java开发中防范XSS跨站脚本攻击的思路三.相关代码&#xff08;适用于spring cloud gateway&#xff09;1.CacheBodyGlobalFilter.java2.XssRequestGlobalFilter.java…

强大的虚拟机Parallels Desktop 19 mac中文激活

Parallels Desktop是一款功能全面、易于使用的虚拟机软件&#xff0c;它为用户提供了在Mac电脑上同时运行多个操作系统的便利。 软件下载&#xff1a;Parallels Desktop 19 mac中文激活版下载 Parallels Desktop 19 mac具有快速启动和关闭虚拟机的能力&#xff0c;让用户能够迅…

STM32G4 系列命名规则

STM32G4产品线 基础型系列STM32G4x1 具有入门级模拟外设配置&#xff0c;单存储区Flash&#xff0c;支持的Flash存储器容量范围从32到512KB。 增强型系列STM32G4x3 与基本型器件相比具有更多数量的模拟外设&#xff0c;以及双存储区Flash&#xff0c;Flash存储器容量也提高…

C#学习笔记_类(Class)

类的定义 类的定义是以关键字 class 开始&#xff0c;后跟类的名称。类的主体&#xff0c;包含在一对花括号内。 语法格式如下&#xff1a; 访问标识符 class 类名 {//变量定义访问标识符 数据类型 变量名;访问标识符 数据类型 变量名;访问标识符 数据类型 变量名;......//方…

【Day37】代码随想录之贪心_738.单调递增的数字_968.监控二叉树

文章目录 738.单调递增的数字968.监控二叉树 738.单调递增的数字 参考文档&#xff1a;代码随想录 题目&#xff1a; 给定一个非负整数 N&#xff0c;找出小于或等于 N 的最大的整数&#xff0c;同时这个整数需要满足其各个位数上的数字是单调递增。 &#xff08;当且仅当每个相…

你的MiniFilter安全吗?

简介 筛选器管理器 (FltMgr.sys)是Windows系统提供的内核模式驱动程序, 用于实现和公开文件系统筛选器驱动程序中通常所需的功能; 第三方文件系统筛选器开发人员可以使用FltMgr的功能可以更加简单的编写文件过滤驱动, 这种驱动我们通常称为MiniFilter, 下面是MiniFilter的基本…

基于SpringBoot+Vue学科竞赛管理系统

文章目录 基于SpringBootVue学科竞赛管理系统1系统概述1.3系统设计思想 2相关技术2.1 MYSQL数据库2.2 B/S结构2.3 Spring Boot框架简介2.4 Vue简介 3系统分析3.1可行性分析3.1.1技术可行性3.1.2经济可行性3.1.3操作可行性 3.2系统性能分析3.2.1 系统安全性3.2.2 数据完整性 3.4…

【2024】大三寒假再回首:缺乏自我意识是毒药,反思和回顾是解药

2024年初&#xff0c;学习状态回顾 开稿时间&#xff1a;2024-1-23 归家百里去&#xff0c;飘雪送客迟。 搁笔日又久&#xff0c;一顾迷惘时。 我们饱含着过去的习惯&#xff0c;缺乏自我意识是毒药&#xff0c;反思和回顾是解药。 文章目录 2024年初&#xff0c;学习状态回顾一…

虚拟机(VMware)ubuntu16.04 直接连接网口设备 USRP 吊舱

编辑虚拟网络编辑器 点击之后 选择网卡之后&#xff0c;点击确定。 电脑配置 使用了&#xff1a;192.168.2.56 虚拟机内部配置 和PC的配置一致

【Matlab】音频信号分析及FIR滤波处理——凯泽(Kaiser)窗

一、前言 1.1 课题内容: 利用麦克风采集语音信号(人的声音、或乐器声乐),人为加上环境噪声(窄带)分析上述声音信号的频谱,比较两种情况下的差异根据信号的频谱分布,选取合适的滤波器指标(频率指标、衰减指标),设计对应的 FIR 滤波器实现数字滤波,将滤波前、后的声音…

Zoho Mail 2023:回顾过去,展望未来:不断进化的企业级邮箱解决方案

当我们告别又一个非凡的一年时&#xff0c;我们想回顾一下Zoho Mail如何融合传统与创新。我们迎来了成立15周年&#xff0c;这是一个由客户、合作伙伴和我们的敬业团队共同庆祝的里程碑。与我们一起回顾这段旅程&#xff0c;探索定义Zoho Mail历史篇章的敏捷性、精确性和创新性…

Spring 事务原理二

该说些什么呢&#xff1f;一连几天&#xff0c;我都沉溺在孤芳自赏的思维中无法自拔。不知道自己为什么会有这种令人不齿的表现&#xff0c;更不知道这颗定时炸弹何时会将人炸的粉身碎骨。好在儒派宗师曾老夫子“吾日三省吾身”的名言警醒了我。遂潜心自省&#xff0c;溯源头以…

C++初阶入门之命名空间和缺省参数的详细解析

个人主页&#xff1a;点我进入主页 专栏分类&#xff1a;C语言初阶 C语言进阶 数据结构初阶 Linux C初阶 欢迎大家点赞&#xff0c;评论&#xff0c;收藏。 一起努力&#xff0c;一起奔赴大厂 目录 一.前言 二.命名空间 2.1命名冲突的例子 2.2解决方案 2.3命…