transformer学习笔记-神经网络原理

在深度学习领域,transformer可以说是在传统的神经网络的基础上发展而来,着重解决传统神经网络长距离关联、顺序处理、模型表达能力等问题。
在学习transformer之前,我想,有必要先对传统的神经网络做简要的了解。

一、神经网络基本结构

在这里插入图片描述
输入层: 神经网络的第一层,直接接收原始输入数据,传递到隐藏层。
隐藏层: 位于输入层和输出层之间,可以有一个或多个隐藏层,多层隐藏层逐步提取更高层次的抽象特征。
输出层: 最后一层,负责生成最终的预测结果。
下面以CNN(卷积神经网络)为例,大致说明下各层间的关系:

1.1、从原始数据到输入层

在这里插入图片描述

假设一张经过灰度处理的数字6、大小N * N个像素的图片,每个像素点的取值经过归一化处理(取值0-1),将N * N个像素点展开成神经网络输入层的每个神经元。

1.2、隐藏层特征提取

在这里插入图片描述

  • 当输入层的各个节点数据传输到第一个隐藏层时,提取多个最基础的形状特征的神经元,
  • 这些神经元在传输给第一个隐藏层后,第二个隐藏层提取更高层次的形状或物体结构。也就是更高层次的特征。
  • 最后,第二个隐藏层的神经元通过权重计算,给每个输出层的特征节点赋值,取值最高的特征节点,就是输入图片最终匹配的预定特征。

那么,每层之间的数据传递到底是怎样的?我们以输入层到第一个隐藏层为例:
在这里插入图片描述

  • 设n = N*N ,那么输入层就有a0-an 个神经元,每个连线代表输入层每个输入对下一层某个神经元的权重,n个输入就有n个权重(全连接的情况,全连接就是每个输入都对下一层的每个节点有影响)。
  • 那么下一层的单个神经元就等于上一层所有神经元的加权和,即 :

    a0ω0 + a1ω1 + … +an*ωn

  • 在识别特征过程中,可能需要对不明显的特征进行过滤,这也就要求加权和需要大于某个数,这个值我们称之为偏置量,在线性表达上,偏置可以是模型更好拟合数据,可以表示不仅仅穿过原点的线性关系。

    (a0ω0 + a1ω1 + … +an*ωn ) +(-b)

  • 当然,有时为了优化模型的表达,比如提高收敛速度、避免梯度消失,需要将下一层神经元的取值限制在有界范围,这时候需要通过激活函数做一些非线性的操作,比如使用SIGMOD函数,将值域限定在(0,1):

    σ((a0ω0 + a1ω1 + … +an*ωn ) +(-b)) f = σ(…)为激活函数

上面仅仅是表示了第二层的一个神经元的计算,那第二层所有的神经元该如何表达:
在这里插入图片描述

加上偏置和激活函数表示如下:
在这里插入图片描述

后面隐藏层到输出层,结构也类似。通过线性表达则是:

y = W* a + b; 其中y为输出矩阵,W为权重矩阵,a为输入矩阵,b为偏置矩阵

那么W和b 如何而来,当然是通过大量的数据训练而来,后面我们将逐步学习如何通过损失函数训练,

二、损失函数

以最后一个隐藏层到输出层为例:假定这个过程的权重矩阵初始化为W,偏置初始化为b,预测值为y′。那么真实值y与预测值之间的误差就可以如下表示(以均方差为例):
假设10个样本:
在这里插入图片描述

更一般的表示: MSE= 1 n ∑ i = 1 n ( y ′ i − y i ) 2 \frac{1}{n} ∑_{i=1}^{n} (y′_i - y_i )^2 n1i=1n(yiyi)2 或 L = 1 n ( y ′ − y ) 2 \frac{1}{n}(y′ - y)^2 n1(yy)2
n为样本数量

除了均方差损失函数,常用的损失函数还有交叉熵损失函数,通常用于多分类概率分布:
在这里插入图片描述
二分类可以看做多分类的一种特例:
在这里插入图片描述

三、梯度下降法与反向传播

3.1 利用损失函数求梯度

得到损失函数后,我们得根据损失函数反向推导哪个权重影响比较大,权重矩阵就重点关注影响损失函数较大的权重:
在这里插入图片描述
那回到损失函数MSE= 1 n ∑ i = 1 n ( y ′ i − y i ) 2 \frac{1}{n} ∑_{i=1}^{n} (y′_i - y_i )^2 n1i=1n(yiyi)2 ,则是对其关于W和b求偏导以获取梯度,梯度越大,斜率越大,对误差影响越大。
由于 y′ = W* a + b ,因此关于W和b求偏导,是个复合函数求偏导:

