PyTorch 自动混合精度AMP Grad Scaler 源码解析:_unscale_grads_ 与 unscale_ 函数

PyTorch AMP Grad Scaler 源码解析:_unscale_grads_ 与 unscale_ 函数

引言

本文详细解析 PyTorch 自动混合精度(AMP)模块中 grad_scaler.py 文件的两个关键函数:_unscale_grads_unscale_。这些函数在梯度缩放与反缩放过程中起到了关键作用,特别适用于训练大规模深度学习模型时的数值稳定性优化。我们还将给出详细的示例与数值模拟,帮助理解其具体应用。


1. _unscale_grads_ 函数解析

def _unscale_grads_(self,optimizer: torch.optim.Optimizer,inv_scale: torch.Tensor,found_inf: torch.Tensor,allow_fp16: bool,) -> Dict[torch.device, torch.Tensor]:per_device_inv_scale = _MultiDeviceReplicator(inv_scale)per_device_found_inf = _MultiDeviceReplicator(found_inf)# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.# There could be hundreds of grads, so we'd like to iterate through them just once.# However, we don't know their devices or dtypes in advance.# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict# Google says mypy struggles with defaultdicts type annotations.per_device_and_dtype_grads: Dict[torch.device, Dict[torch.dtype, List[torch.Tensor]]] = defaultdict(lambda: defaultdict(list))with torch.no_grad():for group in optimizer.param_groups:for param in group["params"]:assert isinstance(param, torch.Tensor)if param.grad is None:continueif (not allow_fp16) and param.grad.dtype == torch.float16:raise ValueError("Attempting to unscale FP16 gradients.")if param.grad.is_sparse:# is_coalesced() == False means the sparse grad has values with duplicate indices.# coalesce() deduplicates indices and adds all values that have the same index.# For scaled fp16 values, there's a good chance coalescing will cause overflow,# so we should check the coalesced _values().if param.grad.dtype is torch.float16:param.grad = param.grad.coalesce()to_unscale = param.grad._values()else:to_unscale = param.grad# TODO: is there a way to split by device and dtype without appending in the inner loop?per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)for device, per_dtype_grads in per_device_and_dtype_grads.items():for grads in per_dtype_grads.values():torch._amp_foreach_non_finite_check_and_unscale_(grads,per_device_found_inf.get(device),per_device_inv_scale.get(device),)return per_device_found_inf._per_device_tensors

1.1 函数定义

def _unscale_grads_(self,optimizer: torch.optim.Optimizer,inv_scale: torch.Tensor,found_inf: torch.Tensor,allow_fp16: bool,) -> Dict[torch.device, torch.Tensor]:

该函数主要用于将梯度从缩放状态恢复到原始大小,同时检查是否存在数值溢出情况。

1.2 参数说明

  • optimizer:优化器对象,包含训练过程中使用的所有参数。
  • inv_scale:缩放因子的倒数,用于恢复梯度。
  • found_inf:用于记录是否存在无穷大或 NaN 值。
  • allow_fp16:是否允许 FP16 精度的梯度反缩放,默认设置为 False。

1.3 核心实现步骤

  1. 按设备与数据类型分类梯度:

    • 将优化器中的参数按设备和数据类型进行分组,便于批量处理。
    • 使用 defaultdict 对分组存储。
  2. 检查梯度并分类:

    • 遍历每个参数,如果存在稀疏梯度,使用 coalesce() 消除重复索引。关于这个方法, 可以参考笔者的另一篇博客:PyTorch 中 coalesce() 函数详解与应用示例
    • 将梯度分组存储到 per_device_and_dtype_grads 中。
  3. 调用 PyTorch 内部函数反缩放梯度:

    • 使用 torch._amp_foreach_non_finite_check_and_unscale_() 批量反缩放梯度并检查是否存在 NaN 或无穷大值。 这个具体解析请参考笔者的另一篇博客:PyTorch源码_amp_foreach_non_finite_check_and_unscale_cpu_kernel 函数解析:自动混合精度AMP的一部分
  4. 返回各设备上的溢出检查结果:

    • 输出包含各设备是否发现溢出的布尔值张量。

1.4 关键代码片段

