【深度学习实战(33)】训练之model.train()和model.eval()

一、model.train(),model.eval()作用?

model.train() 和 model.eval() 是 PyTorch 中的两个方法,用于设置模型的训练模式和评估模式。

model.train() 方法将模型设置为训练模式。在训练模式下,模型会启用 dropout 和 batch normalization 等正则化方法,并且可以计算梯度以进行参数更新,同时还可以追踪梯度计算的图。训练时,均值、方差分别是该批次内数据相应维度的均值与方差

model.eval() 方法将模型设置为评估模式。在评估模式下,模型会禁用 dropout 和 batch normalization 等正则化方法,这样可以保证每次评估的结果是确定的。评估模式下的模型通常用于模型的测试、验证或推理阶段。推理时,均值、方差是基于所有批次的期望计算所得

区分训练模式和评估模式的目的在于保证模型在不同阶段的行为一致性。例如,在训练模式下,模型需要计算并追踪梯度以进行反向传播和参数更新;而在评估模式下,模型不需要计算梯度,只需要给出确定的预测结果。

二、model.train(),model.eval()对dropout产生的影响

使用model.train():有神经元被置零,且比例符合nn.Dropout(0.5)中的0.5设定

import torch
import torch.nn as nnmodel = nn.Dropout(0.5)
model.train()
input = torch.rand([3, 4])print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述
使用model.eval():没有神经元置零,nn.Dropout(0.5)被关闭

import torch
import torch.nn as nnmodel = nn.Dropout(0.5)
#model.train()
model.eval()
input = torch.rand([3, 4])print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述

不使用model.train()和model.eval():有神经元被置零,但是比例非常随机,不符合nn.Dropout(0.5)中的0.5设定
import torch
import torch.nn as nnmodel = nn.Dropout(0.5)
#model.train()
#model.eval()
input = torch.rand([3, 4])print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

三、model.train(),model.eval()对batch normalization产生的影响

使用model.eval():bn中的均值,方差,不发生改变

# 1.导入所需的库:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms# 2.定义数据集的转换方法。MNIST数据集是由28x28像素的手写数字组成的图像,将其转换为torch张量并进行标准化处理:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 3.下载MNIST数据集并进行转换:
trainset = torchvision.datasets.MNIST(root='./data', train=True,download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False,download=True, transform=transform)# 4.创建数据加载器:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=False, num_workers=0)# 5.现在你可以使用trainloader和testloader来获取训练集和测试集的批次数据了。例如,可以使用迭代器遍历数据集中的批次:
#dataiter = iter(trainloader)
#images, labels = dataiter.next()# 上述代码将返回一个批次的图像和对应的标签。可以使用images和labels来进行模型的训练和评估。
# 这就是使用torch库自带的MNIST数据集的基本流程。根据需要,你还可以添加其他的数据处理和增强步骤。# 定义模型
class Model(nn.Module):def __init__(self, hidden_num=32, out_num=10):super().__init__()self.fc1 = nn.Linear(28*28, hidden_num)self.bn  = nn.BatchNorm1d(hidden_num)self.fc2 = nn.Linear(hidden_num, out_num)self.softmax = nn.Softmax()def forward(self, inputs, **kwargs):x = inputs.flatten(1)x = self.fc1(x)print("========= bn之前存的数据: =========")print(self.bn.running_mean, self.bn.running_var)print()print("========= 当前 Batch 的数据: =========")x_mean = torch.mean(x,0)x_variance = torch.mean((x - x_mean)*(x - x_mean),0)print(x_mean, x_variance)print()print("========= torch官方计算之后的bn新数据: =========")x = self.bn(x)print(self.bn.running_mean, self.bn.running_var)print()# x = self.dropout(x)x = self.fc2(x)x = self.softmax(x)return xtorch.manual_seed(1)
model = Model()
#model.train()
model.eval()
for img, label in trainloader:label = nn.functional.one_hot(label.flatten(), 10)out = model(img)break

在这里插入图片描述
使用model.train():bn中的均值,方差,通过滑动平均地方式发生改变,

torch.manual_seed(1)
model = Model()
model.train()
#model.eval()
for img, label in trainloader:label = nn.functional.one_hot(label.flatten(), 10)out = model(img)break

