【Pytorch笔记】4.梯度计算

深度之眼官方账号 - 01-04-mp4-计算图与动态图机制

前置知识:计算图
可以参考我的笔记:
【学习笔记】计算机视觉与深度学习(2.全连接神经网络)

计算图

在这里插入图片描述
以这棵计算图为例。这个计算图中,叶子节点为x和w。

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)# 调用backward()方法,开始反向求梯度
y.backward()
print(w.grad)print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)

输出:

tensor([5.])
is_leaf:True True False False False
gradient:tensor([5.]) tensor([2.]) None None None

由此可见,非叶子节点在最后不会被保留梯度。这是出于节省空间的需要而这样设计的。实际的计算图会非常大,如果每个节点都保留梯度,会占用非常大的存储空间,而这些节点的梯度对于我们学习并没有什么帮助。

如果非要看他们的梯度,可以这样操作:在a = torch.add(w, x)的后面加上一句a.retain_grad(),这样a的梯度就会被存储起来。
输出会变成:

tensor([5.])
is_leaf:True True False False False
gradient:tensor([5.]) tensor([2.]) tensor([2.]) None None

对于节点,还可以看这些节点进行的运算。grad_fn,gradient function的缩写,表示这个节点的tensor是什么运算产生的。加一句:

print("gradient function:\n", w.grad_fn, '\n', x.grad_fn, '\n', a.grad_fn, '\n', b.grad_fn, '\n', y.grad_fn)

会输出

gradient function:NoneNone<AddBackward0 object at 0x000001B1DA3651C0><AddBackward0 object at 0x000001B1DA3651F0><MulBackward0 object at 0x000001B1DA3515B0>

retain_graph

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
a.retain_grad()
b = torch.add(w, 1)
y = torch.mul(a, b)# 调用backward()方法,开始反向求梯度
y.backward()
y.backward()

连续两次调用backward()方法,会报这样的错误:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因是我们进行第一次backward()后,计算图就被自动释放掉了,进行第二次backward()时,没有计算图可以计算梯度,于是报错。

解决方案:backward内部添加一个参数:retain_graph=True,意思是计算完梯度后保留计算图。

# 调用backward()方法,开始反向求梯度
y.backward(retain_graph=True)
y.backward()

这样就不会报错了。

gradient

当计算图末部的节点有1个以上时,有时我们会希望他们之间的梯度有一个权重关系。这时就会用上gradient

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.add(w, 1)# 不难看出,y0和y1是两个互不干扰的末部节点
y0 = torch.mul(a, b)
y1 = torch.add(a, b)# 将两个末部节点打包起来
loss = torch.cat([y0, y1], dim=0)
grad_tensors = torch.tensor([1., 2.])# 将grad_tensors中的内容作为权重,变成y0+2y1
loss.backward(gradient=grad_tensors)print(w.grad)

输出

tensor([9.])

如果把grad_tensors改成:

grad_tensors = torch.tensor([1., 3.])

输出变成:

tensor([11.])

torch.autograd.grad()

除了加减乘除法,我们还可以对torch进行求导操作。求的是 d ( o u t p u t s ) d ( i n p u t s ) \frac{d(outputs)}{d(inputs)} d(inputs)d(outputs)

torch.autograd.grad(outputs,inputs,grad_outputs=None,retain_graph=None,create_graph=False)

outputs和inputs已在上述定义中给出;
grad_outputs:多梯度权重;
retain_graph:保留计算图;
create_graph:创建计算图。

import torch# y = x ** 2
x = torch.tensor([3.], requires_grad=True)
y = torch.pow(x, 2)# grad_1 = dy / dx = 2x = 6
grad_1 = torch.autograd.grad(y, x, create_graph=True)
print(grad_1)# grad_2 = d(dy / dx) / dx = 2
grad_2 = torch.autograd.grad(grad_1, x)
print(grad_2)

输出

(tensor([6.], grad_fn=<MulBackward0>),)
(tensor([2.]),)

autograd注意事项

1.梯度不会自动清零

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)for i in range(4):a = torch.add(w, x)b = torch.mul(w, x)y = torch.mul(a, b)y.backward()print("w's grad: ", w.grad)# w.grad.zero_()

输出:

w's grad:  tensor([8.])
w's grad:  tensor([16.])
w's grad:  tensor([24.])
w's grad:  tensor([32.])

由此可以看出,在不加上注释掉的那一行时,梯度在w处是不断累积的。而如果我们把print后面的那句w.grad.zero_()加上,输出就会变成:

w's grad:  tensor([8.])
w's grad:  tensor([8.])
w's grad:  tensor([8.])
w's grad:  tensor([8.])

w.grad.zero_()的意思就是把w处积累的梯度清零。

