从零开始实现大语言模型(四):简单自注意力机制

1. 前言

理解大语言模型结构的关键在于理解自注意力机制(self-attention)。自注意力机制可以判断输入文本序列中各个token与序列中所有token之间的相关性,并生成包含这种相关性信息的context向量。

本文介绍一种不包含训练参数的简化版自注意力机制——简单自注意力机制(simplified self-attention),后续三篇文章将分别介绍缩放点积注意力机制(scaled dot-product attention),因果注意力机制(causal attention),多头注意力机制(multi-head attention),并最终实现OpenAI的GPT系列大语言模型中MultiHeadAttention类。

2. 从循环神经网络到自注意力机制

解决机器翻译等多对多(many-to-many)自然语言处理任务最常用的模型是sequence-to-sequence模型。Sequence-to-sequence模型包含一个编码器(encoder)和一个解码器(decoder),编码器将输入序列信息编码成信息向量,解码器用于解码信息向量,生成输出序列。在Transformer模型出现之前,编码器和解码器一般都是一个循环神经网络(RNN, recurrent neural network)。

RNN是一种非常适合处理文本等序列数据的神经网络架构。Encoder RNN对输入序列进行处理,将输入序列信息压缩到一个向量中。状态向量 h 0 h_0 h0包含第一个token x 0 x_0 x0的信息, h 1 h_1 h1包含前两个tokens x 0 x_0 x0 x 1 x_1 x1的信息。以此类推, Encoder RNN最后一个状态 h m h_m hm是整个输入序列的概要,包含了整个输入序列的信息。Decoder RNN的初始状态等于Encoder RNN最后一个状态 h m h_m hm h m h_m hm包含了输入序列的信息,Decoder RNN可以通过 h m h_m hm知道输入序列的信息。Decoder RNN可以将 h m h_m hm中包含的信息解码,逐个元素地生成输出序列。

RNN的神经网络结构及计算方法使Encoder RNN必须用一个隐藏状态向量 h m h_m hm记住整个输入序列的全部信息。当输入序列很长时,隐藏状态向量 h m h_m hm对输入序列中前面部分的tokens的偏导数(如对 x 0 x_0 x0的偏导数 ∂ h m x 0 \frac{\partial h_m}{x_0} x0hm)会接近0。输入不同的 x 0 x_0 x0,隐藏状态向量 h m h_m hm几乎不会发生变化,即RNN会遗忘输入序列前面部分的信息。

本文不会详细介绍RNN的原理,大语言模型的神经网络中没有循环结构,RNN的原理及结构与大语言模型没有关系。对RNN的原理感兴趣读者可以参见本人的博客专栏:自然语言处理。

2014年,文章Neural Machine Translation by Jointly Learning to Align and Translate提出了一种改进sequence-to-sequence模型的方法,使Decoder每次更新状态时会查看Encoder所有状态,从而避免RNN遗忘的问题,而且可以让Decoder关注Encoder中最相关的信息,这也是attention名字的由来。

2017年,文章Attention Is All You Need指出可以剥离RNN,仅保留attention,且attention并不局限于sequence-to-sequence模型,可以直接用在输入序列数据上,构建self-attention,并提出了基于attention的sequence-to-sequence架构模型Transformer。

3. 简单自注意力机制

自注意力机制的目标是计算输入文本序列中各个token与序列中所有tokens之间的相关性,并生成包含这种相关性信息的context向量。如下图所示,简单自注意力机制生成context向量的计算步骤如下:

  1. 计算注意力分数(attention score):简单注意力机制使用向量的点积(dot product)作为注意力分数,注意力分数可以衡量两个向量的相关性;
  2. 计算注意力权重(attention weight):将注意力分数归一化得到注意力权重,序列中每个token与序列中所有tokens之间的注意力权重之和等于1;
  3. 计算context向量:简单注意力机制将所有tokens对应Embedding向量的加权和作为context向量,每个token对应Embedding向量的权重等于其相应的注意力权重。

图一

3.1 计算注意力分数

对输入文本序列 “Your journey starts with one step.” 做tokenization,将文本中每个单词分割成一个token,并转换成Embedding向量,得到 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6。自注意力机制分别计算 x i x_i xi x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6的注意力权重,进而计算 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6与其相应注意力权重的加权和,得到context向量 z i z_i zi

如下图所示,将context向量 z i z_i zi对应的向量 x i x_i xi称为query向量,计算query向量 x 2 x_2 x2对应的context向量 z 2 z_2 z2的第一步是计算注意力分数。将query向量 x 2 x_2 x2分别点乘向量 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6,得到实数 ω 21 , ω 22 , ⋯ , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26,其中 ω 2 i \omega_{2i} ω2i是query向量 x 2 x_2 x2与向量 x i x_i xi的注意力分数,可以衡量 x 2 x_2 x2对应token与 x i x_i xi对应token之间的相关性。

