算法面试准备 - 手撕系列第七期 - MLP(利用FashionMNIST数据集)

算法面试准备 - 手撕系列第七期 - MLP(利用FashionMNIST数据集)

目录

  • 算法面试准备 - 手撕系列第七期 - MLP(利用FashionMNIST数据集)
  • FashionMINIST 图像分类原理解析
    • 1. 全连接的原理图
    • 2. 背景介绍
    • 3.引入相关库函数
    • 4. 数据预处理
    • 5. 模型设计
    • 6. 初始化网络,损失函数与优化器
    • 7. 训练与测试
      • 7.1 训练过程
      • 7.2 测试过程
    • 8. 结论
  • 参考

FashionMINIST 图像分类原理解析

本文将详细解析基于 PyTorch 实现的 FashionMNIST 图像分类的原理及代码结构,适用于初学者理解深度学习图像分类任务的完整流程。


1. 全连接的原理图

在这里插入图片描述

全连接的原理图

2. 背景介绍

FashionMNIST 数据集是一个用于替代经典 MNIST 数据集的基准数据集。它包含 10 类不同的服装图像,每张图像大小为 28x28 像素,灰度图像。

类别编号类别名称
0T 恤/上衣
1裤子
2套衫
3连衣裙
4外套
5凉鞋
6衬衫
7运动鞋
8
9短靴

3.引入相关库函数

# 该模块主要是为了实现FashionMinist图像分类。图像的大小为(28,28),类别为同样为10类
'''
# Part1引入相关的库函数
'''
import torch
from torch import nn
from torch.utils import dataimport torchvision
from torchvision import transforms

4. 数据预处理

图像分类任务的第一步是数据加载和预处理。在代码中,通过 torchvision.datasets.FashionMNIST 加载数据集,并对图像进行以下处理:

  1. 转换为张量:使用 transforms.ToTensor() 将图像转换为 PyTorch 张量格式,并将像素值归一化到 [0, 1]。
  2. 数据分割:划分为训练集和测试集,使用 DataLoader 封装成可迭代的数据加载器。
'''
# Part2 数据集的加载,和dataloader的初始化
'''transforms_action = [transforms.ToTensor()]
transforms_action = transforms.Compose(transforms_action)Minist_train = torchvision.datasets.FashionMNIST(root='Minist', train=True, transform=transforms_action, download=True)
Minist_test = torchvision.datasets.FashionMNIST(root='Minist', train=False, transform=transforms_action, download=True)train_dataloader = data.DataLoader(dataset=Minist_train, batch_size=15, shuffle=True)
test_dataloader = data.DataLoader(dataset=Minist_test, batch_size=15, shuffle=True)

5. 模型设计

本例中使用的是多层感知机(MLP)模型,它由以下组件构成:

  1. 输入层:将输入的 28x28 图像展开为一维向量(大小为 784)。
  2. 隐藏层:一层全连接层,输出大小为 128,激活函数使用 ReLU。
  3. 输出层:全连接层输出大小为 10,对应 10 个类别。
class MLP(nn.Module):def __init__(self, image_size, num_kind,latent=128):super(MLP, self).__init__()self.Linear1 = nn.Linear(image_size, latent, bias=False)self.relu1 = nn.ReLU()# 因为最后一层常用于一些其他操作,进行信息传递,一般就不添加非线性的激活函数了,一般都是不需要的。self.Linear2 = nn.Linear(latent, num_kind, bias=False)# 计算CrossEntropyLoss时候会自动计算softmax所以不需要。# self.softmax = nn.Softmax(dim=-1)def forward(self, x):  # (batch,1,28,28)x = x.reshape(x.size()[0], -1)x = self.Linear1(x)x = self.relu1(x)x = self.Linear2(x)# x = self.softmax(x)return x  # (batch,10)

6. 初始化网络,损失函数与优化器

分类任务中使用交叉熵损失(CrossEntropyLoss),其原理是衡量预测类别分布与真实类别分布之间的差异。

优化器选择随机梯度下降(SGD),其更新公式为:
θ t + 1 = θ t − η ∇ L ( θ t ) \theta_{t+1} = \theta_t - \eta \nabla L(\theta_t) θt+1=θtηL(θt)
其中:

  • θ t \theta_t θt 为当前参数
  • η \eta η 为学习率
  • ∇ L ( θ t ) \nabla L(\theta_t) L(θt) 为损失函数的梯度
