Pytorch学习笔记day4——训练mnist数据集和初步研读

该来的还是来了hhhhhhhhhh,基本上机器学习的初学者都躲不开这个例子。开源,数据质量高,数据尺寸整齐,问题简单,实在太适合初学者食用了。

今天把代码跑通,趁着周末好好的琢磨一下里面的各种细节。

代码实现

首先鸣谢百度AI,真的直接生成的代码就能跑,不要太爽。差不多九年前大二的时候,这一点点代码,是要看完一个几小时的英文视频才能获取的。看着网络非常非常浅,就已经达到了比较好的预测效果。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5) #输入为1,输出为10,卷积核大小5self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.fc = nn.Linear(20 * 4 * 4, 10)def forward(self, x):batch_size = x.size(0)   #第一个维度是batch维度,图片为1*28*28时,输入为64*1*28*28x = torch.relu(self.conv1(x))  # 输入64*1*28*28, 输出64*10*24*24x = torch.max_pool2d(x, 2, 2)  # 输入64*10*24*24, 输出64*10*12*12,池化层x = torch.relu(self.conv2(x))  # 输入64*10*12*12, 输出64*20*8*8x = torch.max_pool2d(x, 2, 2)  # 输入64*20*8*8, 输出64*20*4*4x = x.view(batch_size, -1)     # 输入64*20*4*4, 输出64*320x = self.fc(x)                 # 输入64*320, 输出64*10return xif __name__=="__main__":# 定义超参数batch_size = 64epochs = 10learning_rate = 0.01# 数据预处理transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])# 加载训练/测试数据  batch_size:每次训练的规模  shuffle: 是否每次训练完对数据进行洗牌train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)test_dataset = datasets.MNIST('data', train=False, transform=transform)test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)# 实例化模型、损失函数和优化器model = Net()optimizer = optim.Adam(model.parameters(), lr=learning_rate)criterion = nn.CrossEntropyLoss()# 训练模型for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader): #自动打batchoptimizer.zero_grad()   #典型的训练步骤output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),0. * batch_idx / len(train_loader), loss.item()))# 测试模型model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))

