神经网络:深度学习基础

1.反向传播算法(BP)的概念及简单推导

反向传播(Backpropagation,BP)算法是一种与最优化方法(如梯度下降法)结合使用的,用来训练人工神经网络的常见算法。BP算法对网络中所有权重计算损失函数的梯度,并将梯度反馈给最优化方法,用来更新权值以最小化损失函数。该算法会先按前向传播方式计算(并缓存)每个节点的输出值,然后再按反向传播遍历图的方式计算损失函数值相对于每个参数的偏导数

接下来我们以全连接层,使用sigmoid激活函数,Softmax+MSE作为损失函数的神经网络为例,推导BP算法逻辑。由于篇幅限制,这里只进行简单推导,后续Rocky将专门写一篇PB算法完整推导流程,大家敬请期待。

首先,我们看看sigmoid激活函数的表达式及其导数:

s i g m o i d 表达式: σ ( x ) = 1 1 + e − x sigmoid表达式:\sigma(x) = \frac{1}{1+e^{-x}} sigmoid表达式:σ(x)=1+ex1
s i g m o i d 导数: d d x σ ( x ) = σ ( x ) − σ ( x ) 2 = σ ( 1 − σ ) sigmoid导数:\frac{d}{dx}\sigma(x) = \sigma(x) - \sigma(x)^2 = \sigma(1- \sigma) sigmoid导数:dxdσ(x)=σ(x)σ(x)2=σ(1σ)

可以看到sigmoid激活函数的导数最终可以表达为输出值的简单运算。

我们再看MSE损失函数的表达式及其导数:

M S E 损失函数的表达式: L = 1 2 ∑ k = 1 K ( y k − o k ) 2 MSE损失函数的表达式:L = \frac{1}{2}\sum^{K}_{k=1}(y_k - o_k)^2 MSE损失函数的表达式:L=21k=1K(ykok)2

其中 y k y_k yk 代表ground truth(gt)值, o k o_k ok 代表网络输出值。

M S E 损失函数的偏导: ∂ L ∂ o i = ( o i − y i ) MSE损失函数的偏导:\frac{\partial L}{\partial o_i} = (o_i - y_i) MSE损失函数的偏导:oiL=(oiyi)

由于偏导数中单且仅当 k = i k = i k=i 时才会起作用,故进行了简化。

接下来我们看看全连接层输出的梯度:

M S E 损失函数的表达式: L = 1 2 ∑ i = 1 K ( o i 1 − t i ) 2 MSE损失函数的表达式:L = \frac{1}{2}\sum^{K}_{i=1}(o_i^1 - t_i)^2 MSE损失函数的表达式:L=21i=1K(oi1ti)2

M S E 损失函数的偏导: ∂ L ∂ w j k = ( o k − t k ) o k ( 1 − o k ) x j MSE损失函数的偏导:\frac{\partial L}{\partial w_{jk}} = (o_k - t_k)o_k(1-o_k)x_j MSE损失函数的偏导:wjkL=(oktk)ok(1ok)xj

我们用 δ k = ( o k − t k ) o k ( 1 − o k ) \delta_k = (o_k - t_k)o_k(1-o_k) δk=(oktk)ok(1ok) ,则能再次简化:

M S E 损失函数的偏导: d L d w j k = δ k x j MSE损失函数的偏导:\frac{dL}{dw_{jk}} = \delta_kx_j MSE损失函数的偏导:dwjkdL=δkxj

最后,我们看看那PB算法中每一层的偏导数:

输出层:
∂ L ∂ w j k = δ k K o j \frac{\partial L}{\partial w_{jk}} = \delta_k^K o_j wjkL=δkKoj
δ k K = ( o k − t k ) o k ( 1 − o k ) \delta_k^K = (o_k - t_k)o_k(1-o_k) δkK=(oktk)ok(1ok)

倒数第二层:
∂ L ∂ w i j = δ j J o i \frac{\partial L}{\partial w_{ij}} = \delta_j^J o_i wijL=δjJoi
δ j J = o j ( 1 − o j ) ∑ k δ k K w j k \delta_j^J = o_j(1 - o_j) \sum_{k}\delta_k^Kw_{jk} δjJ=oj(1oj)kδkKwjk

倒数第三层:
∂ L ∂ w n i = δ i I o n \frac{\partial L}{\partial w_{ni}} = \delta_i^I o_n wniL=δiIon
δ i I = o i ( 1 − o i ) ∑ j δ j J w i j \delta_i^I = o_i(1 - o_i) \sum_{j}\delta_j^Jw_{ij} δiI=oi(1oi)jδjJwij

像这样依次往回推导,再通过梯度下降算法迭代优化网络参数,即可走完PB算法逻辑。

2.滑动平均的相关概念

滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving avergae),可以用来估计变量的局部均值,使得变量的更新与一段时间内的历史取值有关

