机器学习扫盲系列(2)- 深入浅出“反向传播”-1

系列文章目录

机器学习扫盲系列(1)- 序
机器学习扫盲系列(2)- 深入浅出“反向传播”-1


文章目录

  • 前言
  • 一、神经网络的本质
  • 二、线性问题
    • 解析解的不可行性
    • 梯度下降与随机梯度下降
    • 链式法则
  • 三、非线性问题
    • 激活函数


前言

反向传播(Backpropagation) 是神经网络中重要且难以理解的概念之一,甚至可以说如果理解了反向传播的工作机制,你就基本理解了神经网络的工作原理。所以我们从“反向传播”切入,由此揭开神经网络的神秘面纱。学习完之后你会发现,反向传播只是一个非常简单的过程,它告诉我们在神经网络中需要怎么改变参数。


一、神经网络的本质

神经网络的本质其实就是根据数据集(点)拟合预测/推理函数(曲线),所以简单来说神经网络其实是一个“极为复杂的曲线拟合机器”。 如下图所示,左图的点作为训练数据,神经网络经过训练之后会拟合处右图的曲线,这样对新数据x,可以预测/推理出y的值。
在这里插入图片描述

二、线性问题

为了简单起见,我们使用一个线性回归模型作为示例来演示。显然,我们在对训练数据拟合成函数(曲线)之前,需要先假设一个预测的函数,如 y = weight * x + bias。有了预测函数之后,我们就可以对每个数据计算损失(实际值和预测值的偏差)。这里我们使用 MSE(mean squared error loss 均方误差) 作为损失函数, 针对预测函数中参数的不同值,我们都可以计算出这批训练数据的总的损失。
在这里插入图片描述
如下图,直线是我们假设的预测函数,图上的点是训练数据。我们试着调整weight 和 bias 的值,看看损失如何变化。
在这里插入图片描述
上面提到,我们要让损失函数的值最小才能得到最完美的预测函数, 那我们先看看损失函数的图形,如下图。x、y轴分别表示 weight 和 bias, z 轴的值就是对应的损失值。大家想想,怎样才能找到损失值最小的位置(weight 和 bias)?
在这里插入图片描述
由此引出梯度(导数,在某个位置的瞬时变化率)的定义。下面的公式给的是只有一个变量x 的情况。因此我们还是先从简单的图形入手。
在这里插入图片描述
假设损失函数 loss = x^2 - 1, 我们得到下面的图形:
在这里插入图片描述
虽然肉眼很容易能看出来最小值在哪个位置,但是对于计算机来说却很难,尤其是函数非常复杂的时候。那怎么才能找到让 loss 最(极)小的位置呢? 首先,我们观察这个图形,在什么情况下 loss 值最小?是不是当梯度为0或者梯度接近于0的时候 loss 最(极)小。这个时候,有人可能会说,那直接令 梯度=0,解析出这个函数的变量x值就行了吧(求解析解)?
其实不然,为什么?

解析解的不可行性

1. 解析解的不可行性,数学复杂度
对于绝大多数模型(如神经网络、支持向量机等),损失函数是高维非线性的,其梯度方程(∇L(θ)=0)常无法分解为闭式解(closed-form solution)。例如:

  • 线性回归可以求得解析解(θ=(XᵀX)⁻¹Xᵀy),但仅因模型是凸且线性。
  • 对复杂模型(如深度学习),损失函数的高度非线性导致方程求解需要多项式时间之外的计算量。

2. 计算资源限制, 维数灾难,高维参数空间的矩阵运算代价极高。例如:

  • 当参数维度为n时,求解线性方程组的计算复杂度为O(n³),而n=1e⁶时需1e¹⁸次运算。
  • 实际深度学习模型的参数规模可达1e⁹量级(如GPT-3),直接求解完全不现实。

3. 非凸优化问题,局部极小值与鞍点

  • 非凸损失函数(如神经网络的损失函数)的梯度为零点可能是局部极小值或鞍点,而全局极小值难以定位。
  • 鞍点尤其在高维空间中普遍存在(概率随维度指数级增长),直接求解梯度为零可能落入低质量解。

4. 数据驱动动态优化,在线学习与大数据场景

  • 当数据集过大时,直接求解需一次性加载全体数据(内存不足问题)。

