pytorch笔记:自动混合精度(AMP)

1 理论部分

1.1 FP16 VS FP32

  • FP32具有八个指数位和23个小数位,而FP16具有五个指数位和十个小数位
  • Tensor内核支持混合精度数学,即输入为半精度(FP16),输出为全精度(FP32)

1.1.1 使用FP16的优缺点

  • 优点
    • FP16需要较少的内存,因此更易于训练和部署大型神经网络,同时还减少了数据移动(同时可以使用更大的batch)
    • 数学运算的运行速度大大降低了
      • NVIDIA提供的Volta GPU的确切数量是:FP16中为125 TFlops,而FP32中为15.7 TFlops(加速8倍)
  • 缺点:
    • 从FP32转到FP16时,必然会降低精度
      • 但有的时候,这个精度的降低可以忽略不计
      • FP16实际上可以很好地表示大多数权重和渐变。
      • ——>拥有存储和使用FP32所需的所有这些额外位只是浪费。
    • 溢出错误
      • 由于FP16的动态范围比FP32位的狭窄很多,因此,在计算过程中很容易出现上溢出和下溢出
      • 溢出之后就会出现"NaN"的问题

1.2 解决上述FP16的问题

1.2.1 混合精度训练

  • 用FP16做储存和乘法,而用FP32做累加避免舍入误差
  • ——>混合精度训练的策略有效地缓解了舍入误差的问题

1.2.2 损失放大(Loss scaling)

  • 即使使用了混合精度训练,还是存在无法收敛的情况
    • 原因是激活梯度的值太小,造成了溢出。
  • ——>通过使用torch.cuda.amp.GradScaler,通过放大loss的值来防止梯度的下溢出
    • 只在BP时传递梯度信息使用,真正更新权重时还是要把放大的梯度再unscale回去
      • 反向传播前,将损失变化手动增大2^k倍

        • 因此反向传播时得到的中间变量(激活函数梯度)不会溢出;

      • 反向传播后,将权重梯度缩小2^k倍,恢复正常值。

2 torch.cuda.amp

  • AMP(自动混合精度)的关键词有两个:
    • 自动
      • Tensor的dtype类型会自动变化,框架按需自动调整tensor的dtype,当然有些地方还需手动干预
    • 混合精度
      • 采用不止一种精度的Tensor,torch.FloatTensor和torch.HalfTensor

2.1 Pytorch中不同类型的tensor

类型名称位数
torch.DoubleTensor64bit
torch.LongTensor64bit
torch.FloatTensor(默认)32bit
torch.IntTensor32bit
torch.HalfTensor16bit
torch.BFloat16Tensor16bit
torch.ShortTensor16bit
torch.ByteTensor(无符号)8bit
torch.CharTensor8bit
torch.BoolTensorBoolean

2.2 在AMP上下文中,被自动转化为半精度浮点型的参数:

__matmul__
addbmm
addmm
addmv
addr
baddbmm
bmm
chain_matmul
conv1d
conv2d
conv3d
conv_transpose1d
conv_transpose2d
conv_transpose3d
linear
matmul
mm
mv
prelu

2.3 autocast

from torch.cuda.amp import autocast as autocastmodel = Net().cuda()
#首先初始化一个网络模型Net(),并使用.cuda()方法将模型移至GPU上以利用GPU加速
#Net中的参数默认是torch.FloatTensoroptimizer = optim.SGD(model.parameters(), ...)for input, target in data:optimizer.zero_grad()with autocast():output = model(input)loss = loss_fn(output, target)'''自动混合精度环境包含了前向过程(模型的输出)和loss的计算把支持参数对应tensor的dtype转换为半精度浮点型,从而在不损失训练精度的情况下加快运算进入autocast的上下文时,tensor可以是任何类型不需要在model或者input上手工调用.half() ,框架会自动做'''loss.backward()optimizer.step()# 反向传播在autocast上下文之外

 2.4 GradScaler

在2.3的基础上增加,反向传播时增加梯度,以防止下溢出

