深度学习——线性神经网络(五、图像分类数据集——Fashion-MNIST数据集)

目录

  • 5.1 读取数据集
  • 5.2 读取小批量
  • 5.3 整合所有组件

  MNIST数据集是图像分类中广泛使用的数据集之一,但是作为基准数据集过于简单,在本小节将使用类似但更复杂的Fashion-MNIST数据集。

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l# 这个函数的目的是设置图形显示格式为SVG(Scalable Vector Graphics),
# 这是一种基于矢量的图形格式,可以清晰地缩放而不失真。
d2l.use_svg_display()

5.1 读取数据集

  可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans,download=True)

在这里插入图片描述
  Fashion-MNIST由10个类别的图像组成,每个类别由训练数据集中的6000张图像和测试数据集中的1000张图像组成。因此,训练集和测试集分别总共包含60000和10000张图像。测试数据集不会用于训练,只用于评估模型性能。

print(len(mnist_train))
print(len(mnist_test))
60000
10000

  每个输入图像的高度和宽度均为28像素,数据集由灰度图像组成,其通道数为1.

  在图像处理和计算机视觉中,“通道”一词常用来描述图像中颜色信息的存储方式。每个通道代表图像中一种颜色的成分,不同的颜色模式会有不用的通道数。
  灰度图像的通道数为1,在灰度图像中,每个像素只有一个强度值,表示黑白之间的不同灰度级别,不包含颜色信息。

print(mnist_train[0][0].shape)
torch.Size([1, 28, 28])

  Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。
  以下函数用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

  现在创建一个可视化函数来查看样本。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):"""创建一个函数来可视化这些样本,绘制图像列表,目的是在一张图中显示多个图像。imgs是要显示的图像列表,num_rows是创建的子图的行数,num_cols是创建的子图的列数,该子图没有设置标题,调整子图大小的缩放因子默认为1.5"""figsize = (num_cols * scale, num_rows * scale) # 计算整个子图的尺寸,基于子图的行数和列数以及缩放因子来决定_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) # figsize 参数设置了整个图形的大小axes = axes.flatten() # 将子图网格展平为一维数组,方便后续遍历for i, (ax, img) in enumerate(zip(axes, imgs)):"""使用enumerate函数和zip函数来迭代两个列表:axes和imgs。这个循环将同时遍历这两个列表,并将它们对应的元素组合在一起,然后进行处理。其中enumerate函数用于跟踪循环的当前迭代次数(即索引i),并返回每个元素及其索引。"""if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)# 子图中隐藏坐标轴。具体来说,它们分别隐藏了x轴和y轴ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i]) # 用来给每个子图设置标题plt.show()plt.savefig('class')return axesX, y = next(iter(data.DataLoader(mnist_train, batch_size=18))) # 用于拿到第一个小批量,批量大小为18
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))

在这里插入图片描述

5.2 读取小批量

  为了使我们在读取训练集和测试集时更容易,使用内置的数据迭代器,而不是从开始创建。在每次迭代中,数据加载器都会读取一小批量数据,大小为batch_size,通过内置的数据迭代器,我们可以随机打乱所有样本,从而无偏见地读取小批量。

batch_size = 256 # 设置批量大小def get_dataloader_workers():"""使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())# 看一下读取训练数据所需的时间
timer = d2l.Timer()
for X, y in train_iter:continue
print(f'{timer.stop():.2f} sec')
2.36 sec

  下面设置了不同的进程数所需的时间。设置的8个进程数读取小批量所需的时间比较少。
在这里插入图片描述

5.3 整合所有组件

  现在我们定义load_data_fashion_mnist函数,用于获取和读取Fashion-MNIST数据集。 这个函数返回训练集和验证集的数据迭代器。 此外,这个函数还接受一个可选参数resize,用来将图像大小调整为另一种形状。

def load_data_fashion_mnist(batch_size, resize=None):  #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

  我们通过指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能。

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(X.shape, X.dtype, y.shape, y.dtype)break
torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64

   小结:
  数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程的可能性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/454483.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024软考网络工程师笔记 - 第10章.组网技术

