项目实例_FashionMNIST_CNN

前言

提醒:
文章内容为方便作者自己后日复习与查阅而进行的书写与发布,其中引用内容都会使用链接表明出处(如有侵权问题,请及时联系)。
其中内容多为一次书写,缺少检查与订正,如有问题或其他拓展及意见建议,欢迎评论区讨论交流。

文章目录

  • 前言
  • CNN介绍
  • 数据集介绍
      • FashionMNIST 数据集概述
        • 主要特点:
        • 类别标签:
        • 目标:
        • 为什么使用 FashionMNIST:
        • 图像示例:
        • 数据加载与预处理:
        • 下载和加载 FashionMNIST(PyTorch 代码示例):
        • 数据集使用场景:
      • 数据集可视化
      • 总结:
  • 项目实例
    • 代码


CNN介绍

可见于:
MNIST数据集_CNN
在这里插入图片描述

数据集介绍

FashionMNIST 数据集概述

FashionMNIST 是一个包含 10 个类别的图像数据集,用于训练和测试机器学习模型,特别是在图像分类任务中的应用。它是由 Zalando 提供的,目的是为了给机器学习社区提供一个标准化、简洁但具有挑战性的视觉分类数据集。FashionMNIST 是 MNIST 数据集的一个变种,后者包含手写数字图像。

主要特点:
  • 图像大小:每张图像的分辨率为 28x28 像素,每个像素为灰度值(单通道,值在 0 到 255 之间)。
  • 图像类型:这些图像展示了 10 种不同类型的时尚商品,例如鞋子、T 恤、外套等。
  • 数据集结构
    • 训练集:60,000 张图像
    • 测试集:10,000 张图像
  • 类别:数据集包含 10 个类别,分别对应不同的服装商品(每个类别有对应的标签)。
类别标签:
  • 0: T 恤/上衣
  • 1: 裤子
  • 2: 套头衫
  • 3: 连衣裙
  • 4: 外套
  • 5: 凉鞋
  • 6: 衬衫
  • 7: 运动鞋
  • 8: 包
  • 9: 靴子
目标:

FashionMNIST 的目标是对每张 28x28 的灰度图像进行分类,判定该图像属于哪个类别(例如是 “T 恤”、“裤子” 还是 “运动鞋” 等)。因此,它是一个 多类分类 问题,通常被用来评估各种机器学习模型,尤其是在图像分类任务中的表现。

为什么使用 FashionMNIST:

FashionMNIST 与原始的 MNIST 数据集相似,但相比 MNIST(手写数字),FashionMNIST 的图像内容更加复杂且多样。这使得它在很多机器学习和深度学习领域成为了一个较为简单但富有挑战性的测试集。它被广泛应用于:

  • 深度学习模型的评估:特别是用于测试卷积神经网络(CNN)等模型的性能。
  • 学习和研究:它提供了一个简单且标准化的图像分类数据集,适用于机器学习入门或新模型的验证。
图像示例:

每张图像都是 28x28 像素的灰度图,显示的是一件衣物的图片。图像尺寸较小,适合在初学者的机器学习项目中进行训练,因为它不需要大量的计算资源。

数据加载与预处理:

FashionMNIST 数据集一般会在加载时进行一些标准的预处理步骤,如:

  • 归一化:将像素值从 [0, 255] 范围映射到 [0, 1] 范围,或者进行标准化,常常帮助提升模型的收敛速度。
  • 转换:通常使用 transforms.ToTensor() 将图像转换为 PyTorch 中的 Tensor 格式。
下载和加载 FashionMNIST(PyTorch 代码示例):
import torchvision
from torchvision import transforms# 加载训练集数据
train_data = torchvision.datasets.FashionMNIST(root='data',     # 存储路径train=True,      # 训练集download=True,   # 下载数据集transform=transforms.ToTensor(),  # 转换为 Tensor 格式
)# 加载测试集数据
test_data = torchvision.datasets.FashionMNIST(root='data',train=False,     # 测试集download=True,transform=transforms.ToTensor(),
)

这段代码使用 torchvision.datasets.FashionMNIST 来下载并加载训练和测试数据集,图像会被转换为 PyTorch 的 Tensor 格式。

数据集使用场景:
  • 入门项目:对于初学者,FashionMNIST 是一个非常适合入门的图像分类数据集,因为它相对简单,且具有一定的挑战性。
  • 模型对比与验证:在机器学习和深度学习领域,FashionMNIST 常常作为测试模型性能的标准数据集之一,用来对比不同算法(如支持向量机、KNN、神经网络等)的表现。
  • 神经网络训练:尤其是卷积神经网络(CNN)在图像分类任务中的应用,FashionMNIST 为其提供了一个理想的训练平台。

