pytorch之梯度累加

1.什么是梯度?

梯度可以理解为一个多变量函数的变化率,它告诉我们在某一点上,函数的输出如何随输入的变化而变化。更直观地说,梯度指示了最优化方向。

  • 在机器学习中的作用:在训练模型时,我们的目标是最小化损失函数,以提高模型的准确性。损失函数是衡量模型预测值与真实值之间差距的函数。梯度告诉我们如何调整模型参数,以使损失函数的值减小。

2. 模型参数的优化

考虑一个简单的线性模型:

y=wx+b

  • 其中,yy 是输出,xx 是输入,ww 是权重,bb 是偏置。
  • 为了训练模型,我们使用损失函数(例如均方误差)来衡量模型输出与真实输出之间的差距。损失函数通常定义为:

Loss=1/N∑i=1N(ypred,i−ytrue,i)^2

  • 这里 NN 是样本数,ypred是模型计算的预测值,ytrue是真实值。

3. 反向传播

反向传播是一种高效计算梯度的算法,尤其在深度学习中使用广泛。

3.1 前向传播

在前向传播中,我们将输入数据通过模型传递,计算出预测结果,并基于预测结果与真实结果计算损失。

3.2 计算梯度

反向传播通过链式法则计算梯度,以更新模型参数。通过反向传播,我们可以得到损失函数对每个参数(比如 ww 和 bb)的导数,这些导数就是梯度。

  • 链式法则:假设有一个复合函数 z=f(g(x)),则其导数为:

dz/dx=dz/dg⋅dg/dx

这个法则帮助我们逐层计算梯度。

4. 梯度的累加

4.1 为什么会累加?

在训练过程中,我们可能会处理多个训练样本进行参数更新。如果连续调用多次 loss.backward(),每次都会将计算的梯度值加到之前的梯度上。

  • 这意味着如果我们不清零梯度,梯度会随着样本数的增加而不断增加,可能导致参数更新的步幅变得非常大,影响模型的收敛。
4.2 样例代码

让我们通过具体的代码示例来更好地理解梯度的累加和为什么需要清零。

import torch
import torch.nn as nn
import torch.optim as optim# 定义简单线性模型
model = nn.Linear(1, 1)  # 线性模型:1个输入,1个输出
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 使用 SGD 优化器# 模拟一些数据
x = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=False)  # 输入
y = torch.tensor([[2.0], [3.0], [4.0]], requires_grad=False)  # 目标输出# 训练循环
for epoch in range(5):  # 假设训练 5 轮for i in range(len(x)):  # 遍历每个训练样本optimizer.zero_grad()  # 清零梯度,确保只考虑当前样本# 前向传播output = model(x[i])  # 计算当前样本的输出# 计算损失loss = (output - y[i]) ** 2  # 均方误差损失# 反向传播loss.backward()  # 计算梯度,累加到 model 的参数中# 更新参数optimizer.step()  # 使用累加的梯度更新参数print(f"Epoch: {epoch}, Sample: {i}, Loss: {loss.item()}, W: {model.weight.data}, b: {model.bias.data}")

5. 源代码解释

  1. 清零梯度:在每次处理新的训练样本前,调用 optimizer.zero_grad() 清空梯度。这是为了确保每个训练样本只对当前的梯度产生影响。

  2. 前向传播:计算当前输入的输出。

  3. 损失计算:计算输出与真实值之间的差距。

  4. 反向传播:通过 loss.backward() 计算当前样本对模型参数的梯度并将其累加到 model.parameters() 的 grad 属性上。

  5. 参数更新:调用 optimizer.step() 进行参数更新。

在 PyTorch 中,梯度的累加是一种非常重要且实用的特性,其设计有几个原因:

6. 支持小批量(Mini-batch)训练

在实践中,由于计算资源的限制,通常使用小批量数据进行训练。这意味着我们不会一次性使用整个数据集来更新模型,而是对一小部分数据频繁进行计算。

  • 梯度累加允许我们在多个小批量上计算梯度,并在适当的时候一并更新模型参数。这种策略被称为“累积梯度”。

例如,如果我们有一个较大的数据集,可以将其分为多个小批量,然后在每个小批量上计算梯度。在所有小批量处理完成后,再进行一次参数更新。这种方法可以模拟使用更大批量数据的效果,提高模型的表现。

7. 提高训练灵活性

梯度累加允许用户在特定情况下有效地控制参数更新的频率。例如:

  • 如果处理每个样本时都立即更新权重,可能会导致训练过程不稳定。而通过在多个样本上累加梯度,可以缓解这种波动性,平滑参数的更新过程。

  • 用户可以决定什么情况下清零梯度,例如只有在处理完一个完整的训练周期(epoch)后,或在经历多个小批量后再更新一次参数。这种控制在很多情况下可以提高性能和收敛性。

8. 节省内存

对于一些深度学习模型,特别是当模型较大,或者在训练过程中使用大量数据时,清零梯度后再进行反向传播通常需要的内存较少。没有累加的梯度能够避免内存的额外消耗,进而提高整个训练过程的效率。

