深度学习中batch_size

Batch size调整和epoch/iteration的关系

训练数据集总共有1000个样本。若batch_size=10,那么训练完全体样本集需要100次迭代,1次epoch。
训练样本10000条,batchsize设置为20,将所有的训练样本在同一个模型中训练5遍,则epoch=5,batchsize=20, iteration=10000/20=500(即迭代次数表示有多个个batch会过模型)

分布式训练时的batch_size设置:需要将batch_size/num_process。divide the batch size by the number of replicas in order to maintain the overall batch size of 需要的值.[PyTorch:模型训练-分布式训练]

Batch size设置经验

1 一定条件下,batchsize越大训练效果越好。但是batchsize越大,内存gpu消耗越大。梯度累加则实现了batchsize的变相扩大,如果accumulation_steps为8,则batchsize ‘变相’ 扩大了8倍,是解决显存受限的一个不错的trick。

[使用梯度累加的batch训练函数]

2 经验法则是,如果mini-batch size加倍,那么学习率就加倍。
在这里插入图片描述

[神经网络中 warmup 策略为什么有效;有什么理论解释么? - 知乎]

在前面“如果mini-batch size加倍,那么学习率就加倍"中,我们的假设在什么时候可能不成立呢?两种情况:

1)在训练的开始阶段,模型权重迅速改变
2)mini-batch size较小,样本方差较大

第一种情况很好理解,可以认为,刚开始模型对数据的“分布”理解为零,或者是说“均匀分布”(当然这取决于你的初始化);在第一轮训练的时候,每个数据点对模型来说都是新的,模型会很快地进行数据分布修正,如果这时候学习率就很大,极有可能导致开始的时候就对该数据“过拟合”,后面要通过多轮训练才能拉回来,浪费时间。当训练了一段时间(比如两轮、三轮)后,模型已经对每个数据点看过几遍了,或者说对当前的batch而言有了一些正确的先验,较大的学习率就不那么容易会使模型学偏,所以可以适当调大学习率。这个过程就可以看做是warmup。那么为什么之后还要decay呢?当模型训到一定阶段后(比如十个epoch),模型的分布就已经比较固定了,或者说能学到的新东西就比较少了。如果还沿用较大的学习率,就会破坏这种稳定性,用我们通常的话说,就是已经接近loss的local optimal了,为了靠近这个point,我们就要慢慢来。

第二种情况其实和第一种情况是紧密联系的。在训练的过程中,如果有mini-batch内的数据分布方差特别大,这就会导致模型学习剧烈波动,使其学得的权重很不稳定,这在训练初期最为明显,最后期较为缓解(所以我们要对数据进行scale也是这个道理)。

说明,在上面两种情况下,我们并不能单纯地成倍增长lr η̂ =kη。要么改变学习率增长方法,要么设法解决上面两个问题。

[神经网络中 warmup 策略为什么有效;有什么理论解释么?]

所以就有了下面的warmup策略和学习率衰减方法:

学习率 warm-up 策略

训练神经网络时,在初始使用较大学习率而后期切换为较小学习率是一种广为使用的做法。而 warmup 策略则与上述 scheme 有些矛盾,warmup 需要在训练最初使用较小的学习率来启动,并很快切换到大学习率而后进行常见的 decay,那么最开始的这一步 warmup 为什么有效呢?

warmup_lr 的初始值是跟训练预料的大小成反比的,也就是说训练预料越大,那么warmup_lr 初值越小,随后增长到我们预设的超参 initial_learning_rate相同的量级,再接下来又通过 decay_rates 逐步下降。

这样做的原因前面已经说明了,还有什么好处?

1)这样可以使得学习率适应不同的训练集合size实验的时候经常需要先使用小的数据集训练验证模型,然后换大的数据集做生成环境模型训练。

2)即使不幸学习率设置得很大,那么也能通过warmup机制看到合适的学习率区间(即训练误差先降后升的关键位置附近),以便后续验证。

原文:https://blog.csdn.net/pipisorry/article/details/109192443

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/495753.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用“NodeMCU”、“红外模块”实现空调控制

