365天深度学习训练营-第P1周:实现mnist手写数字识别

  • 🍨 本文为🔗365天深度学习训练营 内部限免文章(版权归 K同学啊 所有)
  • 🍦 参考文章地址: 🔗第P1周:实现mnist手写数字识别 | 365天深度学习训练营
  • 🍖 作者:K同学啊 | 接辅导、程序定制

文章目录

  • 我的环境:
  • 一、前期工作
    • 1. 设置 GPU
    • 2. 导入数据
    • 3. 数据可视化
  • 二、构建简单的CNN网络
  • 三、训练模型
    • 1. 设置超参数
    • 2. 编写训练函数
    • 3. 编写测试函数
    • 4. 正式训练
  • 四、结果可视化
  • 五、用自己制作的图片进行预测

我的环境:

  • 语言环境:Python 3.7.13
  • 编译器:jupyter notebook
  • 深度学习环境:
    • torch==1.12.1+cu113、cuda==11.3.1
    • torchvision==0.13.1+cu113、cuda==11.3.1

一、前期工作

1. 设置 GPU

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvisiondevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")device
device(type='cuda')

2. 导入数据

train_ds = torchvision.datasets.MNIST('data', train=True, transform=torchvision.transforms.ToTensor(), # 将数据类型转化为Tensordownload=True)test_ds  = torchvision.datasets.MNIST('data', train=False, transform=torchvision.transforms.ToTensor(), # 将数据类型转化为Tensordownload=True)
batch_size = 32train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)test_dl  = torch.utils.data.DataLoader(test_ds, batch_size=batch_size)
# 取一个批次查看数据格式
# 数据的shape为:[batch_size, channel, height, weight]
# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。
imgs, labels = next(iter(train_dl))
imgs.shape
torch.Size([32, 1, 28, 28])

3. 数据可视化

import numpy as np# 指定图片大小,图像大小为20宽、5高的绘图(单位为英寸inch)
plt.figure(figsize=(20, 5)) 
for i, imgs in enumerate(imgs[:20]):# 维度缩减npimg = np.squeeze(imgs.numpy())# 将整个figure分成2行10列,绘制第i+1个子图。plt.subplot(2, 10, i+1)plt.imshow(npimg, cmap=plt.cm.binary)plt.axis('off')

在这里插入图片描述

二、构建简单的CNN网络

使用 image_dataset_from_directory 方法将磁盘中的数据加载到 tf.data.Dataset 中

import torch.nn.functional as Fnum_classes = 10  # 图片的类别数class Model(nn.Module):def __init__(self):super().__init__()# 特征提取网络self.conv1 = nn.Conv2d(1, 32, kernel_size=3)  # 第一层卷积,卷积核大小为3*3self.pool1 = nn.MaxPool2d(2)                  # 设置池化层,池化核大小为2*2self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # 第二层卷积,卷积核大小为3*3   self.pool2 = nn.MaxPool2d(2) # 分类网络self.fc1 = nn.Linear(1600, 64)          self.fc2 = nn.Linear(64, num_classes)# 前向传播def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))     x = self.pool2(F.relu(self.conv2(x)))x = torch.flatten(x, start_dim=1)x = F.relu(self.fc1(x))x = self.fc2(x)return x
from torchinfo import summary
# 将模型转移到GPU中(我们模型运行均在GPU中进行)
model = Model().to(device)summary(model)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Model                                    --
├─Conv2d: 1-1                            320
├─MaxPool2d: 1-2                         --
├─Conv2d: 1-3                            18,496
├─MaxPool2d: 1-4                         --
├─Linear: 1-5                            102,464
├─Linear: 1-6                            650
=================================================================
Total params: 121,930
Trainable params: 121,930
Non-trainable params: 0
=================================================================

三、训练模型

1. 设置超参数

loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-2 # 学习率
opt        = torch.optim.SGD(model.parameters(),lr=learn_rate)

2. 编写训练函数

# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片num_batches = len(dataloader)   # 批次数目,1875(60000/32)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)          # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()        # 反向传播optimizer.step()       # 每一步自动更新# 记录acc与losstrain_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

3. 编写测试函数