with torch.no_grad():for group in optimizer.param_groups:for param in group["params"]:if param.grad is None:continueif (not allow_fp16) and param.grad.dtype == torch.float16:raise ValueError("Attempting to unscale FP16 gradients.")to_unscale = param.grad._values() if param.grad.is_sparse else param.gradper_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)for device, per_dtype_grads in per_device_and_dtype_grads.items():for grads in per_dtype_grads.values():torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device), per_device_inv_scale.get(device))

2. unscale_ 函数解析

def unscale_(self, optimizer: torch.optim.Optimizer) -> None:"""Divides ("unscales") the optimizer's gradient tensors by the scale factor.:meth:`unscale_` is optional, serving cases where you need to:ref:`modify or inspect gradients<working-with-unscaled-gradients>`between the backward pass(es) and :meth:`step`.If :meth:`unscale_` is not called explicitly,  gradients will be unscaled  automatically during :meth:`step`.Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::...scaler.scale(loss).backward()scaler.unscale_(optimizer)torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)scaler.step(optimizer)scaler.update()Args:optimizer (torch.optim.Optimizer):  Optimizer that owns the gradients to be unscaled... note:::meth:`unscale_` does not incur a CPU-GPU sync... warning:::meth:`unscale_` should only be called once per optimizer per :meth:`step` call,and only after all gradients for that optimizer's assigned parameters have been accumulated.Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError... warning:::meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute."""if not self._enabled:returnself._check_scale_growth_tracker("unscale_")optimizer_state = self._per_optimizer_states[id(optimizer)]if optimizer_state["stage"] is OptState.UNSCALED:raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")elif optimizer_state["stage"] is OptState.STEPPED:raise RuntimeError("unscale_() is being called after step().")# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.assert self._scale is not Noneinv_scale = self._scale.double().reciprocal().float()found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)optimizer_state["stage"] = OptState.UNSCALED

2.1 函数定义

def unscale_(self, optimizer: torch.optim.Optimizer) -> None:

该函数是 PyTorch AMP 提供的外部接口,供用户调用以解除梯度缩放。

2.2 参数说明

  • optimizer:包含所有待训练参数的优化器对象。

2.3 核心实现步骤

  1. 状态检查:
    • 检查是否已经调用过 unscale_step
  2. 计算反缩放因子:
    • 使用 FP64 精度计算缩放因子的倒数,以避免精度误差。reciprocal这是取倒数的函数,具体可以参考笔者的另一篇博客:PyTorch 中 reciprocal(取倒数)函数的深入解析:分析底层实现CPP代码
  3. 调用内部函数 _unscale_grads_
    • 执行反缩放过程,包含稀疏梯度与 NaN 检查。
  4. 更新状态记录:
    • 将优化器状态更新为 “UNSCALED”。

2.4 关键代码片段

if optimizer_state["stage"] is OptState.UNSCALED:raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED

3. 使用示例与数值模拟

3.1 示例代码

import torch
from torch.cuda.amp import GradScaler, autocast# 创建模型和优化器
model = torch.nn.Linear(10, 1).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
scaler = GradScaler()# 模拟训练循环
for epoch in range(2):for step in range(5):data = torch.randn(16, 10).cuda()target = torch.randn(16, 1).cuda()optimizer.zero_grad()# 使用混合精度训练with autocast():output = model(data)loss = torch.nn.functional.mse_loss(output, target)# 缩放梯度scaler.scale(loss).backward()# 手动解除梯度缩放scaler.unscale_(optimizer)# 使用梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)# 更新权重与缩放器scaler.step(optimizer)scaler.update()print(f"Epoch {epoch}, Step {step}, Loss: {loss.item()}")

3.2 数值模拟分析

  1. 梯度缩放影响:
    缩放因子 = 65536 时,梯度放大至 10^4 量级,有助于 FP16 避免下溢问题。
  2. 反缩放结果验证:
    对比反缩放前后的梯度值,可观察到恢复精度并避免溢出错误。
  3. 梯度裁剪测试:
    执行 torch.nn.utils.clip_grad_norm_(),确认反缩放后的梯度值能够被安全裁剪。