2.依赖于叶子节点的节点,requires_grad默认为True

可以从上面的代码中发现,我们只有在定义w和x两个tensor时,设置requires_grad为True。这个参数在定义tensor时默认为False。后面我们的a、b、y都没有设置这个参数。

如果我们定义w和x的时候不加上requires_grad=True,那么y.backward()这一步就会报错,因为我们的预设,这两个tensor不需要梯度,于是就无法求梯度。而w和x是我们计算图上的叶子节点,所以必须加上requires_grad=True。

而后面通过w和x延伸定义出的a、b、y,由于依赖的w、x的requires_grad是True,那么a、b、y的这个参数也被默认设置为了True,不需要我们手动添加。

3.叶子节点不可执行in-place操作

计算图上叶子节点处的tensor不能进行原地修改。

什么是in-place操作?
t = torch.tensor([1., 2.])
t.add_(3.)
print(t)

输出

tensor([4., 5.])

torch.Tensor.add_就是torch.add的in-place版本。所谓in-place,就是在tensor上进行原地修改。大部分的torch.tensor的运算,名字后面加一个下划线,就变成inplace操作了。

再比如求绝对值:

t = torch.tensor([-1., -2.])
t.abs_()
print(t)

输出

tensor([1., 2.])

知道什么是in-place操作后,我们尝试一下在requires_grad=True的叶子节点上原地修改,代码如下:

import torchw = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.mul(w, x)
y = torch.mul(a, b)w.add_(1)y.backward()

报错信息:

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/147529.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分类预测 | MATLAB实现SSA-FS-SVM麻雀算法同步优化特征选择结合支持向量机分类预测

分类预测 | MATLAB实现SSA-FS-SVM麻雀算法同步优化特征选择结合支持向量机分类预测 目录 分类预测 | MATLAB实现SSA-FS-SVM麻雀算法同步优化特征选择结合支持向量机分类预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 MATLAB实现SSA-FS-SVM麻雀算法同步优化特征选择结…

快速选择排序

"你经过我每个灿烂时刻&#xff0c;我才真正学会如你般自由" 前些天有些无聊&#xff0c;想试试自己写的快排能否过leetcode上的排序算法题。结果是&#xff0c;不用截图可想而知&#xff0c;肯定是没过的&#xff0c;否则也不会有这篇文章的产出。 这份快排算法代码…

HCQ1-1300-D【高速输入】

因为我的PLC固件比较旧。所以有些限制。【比如&#xff1a;编译不报错&#xff0c;下载PLC程序就报故障】我的PLC的高速输入类型只能是【hsi_ref】 所以&#xff0c;程序添加的高速输入模块只能是【1.0.1.0】版本 如果固件版本低&#xff0c;看下固件能支持的类型。选错的话&am…

苹果曾考虑基于定位控制AirPods Pro自适应音频

在一次最近的采访中&#xff0c;苹果公司的高管Ron Huang和Eric Treski透露&#xff0c;他们在开发AirPods Pro自适应音频功能时&#xff0c;曾考虑使用GPS信号来控制音频级别。这个有趣的细节打破了我们对AirPods Pro的固有认知&#xff0c;让我们对苹果的创新思维有了更深的…

Python学习笔记之运算符的使用

Python学习笔记之运算符的使用 整型&#xff1a;二进制0b100十进制4、八进制0o100十进制64、十进制100、十六进制0x100十进制256浮点型&#xff1a;123.456&#xff0c;1.23456e2字符串型&#xff1a;‘Hello’&#xff0c;“Hello”布尔型&#xff1a;True、False复数型&…

电子计算机核心发展(继电器-真空管-晶体管)

目录 继电器 最大的机电计算机之一——哈弗Mark1号&#xff0c;IBM1944年 背景 组成 性能 核心——继电器 简介 缺点 速度 齿轮磨损 Bug的由来 真空管诞生 组成 控制开关电流 继电器对比 磨损 速度 缺点 影响 代表 第一个可编程计算机 第一个真正通用&am…

ssm+vue的图书馆书库管理系统(有报告)。Javaee项目,ssm vue前后端分离项目。

演示视频&#xff1a; ssmvue的图书馆书库管理系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;ssm vue前后端分离项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构…

OpenCV 基础图像处理

1、生成图像 cv2.imread是OpenCV库中的一个函数&#xff0c;用于读取图像文件。它接受一个参数&#xff0c;即要读取的图像文件的路径&#xff0c;返回一个多维数组&#xff0c; 表示图像的像素值。该函数的常用参数包括&#xff1a;flags&#xff1a;指定读取图像的方式&#…

python——Django框架