梯度下降与随机梯度下降

ok,那我们想想如果不用解析解,你会怎么找到最小值?显然,聪明的你肯定会想到这样做:
Prepare阶段:使用训练数据求出损失函数(正向传播)

  1. 随机初始化一个位置值 x1
  2. 计算x1处的梯度d1
  3. 计算 步长=d1 * 学习率(通过调参设置)(这里你会发现一个很巧妙的地方,离最优点越远,步长越长)
  4. 更新位置 x2 = x1 - 步长 = x1 - d1 * 学习率
  5. 循环执行,直到梯度接近0 或者 步数达到最大值

恭喜,你发明了梯度下降算法。

在继续之前,大家想想这个算法有没有什么问题?如果训练数据量非常大,我们的损失函数就会异常复杂,因为计算损失函数时需要将参数加载到内存/显存中,过大的训练数据显然无法运行。怎么办?

把全量数据(epoch)按固定大小随机分组,每次只拿这一组的数据(batch)计算损失函数。

恭喜,你发明了随机梯度下降算法。

以上只是一个参数的情况,如果涉及两个参数比如 weight 和 bias,该怎么处理?
还是让我们先回到一开始的预测函数 y = weight * x + bias。 这里的两个参数weight 和 bias 都会影响 loss值 ,那我们需要计算每个参数对loss的影响程度,即偏导数(梯度),然后根据偏导数不断迭代更新相应参数的值,从而找到最优解(loss 最小)。我们看看有两个参数时,loss的变化情况,为简单起见,我们用二维等高线来画。

在这里插入图片描述
用loss的颜色深浅表示loss值。从图上能很明显看出当 weight = -1, bias = 5时loss为0,也就是我们的目标优化位置。

举个例子,顺便复习一下刚刚梯度下降的步骤,让我们试着从图上随机取点来优化我们的参数(初始化)。

第二步 需要分别计算两个参数在这个位置的偏导数(梯度)
在这里插入图片描述

第三步:分别计算两个参数的步长,第四步: 更新w 和 b
在这里插入图片描述
回顾以上步骤,大家再想想哪一步是比较难的?

对,第3步,怎么计算损失函数的梯度值?

链式法则

先看下这张图:
在这里插入图片描述
如果你想知道蓝色值(loss)如何被粉色值(参数)影响,你会怎么观察?
先从蓝色值开始:

  1. 观察蓝色值被橘色值影响的程度
  2. 观察橘色值被绿色值影响的程度
  3. 观察绿色值被粉色值影响的程度

恭喜,你发明了链式法则! 看公式:

在这里插入图片描述
注意,这里的neuron就是预测函数y。当你想知道loss被w的影响程度(loss 关于 w的梯度), 你可以先计算loss被neuron(预测值)的影响程度,再计算neuron(预测值)被w的影响程度,两个相乘,你就得到了loss 关于 w的梯度。 再观察下,我们是从后往前计算梯度(loss -> 预测值 -> 参数),这也是反向传播这一名词的由来,这样做的好处是可以利用前向传播过程中的计算结果,而不需要重复计算,节约资源和时间。

ok,来看一个具体的示例。
对于 y=wx + b, 我们假设有一条训练数据(x=2.1, y=4),w=1,b=0
在这里插入图片描述
正向传播:
在这里插入图片描述
反向传播:
在这里插入图片描述
在这里插入图片描述
分别更新参数,这里我们设置学习率为 0.1 (lr = 0.1),学习率的更新也是人工/自动调参的一部分:
在这里插入图片描述
看看更新参数后loss值:
在这里插入图片描述
3.61 -> 2.87, loss 下降了!

三、非线性问题

还是回到一开始的图,如果是这些点,你会怎么选择函数来拟合呢?
在这里插入图片描述
聪明的你可能会想到用一个相对复杂的嵌套函数来拟合这个曲线:
y= log(1 + e^(w11 * x + b11)) * w21 + log(1 + e^(w12 * x + b12)) * w22 + b2

恭喜,你发明了神经网络

也就是说,上面的公式其实可以用一个简单的神经网络结构来表示:
在这里插入图片描述
那么它的正向传播过程就可以表示为:
在这里插入图片描述
而反向传播计算梯度的过程就可以这样表示,这里就以w21和w11为例:

在这里插入图片描述
注意: 如果在计算某个参数w11的梯度时用到了另一个参数w21的值,w21的值应该取当前值,而不是优化后的值。等计算完梯度,再统一对所有参数进行迭代更新。

到这里,我们应该能发现,为了拟合更加复杂的曲线,我们可以在每层添加更多神经元、更多层以及在输出中添加非线性(激活函数)来实现。如果我们把整个神经网络视为一个函数,那么通过添加更多神经元和更多层,我们就可以创建一个嵌套更多的函数。这样做有几大好处:

  1. 创建了更多可以调整的参数来拟合输出结果
  2. 保证可微: 可以计算梯度,也就是没有尖锐拐角或断层
  3. 通过添加具有失活点的非线性激活函数,某些神经元可能对输出没有影响,而其他神经元可能变得更加活跃,从而导致输出不必遵循线性约束。

激活函数

我们这里使用的是softplus作为激活函数
在这里插入图片描述
还有比较常见的激活函数:
常见激活函数:

  • Relu:
    在这里插入图片描述
  • sigmoid:
    在这里插入图片描述

未完待续。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/35439.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LabVIEW 线性拟合

该 LabVIEW 程序实现了 线性拟合(Linear Fit),用于计算给定一组数据点的斜率(Slope)和截距(Intercept),并将结果可视化于 XY Graph 中。本案例适用于数据拟合、实验数据分析、传感器…

XSS漏洞靶场---(复现)

XSS漏洞靶场—(复现) 反射型 XSS 的特点是攻击者诱导用户点击包含恶意脚本的 URL,服务器接收到请求后将恶意脚本反射回响应页面,浏览器执行该脚本从而造成攻击,恶意脚本不会在服务器端存储。 Level 1(反射型XSS) 此漏…

优选算法系列(2.滑动窗口 _ 上)

目录 解法⼀(暴力求解)(不会超时,可以通过):一.长度最小的子数组(medium) 题目链接209. 长度最小的子数组 - 力扣(LeetCode) 解法: 代码&#…

ELK(Elasticsearch、Logstash、Kbana)安装及Spring应用

Elasticsearch安装及Spring应用 一、引言二、基本概念1.索引(Index)2.类型(Type)3.文档(Document)4.分片(Shard)5.副本(Replica) 二、ELK搭建1.创建挂载的文件…

Redis,从数据结构到集群的知识总结

Redis基础部分 2. 数据结构 redis底层使用C语言实现,这里主要分析底层数据结构 2.1 动态字符串(SDS) 由于C底层的字符串数组一旦遇到’\0’就会认为这个字符串数组已经结束,意味着无法存储二进制数据(如图片、音频等)&#xff…

【redis】Jedis 操作 Redis 基础指令(下)