文章目录 交换机基础1️⃣交换机分类2️⃣其他分类方式3️⃣级联和堆叠4️⃣堆叠优劣势5️⃣交换机性能参数 🕑路由器基础1️⃣路由器接口2️⃣交换机路由器管理方式2️⃣交换机路由器管理方式 交换机基础 1️⃣交换机分类 1.根据交换方式分 存储转发式交换(Store…

Hadoop 踩坑汇总

文章目录 一、完整教程二、解决问题问题①: DataNode 没有问题②: 网页打不开 三、大功告成!! 一、完整教程 这个教程比较详细,博主是按照这个来执行的 https://blog.csdn.net/qq_47831505/article/details/123806514…

保姆级VsCode配置C++编译环境

文章目录 一、下载安装VSCODE二、 安装C拓展三、下载配置MinGW-w64四、下载配置CMake五、 配置vscode中的 json文件六、 谨记 在现代开发中,VSCode以其轻量、强大的扩展生态圈,逐渐成为了众多开发者的首选编辑器,尤其是在C开发环境中&#xf…

全光网络架构

目前组网架构 世界上有一种最快的速度又是光,以前传统以太网络规划满足不了现在的需求。 有线网 无线网 全光网络方案 场景 全光网络分类 以太全光网络 PON(Pas-sive-Optical Network 无源光网络) 再典型的中大型高校网络中 推荐万兆入…

MySQL程序特别酷

这一篇和上一篇有重合的内容,,我决定从头开始再学一下MySQL,和上一篇的区别是写的更细了,以及写这篇的时候Linux已经学完了 下面就是关于MySQL很多程序的介绍: MySQL安装完成通常会包含如下程序: Linux系…

ArcGIS002:软件自定义设置

摘要:本文详细介绍安装arcgis10.2后软件自定义设置内容,包括工具条的启用、扩展模块的启用、如何加载项管理器、快捷键设置、样式管理器的使用以及软件常规设置。 一、工具条的启用 依次点击菜单栏【自定义】->【工具条】,根据工作需求勾…

基于neo4j的医疗图谱问答与展示

找不到好的毕业设计题材?或者对人工智能领域感兴趣却不知道如何下手?这里给大家推荐一款基于Neo4j的医疗图谱问答系统项目,绝对是毕业设计的不二选择。 这个项目依托于医疗领域的知识图谱,为用户提供交流问答系统。它不仅具有知识…

吃透高并发模型与RPC框架,拿下大厂offer!!!

在当前的互联网市场环境下,竞争愈发激烈,内卷现象严重。在这种背景下,「高并发模型和RPC框架已经成为了大型企业面试的重要环节」。你是否曾因为无法回答相关技术问题而感到尴尬?例如,Java岗位的面试中会询问NIO和Reac…

使用JUC包的AtomicXxxFieldUpdater实现更新的原子性

写在前面 本文一起来看下使用JUC包的AtomicXxxxFieldUpdater实现更新的原子性。代码位置如下: 当前有针对int,long,ref三种类型的支持。如果你需要其他类型的支持的话,也可以照葫芦画瓢。 1:例子 1.1:普…

Java项目-基于springboot框架的学习选课系统项目实战(附源码+文档)

作者:计算机学长阿伟 开发技术:SpringBoot、SSM、Vue、MySQL、ElementUI等,“文末源码”。 开发运行环境 开发语言:Java数据库:MySQL技术:SpringBoot、Vue、Mybaits Plus、ELementUI工具:IDEA/…

MATLAB图像重心计算

图像重心(或质心)计算是计算机视觉和图像处理领域 应用领域广泛:包括医疗,生物,动画,机器人等。 该文章通过灰度转换->二值化->质心计算 以下是代码中涉及的一些数学概念和公式: 灰度转换&#xff1a…

力扣困难题汇总(14道)

题4(困难): 思路: 找两数组中位数,这个看起来简单,顺手反应就是数第(mn)/2个,这个难在要求时间复杂度为log(mn),所以不能这样搞,我的思路是:每次切割长度为较…

Systemd:简介

1号进程 Systemd是linux系统的守护进程,它要管理正在运行的 Linux 主机的许多方面,包括挂载文件系统、管理硬件、处理定时器以及启动和管理生产性主机所需的系统服务。 $ ps -u -p 1 USER PID %CPU %MEM VSZ RSS TTY STAT START TI…

R语言机器学习算法实战系列(九)决策树分类算法 (Decision Trees Classifier)

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍教程下载数据加载R包导入数据数据预处理数据描述数据切割调节参数构建模型模型的决策树预测测试数据评估模型模型准确性混淆矩阵模型评估指标ROC CurvePRC Curve特征的重要性保存模…

N9042B UXA 信号分析仪

N9042BUXA 信号分析仪 - 2Hz到50GHz - 使用 N9042B UXA X 系列信号分析仪和各种测量应用软件,可以测试 5G、卫星等应用中的毫米波(mmWave)创新设计的真实性能。 N9042B 具有是德科技信号分析仪中较大的分析带宽和较深的动态范围&#xff0c…

【云原生】Kubernetes部署Jenkins静动Slave

Kubernetes部署Jenkins静动Slave 文章目录 Kubernetes部署Jenkins静动Slave文档介绍资源列表基础环境一、Jenkins Kubernetes清单文件二、使用静态Slave2.1、安装Kubernetes插件2.2、添加Agent2.3、使用Slave 三、使用动态Slave3.1、添加凭据3.2、配置动态Slave3.3、配置Jenkin…

基于SpringBoot+Vue+uniapp微信小程序的澡堂预订的微信小程序的详细设计和实现

项目运行截图 技术框架 后端采用SpringBoot框架 Spring Boot 是一个用于快速开发基于 Spring 框架的应用程序的开源框架。它采用约定大于配置的理念,提供了一套默认的配置,让开发者可以更专注于业务逻辑而不是配置文件。Spring Boot 通过自动化配置和约…

【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper+代码——加性注意力(Additive Attention)

【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper代码——加性注意力(Additive Attention) 【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper代码——加性注意力(Additive Attention) 文章目录…

kubernetes(三)

k8s之持久化存储pv&pvc 存储资源管理 在基于k8s容器云平台上,对存储资源的使用需求通常包括以下几方面: 1.应用配置文件、密钥的管理; 2.应用的数据持久化存储; 3.在不同的应用间共享数据存储; k8s支持Volume类…

Spring MVC文件请求处理-MultipartResolver

Spring Boot中的MultipartResolver是一个用于解析multipart/form-data类型请求的策略接口,通常用于文件上传。 对应后端使用MultipartFile对象接收。 RequestMapping("/upload")public String uploadFile(MultipartFile file) throws IOException {Strin…