4. 注意事项与总结

  1. 注意 API 使用顺序:
    调用 unscale_ 应在反向传播完成后、优化器更新前进行。
  2. 防止重复调用:
    多次调用可能导致状态不一致,应确保每轮训练仅调用一次。
  3. 稀疏梯度支持:
    自动处理稀疏梯度的特殊情况,避免溢出。

这两个函数是 AMP 核心模块,提供了稳定高效的混合精度训练支持。通过示例与数值分析,开发者可以更好地理解 AMP 工作原理并优化深度学习模型训练过程。


后记

2025年1月2日18点49分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/503337.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI Development Notes 1 - introduction with the OpenAI API Development

Official document&#xff1a;https://platform.openai.com/docs/api-reference/chat/create 1. Use APIfox to call APIs 2.Use PyCharm to call APIs 2.1-1 WIN OS.Configure the Enviorment variable #HK代理环境&#xff0c;不需要科学上网(价格便宜、有安全风险&#…

ComfyUI节点安装笔记

AI高速发展&#xff0c;版本更新相当快&#xff08;11月25日才安装的版本v.0.3.4&#xff0c;27日版本就已经更新到v.0.3.5了&#xff09;&#xff0c;在遇到问题&#xff0c;找到问题原因所在的过程中&#xff0c;ComfyUI版本、python版本、节点对环境版本的依赖&#xff0c;本…

小白学Pytorch

小白学Pytorch 发现一个比较好的教程&#xff0c;对于自己来说比较合适&#xff0c;适合从零开始的教程。 1、搭建一个简单的网络 https://www.cnblogs.com/PythonLearner/p/13587092.html 搭建网络这步说的比较清楚&#xff1a; 我们使用nn包中的Sequential搭建网络&#…

如何查看服务器上的MySQL/Redis等系统服务状态和列表

如果呢你知道系统服务名称&#xff0c;要看状态很简单&#xff1a; systemctl status server-name 比如 systemctl status nginxsystemctl status redis # 等 这是一个nginx的示例&#xff1a; 那问题是 当你不知道服务名称时该怎么办。举个例子&#xff0c;比如mysql在启动…

ubuntu开机启动服务

需求背景&#xff1a; 需要监控日志&#xff0c;每次都是手动启动 nohup ./prometheus >/dev/null & nohub ./node_exporter >/dev/null & 需求目标&#xff1a; 重启后系统自动启动服务

路由组件与一般组件的区别

路由组件与一般组件的区别 1. 基本概念 1.1 路由组件 路由组件是指通过路由规则映射的组件&#xff0c;通常放在 pages 或 views 文件夹中。 1.2 一般组件 一般组件是指通过 import 导入后直接使用的组件&#xff0c;通常放在 components 文件夹中。 2. 主要区别 2.1 存…

Qt天气预报系统设计界面布局第四部分右边

Qt天气预报系统 1、第四部分右边的第一部分1.1添加控件 2、第四部分右边的第二部分2.1添加控件 3、第四部分右边的第三部分3.1添加控件3.2修改控件名字 1、第四部分右边的第一部分 1.1添加控件 拖入一个widget&#xff0c;改名为widget04r作为第四部分的右边 往widget04r再拖…

数据库系统概论期末复习

期末考试题型&#xff1a; 选择题 20题 20分 判断题 10题 10分 简答题 4题 20分 SQL语句&#xff1a; &#xff08;select delete update&#xff09;30分 设计题&#xff1a;ER图 和关系模式 ER转关系模式&#xff0c;注意主码&#xff0c;外码的标注 15分 应用题&#xff1a;…

uni-app 页面生命周期及组件生命周期汇总(Vue2、Vue3)

文章目录 一、前言&#x1f343;二、页面生命周期三、Vue2 页面及组件生命周期流程图四、Vue3 页面及组件生命周期流程图4.1 页面加载时序介绍4.2 页面加载常见问题4.3 onShow 和 onHide4.4 onInit4.5 onLoad4.6 onReachBottom4.7 onPageScroll4.8 onBackPress4.9 onTabItemTap…

微信小程序中 “页面” 和 “非页面” 的区别

微信小程序中 “页面” 和 “非页面” 的区别&#xff0c;并用表格进行对比。 核心概念&#xff1a; 页面 (Page)&#xff1a; 页面是微信小程序中用户可以直接交互的视图层&#xff0c;也是小程序的基本组成部分。每个页面都有自己的 WXML 结构、WXSS 样式和 JavaScript 逻辑…

