pytorch的梯度图与autograd.grad和二阶求导

前向与反向


  • 这里我们从 一次计算 开始比如 z=f(x,y) 讨论
  • 若我们把任意对于tensor的计算都看为函数(如将 a*b(数值) 看为 mul(a,b)),那么都可以将其看为2个过程:forward-前向,backward-反向
  • 在pytorch中我们通过继承torch.autograd.Function来实现这2个过程,详细的用法和扩展参考:https://pytorch.org/docs/stable/notes/extending.html
  • 【例子】比如我们要实现一个数值乘法z= ( x ∗ y ) 2 (x*y)^2 (xy)2

前向

  • 在前向过程中,我们主要干的事情为:1. 通过输入计算得到输出。2.保存反向传播求导数所需要的tensor到ctx(在反向传播的时候会对应的传入)
  • 【例子】1.为了反向传播求梯度保存x,y(因为我们知道 d z / d x = 2 ∗ x ∗ y 2 , d z / d y = 2 ∗ y ∗ x 2 dz/dx=2*x*y^2, dz/dy=2*y*x^2 dz/dx=2xy2,dz/dy=2yx2), 2.return ( x ∗ y ) 2 (x*y)^2 (xy)2

反向

  • 在反向传播的时候我们干的事情就是,将传入的导数(梯度)和我们在前向过程中保存的tensor进行加工,最终返回每个输入变量的梯度
  • 【例子】此时反向时应该返回 2 ∗ x ∗ x ∗ y 与 2 ∗ x ∗ y ∗ y 2*x*x*y与2*x*y*y 2xxy2xyy,分别对应 d z / d x , d z / d y dz/dx, dz/dy dz/dx,dz/dy

计算图


  • 那么对于一次计算的讨论完了,现在我们来讨论多次计算,即梯度(导数)是如何一步步的从最终的z=f(a), a=g(b), b=w©… 一层层的 传回x的。
  • 那么在pytorch中,其使用了图的数据结构,在一开始 z = x ∗ x ∗ y ∗ y z=x*x*y*y z=xxyy的例子中,z会指向x与y,方便反向传播求梯度(导数),现在若 w = z ∗ z w=z*z w=zz(关于z的函数),那么w会指向z
  • 那么 d w / d x = d w / d z ∗ d z / d x dw/dx = dw/dz * dz/dx dw/dx=dw/dzdz/dx
  • 在这里插入图片描述

需要梯度?,require_grad

  • 但是很多tensor在计算时是不需要梯度的,而保存上面那种梯度图又很费空间,pytroch默认你创建的tensor是不需要梯度的
  • 如当你使用线性层时,实际上是 w ∗ x + b w*x+b wx+b,但是其实这里传入x是需要梯度的(它又不需要学习),而w与b是模型的参数是nn.Parameters,所以他需要学习,自然需要梯度,
  • 而这里pytorch就使用一个bool标记来说明这个tensor需要梯度嘛,若他需要梯度,那么基于他的计算才会有指针指向它
  • 比如, z = x ∗ y z=x*y z=xy,若x,y都require_grad=False,则根本不会建立计算图,若x的requires_grad=True,则该计算会建立z->x的计算图

梯度函数,grad_fn

  • 当你进行了建立了计算图的计算,比如x.requires_grad=True, z = x ∗ y z=x*y z=xy, 那么z.grad_fn就会有函数指针指向反向传播的计算,这里就是这个 x ∗ y x*y xy
  • 在上图中一个节点虽然向回指向多个变量,但其实对应函数指针,其实是指向一个函数 x ∗ x ∗ y ∗ y x*x*y*y xxyy,2个箭头对应的是2个返回值 ( d z / d x , d z / d y ) (dz/dx,dz/dy) (dz/dx,dz/dy),函数指针可以在运算完了后在tensor.grad_fn看到

梯度,grad

  • 当你在正向计算时,构建完了上述的这样一个计算图,你就可以对最终得到的tensor调用backward函数,那么整个计算图就会从最后一个变量还是反向一步一步的将梯度传给每个需要保存梯度的tensor,这时可以在tensor.grad中看到,此时默认情况tensor.grad_fn会被清空。

torch.autograd.grad


  • 大部分情况下,我们都是得到loss,然后loss.backward(),模型的参数对应的每个tensor就会的到梯度,这个时候opt.step()就会根据学习率优化参数
  • 但有时候我们需要手动求导,可以使用 torch.autograd.grad函数

自变量,input

  • x,即对什么求导,当然该tensor必须requires_grad=True,在因变量的同一梯度图的后继

因变量,output

  • y,即被求导的变量,这里结合x相当于求dy/dx

