《动手学深度学习(PyTorch版)》笔记4.7

Chapter4 Multilayer Perceptron

4.7 Forward/Backward Propagation and Computational Graphs

本节将通过一些基本的数学和计算图,深入探讨反向传播的细节。首先,我们将重点放在带权重衰减( L 2 L_2 L2正则化)的单隐藏层多层感知机上。

4.7.1 Forward Propagation

前向传播(forward propagation或forward pass)指的是按顺序(从输入层到输出层)计算和存储神经网络中每层的结果。

我们将一步步研究单隐藏层神经网络的机制,为了简单起见,我们假设输入样本是 x ∈ R d \mathbf{x}\in \mathbb{R}^d xRd,并且我们的隐藏层不包括偏置项。这里的中间变量是:

z = W ( 1 ) x , \mathbf{z}= \mathbf{W}^{(1)} \mathbf{x}, z=W(1)x,

其中 W ( 1 ) ∈ R h × d \mathbf{W}^{(1)} \in \mathbb{R}^{h \times d} W(1)Rh×d是隐藏层的权重参数。将中间变量 z ∈ R h \mathbf{z}\in \mathbb{R}^h zRh通过激活函数 ϕ \phi ϕ后,我们得到长度为 h h h的隐藏激活向量:

h = ϕ ( z ) . \mathbf{h}= \phi (\mathbf{z}). h=ϕ(z).

隐藏变量 h \mathbf{h} h也是一个中间变量。假设输出层的参数只有权重 W ( 2 ) ∈ R q × h \mathbf{W}^{(2)} \in \mathbb{R}^{q \times h} W(2)Rq×h,我们可以得到输出层变量,它是一个长度为 q q q的向量:

o = W ( 2 ) h . \mathbf{o}= \mathbf{W}^{(2)} \mathbf{h}. o=W(2)h.

假设损失函数为 l l l,样本标签为 y y y,我们可以计算单个数据样本的损失项,

L = l ( o , y ) . L = l(\mathbf{o}, y). L=l(o,y).

根据 L 2 L_2 L2正则化的定义,给定超参数 λ \lambda λ,正则化项为

s = λ 2 ( ∥ W ( 1 ) ∥ F 2 + ∥ W ( 2 ) ∥ F 2 ) , s = \frac{\lambda}{2} \left(\|\mathbf{W}^{(1)}\|_F^2 + \|\mathbf{W}^{(2)}\|_F^2\right), s=2λ(W(1)F2+W(2)F2),

∥ X ∥ F \|\mathbf{X}\|_F XF表示矩阵的Frobenius范数:
∥ X ∥ F = ∑ i = 1 m ∑ j = 1 n x i j 2 . \|\mathbf{X}\|_F = \sqrt{\sum_{i=1}^m \sum_{j=1}^n x_{ij}^2}. XF=i=1mj=1nxij2 .
最后,模型在给定数据样本上的正则化损失为:

J = L + s . J = L + s. J=L+s.

在下面的讨论中,我们将 J J J称为目标函数(objective function)。

下图是与上述简单网络相对应的计算图,其中正方形表示变量,圆圈表示操作符。

在这里插入图片描述

4.7.2 Backward Propagation

反向传播(backward propagation或backpropagation)指的是计算神经网络参数梯度的方法,该方法根据链式规则,按相反的顺序从输出层到输入层遍历网络。该算法存储了计算某些参数梯度时所需的任何中间变量(偏导数)。
假设我们有函数 Y = f ( X ) \mathsf{Y}=f(\mathsf{X}) Y=f(X) Z = g ( Y ) \mathsf{Z}=g(\mathsf{Y}) Z=g(Y),其中输入和输出 X , Y , Z \mathsf{X}, \mathsf{Y}, \mathsf{Z} X,Y,Z是任意形状的张量。利用链式法则,我们可以计算 Z \mathsf{Z} Z关于 X \mathsf{X} X的导数:

∂ Z ∂ X = prod ( ∂ Z ∂ Y , ∂ Y ∂ X ) . \frac{\partial \mathsf{Z}}{\partial \mathsf{X}} = \text{prod}\left(\frac{\partial \mathsf{Z}}{\partial \mathsf{Y}}, \frac{\partial \mathsf{Y}}{\partial \mathsf{X}}\right). XZ=prod(YZ,XY).