运行结果如下:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: ForbiddenDownloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data\MNIST\raw\train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [02:41<00:00, 61401.03it/s]
Extracting data\MNIST\raw\train-images-idx3-ubyte.gz to data\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: ForbiddenDownloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data\MNIST\raw\train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 97971.03it/s]
Extracting data\MNIST\raw\train-labels-idx1-ubyte.gz to data\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: ForbiddenDownloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data\MNIST\raw\t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:29<00:00, 56423.58it/s]
Extracting data\MNIST\raw\t10k-images-idx3-ubyte.gz to data\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: ForbiddenDownloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data\MNIST\raw\t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 4339528.19it/s]
Extracting data\MNIST\raw\t10k-labels-idx1-ubyte.gz to data\MNIST\rawTrain Epoch: 0 [0/60000 (0%)]	Loss: 2.275243
Train Epoch: 0 [6400/60000 (0%)]	Loss: 0.200208
Train Epoch: 0 [12800/60000 (0%)]	Loss: 0.064670
Train Epoch: 0 [19200/60000 (0%)]	Loss: 0.066074
Train Epoch: 0 [25600/60000 (0%)]	Loss: 0.115960
Train Epoch: 0 [32000/60000 (0%)]	Loss: 0.171170
Train Epoch: 0 [38400/60000 (0%)]	Loss: 0.041663
Train Epoch: 0 [44800/60000 (0%)]	Loss: 0.179172
Train Epoch: 0 [51200/60000 (0%)]	Loss: 0.014898
Train Epoch: 0 [57600/60000 (0%)]	Loss: 0.035095
Train Epoch: 1 [0/60000 (0%)]	Loss: 0.016566
Train Epoch: 1 [6400/60000 (0%)]	Loss: 0.008371
Train Epoch: 1 [12800/60000 (0%)]	Loss: 0.006069
Train Epoch: 1 [19200/60000 (0%)]	Loss: 0.009995
Train Epoch: 1 [25600/60000 (0%)]	Loss: 0.020422
Train Epoch: 1 [32000/60000 (0%)]	Loss: 0.155348
Train Epoch: 1 [38400/60000 (0%)]	Loss: 0.059595
Train Epoch: 1 [44800/60000 (0%)]	Loss: 0.038654
Train Epoch: 1 [51200/60000 (0%)]	Loss: 0.084179
Train Epoch: 1 [57600/60000 (0%)]	Loss: 0.147250
Train Epoch: 2 [0/60000 (0%)]	Loss: 0.040161
Train Epoch: 2 [6400/60000 (0%)]	Loss: 0.147080
Train Epoch: 2 [12800/60000 (0%)]	Loss: 0.037228
Train Epoch: 2 [19200/60000 (0%)]	Loss: 0.257872
Train Epoch: 2 [25600/60000 (0%)]	Loss: 0.052811
Train Epoch: 2 [32000/60000 (0%)]	Loss: 0.005805
Train Epoch: 2 [38400/60000 (0%)]	Loss: 0.092318
Train Epoch: 2 [44800/60000 (0%)]	Loss: 0.084066
Train Epoch: 2 [51200/60000 (0%)]	Loss: 0.000331
Train Epoch: 2 [57600/60000 (0%)]	Loss: 0.011482
Train Epoch: 3 [0/60000 (0%)]	Loss: 0.042851
Train Epoch: 3 [6400/60000 (0%)]	Loss: 0.004001
Train Epoch: 3 [12800/60000 (0%)]	Loss: 0.008942
Train Epoch: 3 [19200/60000 (0%)]	Loss: 0.045065
Train Epoch: 3 [25600/60000 (0%)]	Loss: 0.099309
Train Epoch: 3 [32000/60000 (0%)]	Loss: 0.054098
Train Epoch: 3 [38400/60000 (0%)]	Loss: 0.059155
Train Epoch: 3 [44800/60000 (0%)]	Loss: 0.016098
Train Epoch: 3 [51200/60000 (0%)]	Loss: 0.114458
Train Epoch: 3 [57600/60000 (0%)]	Loss: 0.231477
Train Epoch: 4 [0/60000 (0%)]	Loss: 0.003781
Train Epoch: 4 [6400/60000 (0%)]	Loss: 0.068822
Train Epoch: 4 [12800/60000 (0%)]	Loss: 0.103501
Train Epoch: 4 [19200/60000 (0%)]	Loss: 0.002396
Train Epoch: 4 [25600/60000 (0%)]	Loss: 0.174503
Train Epoch: 4 [32000/60000 (0%)]	Loss: 0.027796
Train Epoch: 4 [38400/60000 (0%)]	Loss: 0.013167
Train Epoch: 4 [44800/60000 (0%)]	Loss: 0.011576
Train Epoch: 4 [51200/60000 (0%)]	Loss: 0.000726
Train Epoch: 4 [57600/60000 (0%)]	Loss: 0.069251
Train Epoch: 5 [0/60000 (0%)]	Loss: 0.006919
Train Epoch: 5 [6400/60000 (0%)]	Loss: 0.015165
Train Epoch: 5 [12800/60000 (0%)]	Loss: 0.117820
Train Epoch: 5 [19200/60000 (0%)]	Loss: 0.031030
Train Epoch: 5 [25600/60000 (0%)]	Loss: 0.031566
Train Epoch: 5 [32000/60000 (0%)]	Loss: 0.046268
Train Epoch: 5 [38400/60000 (0%)]	Loss: 0.055709
Train Epoch: 5 [44800/60000 (0%)]	Loss: 0.021299
Train Epoch: 5 [51200/60000 (0%)]	Loss: 0.004246
Train Epoch: 5 [57600/60000 (0%)]	Loss: 0.014340
Train Epoch: 6 [0/60000 (0%)]	Loss: 0.056358
Train Epoch: 6 [6400/60000 (0%)]	Loss: 0.104084
Train Epoch: 6 [12800/60000 (0%)]	Loss: 0.097005
Train Epoch: 6 [19200/60000 (0%)]	Loss: 0.009379
Train Epoch: 6 [25600/60000 (0%)]	Loss: 0.078417
Train Epoch: 6 [32000/60000 (0%)]	Loss: 0.217889
Train Epoch: 6 [38400/60000 (0%)]	Loss: 0.079795
Train Epoch: 6 [44800/60000 (0%)]	Loss: 0.052873
Train Epoch: 6 [51200/60000 (0%)]	Loss: 0.127716
Train Epoch: 6 [57600/60000 (0%)]	Loss: 0.087016
Train Epoch: 7 [0/60000 (0%)]	Loss: 0.045884
Train Epoch: 7 [6400/60000 (0%)]	Loss: 0.087923
Train Epoch: 7 [12800/60000 (0%)]	Loss: 0.164549
Train Epoch: 7 [19200/60000 (0%)]	Loss: 0.111163
Train Epoch: 7 [25600/60000 (0%)]	Loss: 0.300172
Train Epoch: 7 [32000/60000 (0%)]	Loss: 0.045357
Train Epoch: 7 [38400/60000 (0%)]	Loss: 0.087294
Train Epoch: 7 [44800/60000 (0%)]	Loss: 0.110581
Train Epoch: 7 [51200/60000 (0%)]	Loss: 0.001932
Train Epoch: 7 [57600/60000 (0%)]	Loss: 0.066714
Train Epoch: 8 [0/60000 (0%)]	Loss: 0.047415
Train Epoch: 8 [6400/60000 (0%)]	Loss: 0.106327
Train Epoch: 8 [12800/60000 (0%)]	Loss: 0.016832
Train Epoch: 8 [19200/60000 (0%)]	Loss: 0.013452
Train Epoch: 8 [25600/60000 (0%)]	Loss: 0.035256
Train Epoch: 8 [32000/60000 (0%)]	Loss: 0.026502
Train Epoch: 8 [38400/60000 (0%)]	Loss: 0.011809
Train Epoch: 8 [44800/60000 (0%)]	Loss: 0.171943
Train Epoch: 8 [51200/60000 (0%)]	Loss: 0.209570
Train Epoch: 8 [57600/60000 (0%)]	Loss: 0.047113
Train Epoch: 9 [0/60000 (0%)]	Loss: 0.126423
Train Epoch: 9 [6400/60000 (0%)]	Loss: 0.016720
Train Epoch: 9 [12800/60000 (0%)]	Loss: 0.210951
Train Epoch: 9 [19200/60000 (0%)]	Loss: 0.072410
Train Epoch: 9 [25600/60000 (0%)]	Loss: 0.042366
Train Epoch: 9 [32000/60000 (0%)]	Loss: 0.002912
Train Epoch: 9 [38400/60000 (0%)]	Loss: 0.074261
Train Epoch: 9 [44800/60000 (0%)]	Loss: 0.004673
Train Epoch: 9 [51200/60000 (0%)]	Loss: 0.074964
Train Epoch: 9 [57600/60000 (0%)]	Loss: 0.040360Test set: Average loss: 0.0011, Accuracy: 9795/10000 (98%)