加权,grad_outputs

  • pytorch中求导的因变量必须是一个shape为[1]的tensor,所以比如当backward时,我们往往取loss.sum() 或者mean(), 那么这里y是个大小不定的tensor,那么这个参数就是和y的shape一样,先令(g代表grad_outputs) L = ∑ g i j ∗ y i j L=∑g_{ij}*y_{ij} L=gijyij, 然后L在对x求导,这里求和往往我们取g=torch.ones_like(y), 相当于y.sum()
  • 一般情况下,不同batch之间的计算是独立的,所以得到的y就算sum后,每个x的得到的梯度其实是batch独立的,但是batch_norm除外,因为batch_norm,不同batch的x会与整个batch的均值做运算, 除非你手写batch_norm,并将数据均值对应的tensor mu detach掉,此时mu对于整个梯度图就是一个常数,否则mu会指向不同batch的x,导致每个x的得到的梯度其实不是batch独立的

输出

  • 输出的shape和x一样,即最终的L对于x每个位置的梯度,这也同样解释了为什么必须要对y求和得L,否则每个x中的每个位置其实对应整个tensor,y

二阶求导,create_graph

  • 那么若想二阶求导,我们举个例子z = f(x,y) , z1 = ∂f/∂x, z2 = ∂z1 / ∂y,在代码里,其实x对应X[:,0], y对应X[:,1](也可以其他对应),其实还是一个tensor
  • 那么首先把torch.autograd.grad整个过程又看为一次新的计算,在反向求导求z1时,程序会按照上述过程一步一步反向传播,而反向传播得到梯度时,其实这里又可以形成新的梯度图,z1处于整个新的梯度图顶端
  • 那么对应代码里,使用函数参数create_graph=True来告诉autograd,我这次得到的tensor是需要产生梯度图的(因为可能进一步求导)
  • 那么在代码里,我们相当于
 z1 = torch.autograd.grad(output=y,input=X,...)[0][:, 0]z2 = torch.autograd.grad(output=z1,input=X,...)[0][:, 1]若你的输入X是将x与y合成一体的,如x=X[:,0], y=X[:,1],那么你求导也只能将整体X作为输入,再从答案中获得对应的列
若你直接在求导时input=X[:,i] 这里实际上你创建了一个新tensor X[:,i] -> X, 而Z->X, 并未指向X[:,i], 故不行

在这里插入图片描述

  • 补充: 画2阶导数的梯度图

将一阶导数的梯度图的每一条边作为新的一个节点,例如s-[q]-> e, q作为一个新节点指向e,然后与其他同样从边引申出来的节点进行相乘链接(导数链式法则)一下是 Q = ( x ∗ x ∗ y ) 2 Q=(x*x*y)^2 Q=(xxy)2的例子
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/274580.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3.7号freeRtoS

1. 串口通信 配置串口为异步通信 设置波特率,数据位,校验位,停止位,数据的方向 同步通信 在同步通信中,数据的传输是在发送端和接收端之间通过一个共享的时钟信号进行同步的。这意味着发送端和接收端的时钟需要保持…

进电子厂了,感触颇多...

作者:三哥 个人网站:https://j3code.cn 本文已收录到语雀:https://www.yuque.com/j3code/me-public-note/lpgzm6y2nv9iw8ec 是的,真进电子厂了,但主人公不是我。 虽然我不是主人公,但是我经历的过程是和主…

Qt 实现诈金花的牌面值分析工具

诈金花是很多男人最爱的卡牌游戏 , 每当你拿到三张牌的时候, 生活重新充满了期待和鸟语花香. 那么我们如果判断手中的牌在所有可能出现的牌中占据的百分比位置呢. 这是最终效果: 这是更多的结果: 在此做些简单的说明: 炸弹(有些地方叫豹子) > 同花顺 > 同花 > 顺…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的常见车型识别系统(Python+PySide6界面+训练代码)

摘要:本文深入探讨了如何应用深度学习技术开发一个先进的常见车型识别系统。该系统核心采用最新的YOLOv8算法,并与早期的YOLOv7、YOLOv6、YOLOv5等版本进行性能比较,主要评估指标包括mAP和F1 Score等。详细解析了YOLOv8的工作机制&#xff0c…

Qt/QML编程之路:openglwidget和倒车影像的切换(43)

关于如何实现一个基于OpenGL的3d 图形,这个有很多专门的介绍,我在开发中遇到了这么一个问题: 如何实现一个倒车影像的video显示与一个3D物体显示的切换,因为开窗在同样的一个位置,如果车子倒车启动,则需要将原本显示3D的地方切换为视频图像的显示。 class testOpenGl : …

SpringMVC04、Controller 及 RestFul

4、Controller 及 RestFul 4.1、控制器Controller 控制器复杂提供访问应用程序的行为,通常通过接口定义或注解定义两种方法实现。控制器负责解析用户的请求并将其转换为一个模型。在Spring MVC中一个控制器类可以包含多个方法在Spring MVC中,对于Contr…

【嵌入式】嵌入式系统稳定性建设:最后的防线

🧑 作者简介:阿里巴巴嵌入式技术专家,深耕嵌入式人工智能领域,具备多年的嵌入式硬件产品研发管理经验。 📒 博客介绍:分享嵌入式开发领域的相关知识、经验、思考和感悟。提供嵌入式方向的学习指导、简历面…

ChatGPT 结合实际地图实现问答式地图检索功能基于Function calling