9. 灵活的梯度管理

开发者可以基于需求自定义梯度累加的策略。例如,有时我们可能希望实现一些特殊的训练策略,比如调整学习率、动态更改模型的训练方式等。在这些情况下,梯度的管理就显得至关重要。

10. 应用在不同的训练模式

在一些变种训练方式中,如强化学习或一些优化器的特殊需求,可能需要在更新权重前手动控制梯度。这对开发者提供了更大的灵活性和更丰富的训练策略。

数学推导

假设我们有一个线性回归模型,其数学表达式为:

y = W \cdot x + b

均方误差(MSE)损失函数:

\mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2

梯度计算:

对于权重W

\frac{\partial \mathcal{L}}{\partial W} = \frac{2}{N} \sum_{i=1}^{N} x_i (\hat{y}_i - y_i)


对于偏置 b

\frac{\partial \mathcal{L}}{\partial b} = \frac{2}{N} \sum_{i=1}^{N} (\hat{y}_i - y_i)

假设我们将数据集分成若干小批量,每个小批量包含 m个样本。我们累加这些梯度,并在累积一定数量的小批量k后更新参数。

对于第j个小批量,计算梯度:

对于权重W

\nabla W_j = \frac{2}{m} \sum_{i=1}^{m} x_i (\hat{y}_i - y_i)

对于偏置 b

\nabla b_j = \frac{2}{m} \sum_{i=1}^{m} (\hat{y}_i - y_i)

梯度累加

\nabla W_{\text{accumulated}} = \sum_{j=1}^{k} \nabla W_j

\nabla b_{\text{accumulated}} = \sum_{j=1}^{k} \nabla b_j

g更新参数

W \leftarrow W - \eta \cdot \nabla W_{\text{accumulated}}

b \leftarrow b - \eta \cdot \nabla b_{\text{accumulated}}