部分解读

下面这个语法是定义了一个二维卷积层,

nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

可以参考一下这篇博客 https://blog.csdn.net/qq_60245590/article/details/135856418
百度AI也给出了解释
在这里插入图片描述
训练数据是python实时从网上下载的,打开看看,里面还挺东西,应该最主要的就是训练数据和测试数据。可是这样的话,为啥要分布下载个train_dataset和test_dataset呢?我略有些迷茫。
在这里插入图片描述
batch居然不用我们自己打,咦?这个功能mindspore有吗?我自己捏的数据能自动打batch吗?能的话就很方便了。
在这里插入图片描述
好!今天崩铁前瞻~打游戏去咯~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/382235.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

这7款高效爬虫工具软件,非常实用!

在当今数据驱动的时代&#xff0c;自动化爬虫工具和软件成为了许多企业和个人获取数据的重要手段。这里会介绍6款功能强大、操作简便的自动化爬虫工具&#xff0c;用好了可以更高效地进行数据采集。 1. 八爪鱼采集器 八爪鱼是一款功能强大的桌面端爬虫软件&#xff0c;主打可…

pico+unity3d 射线交互教程

前期配置&#xff1a;环境配置参考教程一&#xff0c;手部模型参考教程二&#xff0c;场景基于上一篇搭建。 最终效果&#xff1a;手部射线&#xff08;初始不可见&#xff09;对准 UI 显示&#xff0c;按下手柄 Trigger 键与可交互 UI&#xff08;如 Button、Toggle、Slider …

