Pytorch自动混合精度的计算:torch.cuda.amp.autocast

1 autocast介绍

1.1 什么是AMP?

默认情况下,大多数深度学习框架都采用32位浮点算法进行训练。2017年,NVIDIA研究了一种用于混合精度训练的方法,该方法在训练网络时将单精度(FP32)与半精度(FP16)结合在一起,并使用相同的超参数实现了与FP32几乎相同的精度。

FP16也即半精度是一种计算机使用的二进制浮点数据类型,使用2字节存储。而FLOAT就是FP32。

1.2 autocast作用

torch.cuda.amp.autocast是PyTorch中一种混合精度的技术(仅在GPU上训练时可使用),可在保持数值精度的情况下提高训练速度和减少显存占用。

    def __init__(self, enabled : bool = True, dtype : torch.dtype = torch.float16, cache_enabled : bool = True):

它是一个自动类型转换器,可以根据输入数据的类型自动选择合适的精度进行计算,从而使得计算速度更快,同时也能够节省显存的使用。使用autocast可以避免在模型训练过程中手动进行类型转换,减少了代码实现的复杂性。

在深度学习中,通常会使用浮点数进行计算,但是浮点数需要占用更多的显存,而低精度数值可以在减少精度的同时,减少缓存使用量。因此,对于正向传播和反向传播中的大多数计算,可以使用低精度型的数值,提高内存使用效率,进而提高模型的训练速度。

1.3 autocast原理

autocast的要做的事情,简单来说就是:在进入算子计算之前,选择性的对输入进行cast操作。为了做到这点,在PyTorch1.9版本的架构上,可以分解为如下两步:

  • 在PyTorch算子调用栈上某一层插入处理函数
  • 在处理函数中对算子的输入进行必要操作

核心代码:autocast_mode.cpp

2 autocast优缺点

PyTorch中的autocast功能是一个性能优化工具,它可以自动调整某些操作的数据类型以提高效率。具体来说,它允许自动将数据类型从32位浮点(float32)转换为16位浮点(float16),这通常在使用深度学习模型进行训练时使用。

2.1 autocast优点

  • 提高性能:使用16位浮点数(half precision)进行计算可以在支持的硬件上显著提高性能,特别是在最新的GPU上。

  • 减少内存占用:16位浮点数占用的内存比32位少,这意味着在相同的内存限制下可以训练更大的模型或使用更大的批量大小。

  • 自动管理autocast能够自动管理何时使用16位浮点数,何时使用32位浮点数,这降低了手动管理数据类型的复杂性。

  • 保持精度:尽管使用了较低的精度,但autocast通常能够维持足够的数值精度,对最终模型的准确度影响不大。

2.2 autocast缺点

  • 硬件要求:并非所有的GPU都支持16位浮点数的高效运算。在不支持或优化不足的硬件上,使用autocast可能不会带来性能提升。

  • 精度问题:虽然在大多数情况下精度损失不显著,但在某些应用中,尤其是涉及到小数值或非常大的数值范围时,降低精度可能会导致问题。

  • 调试复杂性:由于autocast在模型的不同部分自动切换数据类型,这可能会在调试时增加额外的复杂性。

  • 算法限制:某些特定的算法或操作可能不适合在16位精度下运行,或者在半精度下的实现可能还不成熟。

  • 兼容性问题:某些PyTorch的特性或第三方库可能还不完全支持半精度运算。

在实际应用中,是否使用autocast通常取决于特定任务的需求、所使用的硬件以及对性能和精度的权衡。通常,对于大多数现代深度学习应用,特别是在使用最新的GPU时,使用autocast可以带来显著的性能优势。

3 使用示例

3.1 autocast混合精度计算

with autocast(): 语句块内的代码会自动进行混合精度计算,也就是根据输入数据的类型自动选择合适的精度进行计算,并且这里使用了GPU进行加速。使用示例如下:

