pytorch的backward()的底层实现逻辑

自动微分是一种计算张量(tensors)的梯度(gradients)的技术,它在深度学习中非常有用。自动微分的基本思想是:

  • 自动微分会记录数据(张量)和所有执行的操作(以及产生的新张量)在一个由函数(Function)对象组成的有向无环图(DAG)中。在这个图中,叶子节点是输入张量,根节点是输出张量。通过从根节点到叶子节点追踪这个图,可以使用链式法则(chain rule)自动地计算梯度。
  • 在前向传播(forward pass)中,自动微分同时做两件事:
    • 运行请求的操作来计算一个结果张量,以及
    • 在 DAG 中保留操作的梯度函数。  
    • 在 DAG 中保留操作的梯度函数,这就是说,当你给自动微分一个张量和一个操作,它不仅会计算出结果张量,还会记住这个操作的梯度函数,也就是这个操作对输入张量的导数。例如,如果你给自动微分一个张量 x = [1, 2, 3] 和一个操作 y = x + 1,它不仅会计算出 y = [2, 3, 4],还会记住这个操作的梯度函数是 dy/dx = 1,也就是说,y 对 x 的导数是 1。这样,当你需要计算梯度时,自动微分就可以根据这个梯度函数来计算出结果张量对输入张量的梯度。
  • 在PyTorch中,DAG是动态的。需要注意的一点是,图是从头开始重新创建的;在每个 .backward() 调用之后,autograd开始填充一个新的图。
  • 后向传播开始于当在 DAG 的根节点上调用 .backward() 方法。这个方法会触发自动微分开始计算梯度。
  • 自动微分会从每个 .grad_fn 中计算梯度,这个 .grad_fn 是一个函数对象,它保存了操作的梯度函数。例如,如果一个操作是 y = x + 1,那么它的 .grad_fn 就是 dy/dx = 1。
  • 自动微分会将计算出的梯度累加到相应张量的 .grad 属性中,这个 .grad 属性是一个张量,它保存了结果张量对输入张量的梯度。例如,如果一个结果张量是 y = [2, 3, 4],那么它的 .grad 属性就是 [1, 1, 1],表示 y 对 x 的梯度是 1。
  • 使用链式法则(chain rule),自动微分会一直向后传播,直到到达叶子张量。链式法则是一种数学公式,它可以将复合函数的梯度分解为简单函数的梯度的乘积。例如,如果一个复合函数是 z = f(g(x)),那么它的梯度是 dz/dx = dz/dg * dg/dx。

import torch
import torch.nn as nn
M = nn.Linear(2, 2) # neural network module
M.eval() # set M to evaluation mode
with torch.no_grad(): # disable gradient computationfor param in M.parameters(): # loop over all parametersparam.fill_(1) # fill the parameter with 1
M.requires_grad_(False)a = torch.tensor([1., 2.], requires_grad=True) # leaf node
b = torch.tensor([13., 32.], requires_grad=True) # leaf node
c = M(a) # non-leaf node
c2 = M(b) # non-leaf node
d = c * 2  # non-leaf node
d.sum().backward() # compute gradients
print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(M.weight.grad) # None

构建计算图:当我们调用backward()方法时,PyTorch会自动构建从叶子节点a到损失值d.sum()的计算图,这是一个有向无环图,表示了各个张量之间的运算关系。计算图中还包含了两个中间变量c和d,它们是由a经过M模型的前向传播得到的。计算图的作用是记录反向传播的路径,以便于计算梯度。 计算梯度:在计算图中,每个张量都有一个属性grad,用于存储它的梯度值。当我们调用backward()方法时,PyTorch会沿着计算图按照链式法则计算并填充每个张量的grad属性。由于我们只对叶子节点a的梯度感兴趣,所以只有a的grad属性会被计算出来,而中间变量c和d的grad属性会被忽略。a的grad属性的值是损失值d.sum()对a的偏导数,表示了a的变化对损失值的影响。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/199086.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是Mock?为什么要使用Mock呢?

1、前言 在日常开发过程中,大家经常都会遇到:新需求来了,但是需要跟第三方接口来对接,第三方服务还没好,我们自己的功能设计如何继续呢?这里,给大家推荐一下Mock方案。 2、场景示例 2.1、场景一…

HTTP四种请求方式,状态码,请求和响应报文

1.get请求 一般用于获取数据请求参数在URL后面请求参数的大小有限制 2.post请求 一般用于修改数据提交的数据在请求体中提交数据的大小没有限制 3.put请求 一般用于添加数据 4.delete请求 一般用于删除数据 5.一次完整的http请求过程 域名解析:使用DNS协议…

【技术指南资料】编码器与正交译码器

我想提出一个关于PicoScope7新的译码器功能讨论。它已经推出一段时间,但你可能不知道这在汽车领域是扮演相当重要的角色。 正交译码器被用在转子位置传感器来转换关于旋转轴角度及方向的信息。 举例来说,它在电机上采用一对二进制的信号型式。 这种传感器…

渗透测试流程是什么?7个步骤给你讲清楚!