def test (dataloader, model, loss_fn):size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片num_batches = len(dataloader)          # 批次数目,313(10000/32=312.5,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss        = loss_fn(target_pred, target)test_loss += loss.item()test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc  /= sizetest_loss /= num_batchesreturn test_acc, test_loss

4. 正式训练

epochs     = 5
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')
Epoch: 1, Train_acc:77.6%, Train_loss:0.744, Test_acc:91.1%,Test_loss:0.284
Epoch: 2, Train_acc:94.1%, Train_loss:0.196, Test_acc:96.2%,Test_loss:0.128
Epoch: 3, Train_acc:96.2%, Train_loss:0.123, Test_acc:97.5%,Test_loss:0.089
Epoch: 4, Train_acc:97.1%, Train_loss:0.094, Test_acc:97.4%,Test_loss:0.078
Epoch: 5, Train_acc:97.5%, Train_loss:0.078, Test_acc:98.0%,Test_loss:0.062
Done

四、结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

五、用自己制作的图片进行预测

在这里插入图片描述

for i in range(10):img_path = 'imgs/no' + str(i) + '.png'img = Image.open(img_path)img = img.convert('L')img = data_transform(img)img = torch.unsqueeze(img, dim=0)img = img.to(device)model.eval()with torch.no_grad():output = model(img)print(output.argmax(1).item())

预测结果:

0
1
2
3
4
5
6
7
8
9

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/44075.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

找工作交流群

群定义 源码圈找工作交流群,不同于原有的源码圈技术群,主要如下几点不同: 按照地区拆分。精力有限,暂时只有北上广深杭,拉人进群很累的!!!探讨面试相关的问题。你懂的,面…

最近找工作的行情大家来交流交流

人在广州 4 年经验前端没有大厂经历,广州的外包岗位都不好进,现在开始往北京上海投简历了,恐怕面临转行 最近广州大厂是不是在裁员,这样下去更难找了 坐标上海外企,有岗位,投递简历也很多,但是大…

字节跳动、小米、吉利汽车、同程艺龙、沙特阿美等公司高管变动

中国 字节跳动确认,TikTok首席执行官凯文梅耶尔(Kevin Mayer)已辞职。TikTok现任总经理瓦妮莎帕帕斯(Vanessa Pappas)将成为临时首席执行官。凯文梅耶尔于今年6月1日正式加入字节跳动,担任字节跳动COO兼TikTok全球CEO,此前为迪士尼流媒体负责…

【舆情监控】社会化大数据应用平台TOOM舆情监测系统

TOOM舆情监测系统 1.全面性 整合智能爬虫信息采集技术及信息研判和情感分析技术,对网上海量信息自动抓取、自动分类聚类、主题检测、专题聚焦,实现用户的网络舆情监测需求,形成简报、报告、图表等分析结果,为客户全面掌握网民舆…

大数据舆情监测

大数据舆情监测是当前比较流行的一项监测,今天,大数据技术的应用范围日益广泛。大数据正在促进信息技术与各行业的深度融合,其中的重点应用范围之一是在商业领域的运用,接下来我们简单了解大数据舆情监测分析方案,以及…

TOOM舆情分析网络舆情监控平台研究现状

随着网络舆情迅速发展,国内的舆情监测行业也日渐完善,舆情监控平台在企业发展过程中发挥重要作用,但同样也是有问题存在的,接下来TOOM舆情分析网络舆情监控平台研究现状? 一、网络舆情监控平台 网络舆情监控平台是一种能够对网…

舆情监测平台TOOM

随着互联网快速发展,如今市场上网络舆情监测平台种类有很多,但对于消费者很难挑选一款合适的舆情监测平台,接下来我们从网络舆情监测平台有哪些,舆情监测平台实时方案,如何挑选合适舆情监测平台三个方面,让…

网络舆情监测TOOM

网络舆情工作是收集整理分析和报送网络舆情信息,通过网络舆情监测实时监测网络信息,为企业提供强有力的支持工作,有效防范网络舆情危机,全面监测网络舆情信息的需求不断增加,接下来我们简单了解网络舆情监测相关事宜。…

搞笑短视频如何撰写脚本?分享简单小技巧