ChatGPT 结合实际地图实现问答式地图检索功能基于Function calling ChatGPT结合实际业务,主要是研发多函数调用(Function Calling)功能模块,将自定义函数通过ChatGPT 问答结果,实现对应函数执行,再次将结果…

k8s-生产级的k8s高可用(2) 25

部署containerd k8s2、k8s3、k8s4在配置前需要重置节点(reset)在上一章已完成 禁用所有节点docker和cri-docker服务 所有节点清除iptables规则 重置后全部节点重启 由于之前部署过docker,因此containerd默认已安装 修改配置 启动containe…

vue 总结

1.vue 的生命周期 1. es6 2. vue 基本属性指令 <template><div><!--<h1>vue基本指令的使用方式</h1><a :href"url">v-bind使用链接</a><img :src"srcUrl" /><div>解决闪烁问题<p v-cloak>{{…

安装zabbix

部署Zabbix监控平台 部署一台Zabbix监控服务器&#xff0c;一台被监控主机&#xff0c;为进一步执行具体的监控任务做准备&#xff1a; 安装LNMP环境源码安装Zabbix安装监控端主机&#xff0c;修改基本配置初始化Zabbix监控Web页面修改PHP配置文件&#xff0c;满足Zabbix需求…

Vue3全家桶 - Pinia - 【1】(安装与使用 + Store + State + Getters + Actions)

Pinia pinia 是 Vue 的专属状态管理库&#xff0c;它允许你跨组件或跨页面共享状态&#xff1b; 一、 安装与使用 pinia 安装语法&#xff1a;yarn add pinia npm install pinia创建一个 pinia &#xff08;根存储&#xff09;并将其传递给应用程序&#xff1a; 目标文件&am…

PaddlePaddle----基于paddlehub的OCR识别

Paddlehub介绍 PaddleHub是一个基于PaddlePaddle深度学习框架开发的预训练模型库和工具集&#xff0c;提供了丰富的功能和模型&#xff0c;包括但不限于以下几种&#xff1a; 1.文本相关功能&#xff1a;包括文本分类、情感分析、文本生成、文本相似度计算等预训练模型和工具。…

【力扣hot100】刷题笔记Day25

前言 这几天搞工作处理数据真是类似我也&#xff0c;还被老板打电话push压力有点大的&#xff0c;还好搞的差不多了&#xff0c;明天再汇报&#xff0c;赶紧偷闲再刷几道题&#xff08;可恶&#xff0c;被打破连更记录了&#xff09;这几天刷的是动态规划&#xff0c;由于很成…

共基法律考点大默写

法是由国家制定或认可的&#xff0c;&#xff0c;能够反应统治阶级意志&#xff0c;反映着被一定物质生活条件决定的统治阶级&#xff08;在社会主义社会是工人阶级为首的广大人民&#xff09;的意志。 指引作用。法律为人们提供既定的行为模式&#xff0c;指引人们在法律范围内…

Qt插件之输入法插件的构建和使用(一)

文章目录 输入法概述输入法插件实现及调用输入键盘搭建定义样式自定义按钮实现自定义可拖动标签数字符号键盘候选显示控件滑动控件手绘输入控件输入法概述 常见的输入法有三种形式: 1.系统级输入法 2.普通程序输入法 3.程序自带的输入法 系统级输入法就是咱们通常意义上的输入…

Vue3全家桶 - Vue3 - 【8】模板引用【ref】(访问模板引用 + v-for中的模板引用 + 组件上的ref)

模板引用【ref】 Vue3官网-模板引用&#xff1b;如果我们需要直接访问组件中的底层DOM元素&#xff0c;可使用vue提供特殊的ref属性来访问&#xff1b; 一、 访问模板引用 在视图元素上采用ref属性来设置需要访问的DOM元素&#xff1a; 该 ref 属性可采用 字符串 值的执行设…

蝙蝠避障:我生活中的一道光

盲人的世界&#xff0c;是无尽的黑暗。看不见光&#xff0c;看不见色彩&#xff0c;甚至看不见自己的手。但在这个黑暗的世界里&#xff0c;我找到了一个光明的出口&#xff1a;一款可以障碍物实时检测的名为蝙蝠避障的盲人软件。 这款软件就像是我的一双眼睛。它通过先进的激光…

第五十六回 徐宁教使钩镰枪 宋江大破连环马-飞桨图像分类套件PaddleClas初探

宋江等人学会了钩镰枪&#xff0c;大胜呼延灼。呼延灼损失了很多人马&#xff0c;不敢回京&#xff0c;一个人去青州找慕容知府。一天在路上住店&#xff0c;马被桃花山的人偷走了&#xff0c;于是到了青州&#xff0c;带领官兵去打莲花山。 莲花山的周通打不过呼延灼&#xf…

【日常记录】【工具】随机生成图片的网站 Lorem Picsum

文章目录 1、介绍2、获取固定宽高的图片3、处理图片缓存4、 Emmet 缩写语法 1、介绍 Lorem Picsum 是一个免费的图片占位符服务&#xff0c;可以用于网站、应用程序或任何需要占位符图片的地方。它提供了一个简单的 API&#xff0c;可以通过 HTTP 请求获取随机图片&#xff0c;…