关于W的梯度: ∂ M S E ∂ W = ∂ M S E ∂ y ′ ∗ ∂ y ′ ∂ W = 2 n ∑ i = 1 n ( y ′ i − y i ) ∗ a T \frac{\partial MSE}{\partial W} =\frac{\partial MSE}{\partial y′} * \frac{\partial y′}{\partial W} = \frac{2}{n} ∑_{i=1}^{n} (y′_i - y_i ) * a^T WMSE=yMSEWy=n2i=1n(yiyi)aT
对于单个y’(即单行W与单列a的内积,忽略偏置) 基于W求偏导,得到的应该是1行n列的矩阵,由于a是n行一列的矩阵,因此需要对a做转置处理,例如:
对 ω 0 求导: ∂ y ′ ∂ ω 0 = ∂ a 0 ∗ ω 0 + a 1 ∗ ω 1 + . . . . . . + a n ∗ ω n ω 0 = a 0 对ω0求导:\frac{\partial y′}{\partial ω0} =\frac{\partial a0*ω0 + a1*ω1 + ...... +an*ωn}{ω0} = a0 ω0求导:ω0y=ω0a0ω0+a1ω1+......+anωn=a0
对 ω 1 求导: ∂ y ′ ∂ ω 1 = ∂ a 0 ∗ ω 0 + a 1 ∗ ω 1 + . . . . . . + a n ∗ ω n ω 1 = a 1 对ω1求导:\frac{\partial y′}{\partial ω1} =\frac{\partial a0*ω0 + a1*ω1 + ...... +an*ωn}{ω1} = a1 ω1求导:ω1y=ω1a0ω0+a1ω1+......+anωn=a1
对 ω n 求导: ∂ y ′ ∂ ω n = ∂ a 0 ∗ ω 0 + a 1 ∗ ω 1 + . . . . . . + a n ∗ ω n ω n = a n 对ωn求导:\frac{\partial y′}{\partial ωn} =\frac{\partial a0*ω0 + a1*ω1 + ...... +an*ωn}{ωn} = an ωn求导:ωny=ωna0ω0+a1ω1+......+anωn=an
因此 ∂ y ′ ∂ W = [ a 0 , a 1 , . . . , a n ] = a T 因此\frac{\partial y′}{\partial W} = [ a0,a1,...,an] = a^T 因此Wy=[a0,a1,...,an]=aT
多行W与多列a以此类推。

关于b的梯度:
∂ M S E ∂ b = 2 n ∑ i = 1 n ( y ′ i − y i ) \frac{\partial MSE}{\partial b} = \frac{2}{n} ∑_{i=1}^{n} (y′_i - y_i ) bMSE=n2i=1n(yiyi)

因此根据求解的梯度,反向传播,使用梯度下降法修改权重与偏置矩阵

权重梯度调整: W = W − η ∂ M S E ∂ W 权重梯度调整:W = W - η \frac{\partial MSE}{\partial W} 权重梯度调整:W=WηWMSE
偏置梯度调整: b = b − η ∂ M S E ∂ b 偏置梯度调整:b = b - η \frac{\partial MSE}{\partial b} 偏置梯度调整:b=bηbMSE
η表示学习率(步长),η越大,收敛速度越快,但是越容易错过最佳值,η越小收敛速度越慢。

通过多轮训练优化,最终得到能够准确预测的权重矩阵。
看公式可能很头疼,码农肯定看代码更容易理解:

 import numpy as np# 定义均方差损失函数
def mean_squared_error(y_true, y_pred):return np.mean((y_true - y_pred) ** 2)# 定义均方差损失函数的梯度
def mse_gradient(y_true, y_pred, X):n = len(y_true)gradient_w = 2 * np.dot(X.T, (y_true - y_pred)) / ngradient_b = 2 * np.mean(y_true - y_pred)return gradient_w, gradient_b# 示例数据
X = np.array([[1, 2], [2, 3], [3, 4], [4, 5]])  # 输入特征
y_true = np.array([2, 3, 4, 5])  # 真实值
w = np.array([0.5, 0.5])  # 初始权重
b = 0.0  # 初始偏置# 计算预测值
y_pred = np.dot(X, w) + b# 计算均方差损失
mse = mean_squared_error(y_true, y_pred)
print("Mean Squared loss:", mse)# 计算梯度
gradient_w, gradient_b = mse_gradient(y_true, y_pred, X)
print("Gradient of weights:", gradient_w)
print("Gradient of bias:", gradient_b)

3.2 交叉熵损失函数