搞笑短视频如何撰写脚本?分享简单小技巧 在正式拍摄短视频之前,我们往往还需要撰写好脚本才行,它可以帮助我们更加顺利的拍摄短视频,也能让我们在后期制作的时候更为方便。而且短视频脚本撰写其实也相当于是短视频拍摄前的准备工…

【剪辑必备】短视频全自动切片软件,带货直播切片必备脚本【永久脚本+技术教程】

全自动切片系统 多线程处理 2小时的视频只需要30秒切片完成 影视剪辑 解说 抖音看电影项目 带货直播切片必备 如果你不会做影视解说 那你可以配合抖音看电影项目一起做 这项目目前都是用的这种软件切片制作的 某大V直播带货的时候把直播间录制下来 然后马上切片发布作…

TamperMonkey脚本开发_无限制视频提取

背景 已购课程下载 ,在提取m3u8视频时,视频缓存使用ASE加密 以及VI偏移量等等,由于对这方面了解并不多.不知道如何提取到真实的地址 通过几种方式 嗅探 抓包 控制台监控 都无法获取到 IDM的视频下载由于法律原因无法下载该ts文件 但是这些ts都是分段的 就算下载了 我自己也无…

互动视频脚本 : 电子类的短视频

测试视频:电子实验理论与实践 01 测试互动段视频 这是一个测试互动段视频的测试片段。 一、P1-有趣的电子实验 这是一个测试短视频,主要是用来练习在B站搭建互动段视频的过程。 下面选取其中几段视频,组成三个视频分P片段,用于互…

【剪辑必备】情感对话号必备-微信对话生成脚本,一键生成视频【安卓永久版脚本】

微信模拟聊天软件,可以自行更改网名,改头像,聊天内容随意修改,下载即可使用!没有试用教程,用法非常简单 设备需求:安卓系统 教程工具请到CSDN下载https://download.csdn.net/download/Linxiaoyu2022/87423…

用脚本帮同学自动生成文章观后感后,这名大四学生火了...

点击关注上方“五分钟学算法”, 设为“置顶或星标”,第一时间送达干货。 转自大数据文摘 同一个中国,同一个网课。 3 月 9 日,全国大学生共同上了一堂疫情防控思政大课,这可能是中国参与人数最多的一次网课了。 据统计…

最“赚钱”编程语言出炉,惊到我了.....

Stack Overflow 发布了 2023 年开发者调查报告,据称共计超过 9 万名开发者参与了此次调查。 完整报告包含了受访开发者画像,以及关于开发技术、AI、职业、社区等方面的内容。本文主要介绍关于开发技术和 AI 的部分。 懒人目录: 最流行编程语…

使用AI轻松搞定UI设计;a16z:快速高效使用LLM构建应用程序;AI时代99%软件都会消失;豆瓣9.3的经典Python入门书 | ShowMeAI日报

👀日报&周刊合集 | 🎡生产力工具与行业应用大全 | 🧡 点赞关注评论拜托啦! 🤖 Indeed:美国5月份生成式AI职位发布量增长 20% Indeed 是美国就业门户网站,根据其最新发布的数据显示&#xff…

投递简历用什么邮箱

一、关于邮箱 1、给应聘单位发送求职邮件时,最好用比较正式的邮箱,比如TOMVIP邮箱, 用户名最好选用你的英文名,或者你的英文名姓氏,简单大方,HR很容易记下,且会觉得你很专业、正式。 2、设置签名…

关于简历

简历格式 这个图是我以前在网上找的,但不记得出处了。如有侵权,请联系我。

怎么从手机上下载应聘简历模板?个人简历如何从手机做?

​当我们找工作时,一般都会先准备一份应聘简历,当我们投递简历或者面试时都会用到,那么如果想要用手机来制作一份简历时,该如何操作呢?想要制作简历,那么首先要下载一份自己喜欢的简历模板,会让…

网页版简历

最近在学习HTML和CSS,就想用他们写点东西。想了想想不出(想不想谁知道….)。最后瞟了一眼我的高大上的简历(你就吹吧……),就是决定你了。 咳咳……。下面就和大家分享一下我用html和css写的简历。欢迎提出…