dl学习笔记(8):fashion-mnist

过完年懒羊羊也要复工了,这一节的内容不多,我们接着上次的fashion-mnist数据集。

首先第一步就是导入数据集,由于这个数据集很有名,是深度学习的常见入门数据集,所以可以在库里面导入。由于是图像数据集所以,被存放在视觉模块里面。

import torchvision
import torchvision.transforms as transforms
mnist = torchvision.datasets.FashionMNIST(root=r'E:\桌面\深度学习课件\lesson 11\MINST-FASHION', train=True, download=False, transform=transforms.ToTensor())

下面我们来解释一下这几个参数:

1)root指定数据集存储的本地路径,如果路径不存在,且 download=True,PyTorch会自动创建该路径并下载数据。如果路径已存在且包含数据集文件,则直接加载本地数据。

2)train决定加载的是训练集还是测试集。

  • train=True:加载训练集(60,000张图片)

  • train=False:加载测试集(10,000张图片)

3)download:控制是否从网络下载数据集。

  • download=True:如果本地路径 root 中不存在数据集,则自动下载。

  • download=False:不下载,直接加载本地数据(需确保本地路径已存在数据集)。

4)transform:定义数据预处理操作。

  • ToTensor() 将PIL图像或NumPy数组转换为PyTorch张量(Tensor),并自动进行以下操作:

    将图像数据范围从 [0, 255] 缩放到 [0, 1]。调整张量维度为 [C, H, W](通道、高度、宽度),例如FashionMNIST是灰度图,因此 C=1
  • 如果需要对数据做进一步处理(如归一化),可以组合多个变换

运行结果如上,下一步可以查看属性信息。

这里的size含义就是有六万张图片,每张都是28*28的像素,需要注意的是这里省略了颜色通道,由于该数据集是灰度图片所以这里默认是1。

我们可以通过targets来查看标签,再通过unique来获得标签的唯一值,可以看到是一个多分类任务,总共十个类别。我们还可以通过classes来查看每个数字对应的具体衣服的类别是什么。

下一步我们通过索引来具体看看里面存储的是什么:

图片有点长,如果我们仔细看的话,前面全是图片像素点的张量,最后有一个不起眼的9就是这张图片的标签,所以我们可以通过[0][0]来索引张量,下面我们来展示出来这张图片。

我们将像素部分的张量传入,由于这里是tensor结构,所以我们需要最后转化成numpy才行。

再展示一张:

由于前面已经看过标签和样本已经打包在一起了,所以这里我们不需要使用之前学的dataset的打包功能了,只需要dataloader的分批次。

最后我们开始完整的建模之前我们先复习一下上次说过的完整流程:
1)设置步长 ,动量值 ,迭代次数 ,batch_size等信息,(如果需要)设置初始权重
2)导入数据,将数据切分成batches
3)定义神经网络架构
4)定义损失函数 ,如果需要的话,将损失函数调整成凸函数,以便求解最小值
5)定义所使用的优化算法
6)开始在epoches和batch上循环,执行优化算法:
6.1)调整数据结构,确定数据能够在神经网络、损失函数和优化算法中顺利运行
6.2)完成向前传播,计算初始损失
6.3)利用反向传播,在损失函数上求偏导数
6.4)迭代当前权重
6.5)清空本轮梯度
6.6)完成模型进度与效果监控
7)输出结果

按照惯例首先还是先导入库,下面是所有用到的库

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

1)确定超参数

lr = 0.1
gamma = 0.7
epochs = 5
bs = 128

2)导入数据,将数据切分成batches

batcheddata = DataLoader(mnist,batch_size = bs,shuffle = True)

我们可以通过查看shape属性来看结果是否符合要求:

3)定义神经网络架构

先定义输入输出神经元个数:

input_ = mnist.data[0].numel()
output_ = len(mnist.targets.unique())

定义架构:

def fit(net, batchdata, lr=0.01, epochs=5, gamma=0):criterion = nn.NLLLoss()  # 定义损失函数opt = optim.SGD(net.parameters(), lr=lr, momentum=gamma)  # 定义优化算法for epoch in range(epochs):net.train()  # 设置模型为训练模式running_loss = 0.0correct = 0total = 0for batch_idx, (x, y) in enumerate(batchdata):y = y.view(x.shape[0])  # 确保y是一个一维的张量opt.zero_grad()  # 清除之前的梯度sigma = net(x)  # 前向传播loss = criterion(sigma, y)  # 计算损失loss.backward()  # 反向传播opt.step()  # 更新参数# 计算损失running_loss += loss.item()# 计算准确率_, predicted = torch.max(sigma, 1)  # 获取模型的预测total += y.size(0)correct += (predicted == y).sum().item()# 输出每个epoch的平均损失和准确率avg_loss = running_loss / len(batchdata)accuracy = 100 * correct / totalprint(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')