图二

两个向量的点积等于这两个向量相同位置元素的乘积之和。假如向量 x 1 = ( x 11 , x 12 , x 13 ) x_1=(x_{11}, x_{12}, x_{13}) x1=(x11,x12,x13),向量 x 2 = ( x 21 , x 22 , x 23 ) x_2=(x_{21}, x_{22}, x_{23}) x2=(x21,x22,x23),则向量 x 1 x_1 x1 x 2 x_2 x2的点积等于 x 11 × x 21 + x 12 × x 22 + x 13 × x 23 x_{11}\times x_{21} + x_{12}\times x_{22} + x_{13}\times x_{23} x11×x21+x12×x22+x13×x23

可以使用如下代码计算query向量 x 2 x_2 x2 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6的注意力分数:

import torch
inputs = torch.tensor([[0.43, 0.15, 0.89], # Your     (x^1)[0.55, 0.87, 0.66], # journey  (x^2)[0.57, 0.85, 0.64], # starts   (x^3)[0.22, 0.58, 0.33], # with     (x^4)[0.77, 0.25, 0.10], # one      (x^5)[0.05, 0.80, 0.55]] # step     (x^6)
)query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

执行上面代码,打印结果如下:

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

3.2 计算注意力权重

如下图所示,将注意力分数 ω 21 , ω 22 , ⋯ , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26归一化可得到注意力权重 α 21 , α 22 , ⋯ , α 26 \alpha_{21}, \alpha_{22}, \cdots, \alpha_{26} α21,α22,,α26。每个注意力权重 α 2 i \alpha_{2i} α2i的值均介于0到1之间,所有注意力权重的和 ∑ i α 2 i = 1 \sum_i\alpha_{2i}=1 iα2i=1。可以用注意力权重 α 2 i \alpha_{2i} α2i表示 x i x_i xi对当前context向量 z 2 z_2 z2的重要性占比,注意力权重 α 2 i \alpha_{2i} α2i越大,表示 x i x_i xi x 2 x_2 x2的相关性越强,context向量 z 2 z_2 z2 x i x_i xi的信息量比例应该越高。使用注意力权重对 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6加权求和计算context向量,可以使context向量的数值分布范围始终与 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6一致。这种数值分布范围的一致性可以使大语言模型训练过程更稳定,模型更容易收敛。

图三

可以使用softmax函数将注意力分数归一化得到注意力权重:

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

执行上面代码,打印结果如下:

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)

3.3 计算context向量

简单注意力机制使用所有tokens对应Embedding向量的加权和作为context向量,context向量 z 2 = ∑ i α 2 i x i z_2=\sum_i\alpha_{2i}x_i z2=iα2ixi

图四

可以使用如下代码计算context向量 z 2 z_2 z2

query = inputs[1] # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):context_vec_2 += attn_weights_2[i] * x_i
print(context_vec_2)

执行上面代码,打印结果如下:

tensor([0.4419, 0.6515, 0.5683])

3.4 计算所有tokens对应的context向量

将向量 x 2 x_2 x2作为query向量,按照3.1所述方法,可以计算出注意力分数 ω 21 , ω 22 , ⋯ , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26。使用softmax函数将注意力分数 ω 21 , ω 22 , ⋯ , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26归一化,可以得到注意力权重 α 21 , α 22 , ⋯ , α 26 \alpha_{21}, \alpha_{22}, \cdots, \alpha_{26} α21,α22,,α26。Context向量 z 2 z_2 z2是使用注意力权重对 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6的加权和。

计算所有tokens对应的context向量,可以使用矩阵乘法运算,分别将各个 x i x_i xi作为query向量,一次性批量计算注意力分数及注意力权重,并最终得到context向量 z i z_i zi

如下面代码所示,可以使用矩阵乘法,一次性计算出所有注意力分数:

attn_scores = inputs @ inputs.T
print(attn_scores)

