【新手适用】手把手教你从零开始实现一个基于Pytorch的卷积神经网络CNN二: 如何训练模型,内附详细损失、准确率、均值计算

手把手教你从零开始实现一个基于Pytorch的卷积神经网络CNN(新手适用)一: model.py:创建模块-CSDN博客

 从零开始实现一个基于Pytorch的卷积神经网络 - 知乎

目录

1 设备device定义

2 训练模型定义

3 开始训练

3.1 step、batchsize和数据集中图片数的关系

3.2 关于类型转换

4 保存模型

5 训练过程可视化

5.1 损失

5.1.1 损失的计算

 5.1.2 平均损失的计算

5.2 准确率

5.2.1 准确率计算

5.2.2 平均准确率计算

 6. train.py完整代码

7. 训练结果


设备device定义

通过torch.device()来指定使用的设备device,然后通过.to()方法将模型和数据放到指定的设备上,这样我们就可以通过定义device来指定是在cpu还是显卡上进行训练了,而且在多显卡的情况下也可以指定使用其中的某一张显卡进行训练。

torch.cuda.is_available()可以判断本设备是否支持CUDA,如果支持就返回True,不支持就返回False。这个函数可以让其自动判断是否支持CUDA加速并自动选择设备。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

训练模型定义

  1. 初始化和导入模型
  2. 定义超参数、数据集和DataLoader
  3. 定义损失函数loss function和优化器optimizer
  4. 启用梯度:torch.set_grad_enabled(True)

代码如下,其中的具体解释可看知乎原文。

# 训练数据
import torch
import torchvision
import torch.nn as nn
import torch.utils.data as Data# 1. 导入模型文件并且定义
from model import LeNet
model = LeNet()
# 2. 定义参数:轮数,批次和学习率
Epoch = 5
batch_size = 64
lr = 0.001
# 3. 获取训练数据
train_data = torchvision.datasets.MNIST(root='./data',train=True,transform=torchvision.transforms.ToTensor(),download=False)
#定义train data的数据集
train_loader = Data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
# 4. 定义损失函数、优化器
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# 5. 梯度计算
torch.set_grad_enabled(True)
#启用Batch Normalization层和Dropout层
model.train()

model.train()方法:该方法用于启用Batch Normalization层和Dropout层。虽然模型中并没有这两层,但是我们不妨将其加上,并作为一个习惯,以免在真正需要时忘记。

3 开始训练

1) 获得DataLoader中的数据x和标签y

2) 将优化器的梯度清零

3) 将数据送入模型中获得预测的结果y_pred

4) 将标签和预测结果送入损失函数获得损失

5) 将损失值反向传播

6) 使用优化器对模型的参数进行更新

# 训练
for epoch in range(Epoch):for step,data in enumerate(train_data):# 取出data中的数据和标签x,y=data# 优化器梯度清零optimizer.zero_grad()# 计算预测值y_pred = model(x.to(device,torch.float))# 计算损失loss = loss_function(y_pred, y.to(device, torch.long))# 梯度更新loss.backward()# 优化器更新模型参数optimizer.step()

3.1 step、batchsize和数据集中图片数的关系

一个 epoch 表示将训练数据集中的所有样本都用于训练一次。步数(steps)表示在一个 epoch 中所执行的批次数量。

  • step的计算方法:将总样本数除以批次大小。

我们在前面定义的batchsize是64且丢弃最后一批,手写数字数据集中有6万张图片,60000/64=937.5,故每个epoch中有step=937。

3.2 关于类型转换

# 计算预测值y_pred = model(x.to(device,torch.float))# 计算损失loss = loss_function(y_pred, y.to(device, torch.long))

两行中都涉及到了类型转换。在深度学习中,通常希望输入数据和模型参数的数据类型是一致的,这样可以避免类型不匹配的错误,并且可以更有效地利用硬件加速器(如 GPU)进行计算。

  1. 第一行把y_pred输入数据转换为 torch.float 类型的目的是确保模型接收到的数据类型与模型参数的数据类型匹配。在很多情况下,神经网络模型的参数通常是浮点数类型(float),因此将输入数据转换为 torch.float 类型可以确保与模型参数的类型匹配。
  2. 在深度学习中,通常使用整数类型(如 torch.long)来表示类别标签或离散值。许多损失函数(例如交叉熵损失函数)在计算损失时需要模型输出的预测值和真实标签值具有相同的数据类型。