在学习渗透测试之初,有必要先系统了解一下它的流程,静下心来阅读一下,树立一个全局观,一步一步去建设并完善自己的专业领域,最终实现从懵逼到牛逼的华丽转变。渗透测试是通过模拟恶意黑客的攻击方法,同时也…

场景交互与场景漫游-对象选取(8-2)

对象选取示例的代码如程序清单8-11所示: /******************************************* 对象选取示例 *************************************/ // 对象选取事件处理器 class PickHandler :public osgGA::GUIEventHandler { public:PickHandler() :_mx(0.0f), _my…

TableUtilCache:针对CSV表格进行的缓存

TableUtilCache:针对CSV表格进行的缓存 文件结构 首先来看下CSV文件的结构,如下图: 第一行是字段类型,第二行是字段名字;再往下是数据。每个元素之间都是使用逗号分隔。 看一下缓存里面存储所有表数据的字段 如下图&#xff…

【心得】基于flask的SSTI个人笔记

目录 计算PIN码 例题1 SSTI的引用链 例题2 SSTI利用条件: 渲染字符串可控,也就说模板的内容可控 我们通过模板 语法 {{ xxx }}相当于变相的执行了服务器上的python代码 利用render_template_string函数参数可控,或者部分可控 render_…

基于SSM的供电公司安全生产考试系统设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:Vue 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目:是 目录…

算法之路(二)

🖊作者 : D. Star. 📘专栏 : 算法小能手 😆今日分享 : 你知道北极熊的皮肤是什么颜色的吗?(文章结尾有答案哦!) 文章目录 力扣的209题✔解题思路✔代码:✔总结: 力扣的3题✔解题思路&#xff1a…

【计算机视觉】24-Object Detection

文章目录 24-Object Detection1. Introduction2. Methods2.1 Sliding Window2.2 R-CNN: Region-Based CNN2.3 Fast R-CNN2.4 Faster R-CNN: Learnable Region Proposals2.5 Results of objects detection 3. SummaryReference 24-Object Detection 1. Introduction Task Defin…

Android Studio常见问题

Run一直是上次的apk 内存占用太大,导致闪退

基于Python3的scapy解析SSL报文

scapy对于SSL的支持个人觉得不太好,至少在构造报文方面没有HTTP或者DNS这种常见的报文有效方便,但是scapy对于SSL的解析还是可以的。下面我们以一个典型的HTTPS的报文为例,展示scapy解析SSL报文。 一:解析ClientHello报文 from sc…

【zabbix监控三】zabbix之部署代理服务器

一、部署代理服务器 分布式监控的作用: 分担server的几种压力解决多机房之间的网络延时问题 1、搭建proxy主机 1.1 关闭防火墙,修改主机名 systemctl disbale --now firewalld setenforce 0 hostnamectl set-hostname zbx-proxy su1.2 设置zabbix下…

Docker 可视化面板 ——Portainer

Portainer 是一个非常好用的 Docker 可视化面板,可以让你轻松地管理你的 Docker 容器。 官网:Portainer: Container Management Software for Kubernetes and Docker 【Docker系列】超级好用的Docker可视化工具——Portainer_哔哩哔哩_bilibili 环境 …

服务注册发现 springcloud netflix eureka

文章目录 前言角色(三个) 工程说明基础运行环境工程目录说明启动顺序(建议):运行效果注册与发现中心服务消费者: 代码说明服务注册中心(Register Service)服务提供者(Pro…

三十二、W5100S/W5500+RP2040树莓派Pico<UPnP示例>

文章目录 1 前言2 简介2 .1 什么是UPnP?2.2 UPnP的优点2.3 UPnP数据交互原理2.4 UPnP应用场景 3 WIZnet以太网芯片4 UPnP示例概述以及使用4.1 流程图4.2 准备工作核心4.3 连接方式4.4 主要代码概述4.5 结果演示 5 注意事项6 相关链接 1 前言 随着智能家居、物联网等…

ESP32-BLE基础知识

一、存储模式 两种存储模式: 大端存储:低地址存高字节,如将0x1234存成[0x12,0x34]。小端存储:低地址存低字节,如将0x1234存成[0x34,0x12]。 一般来说,我们看到的一些字符串形式的数字都是大端存储形式&a…

服务器端请求伪造(SSRF)

概念 SSRF(Server-Side Request Forgery,服务器端请求伪造) 是一种由攻击者构造形成的由服务端发起请求的一个安全漏洞。一般情况下,SSRF是要攻击目标网站的内部系统。(因为内部系统无法从外网访问,所以要把目标网站当做中间人来…

Flutter 中在单个屏幕上实现多个列表

今天,我将提供一个实际的示例,演示如何在单个页面上实现多个列表,这些列表可以水平排列、网格格式、垂直排列,甚至是这些常用布局的组合。 下面是要做的: 实现 让我们从创建一个包含产品所有属性的产品模型开始。 …

Android描边外框stroke边线、rotate旋转、circle圆形图的简洁通用方案,基于Glide与ShapeableImageView,Kotlin

Android描边外框stroke边线、rotate旋转、circle圆形图的简洁通用方案,基于Glide与ShapeableImageView,Kotlin 利用ShapeableImageView专门处理圆形和外框边线的特性,通过Glide加载图片装载到ShapeableImageView。注意,因为要描边…