Pytorch的自动求导模块

文章目录

  • torch.autograd.backward()
    • 基本用法
    • 非标量张量的反向传播
    • 保留计算图
    • 指定输入张量
    • 高阶梯度计算
  • 与 y.backward() 的区别
  • torch.autograd.grad()
    • 基本用法
    • 非标量张量的梯度
    • 高阶梯度计算
    • 多输入、多输出的梯度计算
    • 未使用的输入张量
    • 保留计算图
  • 与 backward() 的区别

torch.autograd.backward()

该函数实现自动求导梯度,函数如下:

torch.autograd.backward(tensors, grad_tensors=None, retain_graph=False, create_graph=False, inputs=None)

参数介绍:

  • tensors: 需要对其进行反向传播的目标张量(或张量列表),例如:loss。
    这些张量通常是计算图的最终输出。
  • grad_tensors:与 tensors 对应的梯度权重(或权重列表)。
    如果 tensors 是标量张量(单个值),可以省略此参数。
    如果 tensors 是非标量张量(如向量或矩阵),则必须提供 grad_tensors,表示每个张量的梯度权重。例如:当有多个loss需要计算梯度时,需要设置每个loss的权值。
  • retain_graph:是否保留计算图。
    默认值为 False,即反向传播后会释放计算图。如果需要多次反向传播,需设置为 True。
  • create_graph: 是否创建一个新的计算图,用于高阶梯度计算
    默认值为 False,如果需要计算二阶或更高阶梯度,需设置为 True。
  • inputs: 指定需要计算梯度的输入张量(或张量列表)。
    如果指定了此参数,只有这些张量的 .grad 属性会被更新,而不是整个计算图中的所有张量。

基本用法

import torch  # 定义张量并启用梯度计算  
x = torch.tensor(2.0, requires_grad=True)  
y = x ** 2  # y = x^2  # 使用 torch.autograd.backward() 触发反向传播  
torch.autograd.backward(y)  # 查看梯度  
print(x.grad)  # 输出:4.0 (dy/dx = 2x, 当 x=2 时,dy/dx=4)

非标量张量的反向传播

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  
y = x ** 2  # y = [x1^2, x2^2, x3^2]  # 指定 grad_tensors 权重  
grad_tensors = torch.tensor([1.0, 1.0, 1.0])  # 权重  
torch.autograd.backward(y, grad_tensors=grad_tensors)  # 查看梯度  
print(x.grad)  # 输出:[2.0, 4.0, 6.0] (dy/dx = 2x)

保留计算图

如果需要多次调用反向传播,可以设置 retain_graph=True。

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次反向传播  
torch.autograd.backward(y, retain_graph=True)  
print(x.grad)  # 输出:12.0 (dy/dx = 3x^2, 当 x=2 时,dy/dx=12)  # 第二次反向传播  
torch.autograd.backward(y, retain_graph=True)  
print(x.grad)  # 输出:24.0 (梯度累积,12.0 + 12.0)

指定输入张量

通过 inputs 参数,可以只计算指定张量的梯度,而忽略其他张量。

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y = x ** 2 + z ** 3  # y = x^2 + z^3  # 只计算 x 的梯度  
torch.autograd.backward(y, inputs=[x])  
print(x.grad)  # 输出:4.0 (dy/dx = 2x)  
print(z.grad)  # 输出:None (未计算 z 的梯度)

高阶梯度计算

通过设置 create_graph=True,可以构建新的计算图,用于计算高阶梯度。

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次反向传播,创建新的计算图  
torch.autograd.backward(y, create_graph=True)  
print(x.grad)  # 输出:12.0 (dy/dx = 3x^2)  # 计算二阶梯度  
x_grad = x.grad  
x_grad.backward()  
print(x.grad)  # 输出:18.0 (d^2y/dx^2 = 6x)

与 y.backward() 的区别

  • 灵活性:

    • torch.autograd.backward() 更灵活,可以对多个张量同时进行反向传播,并指定梯度权重。
    • y.backward() 是对单个张量的简单封装,适合常见场景。对多个loss求导时,需要指定gradient和grad_outputs相同作用。
  • 梯度权重:

    • torch.autograd.backward() 需要显式提供 grad_tensors 参数(如果目标张量是非标量)。
    • y.backward() 会自动处理标量张量,非标量张量需要手动传入权重。
  • 输入控制:

    • torch.autograd.backward() 可以通过 inputs 参数指定只计算某些张量的梯度。
    • y.backward() 无法直接控制,只会更新计算图中所有相关张量的 .grad。