4 保存模型

把模型的定义和参数全保存在一个文件中,后面可以直接使用训练好的权重文件。

torch.save(model, './LeNet.pkl')

训练过程可视化

  • 查看训练过程中的损失和准确率等等过程参数。

可以在每隔一定的step后输出当前损失和准确率的平均值。MNIST的训练集共有六万张图像,而我们的batch_size是64且丢弃最后一批,因此在每个Epoch中有937个step,实际训练59968张图像。可以每迭代100次后输出当前Epoch的损失和准确率的平均值,并输出当前处在哪一次Epochstep

5.1 损失

5.1.1 损失的计算

对每次计算产生的损失进行相加,把结果放在running_loss中,因此,我们需要在反向传播后添加一个累加操作。由于loss在我们之前定义的设备上,因此我们需要获得loss的值,然后将其传回cpu并转换为float类型,即:

 # 计算损失loss = loss_function(y_pred, y.to(device, torch.long))# 梯度更新loss.backward()# 累加梯度running_loss += float(loss.data.cpu())

 5.1.2 平均损失的计算

如果当前step是100次,则计算一次平均损失。由于step从0开始,所以需要+1.

loss_avg = running_loss / (step + 1)

5.2 准确率

5.2.1 准确率计算

acc = 预测正确的数目 / 总数目

y_pred是一个二维的张量,其形状为[batch_size, num_classes],在这边channel是10,即十个数字。如果我们将batch中的任意一行提取出来就获得了一个10维的向量,向量里的每个数代表与其下标所对应的标签的相关性,相关性越大则代表越有可能是这个数字。

因此,我们需要获得这个向量中最大数的下标,在pytorch中,我们可以用.argmax(dim)方法实现,输入维度dim,即可返回这个维度下最大值的下标,即pred = y_pred.argmax(dim=1)。在此基础上,我们就可以计算其预测正确的数量了;

先获取pred的值,即模型预测的图片的类别,然后传回cpu,用==筛选模型预测的标签和图片标注的标签相等的个数,然后使用.sum()相加统计预测正确的个数

acc保存的即是模型预测正确的个数。

acc += (pred.data.cpu()==y.data).sum()

5.2.2 平均准确率计算

接下来先对steps进行统计,设置每轮数中的每100个steps计算一次平均损失、准确率等。

平均损失值 = 损失值 / steps

平均准确率 = 预测正确的个数 / 总个数 

 # 判断该step是不是该epoch中的第100步if step%100==99:# 平均损失 = 损失值/stepsloss_avg = running_loss / (step+1)# 平均准确率 = 预测正确的数量/总个数acc_avg = float(acc / ((step + 1) * batch_size))# 输出print('Epoch', epoch + 1, ',step', step + 1, '| Loss_avg: %.4f' % loss_avg, '|Acc_avg:%.4f' % acc_avg)

 6. train.py完整代码