一、基本介绍 Django 是一个由 Python 编写的一个开放源代码的 Web 应用框架。 使用 Django&#xff0c;只要很少的代码&#xff0c;Python 的程序开发人员就可以轻松地完成一个正式网站所需要的大部分内容&#xff0c;并进一步开发出全功能的 Web 服务 Django 本身基于 MVC …

【已解决】opencv 交叉编译 ffmpeg选项始终为NO

一、opencv 交叉编译没有 ffmpeg &#xff0c;会导致视频打不开 在交叉编译时候&#xff0c;发现在 pc 端能用 opencv 打开的视频&#xff0c;但是在 rv1126 上打不开。在网上查了很久&#xff0c;原因可能是 交叉编译过程 ffmpeg 造成的。之前 ffmpeg 是直接用 apt 安装的&am…

asp.net core 远程调试

大概说下过程&#xff1a; 1、站点发布使用Debug模式 2、拷贝到远程服务器&#xff0c;以及iis创建站点。 3、本地的VS2022的安装目录&#xff1a;C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\IDE下找Remote Debugger 你的服务器是64位就拷贝x64的目…

【从入门到起飞】IO高级流(1)(缓冲流,转换流,序列化流,反序列化流)

&#x1f38a;专栏【JavaSE】 &#x1f354;喜欢的诗句&#xff1a;天行健&#xff0c;君子以自强不息。 &#x1f386;音乐分享【如愿】 &#x1f384;欢迎并且感谢大家指出小吉的问题&#x1f970; 文章目录 &#x1f384;缓冲流&#x1f354;字节缓冲流&#x1f6f8;一次读取…

【通意千问】大模型GitHub开源工程学习笔记(1)--依赖库

9月25日&#xff0c;阿里云开源通义千问140亿参数模型Qwen-14B及其对话模型Qwen-14B-Chat,免费可商用。 立马就到了GitHub去fork。 GitHub&#xff1a; GitHub - QwenLM/Qwen: The official repo of Qwen (通义千问) chat & pretrained large language model proposed b…

协议栈——收发数据(拼接网络包,自动重发,滑动窗口机制)

目录 协议栈何时发送数据&#xff5e; 数据长度 IP模块的分片功能 发送频率 网络包序号&#xff5e;利用syn拼接网络包ack确认网络包完整 确定偏移量 服务器ack确定收到数据总长度 序号作用 双端告知各自序号 协议栈自动重发机制 大致流程 ack等待时间如何调整 是…

Leetcode202. 快乐数

力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 编写一个算法来判断一个数 n 是不是快乐数。 「快乐数」 定义为&#xff1a; 对于一个正整数&#xff0c;每一次将该数替换为它每个位置上的数字的平方和。然后重复这个过程直到这个数变为 1&#xff0…

【单片机】13-实时时钟DS1302

1.RTC的简介 1.什么是实时时钟&#xff08;RTC&#xff09; &#xff08;rtc for real time clock) &#xff08;1&#xff09;时间点和时间段的概念区分 &#xff08;2&#xff09;单片机为什么需要时间点【一定的时间点干什么事情】 &#xff08;3&#xff09;RTC如何存在于…

Redis中Hash类的操作

Redis中Hash类型是键值对的形式保存数据&#xff0c;其中键被称为字段&#xff08;field&#xff09;&#xff0c;值称为字段值&#xff08;value&#xff09;。在一个key中&#xff0c;字段不能重复&#xff0c;而值可以重复。无论是字段还是值都是无序的&#xff08;保存的次…

Java 大厂八股文面试专题-JVM相关面试题 垃圾回收算法 GC JVM调优

Java 大厂八股文面试专题-JVM相关面试题 类加载器_软工菜鸡的博客-CSDN博客 3 垃圾收回 3.1 简述Java垃圾回收机制&#xff1f;&#xff08;GC是什么&#xff1f;为什么要GC&#xff09; 难易程度&#xff1a;☆☆☆ 出现频率&#xff1a;☆☆☆ 为了让程序员更专注于代码的实现…

【Flink】

事件驱动型应用 核心目标&#xff1a;数据流上的有状态计算 Apache Flink是一个框架和分布式处理引擎&#xff0c;用于对无界或有界数据流进行有状态计算。 运行逻辑 状态 把流处理需要的额外数据保存成一个“状态”,然后针对这条数据进行处理,并且更新状态。这就是所谓的“…

MySQL explain SQL分析工具详解与最佳实践

目录 一、explain工具介绍二、添加示例表和数据用于后续演示三、explain中的列3.1、id列3.2、select_type列3.3、table列3.4、partitions列3.5、type列NULLsystemconsteq_refrefrangeindexALL 3.6、possible_keys列3.7、key列3.8、key_len列3.9、ref列3.10、rows列3.11、filter…