# 初始化网络
net = MLP(784, 10)
# 初始化loss
loss = nn.CrossEntropyLoss()
# 初始化优化器
optimizer = torch.optim.SGD(params=net.parameters(), lr=1e-3)

7. 训练与测试

7.1 训练过程

  1. 遍历每个批次的数据:
    • 将图像输入模型,计算预测结果。
    • 计算损失函数,反向传播计算梯度。
    • 使用优化器更新模型参数。
    • 清零梯度以避免累积。
  2. 每轮训练结束后保存模型状态。
'''
# Part4 循环训练计算损失
'''epochs = 10for epoch in range(epochs):for images, labels in train_dataloader:# 首先前向传播result = net(images)# 计算损失L = loss(result, labels)# 反向传播L.backward()# 参数更新optimizer.step()# 清除梯度optimizer.zero_grad()# 存储模型torch.save(net, 'checkpoint/module_epoch_{}.pth'.format(epoch))

7.2 测试过程

  1. 模型设置为评估模式,禁用梯度计算(torch.no_grad())。
  2. 遍历测试集,计算平均测试损失。
# 每个epoch在测试集跑一遍进行计算平均损失total_loss = 0total_batches = 0with torch.no_grad():for images_test, labels_test in Minist_test:# 形状是Batchsize*hanglabels_hat = net(images_test)L_test = loss(labels_hat, labels_test)total_loss += L_test.item()total_batches += 1# 计算平均测试损失并记录avg_test_loss = total_loss / total_batchesprint(f'第 {epoch + 1} 轮训练完成,平均测试损失为:{avg_test_loss}')

8. 结论

通过上述流程,我们成功实现了基于 FashionMNIST 数据集的分类模型。代码结构清晰,包含了数据加载、模型定义、训练与测试的完整过程,为深度学习图像分类任务提供了良好的实践基础。

参考

自己(好像会了好像又不会,容易忘记各种简单的操作,比如数据集存储的位置啥的):小菜鸟博士-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/3325.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

单元测试与unittest框架

🍅 点击文末小卡片 ,免费获取软件测试全套资料,资料在手,薪资嘎嘎涨 单元测试的定义 1. 什么是单元测试? 单元测试是指,对软件中的最小可测试单元在与程序其他部分相隔离的情况下进行检查和验证的工作&am…

【Hugging Face】下载开源大模型步骤

Mac M1 1、国内镜像站 模型基本都可以在国内镜像站 https://hf-mirror.com/ 下载。 部分 Gated Repo 需登录申请许可,需先前往 Hugging Face 官网登录、申请许可,在官网这里获取 Access Token 后回镜像站用命令行下载。 2、注册登陆 Hugging Face 2.1…

STM32的集成开发环境STM32CubeIDE安装

STM32CubeIDE - STM32的集成开发环境 - 意法半导体STMicroelectronics

.Net8 Avalonia跨平台UI框架——<vlc:VideoView>控件播放海康监控、摄像机视频(Windows / Linux)

一、UI效果 二、新建用户控件:VideoViewControl.axaml 需引用:VideoLAN.LibVLC.Windows包 Linux平台需安装:VLC 和 LibVLC (sudo apt-get update、sudo apt-get install vlc libvlccore-dev libvlc-dev) .axaml 代码 注…

51.WPF应用加图标指南 C#例子 WPF例子

完整步骤: 先使用文心一言生成一个图标如左边使用Windows图片编辑器编辑,去除背景使用正方形,放大图片使图标铺满图片使用格式工程转换为ico格式,分辨率为最大 在资源管理器中右键项目添加ico类型图片到项目里图片属性设置为始终…

C++(二十一)

前言: 本文承接上文,将详细讲解指针概念。 一,通过指针了解变量的数值。 在将变量地址存入指针后,从指针反推也可以知道原变量的值,若想进行反退,就需要使用间接引用运算符:*。 语法&#x…

Redis 性能优化:多维度技术解析与实战策略

文章目录 1 基准性能2 使用 slowlog 优化耗时命令3 big key 优化4 使用 lazy free 特性5 缩短键值对的存储长度6 设置键值的过期时间7 禁用耗时长的查询命令8 使用 Pipeline 批量操作数据9 避免大量数据同时失效10 客户端使用优化11 限制 Redis 内存大小12 使用物理机而非虚拟机…

网络安全面试题汇总(个人经验)