数据集可视化

import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np# 下载 FashionMNIST 数据集
transform = transforms.Compose([transforms.ToTensor()])  # 只需要转换为 Tensor
train_data = datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)# 获取前 10 张图像以及对应的标签
images, labels = zip(*[(train_data[i][0], train_data[i][1]) for i in range(10)])# 类别名称映射
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]# 设置绘图
fig, axes = plt.subplots(1, 10, figsize=(15, 15))  # 创建 1 行 10 列的子图for i in range(10):ax = axes[i]ax.imshow(images[i].squeeze(), cmap='gray')  # squeeze 去掉多余的维度,cmap 为灰度色ax.set_title(class_names[labels[i]])  # 标注类别名称ax.axis('off')  # 不显示坐标轴plt.show()  # 显示图像

运行结果:
在这里插入图片描述

总结:

FashionMNIST 是一个标准化的图像分类数据集,由 Zalando 提供。它由 10 类不同的时尚商品构成,训练集包含 60,000 张图像,测试集包含 10,000 张图像。它适用于深度学习、机器学习模型的训练和评估,尤其适合初学者学习和实验。

项目实例

代码

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt# 下载训练数据集FashionMNIST
training_data = torchvision.datasets.FashionMNIST(root="data",  # 数据集存储位置train=True,    # 使用训练集download=True, # 如果数据集不存在,则下载transform=transforms.ToTensor(),  # 转换为Tensor
)# 下载测试数据集FashionMNIST
test_data = torchvision.datasets.FashionMNIST(root="data",train=False,   # 使用测试集download=True,transform=transforms.ToTensor(),  # 转换为Tensor
)# 标签的映射字典,数字标签对应的衣物类别名称
labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}# 可视化FashionMNIST数据集中前9张图片
figure = plt.figure(figsize=(8, 8))  # 创建一个8x8的图像画布
cols, rows = 3, 3  # 设置行列数
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()  # 随机选取一张图片img, label = training_data[sample_idx]  # 获取图片和标签figure.add_subplot(rows, cols, i)  # 在画布上添加子图plt.title(labels_map[label])  # 设置图片的标题为标签对应的衣物类别plt.axis("off")  # 关闭坐标轴显示plt.imshow(img.squeeze(), cmap="gray")  # 显示图像,squeeze去掉多余维度,cmap设置为灰度图
plt.show()  # 展示图像# 设置训练过程中的超参数
num_epochs = 10       # 训练的轮数
batch_size = 32       # 批大小
weight_decay = 1e-4    # 权重衰减(L2正则化)
learning_rate = 0.001 # 学习率# 创建数据加载器,训练集和测试集
train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)# 打印测试集的一个batch的尺寸和标签类型
for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")  # X是输入图像,N是批大小,C是通道数,H和W是图像的高和宽print(f"Shape of y: {y.shape} {y.dtype}")      # y是标签,显示其维度和数据类型break  # 只打印一个batch的信息# 检测是否使用GPU进行训练
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")  # 输出当前使用的设备(CUDA or CPU)# 定义一个简单的神经网络模型
class NeuralNet(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(NeuralNet, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)   # 第一层全连接层,输入784个特征,输出hidden_size个特征self.relu = nn.ReLU()                           # ReLU激活函数self.fc2 = nn.Linear(hidden_size, num_classes)  # 第二层全连接层,输出num_classes个类别的预测值def forward(self, x):out = self.fc1(x)    # 输入通过第一层out = self.relu(out)  # ReLU激活out = self.fc2(out)   # 输入通过第二层,输出结果return out# 假设输入是28x28的图像,展开为784维,隐藏层大小为500,分类数为10
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数,适用于多分类问题
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)  # 使用Adam优化器# 开始训练模型
total_step = len(train_dataloader)  # 获取训练集的总批次数
for epoch in range(num_epochs):  # 遍历每一个epochfor i, (images, labels) in enumerate(train_dataloader):  # 遍历每一个batchimages = images.reshape(-1, 28*28).to(device)  # 将28x28的图像展开成784维向量,转移到device(GPU/CPU)labels = labels.to(device)  # 标签转移到设备上# 前向传播outputs = model(images)  # 将输入传入模型,得到预测输出loss = criterion(outputs, labels)  # 计算损失# 反向传播和优化optimizer.zero_grad()  # 清零之前的梯度loss.backward()  # 计算当前梯度optimizer.step()  # 更新模型参数# 每100个batch输出一次训练状态if (i + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')# 训练完成后,在测试集上评估模型的准确率
model.eval()  # 设置模型为评估模式(此时BatchNorm等层使用移动平均值而不是批量值)
with torch.no_grad():  # 不需要计算梯度correct = 0total = 0for images, labels in test_dataloader:images = images.reshape(-1, 28*28).to(device)  # 将图像展开为784维labels = labels.to(device)  # 标签转移到设备上outputs = model(images)  # 获取模型的输出_, predicted = torch.max(outputs.data, 1)  # 获取预测类别,outputs.data返回模型的预测结果total += labels.size(0)  # 统计总样本数correct += (predicted == labels).sum().item()  # 统计预测正确的样本数# 输出模型在测试集上的准确率print(f'Test Accuracy of the model on the 10000 test images: {100 * correct / total} %')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/486256.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Autosar FO时间分析和设计规范导读

