Pytorch从零开始实战01

Pytorch从零开始实战——MNIST手写数字识别

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——MNIST手写数字识别
    • 环境准备
    • 数据集
    • 模型选择
    • 模型训练
    • 可视化展示

环境准备

本系列基于Jupyter notebook,使用Python3.7.12,Pytorch1.7.0+cu110,torchvision0.8.0,需读者自行配置好环境且有一些深度学习理论基础。

导入需要用到的包

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torch.nn.functional as F
import random
from time import time
import random
import numpy as np
import pandas as pd
import datetime
import gc
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'  # 用于避免jupyter环境突然关闭
torch.backends.cudnn.benchmark=True  # 用于加速GPU运算的代码

创建设备对象

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type=‘cuda’)

设置随机数种子

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

数据集

本次实战使用MNIST数据集,这是一个包含了手写数字的灰度图像的数据集,每个图像都是28x28像素大小,并且标记了相应的数字,也是很多计算机视觉初学者第一个使用的数据集。

导入训练集与测试集,使用torchvision.datasets可以在线下载很多常见数据集,只需要将后面参数设置download=True即可直接下载,train=True为训练集,train=False为测试集

# 导入训练集和测试集
train_data = torchvision.datasets.MNIST('data', train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.MNIST('data', train=False, transform=torchvision.transforms.ToTensor(),download=True)

定义一个函数,随机查看5张图片

# 随机展示5个图片 data = torchvision.datasets....  需要接受tensor格式的对象
def plotsample(data):fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子图for i in range(5):num = random.randint(0, len(data) - 1) #首先选取随机数,随机选取五次#抽取数据中对应的图像对象,make_grid函数可将任意格式的图像的通道数升为3,而不改变图像原始的数据#而展示图像用的imshow函数最常见的输入格式也是3通道npimg = torchvision.utils.make_grid(data[num][0]).numpy()nplabel = data[num][1] #提取标签 #将图像由(3, weight, height)转化为(weight, height, 3),并放入imshow函数中读取axs[i].imshow(np.transpose(npimg, (1, 2, 0))) axs[i].set_title(nplabel) #给每个子图加上标签axs[i].axis("off") #消除每个子图的坐标轴plotsample(train_data)

在这里插入图片描述

使用DataLoder将它按照batch_size批量划分,并将训练集顺序打乱。

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

模型选择

由于数据集较为简单,所以本次实验使用简单的卷积神经网络。

第一次卷积和池化:
self.conv1 是第一个卷积层,将输入特征图的通道数从1增加到32,同时使用3x3的卷积核进行卷积。由于没有填充(padding)操作,卷积后的特征图大小减小为原来的大小减2(28x28 -> 26x26)。
self.pool1 是第一个最大池化层,将特征图的大小减半,从26x26变为13x13。
第二次卷积和池化:
self.conv2 是第二个卷积层,将输入特征图的通道数从32增加到64,同样使用3x3的卷积核进行卷积。由于没有填充操作,卷积后的特征图大小再次减小为原来的大小减2(13x13 -> 11x11)。
self.pool2 是第二个最大池化层,将特征图的大小再次减半,从11x11变为5x5。
全连接层:
在进入全连接层之前,需要将最后一个池化层的输出拉平成一个一维向量。这是通过 torch.flatten(x, start_dim=1) 完成的,它将5x5x64的三维张量转换为长度为5x5x64 = 1600的一维向量。
然后,self.fc1 是第一个全连接层,将1600个输入特征映射到64个输出特征。
最后进行10分类输出结果。

num_classes = 10 # 10分类
class Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3)self.pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3)self.pool2 = nn.MaxPool2d(2)self.fc1 = nn.Linear(1600, 64)self.fc2 = nn.Linear(64, num_classes)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = torch.flatten(x, start_dim=1) # 拉平x = F.relu(self.fc1(x))x = self.fc2(x)return x

将模型转移到GPU中,并使用summary查看模型

from torchinfo import summary
# 将模型转移到GPU中
model = Model().to(device)
summary(model)

在这里插入图片描述

模型训练

定义损失函数、学习率、优化算法

loss_fn = nn.CrossEntropyLoss()
learn_rate = 0.01
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)