列表操作 lpush/rpush 和 lpop/rpop 将一个或者多个元素从左/右侧放入(头/尾插)到 list 中 依次头插 从 list 左/右侧取出元素(即头/尾删) public static void test1(Jedis jedis) { jedis.flushAll(); long n jedis.lpush(…

基于消失点标定前视相机外参

1. 消失点 艺术家&工程师在纸上表现立体图时,常用一种透视法,这种方法源于人们的视觉经验:近大远小,且平行的直线都消失于无穷远处同一个点。就像我们观察两条平行的铁轨时会觉得他们相交于远处的一点,我们把这个点称为消失点。 图1 铁轨组成的消失点 2. 在标定中的应…

TypeScript接口 interface 高级用法完全解析

TypeScript接口 interface 高级用法完全解析 mindmaproot(TypeScript接口高级应用)基础强化可选属性只读属性函数类型高级类型索引签名继承与合并泛型约束设计模式策略模式工厂模式适配器模式工程实践声明合并类型守卫装饰器集成一、接口核心机制深度解析 1.1 类型兼容性原理 …

Vue3 Pinia $subscribe localStorage的用法 Store的组合式写法

Vue3 Pinia $subscribe 可以用来监视Stroe数据的变化 localStorage的用法 localStorage中只能存字符串,所有对象要选转成json字符串 定义store时,从localStorage中读取数据talkList可能是字符串也可能是空数组 Store的组合式写法 直接使用reactiv…

新版AndroidStudio / IDEA上传项目到Gitee

目录 1.Gitee创建仓库 2.填写仓库的信息 3.创建成功后复制仓库的地址 4.检查AndroidStudio是否配置Git 5.点击测试 6.之后Create Git Repository 7.添加到本地仓库 8.提交项目 9.添加上传仓库的地址 10.上传成功 11.去Gitee上刷新检查 1.Gitee创建仓库 2.填写仓库的…

用 Vue 3.5 TypeScript 重新开发3年前甘特图的核心组件

回顾 3年前曾经用 Vue 2.0 开发了一个甘特图组件,如今3年过去了,计划使用Vue 3.5 TypeScript 把组件重新开发,有机会的话再开发一个React版本。 关于之前的组件以前文章 Vue 2.0 甘特图组件 下面录屏是是 用 Vue 3.5 TypeScript 开发的目前…

C语言【数据结构】:时间复杂度和空间复杂度.详解

引言 详细介绍什么是时间复杂度和空间复杂度。 前言:为什么要学习时间复杂度和空间复杂度 算法在编写成可执行程序后,运行时需要耗费时间资源和空间(内存)资源。因此衡量一个算法的好坏,一般是从时间和空间两个维度来衡量的,即时…

Matlab 基于专家pid控制的时滞系统

1、内容简介 Matlab 185-基于专家pid控制的时滞系统 可以交流、咨询、答疑 2、内容说明 略 在处理时滞系统(Time Delay Systems)时,使用传统的PID控制可能会面临挑战,因为时滞会导致系统的不稳定或性能下降。专家PID控制通过结…

MyBatis源码分析のSql执行流程

文章目录 前言一、准备工作1.1、newExecutor 二、执行Sql2.1、getMappedStatement2.2、query 三、Cache装饰器的执行时机四、补充总结 前言 本篇主要介绍MyBatis解析配置文件完成后,执行sql的相关逻辑: public class Main {public static void main(Str…

【MySQL】数据库基础

目录 一、什么是数据库1.1 为什么要有数据库1.2 数据库的本质是什么1.3 在Linux下看一下数据库 二、主流数据库三、基本使用3.1 连接服务器3.2 服务器,数据库,表关系 四、MySQL架构五、SQL分类六、存储引擎6.1 存储引擎是什么6.2 查看存储引擎6.3 存储引…

算是解决可以访问github但无法clone的问题

本文的前提是使用了**且可以正常访问github 查看代理的端口 将其配置到git 首先查看git配置 git config --list然后添加配置,我这边使用的是Hiddfy默认的端口是12334,如果是clash应该是7890 git config --global http.proxy 127.0.0.1:12334其他 删除…

SpringBoot第三站:配置嵌入式服务器使用外置的Servlet容器

目录 1. 配置嵌入式服务器 1.1 如何定制和修改Servlet容器的相关配置 1.server.port8080 2. server.context-path/tx 3. server.tomcat.uri-encodingUTF-8 1.2 注册Servlet三大组件【Servlet,Filter,Listener】 1. servlet 2. filter 3. 监听器…

AdaLoRA 参数 配置:CAUSAL_LM“ 表示因果语言模型任务

AdaLoRA 参数 配置:CAUSAL_LM" 表示因果语言模型任务 config = AdaLoraConfig( init_r=16, # 增加 LoRA 矩阵的初始秩 lora_alpha=32, target_modules=[“q_proj”, “v_proj”], lora_dropout=0.1, bias=“none”, task_type=“CAUSAL_LM” ) 整体功能概述 AdaLoraCon…

IP 协议

文章目录 IP 协议概述数据包格式首部校验和实例分析实例一 分片抓包分析参考 本文为笔者学习以太网对网上资料归纳整理所做的笔记,文末均附有参考链接,如侵权,请联系删除。 IP 协议 概述 IP 协议是 TCP/IP 协议簇中的核心协议,也…

日常开发记录-radioGroup组件

日常开发记录-radioGroup组件 1.前提2.问题&#xff1a;无限循环调用3.解释Vue 事件传播机制分析与无限循环原因解释4.解决 1.前提 在上一章的&#xff0c;我们实现了radio组件。从这进入了解 新增个radioGroup组件呢。 <template><divclass"q-radio-group&quo…