变量 v v v t t t 时刻记为 v t v_{t} vt θ t \theta_{t} θt 为变量 v v v t t t 时刻训练后的取值,当不使用滑动平均模型时 v t = θ t v_{t} = \theta_{t} vt=θt ,在使用滑动平均模型后, v t v_{t} vt 的更新公式如下:

上式中, β ϵ [ 0 , 1 ) \beta\epsilon[0,1) βϵ[0,1) β = 0 \beta = 0 β=0 相当于没有使用滑动平均。

t t t 时刻变量 v v v 的滑动平均值大致等于过去 1 / ( 1 − β ) 1/(1-\beta) 1/(1β) 个时刻 θ \theta θ 值的平均。并使用bias correction将 v t v_{t} vt 除以 ( 1 − β t ) (1 - \beta^{t}) (1βt) 修正对均值的估计。

加入Bias correction后, v t v_{t} vt v b i a s e d t v_{biased_{t}} vbiasedt 的更新公式如下:

t t t 越大, 1 − β t 1 - \beta^{t} 1βt 越接近1,则公式(1)和(2)得到的结果( v t v_{t} vt v b i a s e d 1 v_{biased_{1}} vbiased1 )将越来越接近。

β \beta β 越大时,滑动平均得到的值越和 θ \theta θ 的历史值相关。如果 β = 0.9 \beta = 0.9 β=0.9 ,则大致等于过去10个 θ \theta θ 值的平均;如果 β = 0.99 \beta = 0.99 β=0.99 ,则大致等于过去100个 θ \theta θ 值的平均。

下图代表不同方式计算权重的结果:

如上图所示,滑动平均可以看作是变量的过去一段时间取值的均值,相比对变量直接赋值而言,滑动平均得到的值在图像上更加平缓光滑,抖动性更小,不会因为某种次的异常取值而使得滑动平均值波动很大

滑动平均的优势: 占用内存少,不需要保存过去10个或者100个历史 θ \theta θ 值,就能够估计其均值。滑动平均虽然不如将历史值全保存下来计算均值准确,但后者占用更多内存,并且计算成本更高。

为什么滑动平均在测试过程中被使用?

滑动平均可以使模型在测试数据上更鲁棒(robust)

采用随机梯度下降算法训练神经网络时,使用滑动平均在很多应用中都可以在一定程度上提高最终模型在测试数据上的表现。

训练中对神经网络的权重 w e i g h t s weights weights 使用滑动平均,之后在测试过程中使用滑动平均后的 w e i g h t s weights weights 作为测试时的权重,这样在测试数据上效果更好。因为滑动平均后的 w e i g h t s weights weights 的更新更加平滑,对于随机梯度下降而言,更平滑的更新说明不会偏离最优点很远。比如假设decay=0.999,一个更直观的理解,在最后的1000次训练过程中,模型早已经训练完成,正处于抖动阶段,而滑动平均相当于将最后的1000次抖动进行了平均,这样得到的权重会更加鲁棒。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/223103.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决IDEA编译/启动报错:Abnormal build process termination

报错信息 报错信息如下: Abnormal build process termination: "D:\Software\Java\jdk\bin\java" -Xmx3048m -Djava.awt.headlesstrue -Djava.endorsed.dirs\"\" -Djdt.compiler.useSingleThreadtrue -Dpreload.project.path………………很纳…

Python并行计算和分布式任务全面指南

更多Python学习内容:ipengtao.com 大家好,我是彭涛,今天为大家分享 Python并行计算和分布式任务全面指南。全文2900字,阅读大约8分钟 并发编程是现代软件开发中不可或缺的一部分,它允许程序同时执行多个任务&#xff0…

CloudPulse:一款针对AWS云环境的SSL证书搜索与分析引擎

关于CloudPulse CloudPulse是一款针对AWS云环境的SSL证书搜索与分析引擎,广大研究人员可以使用该工具简化并增强针对SSL证书数据的检索和分析过程。 在网络侦查阶段,我们往往需要收集与目标相关的信息,并为目标创建一个专用文档&#xff0c…

【Linux】进程周边007之进程控制

👀樊梓慕:个人主页 🎥个人专栏:《C语言》《数据结构》《蓝桥杯试题》《LeetCode刷题笔记》《实训项目》《C》《Linux》 🌝每一个不曾起舞的日子,都是对生命的辜负 目录 前言 1.进程创建 2.进程终止 2.…

Java设计模式-原型模式

目录 一、克隆羊问题 二、传统方式解决 三、基本介绍 四、浅拷贝和深拷贝 (一)浅拷贝介绍 (二)深拷贝 五、原型模式深拷贝 (一)重写clone方法 (二)对象序列化 六、注意事项…

CV算法面试题学习

本文记录了CV算法题的学习。 CV算法面试题学习 点在多边形内(point in polygon)高斯滤波器 点在多边形内(point in polygon) 参考自文章1,其提供的代码没有考虑一些特殊情况,所以做了改进。 做法&#xff…

k8s 中部署Jenkins