定义训练函数,返回一个epoch的模型的准确率和损失

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)num_batches = len(dataloader)train_loss, train_acc = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

定义测试函数,与训练函数类似,只是停止梯度更新,节省计算内存消耗

def test (dataloader, model, loss_fn):size = len(dataloader.dataset) num_batches = len(dataloader)         test_loss, test_acc = 0, 0with torch.no_grad():for X, target in dataloader:X, target = X.to(device), target.to(device)pred = model(X)loss = loss_fn(pred, target)test_acc += (pred.argmax(1) == target).type(torch.float).sum().item()test_loss += loss.item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

开始训练,一共进行了5轮epoch,最后在训练集准确率可达97.7%,测试集准确率可达98.1%

epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval() # 确保模型不会进行训练操作epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)print("epoch:%d, train_acc:%.1f%%, train_loss:%.3f, test_acc:%.1f%%, test_loss:%.3f"% (epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))
print("Done")

可视化展示

使用matplotlib进行训练、测试的可视化

plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/126599.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CH06_第一组重构(下)

封装变量(Encapsulate Variable | 132) 曾用名:自封装字段(Self-Encapsulate Field) 曾用名:封装字段(Encapsulate Field) let defaultOwner {firstName: "Martin", la…

用半天时间从零开始复习前端之html

目录 前言 科班生的标配:半天听完一门标记型语言 准备工作 webstorm2022 webstrom 第一个html页面 body h系列标签 行标签和块标签 列表标签 表格标签(另起一篇) 万能的input 1.快速生成多个标签 2.同时选中多个 前言 科班生的标…

系统报错“由于找不到msvcp140.dll无法继续执行代码”的处理方法

我在使用电脑时,突然发现了一个错误提示:“无法启动程序,因为找不到msvcp140.dll文件”。这让我非常困惑,因为我确定这个文件应该存在于我的电脑上。但是电脑依然报错“由于找不到msvcp140.dll无法继续执行代码”,这个…

【周末闲谈】如何利用AIGC为我们创造有利价值?

个人主页:【😊个人主页】 系列专栏:【❤️周末闲谈】 系列目录 ✨第一周 二进制VS三进制 ✨第二周 文心一言,模仿还是超越? ✨第二周 畅想AR 文章目录 系列目录前言AIGCAI写作AI绘画AI视频生成AI语音合成 前言 在此之…

Linux防火墙(iptables)

一、linux的防火墙组成 linux的防火墙由netfilter和iptables组成。用户空间的iptables制定防火墙规则,内核空间的netfilter实现防火墙功能。 netfilter(内核空间)位于Linux内核中的包过滤防火墙功能体系,称为Linux防火墙的“内核…

MHA高可用及故障切换

一、什么是 MHA MHA(MasterHigh Availability)是一套优秀的MySQL高可用环境下故障切换和主从复制的软件。 MHA 的出现就是解决MySQL 单点的问题。 MySQL故障切换过程中,MHA能做到0-30秒内自动完成故障切换操作。 MHA能在故障切换的过程中最大…

Vue中如何实现城市3D分布图

cityfenbu.vue <template><div ><el-card class"seriesmap-box-card"><div slot"header" class"clearfix"><span>城市分布图 (点击可下钻到县)</span></div><div><div class"series-ma…

c语言练习45:模拟实现内存函数memcpy

模拟实现内存函数memcpy 针对内存块&#xff0c;不在乎内存中的数据。 拷贝内容有重叠的话应用memmove 模拟实现&#xff1a; 代码&#xff1a; 模拟实现memcpy #include<stdio.h> #include<assert.h> void* my_memcpy(void* dest, const void* src, size_t num…

【Linux】网络编程网络基础(C++)

目录 一、计算机网络背景 二、认识 "协议" 三、网络协议初识 【3.1】协议分层 【3.2】OSI七层模型 【3.3】TCP/IP五层(或四层)模型 四、网络传输基本流程 【4.1】网络传输流程图 【4.2】数据包封装和分用 五、网络中的地址管理 一、计算机网络背景 【独立…

谷粒商城----缓存与分布式锁

1、缓存使用 为了系统性能的提升&#xff0c;我们一般都会将部分数据放入缓存中&#xff0c;加速访问。而 db 承担数据落盘工作。 哪些数据适合放入缓存&#xff1f;  即时性、数据一致性要求不高的  访问量大且更新频率不高的数据&#xff08;读多&#xff0c;写少&…

The WebSocket session [x] has been closed and no method (apart from close())

在向客户端发送消息时&#xff0c;session关闭了。 不管是单客户端发送消息还是多客户端发送消息&#xff0c;在发送消息之前判断session 是否关闭 使用 isOpen() 方法

Nginx 学习(九)集群概述与LVS工作模式的配置

一 集群 1 概述 通过高速网络将很多服务器集中起来一起提供同一种服务&#xff0c;在客户端看来就像是只有一个服务器可以在付出较低成本的情况下获得在性能、可靠性、灵活性方面的相对较高的收益任务调度是集群系统中的核心技术 2 目的 提高性能。如计算密集型应用&…

【C++进阶】:AVL树(平衡因子)

AVL树 一.概念二.插入1.搜索二叉树2.平衡因子 三.旋转1.更新平衡因子2.旋转1.左单旋2.右单旋3.先右旋再左旋4.先左旋再右旋 四.完整代码 一.概念 二叉搜索树虽可以缩短查找的效率&#xff0c;但如果数据有序或接近有序二叉搜索树将退化为单支树,查找元素相当于在顺序表中搜索元…

2023区块链应用操作员认证(4级)报名来弘博创新

区块链应用操作员&#xff0c;是指运用区块链技术及工具&#xff0c;从事政务、金融、医疗、教育、养老等场景系统应用操作的人员。 腾讯作为广东省第一批公布的社会培训评价组织&#xff0c;可开展职业技能等级认定职业(工种)区块链应用操作员(4-3-2-1级)。 证书含金量 证书是…

Redis 删除策略

文章目录 Redis 删除策略一、过期数据二、数据删除策略1、定时删除2、惰性删除3、定期删除4、删除策略对比 三、逐出算法 Redis 删除策略 一、过期数据 Redis是一种内存级数据库&#xff0c;所有数据均存放在内存中&#xff0c;内存中的数据可以通过TTL指令获取其状态 XX &a…

100万级连接,爱奇艺WebSocket网关如何架构

说在前面 在40岁老架构师 尼恩的读者社区(50)中&#xff0c;很多小伙伴拿到一线互联网企业如阿里、网易、有赞、希音、百度、滴滴的面试资格。 最近&#xff0c;尼恩指导一个小伙伴简历&#xff0c;写了一个《高并发网关项目》&#xff0c;此项目帮这个小伙拿到 字节/阿里/微…

BPPISE数据科学案例框架

本专题共10篇内容&#xff0c;包含淘宝APP基础链路过去一年在用户体验数据科学领域&#xff08;包括商详、物流、性能、消息、客服、旅程等&#xff09;一些探索和实践经验。 在商详页基于用户动线和VOC挖掘用户决策因子带来浏览体验提升&#xff1b;在物流侧洞察用户求助时间与…

Kafka3.0.0版本——增加副本因子

目录 一、服务器信息二、启动zookeeper和kafka集群2.1、先启动zookeeper集群2.2、再启动kafka集群 三、增加副本因子3.1、增加副本因子的概述3.2、增加副本因子的示例3.2.1、创建topic(主题)3.2.2、手动增加副本存储 一、服务器信息 四台服务器 原始服务器名称原始服务器ip节点…

被问到: http 协议和 https 协议的区别怎么办?别慌,这篇文章给你答案

前言 作为软件测试师&#xff0c;大家都知道一些常用的网络协议是我们必须要了解和掌握的&#xff0c;比如 HTTP 协议&#xff0c;HTTPS 协议就是两个使用非常广泛的协议&#xff0c;所以也是面试官问的面试的时候问的比较多的两个协议&#xff1b;因为这两个协议有相似和关联的…

为什么说网络安全是风口行业?是it行业最后的红利?

前言 “没有网络安全就没有国家安全”。当前&#xff0c;网络安全已被提升到国家战略的高度&#xff0c;成为影响国家安全、社会稳定至关重要的因素之一。 网络安全行业特点 1、就业薪资非常高&#xff0c;涨薪快 2021年猎聘网发布网络安全行业就业薪资行业最高人均33.77万&…