# 导入相关库
import torch
from torch.cuda.amp import autocast# 定义一个模型
class MyModel(torch.nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear = torch.nn.Linear(10, 1)def forward(self, x):with autocast():x = self.linear(x)return x# 初始化数据和模型
x = torch.randn(1, 10).cuda()
model = MyModel().cuda()# 进行前向传播
with autocast():output = model(x)# 计算损失
loss = output.sum()# 反向传播
loss.backward()

3.2 autocast与GradScaler一起使用

因为autocast会损失部分精度,从而导致梯度消失的问题,并且经过中间层时可能计算得到inf导致最终loss出现nan。所以我们通常将GradScaler与autocast配合使用来对梯度值进行一些放缩,来缓解上述的一些问题。

from torch.cuda.amp import autocast, GradScalerdataloader = ...
model = Model.cuda(0)
optimizer = ...
scheduler = ...
scaler = GradScaler()  # 新建GradScale对象,用于放缩
for epoch_idx in range(epochs):for batch_idx, (dataset) in enumerate(dataloader):optimizer.zero_grad()dataset = dataset.cuda(0)with autocast():  # 自动混精度logits = model(dataset)loss = ...scaler.scale(loss).backward()  # scaler实现的反向误差传播scaler.step(optimizer)  # 优化器中的值也需要放缩scaler.update()  # 更新scalerscheduler.step()
...

4 可能出现的问题

使用autocast技术进行混精度训练时loss经常会出现'nan',有以下三种可能原因:

  • 精度损失,有效位数减少,导致输出时数据末位的值被省去,最终出现nan的现象。该情况可以使用GradScaler(上文所示)来解决。
  • 损失函数中使用了log等形式的函数,或是变量出现在了分母中,并且训练时,该数值变得非常小时,混精度可能会让该值更接近0或是等于0,导致了数学上的log(0)或是x/0的情况出现,从而出现'inf'或'nan'的问题。这种时候需要针对该问题设置一个确定值。例如:当log(x)出现-inf的时候,我们直接将输出中该位置的-inf设置为-100,即可解决这一问题。
  • 模型内部存在的问题,比如模型过深,本身梯度回传时值已经非常小。这种问题难以解决。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/193896.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何解决网页中的pdf文件无法下载?pdf打印显示空白怎么办?

问题描述 偶然间,遇到这样一个问题,一个网页上的附件pdf想要下载打印下来,奈何尝试多种办法都不能将其下载下载,点击打印出现的也是一片空白 百度搜索了一些解决方案都不太行,主要解决方案如:https://zh…

Open3D-ML点云语义分割【RandLA-Net】

作为点云 Open3D-ML 实验的一部分,我撰写了文章解释如何使用 Tensorflow 和 PyTorch 支持安装此库。 为了测试安装,我解释了如何运行一个简单的 Python 脚本来可视化名为 SemanticKITTI 的语义分割标记数据集。 在本文中,我将回顾在任何点云上…

uniapp h5发行

前端使用uniapp开发项目完成后,需要将页面打包,生成H5的静态文件,部署在服务器上。 这样通过服务器链接地址,直接可以在手机上点开来访问。 打包全步骤如下: 首先在manifest.json文件中进行基础配置,获取…

【SpringBoot篇】分页查询 | 扩展SpringMvc的消息转换器

文章目录 🛸什么是分页查询🌹代码实现⭐问题🎄解决方法 做了几个项目,发现在这几个项目里面,都实现了分页查询效果,所以就总结一下,方便学习 我们基于黑马程序员的苍穹外卖来讲解分页查询的要点…

黑马程序员微服务 分布式搜索引擎3

分布式搜索引擎03 0.学习目标 1.数据聚合 **聚合(aggregations)**可以让我们极其方便的实现对数据的统计、分析、运算。例如: 什么品牌的手机最受欢迎?这些手机的平均价格、最高价格、最低价格?这些手机每月的销售…

哪种猫罐头比较好?推荐给新手养猫的5款好口碑猫罐头!

新手养猫很容易陷入疯狂购买的模式,但有些品牌真的不能乱买!现在的大环境不太好,我们需要学会控制自己的消费欲望,把钱花在刀刃上!哪种猫罐头比较好?现在宠物市场真的很内卷,很多品牌都在比拼产…

字母不重复的子串-第15届蓝桥第二次STEMA测评Scratch真题精选

[导读]:超平老师的《Scratch蓝桥杯真题解析100讲》已经全部完成,后续会不定期解读蓝桥杯真题,这是Scratch蓝桥杯真题解析第158讲。 第15届蓝桥第2次STEMA测评已于2023年10月29日落下帷幕,编程题一共有6题,分别如下&am…

无需云盘,不限流量实现Zotero跨平台同步:内网穿透+私有WebDAV服务器

🔥博客主页: 小羊失眠啦. 🎥系列专栏:《C语言》 《数据结构》 《Linux》《Cpolar》 ❤️感谢大家点赞👍收藏⭐评论✍️ 无需云盘,不限流量实现Zotero跨平台同步:内网穿透私有WebDAV服务器 文章目…

2-10岁女童冬季穿搭怎么选?麻麻们看这里

分享适合女宝的羽绒服穿搭 这种黄色真的超好看 吸睛显白怎么穿都好看 长款连帽设计,精致走线 冬天穿时尚又好看!!

数据结构—LinkedList与链表

目录 一、链表 1. 链表的概念及结构 1. 单向或者双向 2. 带头或者不带头 3. 循环或者非循环 二.LinkedList的使用 1.LinkedList概念及结构 2. LinkedList的构造 3. LinkedList的方法 三. ArrayList和LinkedList的区别 一、链表 1. 链表的概念及结构 链表是一种 物理…

基于Qt 多线程(继承 QObject 的线程)

​ 继承 QThread 类是创建线程的一种方法,另一种就是继承QObject 类。继承 QObject 类更加灵活。它通过 QObject::moveToThread()方法,将一个 QObeject的类转移到一个线程里执行。恩,不理解的话,我们下面也画个图捋一下。 通过上面的图不难理解,首先我们写一个类继承 QObj…

Ladybug 全景相机, 360°球形成像,带来全方位的视觉体验

360无死角全景照片总能给人带来强烈的视觉震撼,有着大片的既视感。那怎么才能拍出360球形照片呢?它的拍摄原理是通过图片某个点位为中心将图片其他部位螺旋式、旋转式处理,从而达到沉浸式体验的效果。俗话说“工欲善其事,必先利其…

Maven编译报错:javacTask: 源发行版 1.8 需要目标发行版 1.8

报错截图: IDEA中的jdk检查都正常设置的1.8一点毛病没有。参考其他帖子链接如下: https://blog.csdn.net/zhishidi/article/details/131480199https://blog.51cto.com/u_16213460/7197764https://blog.csdn.net/lck_csdn/article/details/125387878 逐…

第四代智能井盖传感器,万宾科技助力城市安全

在迈向更为智能化、相互联系更为紧密的城市发展过程中,智能创新产品无疑扮演了一种重要的角色。智能井盖传感器作为新型科学技术产物,不仅解决传统井盖管理难的问题,也让城市变得更加安全美好,是城市生命线的一层重要保障。这些平…

Mindomo Desktop for Mac(免费思维导图软件)下载

Mindomo Desktop for Mac是一款免费的思维导图软件,适用于Mac电脑用户。它可以帮助你轻松创建、编辑和共享思维导图,让你的思维更加清晰、有条理。 首先,Mindomo Desktop for Mac具有直观易用的界面。它采用了Mac独特的用户界面设计&#xf…

避免defer陷阱:拆解延迟语句,掌握正确使用方法

基本概念 Go语言的延迟语句defer有哪些特点?通常在什么情况下使用? Go语言的延迟语句(defer statement)具有以下特点: 延迟执行:延迟语句会在包含它的函数执行结束前执行,无论函数是正常返回还是…

科研学习|研究方法——Python计量Logit模型

一、离散选择模型 莎士比亚曾经说过:To be, or not to be, that is the question,这就是典型的离散选择模型。如果被解释变量时离散的,而非连续的,称为“离散选择模型”。例如,消费者在购买汽车的时候通常会比较几个不…

Vue 3.0 + vite + axios+PHP跨域问题的解决办法

最后一个Web项目,采用前后端分离。 前端:Vue 3.0 viteelement plus 后端:PHP 运行时前端和后端是两个程序,前端需要时才向后端请求数据。由于是两个程序,这就会出现跨域问题。 比如前端某个地方需要请求的接口如下…

内网穿透的应用-通过内网穿透快速搭建公网可访问的Spring Boot接口调试环境

文章目录 前言1. 本地环境搭建1.1 环境参数1.2 搭建springboot服务项目 2. 内网穿透2.1 安装配置cpolar内网穿透2.1.1 windows系统2.1.2 linux系统 2.2 创建隧道映射本地端口2.3 测试公网地址 3. 固定公网地址3.1 保留一个二级子域名3.2 配置二级子域名3.2 测试使用固定公网地址…

Python自动化测试之request库详解(三)

做过接口测试的都会发现,现在的接口都是HTTPS协议了,今天就写一篇如何通过request发送https请求。 什么是HTTPS HTTPS 的全称是Hyper Text Transfer Protocol over Secure Socket Layer ,是以安全为目标的HTTP通道,简单的讲是HTT…