昇思MindSpore学习入门-自动混合精度

混合精度(Mix Precision)训练是指在训练时,对神经网络不同的运算采用不同的数值精度的运算策略。在神经网络运算中,部分运算对数值精度不敏感,此时使用较低精度可以达到明显的加速效果(如conv、matmul等);而部分运算由于输入和输出的数值差异大,通常需要保留较高精度以保证结果的正确性(如log、softmax等)。

当前的AI加速卡通常通过针对计算密集、精度不敏感的运算设计了硬件加速模块,如NVIDIA GPU的TensorCore、Ascend NPU的Cube等。对于conv、matmul等运算占比较大的神经网络,其训练速度通常会有较大的加速比。

mindspore.amp模块提供了便捷的自动混合精度接口,用户可以在不同的硬件后端通过简单的接口调用获得训练加速。下面我们对混合精度计算原理进行简介,而后通过实例介绍MindSpore的自动混合精度用法。

混合精度计算原理

浮点数据类型主要分为双精度(FP64)、单精度(FP32)、半精度(FP16)。在神经网络模型的训练过程中,一般默认采用单精度(FP32)浮点数据类型,来表示网络模型权重和其他参数。在了解混合精度训练之前,这里简单了解浮点数据类型。

根据IEEE二进制浮点数算术标准(IEEE 754)的定义,浮点数据类型分为双精度(FP64)、单精度(FP32)、半精度(FP16)三种,其中每一种都有三个不同的位来表示。FP64表示采用8个字节共64位,来进行的编码存储的一种数据类型;同理,FP32表示采用4个字节共32位来表示;FP16则是采用2字节共16位来表示。如图所示:

从图中可以看出,与FP32相比,FP16的存储空间是FP32的一半。类似地,FP32则是FP64的一半。因此使用FP16进行运算具备以下优势:

  • 减少内存占用:FP16的位宽是FP32的一半,因此权重等参数所占用的内存也是原来的一半,节省下来的内存可以放更大的网络模型或者使用更多的数据进行训练。
  • 计算效率更高:在特殊的AI加速芯片如华为Atlas训练系列产品和Atlas 200/300/500推理产品系列,或者NVIDIA VOLTA架构的GPU上,使用FP16的执行运算性能比FP32更加快。
  • 加快通讯效率:针对分布式训练,特别是在大模型训练的过程中,通讯的开销制约了网络模型训练的整体性能,通讯的位宽少了意味着可以提升通讯性能,减少等待时间,加快数据的流通。

但是使用FP16同样会带来一些问题:

  • 数据溢出:FP16的有效数据表示范围为 [5.9×10−8,65504],FP32的有效数据表示范围为 [1.4×10−45,1.7×1038]。可见FP16相比FP32的有效范围要窄很多,使用FP16替换FP32会出现上溢(Overflow)和下溢(Underflow)的情况。而在深度学习中,需要计算网络模型中权重的梯度(一阶导数),因此梯度会比权重值更加小,往往容易出现下溢情况。
  • 舍入误差:Rounding Error是指当网络模型的反向梯度很小,一般FP32能够表示,但是转换到FP16会小于当前区间内的最小间隔,会导致数据溢出。如0.00006666666在FP32中能正常表示,转换到FP16后会表示成为0.000067,不满足FP16最小间隔的数会强制舍入。
  • 因此,在使用混合精度获得训练加速和内存节省的同时,需要考虑FP16引入问题的解决。Loss Scale损失缩放,FP16类型数据下溢问题的解决方案,其主要思想是在计算损失值loss的时候,将loss扩大一定的倍数。根据链式法则,梯度也会相应扩大,然后在优化器更新权重时再缩小相应的倍数,从而避免了数据下溢。
  • 根据上述原理介绍,典型的混合精度计算流程如下图所示:

  1. 参数以FP32存储;
  2. 正向计算过程中,遇到FP16算子,需要把算子输入和参数从FP32 cast成FP16进行计算;
  3. 将Loss层设置为FP32进行计算;
  4. 反向计算过程中,首先乘以Loss Scale值,避免反向梯度过小而产生下溢;
  5. FP16参数参与梯度计算,其结果将被cast回FP32;
  6. 除以Loss scale值,还原被放大的梯度;
  7. 判断梯度是否存在溢出,如果溢出则跳过更新,否则优化器以FP32对原始参数进行更新。

下面我们通过导入快速入门中的手写数字识别模型及数据集,演示MindSpore的自动混合精度实现。