在这里,我们使用 prod \text{prod} prod运算符在执行必要的操(如换位和交换输入位置)后将其参数相乘。对于高维张量,我们使用适当的对应项。

在上面的计算图中单隐藏层简单网络的参数是 W ( 1 ) \mathbf{W}^{(1)} W(1) W ( 2 ) \mathbf{W}^{(2)} W(2),反向传播的目的是计算梯度 ∂ J / ∂ W ( 1 ) \partial J/\partial \mathbf{W}^{(1)} J/W(1) ∂ J / ∂ W ( 2 ) \partial J/\partial \mathbf{W}^{(2)} J/W(2),计算的顺序与前向传播中执行的顺序相反,具体如下:

∂ J ∂ L = 1 and ∂ J ∂ s = 1. \frac{\partial J}{\partial L} = 1 \; \text{and} \; \frac{\partial J}{\partial s} = 1. LJ=1andsJ=1.

∂ J ∂ o = prod ( ∂ J ∂ L , ∂ L ∂ o ) = ∂ L ∂ o ∈ R q . \frac{\partial J}{\partial \mathbf{o}} = \text{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \mathbf{o}}\right) = \frac{\partial L}{\partial \mathbf{o}} \in \mathbb{R}^q. oJ=prod(LJ,oL)=oLRq.

∂ s ∂ W ( 1 ) = λ W ( 1 ) , ∂ s ∂ W ( 2 ) = λ W ( 2 ) . \frac{\partial s}{\partial \mathbf{W}^{(1)}} = \lambda \mathbf{W}^{(1)} \; \text{,} \; \frac{\partial s}{\partial \mathbf{W}^{(2)}} = \lambda \mathbf{W}^{(2)}. W(1)s=λW(1),W(2)s=λW(2).

∂ J ∂ W ( 2 ) = prod ( ∂ J ∂ o , ∂ o ∂ W ( 2 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 2 ) ) = ∂ J ∂ o h ⊤ + λ W ( 2 ) ∈ R q × h . \frac{\partial J}{\partial \mathbf{W}^{(2)}}= \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \mathbf{W}^{(2)}}\right)= \frac{\partial J}{\partial \mathbf{o}} \mathbf{h}^\top + \lambda \mathbf{W}^{(2)}\in \mathbb{R}^{q \times h}. W(2)J=prod(oJ,W(2)o)+prod(sJ,W(2)s)=oJh+λW(2)Rq×h.

∂ J ∂ h = prod ( ∂ J ∂ o , ∂ o ∂ h ) = W ( 2 ) ⊤ ∂ J ∂ o ∈ R h . \frac{\partial J}{\partial \mathbf{h}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{h}}\right) = {\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}\in \mathbb{R}^h. hJ=prod(oJ,ho)=W(2)oJRh.

由于激活函数 ϕ \phi ϕ是按元素计算的,计算中间变量 z \mathbf{z} z的梯度需要使用按元素乘法运算符,我们用 ⊙ \odot 表示:

∂ J ∂ z = prod ( ∂ J ∂ h , ∂ h ∂ z ) = ∂ J ∂ h ⊙ ϕ ′ ( z ) ∈ R h . \frac{\partial J}{\partial \mathbf{z}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{h}}, \frac{\partial \mathbf{h}}{\partial \mathbf{z}}\right) = \frac{\partial J}{\partial \mathbf{h}} \odot \phi'\left(\mathbf{z}\right)\in \mathbb{R}^h. zJ=prod(hJ,zh)=hJϕ(z)Rh.