4)实例化

torch.manual_seed(250)
net = model(in_features=input_, out_features=output_)
fit(net,batcheddata,lr=lr,epochs=epochs,gamma=gamma)

由于上面的代码都是前面的章节中已经提及过的,这里就不再重复了。

完整代码:

#完整代码
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transformslr = 0.1
gamma = 0.7
epochs = 5
bs = 128mnist = torchvision.datasets.FashionMNIST(root=r'E:\桌面\深度学习课件\lesson 11\MINST-FASHION', train=True, download=False, transform=transforms.ToTensor())
batcheddata = DataLoader(mnist,batch_size = bs,shuffle = True)
input_ = mnist.data[0].numel()
output_ = len(mnist.targets.unique())
class model(nn.Module):def __init__(self,in_features=1,out_features=2):super().__init__()self.linear1 = nn.Linear(in_features,128,bias=False)self.output = nn.Linear(128,out_features,bias=False)def forward(self,x):x = x.view(-1,28*28)sigma1 = torch.relu(self.linear1(x))z2 = self.output(sigma1)sigma2 = F.log_softmax(z2,dim=1)return sigma2def fit(net, batchdata, lr=0.01, epochs=5, gamma=0):criterion = nn.NLLLoss()  # 定义损失函数opt = optim.SGD(net.parameters(), lr=lr, momentum=gamma)  # 定义优化算法for epoch in range(epochs):net.train()  # 设置模型为训练模式running_loss = 0.0correct = 0total = 0for batch_idx, (x, y) in enumerate(batchdata):y = y.view(x.shape[0])  # 确保y是一个一维的张量opt.zero_grad()  # 清除之前的梯度sigma = net(x)  # 前向传播loss = criterion(sigma, y)  # 计算损失loss.backward()  # 反向传播opt.step()  # 更新参数# 计算损失running_loss += loss.item()# 计算准确率_, predicted = torch.max(sigma, 1)  # 获取模型的预测total += y.size(0)correct += (predicted == y).sum().item()# 输出每个epoch的平均损失和准确率avg_loss = running_loss / len(batchdata)accuracy = 100 * correct / totalprint(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')
torch.manual_seed(250)
net = model(in_features=input_, out_features=output_)
fit(net,batcheddata,lr=lr,epochs=epochs,gamma=gamma)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/13141.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Rust自学】20.2. 最后的项目:多线程Web服务器

说句题外话,这篇文章非常要求Rust的各方面知识,最好看一下我的【Rust自学】专栏的所有内容。这篇文章也是整个专栏最长(4762字)的文章,需要多次阅读消化,最好点个收藏,免得刷不到了。 喜欢的话…

Android学习21 -- launcher

1 前言 之前在工作中,第一次听到launcher有点蒙圈,不知道是啥,当时还赶鸭子上架去和客户PK launcher的事。后来才知道其实就是安卓的桌面。本来还以为很复杂,毕竟之前接触过windows的桌面,那叫一个复杂。。。 后面查了…

[创业之路-276]:从燃油汽车到智能汽车:工业革命下的价值变迁

目录 前言: 从燃油汽车到智能汽车:工业革命下的价值变迁 前言: 燃油汽车,第一次、第二次工业革命,机械化、电气化时代的产物,以机械和电气自动化为核心价值。 智能汽车,第三次、第四次工业革…

Spring Boot - 数据库集成07 - 数据库连接池

数据库连接池 文章目录 数据库连接池一:知识准备1:什么是数据库连接池?2:数据库连接池基本原理 二:HikariCP连接池1:简单使用2:进一步理解2.1:是SpringBoot2.x默认连接池2.2&#xf…

Python-基于PyQt5,Pillow,pathilb,imageio,moviepy,sys的GIF(动图)制作工具

前言:在抖音,快手等社交平台上,我们常常见到各种各样的GIF动画。在各大评论区里面,GIF图片以其短小精悍、生动有趣的特点,被广泛用于分享各种有趣的场景、搞笑的瞬间、精彩的动作等,能够快速吸引我们的注意…

使用线性回归模型逼近目标模型 | PyTorch 深度学习实战

前一篇文章,计算图 Compute Graph 和自动求导 Autograd | PyTorch 深度学习实战 本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started 使用线性回归模型逼近目标模型 什么是回归什么是线性回归使用 PyTorch 实现线性回归模型代码执行结…

【蓝桥杯嵌入式】2_LED

