net.eval()和net.trasin()的用法

当构建神经网络使用到dropout层等时,网络的正向传播后反向传播神经元的系数会有所不同,因此需要用.eval()和.train()来指定模型方向。

net.train()

  • 作用:将模型设置为训练模式。
  • 影响:
    • 启用 Dropout 层:Dropout 会随机丢弃一部分神经元(以一定概率),用于防止过拟合。
    • 启用 Batch Normalization 的训练行为:BatchNorm 会使用当前批次的均值和方差来进行归一化,并更新其内部的运行均值和方差。
  • 使用场景:
    • 在训练模型时需要调用 net.train()(通常是默认状态)。
    • 确保模型的所有层都处于训练模式。

net.eval()

  • 作用:将模型设置为评估模式。
  • 影响:
    • 禁用 Dropout 层:Dropout 不会随机丢弃神经元,而是使用所有神经元的完整输出。
    • 禁用 Batch Normalization 的训练行为:BatchNorm 会使用训练过程中记录的全局均值和方差,而不是当前批次的均值和方差。
  • 使用场景:
    • 在验证或测试模型时需要调用 net.eval()。
    • 确保模型的所有层都处于评估模式,避免因 Dropout 或 BatchNorm 的行为导致结果不一致。
      在这里插入图片描述

示例

model.train()  # 切换到训练模式  
for data, target in train_loader:  optimizer.zero_grad()  output = model(data)  loss = loss_fn(output, target)  loss.backward()  optimizer.step()
model.eval()  # 切换到评估模式  
with torch.no_grad():  # 禁用梯度计算(加速推理并节省内存)  for data, target in test_loader:  output = model(data)  # 计算验证或测试指标

torch.no_grad() 与 eval() 的区别

  • torch.no_grad():禁用梯度计算,用于加速推理和节省内存,但不会改变模型的模式。
  • eval():切换模型到评估模式,但不会禁用梯度计算。
  • 通常在评估时,两者会结合使用

一个完整示例

import torch  
import torch.nn as nn  
import torch.optim as optim  # 定义一个简单的模型  
class SimpleModel(nn.Module):  def __init__(self):  super(SimpleModel, self).__init__()  self.fc = nn.Linear(10, 1)  self.dropout = nn.Dropout(0.5)  self.bn = nn.BatchNorm1d(10)  def forward(self, x):  x = self.bn(x)  x = self.dropout(x)  x = self.fc(x)  return x  # 初始化模型、损失函数和优化器  
model = SimpleModel()  
loss_fn = nn.MSELoss()  
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 训练阶段  
model.train()  # 切换到训练模式  
for epoch in range(5):  for data, target in train_loader:  # 假设 train_loader 已定义  optimizer.zero_grad()  output = model(data)  loss = loss_fn(output, target)  loss.backward()  optimizer.step()  # 验证阶段  
model.eval()  # 切换到评估模式  
with torch.no_grad():  # 禁用梯度计算  for data, target in val_loader:  # 假设 val_loader 已定义  output = model(data)  val_loss = loss_fn(output, target)  print(f"Validation Loss: {val_loss.item()}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/498192.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构与算法-目录

音视频流媒体开发-目录 iOS知识点-目录 Android-目录 Flutter-目录 数据结构与算法-目录 恋上数据结构与算法一 【恋上数据结构与算法一】(一)复杂度 【恋上数据结构与算法一】(二)动态数组 【恋上数据结构与算法一】(三)链表 【恋上数据结构与算法一】(四)栈 【恋上数据结构与…

算法训练营Day28 | leetcode 122.买卖股票的最佳时机II 55.跳跃游戏 45.跳跃游戏II

122.买卖股票的最佳时机II 本题首先要清楚两点: 只有一只股票!当前只有买股票或者卖股票的操作 想获得利润至少要两天为一个交易单元。 贪心算法 这道题目可能我们只会想,选一个低的买入,再选个高的卖,再选一个低…

da白话讲深度学习-卷积网络

卷积神经网络(CNN)是指至少在网络的一层中使用卷积运算来代替一般的矩阵乘法运算的神经网络,因此名为为卷积神经网络(对于神经网络的发展与类型,可以学习站内的相关文章) 1.什么是卷积? 既然是卷积神经网络&#xff…

搭建android开发环境 android studio

1、环境介绍 在进行安卓开发时,需要掌握java,需要安卓SDK,需要一款编辑器,还需要软件的测试环境(真机或虚拟机)。 早起开发安卓app,使用的是eclipse加安卓SDK,需要自行搭建。 目前开…

12.30 linux 文件操作,磁盘分区挂载

ubuntu 在linux 对文件的相关操作【压缩,打包,软链接,文件权限】【head,tail,管道符,通配符,find,grep,cut等】脑图-CSDN博客 1.文件操作 在家目录下创建目录文件&#…

Python Celery快速入门教程

Celery 是一个简单、灵活且可靠的分布式任务队列框架,用于处理大量的异步任务、定时任务等。它允许你将任务发送到消息队列,然后由后台的工作进程(worker)来执行这些任务,并且支持多种消息中间件,如 Rabbit…

Unity WebGL 部署IIS