∂ J ∂ W ( 1 ) = prod ( ∂ J ∂ z , ∂ z ∂ W ( 1 ) ) + prod ( ∂ J ∂ s , ∂ s ∂ W ( 1 ) ) = ∂ J ∂ z x ⊤ + λ W ( 1 ) = ∂ J ∂ h ⊙ ϕ ′ ( z ) x ⊤ + λ W ( 1 ) = ( W ( 2 ) ⊤ ∂ J ∂ o ) ⊙ ϕ ′ ( z ) x ⊤ + λ W ( 1 ) . \begin{align*} \frac{\partial J}{\partial \mathbf{W}^{(1)}} &= \text{prod}\left(\frac{\partial J}{\partial \mathbf{z}}, \frac{\partial \mathbf{z}}{\partial \mathbf{W}^{(1)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \mathbf{W}^{(1)}}\right) \\ &= \frac{\partial J}{\partial \mathbf{z}} \mathbf{x}^\top + \lambda \mathbf{W}^{(1)} \\ &= \frac{\partial J}{\partial \mathbf{h}} \odot \phi'\left(\mathbf{z}\right)\mathbf{x}^\top + \lambda \mathbf{W}^{(1)} \\ &= ({\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}})\odot \phi'\left(\mathbf{z}\right)\mathbf{x}^\top + \lambda \mathbf{W}^{(1)}. \end{align*} W(1)J=prod(zJ,W(1)z)+prod(sJ,W(1)s)=zJx+λW(1)=hJϕ(z)x+λW(1)=(W(2)oJ)ϕ(z)x+λW(1).

4.7.3 Training Neural Networks

在训练神经网络时,前向传播和反向传播相互依赖。以上述简单网络为例:一方面,在前向传播期间计算正则项取决于模型参数 W ( 1 ) \mathbf{W}^{(1)} W(1) W ( 2 ) \mathbf{W}^{(2)} W(2)的当前值。它们是由优化算法根据最近迭代的反向传播给出的。另一方面,反向传播期间参数的梯度计算,取决于由前向传播给出的隐藏变量 h \mathbf{h} h的当前值。

因此,在训练神经网络时,我们交替使用前向传播和反向传播,利用反向传播给出的梯度来更新模型参数。注意,反向传播重复利用前向传播中存储的中间值,以避免重复计算。这带来的影响之一是我们需要保留中间值,直到反向传播完成,这也是训练比单纯的预测需要更多的内存(显存)的原因之一。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/247095.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

opencv#33 边缘检测

边缘检测原理 图像的每一行每一列都可以看成是一个连续的信号经过离散后得到的数值,例如上图左侧给出的图像由黑色到白色的一个信号,也就是图像中某一行像素变化是由黑色逐渐到白色,我们将其对应在一个坐标轴中,将像素值的大小对应…

【Java基础】聊聊你不知道反射的那些事

在编程语言中,反射是一个绕不过的一个话题,反射、注解、动态代理是支撑框架运行的核心技术。 在Spring中,IOC利用反射实现,创建对象。AOP利用动态代理实现,实现切面编程,配置利用注解实现。所以继上一篇&am…

代码随想录算法训练营第32天 | 122.买卖股票的最佳时机II 55.跳跃游戏 45.跳跃游戏II

买卖股票的最佳时机II 贪心思路 要想使用贪心算法解决此问题,意识到利润是可分解的很关键。比如[1,2,3,4,5]这个输入,最大利润为第一天买入,第五天卖出。这等效于第一天买入,第二天卖出,第二天再买入。。。 prices[4]…

HCS-华为云Stack-FusionSphere

HCS-华为云Stack-FusionSphere FusionSphere是华为面向多行业客户推出的云操作系统解决方案。 FusionSphere基于开放的OpenStack架构,并针对企业云计算数据中心场景进行设计和优化,提供了强大的虚拟化功能和资源池管理能力、丰富的云基础服务组件和工具…

MYSQL基本查询(CURD:创建、读取、更新、删除)

文章目录 前言一、Create1.全列插入2.指定列插入3.插入否则更新4.替换 二、Retrieve1.SELECT列2.WHERE条件3.结果排序4.筛选分页结果 三、Update四、Delete1.删除数据2.截断表 五、插入查询结果六、聚合函数 前言 操作关系型数据库的编程语言,定义了一套操作关系型…

kali系统入侵电脑windows(win11系统)渗透测试,骇入电脑教学

本次渗透测试将使用kali虚拟机(攻击机)对本机(靶机)进行入侵并监控屏幕 声明:本篇仅仅是将本机作为靶机的一次简易渗透测试,实际情况中基本不可能出现如此简单的木马骇入(往往在上传木马时就被防…

Android App开发-简单控件(4)——按钮触控和图像显示

3.4 按钮触控 本节介绍了按钮控件的常见用法,包括:如何设置大小写属性与点击属性,如何响应按钮的点击事件和长按事件,如何禁用按钮又该如何启用按钮,等等。 3.4.1 按钮控件Button 除了文本视图之外,按钮…

clickhouse 安装与入门(单节点安装)

1、简介 Clickhouse 是一个开源的面向联机分析处理(OLAP, On-Line Analytical Processing)的列式存储数据库管理系统。写入快、查询快,支持sql向量化、并行和分布式查询;但是不支持事务,不支持二级索引等。由俄罗斯的Y…

5_机械臂运动学基础_矩阵

上次说的向量空间是为矩阵服务的。 1、学科回顾 从科技实践中来的数学问题无非分为两类:一类是线性问题,一类是非线性问题。线性问题是研究最久、理论最完善的;而非线性问题则可以在一定基础上转化为线性问题求解。 线性变换: 数域…

【jetson笔记】解决vscode远程调试qt.qpa.xcb: could not connect to display报错

配置x11转发 jetson远程安装x11转发 安装Xming Xming下载 安装完成后打开安装目录C:\Program Files (x86)\Xming 用记事本打开X0.hosts文件,添加jetson IP地址 后续IP改变需要重新修改配置文件 localhost 192.168.107.57打开Xlaunch Win菜单搜索Xlaundch打开 一…

openssl3.2 - 测试程序的学习 - test\acvp_test.c

文章目录 openssl3.2 - 测试程序的学习 - test\acvp_test.c概述笔记要单步学习的测试函数备注END openssl3.2 - 测试程序的学习 - test\acvp_test.c 概述 openssl3.2 - 测试程序的学习 将test*.c 收集起来后, 就不准备看makefile和make test的日志参考了. 按照收集的.c, 按照…

【java面试】常见问题(超详细)

目录 一、java常见问题JDK和JRE的区别是什么?Java中的String类是可变的还是不可变的?Java中的equals方法和hashCode方法有什么关系?Java中什么是重载【Overloading】?什么是覆盖【Overriding】?它们有什么区别&#xf…

【计算机网络】概述|分层体系结构|OSI参考模型|TCP/IP参考模型|网络协议、层次、接口

目录 一、思维导图 二、计算机网络概述 1.计算机网络定义、组成、功能 2.计算机网络分类 3.计算机网络发展历史 (1)计算机网络发展历史1:ARPANET->互联网 (2)计算机网络发展历史2:三级结构因特网 …

【JavaWeb】日程管理系统 添加过滤器登录校验 第三期

文章目录 过滤器控制登录校验创建过滤器类修改login原业务方法 总结 过滤器控制登录校验 未添加过滤器 可以直接访问 showShedule.html 需求说明: 未登录状态下不允许访问showShedule.html和SysScheduleController相关增删改处理,重定向到login.html,登录成功后可以自由访问 创…

RabbitMQ进阶篇【理解➕应用】

🥳🥳Welcome 的Huihuis Code World ! !🥳🥳 接下来看看由辉辉所写的关于RabbitMQ的相关操作吧 目录 🥳🥳Welcome 的Huihuis Code World ! !🥳🥳 一.什么是交换机 1.概念释义 2.例…

web前端-------伪类和伪元素

但是,网页中一些特殊的样式,需要用到特殊的CSS选择器来设置。在CSS中,我们把这类选择器称为伪选择器。 伪选择器,分为伪类选择器和伪元素选择器两个大类。 伪类选择器,简称伪类;…

【贪吃蛇:C语言实现】

文章目录 前言1.了解Win32API相关知识1.1什么是Win32API1.2设置控制台的大小、名称1.3控制台上的光标1.4 GetStdHandle(获得控制台信息)1.5 SetConsoleCursorPosition(设置光标位置)1.6 GetConsoleCursorInfo(获得光标…

【DeepLearning-8】MobileViT模块配置

完整代码: import torch import torch.nn as nn from einops import rearrange def conv_1x1_bn(inp, oup):return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, biasFalse),nn.BatchNorm2d(oup),nn.SiLU()) def conv_nxn_bn(inp, oup, kernal_size3, stride1):re…

接口测试入门,如何划分接口文档

1.首先最主要的就是要分析接口测试文档,每一个公司的测试文档都是不一样的。具体的就要根据自己公司的接口而定,里面缺少的内容自己需要与开发进行确认。 我认为一针对于测试而言的主要的接口测试文档应该包含的内容分为以下几个方面。 a.具体的一个业…

一文深度解读多模态大模型视频检索技术的实现与使用

当视频检索叠上大模型Buff。 万乐乐|技术作者 视频检索,俗称“找片儿”,即通过输入一段文本,找出最符合该文本描述的视频。 随着视频社会化趋势以及各类视频平台的快速兴起与发展,「视频检索」越来越成为用户和视频平…