PyTorch AMP 混合精度中grad_scaler.py的scale函数解析

PyTorch AMP 混合精度中的 scale 函数解析

混合精度训练(AMP, Automatic Mixed Precision)是深度学习中常用的技术,用于提升训练效率并减少显存占用。在 PyTorch 的 AMP 模块中,GradScaler 类负责动态调整和管理损失缩放因子,以解决 FP16 运算中的数值精度问题。而 scale 函数是 GradScaler 的一个重要方法,用于将输出的张量按当前缩放因子进行缩放。

本文将详细解析 scale 函数的作用、代码逻辑,以及 apply_scale 子函数的递归作用。


函数代码回顾

以下是 scale 函数的完整代码:
Source: anaconda3/envs/xxx/lib/python3.10/site-packages/torch/amp/grad_scaler.py

torch 2.4.0+cu121版本

def scale(self,outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:"""Multiplies ('scales') a tensor or list of tensors by the scale factor.Returns scaled outputs.  If this instance of :class:`GradScaler` is not enabled, outputs are returnedunmodified.Args:outputs (Tensor or iterable of Tensors):  Outputs to scale."""if not self._enabled:return outputs# Short-circuit for the common case.if isinstance(outputs, torch.Tensor):if self._scale is None:self._lazy_init_scale_growth_tracker(outputs.device)assert self._scale is not Nonereturn outputs * self._scale.to(device=outputs.device, non_blocking=True)# Invoke the more complex machinery only if we're treating multiple outputs.stash: List[_MultiDeviceReplicator] = []  # holds a reference that can be overwritten by apply_scaledef apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):if isinstance(val, torch.Tensor):if len(stash) == 0:if self._scale is None:self._lazy_init_scale_growth_tracker(val.device)assert self._scale is not Nonestash.append(_MultiDeviceReplicator(self._scale))return val * stash[0].get(val.device)if isinstance(val, abc.Iterable):iterable = map(apply_scale, val)if isinstance(val, (list, tuple)):return type(val)(iterable)return iterableraise ValueError("outputs must be a Tensor or an iterable of Tensors")return apply_scale(outputs)

1. 函数作用

scale 函数的主要作用是将输出张量(outputs)按当前的缩放因子(self._scale)进行缩放。它支持以下两种输入:

  1. 单个张量:直接将缩放因子乘以张量。
  2. 张量的可迭代对象(如列表或元组):递归地对每个张量进行缩放。

当 AMP 功能未启用时(即 self._enabledFalse),scale 函数会直接返回原始的 outputs,不执行任何缩放操作。

使用场景

  • 放大梯度:在反向传播之前,放大输出张量的数值,以减少数值舍入误差对 FP16 计算的影响。
  • 支持多设备:通过 _MultiDeviceReplicator 支持张量分布在多个设备(如多 GPU)的场景。

2. 核心代码解析

(1) 短路处理单个张量

当输入 outputs 是单个张量(torch.Tensor)时,函数直接对其进行缩放:

if isinstance(outputs, torch.Tensor):if self._scale is None:self._lazy_init_scale_growth_tracker(outputs.device)assert self._scale is not Nonereturn outputs * self._scale.to(device=outputs.device, non_blocking=True)
逻辑解析:
  1. 如果缩放因子 self._scale 尚未初始化,则调用 _lazy_init_scale_growth_tracker 方法在指定设备上初始化缩放因子。
  2. 使用 outputs * self._scale 对张量进行缩放。这里使用了 to(device=outputs.device) 确保缩放因子与张量在同一设备上。

这是单个张量输入的快速路径处理。


(2) 多张量递归处理逻辑

当输入为张量的可迭代对象(如列表或元组)时,函数调用子函数 apply_scale 进行递归缩放:

stash: List[_MultiDeviceReplicator] = []  # 用于存储缩放因子对象def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):if isinstance(val, torch.Tensor):if len(stash) == 0:if self._scale is None:self._lazy_init_scale_growth_tracker(val.device)assert self._scale is not Nonestash.append(_MultiDeviceReplicator(self._scale))return val * stash[0].get(val.device)if isinstance(val, abc.Iterable):iterable = map(apply_scale, val)if isinstance(val, (list, tuple)):return type(val)(iterable)return iterableraise ValueError("outputs must be a Tensor or an iterable of Tensors")return apply_scale(outputs)
apply_scale 子函数的作用
  1. 张量处理

    • 如果 val 是单个张量,检查 stash 是否为空。
    • 如果为空,初始化缩放因子对象 _MultiDeviceReplicator,并存储在 stash 中。
    • 使用 stash[0].get(val.device) 获取对应设备上的缩放因子,并对张量进行缩放。
  2. 递归处理可迭代对象

    • 如果 val 是一个可迭代对象,调用 map(apply_scale, val),对其中的每个元素递归地调用 apply_scale
    • 如果输入是 listtuple,则保持其原始类型。
  3. 类型检查

    • 如果 val 既不是张量也不是可迭代对象,抛出错误。

3. apply_scale 是递归函数吗?

是的,apply_scale 是一个递归函数。

递归逻辑

  • 当输入为嵌套结构(如张量的列表或列表中的列表)时,apply_scale 会递归调用自身,将缩放因子应用到最底层的张量。
  • 递归的终止条件是 val 为单个张量(torch.Tensor)。
示例:

假设输入为嵌套张量列表:

outputs = [torch.tensor([1.0, 2.0]), [torch.tensor([3.0]), torch.tensor([4.0, 5.0])]]
scaled_outputs = scaler.scale(outputs)

递归处理过程如下:

  1. outputs 调用 apply_scale

    • 第一个元素是张量 torch.tensor([1.0, 2.0]),直接缩放。
    • 第二个元素是列表,递归调用 apply_scale
  2. 进入嵌套列表 [torch.tensor([3.0]), torch.tensor([4.0, 5.0])]

    • 第一个元素是张量 torch.tensor([3.0]),缩放。
    • 第二个元素是张量 torch.tensor([4.0, 5.0]),缩放。

4. _MultiDeviceReplicator 的作用

_MultiDeviceReplicator 是一个工具类,用于在多设备场景下管理缩放因子对象的复用。它根据张量所在的设备返回正确的缩放因子。

  • 当张量分布在多个设备(如 GPU)时,_MultiDeviceReplicator 可以高效地为每个设备提供所需的缩放因子,避免重复初始化。

总结

scale 函数是 AMP 混合精度训练中用于梯度缩放的重要方法,其作用是将输出张量按当前缩放因子进行缩放。通过递归函数 apply_scale,该函数能够处理嵌套的张量结构,同时支持多设备场景。

关键点总结:

  1. 快速路径:单张量输入的情况下,直接进行缩放。
  2. 递归处理:对于张量的嵌套结构,递归地对每个张量进行缩放。
  3. 设备管理:通过 _MultiDeviceReplicator 支持多设备场景。

通过 scale 函数,PyTorch 的 AMP 模块能够高效地调整梯度数值范围,提升混合精度训练的稳定性和效率。

后记

2025年1月2日15点47分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/501352.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RAG实战:本地部署ragflow+ollama(linux)

1.部署ragflow 1.1安装配置docker 因为ragflow需要诸如elasticsearch、mysql、redis等一系列三方依赖,所以用docker是最简便的方法。 docker安装可参考Linux安装Docker完整教程,安装后修改docker配置如下: vim /etc/docker/daemon.json {…

56.在 Vue 3 中使用 OpenLayers 通过 moveend 事件获取地图左上和右下的坐标信息

前言 在现代 Web 开发中,地图应用越来越成为重要的组成部分。OpenLayers 是一个功能强大的 JavaScript 地图库,它提供了丰富的地图交互和操作功能,而 Vue 3 是当前流行的前端框架之一。在本篇文章中,我们将介绍如何在 Vue 3 中集…

Codigger集成Copilot:智能编程助手

在信息技术的快速发展中,编程效率和创新能力的提升成为了开发者们追求的目标。Codigger平台通过集成Copilot智能编程助手,为开发者提供了一个强大的工具,以增强其生产力、创新力和技能水平。本文将深入探讨Codigger与Copilot的集成如何为IT专…

用uniapp写一个播放视频首页页面代码

效果如下图所示 首页有导航栏&#xff0c;搜索框&#xff0c;和视频列表&#xff0c; 导航栏如下图 搜索框如下图 视频列表如下图 文件目录 视频首页页面代码如下 <template> <view class"video-home"> <!-- 搜索栏 --> <view class…

Java高频面试之SE-08

hello啊&#xff0c;各位观众姥爷们&#xff01;&#xff01;&#xff01;本牛马baby今天又来了&#xff01;哈哈哈哈哈嗝&#x1f436; 成员变量和局部变量的区别有哪些&#xff1f; 在 Java 中&#xff0c;成员变量和局部变量是两种不同类型的变量&#xff0c;它们在作用域…

在Typora中实现自动编号

文章目录 在Typora中实现自动编号1. 引言2. 准备工作3. 自动编号的实现3.1 文章大纲自动编号3.2 主题目录&#xff08;TOC&#xff09;自动编号3.3 文章内容自动编号3.4 完整代码 4. 应用自定义CSS5. 结论 在Typora中实现自动编号 1. 引言 Typora是一款非常流行的Markdown编辑…

Oracle exp和imp命令导出导入dmp文件

目录 一. 安装 instantclient-tools 工具包二. exp 命令导出数据三. imp 命令导入数据四. expdp 和 impdp 命令 一. 安装 instantclient-tools 工具包 ⏹官方网站 https://www.oracle.com/cn/database/technologies/instant-client/linux-x86-64-downloads.html ⏹因为我们在…

小程序发版后,强制更新为最新版本

为什么要强制更新为最新版本&#xff1f; 在小程序的开发和运营过程中&#xff0c;强制用户更新到最新版本是一项重要的策略&#xff0c;能够有效提升用户体验并保障系统的稳定性与安全性。以下是一些主要原因&#xff1a; 1. 功能兼容 新功能或服务通常需要最新版本的支持&…

设计模式 创建型 原型模式(Prototype Pattern)与 常见技术框架应用 解析

原型模式&#xff08;Prototype Pattern&#xff09;是一种创建型设计模式&#xff0c;其核心思想在于通过复制现有的对象&#xff08;原型&#xff09;来创建新的对象&#xff0c;而非通过传统的构造函数或类实例化方式。这种方式在需要快速创建大量相似对象时尤为高效&#x…

办公 三之 Excel 数据限定录入与格式变换

开始-----条件格式------管理规则 IF($A4"永久",1,0) //如果A4包含永久&#xff0c;条件格式如下&#xff1a; OR($D5<60,$E5<60,$F5<60) 求取任意科目不及格数据 AND($D5<60,$E5<60,$F5<60) 若所有科目都不及格 显示为红色 IF($H4<EDATE…

黑马JavaWeb开发跟学(十四).SpringBootWeb原理

黑马JavaWeb开发跟学 十四.SpringBootWeb原理 SpingBoot原理1. 配置优先级2. Bean管理2.1 获取Bean2.2 Bean作用域2.3 第三方Bean 3. SpringBoot原理3.1 起步依赖3.2 自动配置3.2.1 概述3.2.2 常见方案3.2.2.1 概述3.2.2.2 方案一3.2.2.3 方案二 3.2.3 原理分析3.2.3.1 源码跟踪…

linux-26 文件管理(四)install

说一个命令&#xff0c;叫install&#xff0c;man install&#xff0c;install是什么意思&#xff1f;安装&#xff0c;install表示安装的意思&#xff0c;那你猜install是用来干什么的&#xff1f;猜一猜干什么的&#xff1f;安装软件&#xff0c;安装第三方软件&#xff0c;错…

Win11+WLS Ubuntu 鸿蒙开发环境搭建(二)

参考文章 penHarmony南向开发笔记&#xff08;一&#xff09;开发环境搭建 OpenHarmony&#xff08;鸿蒙南向开发&#xff09;——标准系统移植指南&#xff08;一&#xff09; OpenHarmony&#xff08;鸿蒙南向开发&#xff09;——小型系统芯片移植指南&#xff08;二&…

多文件比对

要比对多个存储目录下的文件是否存在重复文件&#xff0c;可以通过以下步骤实现 MD5 值的比对&#xff1a; 1. 提取文件路径 首先从你的目录结构中获取所有文件的路径&#xff0c;可以使用 find 命令递归列出所有文件路径&#xff1a;find /traixxxnent/zpxxxxx -type f >…

46. Three.js案例-创建颜色不断变化的立方体模型

46. Three.js案例-创建颜色不断变化的立方体模型 实现效果 知识点 Three.js基础组件 WebGLRenderer THREE.WebGLRenderer是Three.js提供的用于渲染场景的WebGL渲染器。它支持抗锯齿处理&#xff0c;可以设置渲染器的大小和背景颜色。 构造器 antialias: 是否开启抗锯齿&am…

【51单片机零基础-chapter6:LCD1602调试工具】

实验0-用显示屏LCD验证自己的猜想 如同c的cout,前端的console.log() #include <REGX52.H> #include <INTRINS.H> #include "LCD1602.h" int var0; void main() {LCD_Init();LCD_ShowNum(1,1,var211,5);while(1){;} }实验1-编写LCD1602液晶显示屏驱动函…

【GO基础学习】gin的使用

文章目录 模版使用流程参数传递路由分组数据解析和绑定gin中间件 模版使用流程 package mainimport ("net/http""github.com/gin-gonic/gin" )func main() {// 1.创建路由r : gin.Default()// 2.绑定路由规则&#xff0c;执行的函数// gin.Context&#x…

杰盛微 JSM4056 1000mA单节锂电池充电器芯片 ESOP8封装

JSM4056 1000mA单节锂电池充电器芯片 JSM4056是一款单节锂离子电池恒流/恒压线性充电器&#xff0c;简单的外部应用电路非常适合便携式设备应用&#xff0c;适合USB电源和适配器电源工作&#xff0c;内部采用防倒充电路&#xff0c;不需要外部隔离二极管。热反馈可对充电电流进…

Linux实验报告14-Linux内存管理实验

目录 一&#xff1a;实验目的 二&#xff1a;实验内容 1、编辑模块的源代码mm_viraddr.c 2、编译模块 3、编写测试程序mm_test.c 4、编译测试程序mm_test.c 5、在后台运行mm_test 6、验证mm_viraddr模块 一&#xff1a;实验目的 (1)掌握内核空间、用户空间&#xff…

供需平台信息发布付费查看小程序系统开发方案

供需平台信息发布付费查看小程序系统主要是为了满足个人及企业用户的供需信息发布与匹配需求。 一、目标用户群体 个人用户&#xff1a;寻找兼职工作、二手物品交换、本地服务&#xff08;如家政、维修&#xff09;等。 小微企业&#xff1a;推广产品和服务&#xff0c;寻找合…