torch.autograd.grad()

torch.autograd.grad() 是 PyTorch 中用于计算张量梯度的函数,与 backward() 不同的是,它不会更新张量的 .grad 属性,而是直接返回计算的梯度值。它适用于需要手动获取梯度值而不修改计算图中张量的 .grad 属性的场景。

torch.autograd.grad(  outputs,   inputs,   grad_outputs=None,   retain_graph=False,   create_graph=False,   only_inputs=True,   allow_unused=False  
)

参数介绍:

  • outputs:
    目标张量(或张量列表),即需要对其进行求导的输出张量。
  • inputs:
    需要计算梯度的输入张量(或张量列表)。
    这些张量必须启用了 requires_grad=True。
  • grad_outputs:
    与 outputs 对应的梯度权重(或权重列表)。
    如果 outputs 是标量张量,可以省略此参数;如果是非标量张量,则需要提供权重,表示每个输出张量的梯度权重。
  • retain_graph:
    是否保留计算图。
    默认值为 False,即反向传播后会释放计算图。如果需要多次计算梯度,需设置为 True。
  • create_graph:
    是否创建一个新的计算图,用于高阶梯度计算。
    默认值为 False,如果需要计算二阶或更高阶梯度,需设置为 True。
  • only_inputs:
    是否只对 inputs 中的张量计算梯度。
    默认值为 True,表示只计算 inputs 的梯度。
  • allow_unused:
    是否允许 inputs 中的某些张量未被 outputs 使用。
    默认值为 False,如果某些 inputs 未被 outputs 使用,会抛出错误。如果设置为 True,未使用的张量的梯度会返回 None。

返回值:

  • 返回一个元组,包含 inputs 中每个张量的梯度值。
  • 如果某个输入张量未被 outputs 使用,且 allow_unused=True,则对应的梯度为 None。

基本用法

import torch  # 定义张量并启用梯度计算  
x = torch.tensor(2.0, requires_grad=True)  
y = x ** 2  # y = x^2  # 使用 torch.autograd.grad() 计算梯度  
grad = torch.autograd.grad(y, x)  
print(grad)  # 输出:(4.0,) (dy/dx = 2x, 当 x=2 时,dy/dx=4)

非标量张量的梯度

当目标张量是非标量时,需要提供 grad_outputs 参数:

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  
y = x ** 2  # y = [x1^2, x2^2, x3^2]  # 指定 grad_outputs 权重  
grad_outputs = torch.tensor([1.0, 1.0, 1.0])  # 权重  
grad = torch.autograd.grad(y, x, grad_outputs=grad_outputs)  
print(grad)  # 输出:(tensor([2.0, 4.0, 6.0]),) (dy/dx = 2x)

高阶梯度计算

通过设置 create_graph=True,可以计算高阶梯度:

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次计算梯度  
grad = torch.autograd.grad(y, x, create_graph=True)  
print(grad)  # 输出:(12.0,) (dy/dx = 3x^2)  # 计算二阶梯度  
grad2 = torch.autograd.grad(grad[0], x)  
print(grad2)  # 输出:(6.0,) (d^2y/dx^2 = 6x)

多输入、多输出的梯度计算

可以对多个输入和输出同时计算梯度:

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y1 = x ** 2 + z ** 3  # y1 = x^2 + z^3  
y2 = x * z  # y2 = x * z  # 对多个输入计算梯度  
grads = torch.autograd.grad([y1, y2], [x, z], grad_outputs=[torch.tensor(1.0), torch.tensor(1.0)])  
print(grads)  # 输出:(7.0, 11.0) (dy1/dx + dy2/dx, dy1/dz + dy2/dz)

未使用的输入张量

如果某些输入张量未被目标张量使用,需设置 allow_unused=True:

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y = x ** 2  # y = x^2  # z 未被 y 使用  
grad = torch.autograd.grad(y, [x, z], allow_unused=True)  
print(grad)  # 输出:(4.0, None) (dy/dx = 4, z 未被使用,梯度为 None)

保留计算图