from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScalermodel = Net().cuda()
#首先初始化一个网络模型Net(),并使用.cuda()方法将模型移至GPU上以利用GPU加速
#Net中的参数默认是torch.FloatTensoroptimizer = optim.SGD(model.parameters(), ...)scaler = GradScaler()
# 在训练最开始之前实例化一个GradScaler对象for epoch in epochs:for input, target in data:optimizer.zero_grad()with autocast():output = model(input)loss = loss_fn(output, target)'''自动混合精度环境包含了前向过程(模型的输出)和loss的计算把支持参数对应tensor的dtype转换为半精度浮点型,从而在不损失训练精度的情况下加快运算进入autocast的上下文时,tensor可以是任何类型不需要在model或者input上手工调用.half() ,框架会自动做'''scaler.scale(loss).backward()# Scales loss. 为了梯度放大,防止下溢出# 代替原来的loss.backward()scaler.step(optimizer)'''scaler.step() 首先把梯度的值unscale回来.如果梯度的值不是 infs 或者 NaNs, 那么调用optimizer.step()来更新权重,否则,忽略step调用,从而保证权重不更新(不被破坏)'''scaler.update()'''准备着,看是否要增大scaler'''
  •  scaler的大小在每次迭代中动态的估计
    • 为了尽可能的减少梯度underflow,scaler应该更大
    • 但是如果太大的话,半精度浮点型的tensor又容易overflow(变成inf或者NaN)。
  • ——>动态估计的原理就是在不出现inf或者NaN梯度值的情况下尽可能的增大scaler的值

3 一些tips

  • 为了保证计算不溢出,首先保证人工设定的常数不溢出。如epsilon,INF等
  • Dimension最好是8的倍数:维度是8的倍数,性能最好
  • 涉及sum的操作要小心,容易溢出
    • 比如softmax操作,建议用官方API,并定义成layer写在模型初始化里
  • 如果遇到以下的报错:
    • RuntimeError: expected scalar type float but found c10::Half
    • 需要手动在tensor上调用.float()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/345091.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机组成原理复习笔记

前言 就是按照考试的题型写的总结 非常应试版 题型 一、进制转换 只考 十进制 二进制 十六进制 之间的相互转换 一个个看 (1)十进制 转其他 转二进制:除以2 从小到大取余数(0或1) 转十六进制 : 除以1…

摆脱Jenkins - 使用google cloudbuild 部署 java service 到 compute engine VM

在之前 介绍 cloud build 的文章中 初探 Google 云原生的CICD - CloudBuild 已经介绍过, 用cloud build 去部署1个 spring boot service 到 cloud run 是很简单的, 因为部署cloud run 无非就是用gcloud 去部署1个 GAR 上的docker image 到cloud run 容…

墨雨云间王星越雨中情深

墨雨云间:王星越的雨中情深,吻上萧蘅,宿命之恋在烟雨朦胧的《墨雨云间》中,王星越饰演的角色,以其深邃的眼神和细腻的演技,将一段宿命之恋演绎得淋漓尽致。当镜头聚焦于他与阿狸在雨中的那一幕,…

Django学习三:views业务层中通过models对实体对象进行的增、删、改、查操作。

文章目录 前言一、Django ORM介绍二、项目快速搭建三、操作1、view.pya、增加操作b、删除操作c、修改操作d、查询操作 2、urls.py 前言 上接博文:Django学习二:配置mysql,创建model实例,自动创建数据库表,对mysql数据…

暴雨推出基于英特尔® 至强® 6 能效核处理服务器

随着人工智能技术的快速发展,大模型的应用越来越广泛。据预测,到2024年年底,我国将有5%-8%的企业大模型参数从千亿级跃升至万亿级,算力需求增速将达到320%,这进一步推动了数据中心的持续变革。 超凡性能,绿…

STM32中ADC在cubemx基础配置界面介绍

ADCx的引脚,对应的不同I/O口,可以复用。 Temperature :温度传感器通道。 Vrefint :内部参照电压。 Conversion Trigger: 转换触发器。 IN0 至 IN15,是1ADC1的16个外部通道。本示例中输出连接的是ADC2的IN5通道,所以只勾选IN5.Temperature Sensor Cha…

百度/迅雷/夸克,网盘免费加速,已破!

哈喽,各位小伙伴们好,我是给大家带来各类黑科技与前沿资讯的小武。 之前给大家安利了百度网盘及迅雷的加速方法,详细方法及获取参考之前文章: 刚刚!度盘、某雷已破!速度50M/s! 本次主要介绍夸…

单元测试之CppTest测试框架