# 训练数据
import torch
import torchvision
import torch.nn as nn
import torch.utils.data as Data
# 导入模型文件并且定义
from model import LeNetmodel = LeNet()
# 定义参数:轮数,批次和学习率
Epoch = 5
batch_size = 64
lr = 0.001# 获取训练数据
train_data = torchvision.datasets.MNIST(root='./data',train=True,transform=torchvision.transforms.ToTensor(),download=False)
#定义train data的数据集
train_loader = Data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
# 定义损失函数、优化器
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#梯度计算
torch.set_grad_enabled(True)
#启用Batch Normalization层和Dropout层
model.train()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)# 训练
for epoch in range(Epoch):# 定义损失值和准确率running_loss = 0.0acc = 0.0for step,data in enumerate(train_loader):# 取出data中的数据和标签,x为数据y为标签x,y=data# 优化器梯度清零optimizer.zero_grad()# 计算预测值y_pred = model(x.to(device,torch.float))# 计算损失loss = loss_function(y_pred, y.to(device, torch.long))# 梯度更新loss.backward()# 累加梯度running_loss += float(loss.data.cpu())# 取出预测的最大值pred = y_pred.argmax(dim=1)# 统计预测正确的个数acc += (pred.data.cpu()==y.data).sum()# 优化器更新模型参数optimizer.step()# 判断该step是不是该epoch中的第100步if step%100==99:# 平均损失 = 损失值/stepsloss_avg = running_loss / (step+1)# 平均准确率 = 预测正确的数量/总个数acc_avg = float(acc / ((step + 1) * batch_size))# 输出print('Epoch', epoch + 1, ',step', step + 1, '| Loss_avg: %.4f' % loss_avg, '|Acc_avg:%.4f' % acc_avg)# 保存模型
torch.save(model, './LeNet.pkl')

7. 训练结果

可以看到loss呈下降趋势,acc提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/294495.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用filezilla连接Ubuntu22.04虚拟机

获取电脑IP和虚拟机IP ① 在windows下ctrlR再输入cmd,打开指令窗口,输入 ipconfig 虚拟机连接电脑用的是NAT模式,故看VMnet8的IP地址 ② 查看虚拟机IP地址 终端输入 ifconfig 如果没安装,按提示安装net-tools sudo apt install …

物联网学习2、MQTT 发布/订阅模式介绍

MQTT 发布/订阅模式 发布订阅模式(Publish-Subscribe Pattern)是一种消息传递模式,它将发送消息的客户端(发布者)与接收消息的客户端(订阅者)解耦,使得两者不需要建立直接的联系也不…

使用CMake搭建简单的Qt程序

目录结构 代码 CMakeLists.txt: cmake_minimum_required(VERSION 3.15)set(CMAKE_AUTOUIC ON) set(CMAKE_AUTOMOC ON) set(CMAKE_AUTORCC ON)# set the project name project(xxx)# 设置Qt的路径 # 例如 E:/Qt/Qt/aaa/msvc2019_64 # aaa 为Qt的版本号 set(QT_PATH…

【工具-MATLAB】

MATLAB ■ MATLAB-简介■ MATLAB-应用领域■ MATLAB■ MATLAB■ MATLAB■ MATLAB ■ MATLAB-简介 MATLAB是matrix&laboratory两个词的组合,意为矩阵工厂(矩阵实验室) 美国MathWorks公司出品的商业数学软件, MATLAB和Mathematica、Maple并…

高风险IP来自哪里:探讨IP地址来源及其风险性质

在网络安全领域,高风险IP地址是指那些可能涉及恶意活动或网络攻击的IP地址。了解这些高风险IP地址的来源可以帮助网络管理员更好地识别和应对潜在的安全威胁。本文将探讨高风险IP地址的来源及其风险性质,并提供一些有效的应对措施。 风险IP查询&#xf…

物联网技术在数字化工厂中的应用研究——青创智通

工业物联网解决方案-工业IOT-青创智通 随着科技的不断进步和数字化浪潮的推动,物联网(IoT)技术在各个领域中得到了广泛应用。其中,数字化工厂作为现代制造业的重要代表,物联网技术的应用更是为其带来了革命性的变革。…

C# WPF编程-元素绑定

C# WPF编程-元素绑定 将元素绑定到一起绑定表达式绑定错误绑定模式代码创建绑定移除绑定使用代码检索绑定多绑定绑定更新绑定延时 数据绑定是一种关系,该关系告诉WPF从源对象提取一下信息,并用这些信息设置目标对象的属性。目标属性始终是依赖项属性&…

Centos7 elasticsearch-7.7.0 集群搭建,启用x-pack验证 Kibana7.4用户管理

前言 Elasticsearch 是一个分布式、RESTful 风格的搜索和数据分析引擎,能够解决不断涌现出的各种用例。 作为 Elastic Stack 的核心,它集中存储您的数据,帮助您发现意料之中以及意料之外的情况。 环境准备 软件 …