类型转换

混合精度计算需要将需要使用低精度的运算进行类型转换,将其输入转为FP16类型,得到输出后进将其重新转回FP32类型。MindSpore同时提供了自动和手动类型转换的方法,满足对易用性和灵活性的不同需求,下面我们分别对其进行介绍。

自动类型转换

mindspore.amp.auto_mixed_precision 接口提供对网络做自动类型转换的功能。自动类型转换遵循黑白名单机制,根据常用的运算精度习惯配置了4个等级,分别为:

  • ‘O0’:神经网络保持FP32;
  • ‘O1’:按白名单将运算cast为FP16;
  • ‘O2’:按黑名单保留FP32,其余运算cast为FP16;
  • ‘O3’:神经网络完全cast为FP16。

下面是使用自动类型转换的示例:

from mindspore.amp import auto_mixed_precision

model = Network()

model = auto_mixed_precision(model, 'O2')

手动类型转换

通常情况下自动类型转换可以通过满足大部分混合精度训练的需求,但当用户需要精细化控制神经网络不同部分的运算精度时,可以通过手动类型转换的方式进行控制。

Cell粒度类型转换

nn.Cell类提供了to_float方法,可以一键配置该模块的运算精度,自动将模块输入cast为指定的精度:

class NetworkFP16(nn.Cell):

    def __init__(self):

        super().__init__()

        self.flatten = nn.Flatten()

        self.dense_relu_sequential = nn.SequentialCell(

            nn.Dense(28*28, 512).to_float(ms.float16),

            nn.ReLU(),

            nn.Dense(512, 512).to_float(ms.float16),

            nn.ReLU(),

            nn.Dense(512, 10).to_float(ms.float16)

        )

    def construct(self, x):

        x = self.flatten(x)

        logits = self.dense_relu_sequential(x)

        return logits

自定义粒度类型转换

当用户需要在单个运算,或多个模块组合配置运算精度时,Cell粒度往往无法满足,此时可以直接通过对输入数据的类型进行cast来达到自定义粒度控制的目的。

class NetworkFP16Manual(nn.Cell):

    def __init__(self):

        super().__init__()

        self.flatten = nn.Flatten()

        self.dense_relu_sequential = nn.SequentialCell(

            nn.Dense(28*28, 512),

            nn.ReLU(),

            nn.Dense(512, 512),

            nn.ReLU(),

            nn.Dense(512, 10)

        )

    def construct(self, x):

        x = self.flatten(x)

        x = x.astype(ms.float16)

        logits = self.dense_relu_sequential(x)

        logits = logits.astype(ms.float32)

        return logits

损失缩放

MindSpore中提供了两种Loss Scale的实现,分别为StaticLossScaler和DynamicLossScaler,其差异为损失缩放值scale value是否进行动态调整。下面以DynamicLossScalar为例,根据混合精度计算流程实现神经网络训练逻辑。

首先,实例化LossScaler,并在定义前向网络时,手动放大loss值。

from mindspore.amp import DynamicLossScaler

# Instantiate loss function and optimizer

loss_fn = nn.CrossEntropyLoss()

optimizer = nn.SGD(model.trainable_params(), 1e-2)

# Define LossScaler

loss_scaler = DynamicLossScaler(scale_value=2**16, scale_factor=2, scale_window=50)

def forward_fn(data, label):

    logits = model(data)

    loss = loss_fn(logits, label)

    # scale up the loss value

    loss = loss_scaler.scale(loss)

return loss, logits

接下来进行函数变换,获得梯度函数。

grad_fn = value_and_grad(forward_fn, None, model.trainable_params())

定义训练step:计算当前梯度值并恢复损失。使用 all_finite 判断是否出现梯度下溢问题,如果无溢出,恢复梯度并更新网络权重;如果溢出,跳过此step。

from mindspore.amp import all_finite

@ms.jit

def train_step(data, label):

    (loss, _), grads = grad_fn(data, label)

    loss = loss_scaler.unscale(loss)

    is_finite = all_finite(grads)

    if is_finite:

        grads = loss_scaler.unscale(grads)

        optimizer(grads)

    loss_scaler.adjust(is_finite)

    return loss

最后,我们训练1个epoch,观察使用自动混合精度训练的loss收敛情况。

size = train_dataset.get_dataset_size()

model.set_train()