目录 1 背景2 设计3 实现4 使用4.1 主函数4.2 测试用例4.2.1 定义4.2.2 实现 4.3 运行 1 背景 前面文章CppTest实战演示中讲述如何使用CppTest库。其主函数如下: int main(int argc, char *argv[]) {Test::Suite mainSuite;Test::TextOutput output(Test::TextOut…

在离线单机或内网环境中快速安装Visual Studio 2022并还原用户设定

20240606 By wdhuag 目录 前言 参考: 在外网环境下载离线安装包 1、在已安装好VS的电脑上用Visual Studio Installer导出配置.vsconfig 2、下载在线安装包VisualStudioSetup_Enterprise_2022.exe到D:\VisualStudio\ 3、使用cmd定位到VisualStudioSetup_Enter…

什么是 AOF 重写?AOF 重写机制的流程是什么?

引言:在Redis中,持久化是确保数据持久性和可恢复性的重要机制之一。除了常见的RDB(Redis Database)持久化方式外,AOF(Append Only File)也是一种常用的持久化方式。AOF持久化通过记录Redis服务器…

使用 Django 和 MQTT 构建实时数据传输应用

文章目录 什么是 MQTT?Django 中的 MQTT结论 在现代的 Web 应用程序开发中,实时数据传输变得越来越重要。MQTT(Message Queuing Telemetry Transport)是一种轻量级的发布/订阅消息传输协议,而 Django 是一个流行的 Pyt…

【Linux】进程5——进程优先级

1.进程优先级 1.1.什么是进程优先级 cpu资源分配的先后顺序,就是指进程的优先权(priority)。优先权高的进程有优先执行权利。配置进程优先权对多任务环境的linux很有用,可以改善系统性能。还可以把进程运行到指定的CPU上&#x…

高考之后第一张大流量卡应该怎么选?

高考之后第一张大流量卡应该怎么选? 高考结束后,选择一张合适的大流量卡对于准大学生来说非常重要,因为假期期间流量的使用可能会暴增。需要综合考虑多个因素,以确保选到最适合自己需求、性价比较高且稳定的套餐。以下是一些建议…

二次规划问题(Quadratic Programming, QP)原理例子

二次规划(Quadratic Programming, QP) 二次规划(Quadratic Programming, QP)是优化问题中的一个重要类别,它涉及目标函数为二次函数并且线性约束条件的优化问题。二次规划在控制系统、金融优化、机器学习等领域有广泛应用。下面详细介绍二次规划问题的原理和求解过程 二…

Nginx的https功能和防盗链

目录 一.HTTPS功能简介 二.https自签名证书 三.防盗链 一.HTTPS功能简介 Web网站的登录页面都是使用https加密传输的,加密数据以保障数据的安全,HTTPS能够加密信息,以免敏感信息被第三方获取,所以很多银行网站或电子邮箱等等安…

postgresql根据某个字段去重获取整行数据

背景:在一些情况下我们需要对数据进行去重统计,group by就行,但是一些特殊情况下我们要根据某个字段去重之后获取非聚合字段的值,这个时候在mysql非严格模式下可以直接执行,但是在严格模式和postgresql里面是直接报错的…

蓝桥杯--跑步计划

问题描述 小蓝计划在某天的日期中出现 11 时跑 55 千米,否则只跑 11 千米。注意日期中出现 11 不仅指年月日也指星期。 请问按照小蓝的计划,20232023 年小蓝总共会跑步锻炼多少千米?例如,55 月 11 日、11 月 1313 日、1111 月 55 日、44 月…

使用 Scapy 库编写 TCP RST 攻击脚本

一、介绍 TCP RST攻击是一种拒绝服务攻击(Denial-of-Service, DoS)类型,攻击者通过伪造TCP重置(RST)包,中断目标主机与其他主机之间的TCP连接。该攻击利用了TCP协议中的重置机制,强制关闭合法的…

数据结构:旋转数组

方法1 &#xff08;三次逆置法&#xff09;&#xff1a; void reverse(int* nums, int start, int end) {while (start < end) {int temp nums[start];nums[start] nums[end];nums[end] temp;start;end--;} }void rotate(int* nums, int numsSize, int k) {k k % numsS…

Android14 WMS-窗口绘制之relayoutWindow流程(一)-Client端

Android14 WMS-窗口添加流程(一)-Client端-CSDN博客 Android14 WMS-窗口添加流程(二)-Server端-CSDN博客 经过上述两个流程后&#xff0c;窗口的信息都已经传入了WMS端。 1. ViewRootImpl#setView 在窗口添加流程(一)中&#xff0c;有这个方法&#xff1a; http://aospxref…