在这里插入图片描述
不使用model.train()和model.eval():默认bn中的均值,方差,通过滑动平均地方式发生改变,
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/321318.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用双指针解决问题题集(二)

1. 有效三角形的个数 给定一个包含非负整数的数组 nums ,返回其中可以组成三角形三条边的三元组个数。 示例 1: 输入: nums [2,2,3,4] 输出: 3 解释:有效的组合是: 2,3,4 (使用第一个 2) 2,3,4 (使用第二个 2) 2,2,3 示例 2: 输入: nums [4,2,3,4] 输出: 4 题解&a…

set_input_delay的理解

1,set_input_delay约束理解 input_delay是指输入的数据到达FPGA的pad引脚时相对于时钟边沿的延迟有多大,单位是ns,数值可以是正,也可以是负。通过set_input_delay约束告诉编译器输入时钟和输入数据的相位关系。如下图所示假设时钟…

C语言猜数字游戏

用C语言实现猜数字游戏&#xff0c;电脑随机给出一个范围内的数字&#xff0c;用户在终端输入数字&#xff0c;去猜大小&#xff1b;对比数字&#xff0c;电脑给出提示偏大还是偏小&#xff1b;不断循环&#xff0c;直到正确 #include <stdio.h> #include <time.h>…

C++ Primer 总结索引 | 第十四章:重载运算与类型转换

1、C语言定义了 大量运算符 以及 内置类型的自动转换规则 当运算符 被用于 类类型的对象时&#xff0c;C语言允许我们 为其指定新的含义&#xff1b;也能自定义类类型之间的转换规则 例&#xff1a;可以通过下述形式输出两个Sales item的和&#xff1a; cout << item1 …

快速找出存(不存在)在某个(或多个)文件的文件夹

首先&#xff0c;需要用到的这个工具&#xff1a; 度娘网盘 提取码&#xff1a;qwu2 蓝奏云 提取码&#xff1a;2r1z 想要找出有下面这个文件存在的文件夹 切换到批量文件复制版块&#xff0c;快捷键Ctrl5 右侧&#xff0c;搜索添加 选定范围&#xff0c;勾选搜索文件夹、包…

重生奇迹mu套装大全

1.战士 汉斯的皮套装&#xff1a;冰之指环,皮护腿,皮盔,皮护手,皮靴,皮铠,流星槌 汉斯的青铜套装&#xff1a;青铜护腿,青铜靴,青铜铠 汉斯的翡翠套装&#xff1a;雷之项链,翡翠护腿,翡翠盔,翡翠铠,远古之盾 汉斯的黄金套装&#xff1a;火之项链,黄金护腿,黄金护手,黄金靴,…

Skywalking数据持久化与自定义链路追踪

学习本篇文章之前首先要了解一下Sky walking的基础知识 分布式链路追踪工具Skywalking详解 一&#xff0c;Sky walking数据持久化 Sky walking提供了es&#xff0c;MySQL等数据持久化方案&#xff0c;默认使用h2基于内存的数据库&#xff0c;重启之后数据即会丢失。 在实际工…

renren-fast开源快速开发代码生成器

简介 renrenfast框架介绍 renren-fast是一个轻量级的Spring Boot快速开发平台&#xff0c;能快速开发项目并交付.完善的XSS防范及脚本过滤&#xff0c;彻底杜绝XSS攻击实现前后端分离&#xff0c;通过token进行数据交互 使用流程 项目地址 https://gitee.com/renrenio/ren…

k8s部署skywalking(helm)

官方文档 官方文档说明&#xff1a;Backend setup | Apache SkyWalking官方helm源码&#xff1a;apache/skywalking-helm官方下载&#xff08;包括agent、apm&#xff09;:Downloads | Apache SkyWalking 部署 根据官方helm提示&#xff0c;选择你自己部署的方式&#xff0c…

【LAMMPS学习】八、基础知识(5.8)LAMMPS 中热化 Drude 振荡器教程

8. 基础知识 此部分描述了如何使用 LAMMPS 为用户和开发人员执行各种任务。术语表页面还列出了 MD 术语&#xff0c;以及相应 LAMMPS 手册页的链接。 LAMMPS 源代码分发的 examples 目录中包含的示例输入脚本以及示例脚本页面上突出显示的示例输入脚本还展示了如何设置和运行各…