如果需要多次计算梯度,可以设置 retain_graph=True:

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次计算梯度  
grad1 = torch.autograd.grad(y, x, retain_graph=True)  
print(grad1)  # 输出:(12.0,)  # 第二次计算梯度  
grad2 = torch.autograd.grad(y, x)  
print(grad2)  # 输出:(12.0,)

与 backward() 的区别

  • 梯度存储
    • torch.autograd.grad() 不会修改张量的 .grad 属性,而是直接返回梯度值。
    • backward() 会将计算的梯度累积到 .grad 属性中。
  • 灵活性:
    • torch.autograd.grad() 可以对多个输入和输出同时计算梯度,并支持未使用的输入张量。
    • backward() 只能对单个输出张量进行反向传播。
  • 高阶梯度:
    • torch.autograd.grad() 支持通过 create_graph=True 计算高阶梯度。
    • backward() 也支持高阶梯度,但需要手动设置 create_graph=True。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/501489.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习之模型评估——混淆矩阵,交叉验证与数据标准化

目录 混淆矩阵 交叉验证 数据标准化 0-1标准化 z 标准化 混淆矩阵 混淆矩阵(Confusion Matrix)是一种用于评估分类模型性能的工具。 它是一个二维表格,其中行表示实际的类别,列表示模型预测的类别。 假设我们有一个二分类问题&…

第R3周:RNN-心脏病预测

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 文章目录 一、前言二、代码流程1、导入包,设置GPU2、导入数据3、数据处理4、构建RNN模型5、编译模型6、模型训练7、模型评估 电脑环境:…

40% 降本:多点 DMALL x StarRocks 的湖仓升级实战

小编导读: 多点 DMALL 成立于2015年,持续深耕零售业,为企业提供一站式全渠道数字零售解决方案 DMALL OS。作为 DMALL OS 数字化能力的技术底座,大数据平台历经多次迭代平稳支撑了公司 To B 业务的快速开展。随着国家产业升级和云原…

C语言——字符函数和内存函数

目录 前言 字符函数 1strlen 模拟实现 2strcpy 模拟实现 3strcat 模拟实现 4strcmp 模拟实现 5strncpy 模拟实现 6strncat 模拟实现 7strncmp 模拟实现 8strstr 模拟实现 9strtok 10strerror 11大小写字符转换函数 内存函数 1memcpy 模拟实现 2…

职场常用Excel基础04-二维表转换

大家好,今天和大家一起分享一下excel的二维表转换相关内容~ 在Excel中,二维表(也称为矩阵或表格)是一种组织数据的方式,其中数据按照行和列的格式进行排列。然而,在实际的数据分析过程中,我们常…

软考教材重点内容 信息安全工程师 第 12 章网络安全审计技术原理与应用

12.1.1 网络安全审计概念 网络安全审计是指对网络信息系统的安全相关活动信息进行获取、记录、存储、分析和利用的工作。网络安全审计的作用在于建立“事后”安全保障措施,保存网络安全事件及行为信息,为网络安全事件分析提供线索及证据,以便…

TT100K数据集, YOLO格式, COCO格式

TT100K交通标志数据集, 标签txt,图像已经分好了测试集,验证集,训练集 1️⃣可以直接导入YOLO进行训练,没有细分类,里面有的类, 闲鱼9.9 解君愁 ,明人不说暗话 https://m.tb.cn/h.T7Ossey?tk…

更改element-plus的table样式

表头样式&#xff1a; <el-table :data"props.tableData" style"width: 100%" :header-cell-style"headerCellStyle" :cell-style"cellStyle"> </el-table>样式&#xff1a; // 表头样式 const headerCellStyle {backgro…

“善弈者”也需妙手,Oclean欧可林:差异化不是说说而已

作者 | 曾响铃 文 | 响铃说 俗话说&#xff0c;“牙痛不是病&#xff0c;痛起来要人命”。这话意思大家都知道&#xff0c;牙痛虽不是什么大病&#xff0c;可一旦发作却是极难忍受。 前几日&#xff0c;Oclean欧可林举办了一场AirPump A10氧气啵啵冲牙器新品品鉴会&#xff…

数字货币支付系统开发搭建:构建未来的区块链支付生态

随着数字货币的迅猛发展&#xff0c;越来越多的企业和机构开始关注如何搭建一个高效、安全、可扩展的数字货币支付系统。区块链技术因其去中心化、安全性高、透明性强等优势&#xff0c;已成为开发数字货币支付系统的首选技术。本文将深入探讨数字货币支付系统的开发和搭建过程…