Unity WebGL 部署IIS iis添加网站WebGL配置文件WebGL Gzip模式浏览器加载速度优化iis添加网站 第一步在配置好IIS并且添加网站 WebGL配置文件 在web包Build文件夹同级创建web.config文件 web.config文件内容 <?xml version="1.0" encoding="UTF-8"?…

基于西湖大学强化学习课程的笔记

放在前面 课程链接 2024年12月30日 前言&#xff1a;强化学习有原理部分的学习&#xff0c;也有与实践相关的编程部分。我认为实践部分应该是更适合我的&#xff0c;不过原理部分也很重要&#xff0c;我目前是准备先过一过原理。 应该花多少时间学习这部分呢&#xff1f; 但是这…

CannotRetrieveUpdates alert in disconnected OCP 4 cluster解决

环境&#xff1a; Red Hat OpenShift Container Platform (RHOCP) 4 问题&#xff1a; Cluster Version Operator 不断发送警报&#xff0c;表示在受限网络/断开连接的 OCP 4 集群中无法接收更新。 在隔离的 OpenShift 4 集群中看到 CannotRetrieveUpdates 警报&#xff1a; …

Redis--持久化策略(AOF与RDB)

持久化策略&#xff08;AOF与RDB&#xff09; 持久化Redis如何实现数据不丢失&#xff1f;RDB 快照是如何实现的呢&#xff1f;执行时机RDB原理执行快照时&#xff0c;数据能被修改吗&#xff1f; AOF持久化是怎么实现的&#xff1f;AOF原理三种写回策略AOF重写机制 RDB和AOF合…

【数据结构】链表(1):单向链表和单向循环链表

链表 链表是一种经典的数据结构&#xff0c;它通过节点的指针将数据元素有序地链接在一起&#xff0c;在链表中&#xff0c;每个节点存储数据以及指向其他节点的指针&#xff08;或引用&#xff09;。链表具有动态性和灵活性的特点&#xff0c;适用于频繁插入、删除操作的场景…

开源电子书转有声书整合包ebook2audiobookV2.0.0

ebook2audiobook&#xff1a;将电子书转换为有声书的开源项目 项目地址 GitHub - DrewThomasson/ebook2audiobook 整合包下载 更新至v2.0.0 https://pan.quark.cn/s/22956c5559d6 修改:页面已转为中文 项目简介 ebook2audiobook 是一个开源项目&#xff0c;它能够将电子…

NSSCTFpwn刷题

[SWPUCTF 2021 新生赛]nc签到 打开附件里面内容 import osart (( "####!!$$ ))#####!$$ ))(( ####!!$:(( ,####!!$: )).###!!$:##!$:#!!$!# #!$: #$#$ #!$: !!!$:\ "!$: /\ !: /"\ : /"-."-/\\\-."//.-"…

java里classpath都包含哪些范围?

什么是 classpath &#xff1f; classpath 等价于 main/java main/resources 第三方jar包的根目录 「引」SpringBoot中的classpath都包含啥

Docker+Portainer 离线安装

1. Docker安装 步骤一&#xff1a;官网下载 docker 安装包 步骤二&#xff1a;解压安装包; tar -zxvf docker-24.0.6.tgz 步骤三&#xff1a;将解压之后的docker文件移到 /usr/bin目录下; cp docker/* /usr/bin/ 步骤四&#xff1a;将docker注册成系统服务; vim /etc/sy…

#渗透测试#红蓝攻防#红队打点web服务突破口总结01

免责声明 本教程仅为合法的教学目的而准备&#xff0c;严禁用于任何形式的违法犯罪活动及其他商业行为&#xff0c;在使用本教程前&#xff0c;您应确保该行为符合当地的法律法规&#xff0c;继续阅读即表示您需自行承担所有操作的后果&#xff0c;如有异议&#xff0c;请立即停…

Java:190 基于SSM的药品管理系统

作者主页&#xff1a;舒克日记 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文中获取源码 项目介绍 系统的用户分管理员和销售两个角色的权限子模块。 管理员统计药品销售量&#xff0c;可以导出药品出入库记录&#xff0c;管理药品以及报损信息。 销…

Quo Vadis, Anomaly Detection? LLMs and VLMs in the Spotlight 论文阅读

文章信息&#xff1a; 原文链接&#xff1a;https://arxiv.org/abs/2412.18298 Abstract 视频异常检测&#xff08;VAD&#xff09;通过整合大语言模型&#xff08;LLMs&#xff09;和视觉语言模型&#xff08;VLMs&#xff09;取得了显著进展&#xff0c;解决了动态开放世界…

VUE echarts 教程二 折线堆叠图

VUE echarts 教程一 折线图 import * as echarts from echarts;var chartDom document.getElementById(main); var myChart echarts.init(chartDom); var option {title: {text: Stacked Line},tooltip: {trigger: axis},legend: {data: [Email, Union Ads, Video Ads, Dir…

001__VMware软件和ubuntu系统安装(镜像)

[ 基本难度系数 ]:★☆☆☆☆ 一、Vmware软件和Ubuntu系统说明&#xff1a; a、Vmware软件的说明&#xff1a; 官网&#xff1a; 历史版本&#xff1a; 如何下载&#xff1f; b、Ubuntu系统的说明&#xff1a; 4、linux系统的其他版本&#xff1a;红旗(redhat)、dibian、cent…