均方差相比比较好理解,但是对于概率等较小的预测值,反应不灵敏,有时不能有效反应误差,如何对较小的值如(0,1)更好的做损失分析呢,聪明的你肯定可以想到用log(x):
在这里插入图片描述
但从取现观察看,在x∈(0,1)时,x越小,斜率的绝对值越大,说明log(x)对越小的值越敏感,从概率考虑,也就说,概率越小的事件,越不该发生,发生了越应该被注意到,越应该调整它的权重,减少它对误差的影响。
交叉熵损失函数,就是应用了这一原理。交叉熵损失函数应用在softmax()处理后形成了概率分布的输出层,softmax处理后的输出即为预测的概率分布。这里不对交叉熵损失函数求梯度做进一步学习。

3.3优化器

在我们实际的开发过程中,可以直接使用torch提供的优化器完成反向传播梯度优化。
常见的优化器有如下SGD(随机梯度)、动量(Momentum)、RMSProp、Adam等几种,极大方便模型开发训练工作,后续有机会再认真学习优化器的原理。再第五章的示例代码中,也将简单应用优化器完成模型训练。

四、激活函数

激活函数主要非线性的操作,用来对神经元的输出限制值域,以求模型有更好的表现力,能够更好学习拟合复杂的线性关系,解决梯度消失,平滑分布,矩阵稀疏等问题。
以下是常见的一些激活函数和使用场景:

4.1 SIGMOD函数

σ(x) = 1 1 + e − x \frac{1}{1+e^{-x}} 1+ex1

特点: 输出范围(0,1),可用于表示二分类概率,也可用于平滑梯度,但在两端附近的值梯度趋零,易导致梯度消失,因此深层网络通常采用ReLU函数。
在这里插入图片描述

4.2 ReLU函数

f(x)=max(0,x)

特点: 输出范围在 [0, ∞) 之间,对于正输入值保持不变,有助于缓解梯度下降;而对于负输入值输出为0。可以产生稀疏激活,对于负输入值,输出为0,减少模型的复杂度和过拟合风险,但也会导致输出为0的神经元,后续不会再被激活。因此可引入ReLU函数的变体,如:

f(x)=max(αx,x); α通常取非常小的正数,如0.01

在这里插入图片描述

当然还有如Tanh 函数、Softmax 函数、ELU 函数等激活函数,这里就不一一介绍了。Softmax 函数主要用来生成概率分布,在自注意力机制中也有用到,到时再单独做简要说明。

五、完整示例

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/479463.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【前端】JavaScript中的字面量概念与应用详解