VSCode - 离线安装扩展python插件教程

1,下载插件 (1)首先使用浏览器打开 VSCode 插件市场link (2)进入插件主页,点击右侧的 Download Extension 链接,将离线安装包下载下来(文件后缀为 .vsix) 2,…

GoogleNet神经网络介绍

一、简介 GoogleNet,也称为GoogLeNet,是谷歌工程师设计的一种深度神经网络结构,它在2014年的ImageNet图像识别挑战赛中取得了冠军。该神经网络的设计特点主要体现在其深度和宽度上,通过引入名为Inception的核心子网络结构&#x…

HarmonyOS NEXT应用开发案例——阻塞事件冒泡

介绍 本示例主要介绍在点击事件中,子组件enabled属性设置为false的时候,如何解决点击子组件模块区域会触发父组件的点击事件问题;以及触摸事件中当子组件触发触摸事件的时候,父组件如果设置触摸事件的话,如何解决父组…

IDEA报错,`java.io.NotSerializableException`,解决:一个类只有实现了Serializable接口,它的对象才是可序列化的

问题:IDEA报错,java.io.NotSerializableException 解决办法:在出问题的类中加上implements Serializable,如下所示: 原因:当要将该实体类对象保存至某个地方时(我这里是想将Catalog2Vo保存至R…

小型分布式文件存储系统GoFastDfs应用简介

前言 最近稍微留意了一下各个文件存储系统的协议,发现minio是LGPLV3, 而fastdfs 是GPL3,这些协议其实对于商业应用是一个大坑。故而寻找一些代替品。 go-fastdfs就是其中之一,官网在: go-fastdfs 具体应用 其实可以直接查看官网教程的。 下…

GEE:基于光谱距离方法的变化检测(以滑坡为例)

作者:CSDN @ _养乐多_ 本文将介绍在 Google Earth Engine(GEE)平台上,使用光谱向量距离度量方法进行变化检测的代码。代码中使用哨兵数据的光谱向量,并以检测滑坡为例进行演示。 结果如下图所示, 文章目录 一、参考内容1.1 光谱距离1.2 点积二、代码链接三、完整代码一…

python怎么处理txt

导入文件处理模块 import os 检测路径是否存在,存在则返回True,不存在则返回False os.path.exists("demo.txt") 如果你要创建一个文件并要写入内容 #如果demo.txt文件存在则会覆盖,并且demo.txt文件里面的内容被清空,如…

实现一个Google身份验证代替短信验证

最近才知道公司还在做国外的业务,要实现一个登陆辅助验证系统。咱们国内是用手机短信做验证,当然 这个google身份验证只是一个辅助验证登陆方式。看一下演示 看到了嘛。 手机下载一个谷歌身份验证器就可以 。 谷歌身份验证器,我本身是一个基…

每天五分钟深度学习:神经网络和深度学习有什么样的关系?

本文重点 神经网络是一种模拟人脑神经元连接方式的计算模型,通过大量神经元之间的连接和权重调整,实现对输入数据的处理和分析。而深度学习则是神经网络的一种特殊形式,它通过构建深层次的神经网络结构,实现对复杂数据的深度学习…

python爬虫----了解爬虫(十一天)

🎈🎈作者主页: 喔的嘛呀🎈🎈 🎈🎈所属专栏:python爬虫学习🎈🎈 ✨✨谢谢大家捧场,祝屏幕前的小伙伴们每天都有好运相伴左右,一定要天天…

算法打卡day23

今日任务: 1)39. 组合总和 2)40.组合总和II 3)131.分割回文串 39. 组合总和 题目链接:39. 组合总和 - 力扣(LeetCode) 给定一个无重复元素的数组 candidates 和一个目标数 target ,…

NB-IOT——浅谈NB-IOT及模块测试

浅谈NB-IOT及模块基本使用测试 介绍什么是NB-IOT?NB-IOT的特点 使用准备基本使用 总结 介绍 什么是NB-IOT? NB-IoT,即窄带物联网(Narrowband Internet of Things),是一种低功耗广域物联网(LPW…