for batch, (data, label) in enumerate(train_dataset.create_tuple_iterator()):

    loss = train_step(data, label)

    if batch % 100 == 0:

        loss, current = loss.asnumpy(), batch

        print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

可以看到loss收敛趋势正常,没有出现溢出问题。

Cell配置自动混合精度

MindSpore支持使用Cell封装完整计算图的编程范式,此时可以使用mindspore.amp.build_train_network接口,自动进行类型转换,并将Loss Scale传入,作为整图计算的一部分。 此时仅需要配置混合精度等级和LossScaleManager即可获得配置好自动混合精度的计算图。

FixedLossScaleManager和DynamicLossScaleManager是Cell配置自动混合精度的Loss scale管理接口,分别与StaticLossScalar和DynamicLossScalar对应,具体详见mindspore.amp。

from mindspore.amp import build_train_network, FixedLossScaleManager

model = Network()

loss_scale_manager = FixedLossScaleManager()

model = build_train_network(model, optimizer, loss_fn, level="O2", loss_scale_manager=loss_scale_manager)

Model 配置自动混合精度

mindspore.train.Model是神经网络快速训练的高阶封装,其将mindspore.amp.build_train_network封装在内,因此同样只需要配置混合精度等级和LossScaleManager,即可进行自动混合精度训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/386583.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OSI七层模型详解