博客主页: [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: 前端 文章目录 💯前言💯字面量1. 数字字面量2. 字符串字面量3. 布尔字面量4. 空值字面量(null)5. 对象字面量6. 数组字面量7. 正则表达式字面量8. 特殊值字面量9. 函数字…

字节跳动青训营刷题笔记19

问题描述 小R正在组织一个比赛,比赛中有 n 支队伍参赛。比赛遵循以下独特的赛制: 如果当前队伍数为 偶数,那么每支队伍都会与另一支队伍配对。总共进行 n / 2 场比赛,且产生 n / 2 支队伍进入下一轮。如果当前队伍数为 奇数&…

Python中的简单爬虫

文章目录 一. 基于FastAPI之Web站点开发1. 基于FastAPI搭建Web服务器2. Web服务器和浏览器的通讯流程3. 浏览器访问Web服务器的通讯流程4. 加载图片资源代码 二. 基于Web请求的FastAPI通用配置1. 目前Web服务器存在问题2. 基于Web请求的FastAPI通用配置 三. Python爬虫介绍1. 什…

【ArcGISPro】使用AI提取要素-土地分类(sentinel2)

Sentinel2数据处理 【ArcGISPro】Sentinel-2数据处理-CSDN博客 土地覆盖类型分类 处理结果

WinForm 的Combox下拉框 在FlatStyle.Flat的边框设置

现象:Combox在设置FlatStyle.Flat时边框不见了 效果: 解决问题思路封装新控件: public class DBorderComboBox : ComboBox {private const int WM_PAINT 0xF;[Browsable(true)][Category("Appearance")][Description("边框…

Python 爬虫入门教程:从零构建你的第一个网络爬虫

网络爬虫是一种自动化程序,用于从网站抓取数据。Python 凭借其丰富的库和简单的语法,是构建网络爬虫的理想语言。本文将带你从零开始学习 Python 爬虫的基本知识,并实现一个简单的爬虫项目。 1. 什么是网络爬虫? 网络爬虫&#x…

使用UE5.5的Animator Kit变形器

UE5.5版本更新了AnimatorKit内置插件,其中包含了一些内置变形器,可以辅助我们的动画制作。 操作步骤 首先打开UE5.5,新建第三人称模板场景以便测试,并开启AnimatorKit组件。 新建Sequence,放入测试角色 点击角色右…

Uniapp 安装安卓、IOS模拟器并调试

一、安装Android模拟器并调试 1. 下载并安装 Android Studio 首先下载 Mac 环境下的 Android Studio 的安装包,为dmg 格式。 下载完将Android Studio 向右拖拽到Applications中,接下来等待安装完成就OK啦! 打开过程界面如下图所示&#xf…

shell(5)字符串运算符和逻辑运算符

声明! 学习视频来自B站up主 泷羽sec 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章,笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习交流,其他均与本人以及泷羽sec团队无关&#…

【金蝶双线指标】以看资金进出操作为主,兼顾波段跟踪和短线低吸

如上图,个股副图指标,大佬资金监控短线低吸攻击线操盘线趋势红蝴蝶,五大功能于一体。下面慢慢给大家仔细分享。 大佬资金监控指标,红绿进出,绿色缩小到极致,接近零轴,红绿柱分界线,为…

多输入多输出 | Matlab实现TCN-GRU时间卷积神经网络结合门控循环单元多输入多输出预测

多输入多输出 | Matlab实现TCN-GRU时间卷积神经网络结合门控循环单元多输入多输出预测 目录 多输入多输出 | Matlab实现TCN-GRU时间卷积神经网络结合门控循环单元多输入多输出预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 多输入多输出 | Matlab实现TCN-GRU时间卷积…

HCIA笔记4--VLAN划分

1. vlan是什么 vlan: virtual lan; 虚拟局域网的简称。 主要目的是隔离广播域。 2. vlan报文格式 在普通的以太网数据帧开关的12字节后添加4字节的vlan tag。而来区分vlan的是其中的vid部分12个比特位,范围自然就是0~2^12-1(0~4095); 0 4095保留使用。实际使用的是…

蓝牙定位的MATLAB仿真程序|基于信号强度的定位,平面、四个蓝牙基站(附源代码)

这段代码通过RSSI信号强度实现了蓝牙定位,展示了如何使用锚点位置和测量的信号强度来估计未知点的位置。它涵盖了信号衰减模型、距离计算和最小二乘法估计等基本概念。通过图形化输出,用户可以直观地看到真实位置与估计位置的关系。 文章目录 蓝牙定位原…

基于Springboot企业级工位管理系统【附源码】

基于Springboot企业级工位管理系统 效果如下: 系统登录页面 员工主页面 部门信息页面 员工管理页面 部门信息管理页面 工位信息管理页面 工位分配管理页面 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所。…

Spring Boot教程之十: 使用 Spring Boot 实现从数据库动态下拉列表

使用 Spring Boot 实现从数据库动态下拉列表 动态下拉列表(或依赖下拉列表)的概念令人兴奋,但编写起来却颇具挑战性。动态下拉列表意味着一个下拉列表中的值依赖于前一个下拉列表中选择的值。一个简单的例子是三个下拉框,分别显示…

SpringBoot源码-spring boot启动入口ruan方法主线分析(一)

一、SpringBoot启动的入口 1.当我们启动一个SpringBoot项目的时候,入口程序就是main方法,而在main方法中就执行了一个run方法。 SpringBootApplication public class StartApp {public static void main(String[] args) {// testSpringApplication.ru…

AI 助力开发新篇章:云开发 Copilot 深度体验与技术解析

本文 一、引言:技术浪潮中的个人视角1.1 AI 和低代码的崛起1.2 为什么选择云开发 Copilot? 二、云开发 Copilot 的核心功能解析2.1 自然语言驱动的低代码开发2.1.1 自然语言输入示例2.1.2 代码生成的模块化支持 2.2 实时预览与调整2.2.1 实时预览窗口功能…

vscode的markdown扩展问题

使用vscode编辑markdown文本时,我是用的是Office Viewer(Markdown Editor)这个插件 今天突然发现不能用了,点击切换编辑视图按钮时会弹出报错信息: command office.markdown.switch not found 在网上找了很久发现没有有关这个插件的文章………

从零开始学 Maven:简化 Java 项目的构建与管理

一、关于Maven 1.1 简介 Maven 是一个由 Apache 软件基金会开发的项目管理和构建自动化工具。它主要用在 Java 项目中,但也可以用于其他类型的项目。Maven 的设计目标是提供一种更加简单、一致的方法来构建和管理项目,它通过使用一个标准的目录布局和一…

去哪儿大数据面试题及参考答案

Hadoop 工作原理是什么? Hadoop 是一个开源的分布式计算框架,主要由 HDFS(Hadoop 分布式文件系统)和 MapReduce 计算模型两部分组成 。 HDFS 工作原理 HDFS 采用主从架构,有一个 NameNode 和多个 DataNode。NameNode 负…