数学建模(7)——Logistic模型

一、马尔萨斯人口模型 import numpy as np import matplotlib.pyplot as plt# 初始人口 N0 100 # 人口增长率 r 0.02 # 时间段&#xff08;年&#xff09; t np.linspace(0, 200, 200)# 马尔萨斯人口模型 N N0 * np.exp(r * t)# 绘图 plt.plot(t, N, labelPopulation) plt.…

【源码阅读】Sony的go breaker熔断器源码探究

文章目录 背景源码分析总结 背景 在微服务时代&#xff0c;服务和服务之间调用、跨部门调用都是很常见的事&#xff0c;但这些调用都存在很多不确定因素&#xff0c;如核心服务A依赖的部门B服务挂掉了&#xff0c;那么A本身的功能将会受到直接的影响&#xff0c;而这些都会影响…

CSS 两种盒模型 box-sizing content-box 和 border-box

文章目录 Intro谨记box-sizing 两个不同赋值的效果区别&#xff1f;宽高的数值计算标准盒模型 box-sizing: content-box; box-sizing 属性的全局设置 Intro 先问一句&#xff1a;box-sizing 和它的两个属性值是做什么用的&#xff1f;以前我并不知道它的存在&#xff0c;也做…

使用 Redis 实现验证码、token 的存储,用自定义拦截器完成用户认证、并使用双重拦截器解决 token 刷新的问题

基于session实现登录流程 1.发送验证码 用户在提交手机号后&#xff0c;会校验手机号是否合法&#xff0c;如果不合法&#xff0c;则要求用户重新输入手机号 如果手机号合法&#xff0c;后台此时生成对应的验证码&#xff0c;同时将验证码进行保存&#xff0c;然后再通过短信…

【数据挖掘】词云分析

目录 1. 词云分析 2. Python 中的 WordCloud 库 1. 词云分析 词云&#xff08;Word Cloud&#xff09;是数据可视化的一种形式&#xff0c;主要用于展示文本数据中单词的频率和重要性。它具有以下几种主要用途和意义&#xff1a; 1. 文本分析 • 识别关键主题&#xff1a;通…

AGI 之 【Hugging Face】 的【从零训练Transformer模型】之一 [ 如何寻找大型数据集 ] / [ 构建词元分析器 ] 的简单整理

AGI 之 【Hugging Face】 的【从零训练Transformer模型】之一 [ 如何寻找大型数据集 ] / [ 构建词元分析器 ] 的简单整理 目录 AGI 之 【Hugging Face】 的【从零训练Transformer模型】之一 [ 如何寻找大型数据集 ] / [ 构建词元分析器 ] 的简单整理 一、简单介绍 二、Transf…

数据结构day6

一、思维导图 二、模拟面试 typedef定义函数指针的方式typedef int(*p)(int,int);对void*指针的理解&#xff0c;相关应用万能指针&#xff0c;可以定义形参用来接收任意类型的指针变量&#xff0c;也可以定义函数用来返回任意类型的指针变量例如malloc函数在堆区申请内存&…

组队学习——支持向量机

本次学习支持向量机部分数据如下所示 IDmasswidthheightcolor_scorefruit_namekind 其中ID&#xff1a;1-59是对应训练集和验证集的数据&#xff0c;60-67是对应测试集的数据&#xff0c;其中水果类别一共有四类包括apple、lemon、orange、mandarin。要求根据1-59的数据集的自…

NPS配置域名访问本地应用