【Linux】传输层协议UDP

目录 再谈端口号 端口号范围划分 UDP协议 UDP协议端格式 UDP的特点 UDP的缓冲区 UDP注意事项 进一步深刻理解 再谈端口号 在上图中&#xff0c;有两个客户端A和B&#xff0c;客户端A打开了两个浏览器&#xff0c;这两个客户端都访问同一个服务器&#xff0c;都访问服务…

大数据架构演变

一、离线数仓 缺点&#xff1a; ETL计算、存储、时间成本高数据处理链路过长无法支持实时、近实时的数据分析数据采集对业务库造成影响 二、Lambda架构&#xff0c;离线实时分开 缺点&#xff1a; 组件多&#xff0c;不方便管理很难保证数据一致数据探查困难&#xff0c;出现…

进程间通讯

简介&#xff1a; 进程间通讯方式有&#xff1a; 1.内存映射&#xff08;mmap&#xff09;&#xff1a; 使用mmap函数将磁盘空间映射到内存 2.管道 3.信号 4.套接字&#xff08;socket&#xff09; 5.信号机制 通过进程中kill函数&#xff0c;去给另一个函数发送信号&a…

毕业项目推荐:基于yolov8/yolov5的行人检测识别系统(python+卷积神经网络)

文章目录 概要一、整体资源介绍技术要点功能展示&#xff1a;功能1 支持单张图片识别功能2 支持遍历文件夹识别功能3 支持识别视频文件功能4 支持摄像头识别功能5 支持结果文件导出&#xff08;xls格式&#xff09;功能6 支持切换检测到的目标查看 二、数据集三、算法介绍1. YO…

[桌面运维]windows自动设置浅深色主题

设置自动浅色/深色主题 我看很多up主的教程过于繁琐&#xff0c;需要添加四个功能&#xff0c;并且有些还不能生效&#xff01; 大多数都是教程&#xff1a; 自动任务栏浅色 add HKCUSOFTWAREMicrosoftWindowsCurrentVersionThemesPersonalize/v SystemUsesLightTheme /t …

LQ quarter 5th

目录 B. 开赛主题曲 C. BlueAI E. 精准难度 B. 开赛主题曲 &#xff08;1&#xff09;两层循环枚举所有子串。第一层子串长度&#xff0c;第二层子串起点 &#xff08;2&#xff09;判子串是否合法还要一个 for&#xff0c;26 * 26 * 2e5 快要超时&#xff0c;因此计算每个字母…

Directx12 chapter4

官方的初始化需要的组件 Initialize 初始化涉及到首次设置全局变量和类&#xff0c;initialize 函数必须准备管道和资产。 初始化管道。 启用调试层。创建设备。创建命令队列。创建交换链。创建渲染器目标视图 (RTV) 描述符堆。 备注 可将描述符堆视为描述符的数组。 其中…

STM32 软件I2C读写

单片机学习&#xff01; 目录 前言 一、软件I2C读写代码框架 二、I2C初始化 三、六个时序基本单元 3.1 引脚操作的封装和改名 3.2 起始条件执行逻辑 3.3 终止条件执行逻辑 3.4 发送一个字节 3.5 接收一个字节 3.5 发送应答&接收应答 3.5.1 发送应答 3.5.2 接…

计算机网络--UDP和TCP课后习题

【5-05】 试举例说明有些应用程序愿意采用不可靠的UDP, 而不愿意采用可靠的TCP。 解答&#xff1a; 这可能有以下几种情况。 首先&#xff0c;在互联网上传输实时数据的分组时&#xff0c;有可能会出现差错甚至丢失。如果利用 TCP 协议对这些出错或丢失的分组进行重传&…

【C++】B2099 矩阵交换行

博客主页&#xff1a; [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C 文章目录 &#x1f4af;前言&#x1f4af;题目描述题目描述输入格式输出格式输入输出样例输入 #1输出 #1 &#x1f4af;题目分析&#x1f4af;不同解法分析我的做法实现步骤&#xff1a;优点&#xff1a;不足&#…