创建namespace apiVersion: v1 kind: Namespace metadata:name: jenkins创建pv以及pvc kind: PersistentVolume apiVersion: v1 metadata:name: jenkins-pv-volumenamespace: jenkinslabels:type: localapp: jenkins spec:#storageClassName: manualcapacity:storage: 5Giacc…

鸿蒙原生应用再添新丁!喜马拉雅入局鸿蒙

鸿蒙原生应用再添新丁!喜马拉雅入局鸿蒙 来自 HarmonyOS 微博12月20日消息, #喜马拉雅正式完成鸿蒙原生应用版本适配#,作为音频业巨头的喜马拉雅 ,将基于#HarmonyOS NEXT#创造更丰富、更智慧的全场景“声音宇宙”!#鸿…

Spring security之授权

前言 本篇为大家带来Spring security的授权,首先要理解一些概念,有关于:权限、角色、安全上下文、访问控制表达式、方法级安全性、访问决策管理器 一.授权的基本介绍 Spring Security 中的授权分为两种类型: 基于角色的授权&…

mathtype公式章节编号

1. word每章标题后插入章节符 如果插入后显示章节符,需要进行隐藏 开始->样式->MTEquationSection->修改样式->字体,勾选隐藏 2. 设置mathtype公式编号格式 插入编号->格式化->设置格式

Ansible自动化工具之Playbook剧本编写

目录 Playbook的组成部分 实例模版 切换用户 指定声明用户 声明和引用变量,以及外部传参变量 playbook的条件判断 ​编辑 习题 ​编辑 ansible-playbook的循环 item的循环 ​编辑 list循环 ​编辑 together的循环(列表对应的列&#xff0…

【Pytorch】学习记录分享6——PyTorch经典网络 ResNet与手写体识别

【Pytorch】学习记录分享5——PyTorch经典网络 ResNet 1. ResNet (残差网络)基础知识2. 感受野3. 手写体数字识别3. 0 数据集(训练与测试集)3. 1 数据加载3. 2 函数实现:3. 3 训练及其测试: 1. ResNet &…

数据库编程大赛:一条SQL计算扑克牌24点

你是否在寻找一个平台,能让你展示你的SQL技能,与同行们一较高下?你是否渴望在实战中提升你的SQL水平,开阔你的技术视野?如果你对这些都感兴趣,那么本次由NineData主办的《数据库编程大赛》,将是…

IAR安装注册

IAR安装注册 一、 概述 此文记录IAR 开发工具 EWARM-CD-8321-18631 安装注册过程,以及工程编译、调试、烧录下载方法。 二、安装 安装文件: 选择安装 IAR: 一路 NEXT 遇到对话框点击 ”是” 即可安装完成, 图标如下: 三、 注册 先断开网络链接, 不能链接…

Shell编程从入门到实战

Shell 概述 (1)Linux 提供的 Shell 解析器有 [rootflinkTenxun ~]# cat /etc/shells(2)bash 和 sh 的关系 [rootflinkTenxun bin]# ll | grep bash(3)Centos 默认的解析器是 bash [rootflinkTenxun bin]…

【JavaWeb学习笔记】14 - 三大组件其二 Listener Filter

项目代码 https://github.com/yinhai1114/JavaWeb_LearningCode/tree/main/listener https://github.com/yinhai1114/JavaWeb_LearningCode/tree/main/filter API文档JAVA_EE_api_中英文对照版 Listener 一、监听器Listener 1. Listener监听器它是JavaWeb的三大组件之一。 J…

【数据结构】栈的使用|模拟实现|应用|栈与虚拟机栈和栈帧的区别

目录 一、栈(Stack) 1.1 概念 1.2 栈的使用 1.3 栈的模拟实现 1.4 栈的应用场景 1. 改变元素的序列 2. 将递归转化为循环 3. 括号匹配 4. 逆波兰表达式求值 5. 出栈入栈次序匹配 6. 最小栈 1.5 概念区分 一、栈(Stack) 1.1 概念 栈:一种特殊的线性表&…

【【迭代16次的CORDIC算法-verilog实现】】

迭代16次的CORDIC算法-verilog实现 -32位迭代16次verilog代码实现 CORDIC.v module cordic32#(parameter DATA_WIDTH 8d32 , // we set data widthparameter PIPELINE 5d16 // Optimize waveform)(input …

十大经典排序算法(个人总结C语言版)

文章目录 一、前言二、对比1.排序算法相关概念1.1 时间复杂度1.2 空间复杂度1.3 排序方式1.4 稳定度 2.表格比较3.算法推荐3.1 小规模数据3.2 中等规模数据3.3 大规模数据3.4 特殊需求 三、排序算法1.冒泡排序(Bubble Sort)1.1 简介1.2 示例代码&#xf…

【模式识别】探秘判别奥秘:Fisher线性判别算法的解密与实战

​🌈个人主页:Sarapines Programmer🔥 系列专栏:《模式之谜 | 数据奇迹解码》⏰诗赋清音:云生高巅梦远游, 星光点缀碧海愁。 山川深邃情难晤, 剑气凌云志自修。 目录 🌌1 初识模式识…