架构简易说明&#xff1a; 阿里云云服务器一台&#xff1a;NPS服务端 本地Linux服务器一台&#xff1a;NPS客户端&#xff0c;支持互联网 域名一个&#xff1a;解析到云服务器 1.在nps后台配置TCP隧道信息 其中&#xff0c;服务端口为云服务器的端口 &#xff0c;不要与已存…

Linux ls命令详解

学习 Linux &#xff0c;本质上是学习在命令行下熟悉使用 Linux 的各类命令&#xff1b; 1. Linux 命令通用格式 命令格式&#xff1a;命令 【-选项】【参数】(个别命令不遵循该格式) 短线&#xff08;-&#xff09;是区分选项和参数的标志&#xff0c;选项用来调整命令的功能…

陶德:边种田边写代码,3年300万行,一个人写出了“国产大满贯QT”

这是《开发者说》的第12期&#xff0c;本期我们邀请的开发者是陶德&#xff0c;从小在国企矿山里长大&#xff0c;计算机成绩是文科班里最差的一个&#xff0c;毕业两年找不到工作&#xff0c;睡过公园&#xff0c;讨过剩饭&#xff0c;用打魔兽世界的方式磨炼技术&#xff0c;…

PYTHON学习笔记(四、pyhton数据结构--列表)

&#xff08;1&#xff09;list列表 列表的含义是指&#xff1a;&#xff08;1&#xff09;一系列的按特定顺序排列的元素组成。&#xff08;2&#xff09;python中内置的可变序列。&#xff08;3&#xff09;在python中使用[]定义列表&#xff0c;元素与元素之间使用英文的逗…

数据结构 day4

目录 思维导图&#xff1a; 学习内容&#xff1a; 1. 链表的引入 1.1 顺序表的优缺点 1.1.1 优点 1.1.2 不足 1.1.3 缺点 1.2 链表的概念 1.2.1 链式存储的线性表叫做链表 1.2.2 链表的基础概念 1.3 链表的分类 2. 单向链表 2.1 节点结构体类型 2.2 创建链表 2.…

【手撕数据结构】拿捏单链表

目录 单链表介绍链表的初始化打印链表增加节点尾插头插再给定位置之后插入在给定位置之前插入 删除节点尾删头删删除给定位置的节点删除给定位置之后的节点 查找节点 单链表介绍 单链表也叫做无头单向非循环链表&#xff0c;链表也是一种线性结构。他在逻辑结构上一定连续&…

展望未来:利用【Python】结合【机器学习】强化数据处理能力

欢迎来到 破晓的历程的 博客 ⛺️不负时光&#xff0c;不负己✈️ 文章目录 一、引言二、数据清洗与预处理三、特征工程四、数据可视化五、模型训练与评估六、模型部署与优化七、总结 在数据驱动的时代&#xff0c;数据处理与机器学习技术的结合已成为推动业务增长和创新的关键…

Redis 7.x 系列【25】集群部署

有道无术&#xff0c;术尚可求&#xff0c;有术无道&#xff0c;止于术。 本系列Redis 版本 7.2.5 源码地址&#xff1a;https://gitee.com/pearl-organization/study-redis-demo 文章目录 1. 概述2. 配置文件2.1 cluster-enabled2.2 cluster-config-file2.3 cluster-node-tim…

HAL库源码移植与使用之RTC时钟

实时时钟(Real Time Clock&#xff0c;RTC)&#xff0c;本质是一个计数器&#xff0c;计数频率常为秒&#xff0c;专门用来记录时间。 普通定时器无法掉电运行&#xff01;但RTC可由VBAT备用电源供电&#xff0c;断电不断时 这里讲F1系列的RTC 可以产生三个中断信号&#xff…

TYPE-C接口PD取电快充协议芯片ECP5701:支持PD 2.0和PD 3.0(5V,9V,12V,15V,20V)

随着智能设备的普及&#xff0c;快充技术成为了越来越多用户的刚需。而TYPE-C接口作为新一代的USB接口&#xff0c;具有正反插、传输速度快、充电体验好等优点&#xff0c;已经成为了快充技术的主要接口形式。而TYPE-C接口的PD&#xff08;Power Delivery&#xff09;取电快充协…