SpringBoot中HandlerInterceptor拦截器的构建详细教程

作用范围&#xff1a;拦截器主要作用于Spring MVC的DispatcherServlet处理流程中&#xff0c;针对进入Controller层的请求进行拦截处理。它基于Java的反射机制&#xff0c;通过AOP&#xff08;面向切面编程&#xff09;的思想实现&#xff0c;因此它能够访问Spring容器中的Bean…

如何解决3D模型变黑或贴图不显示的问题---模大狮模型网

在进行3D建模和视觉渲染时&#xff0c;经常会遇到模型表面变黑或贴图不显示的问题&#xff0c;这可能严重影响最终视觉效果的质量。这些问题通常与材质设置、光照配置或文件路径错误有关。本文将探讨几种常见原因及其解决方法&#xff0c;帮助3D艺术家和开发者更有效地处理这些…

TinyXML-2介绍

1.简介 TinyXML-2 是一个简单、小巧的 C XML 解析库&#xff0c;它是 TinyXML 的一个改进版本&#xff0c;专注于易用性和性能。TinyXML-2 用于读取、修改和创建 XML 文档。它不依赖于外部库&#xff0c;并且可以很容易地集成到项目中。 tinyXML-2 的主要特点包括&#xff1a…

回归预测 | Matlab实现基于CNN-SE-Attention-ITCN多特征输入回归组合预测算法

回归预测 | Matlab实现基于CNN-SE-Attention-ITCN多特征输入回归组合预测算法 目录 回归预测 | Matlab实现基于CNN-SE-Attention-ITCN多特征输入回归组合预测算法预测效果基本介绍程序设计参考资料 预测效果 基本介绍 【模型简介】CNN-SE_Attention结合了卷积神经网络&#xff…

武汉星起航:精准布局,卓越服务——运营交付团队领跑亚马逊

在全球电商浪潮中&#xff0c;亚马逊平台以其独特的商业模式和全球化的市场布局&#xff0c;吸引了无数商家和创业者的目光。在这个充满机遇的市场中&#xff0c;武汉星起航电子商务有限公司凭借其专业的运营交付团队&#xff0c;以其独特的五对一服务体系和精准的战略布局&…

Azure AKS日志查询KQL表达式

背景需求 Azure&#xff08;Global&#xff09; AKS集群中&#xff0c;需要查询部署服务的历史日志&#xff0c;例如&#xff1a;我部署了服务A&#xff0c;但服务A的上一个版本Pod已经被杀掉由于版本的更新迭代&#xff0c;而我在命令行中只能看到当前版本的pod日志&#xff…

Git推送本地项目到gitee远程仓库

Git 是一个功能强大的分布式版本控制系统&#xff0c;它允许多人协作开发项目&#xff0c;同时有效管理代码的历史版本。开发者可以克隆一个公共仓库到本地&#xff0c;进行更改后将更新推送回服务器&#xff0c;或从服务器拉取他人更改&#xff0c;实现代码的同步和版本控制。…

普洱茶泡多少茶叶才算淡茶?

普洱茶淡茶一般放几克茶叶&#xff0c;品深茶官网根据多年专业研究与实践结果&#xff0c;制定了淡茶冲泡标准。在冲泡普洱茶淡茶时&#xff0c;茶叶的投放量是关键因素之一。淡茶冲泡标准旨在保持茶汤的清爽口感&#xff0c;同时充分展现普洱茶的独特风味。 根据《品深淡茶冲…

题目:吃奶酪

问题描述&#xff1a; 解题思路&#xff1a; 枚举每种吃奶酪顺序&#xff0c;并计算其距离&#xff0c;选择最小的距离即答案。v数组&#xff1a;记录顺序。 注意点&#xff1a;1. 每次用于min的s需要重置为0。 2. 实数包括小数&#xff0c;所以结构体内x,y为double类型。 3. 第…

C++变量的作用域与存储类型

一 变量的作用域和存储类型 1 变量的作用域(Scope) 指在源程序中定义变量的位置及其能被读写访问的范围分为局部变量(Local Variable)和全局变量(Global Variable) 1&#xff09;局部变量(Local Variable) 在语句块内定义的变量 形参也是局部变量 特点&#xff1a; 生存期是…