import torch
import torch.nn as nn
import torch.optim as optim# 创建数据集
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0]])
y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0], [10.0]])# 简单的线性回归模型
class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear = nn.Linear(1, 1)def forward(self, x):return self.linear(x)# 初始化模型、损失函数和优化器
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 定义小批量大小和累积步数
batch_size = 2
accumulation_steps = 2# 训练过程
for epoch in range(5):optimizer.zero_grad()  # 清零梯度for i in range(0, len(x_train), batch_size):# 获取小批量数据x_batch = x_train[i:i + batch_size]y_batch = y_train[i:i + batch_size]# 前向传播outputs = model(x_batch)loss = criterion(outputs, y_batch)# 反向传播,累加梯度loss.backward()# 每处理完指定的累积步骤后,更新参数并清零梯度if (i // batch_size + 1) % accumulation_steps == 0:optimizer.step()optimizer.zero_grad()# 打印损失print(f'Epoch [{epoch + 1}/5], Loss: {loss.item():.4f}')# 打印模型参数
print(f'Final Parameters: W: {model.linear.weight.item():.4f}, b: {model.linear.bias.item():.4f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/438647.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TransFormer 视频笔记

TransFormer BasicsAttention单头注意力 single head attentionQ: query 查寻矩阵 128*12288K key matrix 128*12288SoftMax 归一 ![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/19e3cf1ea28442eca60d5fc1303921f4.png)Value matrix 12288*12288 MLP Bas…

【Linux】进程地址空间、环境变量:从理论到实践(三)

🌈 个人主页:Zfox_ 🔥 系列专栏:Linux 目录 🚀 前言一:🔥 环境变量 🥝 基本概念🥝 常见环境变量🥝 查看环境变量方法 二:🔥 测试 &…

前端算法合集-1(含面试题)

(这是我面试一家中厂公司的二面算法题) 数组去重并按出现次数排序 题目描述: 给定一个包含重复元素的数组,请你编写一个函数对数组进行去重,并按元素出现的次数从高到低排序。如果次数相同,则按元素值从小到大排序。 let arr [2, 11,10, 1…

GPTQ vs AWQ vs GGUF(GGML) 速览和 GGUF 文件命名规范

简单介绍一下四者的区别。 参考链接:GPTQ - 2210.17323 | AWQ - 2306.00978 | GGML | GGUF - docs | What is GGUF and GGML? 文章目录 GPTQ vs AWQ vs GGUF(GGML) 速览GGUF 文件命名GGUF 文件结构文件名解析答案 附录GGUF 文件命名GGUF 文件…

15分钟学 Python 第35天 :Python 爬虫入门(一)

Day 35 : Python 爬虫简介 1.1 什么是爬虫? 网页爬虫(Web Crawler)是自动访问互联网并提取所需信息的程序。爬虫的主要功能是模拟用户通过浏览器访问网页的操作,从而实现对网页内容的批量访问与信息提取。它们广泛应用于数据收集…

JAVA并发编程系列(13)Future、FutureTask异步小王子

美团本地生活面试:模拟外卖订单处理,客户支付提交订单后,查询订单详情,后台需要查询店铺备餐进度、以及外卖员目前位置信息后再返回。 时间好快,一转眼不到一个月时间,已经完成分享synchronized、volatile、…

【VUE】案例:商场会员管理系统

编写vuedfr实现对会员进行基本增删改查 1. drf项目初始化 请求: POST http://127/0.0.0.1:8000/api/auth/ {"username":"cqn", "password":"123"}返回: {"username":"cqn", "token&q…

读论文、学习时 零碎知识点记录01

1.入侵检测技术 2.深度学习、机器学习相关的概念 ❶注意力机制 ❷池化 ❸全连接层 ❹Dropout层 ❺全局平均池化 3.神经网络中常见的层

.NET Core 集成 MiniProfiler性能分析工具

前言: 在日常开发中,应用程序的性能是我们需要关注的一个重点问题。当然我们有很多工具来分析程序性能:如:Zipkin等;但这些过于复杂,需要单独搭建。 MiniProfiler就是一款简单,但功能强大的应用…

Unraid的cache使用btrfs或zfs?

Unraid的cache使用btrfs或zfs? 背景:由于在unraid中添加了多个docker和虚拟机,因此会一直访问硬盘。然而,单个硬盘实在难以让人放心。在阵列盘中,可以通过添加校验盘进行数据保护,在cache中无法使用xfs格式…

深入挖掘C++中的特性之一 — 继承

目录 1.继承的概念 2.举个继承的例子 3.继承基类成员访问方式的变化 1.父类成员的访问限定符对在子类中访问父类成员的影响 2.父类成员的访问限定符子类的继承方式对在两个类外访问子类中父类成员的影响 4.继承类模版(注意事项) 5.父类与子类间的转…

数据结构——计数、桶、基数排序

目录 引言 计数排序 1.算法思想 2.算法步骤 3.代码实现 4.复杂度分析 桶排序 1.算法思想 2.算法步骤 3.代码实现 4.复杂度分析 基数排序 1.算法思想 2.算法步骤 3.代码实现 4.复杂度分析 排序算法的稳定性 1.稳定性的概念 2.各个排序算法的稳定性 结束语 引…

C++(string类的实现)

1. 迭代器、返回capacity、返回size、判空、c_str、重载[]和clear的实现 string类的迭代器的功能就类似于一个指针,所以我们可以直接使用一个指针来实现迭代器,但如下图可见迭代器有两个,一个是指向的内容可以被修改,另一个则是指…

Pytorch最最适合研究生的入门教程,Q3 开始训练

文章目录 Pytorch最最适合研究生的入门教程Q3 开始训练3.1 训练的见解3.2 Pytorch基本训练框架work Pytorch最最适合研究生的入门教程 Q3 开始训练 3.1 训练的见解 如何理解深度学习能够完成任务? 考虑如下回归问题 由函数 y f ( x ) yf(x) yf(x)采样得到的100个…

【安当产品应用案例100集】018-Vmware Horizon如何通过安当ASP身份认证系统增强登录安全性

启用Radius认证是提高VMware Horizon环境安全性的有效方法,特别是在需要满足复杂安全要求的场景中。 启用Radius认证对于VMware Horizon具有以下几个关键优势: 增强安全性:Radius认证支持多种认证方法,包括PAP、CHAP、MS-CHAPv1…

web前端面试中拍摄的真实js面试题(真图)

web前端面试中拍摄的真实js面试题(真图) WechatIMG258.jpeg WechatIMG406.jpeg WechatIMG407.jpeg WechatIMG922.jpeg WechatIMG1063.jpeg © 著作权归作者所有,转载或内容合作请联系作者 喜欢的朋友记得点赞、收藏、关注哦!!…

TypeScript 算法手册 - 【冒泡排序】

文章目录 TypeScript 算法手册 - 冒泡排序1. 冒泡排序简介1.1 冒泡排序定义1.2 冒泡排序特点 2. 冒泡排序步骤过程拆解2.1 比较相邻元素2.2 交换元素2.3 重复过程 3. 冒泡排序的优化3.1 提前退出3.2 记录最后交换位置案例代码和动态图 4. 冒泡排序的优点5. 冒泡排序的缺点总结 …

【SpringBoot详细教程】-09-Redis详细教程以及SpringBoot整合Redis【持续更新】

🌲 Redis 简介 🌾 什么是Redis Redis 是C语言开发的一个开源高性能键值对的内存数据库,可以用来做数据库、缓存、消息中间件等场景,是一种NoSQL(not-only sql,非关系型数据库)的数据库 Redis是互联网技术领域使用最为广泛的存储中间件,它是「Remote DictionaryServic…

TARA分析方法论——威胁分析和风险评估方法

一、什么是TARA分析方法论 威胁分析和风险评估(Threat Analysis and Risk Assessment) 通过识别整车/项目的网络安全资产,分析其中的潜在的安全威胁,综合考虑威胁攻击可行性、危害影响等因素,识别出整车/项目可能存在…

Python并发编程(2)——初始Python多线程

左手编程,右手年华。大家好,我是一点,关注我,带你走入编程的世界。 公众号:一点sir,关注领取python编程资料 前言 什么是多线程? 为什么需要多线程? 多线程的优点和缺点&#xff1f…