1、电路图 74HC573是八位锁存器,当控制端LE脚为高电平时,芯片“导通”,LE为低电平时芯片“截止”即将输出状态“锁存”,led此时不会改变状态,所以可通过led对应的八个引脚的电平来控制led的状态,原理图分析…

尝试在Office里调用免费大语言模型的阶段性进展

我个人觉得通过api而不是直接浏览器客户端聊天调用大语言模型是使用人工智能大模型的一个相对进阶的阶段。 于是就尝试了一下。我用的是老师木 袁进辉博士新创的硅基流动云上的免费的大模型。——虽然自己获赠了不少免费token,但测试阶段用不上。 具体步骤如下&am…

LabVIEW自定义测量参数怎么设置?

以下通过一个温度采集案例,说明在 LabVIEW 中设置自定义测量参数的具体方法: 案例背景 ​ 假设使用 NI USB-6009 数据采集卡 和 热电偶传感器 监测温度,需自定义以下参数: 采样率:1 kHz 输入量程:0~10 V&a…

理解 C 与 C++ 中的 const 常量与数组大小的关系

博客主页: [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C语言 文章目录 💯前言💯数组大小的常量要求💯C 语言中的数组大小要求💯C 中的数组大小要求💯为什么 C 中 const 变量可以作为数组大小💯进一步的…

【Elasticsearch】文本分类聚合Categorize Text Aggregation

响应参数讲解: key (字符串)由 categorization_analyzer 提取的标记组成,这些标记是类别中所有输入字段值的共同部分。 doc_count (整数)与类别匹配的文档数量。 max_matching_length (整数)从…

基于SpringBoot的信息技术知识赛系统的设计与实现(源码+SQL脚本+LW+部署讲解等)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

Windows Docker笔记-安装docker

安装环境 操作系统:Windows 11 家庭中文版 docker版本:Docker Desktop version: 4.36.0 (175267) 注意: Docker Desktop 支持以下Windows操作系统: 支持的版本:Windows 10(家庭版、专业版、企业版、教育…

《Kettle保姆级教学-界面介绍》

目录 一、Kettle介绍二、界面介绍1.界面构成2、菜单栏详细介绍2.1 【文件F】2.2 【编辑】2.3 【视图】2.4 【执行】2.5 【工具】2.6 【帮助】 3、转换界面介绍4、作业界面介绍5、执行结果 一、Kettle介绍 Kettle 是一个开源的 ETL(Extract, Transform, Load&#x…

新型智慧城市建设方案-1

智慧城市建设的背景与需求 随着信息技术的飞速发展,新型智慧城市建设成为推动城市现代化、提升城市管理效率的重要途径。智慧城市通过整合信息资源,优化城市规划、建设和管理,旨在打造更高效、便捷、宜居的城市环境。 智慧城市建设的主要内容…

【Java计算机毕业设计】基于Springboot的物业信息管理系统【源代码+数据库+LW文档+开题报告+答辩稿+部署教程+代码讲解】

源代码数据库LW文档(1万字以上)开题报告答辩稿 部署教程代码讲解代码时间修改教程 一、开发工具、运行环境、开发技术 开发工具 1、操作系统:Window操作系统 2、开发工具:IntelliJ IDEA或者Eclipse 3、数据库存储&#xff1a…

ollama部署deepseek实操记录

1. 安装 ollama 1.1 下载并安装 官网 https://ollama.com/ Linux安装命令 https://ollama.com/download/linux curl -fsSL https://ollama.com/install.sh | sh安装成功截图 3. 开放外网访问 1、首先停止ollama服务:systemctl stop ollama 2、修改ollama的servic…

Agentic Automation:基于Agent的企业认知架构重构与数字化转型跃迁---我的AI经典战例

文章目录 Agent代理Agent组成 我在企业实战AI Agent企业痛点我构建的AI Agent App 项目开源 & 安装包下载 大家好,我是工程师令狐,今天想给大家讲解一下AI智能体,以及企业与AI智能体的结合,文章中我会列举自己在企业中Agent实…

图论常见算法

图论常见算法 算法prim算法Dijkstra算法 用途最小生成树(MST):最短路径:拓扑排序:关键路径: 算法用途适用条件时间复杂度Kruskal最小生成树无向图(稀疏图)O(E log E)Prim最小生成树无…

手机上运行AI大模型(Deepseek等)

最近deepseek的大火,让大家掀起新一波的本地部署运行大模型的热潮,特别是deepseek有蒸馏的小参数量版本,电脑上就相当方便了,直接ollamaopen-webui这种类似的组合就可以轻松地实现,只要硬件,如显存&#xf…