一、规范功能概述 “Timing Analysis and Design AUTOSAR FO R24 - 11” 文档主要聚焦于汽车电子系统开发中的定时分析与设计,详细阐述了相关概念、方法、用例及涉及的各项要素,旨在为汽车电子系统的开发提供全面且系统的定时分析指导,以确保…

使用 libssh2_session_set_timeout 设置 SSH 会话超时时间

使用 libssh2_session_set_timeout 设置 SSH 会话超时时间 函数原型参数说明返回值示例代码注意事项libssh2_session_set_timeout 是 libssh2 库中的一个函数,用于设置 SSH 会话的超时时间。这对于防止网络延迟或连接中断导致的长时间挂起非常有用。 函数原型 int libssh2_se…

001 LVGL PC端模拟搭建

01 LVGL模拟器介绍 使用PC端软件模拟LVGL运行,而不需要任何嵌入式硬件 环境搭建:codeblocks-20.03mingw-setup 正常安装流程即可 工程获取:LVGL官网-> github仓库 本地安装包下载资源包 工程模版和软件安装包 补充:…

开源ISP介绍(2)————嵌入式Vitis搭建

Vivado搭建参考前一节Vivado基于IP核的视频处理框架搭建: 开源ISP介绍(1)——开源ISP的Vivado框架搭建-CSDN博客 导出Hardware 在vivado中导出Hardware文件,成功综合—实现—生成比特流后导出硬件.xsa文件。(注意导…

人工智能-自动驾驶领域

目录 引言自动驾驶与人工智能的结合为什么自动驾驶领域适合发表文章博雅智信的自动驾驶辅导服务结语 引言 自动驾驶技术的崛起是当代交通行业的一场革命。通过结合先进的人工智能算法、传感器技术与计算机视觉,自动驾驶不仅推动了技术的进步,也使得未来…

Kubernetes 深入浅出系列 | 容器编排与作业调度之Deployment

目录 概述Deployment 的更新原理实验 概述 Kubernetes 中,Deployment 控制器是用于管理应用程序生命周期的核心对象。Deployment 通过管理 ReplicaSet 来间接控制 Pod,确保在任何时刻都能维持指定数量的 Pod 副本。这种间接管理使得 Deployment 功能比 …

Java——异常机制(上)

1 异常机制本质 (异常在Java里面是对象) (抛出异常:执行一个方法时,如果发生异常,则这个方法生成代表该异常的一个对象,停止当前执行路径,并把异常对象提交给JRE) 工作中,程序遇到的情况不可能完美。比如…

如何查看电脑的屏幕刷新率?

1、按一下键盘的 win i 键,打开如下界面,选择【系统】: 2、选择【屏幕】-【高级显示设置】 如下位置,显示屏幕的刷新率:60Hz 如果可以更改,则选择更高的刷新率,有助于电脑使用起来界面更加流…

WAT绕过姿势

一.空格字符绕过 两个空格代替⼀个空格,⽤ Tab 代替空格,%a0空格 %20 %09 %0a %0b %0c %0d %a0 %00 /**/ /*!*/ select * from users where id1 /*!union*//*!select*/1,2,3,4; %09 TAB 键(⽔平)%0a 新建⼀⾏%0c 新的⼀⻚%0d …

Ubuntu 环境美化

一、终端选择 zsh 参考文章使用 oh-my-zsh 美化终端 Oh My Zsh 是基于 zsh 命令行的一个扩展工具集,提供了丰富的扩展功能。 先安装zsh再安装Oh My Zsh 1.zsh安装 sudo apt-get install zsh 2.设置默认终端为 zsh chsh -s /bin/zsh 3.安装 oh-my-zsh 官网&…

QT的ui界面显示不全问题(适应高分辨率屏幕)

//自动适应高分辨率 QCoreApplication::setAttribute(Qt::AA_EnableHighDpiScaling);一、问题 电脑分辨率高,默认情况下,打开QT的ui界面,显示不全按钮内容 二、解决方案 如果自己的电脑分辨率较高,可以尝试以下方案:自…

docker报错ls: cannot access SURF: Transport endpoint is not connected

docker挂载nfs文件夹/CMADAAS/DATA。它大部分时间都可用,只是有时会断开连接。重新挂载后,实际挂载的文件夹将再次可用。 问题是我将此文件夹放入docker卷中以使其可供我的应用程序使用:/SURF。当我启动容器时,该卷可用。 但是&…

AJAX三、XHR,基本使用,查询参数,数据提交,promise的三种状态,封装-简易axios-获取省份列表 / 获取地区列表 / 注册用户,天气预报

一、XMLHttpRequest基本使用 XMLHttpRequest(XHR)对象用于与服务器交互。 二、XMLHttpRequest-查询参数 语法: 用 & 符号分隔的键/值对列表 三、XMLHttpRequest-数据提交 核心步骤 : 1. 请求头 设置 Content-Type 2. 请求体 携带 符合要求 的数…

【Ubuntu】URDC(Ubuntu远程桌面助手)安装、用法,及莫名其妙进入全黑模式的处理

1、简述 URDC是Ubuntu远程桌面助手的简称。 它可以: 实时显示桌面:URDC支持通过Windows连接至Ubuntu设备(包括x86和ARM架构,例如Jetson系列、树莓派等)的桌面及光标。远程操控双向同步剪切板多客户端连接:同一Ubuntu设备最多可同时被三台Windows客户端连接和操控,适用于…

MVC基础——市场管理系统(一)

文章目录 项目地址一、创建项目结构1.1 创建程序以及Controller1.2 创建View1.3 创建Models层,并且在Edit页面显示1.4 创建Layou模板页面1.5 创建静态文件css中间件二、Categories的CRUD2.1 使用静态仓库存储数据2.2 将Categorie的列表显示在页面中(List)2.3 创建_ViewImport.…

KV Shifting Attention Enhances Language Modeling

基本信息 📝 原文链接: https://arxiv.org/abs/2411.19574👥 作者: Mingyu Xu, Wei Cheng, Bingning Wang, Weipeng Chen🏷️ 关键词: KV shifting attention, induction heads, language modeling📚 分类: 机器学习, 自然语言处…

spring下的beanutils.copyProperties实现深拷贝

spring下的beanutils.copyProperties方法是深拷贝还是浅拷贝?可以实现深拷贝吗? 答案:浅拷贝。 一、浅拷贝深拷贝的理解 简单说拷贝就是将一个类中的属性拷贝到另一个中,对于BeanUtils.copyProperties来说,你必须保…

沈阳工业大学《2024年827自动控制原理真题》 (完整版)

本文内容,全部选自自动化考研联盟的:《沈阳工业大学827自控考研资料》的真题篇。后续会持续更新更多学校,更多年份的真题,记得关注哦~ 目录 2024年真题 Part1:2024年完整版真题 2024年真题

Milvus Cloud 2.5:向量数据库的新里程碑与全文检索的革新

Milvus Cloud 2.5:向量数据库的新里程碑与全文检索的革新 各位同仁,大家好!我是大禹智库的向量数据库高级研究员王帅旭,也是《向量数据库指南》的作者。今天,我怀着激动的心情,为大家带来 Milvus Cloud 2.5 最新版本的深度解读。这个版本不仅标志着我们在向量数据库领域…

【金猿CIO展】复旦大学附属中山医院计算机网络中心副主任张俊钦:推进数据安全风险评估,防范化解数据安全风险,筑牢医疗数据安全防线...

‍ 张俊钦 本文由复旦大学附属中山医院计算机网络中心副主任张俊钦撰写并投递参与“数据猿年度金猿策划活动——2024大数据产业年度优秀CIO榜单及奖项”评选。 大数据产业创新服务媒体 ——聚焦数据 改变商业 数据要素时代,医疗数据已成为医院运营与决策的重要基石…