K8s高可用集群之Kubernetes集群管理平台、命令补全工具、资源监控工具部署、常用命令

K8s高可用集群之Kubernetes管理平台、补全命令工具、资源监控工具部署 1.Kuboard可视化管理平台2.kubectl命令tab补全工具3.MetricsServer资源监控工具4.Kubernetes常用命令 1.Kuboard可视化管理平台 可以选择安装k8s官网的管理平台&#xff1b;我这里是安装的其他开源平台Kub…

cka考试-02-节点维护

一.解答答案 kubectl config use-context ek8s kubectl cordon k8s-node1 kubectl drain k8s-node1 --delete-emptydir-data --ignore-daemonsets --force 二.解答思路 记住这2个cordon,drain,使用kubectl -h 查询使用方法 [root@master ~]# kubectl -h |grep -E cordon…

【pytorch】现代循环神经网络-2

1 双向循环神经网络&#xff08;Bi-RNN&#xff09; 具有单个隐藏层的双向循环神经网络的架构如图所示&#xff1a; 对于任意时间步t&#xff0c;给定一个小批量的输入数据 Xt ∈ Rnd &#xff08;样本数n&#xff0c;每个示例中的输入数d&#xff09;&#xff0c;并且令隐藏层…

服务器等保测评日志策略配置

操作系统日志 /var/log/message 系统启动后的信息和错误日志&#xff0c;是Red Hat Linux中最常用的日志之一 /var/log/secure 与安全相关的日志信息 /var/log/maillog 与邮件相关的日志信息 /var/log/cron 与定时任务相关的日志信息 /var/log/spooler 与UUCP和news设备相关的…

Flutter-插件 scroll-to-index 实现 listView 滚动到指定索引位置

scroll-to-index 简介 scroll_to_index 是一个 Flutter 插件&#xff0c;用于通过索引滚动到 ListView 中的某个特定项。这个库对复杂滚动需求&#xff08;如动态高度的列表项&#xff09;非常实用&#xff0c;因为它会自动计算需要滚动的目标位置。 使用 安装插件 flutte…

我用AI学Android Jetpack Compose之开篇

最近突发奇想&#xff0c;想学一下Jetpack Compose&#xff0c;打算用Ai学&#xff0c;学最新的技术应该要到官网学&#xff0c;不过Compose已经出来一段时间了&#xff0c;Ai肯定学过了&#xff0c;用Ai来学&#xff0c;应该问题不大&#xff0c;学习过程记录下来&#xff0c;…

PHP框架+gatewayworker实现在线1对1聊天--发送消息(6)

文章目录 发送消息原理说明发送功能实现html部分javascript代码PHP代码 发送消息原理说明 接下来我们发送聊天的文本信息。点击发送按钮的时候&#xff0c;会自动将文本框里的内容发送出去。过程是我们将信息发送到服务器&#xff0c;服务器再转发给对方。文本框的id为msgcont…

网络安全 | 信息安全管理体系(ISMS)认证与实施

网络安全 | 信息安全管理体系&#xff08;ISMS&#xff09;认证与实施 一、前言二、信息安全管理体系&#xff08;ISMS&#xff09;概述2.1 ISMS 的定义与内涵2.2 ISMS 的核心标准 ——ISO/IEC 27001 三、信息安全管理体系&#xff08;ISMS&#xff09;认证3.1 认证的意义与价值…

服务器数据恢复—服务器硬盘亮黄灯的数据恢复案例

服务器硬盘指示灯闪烁黄灯是一种警示&#xff0c;意味着服务器硬盘出现故障即将下线。发现这种情况建议及时更换硬盘。 一旦服务器上有大量数据频繁读写&#xff0c;硬盘指示灯会快速闪烁。服务器上某个硬盘的指示灯只有黄灯亮着&#xff0c;而其他颜色的灯没有亮的话&#xff…

AfuseKt1.4.4 | 刮削视频播放器,支持阿里云盘和自动海报墙

AfuseKt是一款功能强大的安卓端在线视频播放器&#xff0c;广泛兼容多种平台如阿里云盘、Alist、WebDAV、Emby、Jellyfin等&#xff0c;同时也支持本地存储视频文件的播放。其特色功能包括自动抓取影片信息生成海报墙展示&#xff0c;充分利用设备硬件进行高清视频流畅播放&…