项目思路 空调遥控器之所以能够实现对空调的控制,是因为它能够向空调发射出特定的红外信号。从理论上来说,任何能够发射出这种相同红外信号的红外发射器,都可以充当空调遥控器(这也正是手机能够控制多种不同品牌空调的原因所在&a…

Git--tag标签远程管理

目录 一、git 标签 tag管理 1.创建一个轻量级标签 2.创建一个带有附注的标签 3.删除标签 二、标签推送 1.再创建两个分支 2.把多个标签推送到远程 三、标签拉取 四、删除远程标签 1.命令 2.查看远程仓库,标签被删除 3.远程标签删除后本地标签不会消失&a…

43. Three.js案例-绘制100个立方体

43. Three.js案例-绘制100个立方体 实现效果 知识点 WebGLRenderer(WebGL渲染器) WebGLRenderer是Three.js中最常用的渲染器之一,用于将3D场景渲染到网页上。 构造器 WebGLRenderer(parameters : Object) 参数类型描述parametersObject…

【论文阅读笔记】Scalable, Detailed and Mask-Free Universal Photometric Stereo

【论文阅读笔记】Scalable, Detailed and Mask-Free Universal Photometric Stereo 前言摘要引言Task 相关工作方法SDM-UniPS预处理尺度不变的空间光特征编码器像素采样变压器的非局部交互 PS-Mix数据集 实验结果训练细节评估和时间: 消融实验定向照明下的评估没有对…

【LuaFramework】服务器模块相关知识

目录 一、客户端代码 二、本地服务器代码 三、解决服务器无法多次接收客户端消息问题 一、客户端代码 连接本地服务器127.0.0.1:2012端口(如何创本地服务器,放最后说),连接成功后会回调 协议号Connect是101,其他如下…

攻防世界 robots

开启场景 根据提示访问/robots.txt,发现了 f1ag_1s_h3re.php 拼接访问 /f1ag_1s_h3re.php 发现了 flag cyberpeace{d8b7025ed93ed79d44f64e94f2527a17}

gitlab克隆仓库报错fatal: unable to access ‘仓库地址xxxxxxxx‘

首次克隆仓库,失效了,上网查方法,都说是网络代理的问题,各种清理网络代理后都无效,去问同事: 先前都是直接复制的网页url当做远端url,或者点击按钮‘使用http克隆’ 这次对于我来说有效的远端u…

数字IC后端设计实现十大精华主题分享

今天小编给大家分享下吾爱IC社区星球上周十大后端精华主题。 Q1:星主,请教个问题,长tree的时候发现这个scan的tree 的skew差不多400p,我高亮了整个tree的schematic,我在想是不是我在这一系列mux前边打断,设置ignore p…

华为管理变革之道:管理制度创新

目录 华为崛起两大因素:管理制度创新和组织文化。 管理是科学,150年来管理史上最伟大的创新是流程 为什么要变革? 向世界标杆学习,是变革第一方法论 体系之一:华为的DSTE战略管理体系(解决&#xff1a…

CentOS7下的vsftpd服务器和客户端

目录 1、安装vsftpd服务器和ftp客户端; 2、配置vsftpd服务器,允许普通用户登录、下载、上传文件; 3、配置vsftpd服务器,允许anonymous用户登录、下载、上传文件; 4、配置vsftpd服务器,允许root用户登录…

Python Polars快速入门指南:LazyFrames

前文已经介绍了Polars的Dataframe, Contexts 和 Expressions,本文继续介绍Polars的惰性API。惰性API是该库最强大的功能之一,使用惰性API可以设定一系列操作,而无需立即运行它们。相反,这些操作被保存为计算图,只在必要…

【论文阅读笔记】IC-Light

SCALING IN-THE-WILD TRAINING FOR DIFFUSION-BASED ILLUMINATION HARMONIZATION AND EDITING BY IMPOSING CONSISTENT LIGHT TRANSPORT 通过施加一致的光线传输来扩展基于扩散模型的真实场景光照协调与编辑训练 前言摘要引言相关工作基于学习的基于扩散模型的外观和光照操纵光…

学习threejs,THREE.CircleGeometry 二维平面圆形几何体

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:threejs gis工程师 文章目录 一、🍀前言1.1 ☘️THREE.CircleGeometry 圆形…

AWS Certified AI Practitioner 自学考试心得

学习目标: 考取 AWS Certified AI Practitioner 那什么是 AWS Certified AI Practitioner 认证 是基础级的认证 比较简单 — 学习内容: 1. AWS网站自学网站 极客时间免费课程:http://gk.link/a/12sJL 配合极客时间课程的章节测试检验自…

WebRTC服务质量(07)- 重传机制(04) 接收NACK消息

WebRTC服务质量(01)- Qos概述 WebRTC服务质量(02)- RTP协议 WebRTC服务质量(03)- RTCP协议 WebRTC服务质量(04)- 重传机制(01) RTX NACK概述 WebRTC服务质量(…

RAGFlow 基于深度文档理解构建的开源 RAG引擎 - 安装部署

RAGFlow 基于深度文档理解构建的开源 RAG引擎 - 安装部署 flyfish 1. 确保 vm.max_map_count ≥ 262144 这是指要调整Linux内核参数vm.max_map_count,以确保其值至少为262144。这个参数控制着进程可以映射的最大内存区域数量。对于某些应用程序(如Ela…

图神经网络_图嵌入_SDNE

0 提出背景 SDNE:Structural Deep Network Embedding 之前的DeepWalk、LINE、node2vec、struc2vec都使用了浅层结构,浅层模型往往不能捕获高度非线性的网络结构。 SDNE方法使用了多个非线性层来捕获节点的embedding。 1 预备知识 1阶相似度衡量的是…

redis——岁月云实战

单线程序,基于IO多路复用,基于内存和c语言编写,性能高。redis官方命令 1 数据结构 1.1 key的层级 redis的key可以通过冒号(:)来划分层级,如下图mms:company:order,但系统中可以看到有不少没有…

参数名在不同的SpringBoot版本中,处理方案不同

参数名在不同的SpringBoot版本中,处理方案还不同: 在springBoot的2.x版本(保证参数名一致) springBoot的父工程对compiler编译插件进行了默认的参数parameters配置,使得在编译时,会在生成的字节码文件中…

五、Swagger 介绍(Flask+Flasgger的应用)

Swagger 介绍 0. 引言1. Swagger 介绍2. Flasgger 介绍3. Flasgger效果3.1 原始flask代码3.2 转化成Flasgger形式3.3 使用Try it out调试3.4 多个url接口自动生成和调试 4. 使用教程4.1 使用 docstrings 作为规范4.2 使用外部 YAML 文件4.3 使用 Python 字典作为原始规范 5. 和…