OSI七层模型 OSI(Open System Interconnect),即开放式系统互连。 一般都叫OSI参考模型,是ISO组织在1985年研究的网络互连模型。该体系结构标准定义了网络互连的七层框架(物理层、数据链路层、网络层、传输层、会话层、…

[玄机]流量特征分析-常见攻击事件 tomcat

[玄机]流量特征分析-常见攻击事件 tomcat 题目做法及思路解析(个人分享) Tomcat是一个开源的Java Servlet容器,它实现了Java Servlet和JavaServer Pages (JSP) 技术,提供了一个运行这些应用程序的Web服务器环境。Tomcat由A…

go程序在windows服务中优雅开启和关闭

本篇主要是讲述一个go程序,如何在windows服务中优雅开启和关闭,废话不多说,开搞!!!   使用方式:go程序 net服务启动 Ⅰ 开篇不利 Windows go进程编译后,为一个.exe文件,直接执行即…

语言转文字

因为工作原因需要将语音转化为文字,经常搜索终于找到一个免费的好用工具,记录下使用方法 安装Whisper 搜索Colaboratory 右上方链接服务 执行 !pip install githttps://github.com/openai/whisper.git !sudo apt update && sudo apt install f…

NSSRound#4 Team

[NSSRound#4 SWPU]1zweb 考察&#xff1a;phar的反序列化 1.打开环境&#xff0c;审计代码 1.非预期解 直接用file伪协议读取flag,或直接读取flag file:///flag /flag 2.正常解法 用读取文件读取index.php,upload.php的源码 index.php: <?php class LoveNss{publi…

hadoop学习(一)

一.hadoop概述 1.1hadoop优势 1&#xff09;高可靠性&#xff1a;Hadoop底层维护多个数据副本&#xff0c;即使Hadoop某个计算元素或存储出现故障&#xff0c;也不会导致数据的丢失。 2&#xff09;高扩展性&#xff1a;在集群间分配任务数据&#xff0c;可方便扩展数以千计…

c++ 内存管理(newdeletedelete[])

因为在c里面新增了类&#xff0c;所以我们在有时候会用malloc来创建类&#xff0c;但是这种创建只是单纯的开辟空间&#xff0c;没有什么默认构造的。同时free也是free的表面&#xff0c;如果类里面带有指针指向堆区的成员变量就会free不干净。 所以我们c增加了new delete和de…

Python --Pandas库基础方法(2)

文章目录 Pandas 变量类型的转换查看各列数据类型改变数据类型 重置索引删除行索引和切片seriesDataFrame取列按行列索引选择loc与iloc获取 isin()选择query()的使用排序用索引排序使用变量值排序 修改替换变量值对应数值的替换 数据分组基于拆分进行筛选 分组汇总引用自定义函…

springcloud RocketMQ 客户端是怎么走到消费业务逻辑的 - debug step by step

springcloud RocketMQ &#xff0c;一个mq消息发送后&#xff0c;客户端是怎么一步步拿到消息去消费的&#xff1f;我们要从代码层面探究这个问题。 找的流程图&#xff0c;有待考究。 以下我们开始debug&#xff1a; 拉取数据的线程&#xff1a; PullMessageService.java 本…

126M全球手机基站SHP数据分享

数据是GIS的血液&#xff01; 我们在《2.8亿东亚五国建筑数据分享》一文中&#xff0c;为你分享过东亚五国建筑数据。 现在再为你分享全球手机基站SHP数据&#xff0c;你可以在文末查看该数据的领取方法。 全球手机基站SHP数据 全球手机基站数据是OpenCelliD团队创建由社区…

【Spring Cloud】Sleuth +Zinkin 实现链路追踪并持久化的解决方案

文章目录 前言链路追踪介绍Sleuth入门Sleuth介绍TraceSpanAnnotation Sleuth入门1、引入依赖2、修改配置文件3、网关路由配置4、演示 Zipkin的集成ZipKin介绍ZipKin服务端安装Zipkin客户端集成1、添加依赖2、添加配置3、访问微服务4、演示 Zipkin数据持久化使用mysql实现数据持…

现代Java开发:使用jjwt实现JWT认证

前言 jjwt 库 是一个流行的 Java 库&#xff0c;用于创建和解析 JWT。我在学习spring security 的过程中看到了很多关于jwt的教程&#xff0c;其中最流行的就是使用jjwt实现jwt认证&#xff0c;但是教程之中依然使用的旧版的jjwt库&#xff0c;许多的类与方法已经标记弃用或者…

多家隧道代理价格:阿布云、快代理、小象代理、熊猫代理和亿牛云……

随着奥运的热度攀升&#xff0c;各大品牌也在抓紧时机赶上这波奥运热潮&#xff0c;随之而来的大量数据信息收集和分析工作也接踵而至&#xff0c;在这一数据采集过程中&#xff0c;HTTP代理的质量和价格对企业的效率和成本调控重要性不言而喻。我们大部分人在日常购买产品的时…

Revit中如何添加剖面?快速实现剖面图的方法汇总

Revit中创建剖面以及剖面视图一般有两种方法&#xff0c;一是使用Revit原生的剖面功能&#xff0c;而是使用Revit插件BIM建模助手进行便捷的剖面操作以及剖面视图创建。 Revit原生的剖面功能&#xff0c;点击后可以自由拉伸剖面方向、范围&#xff0c;放置完剖面符号后&#xf…

【ROS 最简单教程 003/300】ROS 快速体验:Hello World

开始自己的第一次尝试叭 ~ Hello World 本篇是 C 版本&#xff0c;如需 python 版本 &#x1f449; python 版本指路 ROS 中程序的实现流程&#xff1a; 创建工作空间 ( &#x1f499; 如 tutu_ws) &#xff0c;进入并编译 mkdir -p tutu_ws/src cd tutu_ws catkin_make在 src …

【C语言】结构体详解 -《探索C语言的 “小宇宙” 》

目录 C语言结构体&#xff08;struct&#xff09;详解结构体概览表1. 结构体的基本概念1.1 结构体定义1.2 结构体变量声明 2. 结构体成员的访问2.1 使用点运算符&#xff08;.&#xff09;访问成员输出 2.2 使用箭头运算符&#xff08;->&#xff09;访问成员输出 3. 结构体…

springboot使用Gateway做网关并且配置全局拦截器

一、为什么要用网关 统一入口&#xff1a; 作用&#xff1a;作为所有客户端请求的统一入口。说明&#xff1a;所有客户端请求都通过网关进行路由&#xff0c;网关负责将请求转发到后端的微服务 路由转发&#xff1a; 作用&#xff1a;根据请求的URL、方法等信息将请求路由到…

Hive之扩展函数(UDF)

Hive之扩展函数(UDF) 1、概念讲解 当所提供的函数无法解决遇到的问题时&#xff0c;我们通常会进行自定义函数&#xff0c;即&#xff1a;扩展函数。Hive的扩展函数可分为三种&#xff1a;UDF,UDTF,UDAF。 UDF&#xff1a;一进一出 UDTF&#xff1a;一进多出 UDAF&#xff1a…

作业帮6-19笔试-选填题

可以看到10在第一位&#xff0c;说明用的是挖坑法快速排序&#xff0c;过程如下&#xff1a; 右指针从最右边开始&#xff0c;找到第一个比30小的数10&#xff0c;与30交换。 10、15、40、28、50、30、70 左指针从位置1开始&#xff0c;找到40&#xff0c;与30互换。 10、15、3…