1.谈一下SQL主从备份原理? 答:主将数据变更写入自己的二进制log,从主动去主那里去拉二进制log并写入自己的二进制log,从而自己数据库依据二进制log内容做相应变更。主写从读 2.linux系统中的计划任务crontab配置文件中的五个星星分别代表什么&#xff…

从AI生成内容到虚拟现实:娱乐体验的新边界

引言 在快速发展的科技时代,娱乐行业正经历一场前所未有的变革。传统的娱乐方式正与先进技术融合,创造出全新的沉浸式体验。从AI生成的个性化内容,到虚拟现实带来的身临其境的互动场景,科技不仅改变了我们消费娱乐的方式&#xf…

爬虫基础学习

什么是爬虫: 通过编写程序,模拟浏览器上网,然后让其去互联网上抓取数据的过程。 爬虫的价值: 实际应用就业 爬虫究竟是合法还是违法的? 在法律中是不被禁止具有违法风险善意爬虫 恶意爬虫 爬虫带来的风险可以体现在如下方面: 爬虫干扰了被访问网…

wireshark抓路由器上的包 抓包路由器数据

文字目录 抓包流程概述设置抓包配置选项 设置信道设置无线数据包加密信息设置MAC地址过滤器 抓取联网过程 抓包流程概述 使用Omnipeek软件分析网络数据包的流程大概可以分为以下几个步骤: 扫描路由器信息,确定抓包信道;设置连接路由器的…

第34天:Web开发-PHP应用鉴别修复AI算法流量检测PHP.INI通用过滤内置函数

#知识点 1、安全开发-原生PHP-PHP.INI安全 2、安全开发-原生PHP-全局文件&单函数 3、安全开发-原生PHP-流量检测&AI算法 一、通用-PHP.INI设置 参考: https://www.yisu.com/ask/28100386.html https://blog.csdn.net/u014265398/article/details/109700309 …

Python爬虫学习前传 —— Python从安装到学会一站式服务

早上好啊,大佬们。我们的python基础内容的这一篇终于写好了,啪唧啪唧啪唧…… 说实话,这一篇确实写了很久,一方面是在忙其他几个专栏的内容,再加上生活学业上的事儿,确实精力有限,另一方面&…

【Flink系列】6. Flink中的时间和窗口

6. Flink中的时间和窗口 在批处理统计中,我们可以等待一批数据都到齐后,统一处理。但是在实时处理统计中,我们是来一条就得处理一条,那么我们怎么统计最近一段时间内的数据呢?引入“窗口”。 所谓的“窗口”&#xff…

《汽车维修技师》是什么级别的期刊?是正规期刊吗?能评职称吗?

​问题解答: 问:《汽车维修技师》是不是核心期刊? 答:不是,是知网收录的正规学术期刊。 问:《汽车维修技师》级别? 答:省级。主管单位:北方联合出版传媒(…

HTML中如何保留字符串的空白符和换行符号的效果

有个字符串 储值门店{{thing3.DATA}}\n储值卡号{{character_string1.DATA}}\n储值金额{{amount4.DATA}}\n当前余额{{amount5.DATA}}\n储值时间{{time2.DATA}} , HTML中想要保留 \n的换行效果的有下面3种方法: 1、style 中 设置 white-space: pre-lin…

Git在码云上的使用指南:从安装到推送远程仓库

目录 目录 前言: 1、git的安装 1.1.Linux-centos环境下安装 1.2.Linux-ubuntu环境下安装 2.创建Git本地仓库 3.配置Git 4.认识⼯作区、暂存区、版本库 5.添加文件 5.1.git命令 5.2.commit命令 6.远程操作 6.1.新建远程仓库 6.2.克隆远程仓库&#xff…

群论学习笔记

什么是对称? 对称是一个保持对象结构不变的变换,对称是一个过程,而不是一个具体的事物,伽罗瓦的对称是对方程根的置换,而一个置换就是对一系列事物的重排方式,严格的说,它也并不是这个重排本身…

联通用户管理系统(一)

#联通用户管理系统(一) 1.新建项目 如果你是windows的话,界面应该是如下的: 2.创建app python manage.py startapp app01一般情况下:我们是在pycharm的终端中运行上述指令,但是pychrm中为我们提供了工具…

.Net Core微服务入门全纪录(二)——Consul-服务注册与发现(上)

系列文章目录 1、.Net Core微服务入门系列(一)——项目搭建 2、.Net Core微服务入门全纪录(二)——Consul-服务注册与发现(上) 3、.Net Core微服务入门全纪录(三)——Consul-服务注…