深度学习优化-Gradient Checkpointing

数学原理参考:

梯度检查点技术(Gradient Checkpointing)详细介绍:中英双语-CSDN博客

视频讲解参考:

用梯度检查点来节省显存 gradient checkpointing_哔哩哔哩_bilibili

Gradient Checkpointing(梯度检查点

Gradient Checkpointing 是一种用于优化深度学习模型训练的技术,旨在减少训练过程中显存的占用。在深度神经网络训练中,通常需要存储每一层的激活值以用于反向传播计算梯度。然而,对于层数较多或参数量较大的模型,这些激活值会占用大量显存。

Gradient Checkpointing 的核心思想是在前向传播时选择性地保存部分激活值(称为检查点),而丢弃其他激活值。在反向传播时,如果需要这些被丢弃的激活值,则重新计算它们。通过这种方式,显存使用量可以从 O(L) 降低到 O(K),其中 L 是网络层数,K 是选择的检查点层数。

工作原理

  1. 选择检查点:在前向传播时,选择某些层作为检查点,保存这些层的激活值。

  2. 丢弃激活值:对于未被选为检查点的层,丢弃其激活值。

  3. 反向传播时重新计算:在反向传播时,如果需要被丢弃的激活值,则通过重新计算它们来获取,从而计算梯度。

a1和a3被丢弃,反向传播时,如果需要被丢弃的激活值,则需要重新计算

a1 = x * w1,

a3 = a2 * w3

优点与缺点

优点

  • 显著减少显存占用,使训练更大规模的模型成为可能。

  • 在显存受限的环境中,可以提高训练效率。

  • 允许使用更大的批量大小,从而加速训练。

缺点

  • 增加了计算开销,因为需要在反向传播时重新计算激活值。

  • 实现复杂度增加,需要修改代码来管理检查点。

  • 可能导致训练时间延长。

实现方法

在 PyTorch 中,可以通过 torch.utils.checkpoint 模块实现 Gradient Checkpointing。例如:

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpointclass SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.layer1 = nn.Linear(256, 256)self.layer2 = nn.Linear(256, 256)self.layer3 = nn.Linear(256, 10)def forward(self, x):x = checkpoint.checkpoint(self.layer1, x)  # 应用梯度检查点x = checkpoint.checkpoint(self.layer2, x)x = self.layer3(x)  # 最后一层不需要检查点return x

在 DeepSpeed 中,可以通过配置文件启用 Gradient Checkpointing:

{"train_batch_size": 16,"gradient_accumulation_steps": 4,"zero_optimization": {"stage": 2,"contiguous_gradients": true},"gradient_checkpointing": true
}

应用场景

Gradient Checkpointing 广泛应用于以下场景:

  • 训练大规模深度学习模型,如 7B 或 10B 参数的模型。

  • 在 GPU 显存有限的环境中优化训练。

  • 提高训练效率,同时减少硬件成本。

通过合理使用 Gradient Checkpointing,可以在有限的硬件资源下训练更大规模的模型,同时平衡显存和计算开销。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/33043.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

sql靶场-时间盲注(第九、十关)保姆级教程

目录 时间盲注(第九、十关) 1.判断 2.确认时间盲注 2.手工尝试时间盲注 数据库名长度 数据库名字符 表数 表名长度 表名字符 字段数 字段名长度 字段名字符 4.脚本时间盲注注入 5.第十关 时间盲注(第九、十关) 1.判…

小米路由器SSH下安装DDNS-GO

文章目录 前言一、下载&安装DDNS-GO二、配置ddns-go设置开机启动 前言 什么是DDNS? DDNS(Dynamic Domain Name Server)是动态域名服务的缩写。 目前路由器拨号上网获得的多半都是动态IP,DDNS可以将路由器变化的外网I…

Flutter_学习记录_device_info_plus 插件获取设备信息

引入三方库device_info_plus导入头文件 import package:device_info_plus/device_info_plus.dart;获取设备信息的主要代码 DeviceInfoPlugin deviceInfoPlugin DeviceInfoPlugin(); BaseDeviceInfo deviceInfo await deviceInfoPlugin.deviceInfo;完整案例 import package…

【现代深度学习技术】卷积神经网络05:汇聚层

【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈PyTorch深度学习 ⌋ ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重…

Amazon RDS ProxySQL 探索(一)

:::info 💡 在日常开发中,开发者们会涉及到数据库的连接,在使用Amazon RDS数据库时,若使用集群模式或者多数据库时,会出现一写多读多个Endpoint,在实际开发中, 开发者们配置数据库连接通常希望走…

Appium高级操作--ActionChains类、Toast元素识别、Hybrid App操作、手机系统API的操作

书接上回Appium高级操作--从源码角度解析--模拟复杂手势操作-CSDN博客文章浏览阅读712次,点赞24次,收藏6次。下面总结Appium模拟复杂手势整体流程创建类实例action时,一定要传入WebDriver实例参数,创建实例成功后,调用…

媲美Deepseek R1 671B的千问QwQ32B本地部署与远程访问实测流程

文章目录 前言1. 环境准备2.QwQ 32B模型安装与运行测试3. 安装Open WebUI图形化界面3.1 安装Open WebUI3.2 添加QWQ32B模型 4. 安装内网穿透工具5. 配置固定公网地址总结 前言 近日,阿里千问发布了最新推理模型QwQ32B !并表示“它只有 320亿参数&#x…

SpringBoot整合RabbitMq

1.引入依赖 <!--RabbitMq相关--> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-amqp</artifactId> </dependency>2.application.yml文件配置 spring:rabbitmq:host: 192.168.101.129…

2024年第十五届蓝桥杯软件C/C++大学A组——五子棋对弈

蓝桥杯原题&#xff1a; 题目描述&#xff1a; “在五子棋的对弈中&#xff0c;友谊的小船说翻就翻&#xff1f; ” 不&#xff01;对小蓝和小桥来说&#xff0c;五子棋不仅是棋盘上的较量&#xff0c;更是心与心之间的沟通。这两位挚友秉承着 “ 友谊第一&#xff0c;比赛第二…

PyQt基础——简单的图形化界面(窗口)

一、代码展示 import sysfrom PyQt6.QtGui import QPixmap from PyQt6.QtWidgets import QWidget, QApplication, QLabel, QLineEdit, QPushButton from PyQt6 import uic from PyQt6.QtCore import Qt# 封装一个我的窗口类 class MyWidget(QWidget):def __init__(self):supe…

大语言模型-1.2-大模型技术基础

简介 本博客内容是《大语言模型》一书的读书笔记&#xff0c;该书是中国人民大学高瓴人工智能学院赵鑫教授团队出品&#xff0c;覆盖大语言模型训练与使用的全流程&#xff0c;从预训练到微调与对齐&#xff0c;从使用技术到评测应用&#xff0c;帮助学员全面掌握大语言模型的…

【MATLAB例程】AOA(到达角度)法,多个目标定位算法,三维空间、锚点数量自适应(附完整代码)

给出AOA方法下的多目标定位,适用三维空间,锚点数量>3即可,可自定义目标和锚点的数量、坐标等。 文章目录 运行结果源代码代码讲解概述功能代码结构运行结果 10个锚点、4个目标的情况: 100个锚点、10个目标的情况: 修改方便,只需调节下面的两个数字即可: 源代码 …

[CVE-2017-10271]Weblogic--WLS Security反序列化漏洞复现

文章目录 靶机地址靶机说明开启并访问靶机POCexp反弹shell后续查找flag过程略 靶机地址 网站地址 cyberstrikelab.com 靶机地址CVE-2017-10271 靶机说明 Weblogic的WLS Security组件对外提供webservice服务&#xff0c;其中使用了XMLDecoder来解析用户传入的XML数据&#xf…

如何在vscode中编译linux中的c++文件

方式一 在终端打开进行连接编译 指令含义&#xff1a;将 muduo_server.cpp 源文件编译成一个可执行文件 server&#xff0c;并且在链接过程中使用 muduo_net、muduo_base 库以及 pthread 库 方式二 在vscode中修改配置文件 按F1打开配置文件搜索栏&#xff0c;输入C/C 打开…

Unity中刚体撞墙抖动的原因和本质

当我们制作角色移动的时候我们都知道使用设置位置的方法来移动一个带有刚体和碰撞体的物体&#xff0c;遇到碰撞体的时候就会抖动。 上网查找原因&#xff0c;都说是和物理系统冲突导致的&#xff0c;然后再也找不到其他线索。 这个说法&#xff0c;对&#xff0c;但它并不是最…

文件解析漏洞靶场通关合集

一、IIS解析漏洞 &#xff08;一&#xff09;iis6的目录解析漏洞(.asp目录中的所有文件都会被当做asp文件执行) 第一步&#xff1a;在网站根目录下创建了一个x.asp文件夹&#xff0c;并在文件夹中创建一个名为1.txt的文本文档 第二步&#xff1a;文本文档中输入<% now()%&…

【Linux】浅谈冯诺依曼和进程

一、冯诺依曼体系结构 冯诺依曼由 输入设备、输出设备、运算器、控制器、存储器 五部分组成。 冯诺依曼的设计特点 二进制表示 所有数据&#xff08;包括程序指令&#xff09;均以二进制形式存储和运算&#xff0c;简化了硬件逻辑设计&#xff0c;提高了可靠性。 存储程序原理…

技术聚焦:Debezium 如何将数据库数据精准注入 Kafka

#作者&#xff1a;任少近 文章目录 第一章 Debezium抽取mysql数据给kafka原理第二章 Debezium 与kafka抽取方法及验证2.1 debezium2.0kafka3.3.1mysql82.2 debezium2.0kafka2.6.1mysql82.3 debezium2.0kafka2.6.1mysql5.7 第一章 Debezium抽取mysql数据给kafka原理 debezium的…

SpringBoot学生宿舍管理系统的设计与开发

项目概述 幽络源分享的《SpringBoot学生宿舍管理系统的设计与开发》是一款专为校园宿舍管理设计的智能化系统&#xff0c;基于SpringBoot框架开发&#xff0c;功能全面&#xff0c;操作便捷。该系统涵盖管理员、宿管员和学生三大角色&#xff0c;分别提供宿舍管理、学生信息管…

LLM剪枝代码解释与实现

LLM剪枝代码解释与实现 目录 LLM剪枝代码解释与实现函数概述函数参数函数实现步骤1. 遍历模型的所有参数2. 筛选权重参数3. 计算参数的绝对值4. 计算阈值5. 创建掩码6. 应用掩码7. 返回剪枝后的模型总结可运行代码注意安装包的版本信息 transformers adapter-transformers函数概…