执行上面代码,打印结果如下:

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],[0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],[0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],[0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],[0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],[0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

@操作符是PyTorch中的矩阵乘法运算符号,与函数torch.matmul运算逻辑相同。

将一个 n n n m m m列的矩阵 A A A与另一个 m m m n n n B B B的矩阵相乘,结果 C C C是一个 n n n n n n列的矩阵。其中矩阵 C C C i i i j j j列元素等于矩阵 A A A的第 i i i行与矩阵 B B B的第 j j j列两个向量的内积。

如下面代码所示,使用softmax函数注意力分数归一化,可以一次批量计算出所有注意力权重:

attn_weights = torch.softmax(attn_scores, dim=1)
print(attn_weights)

执行上面代码,打印结果如下:

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],[0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],[0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],[0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],[0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],[0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

可以同样使用矩阵乘法运算,一次性批量计算出所有context向量:

all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

执行上面代码,打印结果如下:

tensor([[0.4421, 0.5931, 0.5790],[0.4419, 0.6515, 0.5683],[0.4431, 0.6496, 0.5671],[0.4304, 0.6298, 0.5510],[0.4671, 0.5910, 0.5266],[0.4177, 0.6503, 0.5645]])

4. 结束语

自注意力机制是大语言模型神经网络结构中最复杂的部分。为降低自注意力机制原理的理解门槛,本文介绍了一种不带任何训练参数的简化版自注意力机制。

自注意力机制的目标是计算输入文本序列中各个token与序列中所有tokens之间的相关性,并生成包含这种相关性信息的context向量。简单自注意力机制生成context向量共3个步骤,首先计算注意力分数,然后使用softmax函数将注意力分数归一化得到注意力权重,最后使用注意力权重对所有tokens对应的Embedding向量加权求和得到context向量。

接下来,该去看看大语言模型中真正使用到的注意力机制了!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/374959.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uni-app/vue项目如何封装全局消息提示组件

效果图&#xff1a; 第一步&#xff1a;封装组件和方法&#xff0c;采用插件式注册&#xff01; 在项目目录下新建components文件夹&#xff0c;里面放两个文件&#xff0c;分别是index.vue和index.js. index.vue&#xff1a; <template><div class"toast&quo…

【Linux杂货铺】2.进程优先级

1.进程优先级基本概念 进程优先级是操作系统中用于确定进程调度顺序的一个指标。每个进程都会被分配一个优先级&#xff0c;优先级较高的进程会在调度时优先被执行。进程优先级的设定通常根据进程的重要性、紧急程度、资源需求等因素来确定。操作系统会根据进程的优先级来决定进…

nuPlan 是一个针对自动驾驶车辆的闭环机器学习(ML-based)规划基准测试

nuPlan: A closed-loop ML-based planning benchmark for autonomous vehicles nuPlan 是一个针对自动驾驶车辆的闭环机器学习&#xff08;ML-based&#xff09;规划基准测试 Abstract In this work, we propose the world’s first closed-loop ML-based planning benchmar…

【JavaScript】解决 JavaScript 语言报错:Uncaught ReferenceError: XYZ is not defined

文章目录 一、背景介绍常见场景 二、报错信息解析三、常见原因分析1. 变量未声明2. 拼写错误3. 块级作用域4. 使用未定义的函数或对象5. 代码执行顺序 四、解决方案与预防措施1. 确保变量已声明2. 检查拼写错误3. 注意块级作用域4. 定义和调用函数5. 正确的代码执行顺序 五、示…

tkinter-TinUI-xml实战(11)多功能TinUIxml编辑器

引言 在TinUIXml简易编辑器中&#xff0c;我们通过TinUI搭建了一个简易的针对TinUIXml布局的编辑器&#xff0c;基本掌握了TinUIXml布局和TinUIXml的导入与导出。现在&#xff0c;就在此基础上&#xff0c;对编辑器进行升级。 本次升级的功能&#xff1a; 更合理的xml编辑与…

Java设计模式---(创建型模式)工厂、单例、建造者、原型

目录 前言一、工厂模式&#xff08;Factory&#xff09;1.1 工厂方法模式&#xff08;Factory Method&#xff09;1.1.1 普通工厂方法模式1.1.2 多个工厂方法模式1.1.3 静态工厂方法模式 1.2 抽象工厂模式&#xff08;Abstract Factory&#xff09; 二、单例模式&#xff08;Si…

浅析Kafka-Stream消息流式处理流程及原理

以下结合案例&#xff1a;统计消息中单词出现次数&#xff0c;来测试并说明kafka消息流式处理的执行流程 Maven依赖 <dependencies><dependency><groupId>org.apache.kafka</groupId><artifactId>kafka-streams</artifactId><exclusio…

【密码学】大整数分解问题和离散对数问题

公钥密码体制的主要思想是通过一种非对称性&#xff0c;即正向计算简单&#xff0c;逆向计算复杂的加密算法设计&#xff0c;来解决安全通信。本文介绍两种在密码学领域内最为人所熟知、应用最为广泛的数学难题——大整数分解问题与离散对数问题 一、大整数分解问题 &#xf…

thinkphp 生成邀请推广二维码,保存到服务器并接口返回给前端

根据每个人生成自己的二维码图片,接口返回二维码图片地址 生成在服务器的二维码图片 控制器 public function createUserQRcode(){$uid = input(uid);if

传言称 iPhone 16 Pro 将支持 40W 快速充电和 20W MagSafe

目前&#xff0c;iPhone 15 和 iPhone 15 Pro 机型使用合适的 USB-C 电源适配器可实现高达 27W 的峰值充电速度&#xff0c;而 Apple 和授权第三方的官方 MagSafe 充电器可以高达 15W 的功率为 iPhone 15 机型进行无线充电。所有四款 iPhone 15 机型均可使用 20W 或更高功率的电…

FPGA学习笔记(一) FPGA最小系统

文章目录 前言一、FPGA最小系统总结 前言 今天学习下FPGA的最小系统一、FPGA最小系统 FPGA最小系统与STM32最小系统类似&#xff0c;由供电电源&#xff0c;时钟电路晶振&#xff0c;复位和调试接口JTAG以及FLASH配置芯片组成&#xff0c;其与STM32最大的不同之处就是必须要有…

Appium自动化测试系列: 2. 使用Appium启动APP(真机)

历史文章&#xff1a;Appium自动化测试系列: 1. Mac安装配置Appium_mac安装appium-CSDN博客 一、准备工作 1. 安卓测试机打开调试模式&#xff0c;然后使用可以传输数据的数据线连接上你的电脑。注意&#xff1a;你的数据线一定要支持传输数据&#xff0c;有的数据线只支持充…

《数据结构:C语言实现顺序表》

文章目录 一、顺序表1、静态顺序表2、动态顺序表 二、动态顺序表实现1、创建自定义类型2、完成顺序表的创建&#xff0c;测试功能需求3、完成顺序表的初始化和销毁功能4、顺序表插入数据和打印数据5、删除数据 三、顺序表完成最终的代码test.c文件中的代码&#xff1a;用来测试…

新手教学系列——MongoDB聚合查询的进阶用法

引言 MongoDB的聚合查询是其最强大的功能之一。无论是汇总、平均值、计数等标准操作,还是处理复杂的数据集合,MongoDB的聚合框架都能提供高效且灵活的解决方案。本文将通过几个实例,详细讲解如何在实际项目中使用MongoDB进行聚合查询。 标准应用:汇总、平均值、计数等 在…

k8s集群部署mysql8主备

一、搜索mysql8版本 # helm search repo mysql# helm pull bitnami/mysql --version:11.1.2# tar -zxf mysql-11.1.2.tgz# cd mysql 二、修改value.ysqml文件 动态存储类自己提前搭建。 # helm install mysql8 -n mysql-cluster ./ -f values.yaml NAME: mysql8 LAST DEPLOYED…

Neo4j安装

下载地址&#xff1a;Neo4j Deployment Center - Graph Database & Analytics 1.安装jdk&#xff0c;Neo4j 3.0需要jdk8&#xff0c;2.3.0之前的版本建议jdk7。Neo4j最新版本5.21.2&#xff0c;对应jdk版本17 2.将下载的zip文件解压到合适路径。 3.设置环境变量NEO4J_H…

【机器学习】朴素贝叶斯算法详解与实战扩展

欢迎来到 破晓的历程的 博客 ⛺️不负时光&#xff0c;不负己✈️ 引言 朴素贝叶斯算法是一种基于概率统计的分类方法&#xff0c;它利用贝叶斯定理和特征条件独立假设来预测样本的类别。尽管其假设特征之间相互独立在现实中往往不成立&#xff0c;但朴素贝叶斯分类器因其计算…

卤味江湖中,周黑鸭究竟该抓住什么赛点?

近年来&#xff0c;卤味江湖的决斗从未停止。 随着休闲卤味、佐餐卤味等细分赛道逐渐形成&#xff0c;“卤味三巨头”&#xff08;周黑鸭、绝味食品、煌上煌&#xff09;的牌桌上有了更多新对手&#xff0c;赛道变挤了&#xff0c;“周黑鸭们”也到了转型关键期。 这个夏天&a…

linux系统操作/基本命令/vim/权限修改/用户建立

Linux的目录结构&#xff1a; 一&#xff1a;在Linux系统中&#xff0c;路径之间的层级关系&#xff0c;使用:/来表示 注意:1、开头的/表示根目录 2、后面的/表示层级关系 二&#xff1a;在windows系统中&#xff0c;路径之间的层级关系&#xff0c;使用:\来表示 注意:1、D:表示…

【web前端HTML+CSS+JS】--- JS学习笔记03

一、JS介绍 可以在前端页面上进行逻辑处理&#xff0c;来解决表单的验证等问题&#xff0c;提升效率&#xff0c;直接在前端提示问题&#xff0c;减少服务器压力 应用1&#xff1a;可以做静态验证和动态验证&#xff08;进行异步请求&#xff09; 应用2&#xff1a;可以解析后…